vision.gan
All the modules and callbacks necessary to train a GAN
GANs
GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we will train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in our dataset, and the critic’s job will try to classify real images from the fake ones the generator does. The generator returns images, the discriminator a feature map (it can be a single number depending on the input size). Usually the discriminator will be trained to return 0. everywhere for fake images and 1. everywhere for real ones.
This module contains all the necessary function to create a GAN.
We train them against each other in the sense that at each step (more or less), we:
Freeze the generator and train the discriminator for one step by:
- getting one batch of true images (let’s call that
real
) - generating one batch of fake images (let’s call that
fake
) - have the discriminator evaluate each batch and compute a loss function from that; the important part is that it rewards positively the detection of real images and penalizes the fake ones
- update the weights of the discriminator with the gradients of this loss
- getting one batch of true images (let’s call that
Freeze the discriminator and train the generator for one step by:
- generating one batch of fake images
- evaluate the discriminator on it
- return a loss that rewards posisitivly the discriminator thinking those are real images; the important part is that it rewards positively the detection of real images and penalizes the fake ones
- update the weights of the generator with the gradients of this loss
class
GANLearner
[source][test]
GANLearner
(data
:DataBunch
,generator
:Module
,critic
:Module
,gen_loss_func
:LossFunction
,crit_loss_func
:LossFunction
,switcher
:Callback
=None
,gen_first
:bool
=False
,switch_eval
:bool
=True
,show_img
:bool
=True
,clip
:float
=None
, **learn_kwargs
) ::Learner
No tests found forGANLearner
. To contribute a test please refer to this guide and this discussion.
A Learner
suitable for GANs.
This is the general constructor to create a GAN, you might want to use one of the factory methods that are easier to use. Create a GAN from data
, a generator
and a critic
. The data
should have the inputs the generator
will expect and the images wanted as targets.
gen_loss_func
is the loss function that will be applied to the generator
. It takes three argument fake_pred
, target
, output
and should return a rank 0 tensor. output
is the result of the generator
applied to the input (the xs of the batch), target
is the ys of the batch and fake_pred
is the result of the discriminator
being given output
. output
and target
can be used to add a specific loss to the GAN loss (pixel loss, feature loss) and for a good training of the gan, the loss should encourage fake_pred
to be as close to 1 as possible (the generator
is trained to fool the critic
).
crit_loss_func
is the loss function that will be applied to the critic
. It takes two arguments real_pred
and fake_pred
. real_pred
is the result of the critic
on the target images (the ys of the batch) and fake_pred
is the result of the critic
applied on a batch of fake, generated byt the generator
from the xs of the batch.
switcher
is a Callback
that should tell the GAN when to switch from critic to generator and vice versa. By default it does 5 iterations of the critic for 1 iteration of the generator. The model begins the training with the generator
if gen_first=True
. If switch_eval=True
, the model that isn’t trained is switched on eval mode (left in training mode otherwise, which means some statistics like the running mean in batchnorm layers are updated, or the dropouts are applied).
clip
should be set to a certain value if one wants to clip the weights (see the Wassertein GAN for instance).
If show_img=True
, one image generated by the GAN is shown at the end of each epoch.
Factory methods
from_learners
[source][test]
from_learners
(learn_gen
:Learner
,learn_crit
:Learner
,switcher
:Callback
=None
,weights_gen
:Point
=None
, **learn_kwargs
) No tests found forfrom_learners
. To contribute a test please refer to this guide and this discussion.
Create a GAN from learn_gen
and learn_crit
.
Directly creates a GANLearner
from two Learner
: one for the generator
and one for the critic
. The switcher
and all kwargs
will be passed to the initialization of GANLearner
along with the following loss functions:
loss_func_crit
is the mean oflearn_crit.loss_func
applied toreal_pred
and a target of ones withlearn_crit.loss_func
applied tofake_pred
and a target of zerosloss_func_gen
is the mean oflearn_crit.loss_func
applied tofake_pred
and a target of ones (to full the discriminator) withlearn_gen.loss_func
applied tooutput
andtarget
. The weights of each of those contributions can be passed inweights_gen
(default is 1. and 1.)
wgan
[source][test]
wgan
(data
:DataBunch
,generator
:Module
,critic
:Module
,switcher
:Callback
=None
,clip
:float
=0.01
, **learn_kwargs
) No tests found forwgan
. To contribute a test please refer to this guide and this discussion.
Create a WGAN from data
, generator
and critic
.
The Wasserstein GAN is detailed in [this article]. switcher
and the kwargs
will be passed to the GANLearner
init, clip
is the weight clipping.
Switchers
In any GAN training, you will need to tell the Learner
when to switch from generator to critic and vice versa. The two following Callback
are examples to help you with that.
As usual, don’t call the on_something
methods directly, the fastai library will do it for you during training.
class
FixedGANSwitcher
[source][test]
FixedGANSwitcher
(learn
:Learner
,n_crit
:Union
[int
,Callable
]=1
,n_gen
:Union
[int
,Callable
]=1
) ::LearnerCallback
No tests found forFixedGANSwitcher
. To contribute a test please refer to this guide and this discussion.
Switcher to do n_crit
iterations of the critic then n_gen
iterations of the generator.
on_train_begin
[source][test]
on_train_begin
(**kwargs
) No tests found foron_train_begin
. To contribute a test please refer to this guide and this discussion.
Initiate the iteration counts.
on_batch_end
[source][test]
on_batch_end
(iteration
, **kwargs
) No tests found foron_batch_end
. To contribute a test please refer to this guide and this discussion.
Switch the model if necessary.
class
AdaptiveGANSwitcher
[source][test]
AdaptiveGANSwitcher
(learn
:Learner
,gen_thresh
:float
=None
,critic_thresh
:float
=None
) ::LearnerCallback
No tests found forAdaptiveGANSwitcher
. To contribute a test please refer to this guide and this discussion.
Switcher that goes back to generator/critic when the loss goes below gen_thresh
/crit_thresh
.
on_batch_end
[source][test]
on_batch_end
(last_loss
, **kwargs
) No tests found foron_batch_end
. To contribute a test please refer to this guide and this discussion.
Switch the model if necessary.
Discriminative LR
If you want to train your critic at a different learning rate than the generator, this will let you do it automatically (even if you have a learning rate schedule).
class
GANDiscriminativeLR
[source][test]
GANDiscriminativeLR
(learn
:Learner
,mult_lr
:float
=5.0
) ::LearnerCallback
No tests found forGANDiscriminativeLR
. To contribute a test please refer to this guide and this discussion.
Callback
that handles multiplying the learning rate by mult_lr
for the critic.
on_batch_begin
[source][test]
on_batch_begin
(train
, **kwargs
) No tests found foron_batch_begin
. To contribute a test please refer to this guide and this discussion.
Multiply the current lr if necessary.
on_step_end
[source][test]
on_step_end
(**kwargs
) No tests found foron_step_end
. To contribute a test please refer to this guide and this discussion.
Put the LR back to its value if necessary.
Specific models
basic_critic
[source][test]
basic_critic
(in_size
:int
,n_channels
:int
,n_features
:int
=64
,n_extra_layers
:int
=0
, **conv_kwargs
) Tests found forbasic_critic
:
pytest -sv tests/test_vision_gan.py::test_basic_critic
[source]
Some other tests where basic_critic
is used:
pytest -sv tests/test_vision_gan.py::test_gan_module
[source]
To run tests please refer to this guide.
A basic critic for images n_channels
x in_size
x in_size
.
This model contains a first 4 by 4 convolutional layer of stride 2 from n_channels
to n_features
followed by n_extra_layers
3 by 3 convolutional layer of stride 1. Then we put as many 4 by 4 convolutional layer of stride 2 with a number of features multiplied by 2 at each stage so that the in_size
becomes 1. kwargs
can be used to customize the convolutional layers and are passed to conv_layer
.
basic_generator
[source][test]
basic_generator
(in_size
:int
,n_channels
:int
,noise_sz
:int
=100
,n_features
:int
=64
,n_extra_layers
=0
, **conv_kwargs
) Tests found forbasic_generator
:
pytest -sv tests/test_vision_gan.py::test_basic_generator
[source]
Some other tests where basic_generator
is used:
pytest -sv tests/test_vision_gan.py::test_gan_module
[source]
To run tests please refer to this guide.
A basic generator from noise_sz
to images n_channels
x in_size
x in_size
.
This model contains a first 4 by 4 transposed convolutional layer of stride 1 from noise_size
to the last numbers of features of the corresponding critic. Then we put as many 4 by 4 transposed convolutional layer of stride 2 with a number of features divided by 2 at each stage so that the image ends up being of height and widht in_size//2
. At the end, we addn_extra_layers
3 by 3 convolutional layer of stride 1. The last layer is a transpose convolution of size 4 by 4 and stride 2 followed by tanh
. kwargs
can be used to customize the convolutional layers and are passed to conv_layer
.
gan_critic
[source][test]
gan_critic
(n_channels
:int
=3
,nf
:int
=128
,n_blocks
:int
=3
,p
:int
=0.15
) No tests found forgan_critic
. To contribute a test please refer to this guide and this discussion.
Critic to train a GAN
.
class
GANTrainer
[source][test]
GANTrainer
(learn
:Learner
,switch_eval
:bool
=False
,clip
:float
=None
,beta
:float
=0.98
,gen_first
:bool
=False
,show_img
:bool
=True
) ::LearnerCallback
Tests found forGANTrainer
:
pytest -sv tests/test_vision_gan.py::test_gan_trainer
[source]
To run tests please refer to this guide.
Handles GAN Training.
LearnerCallback
that will be responsible to handle the two different optimizers (one for the generator and one for the critic), and do all the work behind the scenes so that the generator (or the critic) are in training mode with parameters requirement gradients each time we switch.
switch_eval=True
means that the GANTrainer
will put the model that isn’t training into eval mode (if it’s False
its running statistics like in batchnorm layers will be updated and dropout will be applied). clip
is the clipping applied to the weights (if not None
). beta
is the coefficient for the moving averages as the GANTrainer
tracks separately the generator loss and the critic loss. gen_first=True
means the training begins with the generator (with the critic if it’s False
). If show_img=True
we show a generated image at the end of each epoch.
switch
[source][test]
switch
(gen_mode
:bool
=None
) Tests found forswitch
:
Some other tests where switch
is used:
pytest -sv tests/test_vision_gan.py::test_gan_module
[source]
To run tests please refer to this guide.
Switch the model, if gen_mode
is provided, in the desired mode.
If gen_mode
is left as None
, just put the model in the other mode (critic if it was in generator mode and vice versa).
on_train_begin
[source][test]
on_train_begin
(**kwargs
) No tests found foron_train_begin
. To contribute a test please refer to this guide and this discussion.
Create the optimizers for the generator and critic if necessary, initialize smootheners.
on_epoch_begin
[source][test]
on_epoch_begin
(epoch
, **kwargs
) No tests found foron_epoch_begin
. To contribute a test please refer to this guide and this discussion.
Put the critic or the generator back to eval if necessary.
on_batch_begin
[source][test]
on_batch_begin
(last_input
,last_target
, **kwargs
) No tests found foron_batch_begin
. To contribute a test please refer to this guide and this discussion.
Clamp the weights with self.clip
if it’s not None, return the correct input.
on_backward_begin
[source][test]
on_backward_begin
(last_loss
,last_output
, **kwargs
) No tests found foron_backward_begin
. To contribute a test please refer to this guide and this discussion.
Record last_loss
in the proper list.
on_epoch_end
[source][test]
on_epoch_end
(pbar
,epoch
,last_metrics
, **kwargs
) No tests found foron_epoch_end
. To contribute a test please refer to this guide and this discussion.
Put the various losses in the recorder and show a sample image.
on_train_end
[source][test]
on_train_end
(**kwargs
) No tests found foron_train_end
. To contribute a test please refer to this guide and this discussion.
Switch in generator mode for showing results.
Specific modules
class
GANModule
[source][test]
GANModule
(generator
:Module
=None
,critic
:Module
=None
,gen_mode
:bool
=False
) ::PrePostInitMeta
::Module
Tests found forGANModule
:
pytest -sv tests/test_vision_gan.py::test_gan_module
[source]
To run tests please refer to this guide.
Wrapper around a generator
and a critic
to create a GAN.
If gen_mode
is left as None
, just put the model in the other mode (critic if it was in generator mode and vice versa).
switch
[source][test]
switch
(gen_mode
:bool
=None
) Tests found forswitch
:
Some other tests where switch
is used:
pytest -sv tests/test_vision_gan.py::test_gan_module
[source]
To run tests please refer to this guide.
Put the model in generator mode if gen_mode
, in critic mode otherwise.
class
GANLoss
[source][test]
GANLoss
(loss_funcG
:Callable
,loss_funcC
:Callable
,gan_model
:GANModule
) ::PrePostInitMeta
::GANModule
No tests found forGANLoss
. To contribute a test please refer to this guide and this discussion.
Wrapper around loss_funcC
(for the critic) and loss_funcG
(for the generator).
class
AdaptiveLoss
[source][test]
AdaptiveLoss
(crit
) ::PrePostInitMeta
::Module
No tests found forAdaptiveLoss
. To contribute a test please refer to this guide and this discussion.
Expand the target
to match the output
size before applying crit
.
accuracy_thresh_expand
[source][test]
accuracy_thresh_expand
(y_pred
:Tensor
,y_true
:Tensor
,thresh
:float
=0.5
,sigmoid
:bool
=True
) →Rank0Tensor
No tests found foraccuracy_thresh_expand
. To contribute a test please refer to this guide and this discussion.
Compute accuracy after expanding y_true
to the size of y_pred
.
Data Block API
class
NoisyItem
[source][test]
NoisyItem
(noise_sz
) ::ItemBase
Tests found forNoisyItem
:
pytest -sv tests/test_vision_gan.py::test_noisy_item
[source]
To run tests please refer to this guide.
An random ItemBase
of size noise_sz
.
class
GANItemList
[source][test]
GANItemList
(items
,noise_sz
:int
=100
, **kwargs
) ::ImageList
Tests found forGANItemList
:
Some other tests where GANItemList
is used:
pytest -sv tests/test_vision_gan.py::test_gan_datasets
[source]
To run tests please refer to this guide.
ItemList
suitable for GANs.
Inputs will be NoisyItem
of noise_sz
while the default class for target is ImageList
.
show_xys
[source][test]
show_xys
(xs
,ys
,imgsize
:int
=4
,figsize
:Optional
[Tuple
[int
,int
]]=None
, **kwargs
) No tests found forshow_xys
. To contribute a test please refer to this guide and this discussion.
Shows ys
(target images) on a figure of figsize
.
show_xyzs
[source][test]
show_xyzs
(xs
,ys
,zs
,imgsize
:int
=4
,figsize
:Optional
[Tuple
[int
,int
]]=None
, **kwargs
) No tests found forshow_xyzs
. To contribute a test please refer to this guide and this discussion.
Shows zs
(generated images) on a figure of figsize
.
©2021 fast.ai. All rights reserved.
Site last generated: Jan 5, 2021