Improving the RNN

Looking at the code for our RNN, one thing that seems problematic is that we are initializing our hidden state to zero for every new input sequence. Why is that a problem? We made our sample sequences short so they would fit easily into batches. But if we order the samples correctly, those sample sequences will be read in order by the model, exposing the model to long stretches of the original sequence.

Another thing we can look at is having more signal: why only predict the fourth word when we could use the intermediate predictions to also predict the second and third words?

Let’s see how we can implement those changes, starting with adding some state.

Maintaining the State of an RNN

Because we initialize the model’s hidden state to zero for each new sample, we are throwing away all the information we have about the sentences we have seen so far, which means that our model doesn’t actually know where we are up to in the overall counting sequence. This is easily fixed; we can simply move the initialization of the hidden state to __init__.

But this fix will create its own subtle, but important, problem. It effectively makes our neural network as deep as the entire number of tokens in our document. For instance, if there were 10,000 tokens in our dataset, we would be creating a 10,000-layer neural network.

To see why this is the case, consider the original pictorial representation of our recurrent neural network in <>, before refactoring it with a for loop. You can see each layer corresponds with one token input. When we talk about the representation of a recurrent neural network before refactoring with the for loop, we call this the unrolled representation. It is often helpful to consider the unrolled representation when trying to understand an RNN.

The problem with a 10,000-layer neural network is that if and when you get to the 10,000th word of the dataset, you will still need to calculate the derivatives all the way back to the first layer. This is going to be very slow indeed, and very memory-intensive. It is unlikely that you’ll be able to store even one mini-batch on your GPU.

The solution to this problem is to tell PyTorch that we do not want to back propagate the derivatives through the entire implicit neural network. Instead, we will just keep the last three layers of gradients. To remove all of the gradient history in PyTorch, we use the detach method.

Here is the new version of our RNN. It is now stateful, because it remembers its activations between different calls to forward, which represent its use for different samples in the batch:

In [ ]:

  1. class LMModel3(Module):
  2. def __init__(self, vocab_sz, n_hidden):
  3. self.i_h = nn.Embedding(vocab_sz, n_hidden)
  4. self.h_h = nn.Linear(n_hidden, n_hidden)
  5. self.h_o = nn.Linear(n_hidden,vocab_sz)
  6. self.h = 0
  7. def forward(self, x):
  8. for i in range(3):
  9. self.h = self.h + self.i_h(x[:,i])
  10. self.h = F.relu(self.h_h(self.h))
  11. out = self.h_o(self.h)
  12. self.h = self.h.detach()
  13. return out
  14. def reset(self): self.h = 0

This model will have the same activations whatever sequence length we pick, because the hidden state will remember the last activation from the previous batch. The only thing that will be different is the gradients computed at each step: they will only be calculated on sequence length tokens in the past, instead of the whole stream. This approach is called backpropagation through time (BPTT).

jargon: Back propagation through time (BPTT): Treating a neural net with effectively one layer per time step (usually refactored using a loop) as one big model, and calculating gradients on it in the usual way. To avoid running out of memory and time, we usually use truncated BPTT, which “detaches” the history of computation steps in the hidden state every few time steps.

To use LMModel3, we need to make sure the samples are going to be seen in a certain order. As we saw in <>, if the first line of the first batch is our dset[0] then the second batch should have dset[1] as the first line, so that the model sees the text flowing.

LMDataLoader was doing this for us in <>. This time we’re going to do it ourselves.

To do this, we are going to rearrange our dataset. First we divide the samples into m = len(dset) // bs groups (this is the equivalent of splitting the whole concatenated dataset into, for example, 64 equally sized pieces, since we’re using bs=64 here). m is the length of each of these pieces. For instance, if we’re using our whole dataset (although we’ll actually split it into train versus valid in a moment), that will be:

In [ ]:

  1. m = len(seqs)//bs
  2. m,bs,len(seqs)

Out[ ]:

  1. (328, 64, 21031)

The first batch will be composed of the samples:

  1. (0, m, 2*m, ..., (bs-1)*m)

the second batch of the samples:

  1. (1, m+1, 2*m+1, ..., (bs-1)*m+1)

and so forth. This way, at each epoch, the model will see a chunk of contiguous text of size 3*m (since each text is of size 3) on each line of the batch.

The following function does that reindexing:

In [ ]:

  1. def group_chunks(ds, bs):
  2. m = len(ds) // bs
  3. new_ds = L()
  4. for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))
  5. return new_ds

Then we just pass drop_last=True when building our DataLoaders to drop the last batch that does not have a shape of bs. We also pass shuffle=False to make sure the texts are read in order:

In [ ]:

  1. cut = int(len(seqs) * 0.8)
  2. dls = DataLoaders.from_dsets(
  3. group_chunks(seqs[:cut], bs),
  4. group_chunks(seqs[cut:], bs),
  5. bs=bs, drop_last=True, shuffle=False)

The last thing we add is a little tweak of the training loop via a Callback. We will talk more about callbacks in <>; this one will call the reset method of our model at the beginning of each epoch and before each validation phase. Since we implemented that method to zero the hidden state of the model, this will make sure we start with a clean state before reading those continuous chunks of text. We can also start training a bit longer:

In [ ]:

  1. learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy,
  2. metrics=accuracy, cbs=ModelResetter)
  3. learn.fit_one_cycle(10, 3e-3)
epochtrain_lossvalid_lossaccuracytime
01.6770741.8273670.46754800:02
11.2827221.8709130.38894200:02
21.0907051.6517930.46250000:02
31.0050921.6137940.51658700:02
40.9659751.5607750.55120200:02
50.9161821.5958570.56057700:02
60.8976571.5397330.57427900:02
70.8362741.5851410.58317300:02
80.8058771.6298080.58677900:02
90.7950961.6512670.58894200:02

This is already better! The next step is to use more targets and compare them to the intermediate predictions.

Creating More Signal

Another problem with our current approach is that we only predict one output word for each three input words. That means that the amount of signal that we are feeding back to update weights with is not as large as it could be. It would be better if we predicted the next word after every single word, rather than every three words, as shown in <>.

RNN predicting after every token

This is easy enough to add. We need to first change our data so that the dependent variable has each of the three next words after each of our three input words. Instead of 3, we use an attribute, sl (for sequence length), and make it a bit bigger:

In [ ]:

  1. sl = 16
  2. seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))
  3. for i in range(0,len(nums)-sl-1,sl))
  4. cut = int(len(seqs) * 0.8)
  5. dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),
  6. group_chunks(seqs[cut:], bs),
  7. bs=bs, drop_last=True, shuffle=False)

Looking at the first element of seqs, we can see that it contains two lists of the same size. The second list is the same as the first, but offset by one element:

In [ ]:

  1. [L(vocab[o] for o in s) for s in seqs[0]]

Out[ ]:

  1. [(#16) ['one','.','two','.','three','.','four','.','five','.'...],
  2. (#16) ['.','two','.','three','.','four','.','five','.','six'...]]

Now we need to modify our model so that it outputs a prediction after every word, rather than just at the end of a three-word sequence:

In [ ]:

  1. class LMModel4(Module):
  2. def __init__(self, vocab_sz, n_hidden):
  3. self.i_h = nn.Embedding(vocab_sz, n_hidden)
  4. self.h_h = nn.Linear(n_hidden, n_hidden)
  5. self.h_o = nn.Linear(n_hidden,vocab_sz)
  6. self.h = 0
  7. def forward(self, x):
  8. outs = []
  9. for i in range(sl):
  10. self.h = self.h + self.i_h(x[:,i])
  11. self.h = F.relu(self.h_h(self.h))
  12. outs.append(self.h_o(self.h))
  13. self.h = self.h.detach()
  14. return torch.stack(outs, dim=1)
  15. def reset(self): self.h = 0

This model will return outputs of shape bs x sl x vocab_sz (since we stacked on dim=1). Our targets are of shape bs x sl, so we need to flatten those before using them in F.cross_entropy:

In [ ]:

  1. def loss_func(inp, targ):
  2. return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))

We can now use this loss function to train the model:

In [ ]:

  1. learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss_func,
  2. metrics=accuracy, cbs=ModelResetter)
  3. learn.fit_one_cycle(15, 3e-3)
epochtrain_lossvalid_lossaccuracytime
03.1032982.8743410.21256500:01
12.2319641.9712800.46215800:01
21.7113581.8135470.46118200:01
31.4485161.8281760.48323600:01
41.2886301.6595640.52067100:01
51.1614701.7140230.55493200:01
61.0555681.6609160.57503300:01
70.9607651.7196240.59106400:01
80.8701531.8395600.61466500:01
90.8085451.7702780.62434900:01
100.7580841.8429310.61075800:01
110.7193201.7995270.64656600:01
120.6834391.9179280.64982100:01
130.6602831.8747120.62858100:01
140.6461541.8775190.64005500:01

We need to train for longer, since the task has changed a bit and is more complicated now. But we end up with a good result… At least, sometimes. If you run it a few times, you’ll see that you can get quite different results on different runs. That’s because effectively we have a very deep network here, which can result in very large or very small gradients. We’ll see in the next part of this chapter how to deal with this.

Now, the obvious way to get a better model is to go deeper: we only have one linear layer between the hidden state and the output activations in our basic RNN, so maybe we’ll get better results with more.