A Generic Optimizer
To build up our accelerated SGD tricks, we’ll need to start with a nice flexible optimizer foundation. No library prior to fastai provided such a foundation, but during fastai’s development we realized that all the optimizer improvements we’d seen in the academic literature could be handled using optimizer callbacks. These are small pieces of code that we can compose, mix and match in an optimizer to build the optimizer step
. They are called by fastai’s lightweight Optimizer
class. These are the definitions in Optimizer
of the two key methods that we’ve been using in this book:
def zero_grad(self):
for p,*_ in self.all_params():
p.grad.detach_()
p.grad.zero_()
def step(self):
for p,pg,state,hyper in self.all_params():
for cb in self.cbs:
state = _update(state, cb(p, **{**state, **hyper}))
self.state[p] = state
As we saw when training an MNIST model from scratch, zero_grad
just loops through the parameters of the model and sets the gradients to zero. It also calls detach_
, which removes any history of gradient computation, since it won’t be needed after zero_grad
.
The more interesting method is step
, which loops through the callbacks (cbs
) and calls them to update the parameters (the _update
function just calls state.update
if there’s anything returned by cb
). As you can see, Optimizer
doesn’t actually do any SGD steps itself. Let’s see how we can add SGD to Optimizer
.
Here’s an optimizer callback that does a single SGD step, by multiplying -lr
by the gradients and adding that to the parameter (when Tensor.add_
in PyTorch is passed two parameters, they are multiplied together before the addition):
In [ ]:
def sgd_cb(p, lr, **kwargs): p.data.add_(-lr, p.grad.data)
We can pass this to Optimizer
using the cbs
parameter; we’ll need to use partial
since Learner
will call this function to create our optimizer later:
In [ ]:
opt_func = partial(Optimizer, cbs=[sgd_cb])
Let’s see if this trains:
In [ ]:
learn = get_learner(opt_func=opt_func)
learn.fit(3, 0.03)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 2.730918 | 2.009971 | 0.332739 | 00:09 |
1 | 2.204893 | 1.747202 | 0.441529 | 00:09 |
2 | 1.875621 | 1.684515 | 0.445350 | 00:09 |
It’s working! So that’s how we create SGD from scratch in fastai. Now let’s see what “momentum” is.