はじめに
『ゼロから作るDeep Learning 3』の初学者向け攻略ノートです。『ゼロつく3』の学習の補助となるように適宜解説を加えていきます。本と一緒に読んでください。
本で登場する数学的な内容をもう少し深堀りして解説していきます。
この記事は、主にステップ47「ソフトマックス関数と交差エントロピー誤差」を補足する内容です。
オーバーフロー対策をしたLogSumExpの計算と、LogSumExpを用いたSoftmax関数と交差エントロピー誤差の計算を確認します。
【前ステップの内容】
【他の記事一覧】
【この記事の内容】
・多値分類の出力層の計算
$K$クラスの多値分類におけるソフトマックス関数(Softmax関数)と交差エントロピー誤差の計算について考えます。
・多値分類の出力層
ニューラルネットワークの(1つの)出力データを要素数が$K$のベクトル$\mathbf{x} = (x_1, x_2, \cdots, x_K)$とします。$N$個のバッチデータ($N \times K$の行列)$\mathbf{X}$の$n$番目のデータ$\mathbf{x}_n = (x_{n,1}, x_{n,2}, \cdots, x_{n,K})$として考えてください。またここでは、ニューラルネットワークの出力を想定していますが、この記事の内容には影響しません。
$\mathbf{x}$の各要素をソフトマックス関数により活性化します。
ソフトマックス関数の出力を$\mathbf{p} = (p_1, p_2, \cdots, p_K)$とします。$\mathbf{x}$がソフトマックス関数により正規化され、$0 \leq p_k \leq 1$、$\sum_{k=1}^K p_k = 1$となります。よって、$p_k$は「入力のクラスが$k$である確率」、$\mathbf{p}$は「$K$個のクラスの確率分布」として解釈できます。
ソフトマックス関数については、「3.5:ソフトマックス関数の実装【ゼロつく1のノート(実装)】 - からっぽのしょこ」を参照してください。
予測値$\mathbf{p}$と教師データ$\mathbf{t} = (t_1, t_2, \cdots, t_K)$から損失を求めます。ここでは、損失として交差エントロピー誤差を用います。
$\mathbf{t}$について、$\mathbf{x}$のクラスが$k$のとき、$k$番目の要素$t_k$が1でそれ以外の要素は0です。このようにして正解のクラスを表す方式をone-hotベクトル(one-hot表現)と言います。
$p_k$は0から1の値なので、$\log p_k$は常に負の値になります。よって、$- \log p_k$は常に正の値になります。また、$- \log p_k$の最小値は$p_k = 1$のときの$\log 1 = 0$です。$p_k$の値が小さくなるほど$- \log p_k$が大きくなるので、この値を誤差として利用します。
つまり、入力のクラスが$k$のとき($t_k = 1$のとき)、$- t_k \log p_k$は$- \log p_k$となりそれ以外の項は0になります。$K$個全ての項の和$L$も$- \log p_k$となります。したがって、正解のクラスを完全に予測できたとき($p_k = 1$のとき)の誤差は0となり、低く予測する($p_k$の値が小さい)ほど誤差が大きくなります。
ちなみに、積や総和の計算をせずに、正解のクラスに対応する$p_k$を取り出す操作によって$L$を求められます。
$N$個のバッチデータの場合は、次の式になります。
全てのデータの交差エントロピー誤差を足し合わせてデータ数$N$で割ることで平均を求めています。
交差エントロピー誤差については、「4.2.1:2乗和誤差の実装【ゼロつく1のノート(実装)】 - からっぽのしょこ」を参照してください。
次は、$\log p_k$の計算を考えます。
・対数Softmax関数
$\log p_k$は、対数をとったSoftmax関数の計算と言えます。
対数の性質$\log \frac{x}{y} = \log x - \log y$より、分母と分子を分解します。
対数と指数は打ち消し合い$\log \{\exp(x)\} = x$となります。
後の項は、$\mathbf{x}$に対して$\log$と$\sum$と$\exp$の計算をしています。この計算をLogSumExpと呼び、$\mathrm{LSE}(\mathbf{x})$で表すことにします。
対数ソフトマックス関数は、$x_k$から$\mathrm{LSE}(\mathbf{x})$を引くことで求められるのが分かりました。
続いて、$\mathrm{LSE}(\mathbf{x})$について考えます。
・LogSumExp
実装において、LogSumExpの計算
は、指数関数$\exp(x)$の計算を含むため$x$が大きいとオーバーフローすることがあります。
そこで、実装上は次のように計算(処理)します。
$\mathbf{x}$の最大値を$x_{\mathrm{max}}$として、$\exp(\cdot)$の中に$- x_{\mathrm{max}} + x_{\mathrm{max}} = 0$を加えます。
$\exp(x) = e^x$であり、指数の性質$x^{n+m} = x^n * x^m$より、$\exp(\cdot)$の項を分解します。
$\exp(x_{\mathrm{max}})$は、$\sum_{k=1}^K$とは無関係なので$\sum$の外に出しました。
対数の性質$\log (x * y) = \log x + \log y$より、$\log (\cdot)$の項を分解します。
$\log \{\exp(x)\} = x$の変形を行います。
$\mathrm{LSE}(\mathbf{x})$の計算は、$\mathbf{x}$の全ての項から最大値$x_{\mathrm{max}}$を引きLogSumExpの計算$\mathrm{LSE}(\mathbf{x} - x_{\mathrm{max}})$をして、$x_{\mathrm{max}}$を加えることで求められるのが分かりました。$x_k - x_{\mathrm{max}}$の最小値が$x_{\mathrm{\min}} - x_{\mathrm{max}}$、最大値が0になるので、オーバーフローが起きにくくなります。
LogSumExpの計算をDeZeroの関数logsumexp()
としてutils.py
に実装します。
・バッチ版の出力層
最後に、バッチデータの場合を確認します。
バッチデータに対する(オーバーフロー対策を行った)ソフトマックス関数と交差エントロピー誤差の計算をまとめると、次の式になります。
また、各データ$\mathbf{x}_n$の正解クラスの要素を$x_{n,*}$で表すと、次の式でも表せます。
(この2つの式と個別の式の中から分かりやすいと思う式と実装例を比べてください。)
logsumexp()
を用いて、ソフトマックス関数と交差エントロピー誤差をSoftmaxCrossEntropy
クラスとしてfunctions.py
に実装します。逆伝播の計算については、「Softmax-with-Lossレイヤの逆伝播の導出【ゼロつく1のノート(数学)】 - からっぽのしょこ」を参照してください。
参考文献
おわりに
ずいぶん前に初めて見たLogSumExpの式が、$a < b$の2項の例
の途中式がないもので、1は何?なぜ$\exp(b)$が消えて$b$に??と投げ出してしまいました。
それが今回、最大値を引いてから総和をとってるんだよ、あと$\exp(0) = 1$なだけ、と分かってスッキリしました。分かってしまうと、大したことではなかったなぁと感じるのもいつものこと。
ところで、記事中の様に式変形と解説文を交互に入れるのと、あとがきの様に式変形ドン!各行の解説ドン!とするのはどちらの方が分かりやすい(読みやすい)ものなのでしょうか?
私の場合は、数式を読むときと文字列を読むときとで使ってる脳のパーツ(あるいは意識)が違うのか、後者の方が好きです。なのでこれまでに書いた記事では、数式をまとめて書くことが多かったです。それと後々読み返すときに、数式だけで理解できるのかを確認できるように分けて書いています。
ゼロつく関連の記事では、数式の塊を見るのがキツい人(初学者時の私)を想定して、それとどちらのパターンでも書けるようになるための練習として、分けて書くようにしています。
解説ノートを作る際にいつも悩むことの1つ。あとは、レベル感とか文体とかどこまでくだけて書くかとか。
投稿日の前日に公開された動画をどうぞ♪
これは一体何のタイアップ?企画?なんだ??続報ぷりーず。
【次ステップの内容】