はじめに
『ゼロから作るDeep Learning 4 ――強化学習編』の独学時のまとめノートです。初学者の補助となるようにゼロつくシリーズの4巻の内容に解説を加えていきます。本と一緒に読んでください。
この記事は、6.3節の内容です。方策オフ型のSARSAにより行動価値関数を推定します。
【前節の内容】
【他の記事一覧】
【この記事の内容】
6.3 方策オフ型のSARSA
方策オフ型のSARSAにより行動価値関数と方策を推定(方策を制御)します。SARSAについては「6.2:SARSA【ゼロつく4のノート】 - からっぽのしょこ」、重点サンプリングについては「5.5:重点サンプリング【ゼロつく4のノート】 - からっぽのしょこ」を参照してください。
6.3.1 方策オフ型と重点サンプリング
まずは、方策オフ型のSARSAによる行動価値関数の更新式を導出します。
数式の確認
方策オン型のSARSAによる行動価値関数の更新式は、次の式でした(6.2.1項)。
ここで、$\gamma$は収益の割引率、$\alpha$は指数移動平均の学習率です。
状態$S_t = s$と行動$A_t = a$は与えられています。サンプリングした報酬$R_t$・状態$S_{t+1}$・行動$A_{t+1}$を用いて、「TDターゲット$R_t + \gamma Q_{\pi}(S_{t+1}, A_{t+1})$の期待値」を指数移動平均で近似します。
方策オフ型では、ターゲット方策$\pi(a | s)$ではなく、挙動方策$b(a | s)$から行動をサンプリングすることを考えます(5.5.2項)。
「ターゲット方策による期待値$\mathbb{E}_{\pi}[\cdot]$」と「挙動方策による期待値$\mathbb{E}_b[\cdot]$」を補正する重み(重点サンプリングの重み)$\rho$は、サンプル$R_t, S_{t+1}, A_{t+1}$が得られる確率の比率です。
ターゲット方策と挙動方策それぞれで$A_{t+1}$が得られる確率の比で計算できます。
TDターゲットに重みを掛けて、指数移動平均を計算します。
方策オフ型のSARSAによる行動価値関数の更新式が得られました。
方策の更新式は、方策オン型のSARSAと同じです。
6.3.2 方策オフ型のSARSAの実装
次は、方策オフ型のSARSAにより行動価値関数と方策の推定を行うエージェントを実装します。
利用するライブラリを読み込みます。
# ライブラリを読み込み import numpy as np from collections import defaultdict, deque # 追加ライブラリ import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap from matplotlib.animation import FuncAnimation
更新推移をアニメーションで確認するのにmatplotlib
のモジュールを利用します。不要であれば省略してください。
また、3×4マスのグリッドワールドのクラスGridWorld
を読み込みます。
# 実装済みのクラスと関数を読み込み import sys sys.path.append('../deep-learning-from-scratch-4-master') from common.gridworld import GridWorld from common.utils import greedy_probs
実装済みクラスの読み込みについては「3.6.1:MNISTデータセットの読み込み【ゼロつく1のノート(Python)】 - からっぽのしょこ」、GridWorld
クラスについては「4.2.1:GridWorldクラスの実装:評価と改善に関するメソッド【ゼロつく4のノート】 - からっぽのしょこ」「4.2.1:GridWorldクラスの実装:可視化に関するメソッド【ゼロつく4のノート】 - からっぽのしょこ」、greedy_probs
関数については「5.4.3-5:モンテカルロ法による方策反復法の実装【ゼロつく4のノート】 - からっぽのしょこ」を参照してください。
処理の確認
SarsaOffPolicyAgent
クラスのupdate
メソッドの内部で行う処理を確認します。SARSAの処理については「6.2:SARSA【ゼロつく4のノート】 - からっぽのしょこ」も参照してください。
例として、ランダムな値の行動価値関数を作成しておきます。
# (仮の)前の時刻のサンプルデータを設定 state, action, reward, done = (0, 1), 1, 1, False # (仮の)現在の時刻のサンプルを設定 next_state, next_action = (0, 2), 3 # 行動の種類数を指定 action_size = 4 # (仮の)行動価値関数を作成 Q = {(s, a): np.random.rand() for s in [state, next_state] for a in range(action_size)} print(list(Q.keys())) print(np.round(list(Q.values()), 3))
[((0, 1), 0), ((0, 1), 1), ((0, 1), 2), ((0, 1), 3), ((0, 2), 0), ((0, 2), 1), ((0, 2), 2), ((0, 2), 3)]
[0.193 0.98 0.745 0.648 0.325 0.641 0.322 0.872]
前の状態state
と現在の状態next_state
ごとに、上下左右の4つの行動に対する値をディクショナリに格納します。
また、ターゲット方策と挙動方策を作成します。
# (仮の)ターゲット方策を作成 pi = {s: {a: r for a, r in zip(range(action_size), np.random.dirichlet(alpha=np.repeat(1, action_size)))} for s in [state, next_state]} print(pi[state]) print(np.sum(list(pi[state].values()))) # (仮の)挙動方策を作成 b = {s: {a: r for a, r in zip(range(action_size), np.random.dirichlet(alpha=np.repeat(1, action_size)))} for s in [state, next_state]} print(b[state]) print(np.sum(list(b[state].values())))
{0: 0.00402079485288214, 1: 0.35592461935354064, 2: 0.09874699844883401, 3: 0.5413075873447433}
1.0
{0: 0.29595268016002374, 1: 0.4312824333875585, 2: 0.005577649438852409, 3: 0.26718723701356545}
1.0000000000000002
ランダムな確率分布を生成するのにディリクレ分布の乱数np.random.dirichlet()
を使っていますが、実装には無関係なので説明を省略します。
現在の状態・行動の行動価値を取り出します。また、重点サンプリングの重みを計算します。
# ゴールの場合 if done: # 現在の時刻の行動価値を0に設定 next_q = 0 # 重みを1に設定 rho = 1 # ゴール以外の場合 else: # 現在の時刻の行動価値を取得 next_q = Q[next_state, next_action] # 重みを計算 rho = pi[next_state][next_action] / b[next_state][next_action] print(next_q) print(rho)
0.871771237208903
1.8829214456962553
「現在の時刻の状態next_state
と行動next_action
」をキーとして、「行動価値関数Q
」から値を取り出して「現在の時刻の行動価値next_q
」とします。ただし、現在の状態がゴールマスのときは、行動価値を0
にします。
「現在の状態next_state
」をキーとしてそれぞれの方策から「状態next_state
の確率論的方策」を取り出し、さらに「行動next_action
」キーとして「状態next_state
で行動next_action
を取る確率」を取り出して、重みを計算します。ただし、現在の状態がゴールマスのときは、重みを1
にします。
前の時刻の行動価値を計算して、値を更新します。
# 収益の計算用の割引率を指定 gamma = 0.9 # 状態価値の計算用の学習率 alpha = 0.01 # TDターゲットを計算 target = rho * (reward + gamma * next_q) # 前の時刻の行動価値関数を更新:式(6.13) Q[state, action] += (target - Q[state, action]) * alpha print(Q[state, action])
1.003785563768994
「前の時刻の状態state
と行動action
」をキーとして、式(6.13)により行動価値を計算して、「前の時刻の行動価値Q[state, action]
」を更新します。
以上が、方策オフ型のSARSAによる方策制御を行うエージェントの処理です。
実装
処理の確認ができたので、方策オフ型のSARSAにおけるエージェントをクラスとして実装します。
# 方策オフ型のSARSAによるエージェントの実装 class SarsaOffPolicyAgent: # 初期化メソッドの定義 def __init__(self): # パラメータを指定 self.gamma = 0.9 # 収益の計算用の割引率 self.alpha = 0.8 # 状態価値の計算用の学習率 self.epsilon = 0.1 # ランダムに行動する確率 self.action_size = 4 # 行動の種類数 # オブジェクトを初期化 random_actions = {0: 0.25, 1: 0.25, 2: 0.25, 3: 0.25} # 確率論的方策 self.pi = defaultdict(lambda: random_actions) # ターゲット方策 self.b = defaultdict(lambda: random_actions) # 挙動方策 self.Q = defaultdict(lambda: 0) # 行動価値関数 self.memory = deque(maxlen=2) # サンプルデータ # 行動メソッドの定義 def get_action(self, state): # 現在の状態の挙動方策の確率分布を取得 action_probs = self.b[state] # 確率分布 actions = list(action_probs.keys()) # 行動番号 probs = list(action_probs.values()) # 行動確率 # 確率論的方策に従う行動を出力 return np.random.choice(actions, p=probs) # サンプルデータの初期化メソッドの定義 def reset(self): # リストを初期化 self.memory.clear() # 更新メソッドの定義 def update(self, state, action, reward, done): # 現在の時刻のサンプルデータを格納 self.memory.append((state, action, reward, done)) # 初回は更新しない if len(self.memory) < 2: return # サンプルデータを取得 state, action, reward, done = self.memory[0] # 前の時刻 next_state, next_action, _, _ = self.memory[1] # 現在の時刻 # ゴールの場合 if done: # 現在の時刻の行動価値を0に設定 next_q = 0 # 重みを1に設定 rho = 1 # ゴール以外の場合 else: # 現在の時刻の行動価値を取得 next_q = self.Q[next_state, next_action] # 重みを計算 rho = self.pi[next_state][next_action] / self.b[next_state][next_action] # 前の時刻の行動価値関数を更新:式(6.13) target = rho * (reward + self.gamma * next_q) self.Q[state, action] += (target - self.Q[state, action]) * self.alpha # greedy法によりターゲット方策を更新 self.pi[state] = greedy_probs(self.Q, state, 0) # ε-greedy法により挙動方策を更新:式(6.11) self.b[state] = greedy_probs(self.Q, state, self.epsilon)
実装したクラスを試してみましょう。
環境(グリッドワールド)とエージェントのインスタンスを作成して、1エピソードの処理を行います。
# 環境・エージェントのインスタンスを作成 env = GridWorld() agent = SarsaOffPolicyAgent() # 行動の表示用のリストを作成 arrows = ['↑', '↓', '←', '→'] # 最初の状態を設定 state = env.start_state # 時刻(試行回数)を初期化 t = 0 # 1エピソードのシミュレーション while True: # 時刻をカウント t += 1 # 挙動方策(ε-greedy法)により行動を決定 action = agent.get_action(state) # サンプルデータを取得 next_state, reward, done = env.step(action) # 前の状態・行動の行動価値関数・方策を更新:式(6.13,11) agent.update(state, action, reward, done) # サンプルデータを表示 print( 't=' + str(t) + ', S_t=' + str(state) + ', A_t=' + arrows[action] + ', S_t+1=' + str(next_state) + ', R_t=' + str(reward) ) # ゴールに着いた場合 if done: # 現在の状態・行動の行動価値関数・方策を更新:式(6.13,11) agent.update(next_state, None, None, None) # エピソードを終了 break # 状態を更新 state = next_state
t=1, S_t=(2, 0), A_t=→, S_t+1=(2, 1), R_t=0
t=2, S_t=(2, 1), A_t=→, S_t+1=(2, 2), R_t=0
t=3, S_t=(2, 2), A_t=↓, S_t+1=(2, 2), R_t=0
t=4, S_t=(2, 2), A_t=↓, S_t+1=(2, 2), R_t=0
t=5, S_t=(2, 2), A_t=↑, S_t+1=(1, 2), R_t=0
t=6, S_t=(1, 2), A_t=↓, S_t+1=(2, 2), R_t=0
t=7, S_t=(2, 2), A_t=↓, S_t+1=(2, 2), R_t=0
t=8, S_t=(2, 2), A_t=↓, S_t+1=(2, 2), R_t=0
t=9, S_t=(2, 2), A_t=↑, S_t+1=(1, 2), R_t=0
t=10, S_t=(1, 2), A_t=↓, S_t+1=(2, 2), R_t=0
t=11, S_t=(2, 2), A_t=→, S_t+1=(2, 3), R_t=0
t=12, S_t=(2, 3), A_t=→, S_t+1=(2, 3), R_t=0
t=13, S_t=(2, 3), A_t=→, S_t+1=(2, 3), R_t=0
t=14, S_t=(2, 3), A_t=↑, S_t+1=(1, 3), R_t=-1.0
t=15, S_t=(1, 3), A_t=←, S_t+1=(1, 2), R_t=0
t=16, S_t=(1, 2), A_t=↑, S_t+1=(0, 2), R_t=0
t=17, S_t=(0, 2), A_t=→, S_t+1=(0, 3), R_t=1.0
agent
のget_action()
で挙動方策に従い行動して、env
のstep()
で状態を遷移し報酬を出力します。
得られたサンプルデータ(1つ前の状態・行動・報酬と現在の状態・行動)を使って、agent
のupdate()
で1つ前の状態と行動の行動価値関数と方策を計算します。ただし、2時刻分のサンプルデータが必要なので、初回は更新されません。
ゴールマスに着くとdone
がTrue
に設定されるので、ダミーの次のサンプルデータを用意して最後の時刻の更新を行い、break
でループ処理を終了します。
行動価値関数をヒートマップで確認します。
# 行動価値関数のヒートマップと方策ラベルを作図
env.render_q(q=agent.Q)
render_q()
内部のnp.argmax()
の仕様で、行動価値が等しいとインデックスが最小の行動がラベルで表示されます。
以上で、方策オフ型のSARSAのエージェントを実装できました。
・方策オフ型のSARSAによる方策制御
最後に、方策オフ型のSARSAにより行動価値関数を推定して、更新の推移を確認します。
推定
方策オフ型のSARSAにより行動価値関数と方策を繰り返し更新します。
# 環境・エージェントのインスタンスを作成 env = GridWorld() agent = SarsaOffPolicyAgent() # エピソード数を指定 episodes = 1000 # 推移の可視化用のリストを初期化 trace_Q = [{(state, action): agent.Q[(state, action)] for state in env.states() for action in env.action_space}] # 初期値を記録 # 繰り返しシミュレーション for episode in range(episodes): # 状態を初期化 state = env.reset() # サンプルデータを初期化 agent.reset() # 時刻(試行回数)を初期化 t = 0 # 1エピソードのシミュレーション while True: # 時刻をカウント t += 1 # 挙動方策(ε-greedy法)により行動を決定 action = agent.get_action(state) # サンプルデータを取得 next_state, reward, done = env.step(action) # 前の状態・行動の行動価値関数・方策を更新:式(6.13,11) agent.update(state, action, reward, done) # ゴールに着いた場合 if done: # 現在の状態・行動の行動価値関数・方策を更新:式(6.13,11) agent.update(next_state, None, None, None) # 更新値を記録 trace_Q.append(agent.Q.copy()) # 総時刻を表示 print('episode '+str(episode+1) + ': T='+str(t)) # エピソードを終了 break # 状態を更新 state = next_state
episode 1: T=41
episode 2: T=57
episode 3: T=59
episode 4: T=22
episode 5: T=68
(省略)
episode 996: T=5
episode 997: T=5
episode 998: T=5
episode 999: T=7
episode 1000: T=7
スタートマスからε-greedy法により行動し、ゴールマスに着くまでを1エピソードとします。エピソードごとに、GridWorld
クラスのreset()
メソッドで状態を初期化し(エージェントをスタートマスに戻し)、SarsaAgent
クラスのreset()
メソッドでサンプルデータを初期化(過去のデータを削除)します。
episodes
に指定した回数のシミュレーションを行い、時刻ごとに繰り返しagent
のupdate()
で現在の状態と行動の行動価値関数と方策を更新します。ただし、更新には2時刻分のサンプルデータを利用するので、初回は更新されず、最後の時刻はダミーデータを使って更新します。
推移の確認用に、行動価値関数の更新値をtrace_Q
に格納していきます。
推定した行動価値関数をヒートマップと方策ラベルで確認します。
# 行動価値関数のヒートマップを作図
env.render_q(q=agent.Q)
結果の解釈については本を参照してください。
更新推移の可視化
ここまでで、繰り返しの更新処理を確認しました。続いて、途中経過をアニメーションで確認します。
行動価値関数のヒートマップのアニメーションを作成します。
・作図コード(クリックで展開)
# グリッドマップのサイズを取得 xs = env.width ys = env.height # 状態価値の最大値・最小値を取得 qmax = max([max(trace_Q[i].values()) for i in range(len(trace_Q))]) qmin = min([min(trace_Q[i].values()) for i in range(len(trace_Q))]) # 色付け用に最大値・最小値を再設定 qmax = max(qmax, abs(qmin)) qmin = -1 * qmax qmax = 1 if qmax < 1 else qmax qmin = -1 if qmin > -1 else qmin # カラーマップを設定 color_list = ['red', 'white', 'green'] cmap = LinearSegmentedColormap.from_list('colormap_name', color_list) # 図を初期化 fig = plt.figure(figsize=(12, 9), facecolor='white') # 図の設定 plt.suptitle('Off-policy SARSA', fontsize=20) # 全体のタイトル # 作図処理を関数として定義 def update(i): # 前フレームのグラフを初期化 plt.cla() # i回目の更新値を取得 Q = trace_Q[i] # マス(状態)ごとに処理 for state in env.states(): # 行動ごとに処理 for action in env.action_space: # インデックスを取得 y, x = state # 報酬を抽出 r = env.reward_map[y, x] # 報酬がある場合 if r != 0 and r is not None: # 報酬ラベル用の文字列を作成 txt = 'R ' + str(r) # ゴールの場合 if state == env.goal_state: # 報酬ラベルにゴールを追加 txt = txt + ' (GOAL)' # 報酬ラベルを描画 plt.text(x=x+0.1, y=ys-y-0.9, s=txt, ha='left', va='bottom', fontsize=15) # ゴールの場合 if state == env.goal_state: # 描画せず次の状態へ continue # 作図用のx軸・y軸の値を設定 tx, ty = x, ys-y-1 # 行動ごとの三角形の頂点を設定 action_map = { 0: ((0.5+tx, 0.5+ty), (1.0+tx, 1.0+ty), (tx, 1.0+ty)), # 上 1: ((tx, ty), (1.0+tx, ty), (0.5+tx, 0.5+ty)), # 下 2: ((tx, ty), (0.5+tx, 0.5+ty), (tx, 1.0+ty)), # 左 3: ((0.5+tx, 0.5+ty), (1.0+tx, ty), (1.0+tx, 1.0+ty)) # 右 } # 行動ごとの価値ラベルのプロット位置を設定 offset_map = { 0: (0.5, 0.75), # 上 1: (0.5, 0.25), # 下 2: (0.25, 0.5), # 左 3: (0.75, 0.5) # 右 } # 壁の場合 if state == env.wall_state: # 壁を描画 rect = plt.Rectangle(xy=(tx, ty), width=1, height=1, fc=(0.4, 0.4, 0.4, 1.0)) # 長方形を作成 plt.gca().add_patch(rect) # 重ねて描画 # (よく分からない) elif state in env.goal_state: plt.gca().add_patch(plt.Rectangle(xy=(tx, ty), width=1, height=1, fc=(0.0, 1.0, 0.0, 1.0))) # 壁以外の場合 else: # 行動価値を抽出 tq = Q[(state, action)] # 行動価値を0から1に正規化 color_scale = 0.5 + (tq / qmax) / 2 # 三角形を描画 poly = plt.Polygon(action_map[action],fc=cmap(color_scale)) # 三角形を作成 plt.gca().add_patch(poly) # 重ねて描画 # プロット位置の調整値を取得 offset = offset_map[action] # 行動価値ラベルを描画 plt.text(x=tx+offset[0], y=ty+offset[1], s=str(np.round(tq, 3)), ha='center', va='center', size=15) # 行動価値ラベル # グラフの設定 plt.xticks(ticks=np.arange(xs)) # x軸の目盛位置 plt.yticks(ticks=np.arange(ys), labels=ys-np.arange(ys)-1) # y軸の目盛位置 plt.xlim(xmin=0, xmax=xs) # x軸の範囲 plt.ylim(ymin=0, ymax=ys) # y軸の範囲 plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) # 軸ラベル plt.grid() # グリッド線 plt.title('episode:'+str(i), loc='left') # タイトル # gif画像を作成 anime = FuncAnimation(fig=fig, func=update, frames=len(trace_Q), interval=50) # gif画像を保存 anime.save('ch6_3.gif')
各エピソードで更新した行動価値をtrace_Q
から取り出してヒートマップを描画する処理を関数update()
として定義して、FuncAnimation()
でアニメーション(gif画像)を作成します。
行動価値関数の更新値の推移を折れ線グラフで確認します。
・作図コード(クリックで展開)
# 行動ラベルを設定 arrows = ['↑', '↓', '←', '→'] # 状態価値関数の推移を作図 plt.figure(figsize=(15, 10), facecolor='white') for state in env.states(): for action in range(agent.action_size): # 更新値を抽出 q_vals = [trace_Q[i][(state, action)] for i in range(episodes+1)] # 各状態の価値の推移を描画 plt.plot(np.arange(episodes+1), q_vals, alpha=0.5, label='$Q_i(L_{'+str(state[0])+','+str(state[1])+'},'+arrows[action]+')$') plt.xlabel('episode') plt.ylabel('action-value') plt.suptitle('Off-policy SARSA', fontsize=20) plt.title('$\gamma='+str(agent.gamma) + ', \\alpha='+str(agent.alpha)+'$', loc='left') plt.grid() plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=2) plt.show()
行番号を$h$、列番号を$w$として各マスを$L_{h,w}$で表します(図4-9)。また、$i$回目の行動価値を$Q_i(L_{h,w}, A)$で表します。
各曲線の縦軸の値が、ヒートマップの色に対応します。
この節では、方策オフ型のSARSAにより方策制御を行うエージェントを実装して、最適方策を求めました。次節では、Q学習を実装して、最適方策を求めます。
参考文献
おわりに
色々な内容が込み入ってきましたが、なんのためにどの手法を使っているのか整理できていますでしょうか?私はなんとか理解できているつもりです。次の手法がメインテーマのようですね。
【次節の内容】