からっぽのしょこ

読んだら書く!書いたら読む!同じ事は二度調べ(たく)ない

8.1.1-3:Decoderの改良1【ゼロつく2のノート(実装)】

はじめに

 『ゼロから作るDeep Learning 2――自然言語処理編』の初学者向け【実装】攻略ノートです。『ゼロつく2』学習の補助となるように適宜解説を加えています。本と一緒に読んでください。

 本の内容を1つずつ確認しながらゆっくりと組んでいきます。

 この記事は、8.1.3項「Decoderの改良①」の内容です。Encoderの隠れ状態から必要な情報を抽出するWeight Sumレイヤの処理を解説して、Pythonで実装します。

【前節の内容】

www.anarchive-beta.com

【他の節の内容】

www.anarchive-beta.com

【この節の内容】

8.1.1 seq2seqの問題点

 7章で実装したseq2seqでは、「Encoder側の最後の隠れ状態$\mathbf{h}_{T-1} = (h_{0,0}^{(T-1)}, \cdots, h_{N-1,H-1}^{(T-1)})$」を「Decoder側の最初のLSTMレイヤ」に入力しました。しかしこれでは、時系列サイズ$T$が大きいと全ての単語の情報がDecoderに伝わらない可能性があります。そこで、Encoderの全ての隠れ状態$\mathbf{hs} = (\mathbf{h}_0, \cdots, \mathbf{h}_{T-1})$をDecoderで利用することを考えます。(本に載っている図は、バッチデータではない(バッチサイズ$N = 1$の)場合です。この資料ではバッチデータを扱う場合を想定しているため、次元の数が1つ異なっている点に注意してください。)

8.1.2 Encoderの改良

 ある時刻の隠れ状態$\mathbf{h}_t$は、時刻に対応した次元を持ちません。そのため、例えば最後の時刻においては、Encoderに入力する1つの文章($T$個の単語)$(x_{n,0}, \cdots, x_{n,T-1})$は、次元数$H$の隠れ状態ベクトル$(h_{n,0}, \cdots, h_{n,H-1})$にエンコードされます。文の長さ(時系列サイズ$T$)に関わらず、$H$個の要素を持つ(次元数$H$で固定された)ベクトルで表現しなければなりません。
 それに対して$\mathbf{hs}$は、入力する単語数と同じ数の隠れ状態を持ちます。なので、入力する文章の長さ(単語の数)に応じて時系列方向の要素数が変化します。そのため、全ての単語の情報をDecoderで利用できることを期待できます。

 Attention付きseq2seqのEncoderは、8.2.1項で実装します。

 次項からは、Encoderが出力する全ての隠れ状態をDecoderでどう利用するのかについて解説していきます。(個人的な印象ですが、何ができるのかに関心が強い人は8.1.3項を先に、どうやってするのかに関心が強い人は3.1.4項を先に読む方が分かりやすい気がします。)

8.1.3 Decoderの改良1

 この項では、Weight Sumレイヤを実装します。Weight Sumレイヤでは、次項で実装するAttention Weightレイヤで求める重みを用いて、Encoderの隠れ状態から必要な情報を抽出します。

 ちなみにステップ関数のような計算によって、出力が0なら選ばない、1なら選ぶということはできます。しかしニューラルネットワークにおける学習では、微分によって得られた傾き(勾配)を用いてパラメータを更新するのでした(確率的勾配降下法)。ステップ関数はほとんどの範囲で傾きが0です。勾配の情報が0では学習できません。また誤差逆伝播法では、各レイヤの勾配の積を伝播するのでした。そのためあるレイヤの勾配が0だと、そのレイヤまでの勾配情報に0を掛けるため0になってしまい、情報を以降のレイヤに伝播できません。それだと、Lossが小さくなるようにパラメータを更新できません。
 他にも、Dropoutレイヤではニューロンごとに伝播する・しないを選びました。ただし、その選び方は学習せずに終始ランダムに選びました。しかしここでは、学習によってより良い情報を選べるようにする必要があります。

# 8.1.3項で利用するライブラリ
import numpy as np


・処理の確認

 図8-11を参考にして、Weight Sumレイヤの処理を確認していきます。

・順伝播の計算

 データとパラメータの形状に関する値を設定します。また、Encoderから入力する「Encoderの隠れ状態$\mathbf{hs}^{(\mathrm{Enc})} = (\mathbf{h}_0^{(\mathrm{Enc})}, \cdots, \mathbf{h}_{T-1}^{(\mathrm{Enc})})$」を処理結果が分かりやすくなるように作成しておきます。この$T$は、Encoder側の時系列サイズです。

# データとパラメータの形状に関する値を指定
N = 3 # バッチサイズ(入力する文章数)
T = 4 # Encoderの時系列サイズ(入力する単語数)
H = 5 # 隠れ状態のサイズ(LSTMレイヤの中間層のニューロン数)

# (簡易的に)EncoderのT個の隠れ状態を作成
hs = np.arange(N * T * H).reshape((N, T, H)) + 1
print(hs)
print(hs.shape)
[[[ 1  2  3  4  5]
  [ 6  7  8  9 10]
  [11 12 13 14 15]
  [16 17 18 19 20]]

 [[21 22 23 24 25]
  [26 27 28 29 30]
  [31 32 33 34 35]
  [36 37 38 39 40]]

 [[41 42 43 44 45]
  [46 47 48 49 50]
  [51 52 53 54 55]
  [56 57 58 59 60]]]
(3, 4, 5)

 この章では、NumPy配列の表示形式に合わせて配列を表記することにします。$\mathbf{hs}^{(\mathrm{Enc})}$は、$(N \times T \times H)$の3次元配列

$$ \mathbf{hs}^{(\mathrm{Enc})} = \begin{pmatrix} \begin{pmatrix} h_{0,0,0} & \cdots & h_{0,0,H-1} \\ \vdots & \ddots & \vdots \\ h_{0,T-1,0} & \cdots & h_{0,T-1,H-1} \end{pmatrix} & \cdots & \begin{pmatrix} h_{N-1,0,0} & \cdots & h_{N-1,0,H-1} \\ \vdots & \ddots & \vdots \\ h_{N-1,T-1,0} & \cdots & h_{N-1,T-1,H-1} \end{pmatrix} \end{pmatrix} $$

と表示されています。(各要素にまで${(\mathrm{Enc})}$を付けるとごちゃごちゃするので極力省略します。)
 各列(2次元方向に並ぶ$H$個の要素)は隠れ状態ベクトルの次元を、各行(1次元方向に並ぶ$T$個の要素)は時刻を表します。また、$T$行$H$列のまとまりで横(NumPy配列の出力としては縦)に並ぶ要素が、0次元方向に並ぶ$N$個の要素で、何番目のバッチデータなのかを表します。例えば、$(h_{0,0,0}, \cdots, h_{0,0,H-1})$は0番目のバッチデータにおける時刻0の隠れ状態ベクトルです。

 また、次項で実装するAttention Weightレイヤからも「Attentionの重み$\mathbf{a}$」が入力します。こちらも処理が分かりやすいように作成します。

# (簡易的に)重Attentionのみを作成
a = np.arange(N * T).reshape((N, T))
print(a)
print(a.shape)
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
(3, 4)

 $\mathbf{a}$は、$(N \times T)$の2次元配列

$$ \mathbf{a} = \begin{pmatrix} a_{0,0} & \cdots & a_{0,T-1} \\ \vdots & \ddots & \vdots \\ a_{N-1,0} & \cdots & a_{N-1,T-1} \end{pmatrix} $$

です。文章(行)ごとに和をとると1になります。

$$ \sum_{t=0}^{T-1} a_{n,t} = 1 $$

 $n$番目の重み$(a_{n,0}, \cdots, a_{n,T-1})$の各要素は、$n$番目の文書の隠れ状態$(h_{n,0,0}, \cdots, h_{n,T-1,H-1})$の各時刻に対応していて、時刻ごとに重み付けします。これは各時刻の重要度に応じて情報を割り引いています。

 全てのデータの重み付け処理を、ブロードキャスト等の機能を使わず明示的に行うために、$\mathbf{a}$と$\mathbf{hs}^{(\mathrm{Enc})}$の形状と一致させます。$\mathbf{a}$を要素数はそのまま3次元配列に変換します。

# 3次元配列に変換
tmp_a = a.reshape((N, T, 1))
print(tmp_a)
print(tmp_a.shape)
[[[ 0]
  [ 1]
  [ 2]
  [ 3]]

 [[ 4]
  [ 5]
  [ 6]
  [ 7]]

 [[ 8]
  [ 9]
  [10]
  [11]]]
(3, 4, 1)

 $(N \times H \times 1)$の3次元配列

$$ \mathbf{a} = \begin{pmatrix} \begin{pmatrix} a_{0,0} \\ \vdots \\ a_{0,T-1} \end{pmatrix} & \cdots & \begin{pmatrix} a_{N-1,0} \\ \vdots \\ a_{N-1,T-1} \end{pmatrix} \end{pmatrix} $$

となりました。

 さらに、各要素を(0から数えて)2次元方向に$H$個複製したものを$\mathbf{ar}$とします(rはリピートのことだと思います)。

# 2次元方向に複製
ar = tmp_a.repeat(H, axis=2)
print(ar)
print(ar.shape)
[[[ 0  0  0  0  0]
  [ 1  1  1  1  1]
  [ 2  2  2  2  2]
  [ 3  3  3  3  3]]

 [[ 4  4  4  4  4]
  [ 5  5  5  5  5]
  [ 6  6  6  6  6]
  [ 7  7  7  7  7]]

 [[ 8  8  8  8  8]
  [ 9  9  9  9  9]
  [10 10 10 10 10]
  [11 11 11 11 11]]]
(3, 4, 5)

 $\mathbf{ar}$は、$(N \times T \times H)$の3次元配列

$$ \mathbf{ar} = \begin{pmatrix} \begin{pmatrix} a_{0,0} & \cdots & a_{0,0} \\ \vdots & \ddots & \vdots \\ a_{0,T-1} & \cdots & a_{0,T-1} \end{pmatrix} & \cdots & \begin{pmatrix} a_{N-1,0} & \cdots & a_{N-1,0} \\ \vdots & \ddots & \vdots \\ a_{N-1,T-1} & \cdots & a_{N-1,T-1} \end{pmatrix} \end{pmatrix} $$

となり、$\mathbf{hs}^{(\mathrm{Enc})}$と同じ形状になりました。($a_{n,t,0} = \cdots = a_{n,t,H-1}$である$(a_{n,0,0}, \cdots, a_{n,T-1,H-1})$が横(表示上は縦)に$N$個並んでいるイメージです。)

 「Encoderの隠れ状態$\mathbf{hs}^{(\mathrm{Enc})}$」と「形状を調整した重み$\mathbf{ar}$」を要素ごとに掛けます。計算結果を$\mathbf{t}$とします(何由来の$t$でしょうか?時間インデックスの$t$とは別物です)。

# 重み付け
t = hs * ar
print(t)
print(t.shape)
[[[  0   0   0   0   0]
  [  6   7   8   9  10]
  [ 22  24  26  28  30]
  [ 48  51  54  57  60]]

 [[ 84  88  92  96 100]
  [130 135 140 145 150]
  [186 192 198 204 210]
  [252 259 266 273 280]]

 [[328 336 344 352 360]
  [414 423 432 441 450]
  [510 520 530 540 550]
  [616 627 638 649 660]]]
(3, 4, 5)

 要素ごとの積はアダマール積と呼び、$\odot$を使って(行列の積と区別できるように)表します。同じ形状の配列を要素ごとに掛けただけなので、形状は変わらず$(N \times T \times H)$の3次元配列

$$ \begin{aligned} \mathbf{t} &= \mathbf{hs} \odot \mathbf{ar} \\ &= \begin{pmatrix} \begin{pmatrix} h_{0,0,0} a_{0,0} & \cdots & h_{0,0,H-1} a_{0,0} \\ \vdots & \ddots & \vdots \\ h_{0,T-1,0} a_{0,T-1} & \cdots & h_{0,T-1,H-1} a_{0,T-1} \end{pmatrix} & \cdots & \begin{pmatrix} h_{N-1,0,0} a_{N-1,0} & \cdots & h_{N-1,0,H-1} a_{N-1,0} \\ \vdots & \ddots & \vdots \\ h_{N-1,T-1,0} a_{N-1,T-1} & \cdots & h_{N-1,T-1,H-1} a_{N-1,T-1} \end{pmatrix} \end{pmatrix} \\ &= \begin{pmatrix} \begin{pmatrix} t_{0,0,0} & \cdots & t_{0,0,H-1} \\ \vdots & \ddots & \vdots \\ t_{0,T-1,0} & \cdots & t_{0,T-1,H-1} \end{pmatrix} & \cdots & \begin{pmatrix} t_{N-1,0,0} & \cdots & t_{N-1,0,H-1} \\ \vdots & \ddots & \vdots \\ h_{N-1,T-1,0} & \cdots & h_{N-1,T-1,H-1} \end{pmatrix} \end{pmatrix} \end{aligned} $$

です。以降の計算を分かりやすくするために、$\mathbf{t}$の要素を$t_{n,t,h} = h_{n,t,h} a_{n,t}$で表すことにします。

 $\mathbf{t}$を時系列方向に和をとって、コンテキスト$\mathbf{c}$とします。

# 重み付け和を計算
c = np.sum(t, axis=1)
print(c)
print(c.shape)
[[  76   82   88   94  100]
 [ 652  674  696  718  740]
 [1868 1906 1944 1982 2020]]
(3, 5)

 1つの次元の和をとることで次元が1つ減り、$(N \times H)$の2次元配列

$$ \begin{aligned} \mathbf{c} &= \begin{pmatrix} \sum_{t=0}^{T-1} t_{0,t,0} & \cdots & \sum_{t=0}^{T-1} t_{0,t,H-1} \\ \vdots & \ddots & \vdots \\ \sum_{t=0}^{T-1} t_{N-1,t,0} & \cdots & \sum_{t=0}^{T-1} t_{N-1,t,H-1} \end{pmatrix} \\ &= \begin{pmatrix} c_{0,0} & \cdots & c_{0,H-1} \\ \vdots & \ddots & \vdots \\ c_{N-1,0} & \cdots & c_{N-1,H-1} \end{pmatrix} \end{aligned} $$

になります。$c_{n,h} = \sum_{t=0}^{T-1} t_{n,t,h} = \sum_{t=0}^{T-1} h_{n,t,h} a_{n,t}$とおきます。

 Encoderの隠れ状態$\mathbf{hs}^{(\mathrm{Enc})}$から注目すべき情報を抽出したコンテキスト$\mathbf{c}$が得られました。$\mathbf{c}$は、同じ時刻のAffineレイヤに入力します。

 以上が順伝播の処理です。続いて、逆伝播の処理を確認します。

・逆伝播の計算

 同じ時刻のAffineレイヤからコンテキストの勾配$\frac{\partial L}{\partial \mathbf{c}}$が入力します。ここでも処理が分かりやすいように作成します。

# (簡易的に)コンテキストの勾配を作成
dc = np.arange(N * H).reshape((N, H)) + 1
print(dc)
print(dc.shape)
[[ 1  2  3  4  5]
 [ 6  7  8  9 10]
 [11 12 13 14 15]]
(3, 5)

 $\frac{\partial L}{\partial \mathbf{c}}$は、$(N \times H)$の2次元配列

$$ \frac{\partial L}{\partial \mathbf{c}} = \begin{pmatrix} \frac{\partial L}{\partial c_{0,0}} & \cdots & \frac{\partial L}{\partial c_{0,H-1}} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial c_{N-1,0}} & \cdots & \frac{\partial L}{\partial c_{N-1,H-1}} \end{pmatrix} $$

で、$\mathbf{c}$と同じ形状です。

 $\mathbf{t}$から$\mathbf{c}$への順伝播の計算では、時系列方向に和をとりました。和の計算はSumノードです。Sumノードの逆伝播では、(要素数は同じまま3次元配列に変換した上で)時系列方向に要素を$T$個複製します(1.3.4.4項「Sumノード」)。要素数はそのまま3次元配列に変換します。

# 3次元配列に変換
tmp_dc = dc.reshape((N, 1, H))
print(tmp_dc)
print(tmp_dc.shape)
[[[ 1  2  3  4  5]]

 [[ 6  7  8  9 10]]

 [[11 12 13 14 15]]]
(3, 1, 5)

 $(N \times 1 \times H)$の3次元配列

$$ \frac{\partial L}{\partial \mathbf{c}} = \begin{pmatrix} \begin{pmatrix} \frac{\partial L}{\partial c_{0,0}} & \cdots & \frac{\partial L}{\partial c_{0,H-1}} \end{pmatrix} & \cdots & \begin{pmatrix} \frac{\partial L}{\partial c_{N-1,0}} & \cdots & \frac{\partial L}{\partial c_{N-1,H-1}} \end{pmatrix} \end{pmatrix} $$

となりました。

 時系列方向に要素を$T$個複製します。

# Sumノードの逆伝播
dt = tmp_dc.repeat(T, axis=1)
print(dt)
print(dt.shape)
[[[ 1  2  3  4  5]
  [ 1  2  3  4  5]
  [ 1  2  3  4  5]
  [ 1  2  3  4  5]]

 [[ 6  7  8  9 10]
  [ 6  7  8  9 10]
  [ 6  7  8  9 10]
  [ 6  7  8  9 10]]

 [[11 12 13 14 15]
  [11 12 13 14 15]
  [11 12 13 14 15]
  [11 12 13 14 15]]]
(3, 4, 5)

 計算結果は$\mathbf{t}$の勾配$\frac{\partial L}{\partial \mathbf{t}}$です。$\frac{\partial L}{\partial \mathbf{t}}$は、$(N \times T \times H)$の3次元配列

$$ \begin{aligned} \frac{\partial L}{\partial \mathbf{t}} &= \begin{pmatrix} \begin{pmatrix} \frac{\partial L}{\partial c_{0,0}} & \cdots & \frac{\partial L}{\partial c_{0,H-1}} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial c_{0,0}} & \cdots & \frac{\partial L}{\partial c_{0,H-1}} \end{pmatrix} & \cdots & \begin{pmatrix} \frac{\partial L}{\partial c_{N-1,0}} & \cdots & \frac{\partial L}{\partial c_{N-1,H-1}} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial c_{N-1,0}} & \cdots & \frac{\partial L}{\partial c_{N-1,H-1}} \end{pmatrix} \end{pmatrix} \\ &= \begin{pmatrix} \begin{pmatrix} \frac{\partial L}{\partial t_{0,0,0}} & \cdots & \frac{\partial L}{\partial t_{0,0,H-1}} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial t_{0,T-1,0}} & \cdots & \frac{\partial L}{\partial t_{0,T-1,H-1}} \end{pmatrix} & \cdots & \begin{pmatrix} \frac{\partial L}{\partial t_{N-1,0,0}} & \cdots & \frac{\partial L}{\partial t_{N-1,0,H-1}} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial h_{N-1,T-1,0}} & \cdots & \frac{\partial L}{\partial h_{N-1,T-1,H-1}} \end{pmatrix} \end{pmatrix} \end{aligned} $$

で、$\mathbf{t}$と同じ形状です。また、$\frac{\partial L}{\partial t_{n,t,h}} = \frac{\partial L}{\partial c_{n,h}}$です。

 $\mathbf{t}$(の各要素)を求める順伝播の計算は乗算ノードです。乗算ノードの逆伝播を計算します(1.3.4.1項「乗算ノード」)。

# 乗算ノードの逆伝播
dhs = dt * ar
print(dhs)
print(dhs.shape)
[[[  0   0   0   0   0]
  [  1   2   3   4   5]
  [  2   4   6   8  10]
  [  3   6   9  12  15]]

 [[ 24  28  32  36  40]
  [ 30  35  40  45  50]
  [ 36  42  48  54  60]
  [ 42  49  56  63  70]]

 [[ 88  96 104 112 120]
  [ 99 108 117 126 135]
  [110 120 130 140 150]
  [121 132 143 154 165]]]
(3, 4, 5)

 計算結果は、Encoderの隠れ状態の勾配$\frac{\partial L}{\partial \mathbf{hs}^{(\mathrm{Enc})}}$です。$\frac{\partial L}{\partial \mathbf{hs}^{(\mathrm{Enc})}}$は、$(N \times T \times H)$の3次元配列

$$ \begin{aligned} \frac{\partial L}{\partial \mathbf{hs}^{(\mathrm{Enc})}} &= \frac{\partial L}{\partial \mathbf{t}} \odot \mathbf{ar} \\ &= \begin{pmatrix} \begin{pmatrix} \frac{\partial L}{\partial t_{0,0,0}} a_{0,0} & \cdots & \frac{\partial L}{\partial t_{0,0,H-1} } a_{0,0} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial t_{0,T-1,0}} a_{0,T-1} & \cdots & \frac{\partial L}{\partial t_{0,T-1,H-1}} a_{0,T-1} \end{pmatrix} & \cdots & \begin{pmatrix} \frac{\partial L}{\partial t_{N-1,0,0}} a_{N-1,0} & \cdots & \frac{\partial L}{\partial t_{N-1,0,H-1}} a_{N-1,0} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial t_{N-1,T-1,0}} a_{N-1,T-1} & \cdots & \frac{\partial L}{\partial t_{N-1,T-1,H-1}} a_{N-1,T-1} \end{pmatrix} \end{pmatrix} \\ &= \begin{pmatrix} \begin{pmatrix} \frac{\partial L}{\partial h_{0,0,0}} & \cdots & \frac{\partial L}{\partial h_{0,0,H-1}} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial h_{0,T-1,0}} & \cdots & \frac{\partial L}{\partial h_{0,T-1,H-1}} \end{pmatrix} & \cdots & \begin{pmatrix} \frac{\partial L}{\partial h_{N-1,0,0}} & \cdots & \frac{\partial L}{\partial h_{N-1,0,H-1}} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial h_{N-1,T-1,0}} & \cdots & \frac{\partial L}{\partial h_{N-1,T-1,H-1}} \end{pmatrix} \end{pmatrix} \end{aligned} $$

で、$\mathbf{hs}^{(\mathrm{Enc})}$と同じ形状です。また、$\frac{\partial L}{\partial h_{n,t,h}^{(\mathrm{Enc})}} = \frac{\partial L}{\partial t_{n,t,h}} a_{n,t}$です。。

 $\frac{\partial L}{\partial \mathbf{hs}^{(\mathrm{Enc})}}$は、EncoderのTime LSTMレイヤに入力します。

 同様に、もう1つの変数についても計算します。

# 乗算ノードの逆伝播
dar = dt * hs
print(dar)
print(dar.shape)
[[[  1   4   9  16  25]
  [  6  14  24  36  50]
  [ 11  24  39  56  75]
  [ 16  34  54  76 100]]

 [[126 154 184 216 250]
  [156 189 224 261 300]
  [186 224 264 306 350]
  [216 259 304 351 400]]

 [[451 504 559 616 675]
  [506 564 624 686 750]
  [561 624 689 756 825]
  [616 684 754 826 900]]]
(3, 4, 5)

 計算結果は、複製した重みの勾配$\frac{\partial L}{\partial \mathbf{ar}}$です。$\frac{\partial L}{\partial \mathbf{ar}}$は、$(N \times T \times H)$の3次元配列

$$ \begin{aligned} \frac{\partial L}{\partial \mathbf{ar}} &= \frac{\partial L}{\partial \mathbf{t}} \odot \mathbf{hs}^{(\mathrm{Enc})} \\ &= \begin{pmatrix} \begin{pmatrix} \frac{\partial L}{\partial t_{0,0,0}} h_{0,0,0} & \cdots & \frac{\partial L}{\partial t_{0,0,H-1}} h_{0,0,H-1} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial t_{0,T-1,0}} h_{0,T-1,0} & \cdots & \frac{\partial L}{\partial t_{0,T-1,H-1}} h_{0,T-1,H-1} \end{pmatrix} & \cdots & \begin{pmatrix} \frac{\partial L}{\partial t_{N-1,0,0}} h_{N-1,0,0} & \cdots & \frac{\partial L}{\partial t_{N-1,0,H-1}} h_{N-1,0,H-1} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial t_{N-1,T-1,0}} h_{N-1,T-1,0} & \cdots & \frac{\partial L}{\partial t_{N-1,T-1,H-1}} h_{N-1,T-1,H-1} \end{pmatrix} \end{pmatrix} \\ &= \begin{pmatrix} \begin{pmatrix} \frac{\partial L}{\partial a_{0,0}} & \cdots & \frac{\partial L}{\partial a_{0,0}} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial a_{0,T-1}} & \cdots & \frac{\partial L}{\partial a_{0,T-1}} \end{pmatrix} & \cdots & \begin{pmatrix} \frac{\partial L}{\partial a_{N-1,0}} & \cdots & \frac{\partial L}{\partial a_{N-1,0}} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial a_{N-1,T-1}} & \cdots & \frac{\partial L}{\partial a_{N-1,T-1}} \end{pmatrix} \end{pmatrix} \end{aligned} $$

で、$\mathbf{ar}$と同じ形状です。また、$\frac{\partial L}{\partial a_{n,t}} = \frac{\partial L}{\partial t_{n,t,h}} h_{n,t,h}^{(\mathrm{Enc})}$です。

 $\mathbf{a}$から$\mathrm{ar}$への順伝播では、2次元方向に$H$個複製しました。これはRepeatノードです。Repeatノードの逆伝播では、複製した要素の勾配の和をとります(1.3.4.3項「Repeatノード」)。

# Repeatノードの逆伝播
da = np.sum(dar, axis=2)
print(da)
print(da.shape)
[[  55  130  205  280]
 [ 930 1130 1330 1530]
 [2805 3130 3455 3780]]
(3, 4)

 計算結果は、Attentionの重みの勾配$\frac{\partial L}{\partial \mathbf{a}}$です。$\frac{\partial L}{\partial \mathbf{a}}$は、$(N \times T)$の2次元配列

$$ \frac{\partial L}{\partial \mathbf{a}} = \begin{pmatrix} \frac{\partial L}{\partial a_{0,0}} & \cdots & \frac{\partial L}{\partial a_{0,T-1}} \\ \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial a_{N-1,0}} & \cdots & \frac{\partial L}{\partial a_{N-1,T-1}} \end{pmatrix} $$

で、$\mathbf{a}$と同じ形状です。また、$\frac{\partial L}{\partial a_{n,t}} = \sum_{h=0}^{H-1} \frac{\partial L}{\partial t_{n,t,h}} h_{n,t,h}^{(\mathrm{Enc})}$です。

 $\frac{\partial L}{\partial \mathbf{a}}$は、同じ時刻のAttention Weightレイヤに入力します。

 以上がWeight Sumレイヤの処理です。

・実装

 処理の確認ができたので、Weight Sumレイヤをクラスとして実装しましょう。

# Weight Sumレイヤの実装
class WeightSum:
    # 初期化メソッド
    def __init__(self):
        # 他のレイヤと対応させるための空のリストを作成
        self.params = [] # パラメータ
        self.grads = []  # 勾配
        
        # 中間変数の受け皿を初期化
        self.cache = None
    
    # 順伝播メソッド
    def forward(self, hs, a):
        # 変数の形状に関する値を取得
        N, T, H = hs.shape
        
        # Encoderの隠れ状態と同じ形状に複製
        ar = a.reshape(N, T, 1).repeat(H, axis=2)
        
        # コンテキスト(重み付き和)を計算
        t = hs * ar
        c = np.sum(t, axis=1)
        
        # 逆伝播の計算用に変数を保存
        self.cache = (hs, ar)
        return c
    
    # 逆伝播メソッド
    def backward(self, dc):
        # 変数を取得
        hs, ar = self.cache
        
        # 変数の形状に関する値を取得
        N, T, H = hs.shape
        
        # Sumノードの逆伝播を計算
        dt = dc.reshape(N, 1, H).repeat(T, axis=1)
        
        # 乗算ノードの逆伝播を計算
        dar = dt * hs
        dhs = dt * ar # Encoderの隠れ状態の勾配
        
        # Repeatノードの逆伝播を計算
        da = np.sum(dar, axis=2) # Attentionの重みの勾配
        return dhs, da


 実装したクラスを試してみましょう。

 簡易的な順伝播の入力$\mathbf{hs}^{(\mathrm{Enc})},\ \mathbf{a}$と、Weight Sumレイヤのインスタンスを作成します。

# (簡易的に)Encoderの隠れ状態を作成
hs = np.random.randn(N, T, H)
print(np.round(hs, 2))
print(hs.shape)

# (簡易的に)Attentionの重みを作成
a = np.random.rand(N, T)
a /= np.sum(a, axis=1, keepdims=True) # 正規化
print(np.round(a, 2))
print(np.sum(a, axis=1))
print(a.shape)

# Weight Sumレイヤのインスタンスを作成
weightsum_layer = WeightSum()
[[[ 1.36  0.13  2.56 -0.82 -2.19]
  [-1.04 -1.09  1.73 -1.17  1.64]
  [ 0.93  0.74 -0.35  1.64  0.34]
  [-0.74 -0.06 -0.37 -0.44  0.32]]

 [[-0.02  0.05 -0.09  0.09 -1.36]
  [ 0.39  0.72  0.1   0.9  -0.52]
  [-0.05  1.62 -0.33 -2.53 -1.41]
  [ 0.69 -1.83 -1.26  0.07 -0.05]]

 [[ 0.88  0.94  0.14 -0.87 -1.1 ]
  [ 1.49  0.97  0.06  1.36  0.11]
  [ 2.37  0.68  0.45 -0.06  0.17]
  [-1.05 -0.17 -0.17 -0.86 -0.98]]]
(3, 4, 5)
[[0.42 0.08 0.2  0.3 ]
 [0.22 0.31 0.4  0.08]
 [0.41 0.16 0.22 0.21]]
[1. 1. 1.]
(3, 4)


 入力を渡して、順伝播を計算します。

# コンテキストを計算
c = weightsum_layer.forward(hs, a)
print(np.round(c, 2))
print(c.shape)
[[ 0.46  0.1   1.03 -0.23 -0.64]
 [ 0.15  0.73 -0.22 -0.7  -1.02]
 [ 0.89  0.65  0.13 -0.34 -0.61]]
(3, 5)

 Encoderによってエンコードされた入力情報$\mathbf{hs}^{(\mathrm{Enc})}$から必要な情報を抽出した$\mathbf{c}$が得られました。$t$番目のWeight Sumレイヤで計算された$\mathbf{c}_t$は、Decoderの同じ時刻の隠れ状態$\mathbf{h}_t^{(\mathrm{Dec})}$と結合してAffineレイヤに入力します。

 続いて、逆伝播の入力(Affineレイヤの出力)$\frac{\partial L}{\partial \mathbf{c}}$を簡易的に作成して、逆伝播を計算します。

# (簡易的に)コンテキストの勾配を作成
dc = np.ones((N, H))
print(dc.shape)

# 逆伝播を計算
dhs, da = weightsum_layer.backward(dc)
print(np.round(dhs, 2))
print(dhs.shape)
print(np.round(da, 2))
print(da.shape)
(3, 5)
[[[0.42 0.42 0.42 0.42 0.42]
  [0.08 0.08 0.08 0.08 0.08]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.3  0.3  0.3  0.3  0.3 ]]

 [[0.22 0.22 0.22 0.22 0.22]
  [0.31 0.31 0.31 0.31 0.31]
  [0.4  0.4  0.4  0.4  0.4 ]
  [0.08 0.08 0.08 0.08 0.08]]

 [[0.41 0.41 0.41 0.41 0.41]
  [0.16 0.16 0.16 0.16 0.16]
  [0.22 0.22 0.22 0.22 0.22]
  [0.21 0.21 0.21 0.21 0.21]]]
(3, 4, 5)
[[ 1.04  0.06  3.29 -1.29]
 [-1.32  1.58 -2.71 -2.37]
 [-0.02  3.99  3.61 -3.22]]
(3, 4)

 得られた$\frac{\partial L}{\partial \mathbf{hs}^{(\mathrm{Enc})}}$hはEncoderのTime LSTMレイヤに、$\frac{\partial L}{\partial \mathbf{a}}$は同じ時刻のDecoderのLSTMレイヤに入力します。

 以上でEncoderの隠れ状態から必要な情報を抽出できました。次項では、情報を抽出する際に用いる重み$\mathbf{a}$を作成するレイヤを実装します。

参考文献

  • 斎藤康毅『ゼロから作るDeep Learning 2――自然言語処理編』オライリー・ジャパン,2018年.

おわりに

 最後の章の1つ目!現時点でほぼほぼ書き上がってるんで気楽なもんよ。

 微妙に無視して過ごしてきたNumPy配列の構造を意識せざるを得なくなりました。記事では数式上の表現とプログラム上の表現を統一したかったのですが、本当は次元ではなく軸と呼ぶべきなんでしょうか。これまでは0軸が縦だったり奥行きだったりでぬぁっ???ってなってたんですが、今回ようやくなんとなく掴めました。何次元であろうと行列っぽく表示されてる部分は後から2つの軸で、全体としては多次元配列というよりは入れ子になってるんだと理解しました。たぶん。以前買ったPython入門書をそろそろ読まねば。

【次節の内容】

www.anarchive-beta.com