はじめに
『ゼロから作るDeep Learning 4 ――強化学習編』の独学時のまとめノートです。初学者の補助となるようにゼロつくシリーズの4巻の内容に解説を加えていきます。本と一緒に読んでください。
この記事は、4.1節の内容です。簡単な例を使って反復方策評価を行います。
【前節の内容】
【他の記事一覧】
【この記事の内容】
4.1 動的計画法と方策評価
反復方策評価により状態価値関数を評価します。状態価値関数についてのベルマン方程式については「3.1.2:状態価値関数のベルマン方程式の導出【ゼロつく4のノート】 - からっぽのしょこ」を参照してください。
4.1.1 動的計画法の概要
動的計画法における状態価値関数の更新式を確認します。
状態価値関数は、確率論的方策$\pi(a | s)$を取るときの収益$G_t$の期待値として定義され(2.3.3項)、さらにベルマン方程式に変形しました(3.1.2項)。
しかし、真の状態価値関数$v_{\pi}(s)$を直接得ることはできません。
そこで、状態価値関数の推定値を$V(s)$で表し、次の更新式で定義します。
ここで、$V_k(s')$は$k$回目に更新された状態価値関数の推定値です。
この式は、「$k$回目に更新された状態$s'$の価値関数$V_k(s')$」と「$k+1$回目に更新された状態$s$の価値関数$V_{k+1}(s)$」の関係を表します。つまり、$V_k(s')$を使って$V_{k+1}(s)$を求めます。
4.1.2-3 反復方策評価を試す
簡単な例として、2マスのグリッドワールド(図4-2)を使って、反復方策評価アルゴリズムを確認します。問題設定と数式上の計算については「3.2.1:状態価値関数のベルマン方程式の例【ゼロつく4のノート】 - からっぽのしょこ」を参照してください。
利用するライブラリを読み込みます。
# 利用するライブラリ import numpy as np import matplotlib.pyplot as plt
・実装1
まずは、アルゴリズムに従ってシンプルに実装します。
試行回数max_iter
を指定して、L1とL2それぞれの状態価値を繰り返し更新(計算)します。
# 試行回数を指定 max_iter = 100 # 割引率を指定 gamma = 0.9 # 左に行動する確率を指定 p_left = 0.5 # 状態価値を初期化 V = {'L1': 0.0, 'L2': 0.0} # 更新用の状態価値を作成 new_V = V.copy() # 推移の記録用のリストを作成 trace_V = [list(V.values())] # 初期値を記録 # 繰り返し試行 for i in range(max_iter): # 状態価値関数を計算:式(4.2) new_V['L1'] = p_left * (-1 + gamma * V['L1']) + (1-p_left) * (1 + gamma * V['L2']) new_V['L2'] = p_left * (0 + gamma * V['L1']) + (1-p_left) * (-1 + gamma * V['L2']) # 状態価値を更新 V = new_V.copy() # 更新値を記録 trace_V.append(list(V.values())) # 途中経過を表示 print(['V_'+str(i+1)+'('+state+')='+str(np.round(value, 3)) for state, value in V.items()]) # 作図用にNumPy配列に変換 trace_V = np.array(trace_V).T
['V_1(L1)=0.0', 'V_1(L2)=-0.5']
['V_2(L1)=-0.225', 'V_2(L2)=-0.725']
['V_3(L1)=-0.427', 'V_3(L2)=-0.927']
(省略)
['V_99(L1)=-2.25', 'V_99(L2)=-2.75']
['V_100(L1)=-2.25', 'V_100(L2)=-2.75']
L2の状態価値関数$V_{k+1}(L2)$の計算(更新)に、1試行前(更新前)のL1の状態価値関数$V_k(L1)$を用います。そこで、更新前の状態価値関数の値をV
として計算を行い、更新後の状態価値関数の値(計算結果)をnew_V
として一時的に保存しておき、2つの状態価値関数の計算後にnew_V
でV
を上書きします。
状態価値関数の推定値(更新値)の推移をグラフで確認します。
# 真の状態価値を指定 v1 = -2.25 v2 = -2.75 # 状態価値関数の推移を作図 plt.figure(figsize=(8, 6)) plt.plot(trace_V[0], color='blue', label='$V_k(L1)$') # L1の推定状態価値 plt.plot(trace_V[1], color='orange', label='$V_k(L2)$') # L2の推定状態価値 plt.hlines(y=v1, xmin=0, xmax=max_iter, color='blue', linestyle=':', label='$v_{\pi}(L1)$') # L1の真の状態価値 plt.hlines(y=v2, xmin=0, xmax=max_iter, color='orange', linestyle=':', label='$v_{\pi}(L2)$') # L2の真の状態価値 plt.xlabel('iteration (k)') plt.ylabel('state-value') plt.suptitle('Iterative Policy Evaluation', fontsize=20) plt.title('$\gamma='+str(gamma) + ', \pi='+str([p_left, 1-p_left])+'$', loc='left') plt.grid() plt.legend() plt.show()
試行回数が増えるに従って、真の状態価値に近付いているのが分かります。
・実装2
実装1では、指定した試行回数に従い更新を繰り返しました。次は、閾値を指定して更新幅が閾値未満になるまで更新を繰り返します。
閾値threshold
を指定して、1回の試行で更新された値delta
がthreshold
未満になるまで処理を繰り返します。
# 閾値を指定 threshold = 0.0001 # 割引率を指定 gamma = 0.9 # 左に行動する確率を指定 p_left = 0.5 # 状態価値を初期化 V = {'L1': 0.0, 'L2': 0.0} # 更新用の状態価値を作成 new_V = V.copy() # 推移の記録用のリストを作成 trace_V = [list(V.values())] # 初期値を記録 # 試行回数のカウントを初期化 cnt = 0 # 繰り返し試行 while True: # 状態価値関数を計算:式(4.2) new_V['L1'] = p_left * (-1 + gamma * V['L1']) + (1-p_left) * (1 + gamma * V['L2']) new_V['L2'] = p_left * (0 + gamma * V['L1']) + (1-p_left) * (-1 + gamma * V['L2']) # 更新幅の最大値を計算 delta_L1 = abs(new_V['L1'] - V['L1']) delta_L2 = abs(new_V['L2'] - V['L2']) delta = max(delta_L1, delta_L2) # 状態価値を更新 V = new_V.copy() # 更新値を記録 trace_V.append(list(V.values())) # 試行回数をカウント cnt += 1 # 途中経過を表示 print(['V_'+str(cnt)+'('+state+')='+str(np.round(value, 3)) for state, value in V.items()]) # 更新幅が閾値未満になると終了 if delta < threshold: break # 作図用にNumPy配列に変換 trace_V = np.array(trace_V).T
['V_1(L1)=0.0', 'V_1(L2)=-0.5']
['V_2(L1)=-0.225', 'V_2(L2)=-0.725']
['V_3(L1)=-0.427', 'V_3(L2)=-0.927']
(省略)
['V_75(L1)=-2.249', 'V_75(L2)=-2.749']
['V_76(L1)=-2.249', 'V_76(L2)=-2.749']
L1・L2の状態価値について、それぞれ更新前後の値の差をabs()
で絶対値(正の値)にします。
L1・L2それぞれの更新値の絶対値をmax()
で比較して、大きい方の値をdelta
とします。
更新幅の最大値delta
を閾値threshold
と比較して、閾値未満になればbreak
で処理を終了します。while
文は、条件がTrue
であれば(条件がFalse
になるまで)処理を繰り返します。while True:
とすることで、無限に処理を繰り返します。
更新値の計算自体は実験1と変わらないので、先ほどと同じ推移になります。
実験1の作図コードのmax_iter
をcnt
に置き換えて、状態価値関数の推移をグラフで確認します。
# 真の状態価値を指定 v1 = -2.25 v2 = -2.75 # 状態価値関数の推移を作図 plt.figure(figsize=(8, 6)) plt.plot(trace_V[0], color='blue', label='$V_k(L1)$') # L1の推定状態価値 plt.plot(trace_V[1], color='orange', label='$V_k(L2)$') # L2の推定状態価値 plt.hlines(y=v1, xmin=0, xmax=cnt, color='blue', linestyle=':', label='$v_{\pi}(L1)$') # L1の真の状態価値 plt.hlines(y=v2, xmin=0, xmax=cnt, color='orange', linestyle=':', label='$v_{\pi}(L2)$') # L2の真の状態価値 plt.xlabel('iteration (k)') plt.ylabel('state-value') plt.suptitle('Iterative Policy Evaluation', fontsize=20) plt.title('$\gamma='+str(gamma) + ', \pi='+str([p_left, 1-p_left])+'$', loc='left') plt.grid() plt.legend() plt.show()
実装1のときよりも試行回数が減っています。実装2では、無駄に処理を繰り返さず、収束前に処理を止めません。ただし、収束しない場合にはいつまでも処理が終わらない点に注意が必要です。
・実装3
これまでは、2つのディクショナリを使って、L1とL2それぞれの状態価値関数を計算してセットで更新しました。次は、1つのディクショナリと中間変数を使って、それぞれ計算と更新を行います。
中間変数t
を使って、更新幅が閾値未満になるまで処理を繰り返します。
# 閾値を指定 threshold = 0.0001 # 割引率を指定 gamma = 0.9 # 左に行動する確率を指定 p_left = 0.5 # 状態価値を初期化 V = {'L1': 0.0, 'L2': 0.0} # 推移の記録用のリストを作成 trace_V = [list(V.values())] # 初期値を記録 # 試行回数のカウントを初期化 cnt = 0 # 繰り返し試行 while True: # L1の状態価値関数を計算:式(4.2) t = p_left * (-1 + gamma * V['L1']) + (1-p_left) * (1 + gamma * V['L2']) # L1の更新幅を計算 delta = abs(t - V['L1']) # L1の状態価値を更新 V['L1'] = t # L2の状態価値関数を計算:式(4.2) t = p_left * (0 + gamma * V['L1']) + (1-p_left) * (-1 + gamma * V['L2']) # 更新幅の最大値を計算 delta = max(delta, abs(t - V['L2'])) # L2の状態価値を更新 V['L2'] = t # 試行回数をカウント cnt += 1 # 更新値を記録 trace_V.append(list(V.values())) # 途中経過を表示 print(['V_'+str(cnt)+'('+state+')='+str(np.round(value, 3)) for state, value in V.items()]) # 更新幅が閾値未満になると終了 if delta < threshold: break # 作図用にNumPy配列に変換 trace_V = np.array(trace_V).T
['V_1(L1)=0.0', 'V_1(L2)=-0.5']
['V_2(L1)=-0.225', 'V_2(L2)=-0.826']
['V_3(L1)=-0.473', 'V_3(L2)=-1.085']
(省略)
['V_59(L1)=-2.249', 'V_59(L2)=-2.749']
['V_60(L1)=-2.249', 'V_60(L2)=-2.749']
L1・L2どちらの更新値(状態価値関数の計算結果)もt
として保持します。
L2の状態価値関数の計算の前に、L1の状態価値V['L1']
をt
で上書きします。それにより、更新後のL1の状態価値を使ってL2の状態価値を更新します。
また、「L1・L2それぞれの更新値t
」を使って「更新幅の最大値delta
」を計算し、delta
が閾値threshold
未満になると処理を終了します。
これまでと計算に用いる状態価値関数が変わっているので、計算結果(推移)も変わります。
実装2の作図コードで、状態価値関数の推移をグラフで確認します。
実験2のときよりも早く収束しています。1つのディクショナリを使う(メモリ効率が良い)だけでなく、計算処理の効率も良くなったのが分かります。
この節では、2マスのグリッドワールドにおける反復方策評価を行いました。次節からは、3×4マスのグリッドワールドを扱います。
参考文献
おわりに
ここからは(プログラム的な意味で)手を動かしながら進めます!
【次節の内容】