callbacks.general_sched

Implementation of a flexible training API

TrainingPhase and General scheduler

Creates a scheduler that lets you train a model with following different TrainingPhase.

class TrainingPhase[source][test]

TrainingPhase(length:int) No tests found for TrainingPhase. To contribute a test please refer to this guide and this discussion.

Schedule hyper-parameters for a phase of length iterations.

You can then schedule any hyper-parameter you want by using the following method.

schedule_hp[source][test]

schedule_hp(name, vals, anneal=None) No tests found for schedule_hp. To contribute a test please refer to this guide and this discussion.

Adds a schedule for name between vals using anneal.

The phase will make the hyper-parameter vary from the first value in vals to the second, following anneal. If an annealing function is specified but vals is a float, it will decay to 0. If no annealing function is specified, the default is a linear annealing for a tuple, a constant parameter if it’s a float.

Note: If you want to use discriminative values, you can pass an numpy array in vals (or a tuple of them for start and stop).

The basic hyper-parameters are named:

  • ‘lr’ for learning rate
  • ‘mom’ for momentum (or beta1 in Adam)
  • ‘beta’ for the beta2 in Adam or the alpha in RMSprop
  • ‘wd’ for weight decay

You can also add any hyper-parameter that is in your optimizer (even if it’s custom or a GeneralOptimizer), like ‘eps’ if you’re using Adam.

Let’s make an example by using this to code SGD with warm restarts.

  1. def fit_sgd_warm(learn, n_cycles, lr, mom, cycle_len, cycle_mult):
  2. n = len(learn.data.train_dl)
  3. phases = [(TrainingPhase(n * (cycle_len * cycle_mult**i))
  4. .schedule_hp('lr', lr, anneal=annealing_cos)
  5. .schedule_hp('mom', mom)) for i in range(n_cycles)]
  6. sched = GeneralScheduler(learn, phases)
  7. learn.callbacks.append(sched)
  8. if cycle_mult != 1:
  9. total_epochs = int(cycle_len * (1 - (cycle_mult)**n_cycles)/(1-cycle_mult))
  10. else: total_epochs = n_cycles * cycle_len
  11. learn.fit(total_epochs)
  1. path = untar_data(URLs.MNIST_SAMPLE)
  2. data = ImageDataBunch.from_folder(path)
  3. learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)
  4. fit_sgd_warm(learn, 3, 1e-3, 0.9, 1, 2)
epochtrain_lossvalid_lossaccuracytime
00.1621460.1535320.94210000:02
10.1261120.1172670.96025500:02
20.1120450.1105860.96221800:02
30.0976030.0908380.96761500:02
40.0868830.0813750.97301300:02
50.0836730.0761600.97399400:02
60.0848350.0762110.97399400:02
  1. learn.recorder.plot_lr()

GeneralScheduler - 图1

class GeneralScheduler[source][test]

GeneralScheduler(learn:Learner, phases:Collection[TrainingPhase], start_epoch:int=None) :: LearnerCallback No tests found for GeneralScheduler. To contribute a test please refer to this guide and this discussion.

Schedule multiple TrainingPhase for a Learner.

Callback methods

You don’t call these yourself - they’re called by fastai’s Callback system automatically to enable the class’s functionality.

on_batch_end[source][test]

on_batch_end(train, **kwargs:Any) No tests found for on_batch_end. To contribute a test please refer to this guide and this discussion.

Takes a step in the current phase and prepare the hyperparameters for the next batch.

on_train_begin[source][test]

on_train_begin(epoch:int, **kwargs:Any) No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.

Initiates the hyperparameters to the start values of the first phase.


Company logo

©2021 fast.ai. All rights reserved.
Site last generated: Jan 5, 2021