binary_cross_entropy_with_logits

paddle.nn.functional. binary_cross_entropy_with_logits ( logit, label, weight=None, reduction=’mean’, pos_weight=None, name=None ) [源代码]

该OP用于计算输入 logit 和标签 label 间的 binary cross entropy with logits loss 损失。

该OP结合了 sigmoid 操作和 api_nn_loss_BCELoss 操作。同时,我们也可以认为该OP是 sigmoid_cross_entrop_with_logits 和一些 reduce 操作的组合。

在每个类别独立的分类任务中,该OP可以计算按元素的概率误差。可以将其视为预测数据点的标签,其中标签不是互斥的。例如,一篇新闻文章可以同时关于政治,科技,体育或者同时不包含这些内容。

首先,该OP可通过下式计算损失函数:

binary_cross_entropy_with_logits - 图1

其中

binary_cross_entropy_with_logits - 图2

, 代入上方计算公式中:

binary_cross_entropy_with_logits - 图3

为了计算稳定性,防止当

binary_cross_entropy_with_logits - 图4

时,

binary_cross_entropy_with_logits - 图5

溢出,loss将采用以下公式计算:

binary_cross_entropy_with_logits - 图6

然后,当 weight or pos_weight 不为None的时候,该算子会在输出Out上乘以相应的权重。张量 weight 给Batch中的每一条数据赋予不同权重,张量 pos_weight 给每一类的正例添加相应的权重。

最后,该算子会添加 reduce 操作到前面的输出Out上。当 reduction 为 none 时,直接返回最原始的 Out 结果。当 reduction 为 mean 时,返回输出的均值 Out=MEAN(Out)Out=MEAN(Out) 。当 reduction 为 sum 时,返回输出的求和 Out=SUM(Out)Out=SUM(Out) 。

**注意: 因为是二分类任务,所以标签值应该是0或者1。

参数

  • logit (Tensor) - [N,∗][N,∗] , 其中N是batch_size, * 是任意其他维度。输入数据 logit 一般是线性层的输出,不需要经过 sigmoid 层。数据类型是float32、float64。

  • label (Tensor) - [N,∗][N,∗] ,标签 label 的维度、数据类型与输入 logit 相同。

  • weight (Tensor,可选) - 手动指定每个batch二值交叉熵的权重,如果指定的话,维度必须是一个batch的数据的维度。数据类型是float32, float64。默认值是:None。

  • reduction (str,可选) - 指定应用于输出结果的计算方式,可选值有: 'none', 'mean', 'sum' 。默认为 'mean',计算 BCELoss 的均值;设置为 'sum' 时,计算 BCELoss 的总和;设置为 'none' 时,则返回原始loss。

  • pos_weight (Tensor,可选) - 手动指定正类的权重,必须是与类别数相等长度的向量。数据类型是float32, float64。默认值是:None。

  • name (str,可选) - 操作的名称(可选,默认值为None)。更多信息请参见 Name

返回

  • Tensor,输出的Tensor。如果 reduction'none', 则输出的维度为 [N,∗][N,∗] , 与输入 input 的形状相同。如果 reduction'mean''sum', 则输出的维度为 [1][1] 。

代码示例

  1. import paddle
  2. logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32")
  3. label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32")
  4. output = paddle.nn.functional.binary_cross_entropy_with_logits(logit, label)
  5. print(output) # [0.45618808]