からっぽのしょこ

読んだら書く!書いたら読む!同じ事は二度調べ(たく)ない

6.2:SARSA【ゼロつく4のノート】

はじめに

 『ゼロから作るDeep Learning 4 ――強化学習編』の独学時のまとめノートです。初学者の補助となるようにゼロつくシリーズの4巻の内容に解説を加えていきます。本と一緒に読んでください。

 この記事は、6.2節の内容です。SARSAにより行動価値関数を推定します。

【前節の内容】

www.anarchive-beta.com

【他の記事一覧】

www.anarchive-beta.com

【この記事の内容】

6.2 SARSA

 方策オン型のSARSAにより行動価値関数と方策を推定(方策を制御)します。

6.2.1 方策オン型のSARSA

 まずは、方策オン型のSARSAによる行動価値関数と方策の更新式を導出します。

数式の確認

 行動価値関数の定義式とMC法の計算式、ベルマン方程式を確認して、TD法による計算式を導出します。

行動価値関数の定義式

 各時刻の収益は、割り引き報酬和で定義されました(2.3.2項・3.1.2項)。

$$ \begin{align} G_t &= R_t + \gamma R_{t+1} + \gamma^2 R_{t+2} + \cdots \tag{6.1}\\ &= R_t + \gamma G_{t+1} \tag{6.2} \end{align} $$

 ここで、$\gamma$は割引率で$0 \leq \gamma \leq 1$の値を指定します。
 また、行動価値関数は、状態$s$で行動$a$を行った後の収益$G_t$の期待値で定義されました(3.3.1項)。

$$ \begin{align} q_{\pi}(s, a) &= \mathbb{E}_{\pi}[G_t | S_t = s, A_t = a] \tag{1}\\ &= \mathbb{E}_{\pi}[R_t + \gamma G_{t+1} | S_t = s, A_t = a] \tag{2} \end{align} $$

 アルゴリズムごとに行動価値関数(収益の期待値)を推定(近似)する方法(計算式)が異なります。

MC法の更新式とベルマン方程式

 MC法では、収益と状態・行動のサンプルを用いて、行動価値関数(1)を近似するのでした(5.4.1項)。

$$ \begin{align} q_{\pi}(s, a) &= \mathbb{E}_{\pi}[G_t | S_t = s, A_t = a] \tag{1}\\ &\simeq Q_{\pi}(s, a) \end{align} $$

 行動価値関数(収益の期待値)の推定値(近似値)として、指数移動平均を用いました(5.4.2項)。

$$ Q'_{\pi}(S_t, A_t) = Q_{\pi}(S_t, A_t) + \alpha \Bigl\{ G_t - Q_{\pi}(S_t, A_t) \Bigr\} $$

 ここで、$\alpha$は学習率で$0 < \alpha < 1$の値を指定します。また、更新後の行動価値関数を$Q'_{\pi}$で表します。
 繰り返し収益のサンプルを生成して行動価値関数を更新することで、推定値$Q_{\pi}(s, a)$を真の値$q_{\pi}(s, a)$に近付けます。

 行動価値関数のベルマン方程式は、行動価値関数(2)より次の式になりました(3.3.2項)。

$$ \begin{align} q_{\pi}(s, a) &= \mathbb{E}_{\pi}[R_t + \gamma G_{t+1} | S_t = s, A_t = a] \tag{2}\\ &= \sum_{s'} p(s' | s, a) \left\{ r(s, a, s') + \gamma \sum_{a'} \pi(a' | s') q_{\pi}(s', a') \right\} \tag{3.14} \end{align} $$

 次の時刻の行動$a'$が式に含まれます。

 MC法の更新式とベルマン方程式を組み合わせて、TD法の更新式を求めます。

SARSAの更新式

 ベルマン方程式(3.14)を、状態遷移確率$p(s' | s, a)$による期待値の項に変形します。

$$ \begin{align} q_{\pi}(s, a) &= \sum_{s'} p(s' | s, a) \left\{ r(s, a, s') + \gamma \sum_{a'} \pi(a' | s') q_{\pi}(s', a') \right\} \tag{3.14}\\ &= \mathbb{E}_{\pi} \left[ r(s, a, s') + \gamma \sum_{a'} \pi(a' | s') q_{\pi}(s', a') \middle| S_t = s, A_t = a \right] \end{align} $$

 この期待値計算をサンプリングによって求めます。
 報酬関数$r(s, a, s')$について、「次の状態のサンプル$S_{t+1}$」によって得られる「現在の報酬のサンプル$R_t$」

$$ \begin{aligned} S_{t+1} &\sim p(s' | s, a) \\ R_t &= r(S_t, A_t, S_{t+1}) \end{aligned} $$

を用います。
 また、「次の時刻の確率論的方策$\pi(a' | s')$」による「次の時刻の行動価値関数$q_{\pi}(s', a')$」の期待値(の因子)$\sum_{a'} \pi(a' | s') q_{\pi}(s', a')$について、「次の状態のサンプル$S_{t+1}$」と「次の時刻の行動のサンプル$A_{t+1}$」

$$ \begin{aligned} &A_{t+1} \sim \pi(a' | s') \\ &Q_{\pi}(S_{t+1} = s', A_{t+1} = a') \end{aligned} $$

を用います。

$$ \begin{align} Q_{\pi}(S_t, A_t) &= \mathbb{E}_{\pi}[ R_t + \gamma Q_{\pi}(S_{t+1}, A_{t+1}) | S_t = s, A_t = a ] \end{align} $$

 $R_t + \gamma Q_{\pi}(S_{t+1}, A_{t+1})$の期待値を、指数移動平均で近似します。

$$ Q_{\pi}'(S_t, A_t) = Q_{\pi}(S_t, A_t) + \alpha \Bigl\{ R_t + \gamma Q_{\pi}(S_{t+1}, A_{t+1}) - Q_{\pi}(S_t, A_t) \Bigr\} \tag{6.10} $$

 SARSAによる行動価値関数の更新式が得られました。この式は、DP法による状態価値関数の更新式(6.9)と同様にして展開できます。
 「現在の状態$S_t$に関する行動価値関数$Q_{\pi}(S_t, A_t)$」の計算に、「次の時刻の行動$A_{t+1}$」を使っているのが分かります。$S_t$は時刻$t-1$、$A_t, R_t, S_{t+1}$は時刻$t$、$A_{t+1}$は時刻$t+1$において、環境またはエージェントによって得られるサンプルです。
 つまり、新たに得たサンプルデータを使って、1つ前の時刻に関する行動価値関数を更新する処理を繰り返します。

方策の改善

 行動価値関数$Q_{\pi}(s, a)$をgreedy化して、決定論的方策を計算します(5.4.1項)。

$$ \mu(S_t) = \mathop{\mathrm{argmax}}\limits_a Q_{\pi}(S_t, a) $$

 ターゲット方策として利用します。
 または、ε-greedy化して、確率論的方策を計算します(5.4.3項)。

$$ \pi'(a | S_t) = \begin{cases} \mathop{\mathrm{argmax}}\limits_a Q_{\pi}(S_t, a) &\quad (1 - \epsilon) \\ \mathrm{random} &\quad (\epsilon) \end{cases} \tag{6.11} $$

 こちらは、挙動方策として利用します。ただし、方策オン型ではターゲット方策としても利用します。

6.2.2 SARSAの実装

 次は、方策オン型のSARSAにより行動価値関数と方策の推定を行うエージェントを実装します。

 利用するライブラリを読み込みます。

# ライブラリを読み込み
import numpy as np
from collections import defaultdict, deque

# 追加ライブラリ
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.animation import FuncAnimation

 更新推移をアニメーションで確認するのにmatplotlibのモジュールを利用します。不要であれば省略してください。

 また、3×4マスのグリッドワールドのクラスGridWorldを読み込みます。

# 実装済みのクラスと関数を読み込み
import sys
sys.path.append('../deep-learning-from-scratch-4-master')
from common.gridworld import GridWorld
from common.utils import greedy_probs

 実装済みクラスの読み込みについては「3.6.1:MNISTデータセットの読み込み【ゼロつく1のノート(Python)】 - からっぽのしょこ」、GridWorldクラスについては「4.2.1:GridWorldクラスの実装:評価と改善に関するメソッド【ゼロつく4のノート】 - からっぽのしょこ」「4.2.1:GridWorldクラスの実装:可視化に関するメソッド【ゼロつく4のノート】 - からっぽのしょこ」、greedy_probs関数については「5.4.3-5:モンテカルロ法による方策反復法の実装【ゼロつく4のノート】 - からっぽのしょこ」を参照してください。

処理の確認

 SarsaAgentクラスのupdateメソッドの内部で行う処理を確認します。他のメソッドについては「5.3:モンテカルロ法による方策評価の実装【ゼロつく4のノート】 - からっぽのしょこ」を参照してください。

 updateメソッドの実行時点を現在の時刻として、1つ前の時刻の行動価値を計算します。ただし、メソッド内のオブジェクト名は、更新する時刻(前の時刻)の状態・行動がstate, action、現在の時刻の状態・行動がnext_state, next_actionになっています。

dequeの使い方

 最大要素数を設定できるリストを確認しておきます。

# 最大要素数を指定したリストを初期化
memory = deque(maxlen=2)
print(memory)
deque([], maxlen=2)

 collectionsライブラリのdeque()を使ってリストを作成します。最大要素数の引数maxlen2を指定します。

 リストに値を繰り返し格納してみます。

# 1要素ずつ値を格納
for n in range(10):
    # 値を格納
    memory.append(n)
    
    # リストを確認
    print(memory)
deque([0], maxlen=2)
deque([0, 1], maxlen=2)
deque([1, 2], maxlen=2)
deque([2, 3], maxlen=2)
deque([3, 4], maxlen=2)
deque([4, 5], maxlen=2)
deque([5, 6], maxlen=2)
deque([6, 7], maxlen=2)
deque([7, 8], maxlen=2)
deque([8, 9], maxlen=2)

 3つ目の要素を格納すると1つ目の(一番古い)要素が削除され、要素数が2のままなのが分かります。

 同様に、リストにダミーのサンプルデータを繰り返し格納します。

# 最大要素数を指定したリストを初期化
memory = deque(maxlen=2)

# ダミーのサンプルデータを格納
for n in range(10):
    # タプルを格納
    memory.append(((n, n), n, n, False))
    
    # 初回は処理しない
    if len(memory) < 2:
        continue
    
    # リストを確認
    print(memory)
deque([((0, 0), 0, 0, False), ((1, 1), 1, 1, False)], maxlen=2)
deque([((1, 1), 1, 1, False), ((2, 2), 2, 2, False)], maxlen=2)
deque([((2, 2), 2, 2, False), ((3, 3), 3, 3, False)], maxlen=2)
deque([((3, 3), 3, 3, False), ((4, 4), 4, 4, False)], maxlen=2)
deque([((4, 4), 4, 4, False), ((5, 5), 5, 5, False)], maxlen=2)
deque([((5, 5), 5, 5, False), ((6, 6), 6, 6, False)], maxlen=2)
deque([((6, 6), 6, 6, False), ((7, 7), 7, 7, False)], maxlen=2)
deque([((7, 7), 7, 7, False), ((8, 8), 8, 8, False)], maxlen=2)
deque([((8, 8), 8, 8, False), ((9, 9), 9, 9, False)], maxlen=2)

 ダミーのサンプルデータ(状態・行動・報酬・ゴールフラグ)に対応する4つの値をまとめたタプルを、append()でリストに格納していきます。
 リストの要素数を条件とすることで、初回の処理を飛ばせます。

 リストからサンプルデータを取り出します。

# 前の時刻のサンプルデータを取得
state, action, reward, done = memory[0]
print(state, action, reward, done)

# 現在の時刻のサンプルデータを取得
next_state, next_action, _, _ = memory[1]
print(next_state, next_action)
(8, 8) 8 8 False
(9, 9) 9

 0番目の要素が前の時刻(更新を行う時刻)のサンプルデータ、1番目の要素が現在の時刻(更新時点からみて次の時刻)のサンプルデータです。現在の時刻の報酬とゴールフラグは計算に利用しないので、オブジェクト名を_としておきます。

 続いて、取り出したサンプルデータを使って、行動価値関数を更新する処理を確認します。

更新の計算

 例として、ランダムな値の行動価値関数を作成しておきます。

# (仮の)前の時刻のサンプルデータを設定
state, action, reward, done = (0, 1), 1, 1, False

# (仮の)現在の時刻のサンプルを設定
next_state, next_action = (0, 2), 3

# 行動の種類数を指定
action_size = 4

# (仮の)行動価値関数を作成
Q = {(s, a): np.random.rand() for s in [state, next_state] for a in range(action_size)}
print(list(Q.keys()))
print(np.round(list(Q.values()), 3))
[((0, 1), 0), ((0, 1), 1), ((0, 1), 2), ((0, 1), 3), ((0, 2), 0), ((0, 2), 1), ((0, 2), 2), ((0, 2), 3)]
[0.594 0.955 0.427 0.831 0.944 0.18  0.637 0.97 ]

 前の状態stateと現在の状態next_stateごとに、上下左右の4つの行動に対する値をディクショナリに格納します。

 現在の状態・行動の行動価値を取り出します。

# 現在の状態・行動の行動価値を取得
if done:
    next_q = 0 
else:
    next_q = Q[next_state, next_action]
print(next_q)
0.9699877748547381

 「現在の時刻の状態next_stateと行動next_action」をキーとして、「行動価値関数Q」から値を取り出して「現在の時刻の行動価値next_q」とします。ただし、現在の状態がゴールマスのときは、行動価値を0にします。

 前の時刻の行動価値を計算して、値を更新します。

# 収益の計算用の割引率を指定
gamma = 0.9

# 状態価値の計算用の学習率
alpha = 0.01

# TDターゲットを計算
target = reward + gamma * next_q

# 前の状態・行動の行動価値を更新:式(6.10)
Q[state, action] += (target - Q[state, action]) * alpha
print(Q[state, action])
0.9639834062967971

 「前の時刻の状態stateと行動action」をキーとして、式(6.10)により行動価値を計算して、「前の時刻の行動価値Q[state, action]」を更新します。

 以上が、SARSAによる方策制御を行うエージェントの処理です。

実装

 処理の確認ができたので、方策オン型のSARSAにおけるエージェントをクラスとして実装します。

# 方策オン型のSARSAによるエージェントの実装
class SarsaAgent:
    # 初期化メソッドの定義
    def __init__(self):
        # パラメータを指定
        self.gamma = 0.9 # 収益の計算用の割引率
        self.alpha = 0.8 # 状態価値の計算用の学習率
        self.epsilon = 0.1 # ランダムに行動する確率
        self.action_size = 4 # 行動の種類数
        
        # オブジェクトを初期化
        random_actions = {0: 0.25, 1: 0.25, 2: 0.25, 3: 0.25} # 確率論的方策
        self.pi = defaultdict(lambda: random_actions) # ターゲット方策・挙動方策
        self.Q = defaultdict(lambda: 0) # 行動価値関数
        self.memory = deque(maxlen=2) # サンプルデータ
    
    # 行動メソッドの定義
    def get_action(self, state):
        # 現在の状態の挙動方策の確率分布を取得
        action_probs = self.pi[state] # 確率分布
        actions = list(action_probs.keys()) # 行動番号
        probs = list(action_probs.values()) # 行動確率
        
        # 確率論的方策に従う行動を出力
        return np.random.choice(actions, p=probs)
    
    # サンプルデータの初期化メソッドの定義
    def reset(self):
        # リストを初期化
        self.memory.clear()
    
    # 更新メソッドの定義
    def update(self, state, action, reward, done):
        # 現在の時刻のサンプルデータを格納
        self.memory.append((state, action, reward, done))
        
        # 初回は更新しない
        if len(self.memory) < 2:
            return
        
        # サンプルデータを取得
        state, action, reward, done = self.memory[0] # 前の時刻
        next_state, next_action, _, _ = self.memory[1] # 現在の時刻
        
        # 現在の時刻の行動価値を取得
        next_q = 0 if done else self.Q[next_state, next_action] # ゴールの場合は0を設定
        
        # 前の時刻の行動価値関数を更新:式(6.10)
        target = reward + self.gamma * next_q
        self.Q[state, action] += (target - self.Q[state, action]) * self.alpha
        
        # ε-greedy法により方策を更新:式(6.11)
        self.pi[state] = greedy_probs(self.Q, state, self.epsilon)


 実装したクラスを試してみましょう。

 環境(グリッドワールド)とエージェントのインスタンスを作成して、1エピソードの処理を行います。

# 環境・エージェントのインスタンスを作成
env = GridWorld()
agent = SarsaAgent()

# 行動の表示用のリストを作成
arrows = ['↑', '↓', '←', '→']

# 最初の状態を設定
state = env.start_state

# 時刻(試行回数)を初期化
t = 0

# 1エピソードのシミュレーション
while True:
    # 時刻をカウント
    t += 1

    # ε-greedy法により行動を決定
    action = agent.get_action(state)

    # サンプルデータを取得
    next_state, reward, done = env.step(action)

    # 前の時刻の状態・行動の行動価値関数・方策を更新:式(6.10-11)
    agent.update(state, action, reward, done)

    # サンプルデータを表示
    print(
        't=' + str(t) + 
        ', S_t=' + str(state) + 
        ', A_t=' + arrows[action] + 
        ', S_t+1=' + str(next_state) + 
        ', R_t=' + str(reward)
    )
    
    # ゴールに着いた場合
    if done:
        # 現在の状態・行動の行動価値関数・方策を更新:式(6.10-11)
        agent.update(next_state, None, None, None)

        # エピソードを終了
        break

    # 状態を更新
    state = next_state
t=1, S_t=(2, 0), A_t=→, S_t+1=(2, 1), R_t=0
t=2, S_t=(2, 1), A_t=↑, S_t+1=(2, 1), R_t=0
t=3, S_t=(2, 1), A_t=↑, S_t+1=(2, 1), R_t=0
t=4, S_t=(2, 1), A_t=←, S_t+1=(2, 0), R_t=0
t=5, S_t=(2, 0), A_t=↓, S_t+1=(2, 0), R_t=0
t=6, S_t=(2, 0), A_t=↑, S_t+1=(1, 0), R_t=0
t=7, S_t=(1, 0), A_t=↓, S_t+1=(2, 0), R_t=0
t=8, S_t=(2, 0), A_t=↓, S_t+1=(2, 0), R_t=0
t=9, S_t=(2, 0), A_t=↓, S_t+1=(2, 0), R_t=0
t=10, S_t=(2, 0), A_t=→, S_t+1=(2, 1), R_t=0
t=11, S_t=(2, 1), A_t=↑, S_t+1=(2, 1), R_t=0
t=12, S_t=(2, 1), A_t=↑, S_t+1=(2, 1), R_t=0
t=13, S_t=(2, 1), A_t=→, S_t+1=(2, 2), R_t=0
t=14, S_t=(2, 2), A_t=→, S_t+1=(2, 3), R_t=0
t=15, S_t=(2, 3), A_t=↓, S_t+1=(2, 3), R_t=0
t=16, S_t=(2, 3), A_t=↑, S_t+1=(1, 3), R_t=-1.0
t=17, S_t=(1, 3), A_t=→, S_t+1=(1, 3), R_t=-1.0
t=18, S_t=(1, 3), A_t=↓, S_t+1=(2, 3), R_t=0
t=19, S_t=(2, 3), A_t=↓, S_t+1=(2, 3), R_t=0
t=20, S_t=(2, 3), A_t=↓, S_t+1=(2, 3), R_t=0
t=21, S_t=(2, 3), A_t=↑, S_t+1=(1, 3), R_t=-1.0
t=22, S_t=(1, 3), A_t=↑, S_t+1=(0, 3), R_t=1.0

 agentget_action()で方策に従い行動して、envstep()で状態を遷移し報酬を出力します。
 得られたサンプルデータ(1つ前の状態・行動・報酬と現在の状態・行動)を使って、agentupdate()で1つ前の状態と行動の行動価値関数と方策を計算します。ただし、2時刻分のサンプルデータが必要なので、初回は更新されません。
 ゴールマスに着くとdoneTrueに設定されるので、ダミーの次のサンプルデータを用意して最後の時刻の更新を行い、breakでループ処理を終了します。

 行動価値関数をヒートマップで確認します。

# 行動価値関数のヒートマップと方策ラベルを作図
env.render_q(q=agent.Q)

1エピソード更新した行動価値関数のヒートマップと方策ラベル

 render_q()内部のnp.argmax()の仕様で、行動価値が等しいとインデックスが最小の行動がラベルで表示されます。

 以上で、方策オン型のSARSAのエージェントを実装できました。

・方策オン型のSARSAによる方策制御

 最後に、方策オン型のSARSAにより行動価値関数を推定して、更新の推移を確認します。

推定

 方策オン型のSARSAにより行動価値関数と方策を繰り返し更新します。

# 環境・エージェントのインスタンスを作成
env = GridWorld()
agent = SarsaAgent()

# エピソード数を指定
episodes = 1000

# 推移の可視化用のリストを初期化
trace_Q = [{(state, action): agent.Q[(state, action)] for state in env.states() for action in env.action_space}] # 初期値を記録

# 繰り返しシミュレーション
for episode in range(episodes):
    # 状態を初期化
    state = env.reset()
    
    # サンプルデータを初期化
    agent.reset()
    
    # 時刻(試行回数)を初期化
    t = 0
    
    # 1エピソードのシミュレーション
    while True:
        # 時刻をカウント
        t += 1
        
        # ε-greedy法により行動を決定
        action = agent.get_action(state)
        
        # サンプルデータを取得
        next_state, reward, done = env.step(action)
        
        # 前の時刻の状態・行動の行動価値関数・方策を更新:式(6.10-11)
        agent.update(state, action, reward, done)
        
        # ゴールに着いた場合
        if done:
            # 現在の状態・行動の行動価値関数・方策を更新:式(6.10-11)
            agent.update(next_state, None, None, None)
            
            # 更新値を記録
            trace_Q.append(agent.Q.copy())
            
            # 総時刻を表示
            print('episode '+str(episode+1) + ': T='+str(t))
            
            # エピソードを終了
            break
        
        # 状態を更新
        state = next_state
episode 1: T=183
episode 2: T=21
episode 3: T=26
episode 4: T=57
episode 5: T=5
(省略)
episode 996: T=7
episode 997: T=6
episode 998: T=17
episode 999: T=7
episode 1000: T=10

 スタートマスからε-greedy法により行動し、ゴールマスに着くまでを1エピソードとします。エピソードごとに、GridWorldクラスのreset()メソッドで状態を初期化し(エージェントをスタートマスに戻し)、SarsaAgentクラスのreset()メソッドでサンプルデータを初期化(過去のデータを削除)します。
 episodesに指定した回数のシミュレーションを行い、時刻ごとに繰り返しagentupdate()で現在の状態と行動の行動価値関数と方策を更新します。ただし、更新には2時刻分のサンプルデータを利用するので、初回は更新されず、最後の時刻はダミーデータを使って更新します。

 推移の確認用に、行動価値関数の更新値をtrace_Qに格納していきます。

 推定した行動価値関数をヒートマップと方策ラベルで確認します。

# 行動価値関数のヒートマップを作図
env.render_q(q=agent.Q)

状態価値関数のヒートマップと方策ラベル

 結果の解釈については本を参照してください。

更新推移の可視化

 ここまでで、繰り返しの更新処理を確認しました。続いて、途中経過をアニメーションで確認します。

 行動価値関数のヒートマップのアニメーションを作成します。

・作図コード(クリックで展開)

# グリッドマップのサイズを取得
xs = env.width
ys = env.height

# 状態価値の最大値・最小値を取得
qmax = max([max(trace_Q[i].values()) for i in range(len(trace_Q))])
qmin = min([min(trace_Q[i].values()) for i in range(len(trace_Q))])

# 色付け用に最大値・最小値を再設定
qmax = max(qmax, abs(qmin))
qmin = -1 * qmax
qmax = 1 if qmax < 1 else qmax
qmin = -1 if qmin > -1 else qmin

# カラーマップを設定
color_list = ['red', 'white', 'green']
cmap = LinearSegmentedColormap.from_list('colormap_name', color_list)

# 図を初期化
fig = plt.figure(figsize=(12, 9), facecolor='white') # 図の設定
plt.suptitle('On-policy SARSA', fontsize=20) # 全体のタイトル

# 作図処理を関数として定義
def update(i):
    # 前フレームのグラフを初期化
    plt.cla()
    
    # i回目の更新値を取得
    Q = trace_Q[i]
    
    # マス(状態)ごとに処理
    for state in env.states():
        # 行動ごとに処理
        for action in env.action_space:
            # インデックスを取得
            y, x = state
            
            # 報酬を抽出
            r = env.reward_map[y, x]
            
            # 報酬がある場合
            if r != 0 and r is not None:
                # 報酬ラベル用の文字列を作成
                txt = 'R ' + str(r)

                # ゴールの場合
                if state == env.goal_state:
                    # 報酬ラベルにゴールを追加
                    txt = txt + ' (GOAL)'

                # 報酬ラベルを描画
                plt.text(x=x+0.1, y=ys-y-0.9, s=txt, 
                         ha='left', va='bottom', fontsize=15)
            
            # ゴールの場合
            if state == env.goal_state:
                # 描画せず次の状態へ
                continue

            # 作図用のx軸・y軸の値を設定
            tx, ty = x, ys-y-1

            # 行動ごとの三角形の頂点を設定
            action_map = {
                0: ((0.5+tx, 0.5+ty), (1.0+tx, 1.0+ty), (tx, 1.0+ty)), # 上
                1: ((tx, ty), (1.0+tx, ty), (0.5+tx, 0.5+ty)), # 下
                2: ((tx, ty), (0.5+tx, 0.5+ty), (tx, 1.0+ty)), # 左
                3: ((0.5+tx, 0.5+ty), (1.0+tx, ty), (1.0+tx, 1.0+ty)) # 右
            }

            # 行動ごとの価値ラベルのプロット位置を設定
            offset_map = {
                0: (0.5, 0.75), # 上
                1: (0.5, 0.25), # 下
                2: (0.25, 0.5), # 左
                3: (0.75, 0.5)  # 右
            }

            # 壁の場合
            if state == env.wall_state:
                # 壁を描画
                rect = plt.Rectangle(xy=(tx, ty), width=1, height=1, 
                                     fc=(0.4, 0.4, 0.4, 1.0)) # 長方形を作成
                plt.gca().add_patch(rect) # 重ねて描画
            # (よく分からない)
            elif state in env.goal_state:
                plt.gca().add_patch(plt.Rectangle(xy=(tx, ty), width=1, height=1, fc=(0.0, 1.0, 0.0, 1.0)))
            # 壁以外の場合
            else:
                # 行動価値を抽出
                tq = Q[(state, action)]

                # 行動価値を0から1に正規化
                color_scale = 0.5 + (tq / qmax) / 2

                # 三角形を描画
                poly = plt.Polygon(action_map[action],fc=cmap(color_scale)) # 三角形を作成
                plt.gca().add_patch(poly) # 重ねて描画

                # プロット位置の調整値を取得
                offset = offset_map[action]

                # 行動価値ラベルを描画
                plt.text(x=tx+offset[0], y=ty+offset[1], s=str(np.round(tq, 3)), 
                         ha='center', va='center', size=15) # 行動価値ラベル
    
    # グラフの設定
    plt.xticks(ticks=np.arange(xs)) # x軸の目盛位置
    plt.yticks(ticks=np.arange(ys), labels=ys-np.arange(ys)-1) # y軸の目盛位置
    plt.xlim(xmin=0, xmax=xs) # x軸の範囲
    plt.ylim(ymin=0, ymax=ys) # y軸の範囲
    plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) # 軸ラベル
    plt.grid() # グリッド線
    plt.title('episode:'+str(i), loc='left') # タイトル

# gif画像を作成
anime = FuncAnimation(fig=fig, func=update, frames=len(trace_Q), interval=50)

# gif画像を保存
anime.save('ch6_2.gif')

 各エピソードで更新した行動価値をtrace_Qから取り出してヒートマップを描画する処理を関数update()として定義して、FuncAnimation()でアニメーション(gif画像)を作成します。


行動価値関数のヒートマップの推移

 行動価値関数の更新値の推移を折れ線グラフで確認します。

・作図コード(クリックで展開)

# 行動ラベルを設定
arrows = ['↑', '↓', '←', '→']

# 状態価値関数の推移を作図
plt.figure(figsize=(15, 10), facecolor='white')
for state in env.states():
    for action in range(agent.action_size):
        # 更新値を抽出
        q_vals = [trace_Q[i][(state, action)] for i in range(episodes+1)]
        
        # 各状態の価値の推移を描画
        plt.plot(np.arange(episodes+1), q_vals, 
                 alpha=0.5, label='$Q_i(L_{'+str(state[0])+','+str(state[1])+'},'+arrows[action]+')$')
plt.xlabel('episode')
plt.ylabel('action-value')
plt.suptitle('On-policy SARSA', fontsize=20)
plt.title('$\gamma='+str(agent.gamma) + ', \\alpha='+str(agent.alpha)+'$', loc='left')
plt.grid()
plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=2)
plt.show()


行動価値関数の推移

 行番号を$h$、列番号を$w$として各マスを$L_{h,w}$で表します(図4-9)。また、$i$回目の行動価値を$Q_i(L_{h,w}, A)$で表します。
 各曲線の縦軸の値が、ヒートマップの色に対応します。

 この節では、方策オン型のSARSAにより方策制御を行うエージェントを実装して、最適方策を求めました。次節では、方策オフ型のSARSAを実装して、最適方策を求めます。

参考文献


おわりに

 ベルマン方程式をサンプリングで代用する計算(操作)について悩んでたんですけど、Q学習の節で図解されてました。

【次節の内容】

www.anarchive-beta.com