pyTorch中Upsample和ConvTranspose区分

在pytorch中,nn.UpsampleConvTranspose(包含nn.ConvTranspose1dnn.ConvTranspose2dnn.ConvTranspose3d)均可以实现上采样。那么是模型设计上应该使用哪一种呢?

实际上,在模型中应该使用哪一种并无严格的约束条件,习惯上可以视设计的网络层的作用而定。

nn.Upsample仅通过插值实现,没有参数也不需要模型训练学习。如果只是想单纯实现特征图上采样或者比较在意模型参数量,那么nn.Upsample必然是很好的选择。

ConvTranspose则是通过转置卷积实现,会引入一定的参数量,故需要进行训练,相对于nn.Upsample其能获得更加细粒度的高频信息。如果想让模型学会如何上采样那么就可以考虑ConvTranspose,如在GAN中生成图像。

当然,两者的适用场景不是绝对的。以UNet为例,在原文中上采样用转置卷积完成,而后续很多实现则是用Upsample+\(1\times 1\)conv实现,这样的一个明显好处是相对于转置卷积来说参数量更少,而在性能上两者几乎是没有区别的。(参考Github作者@jvanvugt的讨论)。

区别总结

Upsample ConvTranspose
实现机制 插值 转置卷积,可训练学习
参数
可处理数据维度 1D、2D、3D ConvTranspose1d:1D ConvTranspose2d:2D ConvTranspose3d:3D
场景 分割、检测等特征上采样 GAN、高分辨率