からっぽのしょこ

読んだら書く!書いたら読む!読書読読書読書♪同じ事は二度調べ(たく)ない

ソフトマックス関数のオーバーフロー対策【ゼロつく1のノート(数学)】

はじめに

 「機械学習・深層学習」学習初手『ゼロから作るDeep Learning』民のための数学攻略ノートです。『ゼロつく1』学習の補助となるように適宜解説を加えています。本と一緒に読んでください。

 NumPy関数を使って実装できてしまう計算について、数学的背景を1つずつ確認していきます。

 この記事は、主に3.5.2項「ソフトマックス関数の実装上の注意」を補足するための内容になります。ソフトマックス関数のオーバーフロー回避策について説明します。

【関連する記事】

www.anarchive-beta.com

【他の記事一覧】

www.anarchive-beta.com

【この記事の内容】

・ソフトマックス関数

 ソフトマックス関数は次の式です。

$$ y_k = \frac{\exp(a_k)}{\sum_{i=1}^n \exp(a_i)} \tag{3.10} $$

 $a_k$は$a_1, a_2, \cdots, a_n$の内の1つです。ただしどれか1つというよりも、どれだったとしてもというようなニュアンスです。1つ目の要素に注目すると$y_1 = \frac{\exp(a_1)}{\sum_{i=1}^n \exp(a_i)}$ですし、5つ目だと$y_5 = \frac{\exp(a_5)}{\sum_{i=1}^n \exp(a_i)}$です。

 分子は$a_k$の指数関数で、分母は全ての要素の指数関数の総和です。

$$ \sum_{i=1}^n \exp(a_i) = \exp(a_1) + \exp(a_2) + \cdots + \exp(a_n) $$

 図3-22の矢印が複雑になっているのは、$k$番目の要素を総和で割る(全ての要素が関わっている)ことを視覚的に表現しているためです。

 ソフトマックス関数の出力は、全ての要素が0から1の値になり、またその総和が1となります。

$$ 0 \leq y_k \leq 1,\ \sum_{k=1}^n y_k = 1 $$

 この性質は確率の定義を満たしていますね。入力信号を確率の値(のよう)に変換できることから、活性化関数にソフトマックス関数が用いられます。

 ちなみにこの性質には次の2つのことが関係しています。(飛ばしていいよ!)1つは、総和で割った値の総和は1になります。

# NumPyを読み込む
import numpy as np

# 値を適当に設定
x = np.array([5, 7, 8, 7, 6, 8, 9])

# 総和で割る
y = x / np.sum(x)
print(y)
print(np.sum(y))
[0.1  0.14 0.16 0.14 0.12 0.16 0.18]
1.0

 総和が1になりました!そして、各要素が0から1の・・・いえ、xに負の値を設定してみましょう!出力も負の値となります、、、

 そこでもう1つが、指数関数の結果は正の値となります。また単調増加する関数なので、要素間の大小関係が変化しません。

# 段々大きくなる値を設定
x = np.arange(-3, 4)

# 指数関数の計算
y = np.exp(x)
print(np.round(y, 3))
[ 0.05   0.135  0.368  1.     2.718  7.389 20.086]

 全て正の値になっていますね。

 ソフトマックス関数の性質はこういう理由でした。

 一旦3.5.1項に戻ってソフトマックス関数を実装してみましょう。そしてすぐ3.5.2で次の話になります。

・オーバーフロー

 メモリの都合上桁を無限に持つことができません。その結果オーバーフロー(アンダーフロー)が起きます。

 指数関数は値が特に大きくなるのでした。よってそのまま実行するとオーバーフローが起こりやすいです。実際にやってみます。

# 利用するライブラリを読み込み
import numpy as np
# (いい感じにダメになる)値を生成
x = np.arange(705, 715)

# 指数をとる
exp_x = np.exp(x)
print(exp_x)
[1.50525383e+306 4.09170414e+306 1.11224050e+307 3.02338314e+307
 8.21840746e+307             inf             inf             inf
             inf             inf]

 値を扱いきれずinf(infinity:無限大)となります。

 更に和をとると

# 和をとる
sum_exp_x = np.sum(exp_x)
print(sum_exp_x)
inf

これもinfですね。infの和もinfになります。

 ではinfで割ると

# infで割る
y = exp_x / sum_exp_x
print(y)
[ 0.  0.  0.  0.  0. nan nan nan nan nan]

値をinfで割ると0になり、infで割るinfnan(not a number:非数)になります。。

 ちなみにnanでも計算できません。

# nanの計算
print(x * y)
print(x / y)
[ 0.  0.  0.  0.  0. nan nan nan nan nan]
[inf inf inf inf inf nan nan nan nan nan]


・オーバーフロー対策

 オーバーフロー対策として式(3.11)のように計算します。この式の実装については3.5.1項の記事で行います。ここでは式変形に少しだけ解説を加えます。

 ソフトマックス関数は次の式でした。

$$ y_k = \frac{\exp(a_k)}{\sum_{i=1}^n \exp(a_i)} \tag{3.10} $$

 本では$\exp(\boldsymbol{\mathrm{a}}) = (\exp(a_1), \cdots, \exp(a_n))$の最大値を$C$とおきます(数式で表現するなら$C = \max(\exp(\boldsymbol{\mathrm{a}}))$です)。そして$C$を分母分子に掛けます。$y_k$に$\frac{C}{C} = 1$を掛けるだけなので、$y_k$に影響しません。(この方法だと$\log$の計算が出てくるので、次のように進めます。$\log$については「自然対数」節で確認してください。)

 同じことをここでは、$\exp(a_l) = \max(\exp(\boldsymbol{\mathrm{a}}))$とおきます($\exp(a_1)$から$\exp(a_n)$の最大値は$l$番目の$\exp(a_l)$だったという意味です)。式(3.10)に$\frac{\exp(a_l)}{\exp(a_l)} = 1$を掛けます。分母の$\sum_{i=1}^n$を展開すると次のようになります(分かりやすい方を見てくれればいいです)。

$$ y_k = \frac{\exp(a_l) \exp(a_k)}{\exp(a_l) \sum_{i=1}^n \exp(a_i)} = \frac{ \exp(a_k) \exp(a_l) }{ \exp(a_1) \exp(a_l) + \cdots + \exp(a_n) \exp(a_l) } $$

 $\exp(a) \exp(b) = \exp(a + b)$なので(詳しくは次の「指数関数(2)」で確認します)、この式は次のようになります。

$$ y_k = \frac{\exp(a_k + a_l)}{\sum_{i=1}^n \exp(a_i + a_l)} = \frac{ \exp(a_k + a_l) }{ \exp(a_1 + a_l) + \cdots + \exp(a_n + a_l) } $$

 本の$C'$と$a_l$は同じものです。

 ちなみに(指数関数は単調増加するので)、$\exp(a_1), \cdots, \exp(a_n)$の最大値は$a_1, \cdots, a_n$の最大値の指数をとったものと同じです。

 この中で登場した指数関数の計算と自然対数について次で確認しますが、ここまでの情報で十分ソフトマックス関数の実装に移れると思います。

・指数関数(2):指数法則(その1)

 指数関数同士の掛け算について確認します。

 まずは定数$a$の$2$乗と$3$乗の掛け算をしてみます。

$$ a^2 * a^3 = (a * a) * (a * a * a) = a^5 $$

 ここから、$a^m * a^n = a^{m+n}$であることが分かります。

 これはネイピア数$e$の場合も同じで

$$ e^2 * e^3 = (e * e) * (e * e * e) = e^5 $$

となるので、$e^m * e^n = e^{m+n}$であることが分かります。

 また、$\exp(\cdot)$の表記でも(当然)同じで

$$ \exp(m) * \exp(n) = \exp(m + n) $$

と書けます。

・自然対数(1):自然対数とは

 指数関数とセットで理解しておきたい自然対数について説明します。(ただしこれも殆ど使いませんよ。)

 まずは対数について確認します。$a > 0,\ a \neq 1,\ b > 0$として$a^c = b$となるとき、次のように定義します。(1は何乗しても1なので省きます。)

$$ \log_a b = c $$

 この$\log_a b$(と$c$の値)を、$a$を底とする$b$の対数と言います。また$b$を真数と言います。

 真数の変化に対する指数の変化に注目すると、次のように書けます。

$$ y = \log_a x $$

 これを$a$を底とする$x$の対数関数と言います。$\log_a x$は$a$を何乗すると$x$になるかを表しています。

# 利用するライブラリを読み込む
import numpy as np
import matplotlib.pyplot as plt


 簡単な例として、2を何乗すれば1024になるのかをnp.log2()で調べてみます。

# 対数の計算
np.log2(1024)
10.0

 2を10乗すれば1024になることが分かりました。式にすると$\log_2 1024 = 10$であり、また$2^{10} = 1024$です。

 ではこれと同じことを、今度は0.1から100まで0.1間隔で調べてグラフにしましょう。

# x軸の値を生成
x = np.arange(0.1, 100, 0.1)

# 対数の計算
log2_x = np.log2(x)

# 作図
plt.plot(x, log2_x)
plt.title("Binary Logarithm", fontsize = 20)
plt.show()

f:id:anemptyarchive:20200605154556p:plain
底が2の対数関数

 これが2を底とする対数関数のグラフになります。

 次は底を、指数関数でも使ったネイピア数$e$とします。

$$ \log_e x = b $$

 ネイピア数を底とする場合の対数関数を自然対数と呼びます(あるいはネイピア数を自然対数の底とも呼びます)。自然対数は$\log_e$ですが$e$を省略して$\log$と書くのが一般的です($\ln$と表記することもあります)。

 先ほどと同じことを自然対数でもやってみましょう。自然対数の計算はnp.log()で行えます(NumPy関数でもeが省略されていますね)。

# xの値を生成
x = np.arange(1, 1000, 0.1)

# 自然対数を計算
log_x = np.log(x)

# 作図
plt.plot(x, log_x)
plt.title("Natural Logarithm", fontsize = 20)
plt.show()

f:id:anemptyarchive:20200605171656p:plain
自然対数


 ところで、$e$の何乗かといえば指数関数$e^x$を思い出しますよね。対数関数と指数関数は、次の関係が成り立ちます。

$$ e^x = b \Leftrightarrow \log_e x = b $$

 ここから次のことも分かります。

$$ \log_e e^x = x $$

であり、また

$$ \exp(\log_e x) = x $$

です。ここでは強調するためあえて$\log_e x$や$e^x$と書きましたが、この資料では基本的に$\log x$、$\exp(x)$の表記を使います。

 ではこの関係をNumPyを使って確認してみましょう。

# xの値を生成
x = np.arange(1, 10, 2)
print(x)
[1 3 5 7 9]

 この各要素の対数をとります。この資料で対数をとると書くとき、特に説明がなければ自然対数の計算を行います。

# 対数をとる
log_x = np.log(x)
print(log_x)
[0.         1.09861229 1.60943791 1.94591015 2.19722458]

 その結果をnp.exp()に渡すと

# 対数をとった値を指数をとる
exp_log_x = np.exp(log_x)
print(exp_log_x)
[1. 3. 5. 7. 9.]

 元の値に戻ることを確認できました。ただし整数型から浮動小数点型になっています(なっているため数値の表記が変わっていますね)。

 np.log()np.exp()の順番を逆にしても当然元の値になります。

# 指数をとる
exp_x = np.exp(x)
print(exp_x)

# 指数をとった値の対数をとる
log_exp_x = np.log(exp_x)
print(log_exp_x)
[2.71828183e+00 2.00855369e+01 1.48413159e+02 1.09663316e+03
 8.10308393e+03]
[1. 3. 5. 7. 9.]

 ここまでの知識で、式(3.11)の変形を理解できると思います。ネイピア数の利便性はもう1つありますが、それは5.3節「逆伝播」で利用します。

参考文献

  • 斎藤康毅『ゼロから作るDeep Learning』オライリー・ジャパン,2016年.

おわりに

 文献によって$\log_a b = x$だったり$\log_a x = b$だったりしない?そういうのすごく困る、、、はたまた$y = \log_a x$だったり??

 という訳でまだしっかりと理解しきれてません?、、、だってそれで困ってなかったので。。一応頭の中で整理はできたつもりなので記事になりました。

 対数とるのって、対数尤度にして式(計算)を分かりやすくするためでしょ!?

【元の記事】

www.anarchive-beta.com