Metric

class paddle.metric. Metric [源代码]

评估器metric的基类。

用法:

  1. m = SomeMetric()
  2. for prediction, label in ...:
  3. m.update(prediction, label)
  4. m.accumulate()

compute 接口的进阶用法:

在 compute 中可以使用PaddlePaddle内置的算子进行评估器的状态,而不是通过 Python/NumPy, 这样可以加速计算。 update 接口将 compute 的输出作为 输入,内部采用Python/NumPy计算。

Metric 计算流程如下 (在{}中的表示模型和评估器的计算):

  1. inputs & labels || ------------------
  2. | ||
  3. {model} ||
  4. | ||
  5. outputs & labels ||
  6. | || tensor data
  7. {Metric.compute} ||
  8. | ||
  9. metric states(tensor) ||
  10. | ||
  11. {fetch as numpy} || ------------------
  12. | ||
  13. metric states(numpy) || numpy data
  14. | ||
  15. {Metric.update} / ------------------

代码示例

以 计算正确率的 Accuracy 为例,该评估器的输入为 pred 和 label, 可以在 compute 中通过 pred 和 label 先计算正确预测的矩阵。 例如,预测结果包含10类, pred 的shape是[N, 10],

label 的shape是[N, 1], N是batch size,我们需要计算top-1和top-5的准

确率,可以在 compute 中计算每个样本的top-5得分,正确预测的矩阵的shape 是[N, 5].

  1. def compute(pred, label):
  2. # sort prediction and slice the top-5 scores
  3. pred = paddle.argsort(pred, descending=True)[:, :5]
  4. # calculate whether the predictions are correct
  5. correct = pred == label
  6. return paddle.cast(correct, dtype='float32')

在 compute 中的计算,使用内置的算子(可以跑在GPU上,是的速度更快)。 作为 update 的输入,该接口计算如下:

  1. def update(self, correct):
  2. accs = []
  3. for i, k in enumerate(self.topk):
  4. num_corrects = correct[:, :k].sum()
  5. num_samples = len(correct)
  6. accs.append(float(num_corrects) / num_samples)
  7. self.total[i] += num_corrects
  8. self.count[i] += num_samples
  9. return accs

reset()

清空状态和计算结果。

返回:无

update(*args)

更新状态。如果定义了 compute , update 的输入是 compute 的输出。 如果没有定义,则输入是网络的输出output和标签label, 如: update(output1, output2, …, label1, label2,…) .

也可以参考 update 。

accumulate()

累积的统计指标,计算和返回评估结果。

返回:评估结果,一般是个标量 或 多个标量。

name()

返回Metric的名字, 一般通过init构造函数传入。

返回: 评估的名字,string类型。

compute()

此接口可以通过PaddlePaddle内置的算子计算metric的状态,可以加速metric的计算, 为可选的高阶接口。

如果这个接口定义了,输入是网络的输出 outputs 和 标签 labels , 定义如: compute(output1, output2, …, label1, label2,…) 。 如果这个接口没有定义, 默认的行为是直接将输入参数返回给 update ,则其 定义如: update(output1, output2, …, label1, label2,…) 。

也可以参考 compute 。

使用本API的教程文档