Wrappers for the Scikit-Learn API
You can use Sequential
Keras models (single-input only) as part of your Scikit-Learn workflow via the wrappers found at keras.wrappers.scikit_learn.py
.
There are two wrappers available:
keras.wrappers.scikit_learn.KerasClassifier(build_fn=None, **sk_params)
, which implements the Scikit-Learn classifier interface,
keras.wrappers.scikit_learn.KerasRegressor(build_fn=None, **sk_params)
, which implements the Scikit-Learn regressor interface.
Arguments
- build_fn: callable function or class instance
- sk_params: model parameters & fitting parameters
build_fn
should construct, compile and return a Keras model, whichwill then be used to fit/predict. One of the followingthree values could be passed to build_fn
:
- A function
- An instance of a class that implements the
call
method - None. This means you implement a class that inherits from either
KerasClassifier
orKerasRegressor
. Thecall
method of thepresent class will then be treated as the defaultbuild_fn
.sk_params
takes both model parameters and fitting parameters. Legal modelparameters are the arguments ofbuild_fn
. Note that like all otherestimators in scikit-learn,build_fn
should provide default values forits arguments, so that you could create the estimator without passing anyvalues tosk_params
.
sk_params
could also accept parameters for calling fit
, predict
,predict_proba
, and score
methods (e.g., epochs
, batch_size
).fitting (predicting) parameters are selected in the following order:
- Values passed to the dictionary arguments of
fit
,predict
,predict_proba
, andscore
methods - Values passed to
sk_params
- The default values of the
keras.models.Sequential
fit
,predict
,predict_proba
andscore
methodsWhen using scikit-learn'sgrid_search
API, legal tunable parameters arethose you could pass tosk_params
, including fitting parameters.In other words, you could usegrid_search
to search for the bestbatch_size
orepochs
as well as the model parameters.