Loss Functions
Custom fastai loss functions
/usr/local/lib/python3.8/dist-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
return torch._C._cuda_getDeviceCount() > 0
class
BaseLoss
[source]
BaseLoss
(loss_cls
, *args
,axis
=-1
,flatten
=True
,floatify
=False
,is_2d
=True
, **kwargs
)
Same as loss_cls
, but flattens input and target.
Wrapping a general loss function inside of BaseLoss
provides extra functionalities to your loss functions:
- flattens the tensors before trying to take the losses since it’s more convenient (with a potential tranpose to put
axis
at the end) - a potential
activation
method that tells the library if there is an activation fused in the loss (useful for inference and methods such asLearner.get_preds
orLearner.predict
) - a potential
decodes
method that is used on predictions in inference (for instance, an argmax in classification)
The args
and kwargs
will be passed to loss_cls
during the initialization to instantiate a loss function. axis
is put at the end for losses like softmax that are often performed on the last axis. If floatify=True
, the targs
will be converted to floats (useful for losses that only accept float targets like BCEWithLogitsLoss
), and is_2d
determines if we flatten while keeping the first dimension (batch size) or completely flatten the input. We want the first for losses like Cross Entropy, and the second for pretty much anything else.
class
CrossEntropyLossFlat
[source]
CrossEntropyLossFlat
(*args
,axis
=-1
,weight
=None
,ignore_index
=-100
,reduction
='mean'
,flatten
=True
,floatify
=False
,is_2d
=True
) ::BaseLoss
Same as nn.CrossEntropyLoss
, but flattens input and target.
tst = CrossEntropyLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
#nn.CrossEntropy would fail with those two tensors, but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.CrossEntropyLoss()(output,target))
#Associated activation is softmax
test_eq(tst.activation(output), F.softmax(output, dim=-1))
#This loss function has a decodes which is argmax
test_eq(tst.decodes(output), output.argmax(dim=-1))
tst = CrossEntropyLossFlat(axis=1)
output = torch.randn(32, 5, 128, 128)
target = torch.randint(0, 5, (32, 128, 128))
_ = tst(output, target)
test_eq(tst.activation(output), F.softmax(output, dim=1))
test_eq(tst.decodes(output), output.argmax(dim=1))
Focal Loss is the same as cross entropy except easy-to-classify observations are down-weighted in the loss calculation. The strength of down-weighting is proportional to the size of the gamma
parameter. Put another way, the larger gamma
the less the easy-to-classify observations contribute to the loss.
class
FocalLossFlat
[source]
FocalLossFlat
(*args
,gamma
=2
,axis
=-1
,weight
=None
,ignore_index
=-100
,reduction
='mean'
, **kwargs
) ::CrossEntropyLossFlat
Same as CrossEntropyLossFlat but with focal paramter, gamma
. Focal loss is introduced by Lin et al. https://arxiv.org/pdf/1708.02002.pdf. Note the class weighting factor in the paper, alpha, can be implemented through pytorch weight
argument in nn.CrossEntropyLoss.
fl = FocalLossFlat(gamma=0)
ce = CrossEntropyLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
test_close(fl(output, target), ce(output, target))
#Test focal loss with gamma > 0 is different than cross entropy
fl = FocalLossFlat(gamma=2)
test_ne(fl(output, target), ce(output, target))
fl = FocalLossFlat(gamma=0, axis=1)
ce = CrossEntropyLossFlat(axis=1)
output = torch.randn(32, 5, 128, 128)
target = torch.randint(0, 5, (32, 128, 128))
test_close(fl(output, target), ce(output, target), eps=1e-4)
test_eq(fl.activation(output), F.softmax(output, dim=1))
test_eq(fl.decodes(output), output.argmax(dim=1))
class
BCEWithLogitsLossFlat
[source]
BCEWithLogitsLossFlat
(*args
,axis
=-1
,floatify
=True
,thresh
=0.5
,weight
=None
,reduction
='mean'
,pos_weight
=None
,flatten
=True
,is_2d
=True
) ::BaseLoss
Same as nn.BCEWithLogitsLoss
, but flattens input and target.
tst = BCEWithLogitsLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randn(32, 5, 10)
#nn.BCEWithLogitsLoss would fail with those two tensors, but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))
output = torch.randn(32, 5)
target = torch.randint(0,2,(32, 5))
#nn.BCEWithLogitsLoss would fail with int targets but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))
tst = BCEWithLogitsLossFlat(pos_weight=torch.ones(10))
output = torch.randn(32, 5, 10)
target = torch.randn(32, 5, 10)
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))
#Associated activation is sigmoid
test_eq(tst.activation(output), torch.sigmoid(output))
BCELossFlat
[source]
BCELossFlat
(*args
,axis
=-1
,floatify
=True
,weight
=None
,reduction
='mean'
)
Same as nn.BCELoss
, but flattens input and target.
tst = BCELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.BCELoss()(output,target))
MSELossFlat
[source]
MSELossFlat
(*args
,axis
=-1
,floatify
=True
,reduction
='mean'
)
Same as nn.MSELoss
, but flattens input and target.
tst = MSELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.MSELoss()(output,target))
L1LossFlat
[source]
L1LossFlat
(*args
,axis
=-1
,floatify
=True
,reduction
='mean'
)
Same as nn.L1Loss
, but flattens input and target.
class
LabelSmoothingCrossEntropy
[source]
LabelSmoothingCrossEntropy
(eps
:float
=0.1
,weight
=None
,reduction
='mean'
) ::Module
Same as nn.Module
, but no need for subclasses to call super().__init__
lmce = LabelSmoothingCrossEntropy()
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
test_eq(lmce(output.flatten(0,1), target.flatten()), lmce(output.transpose(-1,-2), target))
On top of the formula we define:
- a
reduction
attribute, that will be used when we callLearner.get_preds
weight
attribute to pass to BCE.- an
activation
function that represents the activation fused in the loss (since we use cross entropy behind the scenes). It will be applied to the output of the model when callingLearner.get_preds
orLearner.predict
- a
decodes
function that converts the output of the model to a format similar to the target (here indices). This is used inLearner.predict
andLearner.show_results
to decode the predictions
class
LabelSmoothingCrossEntropyFlat
[source]
LabelSmoothingCrossEntropyFlat
(*args
,axis
=-1
,eps
=0.1
,reduction
='mean'
,flatten
=True
,floatify
=False
,is_2d
=True
) ::BaseLoss
Same as LabelSmoothingCrossEntropy
, but flattens input and target.
©2021 fast.ai. All rights reserved.
Site last generated: Mar 31, 2021