はじめに
『ゼロから作るDeep Learning 4 ――強化学習編』の独学時のまとめノートです。初学者の補助となるようにゼロつくシリーズの4巻の内容に解説を加えていきます。本と一緒に読んでください。
この記事は、6.1節の内容です。TD法により状態価値関数を推定します。
【前節の内容】
【他の記事一覧】
【この記事の内容】
6.1 TD法による方策評価
TD法により状態価値関数を推定(方策を評価)します。
6.1.1 TD法の導出
まずは、TD法による状態価値関数の更新式を導出します。
数式の確認
状態価値関数の定義式とMC法とDP法による計算式を確認して、TD法による計算式を導出します。
状態価値関数の定義式
各時刻の収益は、割り引き報酬和で定義されました(2.3.2項・3.1.2項)。
ここで、$\gamma$は割引率で$0 \leq \gamma \leq 1$の値を指定します。
また、状態価値関数は、状態が$s$のときの収益$G_t$の期待値で定義されました(2.3.3項・3.1.2項)。
アルゴリズムごとに状態価値関数(収益の期待値)を推定(近似)する方法(計算式)が異なります。
MC法とDP法の更新式
MC法では、収益と状態のサンプルを用いて、状態価値関数(6.3)を近似するのでした(5.2.1項)。
状態価値関数(収益の期待値)の推定値(近似値)として、指数移動平均を用いました(5.4.2項)。
ここで、$\alpha$は学習率で$0 < \alpha < 1$の値を指定します。また、更新後の状態価値関数を$V'_{\pi}$で表します。
繰り返し収益のサンプルを生成して状態価値関数を更新することで、推定値$V_{\pi}(s)$を真の値$v_{\pi}(s)$に近付けます。
DP法では、ベルマン方程式を用いて、状態価値関数(6.4)を計算するのでした(4.1-2節)。
全ての状態でこの計算を行います。
状態ごとに更新を繰り返すことで真の値$v_{\pi}(s)$が得られます。
MC法とDP法の更新式を組み合わせて、TD法の更新式を求めます。
TD法の更新式
ベルマン方程式(6.6)を、確率論的方策$\pi(a | s)$と状遷移確率$p(s' | s, a)$の同時分布を状態$a$について周辺化して、状態$s$を条件とする期待値の項に変形します(3.1.2項)。
$r(s, a, s') + \gamma v_{\pi}(s')$の期待値を、「次の状態のサンプル$S_{t+1}$」と「報酬のサンプル$R_t$」
を用いて求めます。
$R_t + \gamma V_{\pi}(S_{t+1})$の期待値(6.8)を、指数移動平均で近似します。
DP法による状態価値関数の更新式が得られました。
DP法による更新式(6.9)を展開してみます。$k$回更新した(更新前の)状態関数を$V_k$、$k+1$回更新した(更新後の)状態価値を$V_{k+1}$、状態$S_t$における$k$個目の報酬のサンプルを$R^{(k)}$、状態$S_t$における$k$個目の次の状態のサンプルを$S^{(k)}$で表します。
後の項について
で置き換えられます。
同様に繰り返すと、次の式になります。
過去のTDターゲット$R^{(n)} + \gamma V_{\pi}(S^{(n)})$ほど$(1 - \alpha)^{k-n}$が小さくなるため、推定値への影響が小さくなります。
6.1.3 TD法の実装
次は、TD法により状態価値関数の推定を行うエージェントを実装します。
利用するライブラリを読み込みます。
# ライブラリを読み込み import numpy as np from collections import defaultdict # 追加ライブラリ import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap from matplotlib.animation import FuncAnimation
更新推移をアニメーションで確認するのにmatplotlib
のモジュールを利用します。不要であれば省略してください。
また、3×4マスのグリッドワールドのクラスGridWorld
を読み込みます。
# 実装済みのクラスと関数を読み込み import sys sys.path.append('../deep-learning-from-scratch-4-master') from common.gridworld import GridWorld
実装済みクラスの読み込みについては「3.6.1:MNISTデータセットの読み込み【ゼロつく1のノート(Python)】 - からっぽのしょこ」、GridWorld
クラスについては「4.2.1:GridWorldクラスの実装:評価と改善に関するメソッド【ゼロつく4のノート】 - からっぽのしょこ」「4.2.1:GridWorldクラスの実装:可視化に関するメソッド【ゼロつく4のノート】 - からっぽのしょこ」を参照してください。
処理の確認
TdAgent
クラスのeval
メソッドの内部で行う処理を確認します。他のメソッドについては「5.3:モンテカルロ法による方策評価の実装【ゼロつく4のノート】 - からっぽのしょこ」を参照してください。
例として、ランダムな値の状態価値関数を作成しておきます。
# 環境のインスタンスを作成 env = GridWorld() # (仮の)状態価値関数を作成 V = {state: np.random.rand() for state in env.states()} print(list(V.keys())) print(np.round(list(V.values()), 3))
[(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3), (2, 0), (2, 1), (2, 2), (2, 3)]
[0.836 0.491 0.267 0.768 0.247 0.414 0.437 0.781 0.858 0.104 0.156 0.353]
全てのマスに対応する状態価値をディクショナリに格納しておきます(ゴールマスと壁マスにも値が設定されています)。
次の状態の状態価値を取り出します。
# (仮の)次の状態を設定 next_state = (0, 2) # (仮の)ゴールフラグを設定 done = False #done = True # 次の状態の状態価値を取得 if done: next_V = 0 else: next_V = V[next_state] print(next_V)
0.26690616281474955
「次の状態next_state
」をキーとして、「状態価値関数V
」から値を取り出して「次の状態の状態価値next_V
」とします。ただし、次の状態がゴールマスのときは、状態価値を0
とします。
現在の状態の状態価値を計算して、値を更新します。
# 収益の計算用の割引率を指定 gamma = 0.9 # 状態価値の計算用の学習率 alpha = 0.01 # (仮の)現在の状態を設定 state = (0, 1) print(V[state]) # (仮の)報酬を設定 reward = 1 # TDターゲットを計算 target = reward + gamma * next_V # 現在の状態の状態価値を更新:式(6.9) V[state] += (target - V[state]) * alpha print(V[state])
0.49067951041983626
0.49817487078097067
「現在の状態state
」をキーとして、式(6.9)により状態価値を計算して、「現在の状態の状態価値V[state]
」を更新します。
以上が、TD法による方策評価を行うエージェントの処理です。
実装
処理の確認ができたので、TD法におけるエージェントをクラスとして実装します。
# TD法によるエージェントの実装 class TdAgent: # 初期化メソッドの定義 def __init__(self): # パラメータを指定 self.gamma = 0.9 # 収益の計算用の割引率 self.alpha = 0.01 # 状態価値の計算用の学習率 self.action_size = 4 # 行動の種類数 # オブジェクトを初期化 random_actions = {0: 0.25, 1: 0.25, 2: 0.25, 3: 0.25} # 確率論的方策 self.pi = defaultdict(lambda: random_actions) # ターゲット方策・挙動方策 self.V = defaultdict(lambda: 0) # 状態価値関数 # 行動メソッドの定義 def get_action(self, state): # 現在の状態の挙動方策の確率分布を取得 action_probs = self.pi[state] # 確率分布 actions = list(action_probs.keys()) # 行動番号 probs = list(action_probs.values()) # 行動確率 # 確率論的方策に従う行動を出力 return np.random.choice(actions, p=probs) # 更新メソッドの定義 def eval(self, state, reward, next_state, done): # 次の状態の状態価値を取得 next_V = 0 if done else self.V[next_state] # ゴールの場合は0を設定 # 現在の状態の状態価値を更新:式(6.9) target = reward + self.gamma * next_V self.V[state] += (target - self.V[state]) * self.alpha
実装したクラスを試してみましょう。
環境(グリッドワールド)とエージェントのインスタンスを作成して、1エピソードの処理を行います。
# 環境・エージェントのインスタンスを作成 env = GridWorld() agent = TdAgent() # 行動の表示用のリストを作成 arrows = ['↑', '↓', '←', '→'] # 最初の状態を設定 state = env.start_state # 時刻(試行回数)を初期化 t = 0 # 1エピソードのシミュレーション while True: # 時刻をカウント t += 1 # 確率論的方策に従い行動を決定 action = agent.get_action(state) # サンプルデータを取得 next_state, reward, done = env.step(action) # 現在の状態の状態価値関数を更新:式(6.9) agent.eval(state, reward, next_state, done) # サンプルデータを表示 print( 't=' + str(t) + ', S_t=' + str(state) + ', A_t=' + arrows[action] + ', S_t+1=' + str(next_state) + ', R_t=' + str(reward) ) # ゴールに着いたらエピソードを終了 if done: break # 状態を更新 state = next_state
t=1, S_t=(2, 0), A_t=↑, S_t+1=(1, 0), R_t=0
t=2, S_t=(1, 0), A_t=↑, S_t+1=(0, 0), R_t=0
t=3, S_t=(0, 0), A_t=↓, S_t+1=(1, 0), R_t=0
t=4, S_t=(1, 0), A_t=→, S_t+1=(1, 0), R_t=0
t=5, S_t=(1, 0), A_t=↑, S_t+1=(0, 0), R_t=0
t=6, S_t=(0, 0), A_t=→, S_t+1=(0, 1), R_t=0
t=7, S_t=(0, 1), A_t=→, S_t+1=(0, 2), R_t=0
t=8, S_t=(0, 2), A_t=↓, S_t+1=(1, 2), R_t=0
t=9, S_t=(1, 2), A_t=→, S_t+1=(1, 3), R_t=-1.0
t=10, S_t=(1, 3), A_t=←, S_t+1=(1, 2), R_t=0
t=11, S_t=(1, 2), A_t=↓, S_t+1=(2, 2), R_t=0
t=12, S_t=(2, 2), A_t=→, S_t+1=(2, 3), R_t=0
t=13, S_t=(2, 3), A_t=↑, S_t+1=(1, 3), R_t=-1.0
t=14, S_t=(1, 3), A_t=←, S_t+1=(1, 2), R_t=0
t=15, S_t=(1, 2), A_t=→, S_t+1=(1, 3), R_t=-1.0
t=16, S_t=(1, 3), A_t=←, S_t+1=(1, 2), R_t=0
t=17, S_t=(1, 2), A_t=↑, S_t+1=(0, 2), R_t=0
t=18, S_t=(0, 2), A_t=↓, S_t+1=(1, 2), R_t=0
t=19, S_t=(1, 2), A_t=←, S_t+1=(1, 2), R_t=0
t=20, S_t=(1, 2), A_t=↑, S_t+1=(0, 2), R_t=0
t=21, S_t=(0, 2), A_t=→, S_t+1=(0, 3), R_t=1.0
agent
のget_action()
で方策に従い行動して、env
のstep()
で状態を遷移し報酬を出力します。
得られたサンプルデータ(現在の状態・報酬・次の状態)を使って、agent
のeval()
で現在の状態の状態価値関数を計算します。
ゴールマスに着くとdone
がTrue
に設定されるので、break
でループ処理を終了します。
状態価値関数をヒートマップで確認します。
# 状態価値関数のヒートマップを作成
env.render_v(v=agent.V, policy=agent.pi)
得られたサンプルデータの数によって、状態ごとに更新された回数が異なります。
以上で、TD法のエージェントを実装できました。
・TD法による方策評価
最後に、TD法により状態価値関数を推定して、更新の推移を確認します。
推定
TD法により状態価値関数を繰り返し更新します。
# 環境・エージェントのインスタンスを作成 env = GridWorld() agent = TdAgent() # エピソード数を指定 episodes = 1000 # 推移の可視化用のリストを初期化 trace_V = [{state: agent.V[state] for state in env.states()}] # 初期値を記録 # 繰り返しシミュレーション for episode in range(episodes): # 状態を初期化 state = env.reset() # 時刻(試行回数)を初期化 t = 0 # 1エピソードのシミュレーション while True: # 時刻をカウント t += 1 # ランダムに行動を決定 action = agent.get_action(state) # サンプルデータを取得 next_state, reward, done = env.step(action) # 現在の状態の状態価値関数を更新:式(6.9) agent.eval(state, reward, next_state, done) # ゴールに着いた場合 if done: # 更新値を記録 trace_V.append(agent.V.copy()) # 総時刻を表示 print('episode '+str(episode+1) + ': T='+str(t)) # エピソードを終了 break # 状態を更新 state = next_state
episode 1: T=14
episode 2: T=37
episode 3: T=61
episode 4: T=41
episode 5: T=21
(省略)
episode 996: T=77
episode 997: T=27
episode 998: T=22
episode 999: T=11
episode 1000: T=15
スタートマスからランダムに行動し、ゴールマスに着くまでを1エピソードとします。エピソードごとに、GridWorld
クラスのreset()
メソッドで状態を初期化し(エージェントをスタートマスに戻し)ます。
episodes
に指定した回数のシミュレーションを行い、時刻(状態)ごとにagent
のevals()
で状態価値関数を更新します。
推移の確認用に、状態価値関数の更新値をtrace_V
に格納していきます。
推定した状態価値関数をヒートマップで確認します。
# 状態価値関数のヒートマップを作図
env.render_v(v=agent.V, policy=agent.pi)
結果の解釈については本を参照してください。
更新推移の可視化
ここまでで、繰り返しの更新処理を確認しました。続いて、途中経過をアニメーションで確認します。
状態価値関数のヒートマップのアニメーションを作成します。
・作図コード(クリックで展開)
# グリッドマップのサイズを取得 xs = env.width ys = env.height # 状態価値の最大値・最小値を取得 vmax = max([max(trace_V[i].values()) for i in range(len(trace_V))]) vmin = min([min(trace_V[i].values()) for i in range(len(trace_V))]) # 色付け用に最大値・最小値を再設定 vmax = max(vmax, abs(vmin)) vmin = -1 * vmax vmax = 1 if vmax < 1 else vmax vmin = -1 if vmin > -1 else vmin # カラーマップを設定 color_list = ['red', 'white', 'green'] cmap = LinearSegmentedColormap.from_list('colormap_name', color_list) # 図を初期化 fig = plt.figure(figsize=(10, 7.5), facecolor='white') # 図の設定 plt.suptitle('TD Method', fontsize=20) # 全体のタイトル # 作図処理を関数として定義 def update(i): # 前フレームのグラフを初期化 plt.cla() # i回目の更新値を取得 pi = agent.pi V = trace_V[i] # ディクショナリを配列に変換 v = np.zeros((env.shape)) for state, value in V.items(): v[state] = value # 状態価値のヒートマップを描画 plt.pcolormesh(np.flipud(v), cmap=cmap, vmin=vmin, vmax=vmax) # ヒートマップ # マス(状態)ごとに処理 for state in env.states(): # インデックスを取得 y, x = state # 報酬を抽出 r = env.reward_map[state] # 報酬がある場合 if r != 0 and r is not None: # 報酬ラベル用の文字列を作成 txt = 'R ' + str(r) # ゴールの場合 if state == env.goal_state: # 報酬ラベルにゴールを追加 txt = txt + ' (GOAL)' # 報酬ラベルを描画 plt.text(x=x+0.1, y=ys-y-0.9, s=txt, ha='left', va='bottom', fontsize=15) # 壁以外の場合 if state != env.wall_state: # 状態価値ラベルを描画 plt.text(x=x+0.9, y=ys-y-0.1, s=str(np.round(v[y, x], 3)), ha='right', va='top', fontsize=15) # 確率論的方策を抽出 actions = pi[state] # 確率が最大の行動を抽出 max_actions = [k for k, v in actions.items() if v == max(actions.values())] # 矢印の描画用のリストを作成 arrows = ['↑', '↓', '←', '→'] offsets = [(0, 0.1), (0, -0.1), (-0.1, 0), (0.1, 0)] # 行動ごとに処理 for action in max_actions: # 矢印の描画用の値を抽出 arrow = arrows[action] offset = offsets[action] # ゴールの場合 if state == env.goal_state: # 描画せず次の状態へ continue # 方策ラベル(矢印)を描画 plt.text(x=x+0.5+offset[0], y=ys-y-0.5+offset[1], s=arrow, ha='center', va='center', size=20) # 壁の場合 if state == env.wall_state: # 壁を描画 rect = plt.Rectangle(xy=(x, ys-y-1), width=1, height=1, fc=(0.4, 0.4, 0.4, 1.0)) # 長方形を作成 plt.gca().add_patch(rect) # 重ねて描画 # グラフの設定 plt.xticks(ticks=np.arange(xs)) # x軸の目盛位置 plt.yticks(ticks=np.arange(ys), labels=ys-np.arange(ys)-1) # y軸の目盛位置 plt.xlim(xmin=0, xmax=xs) # x軸の範囲 plt.ylim(ymin=0, ymax=ys) # y軸の範囲 plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) # 軸ラベル plt.grid() # グリッド線 plt.title('episode:'+str(i), loc='left') # タイトル # gif画像を作成 anime = FuncAnimation(fig=fig, func=update, frames=len(trace_V), interval=50) # gif画像を保存 anime.save('ch6_1.gif')
各エピソードで更新した状態価値関数をtrace_V
から取り出してヒートマップを描画する処理を関数update()
として定義して、FuncAnimation()
でアニメーション(gif画像)を作成します。
状態価値関数の更新値の推移を折れ線グラフで確認します。
・作図コード(クリックで展開)
# 状態価値関数の推移を作図 plt.figure(figsize=(12, 9), facecolor='white') # 状態ごとに推移を作図 for state in env.states(): # マスのインデックスを取得 h, w = state # 更新値を抽出 v_vals = [trace_V[i][state] for i in range(episodes+1)] # 推移を描画 plt.plot(np.arange(episodes+1), v_vals, label='$V_i(L_{'+str(h)+','+str(w)+'})$') # 各状態の価値の推移 plt.xlabel('episode') plt.ylabel('state-value') plt.suptitle('TD Method', fontsize=20) plt.title('$\gamma='+str(agent.gamma) + ', \\alpha='+str(agent.alpha)+'$', loc='left') plt.grid() plt.legend(loc='upper left', bbox_to_anchor=(1, 1)) plt.show()
行番号を$h$、列番号を$w$として各マスを$L_{h,w}$で表します(図4-9)。また、$i$回目のエピソード終了時点の状態価値を$V_i(L_{h,w})$で表します。
各曲線の縦軸の値が、ヒートマップの色に対応します。
この節では、TD法により方策評価を行うエージェントを実装して、状態価値関数を求めました。次節では、方策オフ型のSARSAを実装して、行動価値関数を求めます。
参考文献
おわりに
更新式の導出で少し悩みましたが、現時点で6章の内容は概ね書けてて順調に進めそうです。
投稿日に公開された新曲をどうぞ🍵
大先輩も大変だ。
【次節の内容】