迁移学习教程

译者:片刻

校对者:cluster

作者: Sasank Chilamkurthy

在本教程中,您将学习如何使用迁移学习来训练您的网络。您可以在 cs231n 笔记 上阅读更多关于迁移学习的信息

引用这些笔记:

在实践中,很少有人从头开始训练整个卷积网络(随机初始化),因为拥有足够大小的数据集是相对罕见的。相反,通常在非常大的数据集(例如 ImageNet,其包含具有1000个类别的120万个图像)上预先训练 ConvNet,然后使用 ConvNet 对感兴趣的任务进行初始化或用作固定特征提取器。

如下是两个主要的迁移学习场景:

  • Finetuning the convnet: 我们使用预训练网络初始化网络,而不是随机初始化,就像在imagenet 1000数据集上训练的网络一样。其余训练看起来像往常一样。(此微调过程对应引用中所说的初始化)
  • ConvNet as fixed feature extractor: 在这里,我们将冻结除最终完全连接层之外的所有网络的权重。最后一个全连接层被替换为具有随机权重的新层,并且仅训练该层。(此步对应引用中的固定特征提取器)
  1. # License: BSD
  2. # Author: Sasank Chilamkurthy
  3. from __future__ import print_function, division
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. from torch.optim import lr_scheduler
  8. import numpy as np
  9. import torchvision
  10. from torchvision import datasets, models, transforms
  11. import matplotlib.pyplot as plt
  12. import time
  13. import os
  14. import copy
  15. plt.ion() # interactive mode

加载数据

我们将使用 torchvision 和 torch.utils.data 包来加载数据。

我们今天要解决的问题是训练一个模型来对 蚂蚁蜜蜂 进行分类。我们有大约120个训练图像,每个图像用于 蚂蚁蜜蜂。每个类有75个验证图像。通常,如果从头开始训练,这是一个非常小的数据集。由于我们正在使用迁移学习,我们应该能够合理地泛化。

该数据集是 imagenet 的一个非常小的子集。

注意

此处 下载数据并将其解压缩到当前目录。

  1. # Data augmentation and normalization for training
  2. # Just normalization for validation
  3. data_transforms = {
  4. 'train': transforms.Compose([
  5. transforms.RandomResizedCrop(224),
  6. transforms.RandomHorizontalFlip(),
  7. transforms.ToTensor(),
  8. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  9. ]),
  10. 'val': transforms.Compose([
  11. transforms.Resize(256),
  12. transforms.CenterCrop(224),
  13. transforms.ToTensor(),
  14. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  15. ]),
  16. }
  17. data_dir = 'data/hymenoptera_data'
  18. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
  19. data_transforms[x])
  20. for x in ['train', 'val']}
  21. dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
  22. shuffle=True, num_workers=4)
  23. for x in ['train', 'val']}
  24. dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
  25. class_names = image_datasets['train'].classes
  26. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

可视化一些图像

让我们可视化一些训练图像,以便了解数据增强。

  1. def imshow(inp, title=None):
  2. """Imshow for Tensor."""
  3. inp = inp.numpy().transpose((1, 2, 0))
  4. mean = np.array([0.485, 0.456, 0.406])
  5. std = np.array([0.229, 0.224, 0.225])
  6. inp = std * inp + mean
  7. inp = np.clip(inp, 0, 1)
  8. plt.imshow(inp)
  9. if title is not None:
  10. plt.title(title)
  11. plt.pause(0.001) # pause a bit so that plots are updated
  12. # Get a batch of training data
  13. inputs, classes = next(iter(dataloaders['train']))
  14. # Make a grid from batch
  15. out = torchvision.utils.make_grid(inputs)
  16. imshow(out, title=[class_names[x] for x in classes])

迁移学习教程 - 图1

训练模型

现在, 让我们编写一个通用函数来训练模型. 这里, 我们将会举例说明:

  • 调度学习率
  • 保存最佳的学习模型

下面函数中, scheduler 参数是 torch.optim.lr_scheduler 中的 LR scheduler 对象.

  1. def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
  2. since = time.time()
  3. best_model_wts = copy.deepcopy(model.state_dict())
  4. best_acc = 0.0
  5. for epoch in range(num_epochs):
  6. print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  7. print('-' * 10)
  8. # Each epoch has a training and validation phase
  9. for phase in ['train', 'val']:
  10. if phase == 'train':
  11. scheduler.step()
  12. model.train() # Set model to training mode
  13. else:
  14. model.eval() # Set model to evaluate mode
  15. running_loss = 0.0
  16. running_corrects = 0
  17. # Iterate over data.
  18. for inputs, labels in dataloaders[phase]:
  19. inputs = inputs.to(device)
  20. labels = labels.to(device)
  21. # zero the parameter gradients
  22. optimizer.zero_grad()
  23. # forward
  24. # track history if only in train
  25. with torch.set_grad_enabled(phase == 'train'):
  26. outputs = model(inputs)
  27. _, preds = torch.max(outputs, 1)
  28. loss = criterion(outputs, labels)
  29. # backward + optimize only if in training phase
  30. if phase == 'train':
  31. loss.backward()
  32. optimizer.step()
  33. # statistics
  34. running_loss += loss.item() * inputs.size(0)
  35. running_corrects += torch.sum(preds == labels.data)
  36. epoch_loss = running_loss / dataset_sizes[phase]
  37. epoch_acc = running_corrects.double() / dataset_sizes[phase]
  38. print('{} Loss: {:.4f} Acc: {:.4f}'.format(
  39. phase, epoch_loss, epoch_acc))
  40. # deep copy the model
  41. if phase == 'val' and epoch_acc > best_acc:
  42. best_acc = epoch_acc
  43. best_model_wts = copy.deepcopy(model.state_dict())
  44. print()
  45. time_elapsed = time.time() - since
  46. print('Training complete in {:.0f}m {:.0f}s'.format(
  47. time_elapsed // 60, time_elapsed % 60))
  48. print('Best val Acc: {:4f}'.format(best_acc))
  49. # load best model weights
  50. model.load_state_dict(best_model_wts)
  51. return model

可视化模型预测

用于显示少量图像预测的通用功能

  1. def visualize_model(model, num_images=6):
  2. was_training = model.training
  3. model.eval()
  4. images_so_far = 0
  5. fig = plt.figure()
  6. with torch.no_grad():
  7. for i, (inputs, labels) in enumerate(dataloaders['val']):
  8. inputs = inputs.to(device)
  9. labels = labels.to(device)
  10. outputs = model(inputs)
  11. _, preds = torch.max(outputs, 1)
  12. for j in range(inputs.size()[0]):
  13. images_so_far += 1
  14. ax = plt.subplot(num_images//2, 2, images_so_far)
  15. ax.axis('off')
  16. ax.set_title('predicted: {}'.format(class_names[preds[j]]))
  17. imshow(inputs.cpu().data[j])
  18. if images_so_far == num_images:
  19. model.train(mode=was_training)
  20. return
  21. model.train(mode=was_training)

微调卷积网络

加载预训练模型并重置最终的全连接层。

  1. model_ft = models.resnet18(pretrained=True)
  2. num_ftrs = model_ft.fc.in_features
  3. model_ft.fc = nn.Linear(num_ftrs, 2)
  4. model_ft = model_ft.to(device)
  5. criterion = nn.CrossEntropyLoss()
  6. # Observe that all parameters are being optimized
  7. optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
  8. # Decay LR by a factor of 0.1 every 7 epochs
  9. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

训练和评估

CPU上需要大约15-25分钟。但是在GPU上,它只需不到一分钟。

  1. model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
  2. num_epochs=25)

Out:

  1. Epoch 0/24
  2. ----------
  3. train Loss: 0.6022 Acc: 0.6844
  4. val Loss: 0.1765 Acc: 0.9412
  5. Epoch 1/24
  6. ----------
  7. train Loss: 0.4156 Acc: 0.8238
  8. val Loss: 0.2380 Acc: 0.9216
  9. Epoch 2/24
  10. ----------
  11. train Loss: 0.5010 Acc: 0.7951
  12. val Loss: 0.2571 Acc: 0.8954
  13. Epoch 3/24
  14. ----------
  15. train Loss: 0.7152 Acc: 0.7705
  16. val Loss: 0.2060 Acc: 0.9346
  17. Epoch 4/24
  18. ----------
  19. train Loss: 0.5779 Acc: 0.8033
  20. val Loss: 0.4542 Acc: 0.8889
  21. Epoch 5/24
  22. ----------
  23. train Loss: 0.5653 Acc: 0.7951
  24. val Loss: 0.3167 Acc: 0.8824
  25. Epoch 6/24
  26. ----------
  27. train Loss: 0.4948 Acc: 0.8074
  28. val Loss: 0.3238 Acc: 0.8758
  29. Epoch 7/24
  30. ----------
  31. train Loss: 0.3712 Acc: 0.8361
  32. val Loss: 0.2284 Acc: 0.9020
  33. Epoch 8/24
  34. ----------
  35. train Loss: 0.2982 Acc: 0.8730
  36. val Loss: 0.3488 Acc: 0.8497
  37. Epoch 9/24
  38. ----------
  39. train Loss: 0.2491 Acc: 0.8934
  40. val Loss: 0.2405 Acc: 0.8889
  41. Epoch 10/24
  42. ----------
  43. train Loss: 0.3498 Acc: 0.8238
  44. val Loss: 0.2435 Acc: 0.8889
  45. Epoch 11/24
  46. ----------
  47. train Loss: 0.3042 Acc: 0.8648
  48. val Loss: 0.3021 Acc: 0.8627
  49. Epoch 12/24
  50. ----------
  51. train Loss: 0.2500 Acc: 0.8852
  52. val Loss: 0.2340 Acc: 0.8954
  53. Epoch 13/24
  54. ----------
  55. train Loss: 0.3246 Acc: 0.8730
  56. val Loss: 0.2236 Acc: 0.9020
  57. Epoch 14/24
  58. ----------
  59. train Loss: 0.2976 Acc: 0.8566
  60. val Loss: 0.2928 Acc: 0.8562
  61. Epoch 15/24
  62. ----------
  63. train Loss: 0.2733 Acc: 0.8934
  64. val Loss: 0.2370 Acc: 0.8954
  65. Epoch 16/24
  66. ----------
  67. train Loss: 0.3502 Acc: 0.8361
  68. val Loss: 0.2792 Acc: 0.8824
  69. Epoch 17/24
  70. ----------
  71. train Loss: 0.2215 Acc: 0.8975
  72. val Loss: 0.2790 Acc: 0.8497
  73. Epoch 18/24
  74. ----------
  75. train Loss: 0.3929 Acc: 0.8484
  76. val Loss: 0.2648 Acc: 0.8824
  77. Epoch 19/24
  78. ----------
  79. train Loss: 0.3227 Acc: 0.8607
  80. val Loss: 0.2643 Acc: 0.8693
  81. Epoch 20/24
  82. ----------
  83. train Loss: 0.3816 Acc: 0.8484
  84. val Loss: 0.2395 Acc: 0.9085
  85. Epoch 21/24
  86. ----------
  87. train Loss: 0.2904 Acc: 0.8975
  88. val Loss: 0.2399 Acc: 0.8889
  89. Epoch 22/24
  90. ----------
  91. train Loss: 0.3375 Acc: 0.8648
  92. val Loss: 0.2380 Acc: 0.9020
  93. Epoch 23/24
  94. ----------
  95. train Loss: 0.2107 Acc: 0.9139
  96. val Loss: 0.2251 Acc: 0.9085
  97. Epoch 24/24
  98. ----------
  99. train Loss: 0.3243 Acc: 0.8525
  100. val Loss: 0.2545 Acc: 0.8824
  101. Training complete in 1m 7s
  102. Best val Acc: 0.941176
  1. visualize_model(model_ft)

迁移学习教程 - 图2

ConvNet 作为固定特征提取器

在这里,我们需要冻结除最后一层之外的所有网络。我们需要设置 requires_grad == False 冻结参数,以便在 backward() 中不计算梯度。

您可以在 此处 的文档中阅读更多相关信息。

  1. model_conv = torchvision.models.resnet18(pretrained=True)
  2. for param in model_conv.parameters():
  3. param.requires_grad = False
  4. # Parameters of newly constructed modules have requires_grad=True by default
  5. num_ftrs = model_conv.fc.in_features
  6. model_conv.fc = nn.Linear(num_ftrs, 2)
  7. model_conv = model_conv.to(device)
  8. criterion = nn.CrossEntropyLoss()
  9. # Observe that only parameters of final layer are being optimized as
  10. # opposed to before.
  11. optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
  12. # Decay LR by a factor of 0.1 every 7 epochs
  13. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

训练和评估

在CPU上,与前一个场景相比,这将花费大约一半的时间。这是预期的,因为不需要为大多数网络计算梯度。但是,前向传递需要计算梯度。

  1. model_conv = train_model(model_conv, criterion, optimizer_conv,
  2. exp_lr_scheduler, num_epochs=25)

Out:

  1. Epoch 0/24
  2. ----------
  3. train Loss: 0.5666 Acc: 0.6967
  4. val Loss: 0.2794 Acc: 0.8824
  5. Epoch 1/24
  6. ----------
  7. train Loss: 0.5590 Acc: 0.7582
  8. val Loss: 0.1473 Acc: 0.9477
  9. Epoch 2/24
  10. ----------
  11. train Loss: 0.4187 Acc: 0.8156
  12. val Loss: 0.3534 Acc: 0.8693
  13. Epoch 3/24
  14. ----------
  15. train Loss: 0.5248 Acc: 0.7459
  16. val Loss: 0.1848 Acc: 0.9477
  17. Epoch 4/24
  18. ----------
  19. train Loss: 0.4315 Acc: 0.8115
  20. val Loss: 0.1640 Acc: 0.9477
  21. Epoch 5/24
  22. ----------
  23. train Loss: 0.3948 Acc: 0.8238
  24. val Loss: 0.1609 Acc: 0.9542
  25. Epoch 6/24
  26. ----------
  27. train Loss: 0.3359 Acc: 0.8648
  28. val Loss: 0.1734 Acc: 0.9608
  29. Epoch 7/24
  30. ----------
  31. train Loss: 0.3681 Acc: 0.8443
  32. val Loss: 0.1715 Acc: 0.9477
  33. Epoch 8/24
  34. ----------
  35. train Loss: 0.4034 Acc: 0.8361
  36. val Loss: 0.1602 Acc: 0.9477
  37. Epoch 9/24
  38. ----------
  39. train Loss: 0.2983 Acc: 0.8811
  40. val Loss: 0.1561 Acc: 0.9542
  41. Epoch 10/24
  42. ----------
  43. train Loss: 0.4516 Acc: 0.7992
  44. val Loss: 0.1660 Acc: 0.9477
  45. Epoch 11/24
  46. ----------
  47. train Loss: 0.3516 Acc: 0.8484
  48. val Loss: 0.1551 Acc: 0.9542
  49. Epoch 12/24
  50. ----------
  51. train Loss: 0.3592 Acc: 0.8238
  52. val Loss: 0.1525 Acc: 0.9477
  53. Epoch 13/24
  54. ----------
  55. train Loss: 0.2982 Acc: 0.8648
  56. val Loss: 0.1772 Acc: 0.9542
  57. Epoch 14/24
  58. ----------
  59. train Loss: 0.3352 Acc: 0.8484
  60. val Loss: 0.1583 Acc: 0.9542
  61. Epoch 15/24
  62. ----------
  63. train Loss: 0.2981 Acc: 0.8770
  64. val Loss: 0.2133 Acc: 0.9412
  65. Epoch 16/24
  66. ----------
  67. train Loss: 0.2778 Acc: 0.8811
  68. val Loss: 0.1934 Acc: 0.9542
  69. Epoch 17/24
  70. ----------
  71. train Loss: 0.3678 Acc: 0.8156
  72. val Loss: 0.1846 Acc: 0.9477
  73. Epoch 18/24
  74. ----------
  75. train Loss: 0.3520 Acc: 0.8197
  76. val Loss: 0.1577 Acc: 0.9542
  77. Epoch 19/24
  78. ----------
  79. train Loss: 0.3342 Acc: 0.8402
  80. val Loss: 0.1734 Acc: 0.9542
  81. Epoch 20/24
  82. ----------
  83. train Loss: 0.3649 Acc: 0.8361
  84. val Loss: 0.1554 Acc: 0.9412
  85. Epoch 21/24
  86. ----------
  87. train Loss: 0.2948 Acc: 0.8566
  88. val Loss: 0.1878 Acc: 0.9542
  89. Epoch 22/24
  90. ----------
  91. train Loss: 0.3047 Acc: 0.8811
  92. val Loss: 0.1760 Acc: 0.9477
  93. Epoch 23/24
  94. ----------
  95. train Loss: 0.3363 Acc: 0.8648
  96. val Loss: 0.1660 Acc: 0.9542
  97. Epoch 24/24
  98. ----------
  99. train Loss: 0.2745 Acc: 0.8770
  100. val Loss: 0.1853 Acc: 0.9542
  101. Training complete in 0m 34s
  102. Best val Acc: 0.960784
  1. visualize_model(model_conv)
  2. plt.ioff()
  3. plt.show()

迁移学习教程 - 图3

脚本总运行时间: (1分54.087秒)

Download Python source code: transfer_learning_tutorial.pyDownload Jupyter notebook: transfer_learning_tutorial.ipynb

由Sphinx-Gallery生成的图库