set_gradient_clip

注意:该API仅支持【静态图】模式

  • paddle.fluid.clip.set_gradient_clip(clip, param_list=None, program=None)[源代码]

给指定参数做梯度裁剪。

  • 参数:
    • clip (BaseGradientClipAttr) - BaseGradientClipAttr子类的实例,如 GradientClipByGlobalNorm 等,用于描述具体的裁剪方法和属性。
    • param_list (list(Variable),可选) - 需要裁剪的参数列表,可以是参数或参数名称列表。默认值为None,表示裁剪 program 中的所有参数。
    • program (Program,可选) - 参数所在的Program。默认值为None,表示使用 default_main_program

返回: 无。

代码示例

  1. import paddle.fluid as fluid
  2.  
  3. def network():
  4. image = fluid.layers.data(name='image', shape=[28], dtype='float32')
  5. param_attr1 = fluid.ParamAttr("fc1_param")
  6. fc1 = fluid.layers.fc(image, size=10, param_attr=param_attr1)
  7. param_attr2 = fluid.ParamAttr("fc2_param")
  8. fc2 = fluid.layers.fc(fc1, size=10, param_attr=param_attr2)
  9. loss = fluid.layers.reduce_mean(fc2)
  10. return loss
  11.  
  12. # network 1: clip all parameter gradient
  13. with fluid.program_guard(fluid.Program(), fluid.Program()):
  14. loss = network()
  15. fluid.clip.set_gradient_clip(
  16. fluid.clip.GradientClipByGlobalNorm(clip_norm=2.0))
  17. sgd = fluid.optimizer.SGD(learning_rate=1e-3)
  18. sgd.minimize(loss)
  19.  
  20. # network 2: clip parameter gradient by name
  21. with fluid.program_guard(fluid.Program(), fluid.Program()):
  22. loss = network()
  23. fluid.clip.set_gradient_clip(
  24. fluid.clip.GradientClipByValue(min=-1.0, max=1.0),
  25. param_list=["fc1_param", "fc2_param"])
  26. sgd = fluid.optimizer.SGD(learning_rate=1e-3)
  27. sgd.minimize(loss)
  28.  
  29. # network 3: clip parameter gradient by var
  30. with fluid.program_guard(fluid.Program(), fluid.Program()):
  31. loss = network()
  32. param_var1 = fluid.default_main_program().global_block().var("fc1_param")
  33. param_var2 = fluid.default_main_program().global_block().var("fc2_param")
  34. fluid.clip.set_gradient_clip(
  35. fluid.clip.GradientClipByValue(min=-1.0, max=1.0),
  36. param_list=[param_var1, param_var2])
  37. sgd = fluid.optimizer.SGD(learning_rate=1e-3)
  38. sgd.minimize(loss)