CoTLayer与CoTNet-50的Pytorch实现
CoTLayer部分
从上面的结构图中我们可以看到,CoT模块包含三个部分,我们需要先构建这三个基本模块,比较简单,用最基本的卷积操作可以搞定。
keys_embedding
Values_embedding
Attention_embedding
self.key_embed = nn.Sequential(
# 通过K*K的卷积提取邻近上下文信息,视作输入X的静态上下文表达
nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=1, stride=1, bias=False),
nn.BatchNorm2d(dim),
nn.ReLU()
)
self.value_embed = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size = 1, stride=1, bias=False), # 1*1的卷积进行Value的编码
nn.BatchNorm2d(dim)
)
factor = 4
self.attention_embed = nn.Sequential( # 通过连续两个1*1的卷积计算注意力矩阵
nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False), # 输入concat后的特征矩阵 Channel = 2*C
nn.BatchNorm2d(2 * dim // factor),
nn.ReLU(),
nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1, stride=1) # out: H * W * (K*K*C)
)
之后就是重写forward方法,这才是CoTLayer的关键。
首先得到Key和Value的编码
bs, c, h, w = x.shape
k1 = self.key_embed(x) # shape:bs,c,h,w 提取静态上下文信息得到key
v = self.value_embed(x).view(bs, c, -1) # shape:bs,c,h*w 得到value编码
使用torch.cat操作将key与输入x在channel纬度进行拼接,并得到注意力矩阵
y = torch.cat([k1, x], dim=1) # shape:bs,2c,h,w Key与Query在channel维度上进行拼接进行拼接
att = self.attention_embed(y) # shape:bs,c*k*k,h,w 计算注意力矩阵
为了进行之后的静动态上下文信息的融合,需要把注意力矩阵进行reshape
att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
att = att.mean(2, keepdim=False).view(bs, c, -1) # shape:bs,c,h*w 求平均降低维度
k2 = F.softmax(att, dim=-1) * v # 对每一个Channel进行softmax后
k2 = k2.view(bs, c, h, w)
最后return k1 + k2所谓此结构的输出
CoTNet替换ResNet-50
根据原文中的信息,我们只需要将Bottleneck中3*3的卷积替换为CoTLayer即可,但是具体实现起来还是会有一些问题,主要是涉及到图片的大小调整。
Figure4: CoTNet-50结构图
我们运行上面实现的CoTLayer会发现输入特征和输出特征的大小是不会改变的,但是ResNet-50会逐渐从224减小到1。为了解决此问题,我们需要加入额外的downsample操作。
if stride > 1:
self.avd = nn.AvgPool2d(3, 2, padding=1)
else:
self.avd = None
其他部分和ResNet-50结构相同,至此我们就实现了对CoTNet-50的复现,我将复现的代码放到了下面的链接中,可以自取之后替换自己项目中的ResNet-50查看效果。
版权所有:江苏和讯自动化设备有限公司所有 备案号:苏ICP备2022010314号-1
技术支持: 易动力网络