PyLayerContext

class paddle.autograd. PyLayerContext [源代码]

PyLayerContext 对象能够辅助 PyLayer 实现某些功能。

示例代码

  1. import paddle
  2. from paddle.autograd import PyLayer
  3. class cus_tanh(PyLayer):
  4. @staticmethod
  5. def forward(ctx, x):
  6. # ctx is a object of PyLayerContext.
  7. y = paddle.tanh(x)
  8. ctx.save_for_backward(y)
  9. return y
  10. @staticmethod
  11. def backward(ctx, dy):
  12. # ctx is a object of PyLayerContext.
  13. y, = ctx.saved_tensor()
  14. grad = dy * (1 - paddle.square(y))
  15. return grad

save_for_backward ( self, *tensors )

用于暂存 backward 需要的 Tensor ,在 backward 中调用 saved_tensor 获取这些 Tensor

注解

这个API只能被调用一次,且只能在 forward 中调用。

参数

  • tensors (list of Tensor) - 需要被暂存的 Tensor

返回:None

示例代码

  1. import paddle
  2. from paddle.autograd import PyLayer
  3. class cus_tanh(PyLayer):
  4. @staticmethod
  5. def forward(ctx, x):
  6. # ctx is a context object that store some objects for backward.
  7. y = paddle.tanh(x)
  8. # Pass tensors to backward.
  9. ctx.save_for_backward(y)
  10. return y
  11. @staticmethod
  12. def backward(ctx, dy):
  13. # Get the tensors passed by forward.
  14. y, = ctx.saved_tensor()
  15. grad = dy * (1 - paddle.square(y))
  16. return grad

saved_tensor ( self, *tensors )

获取被 save_for_backward 暂存的 Tensor

返回:如果调用 save_for_backward 暂存了一些 Tensor ,则返回这些 Tensor ,否则,返回 None。

示例代码

  1. import paddle
  2. from paddle.autograd import PyLayer
  3. class cus_tanh(PyLayer):
  4. @staticmethod
  5. def forward(ctx, x):
  6. # ctx is a context object that store some objects for backward.
  7. y = paddle.tanh(x)
  8. # Pass tensors to backward.
  9. ctx.save_for_backward(y)
  10. return y
  11. @staticmethod
  12. def backward(ctx, dy):
  13. # Get the tensors passed by forward.
  14. y, = ctx.saved_tensor()
  15. grad = dy * (1 - paddle.square(y))
  16. return grad