训练时验证模型

Linux Ascend GPU CPU 初级 中级 高级 模型导出 模型训练

训练时验证模型 - 图1 训练时验证模型 - 图2

概述

在面对复杂网络时,往往需要进行几十甚至几百次的epoch训练。在训练之前,很难掌握在训练到第几个epoch时,模型的精度能达到满足要求的程度,所以经常会采用一边训练的同时,在相隔固定epoch的位置对模型进行精度验证,并保存相应的模型,等训练完毕后,通过查看对应模型精度的变化就能迅速地挑选出相对最优的模型,本文将采用这种方法,以LeNet网络为样本,进行示例。

流程如下:

  1. 定义回调函数EvalCallBack,实现同步进行训练和验证。

  2. 定义训练网络并执行。

  3. 将不同epoch下的模型精度绘制出折线图并挑选最优模型。

完整示例请参考notebook

定义回调函数EvalCallBack

实现思想:每隔n个epoch验证一次模型精度,由于在自定义函数中实现,如需了解详细用法,请参考API说明

核心实现:回调函数的epoch_end内设置验证点,如下:

cur_epoch % eval_per_epoch == 0:即每eval_per_epoch个epoch结束时,验证一次模型精度。

  • cur_epoch:当前训练过程的epoch数值。

  • eval_per_epoch:用户自定义数值,即验证频次。

其他参数解释:

  • model:即是MindSpore中的Model函数。

  • eval_dataset:验证数据集。

  • epoch_per_eval:记录验证模型的精度和相应的epoch数,其数据形式为{"epoch": [], "acc": []}

  1. from mindspore.train.callback import Callback
  2. class EvalCallBack(Callback):
  3. def __init__(self, model, eval_dataset, eval_per_epoch, epoch_per_eval):
  4. self.model = model
  5. self.eval_dataset = eval_dataset
  6. self.eval_per_epoch = eval_per_epoch
  7. self.epoch_per_eval = epoch_per_eval
  8. def epoch_end(self, run_context):
  9. cb_param = run_context.original_args()
  10. cur_epoch = cb_param.cur_epoch_num
  11. if cur_epoch % self.eval_per_epoch == 0:
  12. acc = self.model.eval(self.eval_dataset, dataset_sink_mode=True)
  13. self.epoch_per_eval["epoch"].append(cur_epoch)
  14. self.epoch_per_eval["acc"].append(acc["Accuracy"])
  15. print(acc)

定义训练网络并执行

在保存模型的参数CheckpointConfig中,需计算好单个epoch中的step数,再根据需要进行验证模型精度的频次对应,本次示例为1875个step/epoch,按照每两个epoch验证一次的思想,这里设置save_checkpoint_steps=eval_per_epoch*1875,其中变量eval_per_epoch等于2。

参数解释:

  • config_ck:定义保存模型信息。

    • save_checkpoint_steps:每多少个step保存一次模型。

    • keep_checkpoint_max:设置保存模型数量的上限。

  • ckpoint_cb:定义模型保存的名称及路径信息。

  • model:定义模型。

  • model.train:模型训练函数。

  • epoch_per_eval:定义收集epoch数和对应模型精度信息的字典。

  1. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
  2. from mindspore.train import Model
  3. from mindspore import context
  4. from mindspore.nn.metrics import Accuracy
  5. if __name__ == "__main__":
  6. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  7. ckpt_save_dir = "./lenet_ckpt"
  8. eval_per_epoch = 2
  9. ... ...
  10. # need to calculate how many steps are in each epoch,in this example, 1875 steps per epoch
  11. config_ck = CheckpointConfig(save_checkpoint_steps=eval_per_epoch*1875, keep_checkpoint_max=15)
  12. ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",directory=ckpt_save_dir, config=config_ck)
  13. model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
  14. epoch_per_eval = {"epoch": [], "acc": []}
  15. eval_cb = EvalCallBack(model, eval_data, eval_per_epoch, epoch_per_eval)
  16. model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375), eval_cb],
  17. dataset_sink_mode=True)

输出结果:

  1. epoch: 1 step: 375, loss is 2.298612
  2. epoch: 1 step: 750, loss is 2.075152
  3. epoch: 1 step: 1125, loss is 0.39205977
  4. epoch: 1 step: 1500, loss is 0.12368304
  5. epoch: 1 step: 1875, loss is 0.20988345
  6. epoch: 2 step: 375, loss is 0.20582482
  7. epoch: 2 step: 750, loss is 0.029070046
  8. epoch: 2 step: 1125, loss is 0.041760832
  9. epoch: 2 step: 1500, loss is 0.067035824
  10. epoch: 2 step: 1875, loss is 0.0050643035
  11. {'Accuracy': 0.9763621794871795}
  12. ... ...
  13. epoch: 9 step: 375, loss is 0.021227183
  14. epoch: 9 step: 750, loss is 0.005586236
  15. epoch: 9 step: 1125, loss is 0.029125651
  16. epoch: 9 step: 1500, loss is 0.00045874066
  17. epoch: 9 step: 1875, loss is 0.023556218
  18. epoch: 10 step: 375, loss is 0.0005807788
  19. epoch: 10 step: 750, loss is 0.02574059
  20. epoch: 10 step: 1125, loss is 0.108463734
  21. epoch: 10 step: 1500, loss is 0.01950589
  22. epoch: 10 step: 1875, loss is 0.10563098
  23. {'Accuracy': 0.979667467948718}

在同一目录找到lenet_ckpt文件夹,文件夹中保存了5个模型,和一个计算图相关数据,其结构如下:

  1. lenet_ckpt
  2. ├── checkpoint_lenet-10_1875.ckpt
  3. ├── checkpoint_lenet-2_1875.ckpt
  4. ├── checkpoint_lenet-4_1875.ckpt
  5. ├── checkpoint_lenet-6_1875.ckpt
  6. ├── checkpoint_lenet-8_1875.ckpt
  7. └── checkpoint_lenet-graph.meta

定义函数绘制不同epoch下模型的精度

定义绘图函数eval_show,将epoch_per_eval载入到eval_show中,绘制出不同epoch下模型的验证精度折线图。

  1. import matplotlib.pyplot as plt
  2. def eval_show(epoch_per_eval):
  3. plt.xlabel("epoch number")
  4. plt.ylabel("Model accuracy")
  5. plt.title("Model accuracy variation chart")
  6. plt.plot(epoch_per_eval["epoch"], epoch_per_eval["acc"], "red")
  7. plt.show()
  8. eval_show(epoch_per_eval)

输出结果:

png

从上图可以一目了然地挑选出需要的最优模型。

总结

本次使用MNIST数据集通过卷积神经网络LeNet5进行训练,着重介绍了在进行模型训练的同时进行模型的验证,保存对应epoch的模型,并从中挑选出最优模型的方法。