からっぽのしょこ

読んだら書く!書いたら読む!同じ事は二度調べ(たく)ない

Affineレイヤの逆伝播の可視化【ゼロつく1のノート(数学)】

はじめに

 「機械学習・深層学習」初学者のための『ゼロから作るDeep Learning』の攻略ノートです。『ゼロつくシリーズ』学習の補助となるように適宜解説を加えています。本と一緒に読んでください。

 ニューラルネットワーク内部の計算について、数学的背景の解説や計算式の導出を行い、また実際の計算結果やグラフで確認していきます。

 この記事は、5.6節「Affineレイヤの実装」の内容です。Affineレイヤの逆伝播の計算過程を可視化することで理解を深めます。

【元の記事】

www.anarchive-beta.com

【他の記事一覧】

www.anarchive-beta.com

【この記事の内容】

Affineレイヤの逆伝播の可視化

 Affineレイヤの逆伝播の計算過程を可視化して確認します。ただし話を簡単にするために、バイアスを省略します。Affineレイヤについては「5.6.2:Affineレイヤの実装【ゼロつく1のノート(実装)】 - からっぽのしょこ」、順伝播については「Affineレイヤの順伝播の可視化【ゼロつく1のノート(数学)】 - からっぽのしょこ」を参照してください。

 利用するライブラリを読み込みます。

# 利用するライブラリ
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

 animationモジュールのFuncAnimation()を使って、解説用のアニメーションを作成します。アニメーションの作成自体は目的ではないので、不要であれば省略してください。

・順伝播の入力と重みの設定

 順伝播の可視化ではMNISTデータを利用しましたが、ここでは人工的に作成したデータを使います。

 バッチサイズ(1試行当たりのデータ数)$N$、1データの要素数(前の層のニューロン数)$D$、次の層のニューロン数$H$を指定して、順伝播の入力$\mathbf{X} = (x_{0,0}, \cdots, x_{N-1,D-1})$、重み$\mathbf{W} = (w_{0,0}, \cdots, w_{D-1,H-1})$を作成します。Pythonのインデックスに合わせて添字を0から割り当てています。

# 形状に関する値を指定
N = 10
D = 6
H = 4

# (仮の)入力を作成
X = np.arange(N).reshape((N, 1)).repeat(D, axis=1) + 1.0
print(X)

# (仮の)パラメータを作成
W = np.arange(D * H).reshape((D, H)) + 1.0
print(W)
[[ 1.  1.  1.  1.  1.  1.]
 [ 2.  2.  2.  2.  2.  2.]
 [ 3.  3.  3.  3.  3.  3.]
 [ 4.  4.  4.  4.  4.  4.]
 [ 5.  5.  5.  5.  5.  5.]
 [ 6.  6.  6.  6.  6.  6.]
 [ 7.  7.  7.  7.  7.  7.]
 [ 8.  8.  8.  8.  8.  8.]
 [ 9.  9.  9.  9.  9.  9.]
 [10. 10. 10. 10. 10. 10.]]
[[ 1.  2.  3.  4.]
 [ 5.  6.  7.  8.]
 [ 9. 10. 11. 12.]
 [13. 14. 15. 16.]
 [17. 18. 19. 20.]
 [21. 22. 23. 24.]]

 この例では、入力Xはデータごとに値が1増えるように、重みWは要素ごとに値が1増えるように作成します(深い意味はありません)。

 $\mathbf{X}$を描画します。

# 順伝播の入力を描画
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(D, N)) # 図の設定
ax.pcolor(X) # ヒートマップ
ax.set_xlabel('d') # x軸ラベル
ax.set_xticks(np.arange(D) + 0.5) # x軸の目盛位置
ax.set_xticklabels(np.arange(D)) # x軸目盛
ax.set_ylabel('n') # y軸ラベル
ax.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
ax.set_yticklabels(np.arange(N)) # y軸目盛
ax.set_title('$X$', fontsize=15) # 全体のタイトル
ax.invert_yaxis() # y軸を反転
ax.set_aspect('equal', adjustable='box') # アスペクト比
plt.show()

順伝播の入力


 $\mathbf{W}$を描画します。

# 重みを描画
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(H, D)) # 図の設定
ax.pcolor(W) # ヒートマップ
ax.set_xlabel('h') # x軸ラベル
ax.set_xticks(np.arange(H) + 0.5) # x軸の目盛位置
ax.set_xticklabels(np.arange(H)) # x軸目盛
ax.set_ylabel('d') # y軸ラベル
ax.set_yticks(np.arange(D) + 0.5) # y軸の目盛位置
ax.set_yticklabels(np.arange(D)) # y軸目盛
ax.set_title('$W$', fontsize=15) # 全体のタイトル
ax.invert_yaxis() # y軸を反転
ax.set_aspect('equal', adjustable='box') # アスペクト比
plt.show()

重み


・順伝播の計算

 順伝播の入力と重みを用意できたので、順伝播の出力(重み付き和)を計算します。

 順伝播の出力$\mathbf{Z} = (z_{0,0}, \cdots, z_{N-1,H-1})$を計算します。

# 順伝播を計算
Z = np.dot(X, W)
print(Z)
print(Z.shape)
[[ 66.  72.  78.  84.]
 [132. 144. 156. 168.]
 [198. 216. 234. 252.]
 [264. 288. 312. 336.]
 [330. 360. 390. 420.]
 [396. 432. 468. 504.]
 [462. 504. 546. 588.]
 [528. 576. 624. 672.]
 [594. 648. 702. 756.]
 [660. 720. 780. 840.]]
(10, 4)

 ドット積np.dot()を使って、$\mathbf{Z} = \mathbf{X} \cdot \mathbf{W}$を計算します。
 $\mathbf{Z}$は、$N \times D$と$D \times H$の行列の積なので、$N \times H$の行列(2次元配列)になります。

 $\mathbf{Z}$を描画します。

# 順伝播の出力を描画
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(H, N)) # 図の設定
ax.pcolor(Z) # ヒートマップ
ax.set_xlabel('h') # x軸ラベル
ax.set_xticks(np.arange(H) + 0.5) # x軸の目盛位置
ax.set_xticklabels(np.arange(H)) # x軸目盛
ax.set_ylabel('n') # y軸ラベル
ax.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
ax.set_yticklabels(np.arange(N)) # y軸目盛
ax.set_title('$Z$', fontsize=15) # 全体のタイトル
ax.invert_yaxis() # y軸を反転
ax.set_aspect('equal', adjustable='box') # アスペクト比
plt.show()

順伝播の出力

 ここまでがAffineレイヤの順伝播で行う計算です。

 順伝播における各項の計算を確認しましょう。

 $\mathbf{Z}$の各項$z_{n,h}$は、「$\mathbf{X}$の$n$行目」と「$\mathbf{W}$の$h$列目」の内積です。

$$ z_{n,h} = \sum_{d=0}^{D-1} z_{n,d} w_{d,h} $$

 この計算を図で確認します。

・作図コード(クリックで展開)

# 出力のインデックスを指定
n, h = 0, 0

## 網掛け範囲の設定
# 順伝播の入力用の配列を作成
X_mesh = np.zeros_like(X)
X_mesh[n, :] = 1
X_mesh = np.ma.masked_where(X_mesh != 1, X) # 指定した範囲以外をマスク

# 重み用の配列を作成
W_mesh = np.zeros_like(W)
W_mesh[:, h] = 1
W_mesh = np.ma.masked_where(W_mesh != 1, W) # 指定した範囲以外をマスク

# 順伝播の出力用の配列を作成
Z_mesh = np.zeros_like(Z)
Z_mesh[n, h] = 1
Z_mesh = np.ma.masked_where(Z_mesh != 1, Z) # 指定した範囲以外をマスク

## 作図
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(16, 9)) # 図の設定

# 順伝播の入力を描画
ax0 = axs[0]
ax0.pcolor(X) # ヒートマップ
ax0.pcolor(X_mesh, hatch='//', edgecolor='turquoise', linewidth=3, vmin=np.min(X), vmax=np.max(X)) # 確認用の枠
ax0.set_xlabel('d') # x軸ラベル
ax0.set_xticks(np.arange(D) + 0.5) # x軸の目盛位置
ax0.set_xticklabels(np.arange(D)) # x軸目盛
ax0.set_ylabel('n') # y軸ラベル
ax0.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
ax0.set_yticklabels(np.arange(N)) # y軸目盛
ax0.set_title('$X$', fontsize=15) # 全体のタイトル
ax0.invert_yaxis() # y軸を反転
ax0.set_aspect('equal', adjustable='box') # アスペクト比

# 重みを描画
ax1 = axs[1]
ax1.pcolor(W) # ヒートマップ
ax1.pcolor(W_mesh, hatch='//', edgecolor='turquoise', linewidth=3, vmin=np.min(W), vmax=np.max(W)) # 確認用の枠
ax1.set_xlabel('h') # x軸ラベル
ax1.set_xticks(np.arange(H) + 0.5) # x軸の目盛位置
ax1.set_xticklabels(np.arange(H)) # x軸目盛
ax1.set_ylabel('d') # y軸ラベル
ax1.set_yticks(np.arange(D) + 0.5) # y軸の目盛位置
ax1.set_yticklabels(np.arange(D)) # y軸目盛
ax1.set_title('$W$', fontsize=15) # 全体のタイトル
ax1.invert_yaxis() # y軸を反転
ax1.set_aspect('equal', adjustable='box') # アスペクト比

# 順伝播の出力を描画
ax2 = axs[2]
ax2.pcolor(Z) # ヒートマップ
ax2.pcolor(Z_mesh, hatch='//', edgecolor='turquoise', linewidth=3, vmin=np.min(Z), vmax=np.max(Z)) # 確認用の枠
ax2.set_xlabel('h') # x軸ラベル
ax2.set_xticks(np.arange(H) + 0.5) # x軸の目盛位置
ax2.set_xticklabels(np.arange(H)) # x軸目盛
ax2.set_ylabel('n') # y軸ラベル
ax2.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
ax2.set_yticklabels(np.arange(N)) # y軸目盛
ax2.set_title('$Z = X \cdot W$', fontsize=15) # 全体のタイトル
ax2.invert_yaxis() # y軸を反転
ax2.set_aspect('equal', adjustable='box') # アスペクト比
plt.show()


順伝播の計算

 XWの網掛けされたピクセルを掛け合わせた総和が、Zの網掛けされたピクセルの値になります。

 アニメーションで全ての項について確認しましょう。

・作図コード(クリックで展開)

# 図の設定
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(16, 9))

# 作図処理を関数として定義
def update(i):
    # 前フレームのグラフを初期化
    for ax in axs.flat:
        ax.cla()
    
    # インデックスを計算
    n = i // H
    h = i % H
    
    ## i回目の図を作成
    
    # 順伝播の入力用の配列を作成
    X_mesh = np.zeros_like(X)
    X_mesh[n, :] = 1
    X_mesh = np.ma.masked_where(X_mesh != 1, X) # 指定した範囲以外をマスク
    
    # 重み用の配列を作成
    W_mesh = np.zeros_like(W)
    W_mesh[:, h] = 1
    W_mesh = np.ma.masked_where(W_mesh != 1, W) # 指定した範囲以外をマスク
    
    # 順伝播の出力用の配列を作成
    Z_mesh = np.zeros_like(Z)
    Z_mesh[n, h] = 1
    Z_mesh = np.ma.masked_where(Z_mesh != 1, Z) # 指定した範囲以外をマスク
    
    # 順伝播の入力を描画
    ax0 = axs[0]
    ax0.pcolor(X) # ヒートマップ
    ax0.pcolor(X_mesh, hatch='//', edgecolor='turquoise', linewidth=3, vmin=np.min(X), vmax=np.max(X)) # 確認用の枠
    ax0.set_xlabel('d') # x軸ラベル
    ax0.set_xticks(np.arange(D) + 0.5) # x軸の目盛位置
    ax0.set_xticklabels(np.arange(D)) # x軸目盛
    ax0.set_ylabel('n') # y軸ラベル
    ax0.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
    ax0.set_yticklabels(np.arange(N)) # y軸目盛
    ax0.set_title('$X$', fontsize=15) # 全体のタイトル
    ax0.invert_yaxis() # y軸を反転
    ax0.set_aspect('equal', adjustable='box') # アスペクト比
    
    # 重みを描画
    ax1 = axs[1]
    ax1.pcolor(W) # ヒートマップ
    ax1.pcolor(W_mesh, hatch='//', edgecolor='turquoise', linewidth=3, vmin=np.min(W), vmax=np.max(W)) # 確認用の枠
    ax1.set_xlabel('h') # x軸ラベル
    ax1.set_xticks(np.arange(H) + 0.5) # x軸の目盛位置
    ax1.set_xticklabels(np.arange(H)) # x軸目盛
    ax1.set_ylabel('d') # y軸ラベル
    ax1.set_yticks(np.arange(D) + 0.5) # y軸の目盛位置
    ax1.set_yticklabels(np.arange(D)) # y軸目盛
    ax1.set_title('$W$', fontsize=15) # 全体のタイトル
    ax1.invert_yaxis() # y軸を反転
    ax1.set_aspect('equal', adjustable='box') # アスペクト比
    
    # 順伝播の出力を描画
    ax2 = axs[2]
    ax2.pcolor(Z) # ヒートマップ
    ax2.pcolor(Z_mesh, hatch='//', edgecolor='turquoise', linewidth=3, vmin=np.min(Z), vmax=np.max(Z)) # 確認用の枠
    ax2.set_xlabel('h') # x軸ラベル
    ax2.set_xticks(np.arange(H) + 0.5) # x軸の目盛位置
    ax2.set_xticklabels(np.arange(H)) # x軸目盛
    ax2.set_ylabel('n') # y軸ラベル
    ax2.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
    ax2.set_yticklabels(np.arange(N)) # y軸目盛
    ax2.set_title('$Z = X \cdot W$', fontsize=15) # 全体のタイトル
    ax2.invert_yaxis() # y軸を反転
    ax2.set_aspect('equal', adjustable='box') # アスペクト比

# gif画像を作成
dot_anime = FuncAnimation(fig, update, frames=N * H, interval=200)

# gif画像を保存
dot_anime.save('ch5_6_dot_Z.gif')


順伝播の計算

 続いて、逆伝播の計算過程を確認します。

・逆伝播の入力の設定

 逆伝播の入力も簡易的に作成します。ニューラルネットワークにおいては、次のレイヤから伝播してきます。

 逆伝播の入力$\frac{\partial L}{\partial \mathbf{Z}} = (\frac{\partial L}{\partial z_{0,0}}, \cdots, \frac{\partial L}{\partial z_{N-1,H-1}})$を作成します。

# (仮の)逆伝播の入力を作成
dZ = np.ones_like(Z)
#dZ = np.arange(N).reshape((N, 1)).repeat(H, axis=1) + 1.0
print(dZ)
print(dZ.shape)
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]
(10, 4)

 $\frac{\partial L}{\partial \mathbf{Z}}$は$\mathbf{Z}$と同じ形状です。この例 では、全ての値を1とします。

 $\frac{\partial L}{\partial \mathbf{Z}}$を描画します。

# 逆伝播の入力を描画
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(H, N)) # 図の設定
ax.pcolor(dZ) # ヒートマップ
ax.set_xlabel('h') # x軸ラベル
ax.set_xticks(np.arange(H) + 0.5) # x軸の目盛位置
ax.set_xticklabels(np.arange(H)) # x軸目盛
ax.set_ylabel('n') # y軸ラベル
ax.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
ax.set_yticklabels(np.arange(N)) # y軸目盛
ax.set_title('$dZ$', fontsize=15) # 全体のタイトル
ax.invert_yaxis() # y軸を反転
ax.set_aspect('equal', adjustable='box') # アスペクト比
plt.show()

逆伝播の入力


 転置した重み$\mathbf{W}^{\mathrm{T}}$を描画します。

# 重みを描画
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(D, H)) # 図の設定
ax.pcolor(W.T) # ヒートマップ
ax.set_xlabel('d') # x軸ラベル
ax.set_xticks(np.arange(D) + 0.5) # x軸の目盛位置
ax.set_xticklabels(np.arange(D)) # x軸目盛
ax.set_ylabel('h') # y軸ラベル
ax.set_yticks(np.arange(H) + 0.5) # y軸の目盛位置
ax.set_yticklabels(np.arange(H)) # y軸目盛
ax.set_title('$W^T$', fontsize=15) # 全体のタイトル
ax.invert_yaxis() # y軸を反転
ax.set_aspect('equal', adjustable='box') # アスペクト比
plt.show()

転置した重み

 転置した重み$\mathbf{W}^{\mathrm{T}}$は、元の重み$\mathbf{W}$のx軸とy軸を入れ変えたグラフになります。

・逆伝播の計算

 では、逆伝播の計算を行います。逆伝播の計算については「バッチデータ版Affineレイヤの逆伝播の導出【ゼロつく1のノート(数学)】 - からっぽのしょこ」を参照してください。

 逆伝播の出力$\frac{\partial L}{\partial \mathbf{X}} = (\frac{\partial L}{\partial x_{0,0}}, \cdots, \frac{\partial L}{\partial x_{N-1,D-1}})$を計算します。

# 逆伝播を計算
dX = np.dot(dZ, W.T)
print(dX)
print(dX.shape)
[[10. 26. 42. 58. 74. 90.]
 [10. 26. 42. 58. 74. 90.]
 [10. 26. 42. 58. 74. 90.]
 [10. 26. 42. 58. 74. 90.]
 [10. 26. 42. 58. 74. 90.]
 [10. 26. 42. 58. 74. 90.]
 [10. 26. 42. 58. 74. 90.]
 [10. 26. 42. 58. 74. 90.]
 [10. 26. 42. 58. 74. 90.]
 [10. 26. 42. 58. 74. 90.]]
(10, 6)

 $\frac{\partial L}{\partial \mathbf{X}} = \frac{\partial L}{\partial \mathbf{Z}} \cdot \mathbf{W}^{\mathrm{T}}$を計算します。
 $\frac{\partial L}{\partial \mathbf{X}}$は、$N \times H$と$H \times D$の行列の積なので$N \times D$の行列になり、$\mathbf{X}$と同じ形状です。

 $\frac{\partial L}{\partial \mathbf{X}}$を描画します。

# 逆伝播の出力を描画
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(D, N)) # 図の設定
ax.pcolor(dX) # ヒートマップ
ax.set_xlabel('d') # x軸ラベル
ax.set_xticks(np.arange(D) + 0.5) # x軸の目盛位置
ax.set_xticklabels(np.arange(D)) # x軸目盛
ax.set_ylabel('n') # y軸ラベル
ax.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
ax.set_yticklabels(np.arange(N)) # y軸目盛
ax.set_title('$dX$', fontsize=15) # 全体のタイトル
ax.invert_yaxis() # y軸を反転
ax.set_aspect('equal', adjustable='box') # アスペクト比
plt.show()

逆伝播の出力

 ここまでがAffineレイヤの逆伝播で行う計算です。

 逆伝播における各項の計算を確認しましょう。

 $\frac{\partial L}{\partial \mathbf{X}}$の各項$\frac{\partial L}{\partial x_{n,d}}$は、「$\frac{\partial L}{\partial \mathbf{Z}}$の$n$行目」と「$\mathbf{W}^{\mathrm{T}}$の$h$列目($\mathbf{W}$の$h$行目)」の内積です。

$$ \frac{\partial L}{\partial x_{n,d}} = \sum_{h=0}^{H-1} \frac{\partial L}{\partial z_{n,h}} w_{d,h} $$

 この計算を図で確認します。

・作図コード(クリックで展開)

# 出力のインデックスを指定
n, d = 5, 1

## 網掛け範囲の設定
# 逆伝播の入力用の配列を作成
dZ_mesh = np.zeros_like(dZ)
dZ_mesh[n, :] = 1
dZ_mesh = np.ma.masked_where(dZ_mesh != 1, dZ) # 指定した範囲以外をマスク

# 重み用の配列を作成
W_mesh = np.zeros_like(W)
W_mesh[d, :] = 1
W_mesh = np.ma.masked_where(W_mesh != 1, W) # 指定した範囲以外をマスク

# 逆伝播の出力用の配列を作成
dX_mesh = np.zeros_like(dX)
dX_mesh[n, d] = 1
dX_mesh = np.ma.masked_where(dX_mesh != 1, dX) # 指定した範囲以外をマスク

## 作図
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(16, 9)) # 図の設定

# 逆伝播の入力を描画
ax0 = axs[0]
ax0.pcolor(dZ) # ヒートマップ
ax0.pcolor(dZ_mesh, hatch='//', edgecolor='blue', linewidth=3, vmin=np.min(dZ), vmax=np.max(dZ)) # 確認用の枠
ax0.set_xlabel('h') # x軸ラベル
ax0.set_xticks(np.arange(H) + 0.5) # x軸の目盛位置
ax0.set_xticklabels(np.arange(H)) # x軸目盛
ax0.set_ylabel('n') # y軸ラベル
ax0.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
ax0.set_yticklabels(np.arange(N)) # y軸目盛
ax0.set_title('$dZ$', fontsize=15) # 全体のタイトル
ax0.invert_yaxis() # y軸を反転
ax0.set_aspect('equal', adjustable='box') # アスペクト比

# 転置した重みを描画
ax1 = axs[1]
ax1.pcolor(W.T) # ヒートマップ
ax1.pcolor(W_mesh.T, hatch='//', edgecolor='blue', linewidth=3, vmin=np.min(W), vmax=np.max(W)) # 確認用の枠
ax1.set_xlabel('d') # x軸ラベル
ax1.set_xticks(np.arange(D) + 0.5) # x軸の目盛位置
ax1.set_xticklabels(np.arange(D)) # x軸目盛
ax1.set_ylabel('h') # y軸ラベル
ax1.set_yticks(np.arange(H) + 0.5) # y軸の目盛位置
ax1.set_yticklabels(np.arange(H)) # y軸目盛
ax1.set_title('$W^T$', fontsize=15) # 全体のタイトル
ax1.invert_yaxis() # y軸を反転
ax1.set_aspect('equal', adjustable='box') # アスペクト比

# 逆伝播の出力を描画
ax2 = axs[2]
ax2.pcolor(dX) # ヒートマップ
ax2_mesh = ax2.pcolor(dX_mesh, hatch='//', edgecolor='blue', linewidth=3, vmin=np.min(dX), vmax=np.max(dX)) # 確認用の枠
ax2.set_xlabel('d') # x軸ラベル
ax2.set_xticks(np.arange(D) + 0.5) # x軸の目盛位置
ax2.set_xticklabels(np.arange(D)) # x軸目盛
ax2.set_ylabel('n') # y軸ラベル
ax2.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
ax2.set_yticklabels(np.arange(N)) # y軸目盛
ax2.set_title('$dX$', fontsize=15) # 全体のタイトル
ax2.invert_yaxis() # y軸を反転
ax2.set_aspect('equal', adjustable='box') # アスペクト比
plt.show()


逆伝播の計算

 dZW.Tの網掛けされたピクセルを掛け合わせた総和が、dXの網掛けされたピクセルの値になります。
 この例では、dZの全ての要素が1なので、dZの全ての行がW.Tの列方向の和にっています。

 アニメーションで全ての項について確認しましょう。

・作図コード(クリックで展開)

# 図の設定
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(16, 9))

# 作図処理を関数として定義
def update(i):
    # 前フレームのグラフを初期化
    for ax in axs.flat:
        ax.cla()
    
    # インデックスを計算
    n = i // D
    d = i % D
    
    ## i回目の図を作成
    
    # 逆伝播の入力用の配列を作成
    dZ_mesh = np.zeros_like(dZ)
    dZ_mesh[n, :] = 1
    dZ_mesh = np.ma.masked_where(dZ_mesh != 1, dZ) # 指定した範囲以外をマスク
    
    # 重み用の配列を作成
    W_mesh = np.zeros_like(W)
    W_mesh[d, :] = 1
    W_mesh = np.ma.masked_where(W_mesh != 1, W) # 指定した範囲以外をマスク
    
    # 逆伝播の出力用の配列を作成
    dX_mesh = np.zeros_like(dX)
    dX_mesh[n, d] = 1
    dX_mesh = np.ma.masked_where(dX_mesh != 1, dX) # 指定した範囲以外をマスク
    
    # 逆伝播の入力を描画
    ax0 = axs[0]
    ax0.pcolor(dZ) # ヒートマップ
    ax0.pcolor(dZ_mesh, hatch='//', edgecolor='blue', linewidth=3, vmin=np.min(dZ), vmax=np.max(dZ)) # 確認用の枠
    ax0.set_xlabel('h') # x軸ラベル
    ax0.set_xticks(np.arange(H) + 0.5) # x軸の目盛位置
    ax0.set_xticklabels(np.arange(H)) # x軸目盛
    ax0.set_ylabel('n') # y軸ラベル
    ax0.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
    ax0.set_yticklabels(np.arange(N)) # y軸目盛
    ax0.set_title('$dZ$', fontsize=15) # 全体のタイトル
    ax0.invert_yaxis() # y軸を反転
    ax0.set_aspect('equal', adjustable='box') # アスペクト比
    
    # 転置した重みを描画
    ax1 = axs[1]
    ax1.pcolor(W.T) # ヒートマップ
    ax1.pcolor(W_mesh.T, hatch='//', edgecolor='blue', linewidth=3, vmin=np.min(W), vmax=np.max(W)) # 確認用の枠
    ax1.set_xlabel('d') # x軸ラベル
    ax1.set_xticks(np.arange(D) + 0.5) # x軸の目盛位置
    ax1.set_xticklabels(np.arange(D)) # x軸目盛
    ax1.set_ylabel('h') # y軸ラベル
    ax1.set_yticks(np.arange(H) + 0.5) # y軸の目盛位置
    ax1.set_yticklabels(np.arange(H)) # y軸目盛
    ax1.set_title('$W^T$', fontsize=15) # 全体のタイトル
    ax1.invert_yaxis() # y軸を反転
    ax1.set_aspect('equal', adjustable='box') # アスペクト比
    
    # 逆伝播の出力を描画
    ax2 = axs[2]
    ax2.pcolor(dX) # ヒートマップ
    ax2.pcolor(dX_mesh, hatch='//', edgecolor='blue', linewidth=3, vmin=np.min(dX), vmax=np.max(dX)) # 確認用の枠
    ax2.set_xlabel('d') # x軸ラベル
    ax2.set_xticks(np.arange(D) + 0.5) # x軸の目盛位置
    ax2.set_xticklabels(np.arange(D)) # x軸目盛
    ax2.set_ylabel('n') # y軸ラベル
    ax2.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
    ax2.set_yticklabels(np.arange(N)) # y軸目盛
    ax2.set_title('$dX$', fontsize=15) # 全体のタイトル
    ax2.invert_yaxis() # y軸を反転
    ax2.set_aspect('equal', adjustable='box') # アスペクト比

# gif画像を作成
dot_anime = FuncAnimation(fig, update, frames=N * D, interval=200)

# gif画像を保存
dot_anime.save('ch5_6_dot_dX.gif')


逆伝播の計算


 最後に、逆伝播の計算をもう少し深掘りします。

 逆伝播の出力($\mathbf{X}$の勾配)$\frac{\partial L}{\partial \mathbf{X}}$の各項$\frac{\partial L}{\partial x_{n,d}}$は、順伝播の入力$x_{n,d}$に関する損失$L$の微分です。

 順伝播の計算において、$x_{n,d}$は$H$個の要素$\mathbf{z}_n = (z_{n,0}, z_{n,1}, \cdots, z_{n,H-1})$に影響しています。これは、$x_{n,d}$のノードから分岐して$H$個のノードに入力していると言えます。つまり、Affineレイヤ内の$x_{n,d}$に関するノードは分岐ノードです(2巻の1.3.4.3項)。
 そのため逆伝播では、$H$個のノードから$\frac{\partial L}{\partial \mathbf{z}_n} = (\frac{\partial L}{\partial z_{n,0}}, \frac{\partial L}{\partial z_{n,1}}, \cdots, \frac{\partial L}{\partial z_{n,H-1}})$が$x_{n,d}$のノードに入力します。

 $z_{n,h} = \sum_{d=0}^{D-1} x_{n,d} w_{d,h}$なので、$x_{n,d}$に関する$z_{n,h}$の微分は$\frac{\partial z_{n,h}}{\partial x_{n,d}} = w_{d,h}$です。
 よって、$x_{n,d}$のノードの「逆伝播の入力$\frac{\partial L}{\partial z_{n,0}}, \frac{\partial L}{\partial z_{n,1}}, \cdots, \frac{\partial L}{\partial z_{n,H-1}}$」と「$x_{n,d}$のノードの微分$\frac{\partial z_{n,0}}{\partial x_{n,d}} = w_{d,0}, \frac{\partial z_{n,1}}{\partial x_{n,d}} = w_{d,1}, \cdots, \frac{\partial z_{n,H-1}}{\partial x_{n,d}} = w_{d,H-1}$」の積は$\frac{\partial L}{\partial z_{n,0}} w_{d,0}, \frac{\partial L}{\partial z_{n,1}} w_{d1}, \cdots, \frac{\partial L}{\partial z_{n,H-1}} w_{d,H-1}$となります。
 分岐ノードの逆伝播では分岐した全ての要素の和を求めるので、$x_{n,d}$のノードの逆伝播の出力は、これら$H$個の要素の和$\frac{\partial L}{\partial x_{n,d}} = \sum_{h=0}^{H-1} \frac{\partial L}{\partial z_{n,h}} w_{d,h}$となります。

 行列の積の計算$\frac{\partial L}{\partial \mathbf{X}} = \frac{\partial L}{\partial \mathbf{Z}} \cdot \mathbf{W}^{\mathrm{T}}$によって、全ての項でこの計算が行われます。

 $x_{n,d}$が影響する$\mathbf{Z}$の要素(ピクセル)をアニメーションで確認しましょう。

・作図コード(クリックで展開)

# 図の設定
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(16, 9))

# 作図処理を関数として定義
def update(h):
    # 前フレームのグラフを初期化
    #plt.cla()
    for ax in axs.flat:
        ax.cla()
    
    # インデックスを指定
    n, d = 5, 1
    
    ## i回目の図を作成
    
    # 順伝播の入力用の配列を作成
    X_mesh = np.zeros_like(X)
    X_mesh[n, :] = 1
    X_mesh = np.ma.masked_where(X_mesh != 1, X) # 指定した範囲以外をマスク
    
    # 注目する入力x_nd用の配列を作成
    X_mesh_nd = np.zeros_like(X)
    X_mesh_nd[n, d] = 1
    X_mesh_nd = np.ma.masked_where(X_mesh_nd != 1, X) # 指定した範囲以外をマスク
    
    # 重み用の配列を作成
    W_mesh = np.zeros_like(W)
    W_mesh[:, h] = 1
    W_mesh = np.ma.masked_where(W_mesh != 1, W) # 指定した範囲以外をマスク
    
    # x_ndに対応した重みw_dh用の配列を作成
    W_mesh_dh = np.zeros_like(W)
    W_mesh_dh[d, h] = 1
    W_mesh_dh = np.ma.masked_where(W_mesh_dh != 1, W) # 指定した範囲以外をマスク
    
    # 順伝播の出力用の配列を作成
    Z_mesh = np.zeros_like(Z)
    Z_mesh[n, h] = 1
    Z_mesh = np.ma.masked_where(Z_mesh != 1, Z) # 指定した範囲以外をマスク
    
    # 順伝播の入力を描画
    ax0 = axs[0]
    ax0.pcolor(X) # ヒートマップ
    ax0.pcolor(X_mesh, hatch='//', edgecolor='turquoise', linewidth=3, vmin=np.min(X), vmax=np.max(X)) # 確認用の枠
    ax0.pcolor(X_mesh_nd, hatch='//', edgecolor='blue', linewidth=3, vmin=np.min(X), vmax=np.max(X)) # 確認用の枠
    ax0.set_xlabel('d') # x軸ラベル
    ax0.set_xticks(np.arange(D) + 0.5) # x軸の目盛位置
    ax0.set_xticklabels(np.arange(D)) # x軸目盛
    ax0.set_ylabel('n') # y軸ラベル
    ax0.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
    ax0.set_yticklabels(np.arange(N)) # y軸目盛
    ax0.set_title('$X$', fontsize=15) # 全体のタイトル
    ax0.invert_yaxis() # y軸を反転
    ax0.set_aspect('equal', adjustable='box') # アスペクト比
    
    # 重みを描画
    ax1 = axs[1]
    ax1.pcolor(W) # ヒートマップ
    ax1.pcolor(W_mesh, hatch='//', edgecolor='turquoise', linewidth=3, vmin=np.min(W), vmax=np.max(W)) # 確認用の枠
    ax1.pcolor(W_mesh_dh, hatch='//', edgecolor='blue', linewidth=3, vmin=np.min(W), vmax=np.max(W)) # 確認用の枠
    ax1.set_xlabel('h') # x軸ラベル
    ax1.set_xticks(np.arange(H) + 0.5) # x軸の目盛位置
    ax1.set_xticklabels(np.arange(H)) # x軸目盛
    ax1.set_ylabel('d') # y軸ラベル
    ax1.set_yticks(np.arange(D) + 0.5) # y軸の目盛位置
    ax1.set_yticklabels(np.arange(D)) # y軸目盛
    ax1.set_title('$W$', fontsize=15) # 全体のタイトル
    ax1.invert_yaxis() # y軸を反転
    ax1.set_aspect('equal', adjustable='box') # アスペクト比
    
    # 順伝播の出力を描画
    ax2 = axs[2]
    ax2.pcolor(Z) # ヒートマップ
    ax2.pcolor(Z_mesh, hatch='//', edgecolor='turquoise', linewidth=4.5, vmin=np.min(Z), vmax=np.max(Z)) # 確認用の枠
    ax2.pcolor(Z_mesh, hatch='//', edgecolor='blue', linewidth=1.5, vmin=np.min(Z), vmax=np.max(Z)) # 確認用の枠
    ax2.set_xlabel('h') # x軸ラベル
    ax2.set_xticks(np.arange(H) + 0.5) # x軸の目盛位置
    ax2.set_xticklabels(np.arange(H)) # x軸目盛
    ax2.set_ylabel('n') # y軸ラベル
    ax2.set_yticks(np.arange(N) + 0.5) # y軸の目盛位置
    ax2.set_yticklabels(np.arange(N)) # y軸目盛
    ax2.set_title('$Z = X \cdot W$', fontsize=15) # 全体のタイトル
    ax2.invert_yaxis() # y軸を反転
    ax2.set_aspect('equal', adjustable='box') # アスペクト比

# gif画像を作成
dot_anime = FuncAnimation(fig, update, frames=H, interval=200)

# gif画像を保存
dot_anime.save('ch5_6_dot_Xnd.gif')


順伝播と逆伝播の計算の対応

 青色の網掛けが移動したピクセル全体が、先ほどの青色の網掛け部分に対応しています。
 (枠線とメッシュ線とタイルの色を分けて指定したいのだが。)

 以上で、Affineレイヤの逆伝播の計算過程を確認できました。重みの勾配とバイアスの勾配についても同様に、順伝播において各要素が分岐したノードに関して和をとっています。

参考文献

おわりに

 Affineレイヤの可視化の記事を2つ書いて私の頭の中はスッキリ整理できたのだけど、読んだ人の頭の中もスッキリするのかはいつも疑問。

 いつもやりたいことに対してPyPlot力が足りずブロリー状態。

【関連する記事】

www.anarchive-beta.com

www.anarchive-beta.com

www.anarchive-beta.com