はじめに
「機械学習・深層学習」初学者のための『ゼロから作るDeep Learning』の攻略ノートです。『ゼロつくシリーズ』学習の補助となるように適宜解説を加えています。本と一緒に読んでください。
ニューラルネットワーク内部の計算について、数学的背景の解説や計算式の導出を行い、また実際の計算結果やグラフで確認していきます。
この記事は、付録A「Softmax-with-Lossレイヤの計算グラフ」の内容です。交差エントロピー誤差を微分します。
【元の記事】
【他の記事一覧】
【この記事の内容】
交差エントロピー誤差の逆伝播の導出
交差エントロピー誤差の微分(逆伝播)を導出します。順伝播の計算については「4.2.2-4:交差エントロピー誤差の実装【ゼロつく1のノート(実装)】 - からっぽのしょこ」を参照してください。
・定義式の確認
交差エントロピー誤差は、入力を$\mathbf{y} = (y_1, y_2, \cdots, y_K)$、教師ラベルを$\mathbf{t} = (t_1, t_2, \cdots, t_K)$として、次の式で定義されます。
ここで、$K$は分類するクラス数です。この式が順伝播の計算式です。
各項$y_k$に関する逆伝播(微分)は、次の式になります。
また、$\mathbf{y}$の勾配(逆伝播の出力)$\frac{\partial L}{\partial \mathbf{y}}$は、次のベクトルになります。
この式を導出します。
・順伝播の確認
交差エントロピー誤差(A.2)の計算は、各入力$y_k$に対して「対数をとる」「$t_k$を掛ける」「総和をとる」「$-1$を掛ける」の4つのノード(ステップ)に分解できます(図A-3)。
まずは、この順伝播の計算を確認します。
・ステップ1
最初の計算は、「入力の対数をとる」です。「$\log$」ノード(対数ノード)に交差エントロピー誤差の入力を入力します。
1つ目のノードの入力を$y_k^{(0)}$、出力を$y_k^{(1)}$とすると、「入力の対数をとる」は次の式で表せます。
$y_k^{(1)}$を次のノードに入力します。
・ステップ2
2つ目の計算は、「入力に対応する教師ラベル$t_k$を掛ける」です。「$\times$」ノード(乗算ノード)に$y_k^{(1)}$と$t_k$を入力します。
2つ目のノードの出力を$y_k^{(2)}$とすると、「入力に$t_k$を掛ける」は次の式で表せます。
$y_k^{(2)}$を次のノードに入力します。
・ステップ3
3つ目の計算は、「総和をとる」です。「$\mathrm{sum}$」ノード(総和ノード)に$\mathbf{y}^{(2)} = (y_1^{(2)}, y_2^{(2)}, \cdots, y_K^{(2)})$を入力します。「$\mathrm{sum}$」ノードについては2巻の1.3.4.4項を参照してください。ここでは「$+$」ノードとほとんど同じです。
3つ目のノードの出力を$y^{(3)}$とすると、「総和をとる」は次の式で表せます。
$K$個の要素を足し合わせたので、ベクトルではなくスカラになり、下付き文字$k$がなくなりました。$y^{(3)}$を次のノードに入力します。
・ステップ4
最後の計算は、「入力に$-1$を掛ける」です。「$\times$」ノード(乗算ノード)に$y^{(3)}$と$-1$を入力します。
4つ目のノードの出力を$L$とすると、「入力に$-1$を掛ける」は次の式で表せます。
$L$が交差エントロピー誤差(Cross Entropy Errorレイヤの出力)です。
以上が、交差エントロピー誤差で行う計算です。4つのノードをまとめてCross Entropy Errorレイヤとして扱います。
・Cross Entropy Errorレイヤ
4つのノードをそれぞれ関数$f_1(x), \cdots, f_4(x)$とすると、Cross Entropy Errorレイヤの計算は合成関数で表現できます。
4つの関数が入れ子になっています。それぞれ分けて書くと
と同じ意味です。
ここまでは、順伝播の計算を確認しました。次からは、逆伝播の計算を確認します。
・逆伝播の導出
逆伝播では、「順伝播の各入力$y_k^{(0)}$」に関する「順伝播の出力$L$」の微分$\frac{\partial L}{\partial y_k^{(0)}}$を求めます。
$\frac{\partial L}{\partial y_k^{(0)}}$は、合成関数の微分と言えます。よって、連鎖律より、4つの関数(ノード)の微分の積で求められます。連鎖律については「5.2:連鎖率【ゼロつく1のノート(数学)】 - からっぽのしょこ」を参照してください。
次のように表記しても同じ意味です。
ここでは、上の表記で統一します。
各ノードの微分$\frac{\partial L}{\partial y^{(3)}}, \cdots, \frac{\partial y_k^{(1)}}{\partial y_k^{(0)}}$を求めていきます。
・ステップ4
4つ目のノード(「$\times$」ノード)の順伝播は、次の計算でした。
$L$を$y_k^{(3)}$で微分すると
になります。
・ステップ3
3つ目のノード(「$\mathrm{sum}$」ノード)の順伝播は、次の計算でした。
$y^{(3)}$を$y_k^{(2)}$で微分すると
$y_k^{(2)}$の項だけが1となり、それ以外の項は0になります。
・ステップ2
2つ目のノード(「$\times$」ノード)の順伝播は、次の計算でした。
$y_k^{(2)}$を$y_k^{(1)}$で微分すると
になります。
・ステップ1
1つ目のノード(「$\log$」ノード)の順伝播は、次の計算でした。
$y_k^{(1)}$を$y_k^{(0)}$で微分すると
になります。対数関数$\log x$の微分$\frac{d \log x}{d x}$は分数($x$の逆数)$\frac{d \log x}{d x} = \frac{1}{x}$です。
以上で、各ノードの微分が求まりました。続いて、Cross Entropy Errorレイヤの微分を考えます。
・Cross Entropy Errorレイヤ
各ノードの微分をそれぞれ連鎖律の式に代入します。
$k$番目の入力$y_k^{(0)}$に関する微分$\frac{\partial L}{\partial y_k^{(0)}}$が求まりました。
他の項も同様に求められるので、順伝播の入力$\mathbf{y}$に関する勾配$\frac{\partial L}{\partial \mathbf{y}}$は
となります。
このレイヤの前にレイヤがある場合は、「逆伝播の入力$\frac{\partial L}{\partial L} = 1$」と「このレイヤの勾配$\frac{\partial L}{\partial \mathbf{y}}$」の積$\frac{\partial L}{\partial \mathbf{y}} = \frac{\partial L}{\partial L} \frac{\partial L}{\partial \mathbf{y}}$を出力します(前のレイヤに入力します)。
バッチ版交差エントロピー誤差の逆伝播の導出
前節は、1つのデータを扱う場合でした。この節では、複数データに対する交差エントロピー誤差の微分(逆伝播)を導出します。順伝播の計算については4.2.2-4項を参照してください。
・定義式の確認
バッチデータに対する交差エントロピー誤差は、入力を$\mathbf{Y} = (y_{1,1}, \cdots, y_{N,K})$、教師ラベルを$\mathbf{T} = (t_{1,1}, \cdots, t_{N,K})$として、次の式で定義されます。
ここで、$N$はバッチサイズ(1試行当たりのデータ数)、$K$は分類するクラス数です。$L$は、1データ当たりの平均交差エントロピー誤差と言えるのでした。
各項$y_{n,k}$に関する逆伝播(微分)は、次の式になります。
また、$\mathbf{Y}$の勾配(逆伝播の出力)$\frac{\partial L}{\partial \mathbf{Y}}$は、次の行列になります。
この式を導出します。
・順伝播の確認
バッチデータに対する交差エントロピー誤差の計算は、1データに対する交差エントロピー誤差の計算に続けて「$n$について総和をとる」「$\frac{1}{N}$を掛ける」の2つのノード(ステップ)が追加されます。
まずは、この順伝播の計算を確認します。
・ステップ1-4
1から4番目のノードでは、各データに関する交差エントロピー誤差$L_n$の求めます。よって、1データに対するCross Entropy Errorレイヤ(の4つのノード)と同様に計算します。
1つ目のノードの入力を$y_{n,k}^{(0)}$、4つ目のノードの出力を$L_n^{(4)}$とすると、4つのノードは次の式で表せます。
この式は、1データに対する交差エントロピー誤差の定義式(A.1)です。$L_n^{(4)}$を次のノードに入力します。
・ステップ5
5つ目の計算は、「データ$n$について総和をとる」です。「$\mathrm{sum}$」ノード(総和ノード)に$L_1^{(4)}, L_2^{(4)}, \cdots, L_K^{(4)}$を入力します。「$\mathrm{sum}$」ノードについては2巻の1.3.4.4項を参照してください。ここでは「$+$」ノードとほとんど同じです。
5つ目のノードの出力を$L^{(5)}$とすると、「$n$について総和をとる」は次の式で表せます。
$L^{(5)}$を次のノードに入力します。
・ステップ6
最後の計算は、「入力に$\frac{1}{N}$を掛ける」です。「$\times$」ノード(乗算ノード)に$L^{(5)}$と$\frac{1}{N}$を入力します。
6つ目のノードの出力を$L$とすると、「入力に$\frac{1}{N}$を掛ける」は次の式で表せます。
$L$が交差エントロピー誤差(Cross Entropy Errorレイヤの出力)です。
以上が、バッチデータに対する交差エントロピー誤差で行う計算です。6つのノードをまとめてCross Entropy Errorレイヤとして扱います。
・Cross Entropy Errorレイヤ
1から4番目のノードを1つの関数$f_{1 \cdots 4}(x)$、5と6番目のノードをそれぞれ関数$f_5(x), f_6(x)$とすると、Cross Entropy Errorレイヤの計算は合成関数で表現できます。
3つの関数が入れ子になっています。それぞれ分けて書くと
と同じ意味です。
ここまでは、順伝播の計算を確認しました。次からは、逆伝播の計算を確認します。
・逆伝播の導出
逆伝播では、「順伝播の各入力$y_{n,k}^{(0)}$」に関する「順伝播の出力$L$」の微分$\frac{\partial L}{\partial y_{n,k}^{(0)}}$を求めます。
$\frac{\partial L}{\partial y_{n,k}^{(0)}}$は、合成関数の微分と言えます。よって、連鎖律(5.2節)より、3つの関数(ノード)の微分の積で求められます。
$\frac{\partial L_n^{(4)}}{\partial y_{n,k}^{(0)}}$は、前節のように4つのノードの微分の積
に分解できますが、ここでは1つのまとまりとして扱います。(5.1.2項や5.3.1項の図5-10で書かれている「局所的な微分を伝播することで全体の微分を求める」雰囲気をここから感じられれば、誤差逆伝播法の嬉しさがが分かるかも。)
各ノードの微分$\frac{\partial L}{\partial L^{(5)}}, \frac{\partial L^{(5)}}{\partial L_n^{(4)}}, \frac{\partial L_n^{(4)}}{\partial y_{n,k}^{(0)}}$を求めていきます。
・ステップ6
6つ目のノード(「$\times$」ノード)の順伝播は、次の計算でした。
$L$を$L^{(5)}$で微分すると
になります。
・ステップ5
5つ目のノード(「$\mathrm{sum}$」ノード)の順伝播は、次の計算でした。
$L^{(5)}$を$L_n^{(4)}$で微分すると
になります。
・ステップ1-4
1から4番目のノードは、1データに対するCross Entropy Errorレイヤ(の4つのノード)と同様に計算できます。
したがって、$L_n^{(4)}$を$y_{n,k}^{(0)}$で微分すると
になります。
以上で、各ノードの微分が求まりました。続いて、Cross Entropy Errorレイヤの微分を考えます。
・Cross Entropy Errorレイヤ
各ノードの微分をそれぞれ連鎖律の式に代入します。
$n$番目のデータの$k$番目の入力$y_{n,k}^{(0)}$に関する微分$\frac{\partial L}{\partial y_{n,k}^{(0)}}$が求まりました。
他の項も同様に求められるので、順伝播の入力$\mathbf{Y}$に関する勾配$\frac{\partial L}{\partial \mathbf{Y}}$は
となります。
このレイヤの前にレイヤがある場合は、「逆伝播の入力$\frac{\partial L}{\partial L} = 1$」と「このレイヤの勾配$\frac{\partial L}{\partial \mathbf{Y}}$」の積$\frac{\partial L}{\partial \mathbf{Y}} = \frac{\partial L}{\partial L} \frac{\partial L}{\partial \mathbf{Y}}$を出力します(前のレイヤに入力します)。
参考文献
おわりに
加筆修正の際に記事を分割しました。
逆伝播シリーズ、これまでとはまた違った大変さだった。
【関連する記事】