搭建神经网络

​神经网络的各层,可以使用 oneflow.nn 名称空间下的 API 搭建,它提供了构建神经网络所需的常见 Module(如 oneflow.nn.Conv2doneflow.nn.ReLU 等等)。 用于搭建网络的所有 Module 类都继承自 oneflow.nn.Module,多个简单的 Module 可以组合在一起构成更复杂的 Module,用这种方式,用户可以轻松地搭建和管理复杂的神经网络。

  1. import oneflow as flow
  2. import oneflow.nn as nn

定义 Module 类

oneflow.nn 下提供了常见的 Module 类,我们可以直接使用它们,或者在它们的基础上,通过自定义 Module 类搭建神经网络。搭建过程包括:

  • 写一个继承自 oneflow.nn.Module 的类
  • 实现类的 __init__ 方法,在其中构建神经网络的结构
  • 实现类的 forward 方法,这个方法针对 Module 的输入进行计算
  1. class NeuralNetwork(nn.Module):
  2. def __init__(self):
  3. super(NeuralNetwork, self).__init__()
  4. self.flatten = nn.Flatten()
  5. self.linear_relu_stack = nn.Sequential(
  6. nn.Linear(28*28, 512),
  7. nn.ReLU(),
  8. nn.Linear(512, 512),
  9. nn.ReLU(),
  10. nn.Linear(512, 10),
  11. nn.ReLU()
  12. )
  13. def forward(self, x):
  14. x = self.flatten(x)
  15. logits = self.linear_relu_stack(x)
  16. return logits
  17. net = NeuralNetwork()
  18. print(net)

以上代码,会输出刚刚搭建的 NeuralNetwork 网络的结构:

  1. NeuralNetwork(
  2. (flatten): Flatten(start_dim=1, end_dim=-1)
  3. (linear_relu_stack): Sequential(
  4. (0): Linear(in_features=784, out_features=512, bias=True)
  5. (1): ReLU()
  6. (2): Linear(in_features=512, out_features=512, bias=True)
  7. (3): ReLU()
  8. (4): Linear(in_features=512, out_features=10, bias=True)
  9. (5): ReLU()
  10. )
  11. )

接着,调用 net (注意:不推荐显式调用 forward)即可完成前向传播:

  1. X = flow.ones(1, 28, 28)
  2. logits = net(X)
  3. pred_probab = nn.Softmax(dim=1)(logits)
  4. y_pred = pred_probab.argmax(1)
  5. print(f"Predicted class: {y_pred}")

会得到类似以下的输出结果:

  1. Predicted class: tensor([1], dtype=oneflow.int32)

以上从数据输入、到网络计算,最终推理输出的流程,如下图所示:

todo

flow.nn.functional

除了 oneflow.nn 外,oneflow.nn.functional 名称空间下也提供了不少 API。它与 oneflow.nn 在功能上有一定的重叠。比如 nn.functional.relunn.ReLU 都可用于神经网络做 activation 操作。

两者的区别主要有:

  • nn 下的 API 是类,需要先构造实例化对象,再调用;nn.functional 下的 API 是作为函数直接调用
  • nn 下的类内部自己管理了网络参数;而 nn.functional 下的函数,需要我们自己定义参数,每次调用时手动传入

实际上,OneFlow 提供的大部分 Module 是通过封装 nn.functional 下的方法得到的。nn.functional 提供了更加细粒度管理网络的可能。

以下的例子,使用 nn.functional 中的方法,构建与上文中 NeuralNetwork 类等价的 Module FunctionalNeuralNetwork,读者可以体会两者的异同:

  1. class FunctionalNeuralNetwork(nn.Module):
  2. def __init__(self):
  3. super(FunctionalNeuralNetwork, self).__init__()
  4. self.weight1 = nn.Parameter(flow.randn(28*28, 512))
  5. self.bias1 = nn.Parameter(flow.randn(512))
  6. self.weight2 = nn.Parameter(flow.randn(512, 512))
  7. self.bias2 = nn.Parameter(flow.randn(512))
  8. self.weight3 = nn.Parameter(flow.randn(512, 10))
  9. self.bias3 = nn.Parameter(flow.randn(10))
  10. def forward(self, x):
  11. x = x.reshape(1, 28*28)
  12. out = flow.matmul(x, self.weight1)
  13. out = out + self.bias1
  14. out = nn.functional.relu(out)
  15. out = flow.matmul(out, self.weight2)
  16. out = out + self.bias2
  17. out = nn.functional.relu(out)
  18. out = flow.matmul(out, self.weight3)
  19. out = out + self.bias3
  20. out = nn.functional.relu(out)
  21. return out
  22. net = FunctionalNeuralNetwork()
  23. X = flow.ones(1, 28, 28)
  24. logits = net(X)
  25. pred_probab = nn.Softmax(dim=1)(logits)
  26. y_pred = pred_probab.argmax(1)
  27. print(f"Predicted class: {y_pred}")

Module 容器

比较以上 NeuralNetworkFunctionalNeuralNetwork 实现的异同,可以发现 nn.Sequential 对于简化代码起到了重要作用。

nn.Sequential 是一种特殊容器,只要是继承自 nn.Module 的类都可以放置放置到其中。

它的特殊之处在于:当 Sequential 进行前向传播时,Sequential 会自动地将容器中包含的各层“串联”起来。具体来说,会按照各层加入 Sequential 的顺序,自动地将上一层的输出,作为下一层的输入传递,直到得到整个 Module 的最后一层的输出。

以下是不使用 Sequential 构建网络的例子(不推荐):

  1. class MyModel(nn.Module):
  2. def __init__(self):
  3. super(MyModel, self).__init__()
  4. self.conv1 = nn.Conv2d(1,20,5)
  5. self.relu1 = nn.ReLU()
  6. self.conv2 = nn.Conv2d(20,64,5)
  7. self.relu2 = nn.ReLU()
  8. def forward(self, x):
  9. out = self.conv1(x)
  10. out = self.relu1(out)
  11. out = self.conv2(out)
  12. out = self.relu2(out)
  13. return out

如果使用 Sequential,则看起来是这样,会显得更简洁。

  1. class MySeqModel(nn.Module):
  2. def __init__(self):
  3. super(MySeqModel, self).__init__()
  4. self.seq = nn.Sequential(
  5. nn.Conv2d(1,20,5),
  6. nn.ReLU(),
  7. nn.Conv2d(20,64,5),
  8. nn.ReLU()
  9. )
  10. def forward(self, x):
  11. return self.seq(x)

除了 Sequential 外,还有 nn.Modulelistnn.ModuleDict,除了会自动注册参数到整个网络外,他们的其它行为类似 Python list、Python dict,只是常用简单的容器,不会自动进行前后层的前向传播,需要自己手工遍历完成各层的计算。

为正常使用来必力评论功能请激活JavaScript