Tracking callbacks
Callbacks that make decisions depending how a monitored metric/loss behaves
/usr/local/lib/python3.8/dist-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
return torch._C._cuda_getDeviceCount() > 0
class
TerminateOnNaNCallback
[source]
TerminateOnNaNCallback
(after_create
=None
,before_fit
=None
,before_epoch
=None
,before_train
=None
,before_batch
=None
,after_pred
=None
,after_loss
=None
,before_backward
=None
,before_step
=None
,after_cancel_step
=None
,after_step
=None
,after_cancel_batch
=None
,after_batch
=None
,after_cancel_train
=None
,after_train
=None
,before_validate
=None
,after_cancel_validate
=None
,after_validate
=None
,after_cancel_epoch
=None
,after_epoch
=None
,after_cancel_fit
=None
,after_fit
=None
) ::Callback
A Callback
that terminates training if loss is NaN.
learn = synth_learner()
learn.fit(10, lr=100, cbs=TerminateOnNaNCallback())
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 1914263325772146366332652801648230400.000000 | 00:00 |
assert len(learn.recorder.losses) < 10 * len(learn.dls.train)
for l in learn.recorder.losses:
assert not torch.isinf(l) and not torch.isnan(l)
class
TrackerCallback
[source]
TrackerCallback
(monitor
='valid_loss'
,comp
=None
,min_delta
=0.0
,reset_on_fit
=True
) ::Callback
A Callback
that keeps track of the best value in monitor
.
When implementing a Callback
that has behavior that depends on the best value of a metric or loss, subclass this Callback
and use its best
(for best value so far) and new_best
(there was a new best value this epoch) attributes. If you want to maintain best
over subsequent calls to fit
(e.g., Learner.fit_one_cycle
), set reset_on_fit
= True.
comp
is the comparison operator used to determine if a value is best than another (defaults to np.less
if ‘loss’ is in the name passed in monitor
, np.greater
otherwise) and min_delta
is an optional float that requires a new value to go over the current best (depending on comp
) by at least that amount.
class
EarlyStoppingCallback
[source]
EarlyStoppingCallback
(monitor
='valid_loss'
,comp
=None
,min_delta
=0.0
,patience
=1
,reset_on_fit
=True
) ::TrackerCallback
A TrackerCallback
that terminates training when monitored quantity stops improving.
comp
is the comparison operator used to determine if a value is best than another (defaults to np.less
if ‘loss’ is in the name passed in monitor
, np.greater
otherwise) and min_delta
is an optional float that requires a new value to go over the current best (depending on comp
) by at least that amount. patience
is the number of epochs you’re willing to wait without improvement.
learn = synth_learner(n_trn=2, metrics=F.mse_loss)
learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='mse_loss', min_delta=0.1, patience=2))
epoch | train_loss | valid_loss | mse_loss | time |
---|---|---|---|---|
0 | 25.913376 | 28.702148 | 28.702148 | 00:00 |
1 | 25.952229 | 28.702074 | 28.702074 | 00:00 |
2 | 25.970026 | 28.701965 | 28.701965 | 00:00 |
No improvement since epoch 0: early stopping
learn.validate()
(#2) [28.70196533203125,28.70196533203125]
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='valid_loss', min_delta=0.1, patience=2))
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 15.580492 | 8.504006 | 00:00 |
1 | 15.592066 | 8.503983 | 00:00 |
2 | 15.603076 | 8.503948 | 00:00 |
No improvement since epoch 0: early stopping
class
SaveModelCallback
[source]
SaveModelCallback
(monitor
='valid_loss'
,comp
=None
,min_delta
=0.0
,fname
='model'
,every_epoch
=False
,with_opt
=False
,reset_on_fit
=True
) ::TrackerCallback
A TrackerCallback
that saves the model’s best during training and loads it at the end.
comp
is the comparison operator used to determine if a value is best than another (defaults to np.less
if ‘loss’ is in the name passed in monitor
, np.greater
otherwise) and min_delta
is an optional float that requires a new value to go over the current best (depending on comp
) by at least that amount. Model will be saved in learn.path/learn.model_dir/name.pth
, maybe every_epoch
or at each improvement of the monitored quantity.
learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp')
learn.fit(n_epoch=2, cbs=SaveModelCallback())
assert (Path.cwd()/'tmp/models/model.pth').exists()
learn.fit(n_epoch=2, cbs=SaveModelCallback(every_epoch=True))
for i in range(2): assert (Path.cwd()/f'tmp/models/model_{i}.pth').exists()
shutil.rmtree(Path.cwd()/'tmp')
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 10.488046 | 10.307009 | 00:00 |
1 | 10.410013 | 10.064041 | 00:00 |
Better model found at epoch 0 with valid_loss value: 10.307008743286133.
Better model found at epoch 1 with valid_loss value: 10.064041137695312.
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 10.038021 | 9.718258 | 00:00 |
1 | 9.838678 | 9.300011 | 00:00 |
ReduceLROnPlateau
class
ReduceLROnPlateau
[source]
ReduceLROnPlateau
(monitor
='valid_loss'
,comp
=None
,min_delta
=0.0
,patience
=1
,factor
=10.0
,min_lr
=0
,reset_on_fit
=True
) ::TrackerCallback
A TrackerCallback
that reduces learning rate when a metric has stopped improving.
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=4, lr=1e-7, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2))
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 11.299067 | 16.745235 | 00:00 |
1 | 11.289301 | 16.745203 | 00:00 |
2 | 11.276413 | 16.745152 | 00:00 |
3 | 11.267982 | 16.745146 | 00:00 |
Epoch 2: reducing lr to 1e-08
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=6, lr=5e-8, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2, min_lr=1e-8))
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 21.629301 | 15.617614 | 00:00 |
1 | 21.608873 | 15.617589 | 00:00 |
2 | 21.620173 | 15.617556 | 00:00 |
3 | 21.619131 | 15.617546 | 00:00 |
4 | 21.615915 | 15.617537 | 00:00 |
5 | 21.606327 | 15.617526 | 00:00 |
Epoch 2: reducing lr to 1e-08
©2021 fast.ai. All rights reserved.
Site last generated: Mar 31, 2021