首页 最新 热门 推荐

  • 首页
  • 最新
  • 热门
  • 推荐

BAM: Bottleneck Attention Module__实现

  • 25-03-03 21:22
  • 3543
  • 8721
blog.csdn.net

文章目录

  • BAM: Bottleneck Attention Module
    • 引言
    • 主要思想:
      • channel attention branch
      • Spatial attention branch
      • Combine two attention branches
    • Pytorch实现BAM

BAM: Bottleneck Attention Module

引言

  在此论文中,我们把重心放在了Attention对于一般深度神经网络的影响上, 我们提出了一个简单但是有效的Attention 模型—BAM,它可以结合到任何前向传播卷积神经网络中,我们的模型通过两个分离的路径 channel和spatial, 得到一个Attention Map.
整体结构图如下:
在这里插入图片描述

这里作者将BAM放在了Resnet网络中每个stage之间。有趣的是,通过可视化我们可以看到多层BAMs形成了一个分层的注意力机制,这有点像人类的感知机制。BAM在每个stage之间消除了像背景语义特征这样的低层次特征,然后逐渐聚焦于高级的语义–明确的目标(比如图中的狗).

主要思想:

channel attention branch

对于给定的feature map F ∈ R C ∗ H ∗ W \mathrm{F} \in R^{C * H * W} F∈RC∗H∗W,BAM可以得到一个3D的Attention map M ( F ) ∈ R C ∗ H ∗ W M(F)∈R^{C*H*W} M(F)∈RC∗H∗W.加强的feature map F′;
F ′ = F + F ⊗ M ( F ) \mathrm{F}^{\prime}=\mathrm{F}+\mathrm{F} \otimes \mathrm{M}(\mathrm{F}) F′=F+F⊗M(F)

为了设计一个有效且强大的模型,我们首先计算channel attention,然后计算spatial attention.这时M(F)就变成:
M ( F ) = σ ( M c ( F ) + M s ( F ) ) \mathrm{M}(\mathrm{F})=\sigma\left(\mathrm{M}_{c}(\mathrm{F})+\mathrm{M}_{s}(\mathrm{F})\right) M(F)=σ(Mc​(F)+Ms​(F))

这里σ 代表sigmoid函数,为了聚合feature map在每个通道维度,我们采用全局平均池化得到 F C F_{C} FC​这个向量然后对全局信息在每个通道进行软编码。为了评估Attention在每个通道的效果?我们使用了一个多层感知(MLP)用一层隐藏层。在MLP之后,我们增加了BN去调整规模和空间分支一样的输出,channel attention可以被计算为:

M c ( F ) = B N ( M L P ( AvgPool ( F ) ) ) \mathbf{M}_{\mathbf{c}}(\mathbf{F})=B N(M L P(\text {AvgPool}(\mathbf{F}))) Mc​(F)=BN(MLP(AvgPool(F)))
= B N ( W 1 ( W 0 A v g P o o l ( F ) + b 0 ) + b 1 ) =B N\left(\mathbf{W}_{1}\left(\mathbf{W}_{0} A v g P o o l(\mathbf{F})+\mathbf{b}_{0}\right)+\mathbf{b}_{1}\right) =BN(W1​(W0​AvgPool(F)+b0​)+b1​)
where W 0 ∈ R C / r × C , b 0 ∈ R C / r , W 1 ∈ R C × C / r , b 1 ∈ R C \mathbf{W}_{0} \in \mathbb{R}^{C / r \times C}, \mathbf{b}_{0} \in \mathbb{R}^{C / r}, \mathbf{W}_{1} \in \mathbb{R}^{C \times C / r}, \mathbf{b}_{1} \in \mathbb{R}^{C} W0​∈RC/r×C,b0​∈RC/r,W1​∈RC×C/r,b1​∈RC

Spatial attention branch

这个空间分支产生了空间Attention去增强或者抑制特征在不同的空间位置,众所周知,利用上下文信息是去知道应该关注哪些位置的关键点。在这里我们为了高效性运用空洞卷积去增大感受野。
我们观察到,与标准卷积相比,空洞卷积有助于构造更有效的spatial map.
细节图:

在这里插入图片描述

空洞模型结构 给与中间feature map F,这个module 计算对应的Attention mapM(F)通过两个单独的Attention 分支–channle Mc 和空间 M S \mathrm{M}_{S} MS​.这里有两个超参数 dilation value (d)和reduction ratio®. d参数决定了感受野大小,这对空间分支聚合上下文信息非常重要。这里我们set d=4 r=16.

我们采用空洞卷积来高效扩大感受野。我们观察到空洞卷积有助于构建比标准卷积更有效的空间映射。 我们的空间分支采用了ResNet建议的“瓶颈结构”,既节省了参数数量又节省了计算开销。 具体地,使用1×1卷积将特征 F ∈ R C × H × W \mathbf{F} \in \mathbb{R}^{C \times H \times W} F∈RC×H×W投影到缩小尺寸的 R C / r × H × W \mathbb{R}^{C / r \times H \times W} RC/r×H×W,以在整个通道维度上对特征图进行结合和压缩。 为简单起见,我们使用与通道分支相同的缩减比r。 在减少之后,应用两个3×3扩张卷积以有效地利用上下文信息。 最后,使用1×1卷积将特征再次简化为 R 1 × H × W \mathbb{R}^{1 \times H \times W} R1×H×W空间注意力图。 对于缩放调整,在空间分支的末尾应用批量标准化层。 简而言之,空间注意力计算如下:
M s ( F ) = B N ( f 3 1 × 1 ( f 2 3 × 3 ( f 1 3 × 3 ( f 0 1 × 1 ( F ) ) ) ) ) \mathbf{M}_{\mathbf{s}}(\mathbf{F})=B N\left(f_{3}^{1 \times 1}\left(f_{2}^{3 \times 3}\left(f_{1}^{3 \times 3}\left(f_{0}^{1 \times 1}(\mathbf{F})\right)\right)\right)\right) Ms​(F)=BN(f31×1​(f23×3​(f13×3​(f01×1​(F)))))
其中f表示卷积运算,BN表示批量归一化运算,上标表示卷积滤波器大小。 通道缩减有两个1×1卷积。中间3×3扩张卷积用于聚合具有较大感受野的上下文信息。

Combine two attention branches

在从两个注意力分支中获取通道注意力Mc(F)和空间注意力 M S \mathrm{M}_{S} MS​(F)后,我们将它们组合起来,生成最终的3D注意力mapM(F)。由于这两个注意图的形状不同,我们将注意图扩展到$ R^{CHW}$,然后将它们合并。在逐项求和、乘法、max运算等多种组合方法中,针对梯度流的特点,选择逐项求和。我们通过实证验证了基于元素的求和在三种选择中效果最好。求和后,我们取一个sigmoid函数,得到0到1范围内的最终三维注意映射M(F)。将该三维注意图与输入特征图F巧妙相乘,然后将其添加到原始输入特征图上,得到细化后的特征图F′
F ′ = F + F ⊗ M ( F ) \mathrm{F}^{\prime}=\mathrm{F}+\mathrm{F} \otimes \mathrm{M}(\mathrm{F}) F′=F+F⊗M(F)

Pytorch实现BAM

源码github链接

import torch
import math
import torch.nn as nn
import torch.nn.functional as F

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channel, reduction_ratio=16, num_layers=1):
        super(ChannelGate, self).__init__()
        self.gate_c = nn.Sequential()
        # after avg_pool
        self.gate_c.add_module('flatten', Flatten())
        gate_channels = [gate_channel]
        gate_channels += [gate_channel // reduction_ratio] * num_layers
        gate_channels += [gate_channel]
        for i in range(len(gate_channels) - 2):
            # fc->bn
            self.gate_c.add_module('gate_c_fc_%d'%i, nn.Linear(gate_channels[i], gate_channels[i+1]))
            self.gate_c.add_module('gate_c_bn_%d'%(i+1), nn.BatchNorm2d(gate_channels[i+1]))
            self.gate_c.add_module('gate_c_relu_%d'%(i+1), nn.ReLU())
        # final_fc
        self.gate_c.add_module('gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]))

    def forward(self, in_tensor):
        # Global avg pool
        avg_pool = F.avg_pool2d(in_tensor, in_tensor.size(2), stride=in_tensor.size(2))
        # C∗H∗W -> C*1*1 -> C*H*W
        return self.gate_c(avg_pool).unsqueeze(2).unsqueeze(3).expand_as(in_tensor)

class SpatiaGate(nn.Module):
    # dilation value and reduction ratio, set d = 4 r = 16
    def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4):
        self.gate_s = nn.Sequential()
        # 1x1 + (3x3)*2 + 1x1
        self.gate_s.add_module('gate_s_conv_reduce0', nn.Conv2d(gate_channel, gate_channel // reduction_ratio, kernel_size=1))
        self.gate_s.add_module('gate_s_bn_reduce0', nn.BatchNorm2d(gate_channel // reduction_ratio))
        self.gate_s.add_module('gate_s_relu_reduce0', nn.ReLU())
        for i in range(dilation_conv_num):
            self.gate_s.add_module('gate_s_conv_di_%d' % i, nn.Conv2d(gate_channel // reduction_ratio, gate_channel // reduction_ratio,
                                                             kernel_size=3, padding=dilation_val, dilation=dilation_val))
            self.gate_s.add_module('gate_s_bn_di_%d' % i, nn.BatchNorm2d(gate_channel // reduction_ratio))
            self.gate_s.add_module('gate_s_relu_di_%d' % i, nn.ReLU())
        self.gate_s.add_module('gate_s_conv_final', nn.Conv2d(gate_channel // reduction_ratio, 1, kernel_size=1))  # 1×H×W

    def forward(self, in_tensor):
        return self.gate_s(in_tensor).expand_as(in_tensor)

class BAM(nn.Module):
    def __init__(self, gate_channel):
        super(BAM, self).__init__()
        self.channel_att = ChannelGate(gate_channel)
        self.spatial_att = SpatiaGate(gate_channel)

    def forward(self, in_tensor):
        att = 1 + F.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor))
        return att * in_tensor
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

将BAM融入到Resnet中,在下一篇CBAM中展示源码。

注:本文转载自blog.csdn.net的Cpp编程小茶馆的文章"https://blog.csdn.net/xu_fu_yong/article/details/93313451"。版权归原作者所有,此博客不拥有其著作权,亦不承担相应法律责任。如有侵权,请联系我们删除。
复制链接
复制链接
相关推荐
发表评论
登录后才能发表评论和回复 注册

/ 登录

评论记录:

未查询到任何数据!
回复评论:

分类栏目

后端 (14832) 前端 (14280) 移动开发 (3760) 编程语言 (3851) Java (3904) Python (3298) 人工智能 (10119) AIGC (2810) 大数据 (3499) 数据库 (3945) 数据结构与算法 (3757) 音视频 (2669) 云原生 (3145) 云平台 (2965) 前沿技术 (2993) 开源 (2160) 小程序 (2860) 运维 (2533) 服务器 (2698) 操作系统 (2325) 硬件开发 (2491) 嵌入式 (2955) 微软技术 (2769) 软件工程 (2056) 测试 (2865) 网络空间安全 (2948) 网络与通信 (2797) 用户体验设计 (2592) 学习和成长 (2593) 搜索 (2744) 开发工具 (7108) 游戏 (2829) HarmonyOS (2935) 区块链 (2782) 数学 (3112) 3C硬件 (2759) 资讯 (2909) Android (4709) iOS (1850) 代码人生 (3043) 阅读 (2841)

热门文章

101
推荐
关于我们 隐私政策 免责声明 联系我们
Copyright © 2020-2025 蚁人论坛 (iYenn.com) All Rights Reserved.
Scroll to Top