callbacks.general_sched
Implementation of a flexible training API
TrainingPhase and General scheduler
Creates a scheduler that lets you train a model with following different TrainingPhase
.
class
TrainingPhase
[source][test]
TrainingPhase
(length
:int
) No tests found forTrainingPhase
. To contribute a test please refer to this guide and this discussion.
Schedule hyper-parameters for a phase of length
iterations.
You can then schedule any hyper-parameter you want by using the following method.
schedule_hp
[source][test]
schedule_hp
(name
,vals
,anneal
=None
) No tests found forschedule_hp
. To contribute a test please refer to this guide and this discussion.
Adds a schedule for name
between vals
using anneal
.
The phase will make the hyper-parameter vary from the first value in vals
to the second, following anneal
. If an annealing function is specified but vals
is a float, it will decay to 0. If no annealing function is specified, the default is a linear annealing for a tuple, a constant parameter if it’s a float.
Note: If you want to use discriminative values, you can pass an numpy array in vals
(or a tuple of them for start and stop).
The basic hyper-parameters are named:
- ‘lr’ for learning rate
- ‘mom’ for momentum (or beta1 in Adam)
- ‘beta’ for the beta2 in Adam or the alpha in RMSprop
- ‘wd’ for weight decay
You can also add any hyper-parameter that is in your optimizer (even if it’s custom or a GeneralOptimizer
), like ‘eps’ if you’re using Adam.
Let’s make an example by using this to code SGD with warm restarts.
def fit_sgd_warm(learn, n_cycles, lr, mom, cycle_len, cycle_mult):
n = len(learn.data.train_dl)
phases = [(TrainingPhase(n * (cycle_len * cycle_mult**i))
.schedule_hp('lr', lr, anneal=annealing_cos)
.schedule_hp('mom', mom)) for i in range(n_cycles)]
sched = GeneralScheduler(learn, phases)
learn.callbacks.append(sched)
if cycle_mult != 1:
total_epochs = int(cycle_len * (1 - (cycle_mult)**n_cycles)/(1-cycle_mult))
else: total_epochs = n_cycles * cycle_len
learn.fit(total_epochs)
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)
fit_sgd_warm(learn, 3, 1e-3, 0.9, 1, 2)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.162146 | 0.153532 | 0.942100 | 00:02 |
1 | 0.126112 | 0.117267 | 0.960255 | 00:02 |
2 | 0.112045 | 0.110586 | 0.962218 | 00:02 |
3 | 0.097603 | 0.090838 | 0.967615 | 00:02 |
4 | 0.086883 | 0.081375 | 0.973013 | 00:02 |
5 | 0.083673 | 0.076160 | 0.973994 | 00:02 |
6 | 0.084835 | 0.076211 | 0.973994 | 00:02 |
learn.recorder.plot_lr()
class
GeneralScheduler
[source][test]
GeneralScheduler
(learn
:Learner
,phases
:Collection
[TrainingPhase
],start_epoch
:int
=None
) ::LearnerCallback
No tests found forGeneralScheduler
. To contribute a test please refer to this guide and this discussion.
Schedule multiple TrainingPhase
for a Learner
.
Callback methods
You don’t call these yourself - they’re called by fastai’s Callback
system automatically to enable the class’s functionality.
on_batch_end
[source][test]
on_batch_end
(train
, **kwargs
:Any
) No tests found foron_batch_end
. To contribute a test please refer to this guide and this discussion.
Takes a step in the current phase and prepare the hyperparameters for the next batch.
on_train_begin
[source][test]
on_train_begin
(epoch
:int
, **kwargs
:Any
) No tests found foron_train_begin
. To contribute a test please refer to this guide and this discussion.
Initiates the hyperparameters to the start values of the first phase.
©2021 fast.ai. All rights reserved.
Site last generated: Jan 5, 2021