callbacks

Callbacks implemented in the fastai library

List of callbacks

fastai’s training loop is highly extensible, with a rich callback system. See the callback docs if you’re interested in writing your own callback. See below for a list of callbacks that are provided with fastai, grouped by the module they’re defined in.

Every callback that is passed to Learner with the callback_fns parameter will be automatically stored as an attribute. The attribute name is snake-cased, so for instance ActivationStats will appear as learn.activation_stats (assuming your object is named learn).

Callback

This sub-package contains more sophisticated callbacks that each are in their own module. They are (click the link for more details):

LRFinder

Use Leslie Smith’s learning rate finder to find a good learning rate for training your model. Let’s see an example of use on the MNIST dataset with a simple CNN.

  1. path = untar_data(URLs.MNIST_SAMPLE)
  2. data = ImageDataBunch.from_folder(path)
  3. def simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])
  4. learn = simple_learner()

The fastai librairy already has a Learner method called lr_find that uses LRFinder to plot the loss as a function of the learning rate

  1. learn.lr_find()
  1. LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
  1. learn.recorder.plot()

Overview - 图1

In this example, a learning rate around 2e-2 seems like the right fit.

  1. lr = 2e-2

OneCycleScheduler

Train with Leslie Smith’s 1cycle annealing method. Let’s train our simple learner using the one cycle policy.

  1. learn.fit_one_cycle(3, lr)

Total time: 00:07

epochtrain_lossvalid_lossaccuracytime
00.1094390.0593490.98086400:02
10.0395820.0231520.99214900:02
20.0190090.0212390.99165900:02

The learning rate and the momentum were changed during the epochs as follows (more info on the dedicated documentation page).

  1. learn.recorder.plot_lr(show_moms=True)

Overview - 图2

MixUpCallback

Data augmentation using the method from mixup: Beyond Empirical Risk Minimization. It is very simple to add mixup in fastai :

  1. learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy]).mixup()

CSVLogger

Log the results of training in a csv file. Simply pass the CSVLogger callback to the Learner.

  1. learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy, error_rate], callback_fns=[CSVLogger])
  1. learn.fit(3)

Total time: 00:07

epochtrain_lossvalid_lossaccuracyerror_ratetime
00.1272590.0980690.9695780.03042200:02
10.0846010.0680240.9749750.02502500:02
20.0550740.0472660.9833170.01668300:02

You can then read the csv.

  1. learn.csv_logger.read_logged_file()
epochtrain_lossvalid_lossaccuracyerror_rate
000.1272590.0980690.9695780.030422
110.0846010.0680240.9749750.025025
220.0550740.0472660.9833170.016683

GeneralScheduler

Create your own multi-stage annealing schemes with a convenient API. To illustrate, let’s implement a 2 phase schedule.

  1. def fit_odd_shedule(learn, lr):
  2. n = len(learn.data.train_dl)
  3. phases = [TrainingPhase(n).schedule_hp('lr', lr, anneal=annealing_cos),
  4. TrainingPhase(n*2).schedule_hp('lr', lr, anneal=annealing_poly(2))]
  5. sched = GeneralScheduler(learn, phases)
  6. learn.callbacks.append(sched)
  7. total_epochs = 3
  8. learn.fit(total_epochs)
  1. learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)
  2. fit_odd_shedule(learn, 1e-3)

Total time: 00:07

epochtrain_lossvalid_lossaccuracytime
00.1766070.1572290.94602500:02
10.1409030.1336900.95436700:02
20.1309100.1311560.95682000:02
  1. learn.recorder.plot_lr()

Overview - 图3

MixedPrecision

Use fp16 to take advantage of tensor cores on recent NVIDIA GPUs for a 200% or more speedup.

HookCallback

Convenient wrapper for registering and automatically deregistering PyTorch hooks. Also contains pre-defined hook callback: ActivationStats.

RNNTrainer

Callback taking care of all the tweaks to train an RNN.

TerminateOnNaNCallback

Stop training if the loss reaches NaN.

EarlyStoppingCallback

Stop training if a given metric/validation loss doesn’t improve.

SaveModelCallback

Save the model at every epoch, or the best model for a given metric/validation loss.

  1. learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)
  2. learn.fit_one_cycle(3,1e-4, callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy')])

Total time: 00:07

epochtrain_lossvalid_lossaccuracytime
00.6791890.6465990.80422000:02
10.5274750.4972900.90824300:02
20.4647560.4624710.91707600:02
  1. !ls ~/.fastai/data/mnist_sample/models
  1. best.pth bestmodel_2.pth model_1.pth model_4.pth stage-1.pth
  2. bestmodel_0.pth bestmodel_3.pth model_2.pth model_5.pth tmp.pth
  3. bestmodel_1.pth model_0.pth model_3.pth one_epoch.pth trained_model.pth

ReduceLROnPlateauCallback

Reduce the learning rate each time a given metric/validation loss doesn’t improve by a certain factor.

PeakMemMetric

GPU and general RAM profiling callback

StopAfterNBatches

Stop training after n batches of the first epoch.

LearnerTensorboardWriter

Broadly useful callback for Learners that writes to Tensorboard. Writes model histograms, losses/metrics, embedding projector and gradient stats.

train and basic_train

Recorder

Track per-batch and per-epoch smoothed losses and metrics.

ShowGraph

Dynamically display a learning chart during training.

BnFreeze

Freeze batchnorm layer moving average statistics for non-trainable layers.

GradientClipping

Clips gradient during training.


Company logo

©2021 fast.ai. All rights reserved.
Site last generated: Jan 5, 2021