Tutorial - Using fastai on a custom new task

Open In Colab

How to use the mid-level API for data collection, model creation and training

In this tutorial, we will see how to deal with a new type of task using the middle layer of the fastai library. The example we will use is a Siamese network, that takes two images and determine if they are of the same class or not. In particular we will see:

  • how to quickly get DataLoaders from a standard PyTorch Datasets
  • how to adapt this in a Transform to get some of the show features of fastai
  • how to add some new behavior to show_batch/show_results for a custom task
  • how to write a custom DataBlock
  • how to create your own model from a pretrained model
  • how to pass along a custom splitter to Learner to take advantage of transfer learning

Preparing the data

To make our data ready for training a model, we need to create a DataLoaders object in fastai. It is just a wrapper around a training DataLoader and a validation DataLoader, so if you already have your own PyTorch dataloaders, you can create such an object directly.

Here we don’t have anything ready yet. Usually, when using PyTorch, the first step is to create a Dataset that is then wrapped inside a DataLoader. We will do this first, then see how to change this Dataset into a Transform that will let us take advantage of fastai’s functionality for showing a batch or using data augmentation on the GPU. Lastly we will see how we can customize the data block API and create our own new TransformBlock.

Purely in PyTorch

To begin with, we will only use PyTorch and PIL to create a Dataset and see how to get this inside fastai. The only helper functions from fastai we will use are untar_data (to download and untar the dataset) and get_image_files (that looks for all images in a folder recursively). Here, we will use the Oxford-IIIT Pet Dataset.

  1. from fastai.data.external import untar_data,URLs
  2. from fastai.data.transforms import get_image_files

untar_data returns a pathlib.Path object with the location of the decompressed dataset, and in this case, all the images are in an images subfolder:

  1. path = untar_data(URLs.PETS)
  2. files = get_image_files(path/"images")
  3. files[0]
  1. Path('/home/jhoward/.fastai/data/oxford-iiit-pet/images/great_pyrenees_173.jpg')

We can open the first image with PIL and have a look at it:

  1. import PIL
  1. img = PIL.Image.open(files[0])
  2. img

Siamese Tutorial - 图2

Let’s wrap all the standard preprocessing (resize, conversion to tensor, dividing by 255 and reordering of the channels) in one helper function:

  1. import torch
  2. import numpy as np
  1. def open_image(fname, size=224):
  2. img = PIL.Image.open(fname).convert('RGB')
  3. img = img.resize((size, size))
  4. t = torch.Tensor(np.array(img))
  5. return t.permute(2,0,1).float()/255.0
  1. open_image(files[0]).shape
  1. torch.Size([3, 224, 224])

We can see the label of our image is in the filename, before the last _ and some number. We can then use a regex expression to create a label function:

  1. import re
  1. def label_func(fname):
  2. return re.match(r'^(.*)_d+.jpg$', fname.name).groups()[0]
  3. label_func(files[0])
  1. 'great_pyrenees'

Now lets gather all unique labels:

  1. labels = list(set(files.map(label_func)))
  2. len(labels)
  1. 37

So we have 37 different breeds of pets. To create our Siamese datasets, we will need to create tuple of images for inputs and the target will be True if the images are of the same class, False otherwise. It will be useful to have a mapping from class to list of filenames of that class, to quickly pick a random image for any class.

  1. lbl2files = {l: [f for f in files if label_func(f) == l] for l in labels}

Now we are ready to create our datasets. For the training set, we will go through all our training filenames for the first image, then pick randomly:

  • a filename of the same class for the second image (with probability 0.5)
  • a filename of a different class for the second image (with probability 0.5)

We will go through that random draw each time we access an item, to have as many samples as possible. For the validation set however, we will fix that random draw once and for all (otherwise we will validate on a different dataset at each epoch).

  1. import random
  1. class SiameseDataset(torch.utils.data.Dataset):
  2. def __init__(self, files, is_valid=False):
  3. self.files,self.is_valid = files,is_valid
  4. if is_valid: self.files2 = [self._draw(f) for f in files]
  5. def __getitem__(self, i):
  6. file1 = self.files[i]
  7. (file2,same) = self.files2[i] if self.is_valid else self._draw(file1)
  8. img1,img2 = open_image(file1),open_image(file2)
  9. return (img1, img2, torch.Tensor([same]).squeeze())
  10. def __len__(self): return len(self.files)
  11. def _draw(self, f):
  12. same = random.random() < 0.5
  13. cls = label_func(f)
  14. if not same: cls = random.choice([l for l in labels if l != cls])
  15. return random.choice(lbl2files[cls]),same

We just need to split our filenames between a training and validation set to use it.

  1. idxs = np.random.permutation(range(len(files)))
  2. cut = int(0.8 * len(files))
  3. train_files = files[idxs[:cut]]
  4. valid_files = files[idxs[cut:]]

We can then use it to create datasets.

  1. train_ds = SiameseDataset(train_files)
  2. valid_ds = SiameseDataset(valid_files, is_valid=True)

All of the above would be different for your custom problem, the main point is that as soon as you have some Datasets, you can create a fastai’s DataLoaders with the following factory method:

  1. from fastai.data.core import DataLoaders
  1. dls = DataLoaders.from_dsets(train_ds, valid_ds)

You can then use this DataLoaders object in a Learner and start training. Most methods that don’t rely on showing something (e.g. DataLoaders.show_batch and Learner.show_results for instance) should work. For instance, you can get and inspect a batch with:

  1. b = dls.one_batch()

If you want to use the GPU, you can just write:

  1. dls = dls.cuda()

Now, what is a bit annoying is that we have to rewrite everything that is already in fastai if we want to normalize our images, or apply data augmentation. With minimal changes to the code we wrote, we can still access all of that and get all the show method to work as a cherry on the top. Let’s see how.

Using the mid-level API

When you have a custom dataset like before, you can easily convert it into a fastai Transform by just changing the __getitem__ function to encodes. In general, a Transform in fastai calls the encodes method when you apply it on an item (a bit like PyTorch modules call forward when applied on something) so this will transform your python dataset in a function that transforms integer to your data.

If you then return a tuple (or a subclass of a tuple), and use fastai’s semantic type, you can then apply any other fastai’s transform on your data and it will be dispatched properly. Let’s see how that works:

  1. from fastai.vision.all import *
  1. class SiameseTransform(Transform):
  2. def __init__(self, files, is_valid=False):
  3. self.files,self.is_valid = files,is_valid
  4. if is_valid: self.files2 = [self._draw(f) for f in files]
  5. def encodes(self, i):
  6. file1 = self.files[i]
  7. (file2,same) = self.files2[i] if self.is_valid else self._draw(file1)
  8. img1,img2 = open_image(file1),open_image(file2)
  9. return (TensorImage(img1), TensorImage(img2), torch.Tensor([same]).squeeze())
  10. def _draw(self, f):
  11. same = random.random() < 0.5
  12. cls = label_func(f)
  13. if not same: cls = random.choice([l for l in labels if l != cls])
  14. return random.choice(lbl2files[cls]),same

So three things changed:

  • the __len__ disappeared, we won’t need it
  • __getitem___ became encodes
  • we return TensorImage for our images

How do we build a dataset with this? We will use TfmdLists. It’s just an object that lazily applies a collection of Transforms on a list. Here since our transform takes integers, we will pass simple ranges for this list.

  1. train_tl= TfmdLists(range(len(train_files)), SiameseTransform(train_files))
  2. valid_tl= TfmdLists(range(len(valid_files)), SiameseTransform(valid_files, is_valid=True))

Then, when we create a DataLoader, we can add any transform we like. fastai replaces the PyTorch DataLoader with its own version that has more hooks (but is fully compatible with PyTorch). The transforms we would like to be applied to items should be passed to after_item, the one we would like to be applied on a batch of data should be passed to after_batch.

  1. dls = DataLoaders.from_dsets(train_tl, valid_tl,
  2. after_batch=[Normalize.from_stats(*imagenet_stats), *aug_transforms()])
  3. dls = dls.cuda()

So with little change, we can use fastai normalization and data augmentation. If we are ready to do a bit more additional coding, we can even get the show behavior to work properly.

Making show work

The show methods in fastai all rely on some types being able to show themselves. Additionally, some transforms that need to be reversed for showing purposes (like changing a category to an index, or normalizing) have a decodes method to undo what their encodes did. In general, fastai will call those decodes method until it arrives at a type that knows how to show itself, then call the show method on this type.

So to make this work, let’s first create a new type with a show method!

  1. class SiameseImage(fastuple):
  2. def show(self, ctx=None, **kwargs):
  3. if len(self) > 2:
  4. img1,img2,similarity = self
  5. else:
  6. img1,img2 = self
  7. similarity = 'Undetermined'
  8. if not isinstance(img1, Tensor):
  9. if img2.size != img1.size: img2 = img2.resize(img1.size)
  10. t1,t2 = tensor(img1),tensor(img2)
  11. t1,t2 = t1.permute(2,0,1),t2.permute(2,0,1)
  12. else: t1,t2 = img1,img2
  13. line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)
  14. return show_image(torch.cat([t1,line,t2], dim=2), title=similarity, ctx=ctx, **kwargs)

There is bit of code in the first part of the test that you can ignore, it’s mostly to make the show method work on PIL Image as well as tensors. The main stuff happens in the last two lines: we create a black line of 10 pixels and show the tensor with our two images concatenated, with the black line in the middle. In general, ctx represents the object where we will show our thing. In this case, it could be a given matplotlib axis.

Let’s see an example:

  1. img = PILImage.create(files[0])
  2. img1 = PILImage.create(files[1])
  3. s = SiameseImage(img, img1, False)
  4. s.show();

Siamese Tutorial - 图3

Note that we used the fastai type PILImage instead of a PIL.Image. That is to get access to fastai’s transforms. For instance, we can use Resize and ToTensor directly on our SiamesImage. Since it subclasses tuple, those transforms are dispatched and applied to the part that make sense (the PILImages, not the bool).

  1. tst = Resize(224)(s)
  2. tst = ToTensor()(tst)
  3. tst.show();

Siamese Tutorial - 图4

Now let’s rewrite a bit our previous transform. Instead of taking integers, we can take files directly for instance. Also, in fastai, splits are usually handled by helper functions that return two lists of integers (the ones in the training set and the ones in the validation set), so let’s adapt a bit the code from before to have the validation images drawn once and for all. We also need to add in the mapping dictionaries from class to list of filenames of that class, separately for the train and valid splits, so that there is total separation between the training and validation sets, i.e. ‘train’ files should only draw samples from train split; ‘valid’ from valid split.

  1. class SiameseTransform(Transform):
  2. def __init__(self, files, splits):
  3. self.splbl2files = [{l: [f for f in files[splits[i]] if label_func(f) == l] for l in labels}
  4. for i in range(2)]
  5. self.valid = {f: self._draw(f,1) for f in files[splits[1]]}
  6. def encodes(self, f):
  7. f2,same = self.valid.get(f, self._draw(f,0))
  8. img1,img2 = PILImage.create(f),PILImage.create(f2)
  9. return SiameseImage(img1, img2, same)
  10. def _draw(self, f, split=0):
  11. same = random.random() < 0.5
  12. cls = label_func(f)
  13. if not same: cls = random.choice(L(l for l in labels if l != cls))
  14. return random.choice(self.splbl2files[split][cls]),same

Then we create our splits using a RandomSplitter:

  1. splits = RandomSplitter()(files)
  2. tfm = SiameseTransform(files, splits)

And we test that tfm.valid does not contain items from the train split:

  1. valids = [v[0] for k,v in tfm.valid.items()]
  2. assert not [v for v in valids if v in files[splits[0]]]

And we can pass those splits to TfmdLists, which will then create the validation and the training set for us.

  1. tls = TfmdLists(files, tfm, splits=splits)

We can now use methods like show_at:

  1. show_at(tls.valid, 0)
  1. <AxesSubplot:title={'center':'True'}>

Siamese Tutorial - 图5

And we can create a DataLoaders like before, by adding our custom transforms for after_item and after_batch.

  1. dls = tls.dataloaders(after_item=[Resize(224), ToTensor],
  2. after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)])

If we try just now, show_batch will not quite work: the default behavior relies on some data build using data blocks, and we used one big transform for everything. In consequence, instead of having an input with a certain type and an output of a certain type, we have one big type for the whole data. If we look at a batch, we can see that the fastai library has propagated that type for us through every transform and batching operation:

  1. b = dls.one_batch()
  2. type(b)
  1. __main__.SiameseImage

When we call show_batch, the fastai library will realize the batch as a whole has a show method, so it must know how to show itself. It will send that batch directly to the type-dispatched function show_batch. The signature of this function is the following:

  1. show_batch(x, y, samples, ctxs=None, **kwargs)

where the kwargs are specific to the application (here we will have nrows, ncols and figsize for instance). In our case, the batch will be sent as a whole to x and y and samples will be None (those arguments are used when the batch does not have a type that knows how to show itself, see the next section).

To write our custom show_batch we just need to use the type annotation on x like this:

  1. @typedispatch
  2. def show_batch(x:SiameseImage, y, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
  3. if figsize is None: figsize = (ncols*6, max_n//ncols * 3)
  4. if ctxs is None: ctxs = get_grid(min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize)
  5. for i,ctx in enumerate(ctxs): SiameseImage(x[0][i], x[1][i], ['Not similar','Similar'][x[2][i].item()]).show(ctx=ctx)

We will see in the next section that the behavior is different when we have a batch that does not have a show method (which is the case most of the time, only the input and target of the batch have those show methods). In that case, the arguments y and samples are useful. Here, everything is in x, because since the batch knows how to show itself as a whole, it is sent as a whole.

Here, we create a list of matplotlib axis with the utility function get_grid then pass it along to all SiameseImage.show. Let’s see how this looks in practice:

  1. b = dls.one_batch()
  1. dls._types
  1. {__main__.SiameseImage: [fastai.torch_core.TensorImage,
  2. fastai.torch_core.TensorImage,
  3. torch.Tensor]}
  1. dls.show_batch()

Siamese Tutorial - 图6

And we will see in the training section it’s as easy to make a custom show_results. Now let’s see how we could have written our own data block.

Writing your custom data block

The siamese problem is just a particular case of problem with our inputs being a tuple of images and our target being a category. If the type “tuple of images” comes again in other problems with a different target, it might be useful to create a custom block for it, to be able to leverage the power of the data block API.

NB: if your problem only has one particular setup and you don’t need the modular aspect for various targets, what we did before is perfectly fine and you should look no further.

Let’s create a type to represent our tuple of two images:

  1. class ImageTuple(fastuple):
  2. @classmethod
  3. def create(cls, fns): return cls(tuple(PILImage.create(f) for f in fns))
  4. def show(self, ctx=None, **kwargs):
  5. t1,t2 = self
  6. if not isinstance(t1, Tensor) or not isinstance(t2, Tensor) or t1.shape != t2.shape: return ctx
  7. line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)
  8. return show_image(torch.cat([t1,line,t2], dim=2), ctx=ctx, **kwargs)

Since it’s a subclass of fastuple, Transforms will be applied over each part of the tuple. For instance ToTensor will convert this ImageTuple to a tuple of TensorImages:

  1. img = ImageTuple.create((files[0], files[1]))
  2. tst = ToTensor()(img)
  3. type(tst[0]),type(tst[1])
  1. (fastai.torch_core.TensorImage, fastai.torch_core.TensorImage)

In the show method, we did not bother with non-tensor elements this time (we could copy and paste the same code as before). Showing assumes we have a resize transform and that we convert the images to tensors in our procesing pipeline:

  1. img1 = Resize(224)(img)
  2. tst = ToTensor()(img1)
  3. tst.show();

Siamese Tutorial - 图7

We can now define a block associated to ImageTuple that we will use in the data block API. A block is basically a set of default transforms, here we specify how to create the ImageTuple and the IntToFloatTensor transform necessary for image preprocessing:

  1. def ImageTupleBlock(): return TransformBlock(type_tfms=ImageTuple.create, batch_tfms=IntToFloatTensor)

To gather our data with the data block API we will use the following functions:

  1. splits_files = [files[splits[i]] for i in range(2)]
  2. splits_sets = mapped(set, splits_files)
  1. def get_split(f):
  2. for i,s in enumerate(splits_sets):
  3. if f in s: return i
  4. raise ValueError(f'File {f} is not presented in any split.')
  1. splbl2files = [{l: [f for f in s if label_func(f) == l] for l in labels} for s in splits_sets]
  1. def splitter(items):
  2. def get_split_files(i): return [j for j,(f1,f2,same) in enumerate(items) if get_split(f1)==i]
  3. return get_split_files(0),get_split_files(1)
  1. def draw_other(f):
  2. same = random.random() < 0.5
  3. cls = label_func(f)
  4. split = get_split(f)
  5. if not same: cls = random.choice(L(l for l in labels if l != cls))
  6. return random.choice(splbl2files[split][cls]),same
  1. def get_tuples(files): return [[f, *draw_other(f)] for f in files]

And we are ready to define our block:

  1. def get_x(t): return t[:2]
  2. def get_y(t): return t[2]
  1. siamese = DataBlock(
  2. blocks=(ImageTupleBlock, CategoryBlock),
  3. get_items=get_tuples,
  4. get_x=get_x, get_y=get_y,
  5. splitter=splitter,
  6. item_tfms=Resize(224),
  7. batch_tfms=[Normalize.from_stats(*imagenet_stats)]
  8. )
  1. dls = siamese.dataloaders(files)

We can check the types of the elements in one batch with the explode_types method. Here we have a tuple with one ImageTuple of two TensorImages and one TensorCategory. The transform properly kept the types of everything even after collating the samples together!

  1. b = dls.one_batch()
  2. explode_types(b)
  1. {tuple: [{__main__.ImageTuple: [fastai.torch_core.TensorImage,
  2. fastai.torch_core.TensorImage]},
  3. fastai.torch_core.TensorCategory]}

The show_batch method here works out of the box, but to customize how things are organized, we can define a dispatched show_batch function. Here the whole batch is just a tuple, so doesn’t have a show method. The fastai library will dispatch on the first part of the tuple (x) and second part of the tuple (y), the actual samples being in the samples variable.

Here we only dispatch on the x (which means this method will be used for xs that are ImageTuple and any ys), but we could have custom behaviors depending on the targets.

  1. @typedispatch
  2. def show_batch(x:ImageTuple, y, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
  3. if figsize is None: figsize = (ncols*6, max_n//ncols * 3)
  4. if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
  5. ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)
  6. return ctxs

As a sidenote, xand y are not actually used (all that needs to be shown is in the samples list). They are only passed along for type-dispatching because they carry the types of our inputs and targets.

We can now have a look:

  1. dls.show_batch()

Siamese Tutorial - 图8

Training a model

The model

We are now at the stage where we can train a model on this data. We will use a very simple approach: take the body of a pretrained model and make the two images pass through it. Then build a head the usual way, with just twice as many features. The model in itself can be written like this:

  1. class SiameseModel(Module):
  2. def __init__(self, encoder, head):
  3. self.encoder,self.head = encoder,head
  4. def forward(self, x1, x2):
  5. ftrs = torch.cat([self.encoder(x1), self.encoder(x2)], dim=1)
  6. return self.head(ftrs)

For our encoder, we use the fastai function create_body. It takes an architecture and an index where to cut it. By default it will use the pretrained version of the model we pick. If we want to check where fastai usually cuts the model, we can have a look at the model_meta dictionary:

  1. model_meta[resnet34]
  1. {'cut': -2,
  2. 'split': <function fastai.vision.learner._resnet_split(m)>,
  3. 'stats': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])}

So we need to cut at -2:

  1. encoder = create_body(resnet34, cut=-2)

Let’s have a look at the last block of this encoder:

  1. encoder[-1]
  1. Sequential(
  2. (0): BasicBlock(
  3. (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  4. (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  5. (relu): ReLU(inplace=True)
  6. (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  7. (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  8. (downsample): Sequential(
  9. (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
  10. (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  11. )
  12. )
  13. (1): BasicBlock(
  14. (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  15. (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  16. (relu): ReLU(inplace=True)
  17. (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  18. (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  19. )
  20. (2): BasicBlock(
  21. (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  22. (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  23. (relu): ReLU(inplace=True)
  24. (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  25. (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  26. )
  27. )

It ends up with 512 features, so for our custom head, we will need to multiply this by 4 (i.e. 2*2): 2 because we have two images concatenated, and another 2 because of the fastai concat-pool trick (we concatenate the average pool and the max pool of the features). The create_head function will give us the head that is usually used in fastai’s transfer learning models.

We also need to define the number of outputs of our head n_out, in our case it’s 2: One for predicting both images are from the same class, and the other, to predict the contrary.

  1. head = create_head(512*2, 2, ps=0.5)
  2. model = SiameseModel(encoder, head)

Let’s have a look at the generated head:

  1. head
  1. Sequential(
  2. (0): AdaptiveConcatPool2d(
  3. (ap): AdaptiveAvgPool2d(output_size=1)
  4. (mp): AdaptiveMaxPool2d(output_size=1)
  5. )
  6. (1): Flatten(full=False)
  7. (2): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  8. (3): Dropout(p=0.25, inplace=False)
  9. (4): Linear(in_features=2048, out_features=512, bias=False)
  10. (5): ReLU(inplace=True)
  11. (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  12. (7): Dropout(p=0.5, inplace=False)
  13. (8): Linear(in_features=512, out_features=2, bias=False)
  14. )

Train the model

We are almost ready to train our model. The last piece missing is a custom splitter: in order to use transfer learning efficiently, we will want to freeze the pretrained model at first, and train only the head. A splitter is a function that takes a model and returns lists of parameters. The params function is useful to return all parameters of the model, so we can create a simple splitter like so:

  1. def siamese_splitter(model):
  2. return [params(model.encoder), params(model.head)]

Then we use the traditional CrossEntropyLossFlat loss function from fastai (the same as nn.CrossEntropyLoss, but flattened). The only thing is, if using the data built by the mid-level API, we have a tensor of bools for our targets, so we need to convert it to integers otherwise PyTorch will throw an error.

  1. def loss_func(out, targ):
  2. return CrossEntropyLossFlat()(out, targ.long())

Let’s grab the data as built by the mid-level API:

  1. class SiameseTransform(Transform):
  2. def __init__(self, files, splits):
  3. self.splbl2files = [{l: [f for f in files[splits[i]] if label_func(f) == l] for l in labels}
  4. for i in range(2)]
  5. self.valid = {f: self._draw(f,1) for f in files[splits[1]]}
  6. def encodes(self, f):
  7. f2,same = self.valid.get(f, self._draw(f,0))
  8. img1,img2 = PILImage.create(f),PILImage.create(f2)
  9. return SiameseImage(img1, img2, int(same))
  10. def _draw(self, f, split=0):
  11. same = random.random() < 0.5
  12. cls = label_func(f)
  13. if not same: cls = random.choice(L(l for l in labels if l != cls))
  14. return random.choice(self.splbl2files[split][cls]),same
  1. splits = RandomSplitter()(files)
  2. tfm = SiameseTransform(files, splits)
  3. tls = TfmdLists(files, tfm, splits=splits)
  4. dls = tls.dataloaders(after_item=[Resize(224), ToTensor],
  5. after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)])

Again, test that tfm.valid does not contain items from the train split:

  1. valids = [v[0] for k,v in tfm.valid.items()]
  2. assert not [v for v in valids if v in files[splits[0]]]

We can then create our Learner:

  1. learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), splitter=siamese_splitter, metrics=accuracy)

Since we are not using a convenience function that directly creates the Learner for us, we need to freeze it manually:

  1. learn.freeze()

Then we can use the learning rate finder:

  1. learn.lr_find()
  1. SuggestedLRs(lr_min=0.0019054606556892395, lr_steep=1.737800812406931e-05)

Siamese Tutorial - 图9

Train for a bit the head:

  1. learn.fit_one_cycle(4, 3e-3)
epochtrain_lossvalid_lossaccuracytime
00.5439070.3788300.83694200:30
10.3894850.2634160.88971600:35
20.2891010.1995030.92016200:27
30.2441860.1769510.93234100:40

Unfreeze and train the full model for a little more:

  1. learn.unfreeze()
  1. learn.fit_one_cycle(4, slice(1e-6,1e-4))
epochtrain_lossvalid_lossaccuracytime
00.2359340.1752520.93369400:53
10.2182590.1648840.93301800:36
20.2287090.1647890.93369400:58
30.2036050.1603170.93572400:58

Making show_results work

  1. @typedispatch
  2. def show_results(x:SiameseImage, y, samples, outs, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
  3. if figsize is None: figsize = (ncols*6, max_n//ncols * 3)
  4. if ctxs is None: ctxs = get_grid(min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize)
  5. for i,ctx in enumerate(ctxs):
  6. title = f'Actual: {["Not similar","Similar"][x[2][i].item()]} n Prediction: {["Not similar","Similar"][y[2][i].item()]}'
  7. SiameseImage(x[0][i], x[1][i], title).show(ctx=ctx)
  1. learn.show_results()

Siamese Tutorial - 图10

Patch in a siampredict method to Learner, to automatically show images and prediction

  1. @patch
  2. def siampredict(self:Learner, item, rm_type_tfms=None, with_input=False):
  3. res = self.predict(item, rm_type_tfms=None, with_input=False)
  4. if res[0] == tensor(0):
  5. SiameseImage(item[0], item[1], 'Prediction: Not similar').show()
  6. else:
  7. SiameseImage(item[0], item[1], 'Prediction: Similar').show()
  8. return res
  1. imgtest = PILImage.create(files[0])
  2. imgval = PILImage.create(files[100])
  3. siamtest = SiameseImage(imgval, imgtest)
  4. siamtest.show();

Siamese Tutorial - 图11

  1. res = learn.siampredict(siamtest)

Siamese Tutorial - 图12


Company logo

©2021 fast.ai. All rights reserved.
Site last generated: Mar 31, 2021