点云处理:实现PointNet点云分类

作者Zhihao Cao
日期:2021.05
摘要:本示例在于演示如何基于Paddle2.1实现PointNet在ShapeNet数据集上进行点云分类处理。

一、环境设置

本教程基于Paddle 2.1 编写,如果你的环境不是本版本,请先参考官网安装 Paddle 2.1 。

  1. import os
  2. import numpy as np
  3. import random
  4. import h5py
  5. import paddle
  6. import paddle.nn as nn
  7. import paddle.nn.functional as F
  8. print(paddle.__version__)
  1. 2.1.0

二、数据集

2.1 数据介绍

ShapeNet数据集是一个注释丰富且规模较大的 3D 形状数据集,由斯坦福大学、普林斯顿大学和芝加哥丰田技术学院于 2015 年联合发布。
ShapeNet数据集官方链接:https://vision.princeton.edu/projects/2014/3DShapeNets/
ShapeNet数据集的储存格式是h5文件,该文件中key值分别为:

  • 1、data:这一份数据中所有点的xyz坐标,

  • 2、label:这一份数据所属类别,如airplane等,

  • 3、pid:这一份数据中所有点所属的类型,如这一份数据属airplane类,则它包含的所有点的类型有机翼、机身等类型。

2.2 解压数据集

  1. !unzip data/data70460/shapenet_part_seg_hdf5_data.zip
  2. !mv hdf5_data dataset
  1. Archive: data/data70460/shapenet_part_seg_hdf5_data.zip
  2. creating: hdf5_data/
  3. inflating: hdf5_data/ply_data_train5.h5
  4. inflating: hdf5_data/ply_data_train1.h5
  5. inflating: hdf5_data/ply_data_train3.h5
  6. inflating: hdf5_data/ply_data_val0.h5
  7. inflating: hdf5_data/ply_data_train0.h5
  8. inflating: hdf5_data/ply_data_test1.h5
  9. inflating: hdf5_data/ply_data_test0.h5
  10. inflating: hdf5_data/ply_data_train4.h5
  11. inflating: hdf5_data/ply_data_train2.h5

2.3 数据列表

ShapeNet数据集所有的数据文件。

  1. train_list = ['ply_data_train0.h5', 'ply_data_train1.h5', 'ply_data_train2.h5', 'ply_data_train3.h5', 'ply_data_train4.h5', 'ply_data_train5.h5']
  2. test_list = ['ply_data_test0.h5', 'ply_data_test1.h5']
  3. val_list = ['ply_data_val0.h5']

2.4 搭建数据生成器

说明:将ShapeNet数据集全部读入。

  1. def make_data(mode='train', path='./dataset/', num_point=2048):
  2. datas = []
  3. labels = []
  4. if mode == 'train':
  5. for file_list in train_list:
  6. f = h5py.File(os.path.join(path, file_list), 'r')
  7. datas.extend(f['data'][:, :num_point, :])
  8. labels.extend(f['label'])
  9. f.close()
  10. elif mode == 'test':
  11. for file_list in test_list:
  12. f = h5py.File(os.path.join(path, file_list), 'r')
  13. datas.extend(f['data'][:, :num_point, :])
  14. labels.extend(f['label'])
  15. f.close()
  16. else:
  17. for file_list in val_list:
  18. f = h5py.File(os.path.join(path, file_list), 'r')
  19. datas.extend(f['data'][:, :num_point, :])
  20. labels.extend(f['label'])
  21. f.close()
  22. return datas, labels

说明:通过继承paddle.io.Dataset来完成数据集的构造。

  1. class PointDataset(paddle.io.Dataset):
  2. def __init__(self, datas, labels):
  3. super(PointDataset, self).__init__()
  4. self.datas = datas
  5. self.labels = labels
  6. def __getitem__(self, index):
  7. data = paddle.to_tensor(self.datas[index].T.astype('float32'))
  8. label = paddle.to_tensor(self.labels[index].astype('int64'))
  9. return data, label
  10. def __len__(self):
  11. return len(self.datas)

说明:使用飞桨框架提供的API:paddle.io.DataLoader完成数据的加载,使得按照Batchsize生成Mini-batch的数据。

  1. # 数据导入
  2. datas, labels = make_data(mode='train', num_point=2048)
  3. train_dataset = PointDataset(datas, labels)
  4. datas, labels = make_data(mode='val', num_point=2048)
  5. val_dataset = PointDataset(datas, labels)
  6. datas, labels = make_data(mode='test', num_point=2048)
  7. test_dataset = PointDataset(datas, labels)
  8. # 实例化数据读取器
  9. train_loader = paddle.io.DataLoader(
  10. train_dataset,
  11. batch_size=128,
  12. shuffle=True,
  13. drop_last=False
  14. )
  15. val_loader = paddle.io.DataLoader(
  16. val_dataset,
  17. batch_size=32,
  18. shuffle=False,
  19. drop_last=False
  20. )
  21. test_loader = paddle.io.DataLoader(
  22. test_dataset,
  23. batch_size=128,
  24. shuffle=False,
  25. drop_last=False
  26. )

三、定义网络

PointNet是斯坦福大学研究人员提出的一个点云处理网络,在这篇论文中,它提出了空间变换网络(T-Net)解决点云的旋转问题(注:因为考虑到某一物体的点云旋转后还是该物体,所以需要有一个网络结构去学习并解决这个旋转问题),并且提出了采取MaxPooling的方法极大程度上地提取点云全局特征。

3.1 定义网络结构

  1. class PointNet(nn.Layer):
  2. def __init__(self, name_scope='PointNet_', num_classes=16, num_point=2048):
  3. super(PointNet, self).__init__()
  4. self.input_transform_net = nn.Sequential(
  5. nn.Conv1D(3, 64, 1),
  6. nn.BatchNorm(64),
  7. nn.ReLU(),
  8. nn.Conv1D(64, 128, 1),
  9. nn.BatchNorm(128),
  10. nn.ReLU(),
  11. nn.Conv1D(128, 1024, 1),
  12. nn.BatchNorm(1024),
  13. nn.ReLU(),
  14. nn.MaxPool1D(num_point)
  15. )
  16. self.input_fc = nn.Sequential(
  17. nn.Linear(1024, 512),
  18. nn.ReLU(),
  19. nn.Linear(512, 256),
  20. nn.ReLU(),
  21. nn.Linear(256, 9,
  22. weight_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(paddle.zeros((256, 9)))),
  23. bias_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(paddle.reshape(paddle.eye(3), [-1])))
  24. )
  25. )
  26. self.mlp_1 = nn.Sequential(
  27. nn.Conv1D(3, 64, 1),
  28. nn.BatchNorm(64),
  29. nn.ReLU(),
  30. nn.Conv1D(64, 64, 1),
  31. nn.BatchNorm(64),
  32. nn.ReLU()
  33. )
  34. self.feature_transform_net = nn.Sequential(
  35. nn.Conv1D(64, 64, 1),
  36. nn.BatchNorm(64),
  37. nn.ReLU(),
  38. nn.Conv1D(64, 128, 1),
  39. nn.BatchNorm(128),
  40. nn.ReLU(),
  41. nn.Conv1D(128, 1024, 1),
  42. nn.BatchNorm(1024),
  43. nn.ReLU(),
  44. nn.MaxPool1D(num_point)
  45. )
  46. self.feature_fc = nn.Sequential(
  47. nn.Linear(1024, 512),
  48. nn.ReLU(),
  49. nn.Linear(512, 256),
  50. nn.ReLU(),
  51. nn.Linear(256, 64*64)
  52. )
  53. self.mlp_2 = nn.Sequential(
  54. nn.Conv1D(64, 64, 1),
  55. nn.BatchNorm(64),
  56. nn.ReLU(),
  57. nn.Conv1D(64, 128, 1),
  58. nn.BatchNorm(128),
  59. nn.ReLU(),
  60. nn.Conv1D(128, 1024, 1),
  61. nn.BatchNorm(1024),
  62. nn.ReLU()
  63. )
  64. self.fc = nn.Sequential(
  65. nn.Linear(1024, 512),
  66. nn.ReLU(),
  67. nn.Linear(512, 256),
  68. nn.ReLU(),
  69. nn.Dropout(p=0.7),
  70. nn.Linear(256, num_classes),
  71. nn.LogSoftmax(axis=-1)
  72. )
  73. def forward(self, inputs):
  74. batchsize = inputs.shape[0]
  75. t_net = self.input_transform_net(inputs)
  76. t_net = paddle.squeeze(t_net, axis=-1)
  77. t_net = self.input_fc(t_net)
  78. t_net = paddle.reshape(t_net, [batchsize, 3, 3])
  79. x = paddle.transpose(inputs, (0, 2, 1))
  80. x = paddle.matmul(x, t_net)
  81. x = paddle.transpose(x, (0, 2, 1))
  82. x = self.mlp_1(x)
  83. t_net = self.feature_transform_net(x)
  84. t_net = paddle.squeeze(t_net, axis=-1)
  85. t_net = self.feature_fc(t_net)
  86. t_net = paddle.reshape(t_net, [batchsize, 64, 64])
  87. x = paddle.squeeze(x, axis=-1)
  88. x = paddle.transpose(x, (0, 2, 1))
  89. x = paddle.matmul(x, t_net)
  90. x = paddle.transpose(x, (0, 2, 1))
  91. x = self.mlp_2(x)
  92. x = paddle.max(x, axis=-1)
  93. x = paddle.squeeze(x, axis=-1)
  94. x = self.fc(x)
  95. return x

3.2 网络结构可视化

说明:使用飞桨API:paddle.summary完成模型结构可视化

  1. pointnet = PointNet()
  2. paddle.summary(pointnet, (64, 3, 2048))
  1. ---------------------------------------------------------------------------
  2. Layer (type) Input Shape Output Shape Param #
  3. ===========================================================================
  4. Conv1D-1 [[64, 3, 2048]] [64, 64, 2048] 256
  5. BatchNorm-1 [[64, 64, 2048]] [64, 64, 2048] 256
  6. ReLU-1 [[64, 64, 2048]] [64, 64, 2048] 0
  7. Conv1D-2 [[64, 64, 2048]] [64, 128, 2048] 8,320
  8. BatchNorm-2 [[64, 128, 2048]] [64, 128, 2048] 512
  9. ReLU-2 [[64, 128, 2048]] [64, 128, 2048] 0
  10. Conv1D-3 [[64, 128, 2048]] [64, 1024, 2048] 132,096
  11. BatchNorm-3 [[64, 1024, 2048]] [64, 1024, 2048] 4,096
  12. ReLU-3 [[64, 1024, 2048]] [64, 1024, 2048] 0
  13. MaxPool1D-1 [[64, 1024, 2048]] [64, 1024, 1] 0
  14. Linear-1 [[64, 1024]] [64, 512] 524,800
  15. ReLU-4 [[64, 512]] [64, 512] 0
  16. Linear-2 [[64, 512]] [64, 256] 131,328
  17. ReLU-5 [[64, 256]] [64, 256] 0
  18. Linear-3 [[64, 256]] [64, 9] 2,313
  19. Conv1D-4 [[64, 3, 2048]] [64, 64, 2048] 256
  20. BatchNorm-4 [[64, 64, 2048]] [64, 64, 2048] 256
  21. ReLU-6 [[64, 64, 2048]] [64, 64, 2048] 0
  22. Conv1D-5 [[64, 64, 2048]] [64, 64, 2048] 4,160
  23. BatchNorm-5 [[64, 64, 2048]] [64, 64, 2048] 256
  24. ReLU-7 [[64, 64, 2048]] [64, 64, 2048] 0
  25. Conv1D-6 [[64, 64, 2048]] [64, 64, 2048] 4,160
  26. BatchNorm-6 [[64, 64, 2048]] [64, 64, 2048] 256
  27. ReLU-8 [[64, 64, 2048]] [64, 64, 2048] 0
  28. Conv1D-7 [[64, 64, 2048]] [64, 128, 2048] 8,320
  29. BatchNorm-7 [[64, 128, 2048]] [64, 128, 2048] 512
  30. ReLU-9 [[64, 128, 2048]] [64, 128, 2048] 0
  31. Conv1D-8 [[64, 128, 2048]] [64, 1024, 2048] 132,096
  32. BatchNorm-8 [[64, 1024, 2048]] [64, 1024, 2048] 4,096
  33. ReLU-10 [[64, 1024, 2048]] [64, 1024, 2048] 0
  34. MaxPool1D-2 [[64, 1024, 2048]] [64, 1024, 1] 0
  35. Linear-4 [[64, 1024]] [64, 512] 524,800
  36. ReLU-11 [[64, 512]] [64, 512] 0
  37. Linear-5 [[64, 512]] [64, 256] 131,328
  38. ReLU-12 [[64, 256]] [64, 256] 0
  39. Linear-6 [[64, 256]] [64, 4096] 1,052,672
  40. Conv1D-9 [[64, 64, 2048]] [64, 64, 2048] 4,160
  41. BatchNorm-9 [[64, 64, 2048]] [64, 64, 2048] 256
  42. ReLU-13 [[64, 64, 2048]] [64, 64, 2048] 0
  43. Conv1D-10 [[64, 64, 2048]] [64, 128, 2048] 8,320
  44. BatchNorm-10 [[64, 128, 2048]] [64, 128, 2048] 512
  45. ReLU-14 [[64, 128, 2048]] [64, 128, 2048] 0
  46. Conv1D-11 [[64, 128, 2048]] [64, 1024, 2048] 132,096
  47. BatchNorm-11 [[64, 1024, 2048]] [64, 1024, 2048] 4,096
  48. ReLU-15 [[64, 1024, 2048]] [64, 1024, 2048] 0
  49. Linear-7 [[64, 1024]] [64, 512] 524,800
  50. ReLU-16 [[64, 512]] [64, 512] 0
  51. Linear-8 [[64, 512]] [64, 256] 131,328
  52. ReLU-17 [[64, 256]] [64, 256] 0
  53. Dropout-1 [[64, 256]] [64, 256] 0
  54. Linear-9 [[64, 256]] [64, 16] 4,112
  55. LogSoftmax-1 [[64, 16]] [64, 16] 0
  56. ===========================================================================
  57. Total params: 3,476,825
  58. Trainable params: 3,461,721
  59. Non-trainable params: 15,104
  60. ---------------------------------------------------------------------------
  61. Input size (MB): 1.50
  62. Forward/backward pass size (MB): 11333.40
  63. Params size (MB): 13.26
  64. Estimated Total Size (MB): 11348.16
  65. ---------------------------------------------------------------------------
  66. {'total_params': 3476825, 'trainable_params': 3461721}

四、训练

说明:模型训练的时候,将会使用paddle.optimizer.Adam优化器来进行优化。使用F.nll_loss来计算损失值。

  1. def train():
  2. model = PointNet(num_classes=16, num_point=2048)
  3. model.train()
  4. optim = paddle.optimizer.Adam(parameters=model.parameters(), weight_decay=0.001)
  5. epoch_num = 10
  6. for epoch in range(epoch_num):
  7. # train
  8. print("===================================train===========================================")
  9. for batch_id, data in enumerate(train_loader()):
  10. inputs, labels = data
  11. predicts = model(inputs)
  12. loss = F.nll_loss(predicts, labels)
  13. acc = paddle.metric.accuracy(predicts, labels)
  14. if batch_id % 20 == 0:
  15. print("train: epoch: {}, batch_id: {}, loss is: {}, accuracy is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
  16. loss.backward()
  17. optim.step()
  18. optim.clear_grad()
  19. if epoch % 2 == 0:
  20. paddle.save(model.state_dict(), './model/PointNet.pdparams')
  21. paddle.save(optim.state_dict(), './model/PointNet.pdopt')
  22. # validation
  23. print("===================================val===========================================")
  24. model.eval()
  25. accuracies = []
  26. losses = []
  27. for batch_id, data in enumerate(val_loader()):
  28. inputs, labels = data
  29. predicts = model(inputs)
  30. loss = F.nll_loss(predicts, labels)
  31. acc = paddle.metric.accuracy(predicts, labels)
  32. losses.append(loss.numpy())
  33. accuracies.append(acc.numpy())
  34. avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
  35. print("validation: loss is: {}, accuracy is: {}".format(avg_loss, avg_acc))
  36. model.train()
  37. if __name__ == '__main__':
  38. train()
  1. ===================================train===========================================
  2. train: epoch: 0, batch_id: 0, loss is: [4.9248095], accuracy is: [0.1484375]
  3. train: epoch: 0, batch_id: 20, loss is: [1.2662703], accuracy is: [0.6875]
  4. train: epoch: 0, batch_id: 40, loss is: [0.79613143], accuracy is: [0.7734375]
  5. train: epoch: 0, batch_id: 60, loss is: [0.5991034], accuracy is: [0.8203125]
  6. train: epoch: 0, batch_id: 80, loss is: [0.6732792], accuracy is: [0.8125]
  7. ===================================val===========================================
  8. validation: loss is: 0.41937708854675293, accuracy is: 0.8569915294647217
  9. ===================================train===========================================
  10. train: epoch: 1, batch_id: 0, loss is: [0.5992371], accuracy is: [0.8046875]
  11. train: epoch: 1, batch_id: 20, loss is: [0.65309006], accuracy is: [0.84375]
  12. train: epoch: 1, batch_id: 40, loss is: [0.4504382], accuracy is: [0.8671875]
  13. train: epoch: 1, batch_id: 60, loss is: [0.5059966], accuracy is: [0.828125]
  14. train: epoch: 1, batch_id: 80, loss is: [0.28158492], accuracy is: [0.875]
  15. ===================================val===========================================
  16. validation: loss is: 0.30018773674964905, accuracy is: 0.90625
  17. ===================================train===========================================
  18. train: epoch: 2, batch_id: 0, loss is: [0.27240375], accuracy is: [0.9375]
  19. train: epoch: 2, batch_id: 20, loss is: [0.4211054], accuracy is: [0.9140625]
  20. train: epoch: 2, batch_id: 40, loss is: [0.3876957], accuracy is: [0.890625]
  21. train: epoch: 2, batch_id: 60, loss is: [0.27216607], accuracy is: [0.9140625]
  22. train: epoch: 2, batch_id: 80, loss is: [0.29216224], accuracy is: [0.9296875]
  23. ===================================val===========================================
  24. validation: loss is: 6.236376762390137, accuracy is: 0.07642251998186111
  25. ===================================train===========================================
  26. train: epoch: 3, batch_id: 0, loss is: [2.7821736], accuracy is: [0.1015625]
  27. train: epoch: 3, batch_id: 20, loss is: [2.5189795], accuracy is: [0.2109375]
  28. train: epoch: 3, batch_id: 40, loss is: [2.2503586], accuracy is: [0.2109375]
  29. train: epoch: 3, batch_id: 60, loss is: [2.1081736], accuracy is: [0.328125]
  30. train: epoch: 3, batch_id: 80, loss is: [1.9944972], accuracy is: [0.375]
  31. ===================================val===========================================
  32. validation: loss is: 1505.8250732421875, accuracy is: 0.2111077606678009
  33. ===================================train===========================================
  34. train: epoch: 4, batch_id: 0, loss is: [2.041934], accuracy is: [0.3203125]
  35. train: epoch: 4, batch_id: 20, loss is: [2.013423], accuracy is: [0.3203125]
  36. train: epoch: 4, batch_id: 40, loss is: [2.1235168], accuracy is: [0.328125]
  37. train: epoch: 4, batch_id: 60, loss is: [2.0412865], accuracy is: [0.2578125]
  38. train: epoch: 4, batch_id: 80, loss is: [2.0814402], accuracy is: [0.328125]
  39. ===================================val===========================================
  40. validation: loss is: 2.312401533126831, accuracy is: 0.3141646683216095
  41. ===================================train===========================================
  42. train: epoch: 5, batch_id: 0, loss is: [2.343786], accuracy is: [0.21875]
  43. train: epoch: 5, batch_id: 20, loss is: [2.121391], accuracy is: [0.296875]
  44. train: epoch: 5, batch_id: 40, loss is: [1.9093541], accuracy is: [0.3828125]
  45. train: epoch: 5, batch_id: 60, loss is: [2.181718], accuracy is: [0.296875]
  46. train: epoch: 5, batch_id: 80, loss is: [1.9160612], accuracy is: [0.3359375]
  47. ===================================val===========================================
  48. validation: loss is: 2.033637046813965, accuracy is: 0.3141646683216095
  49. ===================================train===========================================
  50. train: epoch: 6, batch_id: 0, loss is: [2.0708382], accuracy is: [0.3203125]
  51. train: epoch: 6, batch_id: 20, loss is: [1.8967582], accuracy is: [0.3515625]
  52. train: epoch: 6, batch_id: 40, loss is: [1.9126657], accuracy is: [0.390625]
  53. train: epoch: 6, batch_id: 60, loss is: [1.9337373], accuracy is: [0.3125]
  54. train: epoch: 6, batch_id: 80, loss is: [1.9214094], accuracy is: [0.34375]
  55. ===================================val===========================================
  56. validation: loss is: 2.130458116531372, accuracy is: 0.3141646683216095
  57. ===================================train===========================================
  58. train: epoch: 7, batch_id: 0, loss is: [2.105646], accuracy is: [0.2734375]
  59. train: epoch: 7, batch_id: 20, loss is: [2.063758], accuracy is: [0.28125]
  60. train: epoch: 7, batch_id: 40, loss is: [2.0195878], accuracy is: [0.328125]
  61. train: epoch: 7, batch_id: 60, loss is: [2.061315], accuracy is: [0.265625]
  62. train: epoch: 7, batch_id: 80, loss is: [2.0420928], accuracy is: [0.2578125]
  63. ===================================val===========================================
  64. validation: loss is: 2.0498859882354736, accuracy is: 0.3141646683216095
  65. ===================================train===========================================
  66. train: epoch: 8, batch_id: 0, loss is: [1.9741396], accuracy is: [0.3515625]
  67. train: epoch: 8, batch_id: 20, loss is: [2.0922964], accuracy is: [0.28125]
  68. train: epoch: 8, batch_id: 40, loss is: [1.926232], accuracy is: [0.3125]
  69. train: epoch: 8, batch_id: 60, loss is: [2.044017], accuracy is: [0.3828125]
  70. train: epoch: 8, batch_id: 80, loss is: [2.0386844], accuracy is: [0.296875]
  71. ===================================val===========================================
  72. validation: loss is: 1.9772093296051025, accuracy is: 0.3141646683216095
  73. ===================================train===========================================
  74. train: epoch: 9, batch_id: 0, loss is: [2.0979326], accuracy is: [0.3125]
  75. train: epoch: 9, batch_id: 20, loss is: [2.0070634], accuracy is: [0.2734375]
  76. train: epoch: 9, batch_id: 40, loss is: [2.0131662], accuracy is: [0.3046875]
  77. train: epoch: 9, batch_id: 60, loss is: [2.030023], accuracy is: [0.3046875]
  78. train: epoch: 9, batch_id: 80, loss is: [2.210213], accuracy is: [0.2890625]
  79. ===================================val===========================================
  80. validation: loss is: 2.011220693588257, accuracy is: 0.3141646683216095

五、评估与测试

说明:通过model.load_dict的方式加载训练好的模型对测试集上的数据进行评估与测试。

  1. def evaluation():
  2. model = PointNet()
  3. model_state_dict = paddle.load('./model/PointNet.pdparams')
  4. model.load_dict(model_state_dict)
  5. model.eval()
  6. accuracies = []
  7. losses = []
  8. for batch_id, data in enumerate(test_loader()):
  9. inputs, labels = data
  10. predicts = model(inputs)
  11. loss = F.nll_loss(predicts, labels)
  12. acc = paddle.metric.accuracy(predicts, labels)
  13. losses.append(loss.numpy())
  14. accuracies.append(acc.numpy())
  15. avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
  16. print("validation: loss is: {}, accuracy is: {}".format(avg_loss, avg_acc))
  17. if __name__ == '__main__':
  18. evaluation()