からっぽのしょこ

読んだら書く!書いたら読む!同じ事は二度調べ(たく)ない

Batch Normレイヤの逆伝播【ゼロつく1のノート(数学)】

はじめに

 「機械学習・深層学習」学習初手『ゼロから作るDeep Learning』民のための数学攻略ノートです。『ゼロつく1』学習の補助となるように適宜解説を加えています。本と一緒に読んでください。

 NumPy関数を使って実装できてしまう計算について、数学的背景を1つずつ確認していきます。

 この記事は、主に6.3節「Batch Normalization」を補足するための内容になります。Batch Normalizationの逆伝播を導出します。

【関連する記事】

www.anarchive-beta.com

【他の記事一覧】

www.anarchive-beta.com

【この記事の内容】

Batch Normレイヤの逆伝播

・順伝播の確認

 ミニバッチサイズ$N$、1データのサイズ$M$のBatch Normレイヤの入力データを

$$ \mathbf{X} = \begin{pmatrix} x_{11} & \cdots & x_{1M} \\ \vdots & \ddots & \cdots \\ x_{N1} & \cdots & x_{NM} \end{pmatrix} $$

とします。ちなみにややこしいですが、これはAffineレイヤの出力のことなので、これまでの表記だと$(a_{11}, \cdots, a_{NM})$のことで手書き文字画像のことではありません。

 $i = 1, 2, \cdots, N$、また$k = 1, 2, \cdots, M$として、$i,k$要素に注目して考えます。Batch Normレイヤの順伝播の入力$x_{ik}$は

$$ \begin{align} \mu_k &\leftarrow \frac{1}{N} \sum_{i=1}^N x_{ik} \\ \sigma_k^2 &\leftarrow \frac{1}{N} \sum_{i=1}^N (x_{ik} - \mu_k)^2 \\ \sigma_k &\leftarrow \sqrt{\sigma_k^2 + \epsilon} \\ \hat{x}_{ik} &\leftarrow \frac{ x_{ik} - \mu_k }{ \sigma_k } \tag{6.7'} \end{align} $$

の4つの式計算によって、正規化(標準化)されます。そして

$$ y_{ik} \leftarrow \gamma_k \hat{x}_{ik} + \beta_k \tag{6.8'} $$

の式の計算によって分布を再調整します。

 従って、Batch Normレイヤの順伝播の出力は

$$ \begin{aligned} \mathbf{Y} &= \begin{pmatrix} y_{11} & \cdots & y_{1M} \\ \vdots & \ddots & \cdots \\ y_{N1} & \cdots & y_{NM} \end{pmatrix} \\ &= \begin{pmatrix} \gamma_1 \hat{x}_{11} + \beta_1 & \cdots & \gamma_M \hat{x}_{1M} + \beta_M \\ \vdots & \ddots & \cdots \\ \gamma_1 \hat{x}_{N1} + \beta_1 & \cdots & \gamma_M \hat{x}_{NM} + \beta_M \end{pmatrix} \end{aligned} $$

になります。

・逆伝播の導出

 ではBatch Normレイヤを逆伝播について考えます。逆伝播では、入力データ$x_{ik}$やハイパーパラメータ$\gamma_k,\ \beta_k$などの変数に関する出力データ$y_{ik}$の微分$\frac{\partial y_{ik}}{\partial x_{ik}},\ \frac{\partial y_{ik}}{\partial \gamma_k},\ \frac{\partial y_{ik}}{\partial \beta_k}$などを求めます。これらの微分は連鎖率(5.2節)より、各計算の入力に関する出力の微分の積で計算できるのでした。

 よってこの項では順伝播の5つ計算式を、後の計算からステップ1から7に分けてそれぞれ微分を求めていきます。また、Batch Normレイヤの逆伝播の$i,k$要素の入力を$\frac{\partial L}{\partial y_{ik}}$とします。

・ステップ1

 (後から)1つ目の計算は

$$ y_{ik} = \gamma_k \hat{x}_{ik} + \beta_k \tag{6.8'} $$

です。

 Batch Normレイヤの順伝播の出力$y_{ik}$を各変数で微分すると

$$ \begin{aligned} \frac{\partial y_{ik}}{\partial \hat{x}_{ik}} &= \gamma_k \\ \frac{\partial y_{ik}}{\partial \gamma_k} &= \hat{x}_{ik} \\ \frac{\partial y_{ik}}{\partial \beta_k} &= 1 \end{aligned} $$

となります。

 従って、Batch Normレイヤの逆伝播の入力$\frac{\partial L}{\partial y_{ik}}$との積

$$ \begin{aligned} \frac{\partial L}{\partial \hat{x}_{ik}} &= \frac{\partial L}{\partial y_{ik}} \gamma_k \end{aligned} $$

を次のノードに伝播します。

・ステップ2

 2つ目の正規化(標準化)の計算式は

$$ \begin{aligned} \hat{x}_{ik} &= \frac{ x_{ik} - \mu_k }{ \sigma_k } \\ &= (x_{ik} - \mu_k) \sigma_k^{-1} \end{aligned} $$

と変形できます。分数$\frac{1}{x}$は負の指数を使って$x^{-1}$と表せます。

 正規化されたデータ$\hat{x}_{ik}$を偏差$x_{ik} - \mu_k$と標準偏差$\sigma_k$で微分すると

$$ \begin{aligned} \frac{\partial \hat{x}_{ik}}{\partial (x_{ik} - \mu_k)} &= \frac{1}{\sigma_k} \\ \frac{\partial \hat{x}_{ik}}{\partial \sigma_k} &= - (x_{ik} - \mu_k) \sigma_k^{-2} \\ &= - \frac{ x_{ik} - \mu_k }{ \sigma_k^2 } \end{aligned} $$

となります。

 従って、ステップ1の出力$\frac{\partial L}{\partial \hat{x}_{ik}}$との積

$$ \begin{aligned} \frac{\partial L}{\partial (x_{ik} - \mu_k)} &= \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{\sigma_k} \end{aligned} $$

を(図6-17の上の)次のノードに

$$ \begin{aligned} \frac{\partial L}{\partial \sigma_k} &= - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{ x_{ik} - \mu_k }{ \sigma_k^2 } \end{aligned} $$

を(図6-17の下の)次のノードに伝播します。

・ステップ3

 3つ目の標準偏差の計算式は

$$ \begin{aligned} \sigma_k &= \sqrt{\sigma_k^2} \\ &= (\sigma_k^2)^{\frac{1}{2}} \end{aligned} $$

と変形できます。平方根$\sqrt{x}$は分数を指数を使って$x^{\frac{1}{2}}$と表せます。

 標準偏差$\sigma_k$を分散$\sigma_k^2$で微分すると

$$ \begin{aligned} \frac{\partial \sigma_k}{\partial \sigma_k^2} &= \frac{1}{2} (\sigma_k^2)^{-\frac{1}{2}} \\ &= \frac{ 1 }{ 2 \sqrt{\sigma_k^2} } \\ &= \frac{1}{2 \sigma_k} \end{aligned} $$

となります。

 従って、ステップ2の出力$\frac{\partial L}{\partial \sigma_k}$との積

$$ \begin{aligned} \frac{\partial L}{\partial \sigma_k^2} &= - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{ x_{ik} - \mu_k }{ 2 \sigma_k^3 } \end{aligned} $$

を次のノードに伝播します。

・ステップ4

 4つ目の分散を求める計算式は

$$ \begin{aligned} \sigma_k^2 &= \frac{1}{N} \sum_{i=1}^N (x_{ik} - \mu_k)^2 \\ &= \frac{1}{N} (x_{1k} - \mu_k)^2 + \cdots + \frac{1}{N} (x_{ik} - \mu_k)^2 + \cdots + \frac{1}{N} (x_{Nk} - \mu_k)^2 \end{aligned} $$

と分解できます。

 分散$\sigma_k^2$を偏差$x_{ik} - \mu_k$で微分すると

$$ \frac{\partial \sigma_k^2}{\partial (x_{ik} - \mu_k)} = \frac{2}{N} (x_{ik} - \mu_k) $$

となります。

 従って、ステップ3の出力$\frac{\partial L}{\partial \sigma_k^2}$との積

$$ \begin{aligned} \frac{\partial L}{\partial (x_{ik} - \mu_k)} &= - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{x_{ik} - \mu_k}{2 \sigma_k^3} \frac{2}{N} (x_{ik} - \mu_k) \\ &= - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N \sigma_k^3} \end{aligned} $$

を次のノードに伝播します。

・ステップ5

 4つ目の式の偏差を求める計算

$$ x_{ik} - \mu_k $$

に注目します。

 この偏差を入力$x_{ik}$と平均$\mu_k$で微分すると

$$ \begin{aligned} \frac{\partial (x_{ik} - \mu_k)}{\partial x_{ik}} &= 1 \\ \frac{\partial (x_{ik} - \mu_k)}{\partial \mu_k} &= - 1 \end{aligned} $$

となります。

 このノードには、ステップ2の出力$\frac{\partial L}{\partial (x_{ik} - \mu_k)}$とステップ4の出力$\frac{\partial L}{\partial (x_{ik} - \mu_k)}$が伝播してきます。よってこの2つを合わせた

$$ \begin{aligned} \frac{\partial L}{\partial (x_{ik} - \mu_k)} &= \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{\sigma_k} - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N \sigma_k^3} \end{aligned} $$

が、このノードの逆伝播の入力になります。

 従って、この$\frac{\partial L}{\partial (x_{ik} - \mu_k)}$との積

$$ \begin{aligned} \frac{\partial L}{\partial x_{ik}} &= \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{\sigma_k} - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N \sigma_k^3} \end{aligned} $$

を(図6-17の上の)次のノードに

$$ \begin{aligned} \frac{\partial L}{\partial \mu_k} &= - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{\sigma_k} + \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N \sigma_k^3} \end{aligned} $$

を(図6-17の下の)次のノードに伝播します。

・ステップ6

 5つ目の平均を求める式は

$$ \begin{aligned} \mu_k &= \frac{1}{N} \sum_{i=1}^N x_{ik} \\ &= \frac{1}{N} x_{1k} + \cdots + \frac{1}{N} x_{ik} + \cdots + \frac{1}{N} x_{Nk} \end{aligned} $$

と分解できます。

 平均$\mu_k$を入力$x_{ik}$で偏微分すると、$x_{ik}$以外の$x_{1k}$から$x_{Nk}$は0になるので

$$ \frac{\partial \mu_k}{\partial x_{ik}} = \frac{1}{N} $$

となります。

 従って、ステップ5の出力$\frac{\partial L}{\partial \mu_k}$との積

$$ \begin{aligned} \frac{\partial L}{\partial x_{ik}} &= \left( - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{\sigma_k} + \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N \sigma_k^3} \right) \frac{1}{N} \\ &= - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{N \sigma_k} + \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N^2 \sigma_k^3} \end{aligned} $$

を次のノードに伝播します。

・ステップ7

 最終的に、図6-17の「$x$」ノードに伝播してくるステップ5と6の$\frac{\partial L}{\partial x_{ik}}$の和

$$ \begin{aligned} \frac{\partial L}{\partial x_{ik}} &= \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{\sigma_k} - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N \sigma_k^3} - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{N \sigma_k} + \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N^2 \sigma_k^3} \end{aligned} $$

が、Batch Normレイヤの逆伝播の出力となります。

・「Batch Norm」ノード

 $i,k$要素に関する微分をまとめます。逆伝播の入力を$\frac{\partial L}{\partial y_{ik}}$とすると、各変数に関する微分は

$$ \begin{aligned} \frac{\partial L}{\partial \beta_k} &= \frac{\partial L}{\partial y_{ik}} \\ \frac{\partial L}{\partial \gamma_k} &= \frac{\partial L}{\partial y_{ik}} \hat{x}_{ik} \\ \frac{\partial L}{\partial \hat{x}_{ik}} &= \frac{\partial L}{\partial y_{ik}} \gamma_k \\ \frac{\partial L}{\partial \sigma_k} &= - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{ x_{ik} - \mu_k }{ \sigma_k^2 } \\ \frac{\partial L}{\partial \sigma_k^2} &= - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{ x_{ik} - \mu_k }{ 2 \sigma_k^3 } \\ \frac{\partial L}{\partial (x_{ik} - \mu_k)} &= \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{\sigma_k} - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N \sigma_k^3} \\ \frac{\partial L}{\partial \mu_k} &= - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{\sigma_k} + \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N \sigma_k^3} \\ \frac{\partial y_{ik}}{\partial x_{ik}} &= \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{\sigma_k} - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N \sigma_k^3} - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{N \sigma_k} + \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N^2 \sigma_k^3} \end{aligned} $$

で計算できます。

 続いて、ミニバッチデータ$\mathbf{X},\ \mathbf{Y}$について考えます。$k = 1, 2, \cdots, M$についてはそのまま計算できるので、データ$i = 1, 2, \cdots, N$について考えます。また逆伝播の入力もミニバッチデータになるので、$\frac{\partial L}{\partial \mathbf{Y}}$となります。

 ハイパーパラメータ$\beta_k,\ \gamma_k$は、式(6.8')と$\mathbf{Y}$の計算式から分かる通り、$y_{1k}, y_{2k}, \cdots,y_{Nk}$の計算(ノード)に含まれ(入力し)ます。よって逆伝播では、$\beta_k$のノードには$\frac{\partial L}{\partial y_{1k}}, \frac{\partial L}{\partial y_{2k}}, \cdots, \frac{\partial L}{\partial y_{Nk}}$が、$\gamma_k$のノードには$\frac{\partial L}{\partial y_{1k}} \hat{x}_{1k}, \frac{\partial L}{\partial y_{2k}} \hat{x}_{2k}, \cdots, \frac{\partial L}{\partial y_{Nk}} \hat{x}_{Nk}$が伝播してきます。図6-17のときと同様に、1つのノードに複数の入力がある場合は和をとればいいので

$$ \begin{aligned} \frac{\partial L}{\partial \boldsymbol{\beta}} &= \sum_{i=1}^N \frac{\partial L}{\partial \mathbf{Y}} \\ \frac{\partial L}{\partial \boldsymbol{\gamma}} &= \sum_{i=1}^N \frac{\partial L}{\partial \mathbf{Y}} \odot \hat{\mathbf{X}} \end{aligned} $$

となることが分かります。ここで$\boldsymbol{\beta} = (\beta_1, \beta_2, \cdots, \beta_M),\ \boldsymbol{\gamma} = (\gamma_1, \gamma_2, \cdots, \gamma_M)$とします。また$\odot$は行列の要素ごとの掛け算を表します。

 ステップ2と5の計算でも同じことが起こります。つまり$x_1, x_2, \cdots, x_M$の$M$個のノードと$\mu_k,\ \sigma_k$が1つのノードに伝播するわけです。よって$\frac{\partial L}{\partial \boldsymbol{\mu}},\ \frac{\partial L}{\partial \boldsymbol{\sigma}}$でも$\sum_{i=1}^N$の計算を行う必要があります。が、残りの計算式も含めて書くのを止めておきます。というか数式としてどう表現していいのか分からないんですよね。例えば$\frac{\partial L}{\partial \hat{\mathbf{X}}} = \frac{\partial L}{\partial \mathbf{Y}} \odot \mathbf{1}^{\mathrm{T}} \boldsymbol{\gamma}$といった書き方をする必要があると思うんですけど、数学的な知識が不足していて分かりません。あ、$\mathbf{1}$は1が$M$個並んだ横ベクトルのつもりです。$K$次元ベクトルの$\boldsymbol{\gamma}$を$N$行$M$列の行列にする必要を感じました。

 このレジュメは本にある数式やコードの行間を埋めるのが精一杯で、ゴールが明確でないものを導き出すには力不足です。というわけで、しっかり書いてある資料が見付かれば追記します(というか教えてほしいです色々)。

 では最後に、$i,k$要素の計算式を実装用に整理します。(これもソースコードをゴールとしてそこまでの行間を埋めているわけです。)

$$ \begin{aligned} \frac{\partial L}{\partial \beta_k} &= \frac{\partial L}{\partial y_{ik}} \\ \frac{\partial L}{\partial \gamma_k} &= \frac{\partial L}{\partial y_{ik}} \hat{x}_{ik} \\ \frac{\partial L}{\partial \hat{x}_{ik}} &= \frac{\partial L}{\partial y_{ik}} \gamma_k \\ \frac{\partial L}{\partial \sigma_k} &= - \frac{\partial L}{\partial \hat{x}_{ik}} \frac{x_{ik} - \mu_k}{\sigma_k^2} \\ \frac{\partial L}{\partial \sigma_k^2} &= - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{x_{ik} - \mu_k}{2 \sigma_k^3} \\ &= \frac{\partial L}{\partial \sigma_k} \frac{1}{2 \sigma_k} \\ \frac{\partial L}{\partial (x_{ik} - \mu_k)} &= \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{\sigma_k} - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N \sigma_k^3} \\ &= \frac{\partial L}{\partial \hat{x}_{ik}} \frac{1}{\sigma_k} + \frac{\partial L}{\partial \sigma_k^2} \frac{2 (x_{ik} - \mu_k)}{N} \\ \frac{\partial L}{\partial \mu_k} &= - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{\sigma_k} + \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N \sigma_k^3} \\ &= - \frac{\partial L}{\partial (x_{ik} - \mu_k)} \\ \frac{\partial y_{ik}}{\partial x_{ik}} &= \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{\sigma_k} - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N \sigma_k^3} - \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{1}{N \sigma_k} + \frac{\partial L}{\partial y_{ik}} \gamma_k \frac{(x_{ik} - \mu_k)^2}{N^2 \sigma_k^3} \\ &= \frac{\partial L}{\partial (x_{ik} - \mu_k)} + \frac{\partial L}{\partial \mu_k} \frac{1}{N} \end{aligned} $$

 $\frac{\partial L}{\partial (x_{ik} - \mu_k)}$の計算について、一部の項を$\frac{\partial L}{\partial \sigma_k^2}$に置き換えたことで$\frac{1}{2}$が余計に含まれるため、2を掛けることで打ち消しています。

 実装とdmudxの後の項の符号が違う、、、あとこれ処理の軽減などを意図していると思うのですけど、$\frac{x_{ik} - \mu_k}{\sigma_k}$を$\hat{x}$に置き換えたりする方が計算量が減ったりしませんか?よく意図を推し量れませんでした。。。

・実装イメージ

 実装自体は【実装ノート】で行いますが、複雑なため逆伝播メソッドの定義に関してはこちらで確認することにします。

 各変数に関する微分の計算を行います。

 ハイパーパラメータ$\boldsymbol{\gamma},\ \boldsymbol{\beta}$については$\frac{\partial L}{\partial \boldsymbol{\gamma}},\ \frac{\partial L}{\partial \boldsymbol{\beta}}$を用いた勾配降下法により学習を行うため、dgammabetaの値はインスタンス変数として保存しておきます。

# Batch Normalizationの実装
class BatchNormalization:
    
    # インスタンス変数の定義
    #def __init__(self, gamma, beta, momentum=0.9, running_mean=None, running_var=None):
        # (省略)
    
    # 順伝播メソッドの定義
    #def forward(self, x, train_flg=True):
        # (省略)
    
    # 逆伝播メソッドの定義
    def backward(self, dout):
        
        # 微分の計算
        dbeta = dout.sum(axis=0) # 調整後の平均
        dgamma = np.sum(self.xn * dout, axis=0) # 調整後の標準偏差
        dxn = self.gamma * dout # 正規化後のデータ
        dxc = dxn / self.std # 偏差
        dstd = -np.sum((dxn * self.xc) / (self.std * self.std), axis=0) # 標準偏差
        dvar = 0.5 * dstd / self.std # 分散
        dxc += (2.0 / self.batch_size) * self.xc * dvar # 偏差
        dmu = np.sum(dxc, axis=0) # 平均
        dx = dxc - dmu / self.batch_size # 入力データ
        
        # インスタンス変数に保存
        self.dgamma = dgamma
        self.dbeta = dbeta
        
        return dx

 axis=0だと1次元配列($N = 1$のとき)を入力したとき、$k = 1, \cdots, M$の和をとることにならない?1行$M$列の2次元配列に変換する処理を組み込んだっけ??

参考文献

  • 斎藤康毅『ゼロから作るDeep Learning』オライリー・ジャパン,2016年.

おわりに

 最終的な計算式が載ってない(かつ見付けられなかった)のでソースコードから逆算もしつつ導出しましたが、最終的に微妙に食い違う点が出てきたままです。6章は説明が(私にとって)足りない(またこれ以上時間をかけられない)ので、自分の理解度への要求水準を少し下げることにしました。その影響は解説にも当然出ていますが、まぁ今のところ特に読まれてるわけでもないのでいいかなと。

【元の記事】

www.anarchive-beta.com