加载模型用于推理或迁移学习

Linux Ascend GPU CPU 模型加载 初级 中级 高级

加载模型用于推理或迁移学习 - 图1

概述

在模型训练过程中保存在本地的CheckPoint文件,或从MindSpore Hub下载的CheckPoint文件,都可以帮助用户进行推理或迁移学习使用。

以下通过示例来介绍如何通过本地加载或Hub加载模型,用于推理验证和迁移学习。

本地加载模型

用于推理验证

针对仅推理场景可以使用load_checkpoint把参数直接加载到网络中,以便进行后续的推理验证。

示例代码如下:

  1. resnet = ResNet50()
  2. load_checkpoint("resnet50-2_32.ckpt", net=resnet)
  3. dateset_eval = create_dataset(os.path.join(mnist_path, "test"), 32, 1) # define the test dataset
  4. loss = CrossEntropyLoss()
  5. model = Model(resnet, loss, metrics={"accuracy"})
  6. acc = model.eval(dataset_eval)
  • load_checkpoint方法会把参数文件中的网络参数加载到模型中。加载后,网络中的参数就是CheckPoint保存的。

  • eval方法会验证训练后模型的精度。

用于迁移学习

针对任务中断再训练及微调(Fine Tune)场景,可以加载网络参数和优化器参数到模型中。

示例代码如下:

  1. # return a parameter dict for model
  2. param_dict = load_checkpoint("resnet50-2_32.ckpt")
  3. resnet = ResNet50()
  4. opt = Momentum()
  5. # load the parameter into net
  6. load_param_into_net(resnet, param_dict)
  7. # load the parameter into optimizer
  8. load_param_into_net(opt, param_dict)
  9. loss = SoftmaxCrossEntropyWithLogits()
  10. model = Model(resnet, loss, opt)
  11. model.train(epoch, dataset)
  • load_checkpoint方法会返回一个参数字典。

  • load_param_into_net会把参数字典中相应的参数加载到网络或优化器中。

从Hub加载模型

用于推理验证

mindspore_hub.load API用于加载预训练模型,可以实现一行代码完成模型的加载。主要的模型加载流程如下:

  1. MindSpore Hub官网上搜索感兴趣的模型。

    例如,想使用GoogleNet对CIFAR-10数据集进行分类,可以在MindSpore Hub官网上使用关键词GoogleNet进行搜索。页面将会返回与GoogleNet相关的所有模型。进入相关模型页面之后,获得详情页url

  2. 使用url完成模型的加载,示例代码如下:

    1. import mindspore_hub as mshub
    2. import mindspore
    3. from mindspore import context, Tensor, nn
    4. from mindspore.train.model import Model
    5. from mindspore.common import dtype as mstype
    6. import mindspore.dataset.vision.py_transforms as py_transforms
    7. context.set_context(mode=context.GRAPH_MODE,
    8. device_target="Ascend",
    9. device_id=0)
    10. model = "mindspore/ascend/0.7/googlenet_v1_cifar10"
    11. # Initialize the number of classes based on the pre-trained model.
    12. network = mshub.load(model, num_classes=10)
    13. network.set_train(False)
    14. # ...
  3. 完成模型加载后,可以使用MindSpore进行推理,参考这里

用于迁移学习

通过mindspore_hub.load完成模型加载后,可以增加一个额外的参数项只加载神经网络的特征提取部分,这样我们就能很容易地在之后增加一些新的层进行迁移学习。当模型开发者将额外的参数(例如 include_top)添加到模型构造中时,可以在模型的详情页中找到这个功能。include_top取值为True或者False,表示是否保留顶层的全连接网络。

下面我们以GoogleNet为例,说明如何加载一个基于ImageNet的预训练模型,并在特定的子任务数据集上进行迁移学习(重训练)。主要的步骤如下:

  1. MindSpore Hub官网上搜索感兴趣的模型,并从网站上获取特定的url

  2. 使用url进行MindSpore Hub模型的加载,注意:include_top参数需要模型开发者提供

    1. import mindspore
    2. from mindspore import nn, context, Tensor
    3. from mindspore.train.serialization import save_checkpoint
    4. from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
    5. import mindspore.ops as ops
    6. from mindspore.nn import Momentum
    7. import math
    8. import numpy as np
    9. import mindspore_hub as mshub
    10. from src.dataset import create_dataset
    11. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
    12. save_graphs=False)
    13. model_url = "mindspore/ascend/0.7/googlenet_v1_cifar10"
    14. network = mshub.load(model_url, include_top=False, num_classes=1000)
    15. network.set_train(False)
  3. 在现有模型结构基础上,增加一个与新任务相关的分类层。

    1. class ReduceMeanFlatten(nn.Cell):
    2. def __init__(self):
    3. super(ReduceMeanFlatten, self).__init__()
    4. self.mean = ops.ReduceMean(keep_dims=True)
    5. self.flatten = nn.Flatten()
    6. def construct(self, x):
    7. x = self.mean(x, (2, 3))
    8. x = self.flatten(x)
    9. return x
    10. # Check MindSpore Hub website to conclude that the last output shape is 1024.
    11. last_channel = 1024
    12. # The number of classes in target task is 26.
    13. num_classes = 26
    14. reducemean_flatten = ReduceMeanFlatten()
    15. classification_layer = nn.Dense(last_channel, num_classes)
    16. classification_layer.set_train(True)
    17. train_network = nn.SequentialCell([network, reducemean_flatten, classification_layer])
  4. 为模型训练选择损失函数和优化器。

    1. epoch_size = 60
    2. # Wrap the backbone network with loss.
    3. loss_fn = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
    4. loss_net = nn.WithLossCell(train_network, loss_fn)
    5. lr = get_lr(global_step=0,
    6. lr_init=0,
    7. lr_max=0.05,
    8. lr_end=0.001,
    9. warmup_epochs=5,
    10. total_epochs=epoch_size)
    11. # Create an optimizer.
    12. optim = Momentum(filter(lambda x: x.requires_grad, loss_net.get_parameters()), Tensor(lr), 0.9, 4e-5)
    13. train_net = nn.TrainOneStepCell(loss_net, optim)
  5. 构建数据集,开始重训练。

    如下所示,进行微调任务的数据集为垃圾分类数据集,存储位置为/ssd/data/garbage/train

    1. dataset = create_dataset("/ssd/data/garbage/train",
    2. do_train=True,
    3. batch_size=32,
    4. platform="Ascend",
    5. repeat_num=1)
    6. for epoch in range(epoch_size):
    7. for i, items in enumerate(dataset):
    8. data, label = items
    9. data = mindspore.Tensor(data)
    10. label = mindspore.Tensor(label)
    11. loss = train_net(data, label)
    12. print(f"epoch: {epoch}/{epoch_size}, loss: {loss}")
    13. # Save the ckpt file for each epoch.
    14. ckpt_path = f"./ckpt/garbage_finetune_epoch{epoch}.ckpt"
    15. save_checkpoint(train_network, ckpt_path)
  6. 在测试集上测试模型精度。

    1. from mindspore.train.serialization import load_checkpoint, load_param_into_net
    2. network = mshub.load('mindspore/ascend/0.7/googlenet_v1_cifar10', pretrained=False,
    3. include_top=False, num_classes=1000)
    4. reducemean_flatten = ReduceMeanFlatten()
    5. classification_layer = nn.Dense(last_channel, num_classes)
    6. classification_layer.set_train(False)
    7. softmax = nn.Softmax()
    8. network = nn.SequentialCell([network, reducemean_flatten,
    9. classification_layer, softmax])
    10. # Load a pre-trained ckpt file.
    11. ckpt_path = "./ckpt/garbage_finetune_epoch59.ckpt"
    12. trained_ckpt = load_checkpoint(ckpt_path)
    13. load_param_into_net(network, trained_ckpt)
    14. # Define loss and create model.
    15. model = Model(network, metrics={'acc'}, eval_network=network)
    16. eval_dataset = create_dataset("/ssd/data/garbage/test",
    17. do_train=True,
    18. batch_size=32,
    19. platform="Ascend",
    20. repeat_num=1)
    21. res = model.eval(eval_dataset)
    22. print("result:", res, "ckpt=", ckpt_path)