callbacks.rnn
Implementation of a callback for RNN training
Training tweaks for an RNN
This callback regroups a few tweaks to properly train RNNs. They all come from this article by Stephen Merity et al.
Activation Regularization: on top of weight decay, we apply another form of regularization that is pretty similar and consists in adding to the loss a scaled factor of the sum of all the squares of the outputs (with dropout applied) of the various layers of the RNN. Intuitively, weight decay tries to get the network to learn small weights, this is to get the model to learn to produce smaller activations.
Temporal Activation Regularization: lastly, we add to the loss a scaled factor of the sum of the squares of the h_(t+1) - h_t
, where h_i
is the output (before dropout is applied) of one layer of the RNN at the time step i (word i of the sentence). This will encourage the model to produce activations that don’t vary too fast between two consecutive words of the sentence.
class
RNNTrainer
[source][test]
RNNTrainer
(learn
:Learner
,alpha
:float
=0.0
,beta
:float
=0.0
) ::LearnerCallback
No tests found forRNNTrainer
. To contribute a test please refer to this guide and this discussion.
Callback
that regroups lr adjustment to seq_len, AR and TAR.
Create a Callback
that adds to learner the RNN tweaks for training on data with bptt
. alpha
is the scale for AR, beta
is the scale for TAR.
Callback methods
You don’t call these yourself - they’re called by fastai’s Callback
system automatically to enable the class’s functionality.
on_epoch_begin
[source][test]
on_epoch_begin
(**kwargs
) No tests found foron_epoch_begin
. To contribute a test please refer to this guide and this discussion.
Reset the hidden state of the model.
on_loss_begin
[source][test]
on_loss_begin
(last_output
:Tuple
[Tensor
,Sequence
[Tensor
],Sequence
[Tensor
]], **kwargs
) No tests found foron_loss_begin
. To contribute a test please refer to this guide and this discussion.
Save the extra outputs for later and only returns the true output.
The fastai RNNs return last_output
that are tuples of three elements, the true output (that is returned) and the hidden states before and after dropout (which are saved internally for the next function).
on_backward_begin
[source][test]
on_backward_begin
(last_loss
:Rank0Tensor
,last_input
:Tensor
, **kwargs
) No tests found foron_backward_begin
. To contribute a test please refer to this guide and this discussion.
Apply AR and TAR to last_loss
.
©2021 fast.ai. All rights reserved.
Site last generated: Jan 5, 2021