Wrappers for the Scikit-Learn API

You can use Sequential Keras models (single-input only) as part of your Scikit-Learn workflow via the wrappers found at keras.wrappers.scikit_learn.py.

There are two wrappers available:

keras.wrappers.scikit_learn.KerasClassifier(build_fn=None, **sk_params), which implements the Scikit-Learn classifier interface,

keras.wrappers.scikit_learn.KerasRegressor(build_fn=None, **sk_params), which implements the Scikit-Learn regressor interface.

Arguments

  • build_fn: callable function or class instance
  • sk_params: model parameters & fitting parameters

build_fn should construct, compile and return a Keras model, whichwill then be used to fit/predict. One of the followingthree values could be passed to build_fn:

  • A function
  • An instance of a class that implements the call method
  • None. This means you implement a class that inherits from eitherKerasClassifier or KerasRegressor. The call method of thepresent class will then be treated as the default build_fn.sk_params takes both model parameters and fitting parameters. Legal modelparameters are the arguments of build_fn. Note that like all otherestimators in scikit-learn, build_fn should provide default values forits arguments, so that you could create the estimator without passing anyvalues to sk_params.

sk_params could also accept parameters for calling fit, predict,predict_proba, and score methods (e.g., epochs, batch_size).fitting (predicting) parameters are selected in the following order:

  • Values passed to the dictionary arguments offit, predict, predict_proba, and score methods
  • Values passed to sk_params
  • The default values of the keras.models.Sequentialfit, predict, predict_proba and score methodsWhen using scikit-learn's grid_search API, legal tunable parameters arethose you could pass to sk_params, including fitting parameters.In other words, you could use grid_search to search for the bestbatch_size or epochs as well as the model parameters.