からっぽのしょこ

読んだら書く!書いたら読む!同じ事は二度調べ(たく)ない

ステップ40:ブロードキャストの逆伝播に利用する関数の実装【ゼロつく3のノート(実装)】

はじめに

 『ゼロから作るDeep Learning 3』の初学者向け攻略ノートです。『ゼロつく3』の学習の補助となるように適宜解説を加えていきます。本と一緒に読んでください。

 本で省略されているクラスや関数の内部の処理を1つずつ解説していきます。

 この記事は、主にステップ40「ブロードキャストを行う関数」を補足する内容です。
 ブロードキャストの逆伝播の計算を行う関数sum_to()の処理を確認します。

【前ステップの内容】

www.anarchive-beta.com

【他の記事一覧】

www.anarchive-beta.com

【この記事の内容】

・ブロードキャストの逆伝播用の関数の実装

 ブロードキャストを伴う計算では、入力の要素が複数に分岐します。逆伝播では、分岐した要素に対応する勾配を足し合わせる必要があります。この記事では、その際に利用する関数sum_to()の内部の処理を確認します。

・処理の確認

 dezeroフォルダ内の(functions.pyではなく)utils.pyに実装されている関数sum_to()で行う処理を確認していきます。sum_to()は、ブロードキャスト(分岐ノード)を伴うクラスの逆伝播メソッドbackward()の中で勾配の合算を行います。

 次のライブラリを利用します。

# 利用するライブラリ
import numpy as np


・順伝播(ブロードキャスト)における処理

 ブロードキャストは、足し算Add、引き算Sub、掛け算Mul、割り算Divクラスで行われています。ここでは足し算を例として、ブロードキャストを行う計算処理を確認していきます。

 計算に用いる2つの変数(順伝播の入力)$x_0,\ x_1$を作成します。ここでは処理を確認したいだけなので、Variableインスタンスではなく、NumPy配列のまま扱います。

# 入力データ0を作成
x0 = np.arange(3 * 4 * 5).reshape((3, 4, 5))
print(x0)
print(x0.shape)
print(x0.ndim)
[[[ 0  1  2  3  4]
  [ 5  6  7  8  9]
  [10 11 12 13 14]
  [15 16 17 18 19]]

 [[20 21 22 23 24]
  [25 26 27 28 29]
  [30 31 32 33 34]
  [35 36 37 38 39]]

 [[40 41 42 43 44]
  [45 46 47 48 49]
  [50 51 52 53 54]
  [55 56 57 58 59]]]
(3, 4, 5)
3
# 入力データ1を作成
x1 = np.array([10])
#x1 = np.array([10, 20, 30, 40, 50])
#x1 = np.array([[10, 20, 30, 40, 50]])
#x1 = np.array([[10], [20], [30], [40]])
print(x1)
print(x1.shape)
print(x1.ndim)
[10]
(1,)
1

 $x_0,\ x_1$をそれぞれx0, x1とします。

 どちらか一方の配列の形状に関するオブジェクトを作成します。この例では、ブロードキャストを行うx1にします。ブロードキャストを行わないx0の場合の各処理の結果も確認しておくとより理解が深まります。

# 順伝播の入力の形状を保存
#x_shape = x0.shape
#x_ndim = x0.ndim
x_shape = x1.shape
x_ndim = x1.ndim

 順伝播の入力$x_1$(または$x_0$)の各次元(軸)の要素数をx_shape、次元数をx_ndimとします。

 順伝播の計算(この例では足し算)$y = x_0 + x_1$を行い、計算結果を「順伝播の出力y」とします。

# 足し算の順伝播を計算
y = x0 + x1
print(y)
print(y.shape)
[[[10 11 12 13 14]
  [15 16 17 18 19]
  [20 21 22 23 24]
  [25 26 27 28 29]]

 [[30 31 32 33 34]
  [35 36 37 38 39]
  [40 41 42 43 44]
  [45 46 47 48 49]]

 [[50 51 52 53 54]
  [55 56 57 58 59]
  [60 61 62 63 64]
  [65 66 67 68 69]]]
(3, 4, 5)

 yは、x0と同じ形状になります。計算時にx1の対応する要素を複製(ブロードキャスト)してx0と同じ形状にしてから計算されています。

 どのように複製されているのかを確認しておきましょう。

# ブロードキャスト
x0_broad = np.broadcast_to(x1, x0.shape)
print(x0_broad)
print(x0_broad.shape)
[[[10 10 10 10 10]
  [10 10 10 10 10]
  [10 10 10 10 10]
  [10 10 10 10 10]]

 [[10 10 10 10 10]
  [10 10 10 10 10]
  [10 10 10 10 10]
  [10 10 10 10 10]]

 [[10 10 10 10 10]
  [10 10 10 10 10]
  [10 10 10 10 10]
  [10 10 10 10 10]]]
(3, 4, 5)

 ブロードキャストでは、計算に用いる2つの配列の内、小さい方を大きい方の形状にします。

 この例だと、0次元配列のx1を2次元配列のx0に対応させます(Pythonルールに合わせて0から数えています。実質は1次元配列と3次元配列です)。

 まず、x1に0番目と1番目の軸を追加して、元々あった軸を2番目の軸とした2次元配列に変換されます。
 (x1の作成時にコメントアウトしてある2つ目の例)x1が要素数5の0次元配列(5,)であれば、(1, 1, 5)の2次元配列になります。入れ子状の配列で表現すると、外側に配列を増やします。

 x0x1の軸の数(次元数)を合わせた上で、各軸(次元)の要素を複製して要素数を合わせます。このとき、複製を行う軸の要素数は1である必要があります。
 2つ目の例だと、(1, 1, 5)の0軸と1軸の要素(入れ子の配列)を複製して(3, 4, 5)の2次元配列になります。
 要素が1でない場合、例えば(2,)から(4,)へのブロードキャストはできずエラーになります。ただし、(2, 1)から(2, 5)へは行えます。

 ややこしいですが、ここがブロードキャストのポイントです。逆伝播では、この処理の逆の操作を行います。

 ここまでが順伝播の処理で、Addクラスの順伝播メソッドforward()で行われます。

・逆伝播における処理

 続いて、ブロードキャストを伴う足し算の逆伝播の処理を確認していきます。

 「逆伝播の入力$\frac{\partial y}{\partial y}$」を作成します。$\frac{\partial y}{\partial y}$は、$y$に関する$y$の勾配で、$y$と同じ形状です。「順伝播の出力y」と同じ形状で全ての要素が1の配列を作成してgyとします。

# 逆伝播の入力データを作成
gy = np.ones_like(y)
print(gy)
print(gy.shape)
[[[1 1 1 1 1]
  [1 1 1 1 1]
  [1 1 1 1 1]
  [1 1 1 1 1]]

 [[1 1 1 1 1]
  [1 1 1 1 1]
  [1 1 1 1 1]
  [1 1 1 1 1]]

 [[1 1 1 1 1]
  [1 1 1 1 1]
  [1 1 1 1 1]
  [1 1 1 1 1]]]
(3, 4, 5)

 順伝播の計算において更に次の計算$L = f(y)$がある場合は、この計算の逆伝播の入力は$\frac{\partial L}{\partial y}$になります。その場合は、次の計算の逆伝播メソッドから出力される値がgyとなります。
 これらの処理は、Variableクラスの逆伝播メソッドbackward()で行われます。

 「ブロードキャストされた$x_1$の勾配$\frac{\partial y}{\partial x_1}$」を計算して、gxとします。

# 足し算の逆伝播の計算
gx = gy

 足し算の逆伝播は$\frac{\partial y}{\partial x_1} = \frac{\partial y}{\partial y}$です。この時点では、gxは「順伝播の出力y」と同じ形状です。
 gxが(utils.pyfunctions.py両方の)sum_to()の第1引数xに対応します。

 次からがブロードキャストにおける逆伝播の処理です。

 「現在の次元数gx.ndim」と「最終的な次元数x_ndim」の差を求めてleadとします。(この変数名leadは訳すと何?)

# 順伝播の出入力の次元数の差を計算
lead = gx.ndim - x_ndim
print(lead)
2

 「gx.ndimは順伝播の出力y(逆伝播の入力gy)の次元数」、「x_ndimは順伝播の入力x1の次元数」でもあります。

 ブロードキャスト時に追加された軸番号を作成します。

# 追加された軸番号を作成
lead_axis = tuple(range(lead))
print(lead_axis)
(0, 1)

 ブロードキャストでは、値が小さい側に軸を追加することで形状を調整しました。この例だと、0番目と1番目の軸を追加しました。lead_axisは、その追加された軸番号に対応します。

 ブロードキャスト時に要素を複製した軸番号を抽出します。

# 要素を複製された軸番号を抽出
axis = tuple([i + lead for i, sx in enumerate(x_shape) if sx == 1])
print(axis)
(2,)

 リスト内包表記を使って処理しています。for文とif文を分けると次になります。

# リストを初期化
axis_list = []

# 要素を複製された軸番号を抽出
for i, sx in enumerate(x_shape):
    if sx == 1: # 各次元の要素数が1のとき
        # ブロードキャスト後の軸番号をリストに追加
        axis_list.append(i + lead)

# タプルに変換
axis = tuple(axis_list)
print(axis)
(2,)

 x_shapeは元の形状です。iは軸番号、sxi軸の要素数を表します。
 つまりこの処理は、元々x1にあった軸であり、要素が複製された軸を抽出しています。ただし、ブロードキャスト時にlead個の軸が追加されるのでした。そのため、i + leadがブロードキャスト後における要素が複製された軸番号です。
 抽出した軸番号をタプルに追加していきます。要素を複製した軸がない場合は、空のタプル()になります。

 ブロードキャストによって複製した要素に対応する勾配を足し合わせます。

# 和をとる軸を確認
print(lead_axis + axis)

# 複製した要素の和をとる
gx_sum = gx.sum(lead_axis + axis, keepdims=True)
print(gx_sum)
print(gx_sum.shape)
(0, 1, 2)
[[[60]]]
(1, 1, 1)

 追加された軸lead_axisと要素を複製した軸axisの要素を全て足します。これがブロードキャスト(分岐ノード)の逆伝播の計算です。詳しくは、2巻の1.3.4.3項を参照してください。
 gx_sumutils.pysum_to()におけるyに対応します。

 lead_axis番目の軸を取り除くことで、「順伝播の入力$x_1$」の形状にします。

# 順伝播の入力の形状に整形
if lead > 0:
    # 不要な軸を消去
    gx_sum = gx_sum.squeeze(lead_axis)
print(gx_sum)
print(gx_sum.shape)
[60]
(1,)

 np.squeeze()に指定した軸を取り除きます。lead_axisが空()のときは、要素数が1の軸を全て消去します。lead0の場合は、順伝播の入出力(逆伝播の出入力)の形状に差がないので軸を消去する必要がありません。

 これで「順伝播の入力x1(またはx0)」と同じ形状の「逆伝播の出力gx」が得られました。

 以上がutils.pysum_to()に関連する処理です。

・実装した関数の確認

 実装した関数を試してみましょう。

 dezeroフォルダのutils.pyを読み込みます。dezeroフォルダの親フォルダまでのパスをsys.path.append()に指定します。

# 実装済み関数の読み込み用の設定
import sys
sys.path.append('..')

# 実装済み関数の読み込み
from dezero import utils


 入力データを作成します。

# 入力データ0を作成
x0 = np.arange(3 * 4 * 5).reshape((3, 4, 5))
print(x0)
print(x0.shape)

# 入力データ1を作成
x1 = np.array([10])
#x1 = np.array([10, 20, 30, 40, 50])
#x1 = np.array([[10, 20, 30, 40, 50]])
#x1 = np.array([[10], [20], [30], [40]])
print(x1)
print(x1.shape)
[[[ 0  1  2  3  4]
  [ 5  6  7  8  9]
  [10 11 12 13 14]
  [15 16 17 18 19]]

 [[20 21 22 23 24]
  [25 26 27 28 29]
  [30 31 32 33 34]
  [35 36 37 38 39]]

 [[40 41 42 43 44]
  [45 46 47 48 49]
  [50 51 52 53 54]
  [55 56 57 58 59]]]
(3, 4, 5)
[10]
(1,)


 順伝播(足し算)を計算します。

# 順伝播を計算
y = x0 + x1
print(y)
print(y.shape)
[[[10 11 12 13 14]
  [15 16 17 18 19]
  [20 21 22 23 24]
  [25 26 27 28 29]]

 [[30 31 32 33 34]
  [35 36 37 38 39]
  [40 41 42 43 44]
  [45 46 47 48 49]]

 [[50 51 52 53 54]
  [55 56 57 58 59]
  [60 61 62 63 64]
  [65 66 67 68 69]]]
(3, 4, 5)


 逆伝播の入力を作成します。

# 逆伝播の入力データを作成
gy = np.ones_like(y)
print(gy)
[[[1 1 1 1 1]
  [1 1 1 1 1]
  [1 1 1 1 1]
  [1 1 1 1 1]]

 [[1 1 1 1 1]
  [1 1 1 1 1]
  [1 1 1 1 1]
  [1 1 1 1 1]]

 [[1 1 1 1 1]
  [1 1 1 1 1]
  [1 1 1 1 1]
  [1 1 1 1 1]]]


 足し算とブロードキャストの逆伝播を計算します。

# 逆伝播の出力を計算
gx = utils.sum_to(gy, x1.shape)
print(gx)
print(gx.shape)
[60]
(1,)


 以上で、ブロードキャストの逆伝播の関数(utils.pyの)sum_to()の内部で行われる処理を確認できました。この関数を用いて、SumToクラスを実装します。SumToクラスを関数として扱えるようにしたfunctions.pysum_to()を用いて、四則演算のクラスAdd, Mul, Sub, Divを実装します。また、今回確認した一連の処理が、Addクラスで行う処理です。

参考文献

おわりに

 何をやっているのか(和をとりつつ入力の形状にする)は明確なのですが、どうやってそれを実現しているのかを理解するのに丸一日かかりました。そしてそれを言葉で説明するのも中々難しい。とは言いつつ、理解してしまうとそこまで難しいことはやってないなと思ったり。
 とにかく、ブロードキャストの処理を理解できれば、その逆の処理をしていると捉えれば理解できるはずです。

【次ステップの内容】

つづく