四、超参数优化
4.1 GridSearchCV
GridSearchCV
用于实现超参数优化,其原型为:class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None,
fit_params=None, n_jobs=1, iid=True, refit=True, cv=None,verbose=0,
pre_dispatch='2*n_jobs', error_score='raise',return_train_score='warn')
estimator
:一个学习器对象。它必须有.fit
方法用于学习,.predict
方法用于预测,有.score
方法用于性能评分。param_grid
:字典或者字典的列表。每个字典都给出了学习器的一个超参数,其中:- 字典的键就是超参数名。
- 字典的值是一个列表,指定了超参数对应的候选值序列。
fit_params
:一个字典,用来给学习器的.fit
方法传递参数。iid
:如果为True
,则表示数据是独立同分布的。refit
:一个布尔值。如果为True
,则在参数优化之后使用整个数据集来重新训练该最优的estimator
。error_score
:一个数值或者字符串'raise'
,指定当estimator
训练发生异常时,如何处理:- 如果为
'raise'
,则抛出异常。 - 如果为数值,则将该数值作为本轮
estimator
的预测得分。
- 如果为
return_train_score
: 一个布尔值,指示是否返回训练集的预测得分。如果为
'warn'
,则等价于True
并抛出一个警告。其它参数参考
cross_val_score
。
属性:
cv_results_
:一个数组的字典。可以直接用于生成pandas DataFrame
。其中键为超参数名,值为超参数的数组。另外额外多了一些键:
mean_fit_time
、mean_score_time
、std_fit_time
、std_score_time
:给出了训练时间、评估时间的均值和方差,单位为秒。xx_score
:给出了各种评估得分。
best_estimator_
:一个学习器对象,代表了根据候选参数组合筛选出来的最佳的学习器。如果
refit=False
,则该属性不可用。best_score_
:最佳学习器的性能评分。best_params_
:最佳参数组合。best_index_
:cv_results_
中,第几组参数对应着最佳参数组合。scorer_
:评分函数。n_splits_
:交叉验证的k
值。
方法:
fit(X[, y,groups])
:执行参数优化。predict(X)
:使用学到的最佳学习器来预测数据。predict_log_proba(X)
:使用学到的最佳学习器来预测数据为各类别的概率的对数值。predict_proba(X)
:使用学到的最佳学习器来预测数据为各类别的概率。score(X[, y])
:通过给定的数据集来判断学到的最佳学习器的预测性能。transform(X)
:对最佳学习器执行transform
。inverse_transform(X)
:对最佳学习器执行逆transform
。decision_function(X)
:对最佳学习器调用决策函数。
GridSearchCV
实现了estimator
的.fit
、.score
方法。这些方法内部会调用estimator
的对应的方法。在调用
GridSearchCV.fit
方法时,首先会将训练集进行 折交叉,然后在每次划分的集合上进行多轮的训练和验证(每一轮都采用一种参数组合),然后调用最佳学习器的.fit
方法。
4.2 RandomizedSearchCV
GridSearchCV
采用的是暴力寻找的方法来寻找最优参数。当待优化的参数是离散的取值的时候,GridSearchCV
能够顺利找出最优的参数。但是当待优化的参数是连续取值的时候,暴力寻找就有心无力。GridSearchCV
的做法是从这些连续值中挑选几个值作为代表,从而在这些代表中挑选出最佳的参数。RandomizedSearchCV
采用随机搜索所有的候选参数对的方法来寻找最优的参数组合。其原型为:class sklearn.model_selection.RandomizedSearchCV(estimator, param_distributions,
n_iter=10, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True,
cv=None, verbose=0, pre_dispatch='2*n_jobs', random_state=None,
error_score='raise',return_train_score='warn')
param_distributions
:字典或者字典的列表。每个字典都给出了学习器的一个参数,其中:字典的键就是参数名。
字典的值是一个分布类,分布类必须提供
.rvs
方法。通常你可以使用
scipy.stats
模块中提供的分布类,比如scipy.expon
(指数分布)、scipy.gamma
(gamma分布)、scipy.uniform
(均匀分布)、randint
等等。字典的值也可以是一个数值序列,此时就在该序列中均匀采样。
n_iter
:一个整数,指定每个参数采样的数量。通常该值越大,参数优化的效果越好。但是参数越大,运行时间也更长。其它参数参考
GridSearchCV
。
属性:参考
GridSearchCV
。方法:参考
GridSearchCV
。