summary

paddle. summary ( net, input_size, dtypes=None ) [源代码]

summary 函数能够打印网络的基础结构和参数信息。

参数:

  • net (Layer) - 网络实例,必须是 Layer 的子类。

  • input_size (tuple|InputSpec|list[tuple|InputSpec) - 输入张量的大小。如果网络只有一个输入,那么该值需要设定为tuple或InputSpec。如果模型有多个输入。那么该值需要设定为list[tuple|InputSpec],包含每个输入的shape。

  • dtypes (str,可选) - 输入张量的数据类型,如果没有给定,默认使用 float32 类型。默认值:None。

返回:字典,包含了总的参数量和总的可训练的参数量。

代码示例

  1. import paddle
  2. import paddle.nn as nn
  3. class LeNet(nn.Layer):
  4. def __init__(self, num_classes=10):
  5. super(LeNet, self).__init__()
  6. self.num_classes = num_classes
  7. self.features = nn.Sequential(
  8. nn.Conv2D(
  9. 1, 6, 3, stride=1, padding=1),
  10. nn.ReLU(),
  11. nn.MaxPool2D(2, 2),
  12. nn.Conv2D(
  13. 6, 16, 5, stride=1, padding=0),
  14. nn.ReLU(),
  15. nn.MaxPool2D(2, 2))
  16. if num_classes > 0:
  17. self.fc = nn.Sequential(
  18. nn.Linear(400, 120),
  19. nn.Linear(120, 84),
  20. nn.Linear(
  21. 84, 10))
  22. def forward(self, inputs):
  23. x = self.features(inputs)
  24. if self.num_classes > 0:
  25. x = paddle.flatten(x, 1)
  26. x = self.fc(x)
  27. return x
  28. lenet = LeNet()
  29. params_info = paddle.summary(lenet, (1, 1, 28, 28))
  30. print(params_info)
  31. # ---------------------------------------------------------------------------
  32. # Layer (type) Input Shape Output Shape Param #
  33. # ===========================================================================
  34. # Conv2D-11 [[1, 1, 28, 28]] [1, 6, 28, 28] 60
  35. # ReLU-11 [[1, 6, 28, 28]] [1, 6, 28, 28] 0
  36. # MaxPool2D-11 [[1, 6, 28, 28]] [1, 6, 14, 14] 0
  37. # Conv2D-12 [[1, 6, 14, 14]] [1, 16, 10, 10] 2,416
  38. # ReLU-12 [[1, 16, 10, 10]] [1, 16, 10, 10] 0
  39. # MaxPool2D-12 [[1, 16, 10, 10]] [1, 16, 5, 5] 0
  40. # Linear-16 [[1, 400]] [1, 120] 48,120
  41. # Linear-17 [[1, 120]] [1, 84] 10,164
  42. # Linear-18 [[1, 84]] [1, 10] 850
  43. # ===========================================================================
  44. # Total params: 61,610
  45. # Trainable params: 61,610
  46. # Non-trainable params: 0
  47. # ---------------------------------------------------------------------------
  48. # Input size (MB): 0.00
  49. # Forward/backward pass size (MB): 0.11
  50. # Params size (MB): 0.24
  51. # Estimated Total Size (MB): 0.35
  52. # ---------------------------------------------------------------------------
  53. # {'total_params': 61610, 'trainable_params': 61610}

使用本API的教程文档