callbacks.lr_finder

Implementation of the LR Range test from Leslie Smith

Learning Rate Finder

Learning rate finder plots lr vs loss relationship for a Learner. The idea is to reduce the amount of guesswork on picking a good starting learning rate.

Overview:

  1. First run lr_find learn.lr_find()
  2. Plot the learning rate vs loss learn.recorder.plot()
  3. Pick a learning rate before it diverges then start training

Technical Details: (first described by Leslie Smith)

Train Learner over a few iterations. Start with a very low start_lr and change it at each mini-batch until it reaches a very high end_lr. Recorder will record the loss at each iteration. Plot those losses against the learning rate to find the optimal value before it diverges.

Choosing a good learning rate

For a more intuitive explanation, please check out Sylvain Gugger’s post

  1. path = untar_data(URLs.MNIST_SAMPLE)
  2. data = ImageDataBunch.from_folder(path)
  3. def simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])
  4. learn = simple_learner()

First we run this command to launch the search:

lr_find[source][test]

lr_find(learn:Learner, start_lr:Floats=1e-07, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, wd:float=None) Tests found for lr_find:

  • pytest -sv tests/test_train.py::test_lr_find [source]
  • pytest -sv tests/test_vision_train.py::test_lrfind [source]

To run tests please refer to this guide.

Explore lr from start_lr to end_lr over num_it iterations in learn. If stop_div, stops when loss diverges.

  1. learn.lr_find(stop_div=False, num_it=200)
  1. LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

Then we plot the loss versus the learning rates. We’re interested in finding a good order of magnitude of learning rate, so we plot with a log scale.

  1. learn.recorder.plot()

LRFinder - 图1

Then, we choose a value that is approximately in the middle of the sharpest downward slope. This is given as an indication by the LR Finder tool, so let’s try 1e-2.

  1. simple_learner().fit(2, 1e-2)
epochtrain_lossvalid_lossaccuracytime
10.1274340.0702430.97301300:02
20.0507030.0394930.98478900:02

Don’t just pick the minimum value from the plot!

  1. learn = simple_learner()
  2. simple_learner().fit(2, 1e-0)
epochtrain_lossvalid_lossaccuracytime
10.7272210.6931470.49558400:02
20.6938260.6931470.49558400:02

Picking a value before the downward slope results in slow training:

  1. learn = simple_learner()
  2. simple_learner().fit(2, 1e-3)
epochtrain_lossvalid_lossaccuracytime
10.1528970.1343660.95093200:02
20.1209610.1175500.96074600:02

Suggested LR

If you pass suggestion=True in learn.recorder.plot, you will see the point where the gardient is the steepest with a
red dot on the graph. We can use that point as a first guess for an LR.

  1. learn.lr_find(stop_div=False, num_it=200)
  1. LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
  1. learn.recorder.plot(suggestion=True)
  1. Min numerical gradient: 5.25E-03

LRFinder - 图2

You can access the corresponding learning rate like this:

  1. min_grad_lr = learn.recorder.min_grad_lr
  2. min_grad_lr
  1. 0.005248074602497722
  1. learn = simple_learner()
  2. simple_learner().fit(2, min_grad_lr)
epochtrain_lossvalid_lossaccuracytime
10.1094750.0816070.97055900:02
20.0703030.0509770.98282600:02

class LRFinder[source][test]

LRFinder(learn:Learner, start_lr:float=1e-07, end_lr:float=10, num_it:int=100, stop_div:bool=True) :: LearnerCallback No tests found for LRFinder. To contribute a test please refer to this guide and this discussion.

Causes learn to go on a mock training from start_lr to end_lr for num_it iterations.

Callback methods

You don’t call these yourself - they’re called by fastai’s Callback system automatically to enable the class’s functionality.

on_train_begin[source][test]

on_train_begin(pbar, **kwargs:Any) No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.

Initialize optimizer and learner hyperparameters.

on_batch_end[source][test]

on_batch_end(iteration:int, smooth_loss:TensorOrNumber, **kwargs:Any) No tests found for on_batch_end. To contribute a test please refer to this guide and this discussion.

Determine if loss has runaway and we should stop.

on_epoch_end[source][test]

on_epoch_end(**kwargs:Any) No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.

Called at the end of an epoch.

on_train_end[source][test]

on_train_end(epoch:int, num_batch:int, **kwargs:Any) No tests found for on_train_end. To contribute a test please refer to this guide and this discussion.

Cleanup learn model weights disturbed during LRFinder exploration.


Company logo

©2021 fast.ai. All rights reserved.
Site last generated: Jan 5, 2021