全国服务热线:

15861139266

CoTLayer与CoTNet-50的Pytorch实现,苏州PLC培训,苏州上位机培训,苏州机器视觉培训,苏州工业机器人培训
发布时间:2023-05-04 11:49:50 点击次数:308

CoTLayer与CoTNet-50的Pytorch实现

CoTLayer部分

从上面的结构图中我们可以看到,CoT模块包含三个部分,我们需要先构建这三个基本模块,比较简单,用最基本的卷积操作可以搞定。


keys_embedding

Values_embedding

Attention_embedding

self.key_embed = nn.Sequential(

    # 通过K*K的卷积提取邻近上下文信息,视作输入X的静态上下文表达

    nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=1, stride=1, bias=False),

    nn.BatchNorm2d(dim),

    nn.ReLU()

)

self.value_embed = nn.Sequential(

    nn.Conv2d(dim, dim, kernel_size = 1, stride=1, bias=False),  # 1*1的卷积进行Value的编码

    nn.BatchNorm2d(dim)

)


factor = 4

self.attention_embed = nn.Sequential(  # 通过连续两个1*1的卷积计算注意力矩阵

    nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),  # 输入concat后的特征矩阵 Channel = 2*C

    nn.BatchNorm2d(2 * dim // factor),

    nn.ReLU(),

    nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1, stride=1)  # out: H * W * (K*K*C)

)

之后就是重写forward方法,这才是CoTLayer的关键。


首先得到Key和Value的编码


bs, c, h, w = x.shape

k1 = self.key_embed(x)  # shape:bs,c,h,w  提取静态上下文信息得到key

v = self.value_embed(x).view(bs, c, -1)  # shape:bs,c,h*w  得到value编码

使用torch.cat操作将key与输入x在channel纬度进行拼接,并得到注意力矩阵


y = torch.cat([k1, x], dim=1)  # shape:bs,2c,h,w  Key与Query在channel维度上进行拼接进行拼接

att = self.attention_embed(y)  # shape:bs,c*k*k,h,w  计算注意力矩阵

为了进行之后的静动态上下文信息的融合,需要把注意力矩阵进行reshape


att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)

att = att.mean(2, keepdim=False).view(bs, c, -1)  # shape:bs,c,h*w  求平均降低维度

k2 = F.softmax(att, dim=-1) * v  # 对每一个Channel进行softmax后

k2 = k2.view(bs, c, h, w)

最后return k1 + k2所谓此结构的输出


CoTNet替换ResNet-50

根据原文中的信息,我们只需要将Bottleneck中3*3的卷积替换为CoTLayer即可,但是具体实现起来还是会有一些问题,主要是涉及到图片的大小调整。



Figure4: CoTNet-50结构图

我们运行上面实现的CoTLayer会发现输入特征和输出特征的大小是不会改变的,但是ResNet-50会逐渐从224减小到1。为了解决此问题,我们需要加入额外的downsample操作。


if stride > 1:

    self.avd = nn.AvgPool2d(3, 2, padding=1)

else:

    self.avd = None

其他部分和ResNet-50结构相同,至此我们就实现了对CoTNet-50的复现,我将复现的代码放到了下面的链接中,可以自取之后替换自己项目中的ResNet-50查看效果。






1.png


立即咨询
  • 品质服务

    服务贴心周到

  • 快速响应

    全天24小时随时沟通

  • 专业服务

    授权率高,保密性强

  • 完善售后服务

    快速响应需求,及时性服务

直播课程
软件开发基础课程
上位机软件开发课
机器视觉软件开发课
专题课
联系方式
电话:15861139266
邮箱:75607082@qq.com
地址:苏州吴中区木渎镇尧峰路69号
关注我们

版权所有:江苏和讯自动化设备有限公司所有 备案号:苏ICP备2022010314号-1

技术支持: 易动力网络