Basic support for Generative Adversarial Networks

GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in a dataset, and the critic will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually a probability, 0. for fake images and 1. for real ones).

We train them against each other in the sense that at each step (more or less), we:

  1. Freeze the generator and train the critic for one step by:

    • getting one batch of true images (let’s call that real)
    • generating one batch of fake images (let’s call that fake)
    • have the critic evaluate each batch and compute a loss function from that; the important part is that it rewards positively the detection of real images and penalizes the fake ones
    • update the weights of the critic with the gradients of this loss
  2. Freeze the critic and train the generator for one step by:

    • generating one batch of fake images
    • evaluate the critic on it
    • return a loss that rewards positively the critic thinking those are real images
    • update the weights of the generator with the gradients of this loss

Note: The fastai library provides support for training GANs through the GANTrainer, but doesn’t include more than basic models.

Wrapping the modules

class GANModule[source]

GANModule(generator=None, critic=None, gen_mode=False) :: Module

Wrapper around a generator and a critic to create a GAN.

This is just a shell to contain the two models. When called, it will either delegate the input to the generator or the critic depending of the value of gen_mode.



Put the module in generator mode if gen_mode, in critic mode otherwise.

By default (leaving gen_mode to None), this will put the module in the other mode (critic mode if it was in generator mode and vice versa).


basic_critic(in_size, n_channels, n_features=64, n_extra_layers=0, norm_type=<NormType.Batch: 1>, ks=3, stride=1, padding=None, bias=None, ndim=2, bn_1st=True, act_cls=ReLU, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int, Tuple[int, int]]=1, groups:int=1, padding_mode:str='zeros')

A basic critic for images n_channels x in_size x in_size.

class AddChannels[source]

AddChannels(n_dim) :: Module

Add n_dim channels at the end of the input.


basic_generator(out_size, n_channels, in_sz=100, n_features=64, n_extra_layers=0, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=<NormType.Batch: 1>, bn_1st=True, act_cls=ReLU, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int, Tuple[int, int]]=1, groups:int=1, padding_mode:str='zeros')

A basic generator from in_sz to images n_channels x out_size x out_size.

  1. critic = basic_critic(64, 3)
  2. generator = basic_generator(64, 3)
  3. tst = GANModule(critic=critic, generator=generator)
  4. real = torch.randn(2, 3, 64, 64)
  5. real_p = tst(real)
  6. test_eq(real_p.shape, [2,1])
  7. tst.switch() #tst is now in generator mode
  8. noise = torch.randn(2, 100)
  9. fake = tst(noise)
  10. test_eq(fake.shape, real.shape)
  11. tst.switch() #tst is back in critic mode
  12. fake_p = tst(fake)
  13. test_eq(fake_p.shape, [2,1])


DenseResBlock(nf, norm_type=<NormType.Batch: 1>, ks=3, stride=1, padding=None, bias=None, ndim=2, bn_1st=True, act_cls=ReLU, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int, Tuple[int, int]]=1, groups:int=1, padding_mode:str='zeros')

Resnet block of nf features. conv_kwargs are passed to conv_layer.


gan_critic(n_channels=3, nf=128, n_blocks=3, p=0.15)

Critic to train a GAN.

class GANLoss[source]

GANLoss(gen_loss_func, crit_loss_func, gan_model) :: GANModule

Wrapper around crit_loss_func and gen_loss_func

In generator mode, this loss function expects the output of the generator and some target (a batch of real images). It will evaluate if the generator successfully fooled the critic using gen_loss_func. This loss function has the following signature

  1. def gen_loss_func(fake_pred, output, target):

to be able to combine the output of the critic on output (which the first argument fake_pred) with output and target (if you want to mix the GAN loss with other losses for instance).

In critic mode, this loss function expects the real_pred given by the critic and some input (the noise fed to the generator). It will evaluate the critic using crit_loss_func. This loss function has the following signature

  1. def crit_loss_func(real_pred, fake_pred):

where real_pred is the output of the critic on a batch of real images and fake_pred is generated from the noise using the generator.

class AdaptiveLoss[source]

AdaptiveLoss(crit) :: Module

Expand the target to match the output size before applying crit.


accuracy_thresh_expand(y_pred, y_true, thresh=0.5, sigmoid=True)

Compute accuracy after expanding y_true to the size of y_pred.

Callbacks for GAN training


set_freeze_model(m, rg)

class GANTrainer[source]

GANTrainer(switch_eval=False, clip=None, beta=0.98, gen_first=False, show_img=True) :: Callback

Handles GAN Training.

Warning: The GANTrainer is useless on its own, you need to complete it with one of the following switchers

class FixedGANSwitcher[source]

FixedGANSwitcher(n_crit=1, n_gen=1) :: Callback

Switcher to do n_crit iterations of the critic then n_gen iterations of the generator.

class AdaptiveGANSwitcher[source]

AdaptiveGANSwitcher(gen_thresh=None, critic_thresh=None) :: Callback

Switcher that goes back to generator/critic when the loss goes below gen_thresh/crit_thresh.

class GANDiscriminativeLR[source]

GANDiscriminativeLR(mult_lr=5.0) :: Callback

Callback that handles multiplying the learning rate by mult_lr for the critic.

GAN data

class InvisibleTensor[source]

InvisibleTensor(x, **kwargs) :: TensorBase

A Tensor which support subclass pickling, and maintains metadata when casting or after methods


generate_noise(fn, size=100)

  1. bs = 128
  2. size = 64
  1. dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
  2. get_x = generate_noise,
  3. get_items = get_image_files,
  4. splitter = IndexSplitter([]),
  5. item_tfms=Resize(size, method=ResizeMethod.Crop),
  6. batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))
  1. path = untar_data(URLs.LSUN_BEDROOMS)
  1. dls = dblock.dataloaders(path, path=path, bs=bs)
  1. dls.show_batch(max_n=16)

GAN Learner


gan_loss_from_func(loss_gen, loss_crit, weights_gen=None)

Define loss functions for a GAN from loss_gen and loss_crit.

class GANLearner[source]

GANLearner(dls, generator, critic, gen_loss_func, crit_loss_func, switcher=None, gen_first=False, switch_eval=True, show_img=True, clip=None, cbs=None, metrics=None, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95)) :: Learner

A Learner suitable for GANs.

  1. from fastai.callback.all import *
  1. generator = basic_generator(64, n_channels=3, n_extra_layers=1)
  2. critic = basic_critic (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))
  1. learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)
  1. learn.recorder.train_metrics=True
  2. learn.recorder.valid_metrics=False
  1., 2e-4, wd=0.)
  1. learn.show_results(max_n=9, ds_idx=0)

