はじめに
『ゼロから作るDeep Learning 4 ――強化学習編』の独学時のまとめノートです。初学者の補助となるようにゼロつくシリーズの4巻の内容に解説を加えていきます。本と一緒に読んでください。
この記事は、1.4.1節の内容です。簡単なスロットマシンを実装します。
【前節の内容】
【他の記事一覧】
【この記事の内容】
1.4.1 スロットマシンの実装
バンディット(スロットマシン)の機能を持つBandit
クラスの処理を確認します。
利用するライブラリを読み込みます。
# 利用するライブラリ import numpy as np import matplotlib.pyplot as plt
・処理の確認
Bandit
クラスの内部で行われる処理を確認します。
スロットマシンの数を指定して、マシンごとに確率を設定します。
# マシンの数を指定 arms = 10 # 各マシンの確率をランダムに設定 rates = np.random.rand(arms) print(np.round(rates, 3))
[0.793 0.812 0.363 0.337 0.141 0.845 0.989 0.989 0.328 0.692]
np.random.rand()
で0から1の一様乱数を生成して、各マシンの当たりの確率として使います。
乱数をヒストグラムにすると次のようになります。
# データ数を指定 N = 10000 # 一様乱数を生成 r_vals = np.random.rand(N) # 乱数のヒストグラムを作成 plt.figure(figsize=(8, 6)) plt.hist(r_vals) # 乱数 plt.xlabel('r') plt.ylabel('count') plt.title('N='+str(N), loc='left') # データ数 plt.suptitle('np.random.rand(N)', fontsize=20) plt.grid() plt.show()
0から1の範囲で偏りなく生成されるのが分かります。
プレイするマシン番号を指定して、そのマシンの確率を取り出します。
# マシン番号を指定 arm = 0 # arm番目のマシンの確率を抽出 rate = rates[arm] print(rate)
0.7925828509113335
マシン番号をarm
として整数を指定します。マシン番号は、各マシンの確率rates
の要素番号(インデックス)と対応します。
rates
からarm
番目の要素を取り出して、指定したマシンの当たりの確率rate
とします。
乱数を生成して、当たり外れを判定します。
# 乱数を生成 r = np.random.rand() print(r) # 当たり外れを判定 if rate > r: # 当たりの場合の報酬を出力 print(1) else: # 外れの場合の報酬を出力 print(0)
0.15586509410556726
1
この例では、当たりであれば報酬が1、外れであれば報酬が0のマシンとします。
0から1の一様乱数を生成して、当たりの確率rate
と比較します。rate
の方が大きければ1
を、rate
未満であれば0
を返します。
以上が、Bandit
クラスで行う処理です。
当たり外れの結果が、設定した確率に対応しているのか確認します。
# データ数を指定 N = 10000 # 一様乱数を生成 r_vals = np.random.rand(N) print(np.round(r_vals[:10], 3)) # 確率と乱数の値を比較 result = rate > r_vals print(result[:10]) # 当たりの数をカウント true_num = np.sum(result) print(true_num)
[0.1 0.167 0.666 0.333 0.712 0.216 0.437 0.806 0.156 0.918]
[ True True True True True True True False True False]
7948
複数の乱数を生成してrate
と比較します。当たり(rate
の方が値が大きい)であればTrue
、外れ(rate
が乱数未満)であればFalse
になります。
True
は1
、False
は0
として扱われるので、np.sum()
で当たりの数をカウントできます。
当たりと外れの比率を棒グラフで確認します。
# 当たりと外れの比率の棒グラフを作成 plt.figure(figsize=(8, 6)) plt.bar([0, 1], np.round([N-true_num, true_num]) / N) # 当たり外れの比率 plt.xlabel('r') plt.ylabel('count') plt.xticks(ticks=[0, 1], labels=['false', 'true']) # 当たり外れのラベル plt.title('N='+str(N) + ', rate='+str(np.round(rate, 3)), loc='left') # データ数と確率 plt.suptitle('rate > np.random.rand(N)', fontsize=20) plt.grid() plt.ylim(0, 1) plt.show()
マシンの確率rate
の値に従う結果が得られたのが分かります。
・実装
Bandit
クラスは、次のページを参照してください。
実装したクラスを試してみます。
バンディットクラスのインスタンスを作成します。
# マシンの数を指定 arms = 10 # インスタンスを作成 bandit = Bandit(arms) # 各マシンの確率を確認 print(np.round(bandit.rates, 3))
[0.132 0.619 0.839 0.2 0.547 0.993 0.259 0.797 0.059 0.06 ]
インスタンスを作成する度に、ランダムに値が設定されるので、各マシンの確率が変わります。
マシンのプレイ回数と、プレイするマシン番号を指定して、マシンをプレイします。
# 試行回数を指定 N = 5 # マシン番号を指定 arm = 0 # 繰り返し試行 for n in range(N): # arm番目のマシンをプレイ reward = bandit.play(arm) print(reward)
0
1
0
0
0
当たりなら報酬が1、外れなら報酬が0のスロットマシンの実行結果が得られました。
以上で、スロットマシンの機能を持つクラスを実装できました。次は、エージェントの機能を持つクラスを実装します。
参考文献
おわりに
手元のノートでは1.3節についても書いたのですが、追加要素が少なくて単なる写経の転載感が強くてブログに上げるのが躊躇われます。日記として上げてもいいものか。
【次節の内容】