GAN
Basic support for Generative Adversarial Networks
/usr/local/lib/python3.8/dist-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
return torch._C._cuda_getDeviceCount() > 0
GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in a dataset, and the critic will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually a probability, 0. for fake images and 1. for real ones).
We train them against each other in the sense that at each step (more or less), we:
Freeze the generator and train the critic for one step by:
- getting one batch of true images (let’s call that
real
) - generating one batch of fake images (let’s call that
fake
) - have the critic evaluate each batch and compute a loss function from that; the important part is that it rewards positively the detection of real images and penalizes the fake ones
- update the weights of the critic with the gradients of this loss
- getting one batch of true images (let’s call that
Freeze the critic and train the generator for one step by:
- generating one batch of fake images
- evaluate the critic on it
- return a loss that rewards positively the critic thinking those are real images
- update the weights of the generator with the gradients of this loss
Note: The fastai library provides support for training GANs through the GANTrainer, but doesn’t include more than basic models.
Wrapping the modules
class
GANModule
[source]
GANModule
(generator
=None
,critic
=None
,gen_mode
=False
) ::Module
Wrapper around a generator
and a critic
to create a GAN.
This is just a shell to contain the two models. When called, it will either delegate the input to the generator
or the critic
depending of the value of gen_mode
.
GANModule.switch
[source]
GANModule.switch
(gen_mode
=None
)
Put the module in generator mode if gen_mode
, in critic mode otherwise.
By default (leaving gen_mode
to None
), this will put the module in the other mode (critic mode if it was in generator mode and vice versa).
basic_critic
[source]
basic_critic
(in_size
,n_channels
,n_features
=64
,n_extra_layers
=0
,norm_type
=<NormType.Batch: 1>
,ks
=3
,stride
=1
,padding
=None
,bias
=None
,ndim
=2
,bn_1st
=True
,act_cls
=ReLU
,transpose
=False
,init
='auto'
,xtra
=None
,bias_std
=0.01
,dilation
:Union
[int
,Tuple
[int
,int
]]=1
,groups
:int
=1
,padding_mode
:str
='zeros'
)
A basic critic for images n_channels
x in_size
x in_size
.
class
AddChannels
[source]
AddChannels
(n_dim
) ::Module
Add n_dim
channels at the end of the input.
basic_generator
[source]
basic_generator
(out_size
,n_channels
,in_sz
=100
,n_features
=64
,n_extra_layers
=0
,ks
=3
,stride
=1
,padding
=None
,bias
=None
,ndim
=2
,norm_type
=<NormType.Batch: 1>
,bn_1st
=True
,act_cls
=ReLU
,transpose
=False
,init
='auto'
,xtra
=None
,bias_std
=0.01
,dilation
:Union
[int
,Tuple
[int
,int
]]=1
,groups
:int
=1
,padding_mode
:str
='zeros'
)
A basic generator from in_sz
to images n_channels
x out_size
x out_size
.
critic = basic_critic(64, 3)
generator = basic_generator(64, 3)
tst = GANModule(critic=critic, generator=generator)
real = torch.randn(2, 3, 64, 64)
real_p = tst(real)
test_eq(real_p.shape, [2,1])
tst.switch() #tst is now in generator mode
noise = torch.randn(2, 100)
fake = tst(noise)
test_eq(fake.shape, real.shape)
tst.switch() #tst is back in critic mode
fake_p = tst(fake)
test_eq(fake_p.shape, [2,1])
DenseResBlock
[source]
DenseResBlock
(nf
,norm_type
=<NormType.Batch: 1>
,ks
=3
,stride
=1
,padding
=None
,bias
=None
,ndim
=2
,bn_1st
=True
,act_cls
=ReLU
,transpose
=False
,init
='auto'
,xtra
=None
,bias_std
=0.01
,dilation
:Union
[int
,Tuple
[int
,int
]]=1
,groups
:int
=1
,padding_mode
:str
='zeros'
)
Resnet block of nf
features. conv_kwargs
are passed to conv_layer
.
gan_critic
[source]
gan_critic
(n_channels
=3
,nf
=128
,n_blocks
=3
,p
=0.15
)
Critic to train a GAN
.
class
GANLoss
[source]
GANLoss
(gen_loss_func
,crit_loss_func
,gan_model
) ::GANModule
Wrapper around crit_loss_func
and gen_loss_func
In generator mode, this loss function expects the output
of the generator and some target
(a batch of real images). It will evaluate if the generator successfully fooled the critic using gen_loss_func
. This loss function has the following signature
def gen_loss_func(fake_pred, output, target):
to be able to combine the output of the critic on output
(which the first argument fake_pred
) with output
and target
(if you want to mix the GAN loss with other losses for instance).
In critic mode, this loss function expects the real_pred
given by the critic and some input
(the noise fed to the generator). It will evaluate the critic using crit_loss_func
. This loss function has the following signature
def crit_loss_func(real_pred, fake_pred):
where real_pred
is the output of the critic on a batch of real images and fake_pred
is generated from the noise using the generator.
class
AdaptiveLoss
[source]
AdaptiveLoss
(crit
) ::Module
Expand the target
to match the output
size before applying crit
.
accuracy_thresh_expand
[source]
accuracy_thresh_expand
(y_pred
,y_true
,thresh
=0.5
,sigmoid
=True
)
Compute accuracy after expanding y_true
to the size of y_pred
.
Callbacks for GAN training
set_freeze_model
[source]
set_freeze_model
(m
,rg
)
class
GANTrainer
[source]
GANTrainer
(switch_eval
=False
,clip
=None
,beta
=0.98
,gen_first
=False
,show_img
=True
) ::Callback
Handles GAN Training.
Warning: The GANTrainer is useless on its own, you need to complete it with one of the following switchers
class
FixedGANSwitcher
[source]
FixedGANSwitcher
(n_crit
=1
,n_gen
=1
) ::Callback
Switcher to do n_crit
iterations of the critic then n_gen
iterations of the generator.
class
AdaptiveGANSwitcher
[source]
AdaptiveGANSwitcher
(gen_thresh
=None
,critic_thresh
=None
) ::Callback
Switcher that goes back to generator/critic when the loss goes below gen_thresh
/crit_thresh
.
class
GANDiscriminativeLR
[source]
GANDiscriminativeLR
(mult_lr
=5.0
) ::Callback
Callback
that handles multiplying the learning rate by mult_lr
for the critic.
GAN data
class
InvisibleTensor
[source]
InvisibleTensor
(x
, **kwargs
) ::TensorBase
A Tensor
which support subclass pickling, and maintains metadata when casting or after methods
generate_noise
[source]
generate_noise
(fn
,size
=100
)
bs = 128
size = 64
dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
get_x = generate_noise,
get_items = get_image_files,
splitter = IndexSplitter([]),
item_tfms=Resize(size, method=ResizeMethod.Crop),
batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))
path = untar_data(URLs.LSUN_BEDROOMS)
dls = dblock.dataloaders(path, path=path, bs=bs)
dls.show_batch(max_n=16)
GAN Learner
gan_loss_from_func
[source]
gan_loss_from_func
(loss_gen
,loss_crit
,weights_gen
=None
)
Define loss functions for a GAN from loss_gen
and loss_crit
.
class
GANLearner
[source]
GANLearner
(dls
,generator
,critic
,gen_loss_func
,crit_loss_func
,switcher
=None
,gen_first
=False
,switch_eval
=True
,show_img
=True
,clip
=None
,cbs
=None
,metrics
=None
,loss_func
=None
,opt_func
=Adam
,lr
=0.001
,splitter
=trainable_params
,path
=None
,model_dir
='models'
,wd
=None
,wd_bn_bias
=False
,train_bn
=True
,moms
=(0.95, 0.85, 0.95)
) ::Learner
A Learner
suitable for GANs.
from fastai.callback.all import *
generator = basic_generator(64, n_channels=3, n_extra_layers=1)
critic = basic_critic (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))
learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)
learn.recorder.train_metrics=True
learn.recorder.valid_metrics=False
learn.fit(1, 2e-4, wd=0.)
learn.show_results(max_n=9, ds_idx=0)
©2021 fast.ai. All rights reserved.
Site last generated: Mar 31, 2021