gridsearchcv

功能介绍

gridsearch是通过参数数组组成的网格,对其中的每一组输入参数的组很分别进行训练,预测,评估。取得评估参数最优的模型,作为最终的返回模型

cv为交叉验证,将数据切分为k-folds,对每k-1份数据做训练,对剩余一份数据做预测和评估,得到一个评估结果。

此函数用cv方法得到每一个grid对应参数的评估结果,得到最优模型

参数说明

名称 中文名称 描述 类型 是否必须? 默认值
NumFolds 折数 交叉验证的参数,数据的折数(大于等于2) Integer 10
ParamGrid 参数网格 指定参数的网格 ParamGrid —-
Estimator Estimator 用于调优的Estimator Estimator —-
TuningEvaluator 评估指标 用于选择最优模型的评估指标 TuningEvaluator —-

脚本示例

脚本代码

  1. def adult(url):
  2. data = (
  3. CsvSourceBatchOp()
  4. .setFilePath('http://alink-dataset.cn-hangzhou.oss.aliyun-inc.com/csv/adult_train.csv')
  5. .setSchemaStr(
  6. 'age bigint, workclass string, fnlwgt bigint,'
  7. 'education string, education_num bigint,'
  8. 'marital_status string, occupation string,'
  9. 'relationship string, race string, sex string,'
  10. 'capital_gain bigint, capital_loss bigint,'
  11. 'hours_per_week bigint, native_country string,'
  12. 'label string'
  13. )
  14. )
  15. return data
  16. def adult_train():
  17. return adult('http://alink-dataset.cn-hangzhou.oss.aliyun-inc.com/csv/adult_train.csv')
  18. def adult_test():
  19. return adult('http://alink-dataset.cn-hangzhou.oss.aliyun-inc.com/csv/adult_test.csv')
  20. def adult_numerical_feature_strs():
  21. return [
  22. "age", "fnlwgt", "education_num",
  23. "capital_gain", "capital_loss", "hours_per_week"
  24. ]
  25. def adult_categorical_feature_strs():
  26. return [
  27. "workclass", "education", "marital_status",
  28. "occupation", "relationship", "race", "sex",
  29. "native_country"
  30. ]
  31. def adult_features_strs():
  32. feature = adult_numerical_feature_strs()
  33. feature.extend(adult_categorical_feature_strs())
  34. return feature
  35. def rf_grid_search_cv(featureCols, categoryFeatureCols, label, metric):
  36. rf = (
  37. RandomForestClassifier()
  38. .setFeatureCols(featureCols)
  39. .setCategoricalCols(categoryFeatureCols)
  40. .setLabelCol(label)
  41. .setPredictionCol('prediction')
  42. .setPredictionDetailCol('prediction_detail')
  43. )
  44. paramGrid = (
  45. ParamGrid()
  46. .addGrid(rf, 'SUBSAMPLING_RATIO', [1.0, 0.99, 0.98])
  47. .addGrid(rf, 'NUM_TREES', [3, 6, 9])
  48. )
  49. tuningEvaluator = (
  50. BinaryClassificationTuningEvaluator()
  51. .setLabelCol(label)
  52. .setPredictionDetailCol("prediction_detail")
  53. .setMetricName(metric)
  54. )
  55. cv = (
  56. GridSearchCV()
  57. .setEstimator(rf)
  58. .setParamGrid(paramGrid)
  59. .setTuningEvaluator(tuningEvaluator)
  60. .setNumFolds(2)
  61. )
  62. return cv
  63. def rf_grid_search_tv(featureCols, categoryFeatureCols, label, metric):
  64. rf = (
  65. RandomForestClassifier()
  66. .setFeatureCols(featureCols)
  67. .setCategoricalCols(categoryFeatureCols)
  68. .setLabelCol(label)
  69. .setPredictionCol('prediction')
  70. .setPredictionDetailCol('prediction_detail')
  71. )
  72. paramGrid = (
  73. ParamGrid()
  74. .addGrid(rf, 'SUBSAMPLING_RATIO', [1.0, 0.99, 0.98])
  75. .addGrid(rf, 'NUM_TREES', [3, 6, 9])
  76. )
  77. tuningEvaluator = (
  78. BinaryClassificationTuningEvaluator()
  79. .setLabelCol(label)
  80. .setPredictionDetailCol("prediction_detail")
  81. .setMetricName(metric)
  82. )
  83. cv = (
  84. GridSearchTVSplit()
  85. .setEstimator(rf)
  86. .setParamGrid(paramGrid)
  87. .setTuningEvaluator(tuningEvaluator)
  88. )
  89. return cv
  90. def tuningcv(cv_estimator, input):
  91. return cv_estimator.fit(input)
  92. def tuningtv(tv_estimator, input):
  93. return tv_estimator.fit(input)
  94. def main():
  95. print('rf cv tuning')
  96. model = tuningcv(
  97. rf_grid_search_cv(adult_features_strs(),
  98. adult_categorical_feature_strs(), 'label', 'AUC'),
  99. adult_train()
  100. )
  101. print(model.getReport())
  102. print('rf tv tuning')
  103. model = tuningtv(
  104. rf_grid_search_tv(adult_features_strs(),
  105. adult_categorical_feature_strs(), 'label', 'AUC'),
  106. adult_train()
  107. )
  108. print(model.getReport())
  109. main()

脚本结果

  1. rf cv tuning
  2. com.alibaba.alink.pipeline.tuning.GridSearchCV
  3. [ {
  4. "param" : [ {
  5. "stage" : "RandomForestClassifier",
  6. "paramName" : "numTrees",
  7. "paramValue" : 3
  8. }, {
  9. "stage" : "RandomForestClassifier",
  10. "paramName" : "subsamplingRatio",
  11. "paramValue" : 1.0
  12. } ],
  13. "metric" : 0.8922549257899725
  14. }, {
  15. "param" : [ {
  16. "stage" : "RandomForestClassifier",
  17. "paramName" : "numTrees",
  18. "paramValue" : 3
  19. }, {
  20. "stage" : "RandomForestClassifier",
  21. "paramName" : "subsamplingRatio",
  22. "paramValue" : 0.99
  23. } ],
  24. "metric" : 0.8920255970548456
  25. }, {
  26. "param" : [ {
  27. "stage" : "RandomForestClassifier",
  28. "paramName" : "numTrees",
  29. "paramValue" : 3
  30. }, {
  31. "stage" : "RandomForestClassifier",
  32. "paramName" : "subsamplingRatio",
  33. "paramValue" : 0.98
  34. } ],
  35. "metric" : 0.8944982480437225
  36. }, {
  37. "param" : [ {
  38. "stage" : "RandomForestClassifier",
  39. "paramName" : "numTrees",
  40. "paramValue" : 6
  41. }, {
  42. "stage" : "RandomForestClassifier",
  43. "paramName" : "subsamplingRatio",
  44. "paramValue" : 1.0
  45. } ],
  46. "metric" : 0.8923867598288401
  47. }, {
  48. "param" : [ {
  49. "stage" : "RandomForestClassifier",
  50. "paramName" : "numTrees",
  51. "paramValue" : 6
  52. }, {
  53. "stage" : "RandomForestClassifier",
  54. "paramName" : "subsamplingRatio",
  55. "paramValue" : 0.99
  56. } ],
  57. "metric" : 0.9012141767959505
  58. }, {
  59. "param" : [ {
  60. "stage" : "RandomForestClassifier",
  61. "paramName" : "numTrees",
  62. "paramValue" : 6
  63. }, {
  64. "stage" : "RandomForestClassifier",
  65. "paramName" : "subsamplingRatio",
  66. "paramValue" : 0.98
  67. } ],
  68. "metric" : 0.8993774036693788
  69. }, {
  70. "param" : [ {
  71. "stage" : "RandomForestClassifier",
  72. "paramName" : "numTrees",
  73. "paramValue" : 9
  74. }, {
  75. "stage" : "RandomForestClassifier",
  76. "paramName" : "subsamplingRatio",
  77. "paramValue" : 1.0
  78. } ],
  79. "metric" : 0.8981738808130779
  80. }, {
  81. "param" : [ {
  82. "stage" : "RandomForestClassifier",
  83. "paramName" : "numTrees",
  84. "paramValue" : 9
  85. }, {
  86. "stage" : "RandomForestClassifier",
  87. "paramName" : "subsamplingRatio",
  88. "paramValue" : 0.99
  89. } ],
  90. "metric" : 0.9029671873892725
  91. }, {
  92. "param" : [ {
  93. "stage" : "RandomForestClassifier",
  94. "paramName" : "numTrees",
  95. "paramValue" : 9
  96. }, {
  97. "stage" : "RandomForestClassifier",
  98. "paramName" : "subsamplingRatio",
  99. "paramValue" : 0.98
  100. } ],
  101. "metric" : 0.905228896323363
  102. } ]
  103. rf tv tuning
  104. com.alibaba.alink.pipeline.tuning.GridSearchTVSplit
  105. [ {
  106. "param" : [ {
  107. "stage" : "RandomForestClassifier",
  108. "paramName" : "numTrees",
  109. "paramValue" : 3
  110. }, {
  111. "stage" : "RandomForestClassifier",
  112. "paramName" : "subsamplingRatio",
  113. "paramValue" : 1.0
  114. } ],
  115. "metric" : 0.9022694229691741
  116. }, {
  117. "param" : [ {
  118. "stage" : "RandomForestClassifier",
  119. "paramName" : "numTrees",
  120. "paramValue" : 3
  121. }, {
  122. "stage" : "RandomForestClassifier",
  123. "paramName" : "subsamplingRatio",
  124. "paramValue" : 0.99
  125. } ],
  126. "metric" : 0.8963559966080328
  127. }, {
  128. "param" : [ {
  129. "stage" : "RandomForestClassifier",
  130. "paramName" : "numTrees",
  131. "paramValue" : 3
  132. }, {
  133. "stage" : "RandomForestClassifier",
  134. "paramName" : "subsamplingRatio",
  135. "paramValue" : 0.98
  136. } ],
  137. "metric" : 0.9041948454957178
  138. }, {
  139. "param" : [ {
  140. "stage" : "RandomForestClassifier",
  141. "paramName" : "numTrees",
  142. "paramValue" : 6
  143. }, {
  144. "stage" : "RandomForestClassifier",
  145. "paramName" : "subsamplingRatio",
  146. "paramValue" : 1.0
  147. } ],
  148. "metric" : 0.8982021117392784
  149. }, {
  150. "param" : [ {
  151. "stage" : "RandomForestClassifier",
  152. "paramName" : "numTrees",
  153. "paramValue" : 6
  154. }, {
  155. "stage" : "RandomForestClassifier",
  156. "paramName" : "subsamplingRatio",
  157. "paramValue" : 0.99
  158. } ],
  159. "metric" : 0.9031851535310546
  160. }, {
  161. "param" : [ {
  162. "stage" : "RandomForestClassifier",
  163. "paramName" : "numTrees",
  164. "paramValue" : 6
  165. }, {
  166. "stage" : "RandomForestClassifier",
  167. "paramName" : "subsamplingRatio",
  168. "paramValue" : 0.98
  169. } ],
  170. "metric" : 0.9034443322241488
  171. }, {
  172. "param" : [ {
  173. "stage" : "RandomForestClassifier",
  174. "paramName" : "numTrees",
  175. "paramValue" : 9
  176. }, {
  177. "stage" : "RandomForestClassifier",
  178. "paramName" : "subsamplingRatio",
  179. "paramValue" : 1.0
  180. } ],
  181. "metric" : 0.8993474753000145
  182. }, {
  183. "param" : [ {
  184. "stage" : "RandomForestClassifier",
  185. "paramName" : "numTrees",
  186. "paramValue" : 9
  187. }, {
  188. "stage" : "RandomForestClassifier",
  189. "paramName" : "subsamplingRatio",
  190. "paramValue" : 0.99
  191. } ],
  192. "metric" : 0.9090250137144916
  193. }, {
  194. "param" : [ {
  195. "stage" : "RandomForestClassifier",
  196. "paramName" : "numTrees",
  197. "paramValue" : 9
  198. }, {
  199. "stage" : "RandomForestClassifier",
  200. "paramName" : "subsamplingRatio",
  201. "paramValue" : 0.98
  202. } ],
  203. "metric" : 0.9129786771786127
  204. } ]