5.12. 稠密连接网络(DenseNet)
ResNet中的跨层连接设计引申出了数个后续工作。本节我们介绍其中的一个:稠密连接网络(DenseNet)[1]。 它与ResNet的主要区别如图5.10所示。
图 5.10 ResNet(左)与DenseNet(右)在跨层连接上的主要区别:使用相加和使用连结
图5.10中将部分前后相邻的运算抽象为模块
和模块 。与ResNet的主要区别在于,DenseNet里模块 的输出不是像ResNet那样和模块 的输出相加,而是在通道维上连结。这样模块 的输出可以直接传入模块 后面的层。在这个设计里,模块 直接跟模块 后面的所有层连接在了一起。这也是它被称为“稠密连接”的原因。
DenseNet的主要构建模块是稠密块(dense block)和过渡层(transitionlayer)。前者定义了输入和输出是如何连结的,后者则用来控制通道数,使之不过大。
5.12.1. 稠密块
DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构(参见上一节的练习),我们首先在conv_block
函数里实现这个结构。
- In [1]:
- import d2lzh as d2l
- from mxnet import gluon, init, nd
- from mxnet.gluon import nn
- def conv_block(num_channels):
- blk = nn.Sequential()
- blk.add(nn.BatchNorm(), nn.Activation('relu'),
- nn.Conv2D(num_channels, kernel_size=3, padding=1))
- return blk
稠密块由多个conv_block
组成,每块使用相同的输出通道数。但在前向计算时,我们将每块的输入和输出在通道维上连结。
- In [2]:
- class DenseBlock(nn.Block):
- def __init__(self, num_convs, num_channels, **kwargs):
- super(DenseBlock, self).__init__(**kwargs)
- self.net = nn.Sequential()
- for _ in range(num_convs):
- self.net.add(conv_block(num_channels))
- def forward(self, X):
- for blk in self.net:
- Y = blk(X)
- X = nd.concat(X, Y, dim=1) # 在通道维上将输入和输出连结
- return X
在下面的例子中,我们定义一个有2个输出通道数为10的卷积块。使用通道数为3的输入时,我们会得到通道数为
的输出。卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为增长率(growthrate)。
- In [3]:
- blk = DenseBlock(2, 10)
- blk.initialize()
- X = nd.random.uniform(shape=(4, 3, 8, 8))
- Y = blk(X)
- Y.shape
- Out[3]:
- (4, 23, 8, 8)
5.12.2. 过渡层
由于每个稠密块都会带来通道数的增加,使用过多则会带来过于复杂的模型。过渡层用来控制模型复杂度。它通过
卷积层来减小通道数,并使用步幅为2的平均池化层减半高和宽,从而进一步降低模型复杂度。
- In [4]:
- def transition_block(num_channels):
- blk = nn.Sequential()
- blk.add(nn.BatchNorm(), nn.Activation('relu'),
- nn.Conv2D(num_channels, kernel_size=1),
- nn.AvgPool2D(pool_size=2, strides=2))
- return blk
对上一个例子中稠密块的输出使用通道数为10的过渡层。此时输出的通道数减为10,高和宽均减半。
- In [5]:
- blk = transition_block(10)
- blk.initialize()
- blk(Y).shape
- Out[5]:
- (4, 10, 4, 4)
5.12.3. DenseNet模型
我们来构造DenseNet模型。DenseNet首先使用同ResNet一样的单卷积层和最大池化层。
- In [6]:
- net = nn.Sequential()
- net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3),
- nn.BatchNorm(), nn.Activation('relu'),
- nn.MaxPool2D(pool_size=3, strides=2, padding=1))
类似于ResNet接下来使用的4个残差块,DenseNet使用的是4个稠密块。同ResNet一样,我们可以设置每个稠密块使用多少个卷积层。这里我们设成4,从而与上一节的ResNet-18保持一致。稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。
ResNet里通过步幅为2的残差块在每个模块之间减小高和宽。这里我们则使用过渡层来减半高和宽,并减半通道数。
- In [7]:
- num_channels, growth_rate = 64, 32 # num_channels为当前的通道数
- num_convs_in_dense_blocks = [4, 4, 4, 4]
- for i, num_convs in enumerate(num_convs_in_dense_blocks):
- net.add(DenseBlock(num_convs, growth_rate))
- # 上一个稠密块的输出通道数
- num_channels += num_convs * growth_rate
- # 在稠密块之间加入通道数减半的过渡层
- if i != len(num_convs_in_dense_blocks) - 1:
- num_channels //= 2
- net.add(transition_block(num_channels))
同ResNet一样,最后接上全局池化层和全连接层来输出。
- In [8]:
- net.add(nn.BatchNorm(), nn.Activation('relu'), nn.GlobalAvgPool2D(),
- nn.Dense(10))
5.12.4. 获取数据并训练模型
由于这里使用了比较深的网络,本节里我们将输入高和宽从224降到96来简化计算。
- In [9]:
- lr, num_epochs, batch_size, ctx = 0.1, 5, 256, d2l.try_gpu()
- net.initialize(ctx=ctx, init=init.Xavier())
- trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
- train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
- d2l.train_ch5(net, train_iter, test_iter, batch_size, trainer, ctx,
- num_epochs)
- training on gpu(0)
- epoch 1, loss 0.5423, train acc 0.810, test acc 0.836, time 14.8 sec
- epoch 2, loss 0.3123, train acc 0.886, test acc 0.885, time 13.2 sec
- epoch 3, loss 0.2610, train acc 0.904, test acc 0.906, time 13.1 sec
- epoch 4, loss 0.2328, train acc 0.915, test acc 0.906, time 13.1 sec
- epoch 5, loss 0.2106, train acc 0.923, test acc 0.918, time 13.1 sec
5.12.5. 小结
- 在跨层连接上,不同于ResNet中将输入与输出相加,DenseNet在通道维上连结输入与输出。
- DenseNet的主要构建模块是稠密块和过渡层。
5.12.6. 练习
- DenseNet论文中提到的一个优点是模型参数比ResNet的更小,这是为什么?
- DenseNet被人诟病的一个问题是内存或显存消耗过多。真的会这样吗?可以把输入形状换成 ,来看看实际的消耗。
- 实现DenseNet论文中的表1提出的不同版本的DenseNet [1]。
5.12.7. 参考文献
[1] Huang, G., Liu, Z., Weinberger, K. Q., & van der Maaten, L. (2017).Densely connected convolutional networks. In Proceedings of the IEEEconference on computer vision and pattern recognition (Vol. 1, No. 2).