ModelCheckpoint

class paddle.callbacks. ModelCheckpoint ( save_freq=1, save_dir=None ) [源代码]

ModelCheckpoint 回调类和model.fit联合使用,在训练阶段,保存模型权重和优化器状态信息。当前仅支持在固定的epoch间隔保存模型,不支持按照batch的间隔保存。

子方法可以参考基类。

参数:

  • save_freq (int,可选) - 间隔多少个epoch保存模型。默认值:1。

  • save_dir (int,可选) - 保存模型的文件夹。如果不设定,将不会保存模型。默认值:None。

代码示例

  1. import paddle
  2. import paddle.vision.transforms as T
  3. from paddle.static import InputSpec
  4. inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
  5. labels = [InputSpec([None, 1], 'int64', 'label')]
  6. transform = T.Compose([
  7. T.Transpose(),
  8. T.Normalize([127.5], [127.5])
  9. ])
  10. train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
  11. lenet = paddle.vision.LeNet()
  12. model = paddle.Model(lenet,
  13. inputs, labels)
  14. optim = paddle.optimizer.Adam(0.001, parameters=lenet.parameters())
  15. model.prepare(optimizer=optim,
  16. loss=paddle.nn.CrossEntropyLoss(),
  17. metrics=paddle.metric.Accuracy())
  18. callback = paddle.callbacks.ModelCheckpoint(save_dir='./temp')
  19. model.fit(train_dataset, batch_size=64, callbacks=callback)