LeNet在眼疾识别数据集iChallenge-PM上的应用

iChallenge-PM是百度大脑和中山大学中山眼科中心联合举办的iChallenge比赛中,提供的关于病理性近视(Pathologic Myopia,PM)的医疗类数据集,包含1200个受试者的眼底视网膜图片,训练、验证和测试数据集各400张。下面我们详细介绍LeNet在iChallenge-PM上的训练过程。


说明:

如今近视已经成为困扰人们健康的一项全球性负担,在近视人群中,有超过35%的人患有重度近视。近似将会导致眼睛的光轴被拉长,有可能引起视网膜或者络网膜的病变。随着近似度数的不断加深,高度近似有可能引发病理性病变,这将会导致以下几种症状:视网膜或者络网膜发生退化、视盘区域萎缩、漆裂样纹损害、Fuchs斑等。因此,及早发现近似患者眼睛的病变并采取治疗,显得非常重要。

数据可以从AIStudio下载


数据集准备

/home/aistudio/data/data19065 目录包括如下三个文件,解压缩后存放在/home/aistudio/work/palm目录下。

  • training.zip:包含训练中的图片和标签
  • validation.zip:包含验证集的图片
  • valid_gt.zip:包含验证集的标签

注意

valid_gt.zip文件解压缩之后,需要将/home/aistudio/work/palm/PALM-Validation-GT/目录下的PM_Label_and_Fovea_Location.xlsx文件转存成csv格式,本节代码示例中已经提前转成文件labels.csv。


  1. # 初次运行时将注释取消,以便解压文件
  2. # 如果已经解压过了,则不需要运行此段代码,否则文件已经存在解压会报错
  3. !unzip -o -q -d /home/aistudio/work/palm /home/aistudio/data/data19065/training.zip
  4. %cd /home/aistudio/work/palm/PALM-Training400/
  5. !unzip -o -q PALM-Training400.zip
  6. !unzip -o -q -d /home/aistudio/work/palm /home/aistudio/data/data19065/validation.zip
  7. !unzip -o -q -d /home/aistudio/work/palm /home/aistudio/data/data19065/valid_gt.zip
  1. /home/aistudio/work/palm/PALM-Training400

查看数据集图片

iChallenge-PM中既有病理性近视患者的眼底图片,也有非病理性近视患者的图片,命名规则如下:

  • 病理性近视(PM):文件名以P开头

  • 非病理性近视(non-PM):

    • 高度近似(high myopia):文件名以H开头

    • 正常眼睛(normal):文件名以N开头

我们将病理性患者的图片作为正样本,标签为1; 非病理性患者的图片作为负样本,标签为0。从数据集中选取两张图片,通过LeNet提取特征,构建分类器,对正负样本进行分类,并将图片显示出来。代码如下所示:

  1. import os
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. %matplotlib inline
  5. from PIL import Image
  6. DATADIR = '/home/aistudio/work/palm/PALM-Training400/PALM-Training400'
  7. # 文件名以N开头的是正常眼底图片,以P开头的是病变眼底图片
  8. file1 = 'N0012.jpg'
  9. file2 = 'P0095.jpg'
  10. # 读取图片
  11. img1 = Image.open(os.path.join(DATADIR, file1))
  12. img1 = np.array(img1)
  13. img2 = Image.open(os.path.join(DATADIR, file2))
  14. img2 = np.array(img2)
  15. # 画出读取的图片
  16. plt.figure(figsize=(16, 8))
  17. f = plt.subplot(121)
  18. f.set_title('Normal', fontsize=20)
  19. plt.imshow(img1)
  20. f = plt.subplot(122)
  21. f.set_title('PM', fontsize=20)
  22. plt.imshow(img2)
  23. plt.show()
  1. 2020-03-25 19:44:41,518-INFO: font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
  2. 2020-03-25 19:44:41,916-INFO: generated new fontManager

LeNet在眼疾识别数据集iChallenge-PM上的应用 - 图1

  1. <Figure size 1152x576 with 2 Axes>
  1. # 查看图片形状
  2. img1.shape, img2.shape
  1. ((2056, 2124, 3), (2056, 2124, 3))

定义数据读取器

使用OpenCV从磁盘读入图片,将每张图缩放到

LeNet在眼疾识别数据集iChallenge-PM上的应用 - 图2 大小,并且将像素值调整到 LeNet在眼疾识别数据集iChallenge-PM上的应用 - 图3 之间,代码如下所示:

  1. import cv2
  2. import random
  3. import numpy as np
  4. # 对读入的图像数据进行预处理
  5. def transform_img(img):
  6. # 将图片尺寸缩放道 224x224
  7. img = cv2.resize(img, (224, 224))
  8. # 读入的图像数据格式是[H, W, C]
  9. # 使用转置操作将其变成[C, H, W]
  10. img = np.transpose(img, (2,0,1))
  11. img = img.astype('float32')
  12. # 将数据范围调整到[-1.0, 1.0]之间
  13. img = img / 255.
  14. img = img * 2.0 - 1.0
  15. return img
  16. # 定义训练集数据读取器
  17. def data_loader(datadir, batch_size=10, mode = 'train'):
  18. # 将datadir目录下的文件列出来,每条文件都要读入
  19. filenames = os.listdir(datadir)
  20. def reader():
  21. if mode == 'train':
  22. # 训练时随机打乱数据顺序
  23. random.shuffle(filenames)
  24. batch_imgs = []
  25. batch_labels = []
  26. for name in filenames:
  27. filepath = os.path.join(datadir, name)
  28. img = cv2.imread(filepath)
  29. img = transform_img(img)
  30. if name[0] == 'H' or name[0] == 'N':
  31. # H开头的文件名表示高度近似,N开头的文件名表示正常视力
  32. # 高度近视和正常视力的样本,都不是病理性的,属于负样本,标签为0
  33. label = 0
  34. elif name[0] == 'P':
  35. # P开头的是病理性近视,属于正样本,标签为1
  36. label = 1
  37. else:
  38. raise('Not excepted file name')
  39. # 每读取一个样本的数据,就将其放入数据列表中
  40. batch_imgs.append(img)
  41. batch_labels.append(label)
  42. if len(batch_imgs) == batch_size:
  43. # 当数据列表的长度等于batch_size的时候,
  44. # 把这些数据当作一个mini-batch,并作为数据生成器的一个输出
  45. imgs_array = np.array(batch_imgs).astype('float32')
  46. labels_array = np.array(batch_labels).astype('float32').reshape(-1, 1)
  47. yield imgs_array, labels_array
  48. batch_imgs = []
  49. batch_labels = []
  50. if len(batch_imgs) > 0:
  51. # 剩余样本数目不足一个batch_size的数据,一起打包成一个mini-batch
  52. imgs_array = np.array(batch_imgs).astype('float32')
  53. labels_array = np.array(batch_labels).astype('float32').reshape(-1, 1)
  54. yield imgs_array, labels_array
  55. return reader
  56. # 定义验证集数据读取器
  57. def valid_data_loader(datadir, csvfile, batch_size=10, mode='valid'):
  58. # 训练集读取时通过文件名来确定样本标签,验证集则通过csvfile来读取每个图片对应的标签
  59. # 请查看解压后的验证集标签数据,观察csvfile文件里面所包含的内容
  60. # csvfile文件所包含的内容格式如下,每一行代表一个样本,
  61. # 其中第一列是图片id,第二列是文件名,第三列是图片标签,
  62. # 第四列和第五列是Fovea的坐标,与分类任务无关
  63. # ID,imgName,Label,Fovea_X,Fovea_Y
  64. # 1,V0001.jpg,0,1157.74,1019.87
  65. # 2,V0002.jpg,1,1285.82,1080.47
  66. # 打开包含验证集标签的csvfile,并读入其中的内容
  67. filelists = open(csvfile).readlines()
  68. def reader():
  69. batch_imgs = []
  70. batch_labels = []
  71. for line in filelists[1:]:
  72. line = line.strip().split(',')
  73. name = line[1]
  74. label = int(line[2])
  75. # 根据图片文件名加载图片,并对图像数据作预处理
  76. filepath = os.path.join(datadir, name)
  77. img = cv2.imread(filepath)
  78. img = transform_img(img)
  79. # 每读取一个样本的数据,就将其放入数据列表中
  80. batch_imgs.append(img)
  81. batch_labels.append(label)
  82. if len(batch_imgs) == batch_size:
  83. # 当数据列表的长度等于batch_size的时候,
  84. # 把这些数据当作一个mini-batch,并作为数据生成器的一个输出
  85. imgs_array = np.array(batch_imgs).astype('float32')
  86. labels_array = np.array(batch_labels).astype('float32').reshape(-1, 1)
  87. yield imgs_array, labels_array
  88. batch_imgs = []
  89. batch_labels = []
  90. if len(batch_imgs) > 0:
  91. # 剩余样本数目不足一个batch_size的数据,一起打包成一个mini-batch
  92. imgs_array = np.array(batch_imgs).astype('float32')
  93. labels_array = np.array(batch_labels).astype('float32').reshape(-1, 1)
  94. yield imgs_array, labels_array
  95. return reader
  1. # 查看数据形状
  2. DATADIR = '/home/aistudio/work/palm/PALM-Training400/PALM-Training400'
  3. train_loader = data_loader(DATADIR,
  4. batch_size=10, mode='train')
  5. data_reader = train_loader()
  6. data = next(data_reader)
  7. data[0].shape, data[1].shape
  1. ((10, 3, 224, 224), (10, 1))

启动训练

  1. # -*- coding: utf-8 -*-
  2. # LeNet 识别眼疾图片
  3. import os
  4. import random
  5. import paddle
  6. import paddle.fluid as fluid
  7. import numpy as np
  8. DATADIR = '/home/aistudio/work/palm/PALM-Training400/PALM-Training400'
  9. DATADIR2 = '/home/aistudio/work/palm/PALM-Validation400'
  10. CSVFILE = '/home/aistudio/work/palm/PALM-Validation-GT/labels.csv'
  11. # 定义训练过程
  12. def train(model):
  13. with fluid.dygraph.guard():
  14. print('start training ... ')
  15. model.train()
  16. epoch_num = 5
  17. # 定义优化器
  18. opt = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameter_list=model.parameters())
  19. # 定义数据读取器,训练数据读取器和验证数据读取器
  20. train_loader = data_loader(DATADIR, batch_size=10, mode='train')
  21. valid_loader = valid_data_loader(DATADIR2, CSVFILE)
  22. for epoch in range(epoch_num):
  23. for batch_id, data in enumerate(train_loader()):
  24. x_data, y_data = data
  25. img = fluid.dygraph.to_variable(x_data)
  26. label = fluid.dygraph.to_variable(y_data)
  27. # 运行模型前向计算,得到预测值
  28. logits = model(img)
  29. # 进行loss计算
  30. loss = fluid.layers.sigmoid_cross_entropy_with_logits(logits, label)
  31. avg_loss = fluid.layers.mean(loss)
  32. if batch_id % 10 == 0:
  33. print("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, avg_loss.numpy()))
  34. # 反向传播,更新权重,清除梯度
  35. avg_loss.backward()
  36. opt.minimize(avg_loss)
  37. model.clear_gradients()
  38. model.eval()
  39. accuracies = []
  40. losses = []
  41. for batch_id, data in enumerate(valid_loader()):
  42. x_data, y_data = data
  43. img = fluid.dygraph.to_variable(x_data)
  44. label = fluid.dygraph.to_variable(y_data)
  45. # 运行模型前向计算,得到预测值
  46. logits = model(img)
  47. # 二分类,sigmoid计算后的结果以0.5为阈值分两个类别
  48. # 计算sigmoid后的预测概率,进行loss计算
  49. pred = fluid.layers.sigmoid(logits)
  50. loss = fluid.layers.sigmoid_cross_entropy_with_logits(logits, label)
  51. # 计算预测概率小于0.5的类别
  52. pred2 = pred * (-1.0) + 1.0
  53. # 得到两个类别的预测概率,并沿第一个维度级联
  54. pred = fluid.layers.concat([pred2, pred], axis=1)
  55. acc = fluid.layers.accuracy(pred, fluid.layers.cast(label, dtype='int64'))
  56. accuracies.append(acc.numpy())
  57. losses.append(loss.numpy())
  58. print("[validation] accuracy/loss: {}/{}".format(np.mean(accuracies), np.mean(losses)))
  59. model.train()
  60. # save params of model
  61. fluid.save_dygraph(model.state_dict(), 'mnist')
  62. # save optimizer state
  63. fluid.save_dygraph(opt.state_dict(), 'mnist')
  64. # 定义评估过程
  65. def evaluation(model, params_file_path):
  66. with fluid.dygraph.guard():
  67. print('start evaluation .......')
  68. #加载模型参数
  69. model_state_dict, _ = fluid.load_dygraph(params_file_path)
  70. model.load_dict(model_state_dict)
  71. model.eval()
  72. eval_loader = load_data('eval')
  73. acc_set = []
  74. avg_loss_set = []
  75. for batch_id, data in enumerate(eval_loader()):
  76. x_data, y_data = data
  77. img = fluid.dygraph.to_variable(x_data)
  78. label = fluid.dygraph.to_variable(y_data)
  79. # 计算预测和精度
  80. prediction, acc = model(img, label)
  81. # 计算损失函数值
  82. loss = fluid.layers.cross_entropy(input=prediction, label=label)
  83. avg_loss = fluid.layers.mean(loss)
  84. acc_set.append(float(acc.numpy()))
  85. avg_loss_set.append(float(avg_loss.numpy()))
  86. # 求平均精度
  87. acc_val_mean = np.array(acc_set).mean()
  88. avg_loss_val_mean = np.array(avg_loss_set).mean()
  89. print('loss={}, acc={}'.format(avg_loss_val_mean, acc_val_mean))
  90. # 导入需要的包
  91. import paddle
  92. import paddle.fluid as fluid
  93. import numpy as np
  94. from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
  95. # 定义 LeNet 网络结构
  96. class LeNet(fluid.dygraph.Layer):
  97. def __init__(self, name_scope, num_classes=1):
  98. super(LeNet, self).__init__(name_scope)
  99. # 创建卷积和池化层块,每个卷积层使用Sigmoid激活函数,后面跟着一个2x2的池化
  100. self.conv1 = Conv2D(num_channels=3, num_filters=6, filter_size=5, act='sigmoid')
  101. self.pool1 = Pool2D(pool_size=2, pool_stride=2, pool_type='max')
  102. self.conv2 = Conv2D(num_channels=6, num_filters=16, filter_size=5, act='sigmoid')
  103. self.pool2 = Pool2D(pool_size=2, pool_stride=2, pool_type='max')
  104. # 创建第3个卷积层
  105. self.conv3 = Conv2D(num_channels=16, num_filters=120, filter_size=4, act='sigmoid')
  106. # 创建全连接层,第一个全连接层的输出神经元个数为64, 第二个全连接层输出神经元个数为分裂标签的类别数
  107. self.fc1 = Linear(input_dim=300000, output_dim=64, act='sigmoid')
  108. self.fc2 = Linear(input_dim=64, output_dim=num_classes)
  109. # 网络的前向计算过程
  110. def forward(self, x):
  111. x = self.conv1(x)
  112. x = self.pool1(x)
  113. x = self.conv2(x)
  114. x = self.pool2(x)
  115. x = self.conv3(x)
  116. x = fluid.layers.reshape(x, [x.shape[0], -1])
  117. x = self.fc1(x)
  118. x = self.fc2(x)
  119. return x
  120. if __name__ == '__main__':
  121. # 创建模型
  122. with fluid.dygraph.guard():
  123. model = LeNet("LeNet_", num_classes=1)
  124. train(model)

通过运行结果可以看出,在眼疾筛查数据集iChallenge-PM上,LeNet的loss很难下降,模型没有收敛。这是因为MNIST数据集的图片尺寸比较小(

LeNet在眼疾识别数据集iChallenge-PM上的应用 - 图4 ),但是眼疾筛查数据集图片尺寸比较大(原始图片尺寸约为 LeNet在眼疾识别数据集iChallenge-PM上的应用 - 图5 ,经过缩放之后变成 LeNet在眼疾识别数据集iChallenge-PM上的应用 - 图6 ),LeNet模型很难进行有效分类。这说明在图片尺寸比较大时,LeNet在图像分类任务上存在局限性。