首页 最新 热门 推荐

  • 首页
  • 最新
  • 热门
  • 推荐

Batch Normalization梯度反向传播推导

  • 25-03-03 22:02
  • 4343
  • 7208
blog.csdn.net

最近在看CS231N的课程,同时也顺带做配套的作业,在Assignment2 中关于Batch Normalization的具体数学过程则困惑了很久,通过参看一些博客自己推导了一遍,供大家参考。

Batch Normalization

首先,关于Batch Normalization的具体实现过程就不在此介绍了,想了解的可以参看论文或者博客。
对于Batch Normalization的前向传播可以参看下图的过程,它主要思路就是将每个Batch的输入根据均值μBμB 和方差2B2B 进行归一化,然后再进行尺度缩放到yiyi

对于前向传播网络,可以很直观的给出实现代码

def batchnorm_forward(x, gamma, beta, bn_param):
  """
  Input:
  - x: (N, D)维输入数据
  - gamma: (D,)维尺度变化参数 
  - beta: (D,)维尺度变化参数
  - bn_param: Dictionary with the following keys:
    - mode: 'train' 或者 'test'
    - eps: 一般取1e-8~1e-4
    - momentum: 计算均值、方差的更新参数
    - running_mean: (D,)动态变化array存储训练集的均值
    - running_var:(D,)动态变化array存储训练集的方差

  Returns a tuple of:
  - out: 输出y_i(N,D)维
  - cache: 存储反向传播所需数据
  """
  mode = bn_param['mode']
  eps = bn_param.get('eps', 1e-5)
  momentum = bn_param.get('momentum', 0.9)

  N, D = x.shape
  # 动态变量,存储训练集的均值方差
  running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
  running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))

  out, cache = None, None
  # TRAIN 对每个batch操作
  if mode == 'train':
    sample_mean = np.mean(x, axis = 0)
    sample_var = np.var(x, axis = 0)
    x_hat = (x - sample_mean) / np.sqrt(sample_var + eps)
    out = gamma * x_hat + beta
    cache = (x, gamma, beta, x_hat, sample_mean, sample_var, eps)
    running_mean = momentum * running_mean + (1 - momentum) * sample_mean
    running_var = momentum * running_var + (1 - momentum) * sample_var
  # TEST:要用整个训练集的均值、方差
  elif mode == 'test':
    x_hat = (x - running_mean) / np.sqrt(running_var + eps)
    out = gamma * x_hat + beta
  else:
    raise ValueError('Invalid forward batchnorm mode "%s"' % mode)

  bn_param['running_mean'] = running_mean
  bn_param['running_var'] = running_var

  return out, cache
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

上述代码基于CS231N Assignment2,值得注意的是Batch Normalization对于在训练和测试阶段的计算方法不一样,因为训练阶段的均值和方差是基于一个Batch的数据,而测试阶段是基于整个训练集求得。

梯度反向传播

Batch Normalization最让人头疼的就是理清楚反向传播梯度并写成代码,当然它依然遵循链式求导法则。首先我们基于上图,将变量定义如下:

  • σσ 为一个batch所有样本的方差
  • μμ 为样本均值
  • ˆxxˆ 为归一化后的样本数据
  • yi 为输入样本xi 经过尺度变化的输出量
  • γ 和β 为尺度变化系数
  • ∂L∂y 为已知,并假设x 和y 都为(N,D)维,即有N个维度为D的样本

由于网络正向传播是根据γ β 和ˆx 将xi 变换为yi ,那么反向传播则是根据∂L∂yi 求得∂L∂γ ∂L∂β 和∂L∂xi 。

∂L∂γ=∑i∂L∂yi∂yi∂γ=∑i∂L∂yiˆxi

∂L∂β=∑i∂L∂yi∂yi∂β=∑i∂L∂yi

上面两个式子都涉及到Batch中的N个样本的累加,因为N个样本的 yi 对 β γ 都有影响。

直接求∂L∂xi 步骤比较长,不直观,且μ(x) 、σ(x) 、ˆx(x) ,因此我们首先求∂L∂ˆx 、∂L∂μ 和∂L∂σ :

∂L∂ˆx=∂L∂y∂y∂ˆx=∂L∂yγ

∂L∂σ=∑i∂L∂yi∂yi∂ˆxi∂ˆxi∂σ=−12∑i∂L∂^xi(xi−μ)(σ+ε)−1.5

∂L∂μ=∂L∂ˆx∂ˆx∂μ+∂L∂σ∂σ∂μ=∑i∂L∂ˆxi−1√σ+ε+∂L∂σ−2Σi(xi−μ)N

下面,就可以求 ∂L∂xi 啦:

∂L∂xi=∂L∂^xi∂^xi∂xi+∂L∂σ∂σ∂xi+∂L∂μ∂μ∂xi=∂L∂ˆxi1√σ+ε+∂L∂σ2(xi−μ)N+∂L∂μ1N

在上面的式子中我写成∂L∂xi 而不是∂L∂x 是为了方便理解,当然在代码中我们会表示成后者以提高计算速度。至此,我们就完成了Batch Normalization的梯度反向传播的全过程,并得到论文给出的结果:

这里写图片描述

下面,我们就根据上面的步骤来完成代码:

def batchnorm_backward(dout, cache):
  """
  Inputs:
  - dout: 上一层的梯度,维度(N, D),即 dL/dy
  - cache: 所需的中间变量,来自于前向传播

  Returns a tuple of:
  - dx: (N, D)维的 dL/dx
  - dgamma: (D,)维的dL/dgamma
  - dbeta: (D,)维的dL/dbeta
  """
    x, gamma, beta, x_hat, sample_mean, sample_var, eps = cache
  N = x.shape[0]

  dgamma = np.sum(dout * x_hat, axis = 0)
  dbeta = np.sum(dout, axis = 0)

  dx_hat = dout * gamma
  dsigma = -0.5 * np.sum(dx_hat * (x - sample_mean), axis=0) * np.power(sample_var + eps, -1.5)
  dmu = -np.sum(dx_hat / np.sqrt(sample_var + eps), axis=0) - 2 * dsigma*np.sum(x-sample_mean, axis=0)/ N
  dx = dx_hat /np.sqrt(sample_var + eps) + 2.0 * dsigma * (x - sample_mean) / N + dmu / N

  return dx, dgamma, dbeta
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

附:两个有用的博客 这里 和 这里

注:本文转载自blog.csdn.net的ycsun_的文章"https://blog.csdn.net/yuechuen/article/details/71502503"。版权归原作者所有,此博客不拥有其著作权,亦不承担相应法律责任。如有侵权,请联系我们删除。
复制链接
复制链接
相关推荐
发表评论
登录后才能发表评论和回复 注册

/ 登录

评论记录:

未查询到任何数据!
回复评论:

分类栏目

后端 (14832) 前端 (14280) 移动开发 (3760) 编程语言 (3851) Java (3904) Python (3298) 人工智能 (10119) AIGC (2810) 大数据 (3499) 数据库 (3945) 数据结构与算法 (3757) 音视频 (2669) 云原生 (3145) 云平台 (2965) 前沿技术 (2993) 开源 (2160) 小程序 (2860) 运维 (2533) 服务器 (2698) 操作系统 (2325) 硬件开发 (2491) 嵌入式 (2955) 微软技术 (2769) 软件工程 (2056) 测试 (2865) 网络空间安全 (2948) 网络与通信 (2797) 用户体验设计 (2592) 学习和成长 (2593) 搜索 (2744) 开发工具 (7108) 游戏 (2829) HarmonyOS (2935) 区块链 (2782) 数学 (3112) 3C硬件 (2759) 资讯 (2909) Android (4709) iOS (1850) 代码人生 (3043) 阅读 (2841)

热门文章

101
推荐
关于我们 隐私政策 免责声明 联系我们
Copyright © 2020-2025 蚁人论坛 (iYenn.com) All Rights Reserved.
Scroll to Top