L2Decay

paddle.fluid.regularizer.L2Decay

L2Decay实现L2权重衰减正则化,用于模型训练,有助于防止模型对训练数据过拟合。

该类生成的实例对象,需要设置在 ParamAttr 或者 optimizer (例如 SGDOptimizer )中,在 ParamAttr 中设置时, 只对该网络层中的参数生效;在 optimizer 中设置时,会对所有的参数生效;如果同时设置, 在 ParamAttr 中设置的优先级会高于在 optimizer 中设置。

具体实现中,L2权重衰减正则化的计算公式如下:

L2Decay - 图1

参数

  • regularization_coeff (float) – 正则化系数,默认值为0.0。

代码示例 1

  1. import paddle.fluid as fluid
  2. main_prog = fluid.Program()
  3. startup_prog = fluid.Program()
  4. with fluid.program_guard(main_prog, startup_prog):
  5. data = fluid.layers.data(name='image', shape=[3, 28, 28], dtype='float32')
  6. label = fluid.layers.data(name='label', shape=[1], dtype='int64')
  7. hidden = fluid.layers.fc(input=data, size=128, act='relu')
  8. prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
  9. loss = fluid.layers.cross_entropy(input=prediction, label=label)
  10. avg_loss = fluid.layers.mean(loss)
  11. optimizer = fluid.optimizer.Adagrad(
  12. learning_rate=1e-4,
  13. regularization=fluid.regularizer.L2Decay(
  14. regularization_coeff=0.1))
  15. optimizer.minimize(avg_loss)

代码示例 2

  1. # 在 ParamAttr 和 optimizer 中同时设置正则化
  2. import paddle.fluid as fluid
  3. l1 = fluid.regularizer.L1Decay(regularization_coeff=0.1)
  4. l2 = fluid.regularizer.L2Decay(regularization_coeff=0.1)
  5. x = fluid.layers.uniform_random([3,4])
  6. # 在ParamAttr中设置L1正则化
  7. w_param = fluid.ParamAttr(regularizer=l1)
  8. hidden1 = fluid.layers.fc(x, 8, param_attr=w_param) # fc_0.w_0(L1), fc_0.b_0
  9. hidden2 = fluid.layers.fc(hidden1, 16, param_attr=w_param) # fc_1.w_0(L1), fc_1.b_0
  10. predict = fluid.layers.fc(hidden2, 32) # fc_3.w_0, fc_3.b_0
  11. avg_loss = fluid.layers.mean(predict)
  12. # 在optimizer中设置L2正则化
  13. optimizer = fluid.optimizer.SGD(learning_rate=1e-4, regularization=l2)
  14. optimizer.minimize(avg_loss)
  15. # 将会打印出提示信息:
  16. # Regularization of [fc_0.w_0, fc_1.w_0] have been set by ParamAttr or WeightNormParamAttr already.
  17. # So, the Regularization of Optimizer will not take effect for these parameters!