在pytorch中,nn.Upsample
和ConvTranspose
(包含nn.ConvTranspose1d
、nn.ConvTranspose2d
和nn.ConvTranspose3d
)均可以实现上采样。那么是模型设计上应该使用哪一种呢?
实际上,在模型中应该使用哪一种并无严格的约束条件,习惯上可以视设计的网络层的作用而定。
nn.Upsample
仅通过插值实现,没有参数也不需要模型训练学习。如果只是想单纯实现特征图上采样或者比较在意模型参数量,那么nn.Upsample
必然是很好的选择。
ConvTranspose
则是通过转置卷积实现,会引入一定的参数量,故需要进行训练,相对于nn.Upsample
其能获得更加细粒度的高频信息。如果想让模型学会如何上采样那么就可以考虑ConvTranspose
,如在GAN中生成图像。
当然,两者的适用场景不是绝对的。以UNet为例,在原文中上采样用转置卷积完成,而后续很多实现则是用Upsample+\(1\times 1\)conv实现,这样的一个明显好处是相对于转置卷积来说参数量更少,而在性能上两者几乎是没有区别的。(参考Github作者@jvanvugt的讨论)。
区别总结:
Upsample | ConvTranspose | |
---|---|---|
实现机制 | 插值 | 转置卷积,可训练学习 |
参数 | 无 | 有 |
可处理数据维度 | 1D、2D、3D | ConvTranspose1d:1D ConvTranspose2d:2D ConvTranspose3d:3D |
场景 | 分割、检测等特征上采样 | GAN、高分辨率 |