首页 最新 热门 推荐

  • 首页
  • 最新
  • 热门
  • 推荐

复现DeepSeek V3——在V3官方代码库对MoE、MLA的推理代码之外,补充我对多token预测MTP训练代码的实现

  • 25-02-15 07:21
  • 4233
  • 9080
blog.csdn.net

前言

虽然我司从23年起,便逐步从教育为主转型到了科技为主,但不代表教育业务便没有了

随着DeepSeek特别是R1、其次V3模型的大火,我司七月在线的大模型线上营群里一学员朋友DIFY问道:校长好,deepseek 的课程目前有多少内容啦,我想要参与学习,想请问一下关于v3和r1复现的课程有吗,不用那么大参数量,小尺寸就好

实话讲,我一开始确实没咋重点考虑R1和V3复现的问题,一来,想着毕竟人家开源了,二来,即便有诸如Open R1这种复现,但效果和原装的相比还是差太多

但后来有三点改变了我的看法

  1. 对于V3、R1都没有开源他们最核心的训练数据、训练代码
    比如V3只是开源了模型权重、模型结构和推理脚本——比如本文前两个部分重点分析的作为推理时实例化模型用的model.py,它的整个文件 中的代码,都只是推理代码
  2. 虽然Open-R1 只是复现了R1正式版的前两个阶段(如此文所述,R1正式版 有4个阶段)
    虽然效果上 不会太好「所以之前没咋关注 因为对于作商用项目的我司来讲,其落地潜力有限」
    但毕竟只是一个从零开始的开源小项目 也没法要求太高,所以放到课程中 还是有一定的科研价值的
  3. 如此,综上可得,或如DIFY所说

加之,我已经 把deepseek各个模型的原理 写透彻了,接下来,确实准备抠下他们已经对外开源的部分代码,然后再带头组织我司部分同事及相关朋友,填补一下无论是V3、R1还是Open R1缺失的代码与流程

以上种种,使得本文来了

  1. 本文先做第一步:在V3官方代码库MoE、MLA的推理实现之外,补充我个人对多token预测MTP的训练实现(过程中AI打了30%的辅助)
  2. 下一篇在V3的基础上基于Open R1复现正式版的R1

最后,我特别强调一下,如果对deepseek各类模型及各类算法还不熟悉的话,强烈建议先看对应的原理:《火爆全球的DeepSeek系列模型》,可以看到

  1. 24年1.5日,DeepSeek LLM发布,没太多创新
    类似llama那一套「llama1的RoPE/RMSNorm/SwiGLU + llama2 70B或llama3的GQA」
  2. 24年1.11日,DeepSeekMoE,开启创新之路
    提出细粒度专家分割和共享专家隔离,以及一系列负载均衡
  3. 24年1.25,发布DeepSeek-Coder
    24年2月,发布DeepSeekMath
    提出了Group Relative Policy Optimization(简称GRPO),以替代PPO——舍弃critic模型
  4. 24年5.7日,DeepSeek-V2
    提出多头潜在注意力MLA且改进MoE
    其中的这个MLA是整个deepseek系列最大的几个创新之一,且由此引发了各大厂商百万token的大幅降价
  5. 24年12.26日,DeepSeek-V3发布
    在MoE、GRPO、MLA基础上提出Multi-Token预测,且含FP8训练
    大家纷纷把它和Llama 3.1 405B对比,V3以极低的训练成本造就超强的效果,再度出圈
  6. 25年1.20日,DeepSeek R1发布
    一方面,提出舍弃SFT、纯RL训练大模型的范式,且效果不错
    二方面,性能比肩o1甚至略微超越之
    三方面,直接公布思维链且免费,不藏着掖着,相比o1,对用户极度友好

    至此爆了,火爆全球

总之,原理熟悉之后,再看本文的源码实现,事半功倍——当然,我相信还是有「一帮」朋友就想直接看本文,所以我也在本文中会介绍部分原理,以尽可能让「这帮」朋友可以硬着头皮读下去

第一部分 V3对DeepSeekMoE的推理实现:涉及RoPE、MoE层、Norm层

通过此文《一文通透让Meta恐慌的DeepSeek-V3:在MoE、GRPO、MLA基础上提出Multi-Token预测(含FP8训练详解)》可知,在模型的架构层面,V3主要就在MoE、GRPO、MLA的基础上提出了Multi-Token预测

故先看V3对MoE的实现

根据MoE的结构可知,需要实现Norm层、attention层、MoE层,考虑到V3中的attention是潜在多头注意力——即MLA类实现了多头注意力层,支持低秩查询投影和键值投影,并根据配置选项选择不同的注意力实现,故放到下一部分中介绍(下图来源于Switch Transformers)

在本第一部分中,我们结合V3代码库中的model.py看下这几个部分的实现

  • precompute_freqs_cis函数预计算了用于旋转位置嵌入的频率复数指数值
  • apply_rotary_emb函数将旋转位置嵌入应用于输入张量
  • MLP类实现了一个多层感知机,用于前馈网络层
  • Gate类实现了一个门控机制,用于在专家模型中路由输入
  • Expert类实现了专家模型中的专家层
  • MoE类实现了专家模型模块,包含多个专家和一个共享专家
  • RMSNorm类实现了均方根层归一化,用于对输入张量进行归一化处理
  • Block类实现了Transformer块,结合了注意力层和前馈网络层

1.1 RoPE的推理实现

model.py中,关于RoPE的实现涉及以下两个函数

  • precompute_freqs_cis函数预计算了用于旋转位置嵌入的频率复数指数值
  • apply_rotary_emb函数将旋转位置嵌入应用于输入张量

关于RoPE的更多细节,详见此文《一文通透位置编码:从标准位置编码、旋转位置编码RoPE到ALiBi、LLaMA 2 Long(含NTK-aware简介)》

1.1.1 precompute_freqs_cis函数

precompute_freqs_cis函数用于预计算旋转位置嵌入的基于频率的复数指数值。该函数接收一个ModelArgs类型的参数args,其中包含了位置嵌入的相关参数。函数返回一个预计算的复数指数值的张量,用于位置嵌入

  1. def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
  2. """
  3. 预计算用于旋转位置嵌入的基于频率的复数指数值。
  4. 参数:
  5. args (ModelArgs): 包含位置嵌入参数的模型参数。
  6. 返回:
  7. torch.Tensor: 预计算的用于位置嵌入的复数指数值。
  8. """

函数首先从args中提取相关参数,包括嵌入维度dim、最大序列长度seqlen、快速和慢速beta修正因子beta_fast和beta_slow、基数base和缩放因子factor

  1. dim = args.qk_rope_head_dim # 获取查询键旋转嵌入的维度
  2. seqlen = args.max_seq_len # 获取最大序列长度
  3. beta_fast = args.beta_fast # 获取快速beta修正因子
  4. beta_slow = args.beta_slow # 获取慢速beta修正因子
  5. base = args.rope_theta # 获取旋转位置编码的基数
  6. factor = args.rope_factor # 获取扩展序列长度的缩放因子

接着,定义了三个辅助函数:find_correction_dim、find_correction_range和linear_ramp_factor

  1. find_correction_dim函数计算旋转位置嵌入中给定旋转次数的修正维度
    它使用输入参数计算修正维度,并返回该值
    1. def find_correction_dim(num_rotations, dim, base, max_seq_len):
    2. """
    3. 计算旋转位置嵌入中给定旋转次数的修正维度。
    4. 参数:
    5. num_rotations (float): 要计算修正的旋转次数
    6. dim (int): 嵌入空间的维度
    7. base (float): 指数计算的基数
    8. max_seq_len (int): 最大序列长度
    9. 返回:
    10. float: 基于输入参数的修正维度
    11. """
    12. return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) # 计算修正维度
  2. find_correction_range函数计算旋转位置嵌入的修正维度范围
    它接收旋转次数的上下界、嵌入维度、基数和最大序列长度作为参数,返回修正维度的范围
    1. def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
    2. """
    3. 计算旋转位置嵌入的修正维度范围
    4. 参数:
    5. low_rot (float): 旋转次数的下界
    6. high_rot (float): 旋转次数的上界
    7. dim (int): 嵌入空间的维度
    8. base (float): 指数计算的基数
    9. max_seq_len (int): 最大序列长度
    10. 返回:
    11. Tuple[int, int]: 修正维度的范围(低,高),并限制在有效索引范围内
    12. """
    13. low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) # 计算低修正维度
    14. high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) # 计算高修正维度
    15. return max(low, 0), min(high, dim-1) # 返回修正维度范围
  3. linear_ramp_factor函数计算用于在最小值和最大值之间平滑值的线性斜坡函数
    它返回一个张量,该张量的值在0和1之间线性插值,并限制在[0, 1]范围内
    1. def linear_ramp_factor(min, max, dim):
    2. """
    3. 计算用于在最小值和最大值之间平滑值的线性斜坡函数
    4. 参数:
    5. min (float): 斜坡函数的最小值
    6. max (float): 斜坡函数的最大值
    7. dim (int): 斜坡张量的维度
    8. 返回:
    9. torch.Tensor: 形状为(dim,)的张量,值在0和1之间线性插值,并限制在[0, 1]范围内。
    10. """
    11. if min == max: # 如果最小值等于最大值
    12. max += 0.001 # 增加最大值以避免除零错误
    13. linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) # 计算线性函数
    14. ramp_func = torch.clamp(linear_func, 0, 1) # 限制线性函数的值在0到1之间
    15. return ramp_func # 返回线性斜坡函数

接下来,函数计算频率值freqs,这些值是基于嵌入维度和基数的指数函数。如果序列长度大于原始序列长度,则应用修正范围和平滑因子来调整频率值

  1. # 计算频率值
  2. freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
  3. if seqlen > args.original_seq_len: # 如果序列长度大于原始序列长度
  4. low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) # 计算修正范围
  5. smooth = 1 - linear_ramp_factor(low, high, dim // 2) # 计算平滑因子
  6. freqs = freqs / factor * (1 - smooth) + freqs * smooth # 调整频率值

最后,函数计算时间步长t,并使用外积计算频率值的复数指数表示,返回预计算的复数指数值张量freqs_cis

  1. t = torch.arange(seqlen) # 生成时间步长
  2. freqs = torch.outer(t, freqs) # 计算频率值的外积
  3. freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # 计算频率值的复数指数表示
  4. return freqs_cis # 返回预计算的复数指数值

1.1.2 apply_rotary_emb的实现

apply_rotary_emb函数用于将旋转位置嵌入应用到输入张量x上。该函数接收两个参数:x是包含位置嵌入的输入张量,freqs_cis是预计算的复数指数值张量,用于位置嵌入

  1. def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
  2. """
  3. 将旋转位置嵌入应用于输入张量
  4. 参数:
  5. x (torch.Tensor): 包含要应用位置嵌入的输入张量
  6. freqs_cis (torch.Tensor): 预计算的用于位置嵌入的复数指数值
  7. 返回:
  8. torch.Tensor: 应用了旋转嵌入的张量
  9. """
  1. 首先,函数保存输入张量的原始数据类型dtype
        dtype = x.dtype  # 获取输入张量的数据类型
  2. 然后,将输入张量x转换为浮点类型,并重新调整其形状,使其最后一个维度的大小变为2,以便视为复数
        x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))  # 将输入张量视为复数
  3. 接着,函数将x视为复数张量函数将freqs_cis调整形状,使其与输入张量的形状匹配。具体来说,freqs_cis的形状调整为(1, 序列长度, 1, 嵌入维度/2),以便在后续计算中进行广播
        freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))  # 调整频率值的形状
  4. 然后,函数将输入张量x与freqs_cis相乘,得到应用了旋转位置嵌入的复数张量。接着,将结果转换回实数张量,并将其形状调整为原始形状
        y = torch.view_as_real(x * freqs_cis).flatten(3)  # 计算应用旋转嵌入后的张量
  5. 最后,函数将结果张量转换回原始数据类型,并返回该张量。这样,输入张量x就应用了旋转位置嵌入
        return y.to(dtype)  # 返回转换为原始数据类型的张量

1.2 对MoE层的推理实现:包含MLP类、Gate类、Expert类、MoE类

接下来,我们来看MoE的实现

涉及如下这几个函数的实现

  • MLP类实现了一个多层感知机,用于前馈网络层
  • Gate类实现了一个门控机制,用于在专家模型中路由输入
  • Expert类实现了专家模型中的专家层
  • MoE类实现了专家模型模块,包含多个专家和一个共享专家

1.2.1 MLP类的实现——多层感知机,用于前馈层

MLP类实现了一个多层感知机(MLP),用于前馈层。该类继承自nn.Module,并包含三个线性层:w1、w2和w3。这些线性层分别用于输入到隐藏层的转换、隐藏层到输出层的转换以及特征转换

  1. class MLP(nn.Module):
  2. """
  3. 多层感知机(MLP),用于前馈层
  4. 属性:
  5. w1 (nn.Module): 输入到隐藏层的线性层
  6. w2 (nn.Module): 隐藏层到输出层的线性层
  7. w3 (nn.Module): 额外的特征转换线性层
  8. """
  1. 在初始化方法__init__中
    MLP类接收两个参数:dim表示输入和输出的维度,inter_dim表示隐藏层的维度
    1. def __init__(self, dim: int, inter_dim: int):
    2. """
    3. 初始化MLP层。
    4. 参数
    5. dim (int): 输入和输出的维度
    6. inter_dim (int): 隐藏层的维度
    7. """
    w1和w3是列并行线性层(ColumnParallelLinear),用于将输入维度转换为隐藏层维度
    w2是行并行线性层(RowParallelLinear),用于将隐藏层维度转换回输入维度
    1. self.w1 = ColumnParallelLinear(dim, inter_dim) # 定义输入到隐藏层的列并行线性层
    2. self.w2 = RowParallelLinear(inter_dim, dim) # 定义隐藏层到输出层的行并行线性层
    3. self.w3 = ColumnParallelLinear(dim, inter_dim) # 定义额外的特征转换列并行线性层

1.2.2 门控网络Gate类的实现——输入路由的门控机制

Gate类实现了一个用于混合专家(MoE)模型中的输入路由的门控机制

一般就两个计算公式

类似此文《一文速览DeepSeekMoE:从Mixtral 8x7B到DeepSeekMoE(含DeepSeek LLM的简介)》所述,如果每个token选择2个专家,则门控网络的权重矩阵计算对应2个专家的权重,比如w1,w2,然后做softmax,最后与2个专家的输出expert1、expert做加权求和


类似
softmax(X × w1) × expert1 + softmax(X× w2) × expert2

该类继承自nn.Module,并包含多个属性

  1. class Gate(nn.Module):
  2. """
  3. 混合专家(MoE)模型中用于路由输入的门控机制。
  4. 属性:
  5. dim (int): 输入特征的维度
  6. topk (int): 每个输入激活的顶级专家数量
  7. n_groups (int): 路由组的数量
  8. topk_groups (int): 路由输入的组数
  9. score_func (str): 评分函数('softmax'或'sigmoid')
  10. route_scale (float): 路由权重的缩放因子
  11. weight (torch.nn.Parameter): 门控机制的可学习权重
  12. bias (Optional[torch.nn.Parameter]): 门控机制的可选偏置项
  13. """
  1. 在初始化方法__init__中,Gate类接收一个ModelArgs类型的参数args,其中包含了门控机制的参数
    1. def __init__(self, args: ModelArgs):
    2. """
    3. 初始化门控模块。
    4. 参数:
    5. args (ModelArgs): 包含门控参数的模型参数。
    6. """
    7. super().__init__() # 调用父类的初始化方法
    8. self.dim = args.dim # 设置输入特征的维度
    9. self.topk = args.n_activated_experts # 设置每个输入激活的顶级专家数量
    10. self.n_groups = args.n_expert_groups # 设置路由组的数量
    11. self.topk_groups = args.n_limited_groups # 设置路由输入的组数
    12. self.score_func = args.score_func # 设置评分函数
    13. self.route_scale = args.route_scale # 设置路由权重的缩放因子
    14. self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) # 初始化可学习权重
    15. self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None # 初始化可选偏置项
    根据这些参数,类初始化了各个属性,并创建了权重和偏置项的量
  2. 在前向传播方法forward中,Gate类接收一个输入张量x
    1. def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    2. """
    3. 门控机制的前向传播。
    4. 参数:
    5. x (torch.Tensor): 输入张量。
    6. 返回:
    7. Tuple[torch.Tensor, torch.Tensor]: 路由权重和选择的专家索引。
    8. """
    首先,输入张量通过线性变换函数linear与权重weight相乘,得到评分`score`
            scores = linear(x, self.weight)  # 计算输入张量与权重的线性变换,得到评分
    根据评分函数score_func的不同,评分可以通过softmax或sigmoid函数进行归一化
    1. if self.score_func == "softmax": # 如果评分函数是softmax
    2. scores = scores.softmax(dim=-1, dtype=torch.float32) # 对评分进行softmax归一化
    3. else:
    4. scores = scores.sigmoid() # 对评分进行sigmoid归一化
    然后,如果存在偏置项bias,则将其加到评分上
    1. original_scores = scores # 保存原始评分
    2. if self.bias is not None: # 如果存在偏置项
    3. scores = scores + self.bias # 将偏置项加到评分上
    接下来,如果路由组的数量n_groups大于1,评分将被重新调整形状,并计算每组的最大评分或前两个评分的和
    1. if self.n_groups > 1: # 如果路由组的数量大于1
    2. scores = scores.view(x.size(0), self.n_groups, -1) # 调整评分的形状
    3. if self.bias is None: # 如果没有偏置项
    4. group_scores = scores.amax(dim=-1) # 计算每组的最大评分
    5. else:
    6. group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) # 计算每组前两个评分的和
    然后,选择顶级组的索引,并创建一个掩码,将评分与掩码相乘并展平
    1. indices = group_scores.topk(self.topk_groups, dim=-1)[1] # 选择顶级组的索引
    2. mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True) # 创建掩码
    3. scores = (scores * mask.unsqueeze(-1)).flatten(1) # 将评分与掩码相乘并展平

1.2.3 Expert类的实现:MoE模型中的专家层

Expert类实现了混合专家(MoE)模型中的专家层。该类继承自nn.Module,并包含三个线性层:w1、w2和w3。这些线性层分别用于输入到隐藏层的转换、隐藏层到输出层的转换以及特征转换。

  1. class Expert(nn.Module):
  2. """
  3. 混合专家(MoE)模型中的专家层
  4. 属性:
  5. w1 (nn.Module): 输入到隐藏层的线性层
  6. w2 (nn.Module): 隐藏层到输出层的线性层
  7. w3 (nn.Module): 额外的特征转换线性层
  8. """
  1. 在初始化方法__init__中,Expert类接收两个参数:dim表示输入和输出的维度,inter_dim表示隐藏层的维度
    1. def __init__(self, dim: int, inter_dim: int):
    2. """
    3. 初始化专家层。
    4. 参数:
    5. dim (int): 输入和输出的维度
    6. inter_dim (int): 隐藏层的维度
    7. """
    8. super().__init__() # 调用父类的初始化方法
    w1是一个线性层,用于将输入维度转换为隐藏层维度
            self.w1 = Linear(dim, inter_dim)  # 定义输入到隐藏层的线性层
    w2是另一个线性层,用于将隐藏层维度转换回输入维度
            self.w2 = Linear(inter_dim, dim)  # 定义隐藏层到输出层的线性层
    w3是一个额外的线性层,用于特征转换
            self.w3 = Linear(dim, inter_dim)  # 定义额外的特征转换线性层
  2. 在前向传播方法forward中,Expert类接收一个输入张量x
    1. def forward(self, x: torch.Tensor) -> torch.Tensor:
    2. """
    3. 专家层的前向传播。
    4. 参数:
    5. x (torch.Tensor): 输入张量
    6. 返回:
    7. torch.Tensor: 经过专家层计算后的输出张量
    8. """
    首先,输入张量通过w1线性层,并应用SiLU激活函数(F.silu)
    然后,结果与通过w3线性层的输入张量相乘
    最后,乘积通过w2线性层,得到输出张量
    1. # 计算前向传播,应用SiLU激活函数并进行特征转换
    2. return self.w2(F.silu(self.w1(x)) * self.w3(x))

1.2.4 MoE类:实现了专家模型模块,包含多个专家和一个共享专家

首先,关于什么是共享专家,可以详见此文 《一文速览DeepSeekMoE:从Mixtral 8x7B到DeepSeekMoE(含DeepSeek LLM的简介)》所述

其次,我们来看V3代码库里的model.py中对这一部分的实现

  1. 首先定义MoE类
    1. class MoE(nn.Module):
    2. """
    3. 混合专家(MoE)模块。
    4. 属性:
    5. dim (int): 输入特征的维度。
    6. n_routed_experts (int): 模型中的专家总数。
    7. n_local_experts (int): 分布式系统中本地处理的专家数量。
    8. n_activated_experts (int): 每个输入激活的专家数量。
    9. gate (nn.Module): 用于将输入路由到专家的门控机制。
    10. experts (nn.ModuleList): 专家模块列表。
    11. shared_experts (nn.Module): 应用于所有输入的共享专家。
    12. """
  2. 其次,初始化MoE模块
    在初始化方法__init__中,MoE类接收一个ModelArgs类型的参数args,其中包含了MoE模块的参数
    1. def __init__(self, args: ModelArgs):
    2. """
    3. 初始化MoE模块。
    4. 参数:
    5. args (ModelArgs): 包含MoE参数的模型参数
    6. """
    首先,类初始化了各个属性,并断言专家总数n_routed_experts必须能被世界大小world_size整除
    1. super().__init__() # 调用父类的初始化方法
    2. self.dim = args.dim # 设置输入特征的维度
    3. assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})" # 确保专家数量可以被世界大小整除
    4. self.n_routed_experts = args.n_routed_experts # 设置模型中的专家总数
    然后,计算本地专家数量n_local_experts和专家的起始和结束索引
    1. # 计算本地处理的专家数量
    2. self.n_local_experts = args.n_routed_experts // world_size
    3. # 设置每个输入激活的专家数量
    4. self.n_activated_experts = args.n_activated_experts
    5. # 计算本地专家的起始索引
    6. self.experts_start_idx = rank * self.n_local_experts
    7. # 计算本地专家的结束索引
    8. self.experts_end_idx = self.experts_start_idx + self.n_local_experts
    接着,初始化门控机制gate,并创建专家模块列表experts和共享专家shared_experts
    1. # 初始化门控机制
    2. self.gate = Gate(args)
    3. self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
    4. # 初始化专家模块列表
    5. for i in range(self.n_routed_experts)])
    6. # 初始化共享专家
    7. self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
  3. 最后,前向传播
    在前向传播方法forward中,MoE类接收一个输入张量x
    1. def forward(self, x: torch.Tensor) -> torch.Tensor:
    2. """
    3. MoE模块的前向传播。
    4. 参数:
    5. x (torch.Tensor): 输入张量。
    6. 返回:
    7. torch.Tensor: 经过专家路由和计算后的输出张量。
    8. """
    首先,将输入张量调整为二维形状,并通过门控机制gate计算路由权重和选择的专家索引
    1. shape = x.size() # 获取输入张量的形状
    2. x = x.view(-1, self.dim) # 调整输入张量的形状
    3. weights, indices = self.gate(x) # 通过门控机制计算路由权重和专家索引
    然后,初始化一个与输入张量形状相同的零张量y,并计算每个专家的计数
    1. y = torch.zeros_like(x) # 初始化输出张量
    2. counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() # 计算每个专家的激活次数
    对于每个本地专家,如果计数不为零,则通过专家模块计算输出,并根据路由权重进行加权求和
    1. for i in range(self.experts_start_idx, self.experts_end_idx): # 遍历本地专家
    2. if counts[i] == 0: # 如果专家没有被激活
    3. continue # 跳过该专家
    4. expert = self.experts[i] # 获取专家模块
    5. idx, top = torch.where(indices == i) # 获取激活该专家的输入索引
    6. y[idx] += expert(x[idx]) * weights[idx, top, None] # 计算专家输出并加权累加到输出张量
    接着,通过共享专家shared_experts计算额外的输出z。如果世界大小world_size大于1,则对输出张量y进行全归约操作
    1. z = self.shared_experts(x) # 计算共享专家的输出
    2. if world_size > 1: # 如果是分布式系统
    3. dist.all_reduce(y) # 聚合所有进程的输出
    最后,将输出张量y和z相加,并调整回原始形状,返回最终输出
            return (y + z).view(shape)  # 返回专家输出和共享专家输出的和,并调整回原始形状

总结一下,这种设计的三个好处是

  1. 分布式效率:每个进程只负责部分专家的计算,使用all_reduce实现结果同步
  2. 负载均衡:通过门控机制动态分配计算任务,确保计算资源的高效利用
  3. 内存优化:使用`None`占位未分配的专家,按需计算,跳过未使用的专家

1.3 Norm层的推理实现:RMSNorm

推理脚本中 还有关于均方根层归一化(RMSNorm)的推理实现

  1. 首先,定义RMSNorm类
    1. class RMSNorm(nn.Module):
    2. """
    3. 均方根层归一化(RMSNorm)。
    4. 参数:
    5. dim (int): 输入张量的维度。
    6. eps (float): 用于数值稳定性的epsilon值,默认为1e-6。
    7. """
  2. 其次,定义__init__方法
    1. def __init__(self, dim: int, eps: float = 1e-6):
    2. # 调用父类的初始化方法
    3. super().__init__()
    4. # 设置输入张量的维度
    5. self.dim = dim
    6. # 设置用于数值稳定性的epsilon值
    7. self.eps = eps
    8. # 初始化权重参数,初始值为全1
    9. self.weight = nn.Parameter(torch.ones(dim))
  3. 最后,定义forward方法
    1. def forward(self, x: torch.Tensor):
    2. """
    3. RMSNorm的前向传播
    4. 参数:
    5. x (torch.Tensor): 输入张量
    6. 返回:
    7. torch.Tensor: 归一化后的张量,形状与输入相同
    8. """
    9. # 调用F.rms_norm函数进行归一化处理
    10. return F.rms_norm(x, (self.dim,), self.weight, self.eps)

第二部分 V3对多头潜在注意力MLA的推理代码实现

2.1 对多头潜在注意力MLA原理的回顾

关于对MLA原理的介绍,我已经在这篇《一文通透DeepSeek V2——通俗理解多头潜在注意力MLA:改进MHA,从而压缩KV缓存,提高推理速度》文章中做了详尽、深入、细致的解读

这篇针对MLA的解读,我花了很大的心思、精力,建议好好看看,当你反复琢磨我解读的该文及其中的MLA后,也可以和我一样:脱离v2论文,手绘其图、手推其图背后的公式

2.2 对MLA推理代码的逐行分析

这段代码实现了一个多头注意力层(Multi-Headed Attention Layer, MLA),用于处理输入特征并生成注意力权重

2.2.1 初始化方法__init__的实现

在初始化方法__init__中,类接收一个ModelArgs类型的参数args,其中包含了MLA模块的参数

  1. def __init__(self, args: ModelArgs):
  2. super().__init__() # 调用父类的初始化方法
  3. self.dim = args.dim # 设置输入特征的维度
  4. self.n_heads = args.n_heads # 设置注意力头的数量
  5. self.n_local_heads = args.n_heads // world_size # 计算本地处理的注意力头数量
  6. self.q_lora_rank = args.q_lora_rank # 设置低秩查询投影的秩
  7. self.kv_lora_rank = args.kv_lora_rank # 设置低秩键值投影的秩
  8. # 设置无位置嵌入的查询键投影的维度
  9. self.qk_nope_head_dim = args.qk_nope_head_dim
  10. # 设置旋转位置嵌入的查询键投影的维度
  11. self.qk_rope_head_dim = args.qk_rope_head_dim
  12. # 计算查询键投影的总维度
  13. self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
  14. # 设置值投影的维度
  15. self.v_head_dim = args.v_head_dim

接下来分别是查询投影、键值投影、输出投影、softmax缩放因子、缓存的初始化

  1. 查询投影
    根据self.q_lora_rank的值选择不同的查询投影实现

    这里得解释一下,论文中明明说的要对查询向量做低秩,因为可以降低计算成本,但在具体实现的时候,为何V3官方代码库还允许对查询向量不做低秩呢?
    原因很简单,即凡事有利有弊,做低秩的好处是降低计算成本,但不太好的是没法保留更多的特征信息,当然 实际情况一般还是会选择做低秩,毕竟降低成本带来的好处更有用


    故才有
    \rightarrow  如果self.q_lora_rank为0,则使用ColumnParallelLinear进行查询投影,初始化self.wq
    1. if self.q_lora_rank == 0:
    2. # 初始化列并行查询投影层
    3. self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
    \rightarrow  否则,先通过Linear进行低秩查询投影,初始化self.wq_a,再通过RMSNorm进行归一化,初始化self.q_norm
    1. else:
    2. # 初始化低秩查询投影层
    3. self.wq_a = Linear(self.dim, self.q_lora_rank)
    4. # 初始化查询投影的归一化层
    5. self.q_norm = RMSNorm(self.q_lora_rank)
          最后通过ColumnParallelLinear进行查询投影,初始化self.wq_b
    1. # 初始化列并行查询投影层
    2. self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
  2. 键值投影
    先后通过Linear进行键值投影,初始化self.wkv_a,然后通过RMSNorm进行键值投影归一化,初始化self.kv_norm,最后通过ColumnParallelLinear进行键值投影,初始化self.wkv_b
    1. # 初始化键值投影层
    2. self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
    3. # 初始化键值投影的归一化层
    4. self.kv_norm = RMSNorm(self.kv_lora_rank)
    5. # 初始化列并行键值投影层
    6. self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
  3. 输出投影
    通过RowParallelLinear进行输出投影,初始化self.wo
    1. # 初始化行并行输出投影层
    2. self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
  4. Softmax缩放因子
    计算Softmax的缩放因子,初始化self.softmax_scale
    如果最大序列长度大于原始序列长度,则调整缩放因子
    1. # 计算softmax的缩放因子
    2. self.softmax_scale = self.qk_head_dim ** -0.5
    3. if args.max_seq_len > args.original_seq_len:
    4. # 计算缩放因子
    5. mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
    6. # 调整softmax的缩放因子
    7. self.softmax_scale = self.softmax_scale * mscale * mscale
  5. 缓存初始化
    根据注意力实现类型(attn_impl),选择不同的缓存策略
    如果使用`naive`实现,则初始化键缓存self.k_cache和值缓存self.v_cache——本质就是直接缓存健和值的中间结果
    1. if attn_impl == "naive":
    2. # 初始化键缓存
    3. self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
    4. # 初始化值缓存
    5. self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
    否则,初始化键值缓存self.kv_cache和位置嵌入缓存self.pe_cache——本质是对健值进行了低秩投影优化
    1. else:
    2. # 初始化键值缓存
    3. self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
    4. # 初始化位置嵌入缓存
    5. self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)

总之,MLA这套初始化的设计,可以

  1. 通过列并行和行并行的线性层,实现分布式计算。
  2. 支持低秩查询投影和键值投影,适应不同的模型配置
  3. 根据注意力实现类型,选择不同的缓存策略,减少内存占用

2.2.2 前向传播方法forward方法的实现

在前向传播方法forward中,其接收输入张量,并通过一系列计算生成输出张量

  1. def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
  2. """
  3. Multi-Headed Attention Layer (MLA) 的前向传播
  4. 参数:
  5. x (torch.Tensor): 输入张量,形状为 (batch_size, seq_len, dim)
  6. start_pos (int): 序列中用于缓存的起始位置
  7. freqs_cis (torch.Tensor): 预计算的旋转位置嵌入的复数指数值
  8. mask (Optional[torch.Tensor]): 可选的掩码张量,用于排除某些位置的注意力计算
  9. 返回:
  10. torch.Tensor: 输出张量,形状与输入相同

以下是对这段代码的详细解读:

  1. 输入张量的形状
    获取输入张量的批次大小 (bsz)、序列长度 (seqlen) 和特征维度 (_)
    计算序列的结束位置 (end_pos)
    1. # 获取输入张量的批次大小、序列长度和特征维度
    2. bsz, seqlen, _ = x.size()
    3. # 计算序列的结束位置
    4. end_pos = start_pos + seqlen
  2. 查询投影
    根据 q_lora_rank 的值选择不同的查询投影实现——至于为何这么做的原因,上文已经说明过了,故此处不再赘述
    如果 q_lora_rank为 0,则使用 wq 进行查询投影,否则,先通过 wq_a 进行低秩查询投影,再通过 q_norm 进行归一化,最后通过 wq_b 进行查询投影
    1. # 根据 q_lora_rank 的值选择不同的查询投影实现
    2. if self.q_lora_rank == 0:
    3. # 使用全秩投影
    4. q = self.wq(x)
    5. else:
    6. # 使用低秩投影
    7. q = self.wq_b(self.q_norm(self.wq_a(x)))
    将查询投影结果调整为四维张量,并拆分为无位置嵌入部分 (q_nope) 和旋转位置嵌入部分 (q_pe)
    且对其中的旋转位置嵌入部分q_pe:应用旋转位置嵌入 (apply_rotary_emb)
    1. # 将查询投影结果调整为四维张量
    2. q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
    3. # 拆分查询投影结果为无位置嵌入部分和旋转位置嵌入部分
    4. q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
    5. # 对旋转位置嵌入部分应用旋转位置嵌入
    6. q_pe = apply_rotary_emb(q_pe, freqs_cis)
  3. 键值投影
    通过 wkv_a进行键值投影,并拆分为键值部分 (kv) 和旋转位置嵌入部分 (k_pe)
    并对其中的旋转位置嵌入部分k_pe:应用旋转位置嵌入 (apply_rotary_emb)
    1. # 进行键值投影
    2. kv = self.wkv_a(x)
    3. # 拆分键值投影结果为键值部分和旋转位置嵌入部分
    4. kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
    5. # 对旋转位置嵌入部分应用旋转位置嵌入
    6. k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
  4. 注意力计算
    根据注意力实现类型 (attn_impl),选择不同的注意力计算方法
    \rightarrow  如果使用 `naive` 实现:
            将查询的无位置嵌入部分和旋转位置嵌入部分拼接
            通过 wkv_b进行键值投影归一化
            将键值投影结果调整为四维张量,并拆分为键值部分 (k_nope) 和值部分 (v)
            将键值部分和旋转位置嵌入部分拼接,并缓存键值和值
           计算查询和键值的点积,得到注意力得分 (scores)
    1. # 根据注意力实现类型选择不同的注意力计算方法
    2. if attn_impl == "naive":
    3. # 将查询的无位置嵌入部分和旋转位置嵌入部分拼接
    4. q = torch.cat([q_nope, q_pe], dim=-1)
    5. # 进行键值投影归一化
    6. kv = self.wkv_b(self.kv_norm(kv))
    7. # 将键值投影结果调整为四维张量
    8. kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
    9. # 拆分键值投影结果为键值部分和值部分
    10. k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
    11. # 将键值部分和旋转位置嵌入部分拼接
    12. k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
    13. # 缓存键和值
    14. self.k_cache[:bsz, start_pos:end_pos] = k
    15. self.v_cache[:bsz, start_pos:end_pos] = v
    16. # 计算查询和键的点积,得到注意力得分
    17. scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
    \rightarrow  否则:
            对键值投影结果进行权重反量化,并调整为三维张量
            计算查询和键值的点积,得到注意力得分 (scores)
    1. else:
    2. # 对键值投影结果进行权重反量化
    3. wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
    4. # 调整为三维张量
    5. wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
    6. # 计算查询和键的点积
    7. q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
    8. # 缓存键值
    9. self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
    10. # 缓存位置嵌入
    11. self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
    12. # 计算注意力得分
    13. scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
    14. torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
  5. 掩码应用
    如果存在掩码张量,则将其加到注意力得分上
    1. # 如果存在掩码张量,则将其加到注意力得分上
    2. if mask is not None:
    3. scores += mask.unsqueeze(1)
  6. 注意力权重计算
    对注意力得分应用 softmax
    1. # 对注意力得分应用softmax
    2. scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
    然后根据注意力实现类型计算输出张量
    \rightarrow  如果使用 `naive` 实现,属于直接实现的注意力机制,计算简单,但在大规模数据上效率偏低
            计算注意力权重和值的点积,得到输出张量
    1. # 根据注意力实现类型计算输出张量
    2. if attn_impl == "naive":
    3. # 计算注意力权重和值的点积
    4. x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
    \rightarrow  否则:考虑优化过的注意力机制,比如低秩注意力
            计算注意力权重和键值的点积,再计算与值的点积,得到输出张量
    1. else:
    2. # 计算注意力权重和键值的点积
    3. x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
    4. # 计算与值的点积
    5. x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
  7. 输出投影
    通过 wo 进行输出投影,计算最终输出张量,并返回
    1. # 进行输出投影
    2. x = self.wo(x.flatten(2))
    3. # 返回最终输出张量
    4. return x

第三部分 我个人对多token预测MTP的训练代码实现:严格按照V3技术报告来

比较遗憾的是,V3官方代码库里 并没有对MTP技术的完整实现

  1. 如我司大模型同事阿荀所说,MTP只是属于训练期间设定的损失函数和额外结构,官方没有提供训练代码,这里边应该也意味着不提供MTP的实现
  2. meta 倒是有个mtp实现,但如此文 《一文通透让Meta恐慌的DeepSeek-V3:在MoE、GRPO、MLA基础上提出Multi-Token预测(含FP8训练详解)》的「1.2.3 多token预测:Multi-Token Prediction——显著加快模型的解码速度」的开头所说
    “受Gloeckle等人「其对应的论文为《Better & Faster Large Language Models via Multi-token Prediction》,这是由Meta团队发在ICML 2024的一篇Poster」的启发,他们为DeepSeek-V3研究并设置了一个多token预测(MTP)目标,该目标将预测范围扩展到每个位置的多个未来token”

    相当于ds的mtp实现和meta的mtp实现 有点区别

故咱们得自己来实现下,但实现的过程中要尽可能和V3官方代码库的风格一致——毕竟 我们最终希望可以实地用起来,避免只是做个示例展示而已

3.1 对多token预测MTP原理的回顾

实现之前,首先通过此文《一文通透让Meta恐慌的DeepSeek-V3:在MoE、GRPO、MLA基础上提出Multi-Token预测(含FP8训练详解)》的「1.2.3 多token预测:Multi-Token Prediction——显著加快模型的解码速度」来回顾下MTP的核心原理

3.1.1 对MTP核心原理的理解

我个人觉得啊,无论是V3技术报告中,还是Gloeckle等人(2024年)原始论文中对Multi-Token Prediction的描述对初学者都不友好,很容易看晕——就快到谁看谁晕乎的程度了,我一开始看 也晕乎了一会,为了更好的理解,我还是给大家举个例子吧

据我所知,截止到25年1.7日之前,下面这个例子在全网也是首例了,过程中还和同事阿荀做了深入的讨论/确认


比如下图所示,完整序列是t1-t7,当前主模块考虑的输入序列为t1,​t2​,t3​,t4,然后预测t5,t6,t7

由于当k = 1 时,\mathbf{h}_{i}^{k-1}指的是由主模型给出的表示,故有

对于输入token t1​,主模型生成表示 h_{1}^{0}

对于输入token t2​,主模型生成表示 h_{2}^{0}

对于输入token t3,主模型生成表示 h_{3}^{0}

对于输入token t4,主模型生成表示 h_{4}^{0}

  • 对于MTP Module 1的预测(注,是如下图第2个模块所示),k = 1
    h_{1}^{0}并t2预测t3(或者说,t2辅助h_{1}^{0}预测t3)
    h_{2}^{0}并t3预测t4(或者说,t3辅助h_{2}^{0}预测t4)
    h_{3}^{0}并t4预测t5
    h_{4}^{0}并t5预测t6

    根据公式21(记住一点,\mathbf{h}的下标 i 永远和主模型的输入下标一致,即 i 一直等于1 或2 或3 或4)
    \mathbf{h}_{i}^{\prime k}=M_{k}\left[\operatorname{RMSNorm}\left(\mathbf{h}_{i}^{k-1}\right) ; \operatorname{RMSNorm}\left(\operatorname{Emb}\left(t_{i+k}\right)\right)\right]
    可以得到各个token的输入表示
    将 t1的主模型表示 h_{1}^{0} 和 t2​ 的嵌入 Emb(t2)结合,通过公式 21 计算得到 \mathbf{h}_{1}^{\prime 1}
    将 t2的主模型表示 h_{2}^{0}​ 和 t3 的嵌入 Emb(t3)结合,通过公式 21 计算得到\mathbf{h}_{2}^{\prime 1}
    将 t3的主模型表示 h_{3}^{0} 和 t4​ 的嵌入 Emb(t4)结合,通过公式 21 计算得到 \mathbf{h}_{3}^{\prime 1}
    将 t4的主模型表示 h_{4}^{0} 和 t5​ 的嵌入 Emb(t5)结合,通过公式 21 计算得到 \mathbf{h}_{4}^{\prime 1}

    根据公式22\mathbf{h}_{1: T-k}^{k}=\operatorname{TRM}_{k}\left(\mathbf{h}_{1: T-k}^{\prime k}\right),可得,对于transformer处理
    将 \mathbf{h}_{1}^{\prime 1} 输入到 Transformer 块 TRM1 中,得到 h_{1}^{1}
    将 \mathbf{h}_{2}^{\prime 1} 输入到 Transformer 块 TRM1​ 中,得到 h_{2}^{1}
    将 \mathbf{h}_{3}^{\prime 1} 输入到 Transformer 块 TRM1 中,得到 h_{3}^{1}
    将 \mathbf{h}_{4}^{\prime 1} 输入到 Transformer 块 TRM1 中,得到 h_{4}^{1}

    根据公式23P_{i+k+1}^{k}=\operatorname{OutHead}\left(\mathbf{h}_{i}^{k}\right),可得,对于输出头预测
    将 h_{1}^{1}​ 输入到输出头 OutHead 中,得到 t3​ 的预测概率 P_{3}^{1}
    将 h_{2}^{1}​ 输入到输出头 OutHead 中,得到 t4​ 的预测概率 P_{4}^{1}
    将 h_{3}^{1} 输入到输出头 OutHead 中,得到 t5​ 的预测概率 P_{5}^{1}
    将 h_{4}^{1} 输入到输出头 OutHead 中,得到 t6​ 的预测概率 P_{6}^{1}
  •  对于MTP Module 2的预测(注,如下图第3个模块所示),k = 2
    h_{1}^{1}并t3预测t4(或者说,t3辅助h_{1}^{1}预测t4)
    h_{2}^{1}并t4预测t5
    h_{3}^{1}并t5预测t6
    h_{4}^{1}并t6预测t7

    输入表示:
    \mathbf{h}_{i}^{\prime k}=M_{k}\left[\operatorname{RMSNorm}\left(\mathbf{h}_{i}^{k-1}\right) ; \operatorname{RMSNorm}\left(\operatorname{Emb}\left(t_{i+k}\right)\right)\right]
    将  h_{1}^{1}​ 和 t3​ 的嵌入 Emb(t3) 结合,通过公式 21 计算得到 \mathbf{h}_{1}^{\prime 2}
    将  h_{2}^{1}​ 和 t4​ 的嵌入 Emb(t4) 结合,通过公式 21 计算得到 \mathbf{h}_{2}^{\prime 2}
    将  h_{3}^{1}​ 和 t5​ 的嵌入 Emb(t5) 结合,通过公式 21 计算得到 \mathbf{h}_{3}^{\prime 2}
    将  h_{4}^{1}​ 和 t6​ 的嵌入 Emb(t6) 结合,通过公式 21 计算得到 \mathbf{h}_{4}^{\prime 2}

    Transformer 处理:
    \mathbf{h}_{1: T-k}^{k}=\operatorname{TRM}_{k}\left(\mathbf{h}_{1: T-k}^{\prime k}\right)
    将 \mathbf{h}_{1}^{\prime 2} 输入到 Transformer 块 TRM2​ 中,得到 h_{1}^{2}
    将 \mathbf{h}_{2}^{\prime 2} 输入到 Transformer 块 TRM2​ 中,得到 h_{2}^{2}
    将 \mathbf{h}_{3}^{\prime 2} 输入到 Transformer 块 TRM2​ 中,得到 h_{3}^{2}
    将 \mathbf{h}_{4}^{\prime 2} 输入到 Transformer 块 TRM2​ 中,得到 h_{4}^{2}

    输出头预测:
    P_{i+k+1}^{k}=\operatorname{OutHead}\left(\mathbf{h}_{i}^{k}\right)
    将 h_{1}^{2}​ 输入到输出头 OutHead 中,得到 t4 的预测概率 P_{4}^{2}
    将 h_{2}^{2}​ 输入到输出头 OutHead 中,得到 t5 的预测概率 P_{5}^{2}
    将 h_{3}^{2}​ 输入到输出头 OutHead 中,得到 t6 的预测概率 P_{6}^{2}
    将 h_{4}^{2}​ 输入到输出头 OutHead 中,得到 t7 的预测概率 P_{7}^{2}

我们再把上面这整个过程

弄到一个统一的大表格里下,以示一目了然

主模型表示对于MTP Module 1的预测(注,是如下图第2个模块所示),k = 1对于MTP Module 2的预测(注,如下图第3个模块所示),k = 2

由于当k = 1 时,\mathbf{h}_{i}^{k-1}指的是由主模型给出的表示,故有

对于输入token t1​,主模型生成表示 h_{1}^{0}

对于输入token t2​,主模型生成表示 h_{2}^{0}

对于输入token t3,主模型生成表示 h_{3}^{0}

对于输入token t4,主模型生成表示 h_{4}^{0}

输入表示
\mathbf{h}_{i}^{\prime k}=M_{k}\left[\operatorname{RMSNorm}\left(\mathbf{h}_{i}^{k-1}\right) ; \operatorname{RMSNorm}\left(\operatorname{Emb}\left(t_{i+k}\right)\right)\right]
将 t1的主模型表示 h_{1}^{0} 和 t2​ 的嵌入 Emb(t2)结合,通过公式 21 计算得到 \mathbf{h}_{1}^{\prime 1}
将 t2的主模型表示 h_{2}^{0}​ 和 t3 的嵌入 Emb(t3)结合,通过公式 21 计算得到\mathbf{h}_{2}^{\prime 1}
将 t3的主模型表示 h_{3}^{0} 和 t4​ 的嵌入 Emb(t4)结合,通过公式 21 计算得到 \mathbf{h}_{3}^{\prime 1}
将 t4的主模型表示 h_{4}^{0} 和 t5​ 的嵌入 Emb(t5)结合,通过公式 21 计算得到 \mathbf{h}_{4}^{\prime 1}

输入表示:
\mathbf{h}_{i}^{\prime k}=M_{k}\left[\operatorname{RMSNorm}\left(\mathbf{h}_{i}^{k-1}\right) ; \operatorname{RMSNorm}\left(\operatorname{Emb}\left(t_{i+k}\right)\right)\right]
将  h_{1}^{1}​ 和 t3​ 的嵌入 Emb(t3) 结合,通过公式 21 计算得到 \mathbf{h}_{1}^{\prime 2}
将  h_{2}^{1}​ 和 t4​ 的嵌入 Emb(t4) 结合,通过公式 21 计算得到 \mathbf{h}_{2}^{\prime 2}
将  h_{3}^{1}​ 和 t5​ 的嵌入 Emb(t5) 结合,通过公式 21 计算得到 \mathbf{h}_{3}^{\prime 2}
将  h_{4}^{1}​ 和 t6​ 的嵌入 Emb(t6) 结合,通过公式 21 计算得到 \mathbf{h}_{4}^{\prime 2}

Transformer 处理:\mathbf{h}_{1: T-k}^{k}=\operatorname{TRM}_{k}\left(\mathbf{h}_{1: T-k}^{\prime k}\right)
将 \mathbf{h}_{1}^{\prime 1} 输入到 Transformer 块 TRM1 中,得到 h_{1}^{1}
将 \mathbf{h}_{2}^{\prime 1} 输入到 Transformer 块 TRM1​ 中,得到 h_{2}^{1}
将 \mathbf{h}_{3}^{\prime 1} 输入到 Transformer 块 TRM1 中,得到 h_{3}^{1}
将 \mathbf{h}_{4}^{\prime 1} 输入到 Transformer 块 TRM1 中,得到 h_{4}^{1}

Transformer 处理:
\mathbf{h}_{1: T-k}^{k}=\operatorname{TRM}_{k}\left(\mathbf{h}_{1: T-k}^{\prime k}\right)
将 \mathbf{h}_{1}^{\prime 2} 输入到 Transformer 块 TRM2​ 中,得到 h_{1}^{2}
将 \mathbf{h}_{2}^{\prime 2} 输入到 Transformer 块 TRM2​ 中,得到 h_{2}^{2}
将 \mathbf{h}_{3}^{\prime 2} 输入到 Transformer 块 TRM2​ 中,得到 h_{3}^{2}
将 \mathbf{h}_{4}^{\prime 2} 输入到 Transformer 块 TRM2​ 中,得到 h_{4}^{2}

输出头预测:P_{i+k+1}^{k}=\operatorname{OutHead}\left(\mathbf{h}_{i}^{k}\right)

将 h_{1}^{1}​ 输入到输出头 OutHead 中,得到 t3​ 的预测概率 P_{3}^{1}
将 h_{2}^{1}​ 输入到输出头 OutHead 中,得到 t4​ 的预测概率 P_{4}^{1}
将 h_{3}^{1} 输入到输出头 OutHead 中,得到 t5​ 的预测概率 P_{5}^{1}
将 h_{4}^{1} 输入到输出头 OutHead 中,得到 t6​ 的预测概率 P_{6}^{1}

输出头预测:
P_{i+k+1}^{k}=\operatorname{OutHead}\left(\mathbf{h}_{i}^{k}\right)
将 h_{1}^{2}​ 输入到输出头 OutHead 中,得到 t4 的预测概率 P_{4}^{2}
将 h_{2}^{2}​ 输入到输出头 OutHead 中,得到 t5 的预测概率 P_{5}^{2}
将 h_{3}^{2}​ 输入到输出头 OutHead 中,得到 t6 的预测概率 P_{6}^{2}
将 h_{4}^{2}​ 输入到输出头 OutHead 中,得到 t7 的预测概率 P_{7}^{2}

3.1.2 MTP的训练目标

对于每个预测深度,他们计算一个交叉熵损失\mathcal{L}_{\mathrm{MTP}}^{k} :

\mathcal{L}_{\mathrm{MTP}}^{k}=\operatorname{CrossEntropy}\left(P_{2+k: T+1}^{k}, t_{2+k: T+1}\right)=-\frac{1}{T} \sum_{i=2+k}^{T+1} \log P_{i}^{k}\left[t_{i}\right]

其中T 表示输入序列长度,t_i表示第i 个位置的真实token,P_{i}^{k}\left[t_{i}\right]表示由第k 个MTP 模块给出的t_i 的相应预测概率

最后,他们计算所有深度上的MTP 损失的平均值,并将其乘以一个权重因子\lambda,以获得总体MTP 损失\mathcal{L}_{\mathrm{MTP}} ,这作为DeepSeek-V3 的附加训练目标

\mathcal{L}_{\mathrm{MTP}}=\frac{\lambda}{D} \sum_{k=1}^{D} \mathcal{L}_{\mathrm{MTP}}^{k}

3.2 对MTP技术的多轮实现——coding By July和AI

3.2.1 小试牛刀:先做一轮简单实现

正如R1解答用户问题之前,会先经过一轮长时间的推理/思考、拆解/分析,而这个推理/思考的过程,可以很好的帮助很多人提高分析问题、解决问题的能力

为了更好的和大家一块成长,我也没必要一上来就给大家一个完美的实现——毕竟所有的强大与伟大都不是一蹴而就的 包括2年多前的ChatGPT以及本文的R1(看本文开头便知,R1发布之前,deepseek已经经历了不少大大小小的创新)

  1. 那就先小试牛刀,先不考虑V3已有的官方代码库,先对MTP做一轮简单的实现,以让对原理有个更好的了解「当我们对原理有更好的理解,然后对V3官方代码库已有的结构有更好的研究之后,我们便能写出完美匹配官方库的实现 」
  2. 过程中有30%的部分得到了AI的辅助,相当于代码是由我个人和AI完成的

具体步骤如下

  1. 引入相关库
    1. import torch
    2. import torch.nn as nn
    3. from transformers import RMSNorm
    4. class MTPModule(nn.Module):

    先做初始化——注意,这里暂时没考虑V3的MoE架构,而是简单粗暴的先暂用标准的transformer架构,即我先故意一切从简,但下一节会修改

    1. def __init__(self, d_model, vocab_size, num_layers, n_head):
    2. super().__init__()
    3. # 预测深度
    4. self.D = num_layers
    5. // 从主模型共享嵌入、输出头
    6. self.shared_emb = nn.Embedding(vocab_size, d_model)
    7. self.shared_out = nn.Linear(d_model, vocab_size)
    8. # 初始化MTP组件
    9. self.M = nn.ModuleList([
    10. nn.Linear(2*d_model, d_model, bias=False) for _ in range(num_layers)
    11. ])
    12. self.trm_blocks = nn.ModuleList([
    13. nn.TransformerEncoderLayer(d_model, n_head) for _ in range(num_layers)
    14. ])
    15. # 使用RMSNorm
    16. self.rms_norm = RMSNorm(d_model)
  2. 然后是前向传播函数的实现

    1. def forward(self, hidden_states, token_ids):
    2. """
    3. hidden_states: [T, B, D] 主模型输出
    4. token_ids: [T, B] 输入token IDs
    5. """
    6. T, B = token_ids.shape
    7. total_loss = 0.0

    根据MTP结构图

    与公式21
    \mathbf{h}_{i}^{\prime k}=M_{k}\left[\operatorname{RMSNorm}\left(\mathbf{h}_{i}^{k-1}\right) ; \operatorname{RMSNorm}\left(\operatorname{Emb}\left(t_{i+k}\right)\right)\right]
    比如 将 t1的主模型表示 h_{1}^{0} 和 t2​ 的嵌入 Emb(t2)结合,通过公式 21 计算得到 \mathbf{h}_{1}^{\prime 1}

    可得代码应该如下编写——实现的时候,要注意,将h_i^{k-1}和Emb(t_{i+k})先各自进行RMSNorm后,再拼接

    1. for k in range(1, self.D+1):
    2. # 步骤1:组合表示(公式21)
    3. prev_hidden = self.rms_norm(hidden_states) # h_i^{k-1}
    4. next_emb = self.shared_emb(token_ids[k:]) # Emb(t_{i+k})
    5. next_emb = self.rms_norm(next_emb)
    6. combined = torch.cat([prev_hidden[:-k], next_emb], dim=-1) # [T-k, B, 2D]

    拼接之后,再做投影

                projected = self.M[k-1](combined)  # M_k投影
  3. 接着,做Transformer 处理:\mathbf{h}_{1: T-k}^{k}=\operatorname{TRM}_{k}\left(\mathbf{h}_{1: T-k}^{\prime k}\right)
    将 \mathbf{h}_{1}^{\prime 1} 输入到 Transformer 块 TRM1 中,得到 h_{1}^{1}

    1. # 步骤2:Transformer处理(公式22)
    2. trm_out = self.trm_blocks[k-1](
    3. projected.permute(1,0,2) # 转换为[B, T-k, D]
    4. ).permute(1,0,2) # 恢复[T-k, B, D]
  4.  最后,输出头预测:P_{i+k+1}^{k}=\operatorname{OutHead}\left(\mathbf{h}_{i}^{k}\right)

    1. # 步骤3:计算预测(公式23)
    2. logits = self.shared_out(trm_out) # [T-k, B, V]
  5. 损失计算
    根据V3技术报告可知,对于每个预测深度,他们计算一个交叉熵损失\mathcal{L}_{\mathrm{MTP}}^{k} (如下公式24所示)

    \mathcal{L}_{\mathrm{MTP}}^{k}=\operatorname{CrossEntropy}\left(P_{2+k: T+1}^{k}, t_{2+k: T+1}\right)=-\frac{1}{T} \sum_{i=2+k}^{T+1} \log P_{i}^{k}\left[t_{i}\right]

    其中T 表示输入序列长度,t_i表示第i 个位置的真实token,P_{i}^{k}\left[t_{i}\right]表示由第k 个MTP 模块给出的t_i 的相应预测概率
    可得

    1. # 计算损失(公式24)
    2. targets = token_ids[k+1:].reshape(-1) # 预测目标为i+k+1
    3. loss = nn.functional.cross_entropy(
    4. logits.view(-1, logits.size(-1)),
    5. targets,
    6. reduction='mean'
    7. )
    8. total_loss += loss
  6. 最后,他们计算所有深度上的MTP 损失的平均值,并将其乘以一个权重因子\lambda,以获得总体MTP 损失​ ,这作为DeepSeek-V3 的附加训练目标​

    \mathcal{L}_{\mathrm{MTP}}=\frac{\lambda}{D} \sum_{k=1}^{D} \mathcal{L}_{\mathrm{MTP}}^{k}

    相当于再做加权
    1. # 最终加权损失(公式25)
    2. return total_loss * (0.3 / self.D) # λ=0.3

3.2.2 完美融合:匹配V3官方代码库已有结构的MTP实现

根据DeepSeek-V3官方实现代码的架构风格,需要进行以下关键修改来实现无缝集成:

  1. 库的引入
    1. import torch
    2. import torch.nn as nn
    3. from deepseek_v3_modules import (
    4. DeepseekRMSNorm,
    5. MoETransformerLayer, # 使用项目中的MoE层代替标准Transformer
    6. RotaryEmbedding, # 使用项目自实现的RoPE
    7. FP8Linear # 使用项目中的FP8量化层
    8. )
  2. 初始化
    1. class MTPModule(nn.Module):
    2. def __init__(self, config):
    3. super().__init__()
    4. # 对齐项目参数命名规范
    5. self.depth = config.mtp_depth # 从config读取D值
    6. self.hidden_size = config.hidden_size
    7. # 使用项目自实现的组件 (与model.py保持一致)
    8. self.rms_norm = DeepseekRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
    9. self.rope = RotaryEmbedding(dim=self.hidden_size // config.num_attention_heads)
    10. # 与主模型共享参数——即共享嵌入、共享输出头 (参考model.py的Embedding实现)
    11. self.shared_emb = None # 将在外部绑定
    12. self.shared_out = None
    下面这里 得改动了,如上面所说的,毕竟V3是MoE架构,非标准的transformer架构
    1. # 使用项目中的MoE层 (替换原始Transformer层)
    2. self.mtp_layers = nn.ModuleList([
    3. MoETransformerLayer(
    4. config,
    5. layer_idx=layer_idx,
    6. is_mtp=True # 添加特殊标记
    7. ) for layer_idx in range(self.depth)
    8. ])
    且使用项目中的FP8线性层——以匹配原V3报告的3.3节实现
    1. # 使用项目中的FP8线性层
    2. self.proj_layers = nn.ModuleList([
    3. FP8Linear(
    4. 2 * self.hidden_size,
    5. self.hidden_size,
    6. fp8_params=config.fp8_params
    7. ) for _ in range(self.depth)
    8. ])
  3. 对于前向传播而言
    1. def forward(self, hidden_states, input_ids):
    2. """
    3. 对齐项目输入输出格式:
    4. hidden_states: [batch_size, seq_len, hidden_size]
    5. input_ids: [batch_size, seq_len]
    6. """
    7. batch_size, seq_len = input_ids.shape
    8. total_loss = 0.0
    匹配V3中model.py相关格式的前提下,先分别对hidden_states和next_emb做RMSNorm,然后应用RoPE
    1. for k in range(1, self.depth + 1):
    2. # 1. 组合表示 (适配项目维度格式)
    3. prev_hidden = self.rms_norm(hidden_states[:, :-k, :]) # [B, T-k, D]
    4. next_emb = self.shared_emb(input_ids[:, k:]) # [B, T-k, D]
    5. next_emb = self.rms_norm(next_emb)
    6. # 2. 应用RoPE (与model.py中的处理一致)
    7. prev_hidden = self.rope(prev_hidden)
    8. next_emb = self.rope(next_emb)
    然后做拼接,做完拼接做投影
    1. # 3. 先拼接,后线性投影
    2. combined = torch.cat([prev_hidden, next_emb], dim=-1) # [B, T-k, 2D]
    3. projected = self.proj_layers[k-1](combined)
    接下来
    1. # 4. 使用MoE层 (对齐项目实现)
    2. trm_out = self.mtp_layers[k-1](
    3. projected,
    4. attention_mask=None, # 假设因果掩码在外部处理
    5. position_ids=None # 与model.py中处理一致
    6. )[0]
    再其次,输出头做预测,且计算对应的损失
    1. # 5. 计算损失
    2. logits = self.shared_out(trm_out) # [B, T-k, V]
    3. targets = input_ids[:, k+1:].reshape(-1)
    4. loss = nn.functional.cross_entropy(
    5. logits.view(-1, logits.size(-1)),
    6. targets,
    7. reduction='mean'
    8. )
    9. total_loss += loss
    最后
    1. # 动态lambda处理 (匹配4.3节训练策略)
    2. lambda_weight = 0.3 if self.training else 0.0
    3. return total_loss * (lambda_weight / self.depth)

至于如何与V3官方代码库中的推理文件model.py搭配,以及如何验证是否正确(上面的实现还是有些小问题的),暂见 《DeepSeek原理与项目实战营》中,本文后续再考虑是否更新

最后我说一下,虽然AI在上述的实现中只占了30%,但确实帮我省心了,可能有的同学好奇这个AI到底是哪个模型,嗯,非常非常的不难猜到:没错,过程中我主要就用的R1——通过Google账号登录

// 待更

注:本文转载自blog.csdn.net的v_JULY_v的文章"https://blog.csdn.net/v_JULY_v/article/details/145552931"。版权归原作者所有,此博客不拥有其著作权,亦不承担相应法律责任。如有侵权,请联系我们删除。
复制链接
复制链接
相关推荐
发表评论
登录后才能发表评论和回复 注册

/ 登录

评论记录:

未查询到任何数据!
回复评论:

分类栏目

后端 (14832) 前端 (14280) 移动开发 (3760) 编程语言 (3851) Java (3904) Python (3298) 人工智能 (10119) AIGC (2810) 大数据 (3499) 数据库 (3945) 数据结构与算法 (3757) 音视频 (2669) 云原生 (3145) 云平台 (2965) 前沿技术 (2993) 开源 (2160) 小程序 (2860) 运维 (2533) 服务器 (2698) 操作系统 (2325) 硬件开发 (2492) 嵌入式 (2955) 微软技术 (2769) 软件工程 (2056) 测试 (2865) 网络空间安全 (2948) 网络与通信 (2797) 用户体验设计 (2592) 学习和成长 (2593) 搜索 (2744) 开发工具 (7108) 游戏 (2829) HarmonyOS (2935) 区块链 (2782) 数学 (3112) 3C硬件 (2759) 资讯 (2909) Android (4709) iOS (1850) 代码人生 (3043) 阅读 (2841)

热门文章

101
推荐
关于我们 隐私政策 免责声明 联系我们
Copyright © 2020-2024 蚁人论坛 (iYenn.com) All Rights Reserved.
Scroll to Top