はじめに
「ソフトマックス関数の逆伝播の導出【ゼロつく1のノート(数学)】 - からっぽのしょこ」の補足用の記事です。内容は重複しており、またこの記事で完結しています。おそらくこの記事の方が分かりやすいです。
Softmax関数の微分
Softmax関数(ソフトマックス関数)の微分を導出します。
・定義式
Softmax関数は、入力を$\mathbf{x} = (x_1, x_2, \cdots, x_K)$、出力を$\mathbf{y} = (y_1, y_2, \cdots, y_K)$として、次の式で定義されます。
$$
y_k = \frac{
\exp(x_k)
}{
\sum_{k'=1}^K \exp(x_{k'})
}
$$
ここで、$K$は分類するクラス数です。
Softmax関数の微分は、次の2つの式です。入力と出力の成分$i, j$の関係によって異なります。
$$
\begin{aligned}
\frac{\partial y_i}{\partial x_j}
= \begin{cases}
y_i (1 - y_j) &\quad (i = j) \\
- y_i y_j &\quad (i \neq j)
\end{cases}
\end{aligned}
$$
この2つの式を導出します。
・成分が同じ場合
まずは、入力$x_k$と出力$y_k$の成分(クラス)が同じ場合を考えます。
$x_k$に関する$y_k$の微分$\frac{\partial y_k}{\partial x_k}$を求めます。$y_k = \frac{\exp(x_k)}{\sum_{k'=1}^K \exp(x_{k'})}$なので
$$
\frac{\partial y_k}{\partial x_k}
= \frac{\partial}{\partial x_k}
\frac{\exp(x_k)}{\sum_{k'=1}^K \exp(x_{k'})}
$$
を計算します。
商の微分$\Bigl\{\frac{f(x)}{g(x)}\Bigr\}' = \frac{f'(x) g(x) - f(x) g'(x)}{\{g(x)\}^2}$より
$$
\begin{align*}
\frac{\partial y_k}{\partial x_k}
&= \left\{
\frac{\partial}{\partial x_k}
\exp(x_k)
\right\}
\sum_{k'=1}^K \exp(x_{k'})
\frac{1}{\left(\sum_{k'=1}^K \exp(x_{k'})\right)^2} \\
&\quad
- \exp(x_k) \left\{
\frac{\partial}{\partial x_k}
\sum_{k'=1}^K \exp(x_{k'})
\right\}
\frac{1}{\left(\sum_{k'=1}^K \exp(x_{k'})\right)^2}
\tag{1}
\end{align*}
$$
分子を微分した項と分母を微分した項になります。$f(x) = \exp(x_k)$、$g(x) = \sum_{k'=1}^K \exp(x_{k'})$に対応しています。
分子の微分は、指数関数の微分$\frac{d e^{x}}{d x} = e^x$より
$$
\frac{\partial}{\partial x_k} \exp(x_k)
= \exp(x_k)
$$
変化しません。
分母の微分は、$\exp(x_1), \cdots, \exp(x_K)$の内、$x_k$の影響を受ける$\exp(x_k)$の微分だけが残るので
$$
\begin{aligned}
\frac{\partial}{\partial x_k}
\sum_{k'=1}^K \exp(x_{k'})
&= \frac{\partial}{\partial x_k} \Bigl\{
\exp(x_1) + \cdots + \exp(x_k) + \cdots + \exp(x_K)
\Bigr\}
\\
&= 0 + \cdots + \frac{\partial}{\partial x_k} \exp(x_k) + \cdots + 0
\\
&= \exp(x_k)
\end{aligned}
$$
となります。
それぞれ式(1)に代入すると、$\frac{\partial y_k}{\partial x_k}$は
$$
\begin{aligned}
\frac{\partial y_k}{\partial x_k}
&= \exp(x_k)
\sum_{k'=1}^K \exp(x_{k'})
\frac{1}{\left(\sum_{k'=1}^K \exp(x_{k'})\right)^2} \\
&\quad
- \exp(x_k)
\exp(x_k)
\frac{1}{\left(\sum_{k'=1}^K \exp(x_{k'})\right)^2}
\\
&= \frac{\exp(x_k)}{\sum_{k'=1}^K \exp(x_{k'})}
- \left(
\frac{\exp(x_k)}{\sum_{k'=1}^K \exp(x_{k'})}
\right)^2
\\
&= y_k - y_k^2
\\
&= y_k (1 - y_k)
\end{aligned}
$$
となります。
・成分が異なる場合
次は、入力$x_j$と出力$y_i$の成分(クラス)が異なる$i \neq j$の場合を考えます。
$x_j$に関する$y_i$の微分$\frac{\partial y_i}{\partial x_j}$を求めます。$y_i = \frac{\exp(x_i)}{\sum_{k=1}^K \exp(x_k)}$なので
$$
\frac{\partial y_i}{\partial x_j}
= \frac{\partial}{\partial x_j}
\frac{\exp(x_i)}{\sum_{k=1}^K \exp(x_k)}
$$
を計算します。
分子の$\exp(x_i)$は$x_j$と無関係なので、$\exp(x_i)$を係数として分母の逆数の微分
$$
\frac{\partial y_i}{\partial x_j}
= \exp(x_i)
\frac{\partial}{\partial x_j}
\frac{1}{\sum_{k=1}^K \exp(x_k)}
\tag{2}
$$
となります。
分母の逆数の微分は、$\frac{1}{x} = x^{-1}$の変形をして、合成関数の微分$\{f(g(x))\}' = f'(g(x)) g'(x)$より
$$
\begin{aligned}
\frac{\partial}{\partial x_j}
\frac{1}{\sum_{k=1}^K \exp(x_k)}
&= \frac{\partial}{\partial x_j}
\left(\sum_{k=1}^K \exp(x_k)\right)^{-1}
\\
&= - \left(\sum_{k=1}^K \exp(x_k)\right)^{-2}
\frac{\partial}{\partial x_j}
\sum_{k=1}^K \exp(x_k)
\\
&= - \frac{1}{\left(\sum_{k=1}^K \exp(x_k)\right)^2}
\frac{\partial}{\partial x_j}
\sum_{k=1}^K \exp(x_k)
\end{aligned}
$$
となります。$f(g(x)) = \frac{1}{\sum_{k=1}^K \exp(x_k)}$、$g(x) = \sum_{k=1}^K \exp(x_k)$に対応しています。
さらに$\frac{\partial}{\partial x_j} \sum_{k=1}^K \exp(x_k)$は、$\exp(x_1), \cdots, \exp(x_K)$の内、$x_j$の影響を受ける$\exp(x_j)$の微分だけが残るので
$$
\begin{aligned}
\frac{\partial}{\partial x_j} \sum_{k=1}^K \exp(x_k)
&= \frac{\partial}{\partial x_j} \Bigl\{
\exp(x_1) + \cdots + \exp(x_j) + \cdots + \exp(x_K)
\Bigr\}
\\
&= 0 + \cdots + \frac{\partial}{\partial x_j} \exp(x_j) + \cdots + 0
\\
&= \exp(x_j)
\end{aligned}
$$
となります。
それぞれ式(2)に代入すると、$\frac{\partial y_i}{\partial x_j}$は
$$
\begin{aligned}
\frac{\partial y_i}{\partial x_j}
&= \exp(x_i) \left\{
- \frac{1}{\left(\sum_{k=1}^K \exp(x_k)\right)^2}
\right\}
\exp(x_j)
\\
&= - \frac{\exp(x_i)}{\sum_{k=1}^K \exp(x_k)}
\frac{\exp(x_j)}{\sum_{k=1}^K \exp(x_k)}
\\
&= - y_i y_j
\end{aligned}
$$
となります。
・勾配
最後に、各成分の偏微分をまとめた勾配を確認します。
$k$番目の入力$x_k$に関する出力$\mathbf{y}$の勾配$\frac{\partial \mathbf{y}}{\partial x_k}$は
$$
\frac{\partial \mathbf{y}}{\partial x_k}
= \begin{pmatrix}
\frac{\partial y_1}{\partial a_k} \\
\vdots \\
\frac{\partial y_k}{\partial a_k} \\
\vdots \\
\frac{\partial y_K}{\partial a_k}
\end{pmatrix}
= \begin{pmatrix}
- y_1 y_k \\
\vdots \\
y_k (1 - y_k) \\
\vdots \\
- y_K y_k
\end{pmatrix}
$$
となります。
また、入力$\mathbf{x}$に関する出力$\mathbf{y}$の勾配$\frac{\partial \mathbf{y}}{\partial \mathbf{x}}$は
$$
\frac{\partial \mathbf{y}}{\partial \mathbf{x}}
= \begin{pmatrix}
\frac{\partial y_1}{\partial a_1} & \frac{\partial y_1}{\partial a_2} & \cdots & \frac{\partial y_1}{\partial a_K} \\
\frac{\partial y_2}{\partial a_1} & \frac{\partial y_2}{\partial a_2} & \cdots & \frac{\partial y_2}{\partial a_K} \\
\vdots & \vdots & \ddots & \vdots \\
\frac{\partial y_K}{\partial a_1} & \frac{\partial y_K}{\partial a_2} & \cdots & \frac{\partial y_K}{\partial a_K}
\end{pmatrix}
= \begin{pmatrix}
y_1 (1 - y_1) & - y_1 y_2 & \cdots & - y_1 y_K \\
- y_2 y_1 & y_2 (1 - y_2) & \cdots & - y_2 y_K \\
\vdots & \vdots & \ddots & \vdots \\
- y_K y_1 & - y_K y_2 & \cdots & y_K (1 - y_K)
\end{pmatrix}
$$
となります。
ここでは、縦ベクトル・横ベクトルを適当に扱いました。またベクトルの微分に関して、分子レイアウト・分母レイアウトという概念もあるようなので、必要に応じて設定してください。
おわりに
誤差逆伝播法的に解くよりも微分の公式に沿って解く方が分かりやすくなってきた。対数尤度を微分してた成果を感じる。とは言ったものの相変わらず公式を使って数式パズルを解いているだけで、微分そのものの定義は未だよく分かっていないのでちゃんと勉強しないとなぁ(n回目)。
このカバー曲が格好良いので聴いてください!