Computing Metrics Using Broadcasting
Recall that a metric is a number that is calculated based on the predictions of our model, and the correct labels in our dataset, in order to tell us how good our model is. For instance, we could use either of the functions we saw in the previous section, mean squared error, or mean absolute error, and take the average of them over the whole dataset. However, neither of these are numbers that are very understandable to most people; in practice, we normally use accuracy as the metric for classification models.
As we’ve discussed, we want to calculate our metric over a validation set. This is so that we don’t inadvertently overfit—that is, train a model to work well only on our training data. This is not really a risk with the pixel similarity model we’re using here as a first try, since it has no trained components, but we’ll use a validation set anyway to follow normal practices and to be ready for our second try later.
To get a validation set we need to remove some of the data from training entirely, so it is not seen by the model at all. As it turns out, the creators of the MNIST dataset have already done this for us. Do you remember how there was a whole separate directory called valid? That’s what this directory is for!
So to start with, let’s create tensors for our 3s and 7s from that directory. These are the tensors we will use to calculate a metric measuring the quality of our first-try model, which measures distance from an ideal image:
In [ ]:
valid_3_tens = torch.stack([tensor(Image.open(o))
for o in (path/'valid'/'3').ls()])
valid_3_tens = valid_3_tens.float()/255
valid_7_tens = torch.stack([tensor(Image.open(o))
for o in (path/'valid'/'7').ls()])
valid_7_tens = valid_7_tens.float()/255
valid_3_tens.shape,valid_7_tens.shape
Out[ ]:
(torch.Size([1010, 28, 28]), torch.Size([1028, 28, 28]))
It’s good to get in the habit of checking shapes as you go. Here we see two tensors, one representing the 3s validation set of 1,010 images of size 28×28, and one representing the 7s validation set of 1,028 images of size 28×28.
We ultimately want to write a function, is_3
, that will decide if an arbitrary image is a 3 or a 7. It will do this by deciding which of our two “ideal digits” this arbitrary image is closer to. For that we need to define a notion of distance—that is, a function that calculates the distance between two images.
We can write a simple function that calculates the mean absolute error using an expression very similar to the one we wrote in the last section:
In [ ]:
def mnist_distance(a,b): return (a-b).abs().mean((-1,-2))
mnist_distance(a_3, mean3)
Out[ ]:
tensor(0.1114)
This is the same value we previously calculated for the distance between these two images, the ideal 3 mean3
and the arbitrary sample 3 a_3
, which are both single-image tensors with a shape of [28,28]
.
But in order to calculate a metric for overall accuracy, we will need to calculate the distance to the ideal 3 for every image in the validation set. How do we do that calculation? We could write a loop over all of the single-image tensors that are stacked within our validation set tensor, valid_3_tens
, which has a shape of [1010,28,28]
representing 1,010 images. But there is a better way.
Something very interesting happens when we take this exact same distance function, designed for comparing two single images, but pass in as an argument valid_3_tens
, the tensor that represents the 3s validation set:
In [ ]:
valid_3_dist = mnist_distance(valid_3_tens, mean3)
valid_3_dist, valid_3_dist.shape
Out[ ]:
(tensor([0.1050, 0.1526, 0.1186, ..., 0.1122, 0.1170, 0.1086]),
torch.Size([1010]))
Instead of complaining about shapes not matching, it returned the distance for every single image as a vector (i.e., a rank-1 tensor) of length 1,010 (the number of 3s in our validation set). How did that happen?
Take another look at our function mnist_distance
, and you’ll see we have there the subtraction (a-b)
. The magic trick is that PyTorch, when it tries to perform a simple subtraction operation between two tensors of different ranks, will use broadcasting. That is, it will automatically expand the tensor with the smaller rank to have the same size as the one with the larger rank. Broadcasting is an important capability that makes tensor code much easier to write.
After broadcasting so the two argument tensors have the same rank, PyTorch applies its usual logic for two tensors of the same rank: it performs the operation on each corresponding element of the two tensors, and returns the tensor result. For instance:
In [ ]:
tensor([1,2,3]) + tensor(1)
Out[ ]:
tensor([2, 3, 4])
So in this case, PyTorch treats mean3
, a rank-2 tensor representing a single image, as if it were 1,010 copies of the same image, and then subtracts each of those copies from each 3 in our validation set. What shape would you expect this tensor to have? Try to figure it out yourself before you look at the answer below:
In [ ]:
(valid_3_tens-mean3).shape
Out[ ]:
torch.Size([1010, 28, 28])
We are calculating the difference between our “ideal 3” and each of the 1,010 3s in the validation set, for each of 28×28 images, resulting in the shape [1010,28,28]
.
There are a couple of important points about how broadcasting is implemented, which make it valuable not just for expressivity but also for performance:
- PyTorch doesn’t actually copy
mean3
1,010 times. It pretends it were a tensor of that shape, but doesn’t actually allocate any additional memory - It does the whole calculation in C (or, if you’re using a GPU, in CUDA, the equivalent of C on the GPU), tens of thousands of times faster than pure Python (up to millions of times faster on a GPU!).
This is true of all broadcasting and elementwise operations and functions done in PyTorch. It’s the most important technique for you to know to create efficient PyTorch code.
Next in mnist_distance
we see abs
. You might be able to guess now what this does when applied to a tensor. It applies the method to each individual element in the tensor, and returns a tensor of the results (that is, it applies the method “elementwise”). So in this case, we’ll get back 1,010 matrices of absolute values.
Finally, our function calls mean((-1,-2))
. The tuple (-1,-2)
represents a range of axes. In Python, -1
refers to the last element, and -2
refers to the second-to-last. So in this case, this tells PyTorch that we want to take the mean ranging over the values indexed by the last two axes of the tensor. The last two axes are the horizontal and vertical dimensions of an image. After taking the mean over the last two axes, we are left with just the first tensor axis, which indexes over our images, which is why our final size was (1010)
. In other words, for every image, we averaged the intensity of all the pixels in that image.
We’ll be learning lots more about broadcasting throughout this book, especially in <>, and will be practicing it regularly too.
We can use mnist_distance
to figure out whether an image is a 3 or not by using the following logic: if the distance between the digit in question and the ideal 3 is less than the distance to the ideal 7, then it’s a 3. This function will automatically do broadcasting and be applied elementwise, just like all PyTorch functions and operators:
In [ ]:
def is_3(x): return mnist_distance(x,mean3) < mnist_distance(x,mean7)
Let’s test it on our example case:
In [ ]:
is_3(a_3), is_3(a_3).float()
Out[ ]:
(tensor(True), tensor(1.))
Note that when we convert the Boolean response to a float, we get 1.0
for True
and 0.0
for False
. Thanks to broadcasting, we can also test it on the full validation set of 3s:
In [ ]:
is_3(valid_3_tens)
Out[ ]:
tensor([True, True, True, ..., True, True, True])
Now we can calculate the accuracy for each of the 3s and 7s by taking the average of that function for all 3s and its inverse for all 7s:
In [ ]:
accuracy_3s = is_3(valid_3_tens).float() .mean()
accuracy_7s = (1 - is_3(valid_7_tens).float()).mean()
accuracy_3s,accuracy_7s,(accuracy_3s+accuracy_7s)/2
Out[ ]:
(tensor(0.9168), tensor(0.9854), tensor(0.9511))
This looks like a pretty good start! We’re getting over 90% accuracy on both 3s and 7s, and we’ve seen how to define a metric conveniently using broadcasting.
But let’s be honest: 3s and 7s are very different-looking digits. And we’re only classifying 2 out of the 10 possible digits so far. So we’re going to need to do better!
To do better, perhaps it is time to try a system that does some real learning—that is, that can automatically modify itself to improve its performance. In other words, it’s time to talk about the training process, and SGD.