ResNet残差块详解

论文链接:Deep Residual Learning for Image Recognition

ResNet在深度学习领域可以说是一个人尽皆知的模型,在各种模型的backbone上得到了广泛的应用。其理论简洁、结构简单,因此往往会让人忽视其巧妙的设计。我在最近设计模型结构时逐渐发现这一点,故写此文总结一下。

再看一遍残差块

这里简单回顾一下残差块和跳跃链接,如果足够熟悉的话可以跳过。

在ResNet出现之前,人们发现随着神经网络层数的增加,梯度消失现象越来越严重,模型的性能经常不再增加甚至还有所下降,这使得模型的深度受到限制。

ResNet使用残差模块引入跳跃连接(skip connection)结构解决了这一问题。然而这不是什么复杂的运算,其公式简单地令人咂舌: 这里表示一个残差块的输入特征图,表示经过若干卷积块后的结果。

以前的模型直接将作为最终输出结果,而ResNet创造性的将其与相加的结果作为最终结果。对应的意义是最终想要的知识是,已经学到了知识,那么残差块里边的网络只需要学习到知识即可

Residual learning: a building block.

图1:一个ResNet残差块结构图

残差块结构

很多人(其实是我)可能没有注意到的一点是ResNet提供了两种残差块。第一种被称为building block,主要用于浅层网络,第二种被称为bottleneck,主要用于深层网络。两者结构如图2所示。

A deeper residual function

图2:左图为building block,右图为bottleneck

值得注意的是,上一部分的公式只是一个简化模型,并不严格与实现一致。实际上,残差块的最后一层的激活函数是在与相加后才使用的。

代码实现图2中的卷积块:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, k_size=3, activation=True):
'''
k_size: kernel size, by default is 3
activation: whether use activation
'''
super().__init__()
self.blocks = [
nn.Conv2d(in_channels, out_channels, kernel=k_size, stride=1, padding=(k_size-1)//2, bias=False),
nn.BatchNorm2d(out_channels)
]

if activation:
self.blocks.append(
nn.ReLU(inplace=True)
)

self.conv = nn.Sequential(*self.blocks)

def forward(self, x):
return self.conv(x)

building block

building block只堆叠了两个相同的卷积块,这样简单的结构适用于通道比较窄的情况。所以其主要用于浅层网络,如ResNet18和ResNet34。

代码实现:

1
2
3
4
5
6
7
8
9
10
11
class BuildingBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.block = nn.Sequential(
ConvBlock(channels, channels),
ConvBlock(channels, channels, activation=False)
)
self.act = nn.ReLU(inplace=True)

def forward(self, x):
return self.act(x + self.block(x))

bottleneck

相比于building block的简单结构,bottleneck设计则非常巧妙。在通道非常宽的情况下(如1024)直接堆叠卷积块将会使参数量非常巨大。

因此bottleneck堆叠了三个卷积块,其中第一个卷积块用于降低通道数量,第三个卷积块用于复原通道数量。第一个和第三个卷积块涉及到的通道比较宽因此使用卷积,第二个卷积块涉及到的通道比较窄因此使用卷积。

代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
class Bottleneck(nn.Module):
def __init__(self, in_channels, mid_channels):
super().__init__()
self.block = nn.Sequential(
ConvBlock(in_channels, mid_channels, k_size=1),
ConvBlock(mid_channels, mid_channels, k_size=3),
ConvBlock(mid_channels, in_channels, k_size=1, activation=False),
)
self.act = nn.ReLU(inplace=True)

def forward(self, x):
return self.act(x + self.block(x))

一些思考

ResNet的残差块设计个人觉得真是非常巧妙,所谓大道至简,最有效的东西未必总是要很复杂。

此外,bottleneck的设计也非常精彩,其可以在不影响性能的情况下有效减少参数量。在实际模型设计的时候我们经常要考虑参数量、运算量等,bottleneck的结构可以作为一个不错的参考。