首页 最新 热门 推荐

  • 首页
  • 最新
  • 热门
  • 推荐

Vision Transformer(ViT)

  • 25-03-02 13:44
  • 2782
  • 6287
blog.csdn.net

文章目录

  • 一、ViT整体结构
    • 结构简单说明
  • 二、ViT分解说明
    • Embedding层
    • Encoder
    • MLP Head
  • 三、ViT简洁实现
    • Attention
    • transformer
    • ViT
    • 完整代码

论文链接: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

一、ViT整体结构

ViT总体结构

结构简单说明

首先关注ViT的输入
一张图片会被分成一个个小的patch,如ViT-L/16表示每个patch大小为16×16,然后将每个patch输入到Embedding层(Linear Projection of Flattened Patches),通过Embedding层后可以得到其对应的向量,称为token,图中一张图片划分为9个patches,在经过Embedding层后得到了9个token embedding。
紧接着,我们会在这一系列token的最前面加上一个新的token *(class token),它的维度与前面得到的是一样的。
原始的token加入class token与位置信息后,将其输入到transformer encoder。

输入部分
transformer encoder的结构如下:
transformer encoder
ViT的transformer encoder的操作是将Encoder Block重复堆叠了L次,然后提取class token对应的输出输入到如下的MLP Head中进行分类,最后得到分类结果。
MLP Head输出

二、ViT分解说明

根据上面的阐述,可以看出整个ViT可以分为三大部分

  • Embedding层
  • Encoder
  • MLP Head 用于分类

处理流程为

  1. 将图片切分为patch
  2. patch转换为embedding
  3. 位置embedding和token embedding相加
  4. 输入到ViT模型
  5. CLS输出做多分类

Embedding层

对于标准的transformer模块,它接收的是token embedding向量,变化过程如下图1、2、3标注

embedding操作
对于编码部分,共有三个操作

  1. 生成class符号的token(图中*标记)
  2. 生成所有序列的位置编码(图中淡紫色)
  3. token embedding + 位置编码

图中首先将原始图片变换为多个patch,每个patch大小为3×16×16。再将其展平为token embedding,维度为768,patch转换为embedding需要两个操作:

  • 将patch拉平
  • 将patch拉平后的维度映射到 encoder需要的维度

在这一系列embedding的首部加入cls token,然后生成位置编码,并将位置编码与token embedding相加得到最终的输入embedding。

关于位置编码:
在transformer中,编码器是并行输入的,不会等待之前信息的输出情况,所以需要位置编码提供信息的位置信息,在ViT中,表示图像patch的前后信息。

Encoder

ViT的Encoder模块与原始transformer中的类似。
根据论文中Encoder,结合具体实现可得出Encoder Block

Encoder详细结构
与原始transformer的Encoder输入比较

transformer Encoder输入
将LN操作提前了,同时,因为将图片切分为patches,保证patch的大小一致,所以没有了padding操作。

MLP Head

MLP Head结合代码实现来理解会清晰很多

class Mlp(nn.Module):
   """
   MLP as used in Vision Transformer, MLP-Mixer and related networks
   """
   def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
       super().__init__()
       out_features = out_features or in_features
       hidden_features = hidden_features or in_features
       self.fc1 = nn.Linear(in_features, hidden_features)
       self.act = act_layer()
       self.fc2 = nn.Linear(hidden_features, out_features)
       self.drop = nn.Dropout(drop)

   def forward(self, x):
       x = self.fc1(x)
       x = self.act(x)
       x = self.drop(x)
       x = self.fc2(x)
       x = self.drop(x)
       return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

可以看到,MLP仅由GELU激活函数、全连接层和DropOut层组成,作用就是对Encoder的输出进行多分类处理。

三、ViT简洁实现

对几个关键模块结构进行解释。

Attention

attention模块与transformer类似,实现多头注意力multi head机制,在forward函数中,通过to_qkv和chunk函数一次生成总体的Q、K、V矩阵,再划分为多头注意力的q、k、v,这一点与原始transformer不同,原始transformer是通过Linear层各自生成Q、K、V,这个差别的原因在于ViT无需解码。

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        # 对tensor张量分块 eg. x :1 197 1024
        # 通过to_qkv操作将维度提升至原维度的3倍
        # qkv 最后是一个元组,tuple,长度是3,每个元素形状:1 197 1024
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        print('qkv is ', qkv)
        # 将q,k,v矩阵分头(multi head)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        # attention计算
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        # 与V矩阵相乘
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

对比原始transformer的多头注意力机制:

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        # 输入进来的QKV是相等的,使用映射Linear做一个映射分别得到参数矩阵Wq, Wk,Wv
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
        self.linear = nn.Linear(n_heads * d_v, d_model)
        self.layer_norm = nn.LayerNorm(d_model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

transformer

遵循论文中架构,堆叠L个Encoder。
transformer encoder

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])

        # 堆叠depth个encoder
        for _ in range(depth):
            # 根据论文结构中搭建
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

ViT

ViT实现图片分割为patch、patch展平并添加位置编码,同时映射为输入embedding,并对各个模块进行组装,代码见下方完整代码部分。

完整代码

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        # 对tensor张量分块 x :1 197 1024
        # 通过to_qkv操作将维度提升至原维度的3倍
        # qkv 最后是一个元组,tuple,长度是3,每个元素形状:1 197 1024
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        print('qkv is ', qkv)
        # 将q,k,v矩阵分头(multi head)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        # attention计算
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        # 与V矩阵相乘
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])

        # 堆叠depth个encoder
        for _ in range(depth):
            # 根据论文结构中搭建
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()

        # 原图像的大小
        image_height, image_width = pair(image_size) ## 224*224
        # patch的大小
        patch_height, patch_width = pair(patch_size)## 16 * 16

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 对应论文中提到的patch数目:num_patches=H*W/P^2
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 对应论文中,将patch展平
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 将拉平后的patch映射为Encoder需要的维度dim
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        # 生成位置编码,包括cls token和所有patch对应token的位置编码
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 生成cls token的初始化参数
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        # 完成最后的分类
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        # img 1 3 224 224 ——> 输出形状x : 1 196 1024
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 复制cls符号使每个batch_size都有一个cls符号
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 拼接cls_token与patch embedding
        x = torch.cat((cls_tokens, x), dim=1)
        # 拼接后每个token加上位置信息
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        # 若pool为mean,所有token输出池化;若为cls符号,取cls符号(切片第0个元素)
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        # 全连接分类
        return self.mlp_head(x)



v = ViT(
    image_size = 224,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6, # Encoder堆叠个数
    heads = 16, # multi head的head数目
    mlp_dim = 2048, # feed forward维度
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166

这里使用了einops库帮助解决张量的运算,提升张量运算的可读性,具体可参看以下文章或者官方文档来学习使用。

博文:einops 理解
官方文档:einops GitHub

这里只是一个简单的ViT模型帮助理解,实际训练的ViT要更复杂,可以参看此代码
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py

本文参考:
Vision Transformer详解

文章知识点与官方知识档案匹配,可进一步学习相关知识
Python入门技能树人工智能深度学习394431 人正在系统学习中
注:本文转载自blog.csdn.net的yizhi_hao的文章"https://blog.csdn.net/qq_41533576/article/details/121107247"。版权归原作者所有,此博客不拥有其著作权,亦不承担相应法律责任。如有侵权,请联系我们删除。
复制链接
复制链接
相关推荐
发表评论
登录后才能发表评论和回复 注册

/ 登录

评论记录:

未查询到任何数据!
回复评论:

分类栏目

后端 (14832) 前端 (14280) 移动开发 (3760) 编程语言 (3851) Java (3904) Python (3298) 人工智能 (10119) AIGC (2810) 大数据 (3499) 数据库 (3945) 数据结构与算法 (3757) 音视频 (2669) 云原生 (3145) 云平台 (2965) 前沿技术 (2993) 开源 (2160) 小程序 (2860) 运维 (2533) 服务器 (2698) 操作系统 (2325) 硬件开发 (2492) 嵌入式 (2955) 微软技术 (2769) 软件工程 (2056) 测试 (2865) 网络空间安全 (2948) 网络与通信 (2797) 用户体验设计 (2592) 学习和成长 (2593) 搜索 (2744) 开发工具 (7108) 游戏 (2829) HarmonyOS (2935) 区块链 (2782) 数学 (3112) 3C硬件 (2759) 资讯 (2909) Android (4709) iOS (1850) 代码人生 (3043) 阅读 (2841)

热门文章

101
推荐
关于我们 隐私政策 免责声明 联系我们
Copyright © 2020-2024 蚁人论坛 (iYenn.com) All Rights Reserved.
Scroll to Top