はじめに
『ゼロから作るDeep Learning 4 ――強化学習編』の独学時のまとめノートです。初学者の補助となるようにゼロつくシリーズの4巻の内容に解説を加えていきます。本と一緒に読んでください。
この記事は、4.2.1節の内容です。3×4マスのグリッドワールドのクラスについて確認します。
【前節の内容】
【他の記事一覧】
【この記事の内容】
4.2.1 GridWorldクラスの実装
3×4マスのグリッドワールドのクラスGridWorld
の内部で行われる処理と使い方を確認します。
利用するライブラリを読み込みます。
# 利用するライブラリ import numpy as np import matplotlib.pyplot as plt import matplotlib
実装済みのGridWorld
クラスは、次のようにして読み込めます。
# 読み込み用のライブラリ import sys # フォルダパスを指定 sys.path.append('../deep-learning-from-scratch-4-master') # 実装済みクラスを読み込み from common.gridworld import GridWorld
実装済みクラスの読み込みについては「3.6.1:MNISTデータセットの読み込み【ゼロつく1のノート(Python)】 - からっぽのしょこ」を参照してください。
GridWorld
クラスのインスタンスを作成しておきます。
# インスタンスを作成
env = GridWorld()
・initメソッドとインスタンス変数
インスタンスの作成時に実行される「初期化メソッド」の処理と、「インスタンス変数」の処理を確認します。
行動番号を格納したリストと、行動番号がキーで行動内容が値のディクショナリを作成します。
# 行動番号を指定 action_space = [0, 1, 2, 3] # 行動番号と行動内容の対応ディクショナリを指定 action_meaning = {0: 'UP', 1: 'DOWN', 2: 'LEFT', 3: 'RIGHT'}
縦横方向に広がるグリッドワールドなので、エージェントは、上下左右の4つの行動を取ります。
次のようにして、行動を順番に取り出せます。
# 行動を抽出 for action in action_space: print(action, action_meaning[action])
0 UP
1 DOWN
2 LEFT
3 RIGHT
ランダムに行動を取るような場面でも同様です。
# ランダムに行動を生成 actions = np.random.choice(action_space, size=10) # 行動を抽出 for action in actions: print(action, action_meaning[action])
1 DOWN
1 DOWN
2 LEFT
2 LEFT
1 DOWN
2 LEFT
0 UP
1 DOWN
2 LEFT
3 RIGHT
(私の進捗状況ではまだ登場していないので、今後どういう風に使うのか分かってません。)
グリッドワールドの設定用のNumPy配列とタプルを作成します。
# 各状態(マス)の報酬を指定 reward_map = np.array( [[0, 0, 0, 1.0], [0, None, 0, -1.0], [0, 0, 0, 0]] ) # スタートの位置を指定 start_state = (2, 0) # ゴールの位置を指定 goal_state = (0, 3) # 壁の位置を指定 wall_state = (1, 1) # 現在のエージェントの状態を初期化 agent_state = start_state
各マス(状態)の報酬をreward_map
として、配列に指定します。
スタートの位置をstart_state
、ゴールの位置をgoal_state
、壁の位置をwall_state
として、タプルに指定します。壁の位置と報酬がNone
の位置(インデックス)が対応します。
エージェントの位置(現在の状態)をagent_state
として、初期値をスタートの位置にします。
グリッドワールドのサイズは、次のようにして得られます。
# 縦方向のマスの数を取得 height = len(reward_map) print(height) # 横方向のマスの数を取得 width = len(reward_map[0]) print(width) # グリッドワールドの形状を取得 shape = reward_map.shape print(shape)
3
4
(3, 4)
reward_map
の0番目の軸の要素数が縦方向のマスの数、1番目の軸の要素数が横方向のマスの数です。
以上が、初期化メソッドとインスタンス変数で行われる処理です。続いて、実装済みのクラスを試してみましょう。
各状態に関する値は、インスタンス変数として保存されます。
# 各状態(マス)の報酬を出力 print(env.reward_map) # スタートの位置を出力 print(env.start_state) # ゴールの位置を出力 print(env.goal_state) # 壁の位置を出力 print(env.wall_state) # 現在の状態(エージェントの位置)を出力 print(env.agent_state)
[[0 0 0 1.0]
[0 None 0 -1.0]
[0 0 0 0]]
(2, 0)
(0, 3)
(1, 1)
(2, 0)
グリッドワールドの形状に関する値も、インスタンスとして出力できます。
# 縦方向のマスの数を出力 print(env.height) # 横方向のマスの数を出力 print(env.width) # グリッドワールドの形状を出力 print(env.shape)
3
4
(3, 4)
以上が、初期化メソッドとインスタンス変数の処理です。次は、全ての行動メソッドを確認します。
・actionsメソッド
エージェントの全ての行動を出力する「全ての行動メソッド」の処理を確認します。
行動番号は、インスタンス変数action_space
に保存されています。
# 行動番号を出力 print(env.action_space)
[0, 1, 2, 3]
これをメソッドにしたのがactions()
です。
# 行動番号を出力 print(env.actions())
[0, 1, 2, 3]
以上が、全ての行動メソッドの処理です。次は、状態のリセットメソッドを確認します。
・resetメソッド
現在の状態を初期化して出力する「状態のリセットメソッド」の処理を確認します。
現在の状態(エージェントの位置)は、インスタンス変数agent_state
に保存されていて、書き換えられます(遷移します)。
# 状態を指定 state = (0, 0) # 現在の状態を変更 env.agent_state = state print(env.agent_state)
(0, 0)
状態の遷移についてはstepメソッドで確認します。
agent_state
にスタート位置start_state
を代入することで、現在の状態を初期化できます。
# 現在の状態を初期化 env.agent_state = env.start_state print(env.agent_state)
(2, 0)
エージェントの位置をスタートの位置に戻すことを意味します。
これをメソッドにしたのがreset()
です。
# 現在の状態を変更 env.agent_state = (0, 0) print(env.agent_state) # 初期状態に戻して出力 state = env.reset() print(state) print(env.agent_state)
(0, 0)
(2, 0)
(2, 0)
現在の状態を初期化して、初期状態を出力します。
以上が、状態のリセットメソッドの処理です。次は、全ての状態メソッドを確認します。
・statesメソッド
グリッドワールドの全ての状態(全てのマスのインデックス)を順番に出力する「全ての状態メソッド」の処理を確認します。
states()
では、yield
を使って繰り返し値を出力します。まずは、0
からn-1
の整数を出力する関数を作成して、return
とyeild
の違いを確認します。
return
を使って、次のように定義してみます。
# returnにより出力する関数 def f_return(n=3): for i in range(n): return i
実装した関数を実行します。
# 実行 print(f_return())
0
最初の値しか出力されません。これは、return
が実行されると関数内部の処理が終了するためです。
続いて、yield
を使って定義してみます。
# yieldにより出力する関数 def f_yield(n=3): for i in range(n): yield i
そのまま実行すると、generator
オブジェクトが出力されます。
# そのまま実行 print(f_yield())
<generator object f_yield at 0x0000027062A4EAC0>
generator
オブジェクトの説明は省略しますが、次のようにして使えます。
# for文を使って実行 for i in f_yield(): print(i)
0
1
2
0
からn-1
の整数を出力できました。for
文のような繰り返し処理の中で、順番に値が出力されます。
続いて、メソッド内部の処理を確認します。
縦と横のサイズを使って、全ての状態を作成します。
# 全ての状態を作成 for h in range(env.height): for w in range(env.width): state = (h, w) print(state)
(0, 0)
(0, 1)
(0, 2)
(0, 3)
(1, 0)
(1, 1)
(1, 2)
(1, 3)
(2, 0)
(2, 1)
(2, 2)
(2, 3)
縦方向のマス番号(y軸の値)をh
、横方向のマス番号(x軸の値)をw
として、各マス番号(2次元配列のインデックス)をタプルに格納して出力します。
以上が、全ての状態メソッドで行われる処理です。続いて、実装済みのクラスを試してみましょう。
states()
メソッドで全ての状態を出力します。
# 各状態を順番に出力 for state in env.states(): print(state)
(0, 0)
(0, 1)
(0, 2)
(0, 3)
(1, 0)
(1, 1)
(1, 2)
(1, 3)
(2, 0)
(2, 1)
(2, 2)
(2, 3)
例えば、出力した状態をインデックスとして使って、対応する報酬を取り出せます。
# 各状態の報酬を順番に出力 for state in env.states(): print('state', state, ': reward', env.reward_map[state])
state (0, 0) : reward 0
state (0, 1) : reward 0
state (0, 2) : reward 0
state (0, 3) : reward 1.0
state (1, 0) : reward 0
state (1, 1) : reward None
state (1, 2) : reward 0
state (1, 3) : reward -1.0
state (2, 0) : reward 0
state (2, 1) : reward 0
state (2, 2) : reward 0
state (2, 3) : reward 0
以上が、全ての状態メソッドの処理です。次は、次の状態メソッドを確認します。
・next_stateメソッド
エージェントの行動により遷移した次の状態を出力する「次の状態メソッド」の処理を確認します。
まずは、4方向への移動に対応した値を作成します。
# 行動に対応した移動量を設定 action_move_map = [(-1, 0), (1, 0), (0, -1), (0, 1)] print(action_move_map)
[(-1, 0), (1, 0), (0, -1), (0, 1)]
現在の位置から左のマスに移動するには、縦軸方向には変化せず(0
変化し)、横軸方向に-1
変化します。この変化を(0, -1)
で表します。右に移動する場合は(0, 1)
です。
ただし、上下の移動については直感的でない変化になります。これは、図4-9の座標系や、報酬の配列reward_map
のインデックスに対応するためです。
# 報酬マップを確認 print(env.reward_map)
[[0 0 0 1.0]
[0 None 0 -1.0]
[0 0 0 0]]
上のマスに移動するには、縦軸方向の値が-1
変化します。下に移動する場合は+1
変化します。
これら4つの変化量をリストに格納します。
リストに格納する順番は、行動内容action_meaning
に対応します。
# 行動番号と行動内容の対応を確認 print(env.action_meaning)
{0: 'UP', 1: 'DOWN', 2: 'LEFT', 3: 'RIGHT'}
上下左右の順です。
準備ができたので、行動を指定して、次の状態を計算します。
# エージェントの位置(現在の状態)を指定 state = (2, 0) # 行動を指定 #action = 0 # 上に行動(移動できる) #action = 1 # 下に行動(移動できない) action = 2 # 左に行動(移動できない) #action = 3 # 右に行動(移動できる) # 行動に対応した変化量を抽出 move = action_move_map[action] print(move) # 行動後の位置(次の状態)の候補を計算 next_state = (state[0] + move[0], state[1] + move[1]) print(next_state)
(0, -1)
(2, -1)
行動に対応した変化量をaction_move_map
から取り出して、現在の状態に足します。
ただし、壁が存在するため、行動の通りに移動するとは限りません。よって、この時点のnext_state
は、次の状態の候補と言えます。
そこで、next_state
が、グリッドワールド内のマスであるか、壁のマスでないかを判定します。
# 次の状態のx軸・y軸の値を抽出 ny, nx = next_state print(ny) print(nx) # グリッドワールドの外に移動する場合 if nx < 0 or nx >= env.width or ny < 0 or ny >= env.height: # 元の位置のまま next_state = state # 壁のマスに移動する場合 elif next_state == env.wall_state: # 元の位置のまま next_state = state print(next_state)
2
-1
(2, 0)
次の状態の候補next_state
からy軸の値ny
とx軸の値nx
を取り出して、0
より小さい、または横幅width
・高さheight
以上であれば壁の外なので、元の状態state
に変更します。また、壁のマスwall_state
でも元の状態にします。
以上が、次の状態メソッドで行われる処理です。続いて、実装済みのクラスを試してみましょう。
next_state()
に現在の状態と行動を指定して、次の状態を出力します。
# 現在の状態を初期化 state = env.reset() print(state) # 行動を指定 #action = 0 # 上に行動(移動できる) #action = 1 # 下に行動(移動できない) action = 2 # 左に行動(移動できない) #action = 3 # 右に行動(移動できる) # 次の状態を出力 next_state = env.next_state(state, action) print(next_state)
(2, 0)
(2, 0)
以上が、次の状態メソッドの処理です。次は、報酬メソッドを確認します。
・rewardメソッド
指定した状態に対応する報酬を出力する「報酬メソッド」の処理を確認します。
各状態の報酬(NumPy配列)reward_map
に状態(インデックス)を指定すると、対応する報酬を抽出できます。
# 状態を指定 #state = (0, 0) # 通常マス state = (0, 3) # リンゴの位置 #state = (1, 3) # 爆弾の位置 # 報酬を出力 reward = env.reward_map[state] print(reward)
1.0
これをメソッドにしたのがreward()
です。
# ダミーの状態を作成 state = '_' # ダミーの行動を作成 action = '_' # 状態を指定 next_state = (0, 3) # 報酬を取得 reward = env.reward(state, action, next_state) print(reward)
1.0
報酬メソッドは、報酬関数(数式での表記)$r(s, a, s')$に対応させるために、現在の状態state
・行動action
・次の状態next_state
の3つの引数を持ちますが、処理に利用するのはnext_state
のみです。
以上が、報酬メソッドの処理です。次は、ステップメソッドを確認します。
・stepメソッド
エージェントの行動により、報酬を受け取り状態を遷移する「ステップメソッド」の処理を確認します。
行動を指定して、状態を遷移(エージェントを移動)し、報酬と次の状態を出力します。
# 現在の状態を指定 state = (0, 2) # 現在の状態を設定 env.agent_state = state print(env.agent_state) # 行動を指定 action = 3 # 次の状態を出力 next_state = env.next_state(state, action) print(next_state) # 報酬を抽出 reward = env.reward_map[next_state] print(reward) # 現在の状態を更新 env.agent_state = next_state print(env.agent_state)
(0, 2)
(0, 3)
1.0
(0, 3)
現在の状態と行動をnext_state()
に指定して、次の状態next_state
を出力します。
next_state
をインデックスとして使って、次の状態の報酬をreward_map
から抽出します。
現在の状態agent_state
の値を、次の状態の値に変更します。
この例の問題設定では、ゴールに辿り着くとエピソードが終了します。そのため、次の状態がゴールの位置なのかを判定します。
# ゴールしたかを判定 done = (next_state == env.goal_state) print(done)
True
ゴールであればTrue
、ゴールでなければFalse
になります。
以上が、ステップメソッドで行われる処理です。続いて、実装済みのクラスを試してみましょう。
step()
に行動を指定して、次の状態・報酬・ゴールかの判定結果を出力します。
# 現在の状態を設定 env.agent_state = (0, 2) print(env.agent_state) # 行動を指定 action = 3 # 1ステップの結果を出力 next_step, reward, done = env.step(action) print(next_step) print(reward) print(done) print(env.agent_state)
(0, 2)
(0, 3)
1.0
True
(0, 3)
報酬を受け取り、状態が遷移(エージェントが移動)しました。
続いて、複数ステップを処理してみます。
# 状態を初期化 state = env.reset() print('start state', state) # 行動を指定 actions = [3, 3, 0, 2, 3, 3, 0] # ステップごとに処理 for action in actions: # 報酬を得て状態を遷移 next_step, reward, done = env.step(action) print( 'action', env.action_meaning[action], ':', 'state', next_step, ',', 'reward', reward, ',', 'goal', done )
start state (2, 0)
action RIGHT : state (2, 1) , reward 0 , goal False
action RIGHT : state (2, 2) , reward 0 , goal False
action UP : state (1, 2) , reward 0 , goal False
action LEFT : state (1, 2) , reward 0 , goal False
action RIGHT : state (1, 3) , reward -1.0 , goal False
action RIGHT : state (1, 3) , reward -1.0 , goal False
action UP : state (0, 3) , reward 1.0 , goal True
以上が、ステップメソッドの処理です。次は、状態価値関数の可視化メソッドを確認します。
・render_vメソッド
状態価値関数のヒートマップを作成する「状態価値関数の可視化メソッド」の処理を確認します。ただし、作図コードがとても複雑なので、ここでは基本的な処理のみ確認します。
アルゴリズム自体とは直接関係しないので、飛ばしていいと思います。詳しくは「common」フォルダの「gridworld_render.py」ファイルを参照してください(見ない方がいいよ)。
グリッドワールドのマスの作図、ヒートマップによる状態価値関数の可視化、矢印による方策の可視化の3つの段階に分けて解説します。
・グリッドワールド
まずは、土台となるグリッドワールドを作成します。
実装済みのメソッドを使って完成形を確認しましょう。
# グリッドワールドを作図
env.render_v()
グリッド線を引いてマスを描画し、報酬のラベルを付けて、壁のマスを黒塗りで表現します。
グリッドワールドの縦・横のサイズ(マスの数)を取得します。
# 縦軸のサイズを取得 ys = len(env.reward_map) print(ys) # 横軸のサイズを取得 xs = len(env.reward_map[0]) print(xs)
3
4
3×4のマスを作図します。
## グリッドワールドのマスを作成 # マスを描画 plt.figure(figsize=(9, 6)) # 図の設定 plt.xticks(ticks=np.arange(xs)) # x軸の目盛位置 plt.yticks(ticks=np.arange(ys)) # y軸の目盛位置 plt.xlim(xmin=0, xmax=xs) # x軸の範囲 plt.ylim(ymin=0, ymax=ys) # y軸の範囲 plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) # 軸ラベル plt.grid() # グリッド線 plt.title('base', fontsize=20) # タイトル plt.show()
グリッド線を引く位置をplt.xticks()
とplt.yticks()
に指定します。また、描画範囲をplt.xlim()
とplt.ylim()
に指定します。グリッド線は、plt.grid()
で表示されます。
plt.tick_params()
の各引数をFalse
にして、軸ラベルを非表示にします。
マスの装飾をするために、各マスのy軸・x軸の値を確認しておきます。
## マスの座標を確認 # マスを描画 plt.figure(figsize=(9, 6)) # 図の設定 plt.xticks(ticks=np.arange(xs)) # x軸の目盛位置 plt.yticks(ticks=np.arange(ys)) # y軸の目盛位置 plt.xlim(xmin=0, xmax=xs) # x軸の範囲 plt.ylim(ymin=0, ymax=ys) # y軸の範囲 #plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) # 軸ラベル plt.grid() # グリッド線 # マスごとに処理 for y in range(ys): for x in range(xs): # 装飾するマス(状態)を設定 state = (y, x) # 座標ラベルをそのまま描画 plt.text(x=x+0.5, y=y+0.5, s=str(state), fontsize=20, ha='center', va='center') plt.xlabel('x (width)') # x軸ラベル plt.ylabel('y (height)') # y軸ラベル plt.title('coordinate (y, x)', fontsize=20) # タイトル plt.show()
plt.text()
で図の中にテキストを描画できます。x, y
引数にプロット位置、s
引数に描画する文字列を指定します。
マスの中心に描画するために、x, y
引数にはそれぞれ0.5
を加えた値を指定ます。
ここで注意が必要なのが、通常のグラフでは、横軸は右に行くほど値が大きく、縦軸は上に行くほど値が大きくなります。しかし、図4-9のように下に行くほど値が大きくなるようにしたいです。
そこで、y軸に関しては、縦のサイズys
から各マスの値y+1
を引いた値(あるいはys-1
からy
引いた値)を使います。
## マスのインデックスを描画 # マスを描画 plt.figure(figsize=(9, 6)) # 図の設定 plt.xticks(ticks=np.arange(xs)) # x軸の目盛位置 plt.yticks(ticks=np.arange(ys), labels=ys-np.arange(ys)-1) # y軸の目盛位置 plt.xlim(xmin=0, xmax=xs) # x軸の範囲 plt.ylim(ymin=0, ymax=ys) # y軸の範囲 #plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) # 軸ラベル plt.grid() # グリッド線 # マスごとに処理 for y in range(ys): for x in range(xs): # 装飾するマス(状態) state = (y, x) # 座標ラベルをマスに対応させて描画 plt.text(x=x+0.5, y=ys-y-0.5, s=str(state), fontsize=20, ha='center', va='center') plt.xlabel('x (width)') # x軸ラベル plt.ylabel('y (height)') # y軸ラベル plt.title('index (y, x)', fontsize=20) # タイトル plt.show()
マスの中心は、ys-y-1+0.5なので、ys-y-0.5になります。
各マスのインデックスと座標の関係を確認できました。
では、マスごとに報酬の値を描画します。
## 報酬を描画 # マスを描画 plt.figure(figsize=(9, 6)) # 図の設定 plt.xticks(ticks=np.arange(xs)) # x軸の目盛位置 plt.yticks(ticks=np.arange(ys)) # y軸の目盛位置 plt.xlim(xmin=0, xmax=xs) # x軸の範囲 plt.ylim(ymin=0, ymax=ys) # y軸の範囲 plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) # 軸ラベル plt.grid() # グリッド線 # マスごとに処理 for y in range(ys): for x in range(xs): # 報酬を抽出 r = env.reward_map[y, x] # 報酬ラベルを描画 plt.text(x=x+0.5, y=ys-y-0.5, s=str(r), fontsize=20, ha='center', va='center') plt.title('reward', fontsize=20) # タイトル plt.show()
reward_map
から各マス(状態)の報酬を取り出して描画します。
reward_map
に設定した報酬と一致しています。
# 報酬を確認 print(env.reward_map)
[[0 0 0 1.0]
[0 None 0 -1.0]
[0 0 0 0]]
報酬の無い0, None
は省略して、ゴールの位置も描画します。
## 報酬とゴールを重ねて作図 # マスを描画 plt.figure(figsize=(9, 6)) # 図の設定 plt.xticks(ticks=np.arange(xs)) # x軸の目盛位置 plt.yticks(ticks=np.arange(ys)) # y軸の目盛位置 plt.xlim(xmin=0, xmax=xs) # x軸の範囲 plt.ylim(ymin=0, ymax=ys) # y軸の範囲 plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) # 軸ラベル plt.grid() # グリッド線 # マスごとに処理 for y in range(ys): for x in range(xs): # 装飾するマス(状態)を設定 state = (y, x) # 報酬を抽出 r = env.reward_map[y, x] # 報酬がある場合 if r != 0 and r is not None: # 報酬ラベル用の文字列を作成 txt = 'R ' + str(r) # ゴールの場合 if state == env.goal_state: # 報酬ラベルにゴールを追加 txt = txt + ' (GOAL)' # 報酬ラベルを描画 plt.text(x=x+0.1, y=ys-y-0.9, s=txt) plt.title('reward and goal', fontsize=20) # タイトル plt.show()
報酬r
が0
でもNone
でもない場合に、報酬ラベルを描画します。
また、ゴールのマスの場合には、ゴールを示す文字列を追加します。+
演算子で、文字列を結合できます。
表示位置を中心から調整しました。
続いて、壁のマスを黒く塗りつぶします。
## 壁を重ねて作図 # マスを描画 plt.figure(figsize=(9, 6)) # 図の設定 plt.xticks(ticks=np.arange(xs)) # x軸の目盛位置 plt.yticks(ticks=np.arange(ys)) # y軸の目盛位置 plt.xlim(xmin=0, xmax=xs) # x軸の範囲 plt.ylim(ymin=0, ymax=ys) # y軸の範囲 plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) # 軸ラベル plt.grid() # グリッド線 # マスごとに処理 for y in range(ys): for x in range(xs): # 装飾するマス(状態)を設定 state = (y, x) # 報酬を抽出 r = env.reward_map[y, x] # 報酬がある場合 if r != 0 and r is not None: # 報酬ラベル用の文字列を作成 txt = 'R ' + str(r) # ゴールの場合 if state == env.goal_state: # 報酬ラベルにゴールを追加 txt = txt + ' (GOAL)' # 報酬ラベルを描画 plt.text(x=x+0.1, y=ys-y-0.9, s=txt) # 壁のマスの場合 if state == env.wall_state: # 壁を描画 plt.gca().add_patch(plt.Rectangle(xy=(x, ys-y-1), width=1, height=1, fc=(0.4, 0.4, 0.4, 1.0))) # 長方形を重ねる plt.title('wall', fontsize=20) # タイトル plt.show()
plt.Rectangle()
で長方形を描画できます。xy
引数に長方形の左下の頂点の位置、width
引数に横幅、height
引数に高さ、fc
引数に色を指定します。この例では、濃いグレーになるように値を指定しています。
作成した長方形を、ax.add_patch()
でグリッドワールドに重ねて描画します。
以上で、基本となるグリッドワールドを作図できました。
・状態価値のヒートマップ
次に、状態価値関数のヒートマップを作成します。
例として、ダミーの状態価値関数のディクショナリを作成しておきます。
# 仮の状態価値のディクショナリを作成 V = {state: np.random.randn() for state in env.states()} print(list(V.keys())) print(np.round(list(V.values()), 2))
[(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3), (2, 0), (2, 1), (2, 2), (2, 3)]
[-0.75 -0.48 0.07 -0.48 -0.77 -0.35 0.32 0.93 -0.79 1.69 -0.58 0.14]
状態(マスのインデックス)をキー、状態価値を値とします。(リスト内包表記を使っていますが、処理の内容は101ページと同じです。)
完成形を確認しましょう。
# 状態価値関数のヒートマップを作図
env.render_v(v=V)
状態価値関数の値が、負の値なら赤色、0なら白色、正の値なら緑色になります。また、値が小さいほど濃い赤色、大きいほど濃い緑色になります。
さらに、状態価値のラベルを表示します。
状態価値関数のディクショナリをNumPy配列に変換します。
# 状態価値の受け皿(配列)を作成 v = np.zeros(env.reward_map.shape) # 要素ごとに処理 for state, value in V.items(): # 状態をインデックスとして値を格納 v[state] = value print(np.round(v, 2))
[[-0.75 -0.48 0.07 -0.48]
[-0.77 -0.35 0.32 0.93]
[-0.79 1.69 -0.58 0.14]]
状態(マスの位置)を示すタプル型のキーを、配列のインデックスとして利用して、対応する要素に値を格納します。
状態価値の最小値と最大値を作成します。また、カラーマップを作成します。
# 最小値・最大値を取得 vmin = v.min() vmax = v.max() print(vmin) print(vmax) # カラーマップを設定 color_list = ['red', 'white', 'green'] cmap = matplotlib.colors.LinearSegmentedColormap.from_list('colormap_name', color_list)
-0.7949579351253576
1.6897214538929588
最小値と最大値は、色の濃淡を付けるのに利用します。
カラーマップについては省略します。
ヒートマップを作成して、状態価値のラベルを付けます。
## 状態価値関数のヒートマップを作図 # マスを描画 plt.figure(figsize=(9, 6)) # 図の設定 plt.xticks(ticks=np.arange(xs)) # x軸の目盛位置 plt.yticks(ticks=np.arange(ys)) # y軸の目盛位置 plt.xlim(xmin=0, xmax=xs) # x軸の範囲 plt.ylim(ymin=0, ymax=ys) # y軸の範囲 plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) # 軸ラベル plt.grid() # グリッド線 # 状態価値のヒートマップを描画 plt.pcolormesh(np.flipud(v), cmap=cmap, vmin=vmin, vmax=vmax) # ヒートマップマップ # マスごとに処理 for y in range(ys): for x in range(xs): # 装飾するマス(状態)を設定 state = (y, x) # 報酬を抽出 r = env.reward_map[y, x] # 報酬がある場合 if r != 0 and r is not None: # 報酬ラベル用の文字列を作成 txt = 'R ' + str(r) # ゴールの場合 if state == env.goal_state: # 報酬ラベルにゴールを追加 txt = txt + ' (GOAL)' # 報酬ラベルを描画 plt.text(x=x+0.1, y=ys-y-0.9, s=txt) # 壁以外の場合 if state != env.wall_state: # 状態価値ラベルを描画 plt.text(x=x+0.6, y=ys-y-0.15, s=str(np.round(v[y, x], 2))) # 壁の場合 if state == env.wall_state: # 壁を描画 plt.gca().add_patch(plt.Rectangle(xy=(x, ys-y-1), width=1, height=1, fc=(0.4, 0.4, 0.4, 1.0))) # 長方形を重ねる plt.title('heatmap', fontsize=20) # タイトル plt.show()
pcolormesh()
でヒートマップを描画できます。第1引数に値、cmap
にカラーマップ、vmin, vmax
引数に最小値・最大値を指定します。
最小値・最大値によって各マス(値)の濃淡が決まります。
ただし、配列の位置(インデックス)とグラフ上の位置(座標)には、ズレがあるのでした。そこで、配列の要素をnp.flipud()
で縦方向に反転させます。
# 配列を確認 print(np.round(v, 2)) # 縦方向に反転 print(np.round(np.flipud(v), 2))
[[-0.75 -0.48 0.07 -0.48]
[-0.77 -0.35 0.32 0.93]
[-0.79 1.69 -0.58 0.14]]
[[-0.79 1.69 -0.58 0.14]
[-0.77 -0.35 0.32 0.93]
[-0.75 -0.48 0.07 -0.48]]
ラベルの表示上は元の配列と一致しますが、描画には反転させた配列を利用します。
赤・白・緑の色付けは、最小値と最大値により自動で決まります。負の値なら赤色、0なら白色、正の値なら緑色になるように、最小値と最大値を調整します。
# 最小値・最大値の絶対値に応じて再設定 vmax = max(vmax, abs(vmin)) vmin = -1 * vmax print(vmin) print(vmax) # 最大値が1より小さい場合は1に再設定 if vmax < 1: vmax = 1 # 最小値が-1より大きい場合は-1に再設定 if vmin > -1: vmin = -1 print(vmin) print(vmax)
-1.6897214538929588
1.6897214538929588
-1.6897214538929588
1.6897214538929588
最小値の絶対値と最大値を比較して、大きい方の値を最大値として利用します。また、その最大値をマイナスにした値を最小値とします。
例えば、状態価値が-2から3の範囲であれば、ヒートマップの設定範囲を-3から3にします。これにより、中央値の0が白色になり、また赤色と緑色の境界になります。
さらに、最大値が1未満のときは1に、最小値が-1より大きい場合は-1にします。
先ほどの作図コード(タイトルは変えました)を使って、再設定した範囲でヒートマップを作成します。
以上で、状態価値関数のヒートマップを作図できました。
・方策ラベル
最後に、方策を矢印で可視化します。
例として、ダミーの方策のディクショナリを作成しておきます。
# 仮の方策を作成 pi = {state: {0: 0.4, 1: 0.15, 2: 0.4, 3: 0.05} for state in env.states()}
状態(マスのインデックス)をキー、確率論的方策のディクショナリを値とします。ディクショナリの値として、ディクショナリを格納します。
完成形を確認しましょう。
# 方策を重ねた状態価値関数のヒートマップを作図
env.render_v(v=V, policy=pi)
マスごとに、確率が最大の行動を矢印で表現します。最大値が複数ある場合は、複数の行動を描画します。ただし、エージェントがゴールに辿り着くとエピソードが終了するという問題設定なので、ゴールのマスでは描画されません。
ここでは簡単な例として全ての状態(マス)で同じ方策にしていますが、実際には状態ごとに確率分布が異なります。
確率が最大の行動を抽出する処理を確認します。
# 状態を指定 state = (0, 0) # 確率論的方策を抽出 actions = pi[state] print(actions) # 確率が最大の行動を抽出 max_actions = [key for key, value in actions.items() if value == max(actions.values())] print(max_actions)
{0: 0.4, 1: 0.15, 2: 0.4, 3: 0.05}
[0, 2]
全ての方策pi
から、指定した状態の確率論的方策actions
を抽出します。
actions
から、キー(行動番号)key
と値(確率)value
を順番に取り出して、値が最大値であればキーをリストに格納します。
この処理を各状態で行います。
行動に対応した矢印を描画するのに利用するリストを作成します。
# 矢印の描画用のリストを作成 arrows = ['↑', '↓', '←', '→'] offsets = [(0, 0.1), (0, -0.1), (-0.1, 0), (0.1, 0)]
ラベルの表示位置を調整するときと、エージェントの移動のときと同じ要領で、行動(矢印)ごとに表示位置を矢印と同じ方向に0.1
ズラします。そのためのx軸とy軸の値をリストに格納します。
確率が最大の行動を矢印で表示します。
## 確率的方策を作図 # マスを描画 plt.figure(figsize=(9, 6)) # 図の設定 plt.xticks(ticks=np.arange(xs)) # x軸の目盛位置 plt.yticks(ticks=np.arange(ys)) # y軸の目盛位置 plt.xlim(xmin=0, xmax=xs) # x軸の範囲 plt.ylim(ymin=0, ymax=ys) # y軸の範囲 plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) # 軸ラベル plt.grid() # グリッド線 # 状態価値のヒートマップを描画 plt.pcolormesh(np.flipud(v), cmap=cmap, vmin=vmin, vmax=vmax) # ヒートマップマップ # マスごとに処理 for y in range(ys): for x in range(xs): # 装飾するマス(状態)を設定 state = (y, x) # 報酬を抽出 r = env.reward_map[y, x] # 報酬がある場合 if r != 0 and r is not None: # 報酬ラベル用の文字列を作成 txt = 'R ' + str(r) # ゴールの場合 if state == env.goal_state: # 報酬ラベルにゴールを追加 txt = txt + ' (GOAL)' # 報酬ラベルを描画 plt.text(x=x+0.1, y=ys-y-0.9, s=txt) # 壁以外の場合 if state != env.wall_state: # 状態価値ラベルを描画 plt.text(x=x+0.6, y=ys-y-0.15, s=str(np.round(v[y, x], 2))) # 確率論的方策を抽出 actions = pi[state] # 確率が最大の行動を抽出 max_actions = [k for k, v in actions.items() if v == max(actions.values())] # 行動ごとに処理 for action in max_actions: # 矢印の描画用の値を抽出 arrow = arrows[action] offset = offsets[action] # ゴールの場合 if state == env.goal_state: # 描画せず次の状態へ continue # 方策ラベル(矢印)を描画 plt.text(x=x+0.45+offset[0], y=ys-y-0.5+offset[1], s=arrow) # 壁の場合 if state == env.wall_state: # 壁を描画 plt.gca().add_patch(plt.Rectangle(xy=(x, ys-y-1), width=1, height=1, fc=(0.4, 0.4, 0.4, 1.0))) # 長方形を重ねる plt.title('Policy', fontsize=20) # タイトル plt.show()
以上で、方策のラベルを描画できました。
以上が、状態価値関数の可視化メソッドの処理です。実際の実装では、状態価値関数や方策が指定されていないときの場合分けや、状態価値ラベルの有無を設定するprint_value
引数、サイズによってグラフを調整するなどの処理が含まれています。また、グラフをインスタンス変数として扱うため、オブジェクト指向で作図しています。
・render_qメソッド
「行動価値関数の可視化メソッド」の処理を確認するつもりですが、5章まで登場しないようなので、5章を読んでから書き足します。
この項では、GridWorld
クラスを確認しました。以降の節で利用します。
参考文献
- 斎藤康毅『ゼロから作るDeep Learning 4 ――強化学習編』オライリー・ジャパン,2022年.
- サポートページ:GitHub - oreilly-japan/deep-learning-from-scratch-4
おわりに
思ったより大変でした。でもよく理解できたと思います。最後に書くのもなんですが、各メソッドを使うタイミングで確認するぐらいでいいのではないでしょうか。
投稿日の前日に公開された新MVをどうぞ♪
【次節の内容】