からっぽのしょこ

読んだら書く!書いたら読む!同じ事は二度調べ(たく)ない

1.4.1:スロットマシンの実装【ゼロつく4のノート】

はじめに

 『ゼロから作るDeep Learning 4 ――強化学習編』の独学時のまとめノートです。初学者の補助となるようにゼロつくシリーズの4巻の内容に解説を加えていきます。本と一緒に読んでください。

 この記事は、1.4.1節の内容です。簡単なスロットマシンを実装します。

【前節の内容】

www.anarchive-beta.com

【他の記事一覧】

www.anarchive-beta.com

【この記事の内容】

1.4.1 スロットマシンの実装

 バンディット(スロットマシン)の機能を持つBanditクラスの処理を確認します。

 利用するライブラリを読み込みます。

# 利用するライブラリ
import numpy as np
import matplotlib.pyplot as plt


・処理の確認

 Banditクラスの内部で行われる処理を確認します。

 スロットマシンの数を指定して、マシンごとに確率を設定します。

# マシンの数を指定
arms = 10

# 各マシンの確率をランダムに設定
rates = np.random.rand(arms)
print(np.round(rates, 3))
[0.793 0.812 0.363 0.337 0.141 0.845 0.989 0.989 0.328 0.692]

 np.random.rand()で0から1の一様乱数を生成して、各マシンの当たりの確率として使います。

 乱数をヒストグラムにすると次のようになります。

# データ数を指定
N = 10000

# 一様乱数を生成
r_vals = np.random.rand(N)

# 乱数のヒストグラムを作成
plt.figure(figsize=(8, 6))
plt.hist(r_vals) # 乱数
plt.xlabel('r')
plt.ylabel('count')
plt.title('N='+str(N), loc='left') # データ数
plt.suptitle('np.random.rand(N)', fontsize=20)
plt.grid()
plt.show()

0から1の一様乱数

 0から1の範囲で偏りなく生成されるのが分かります。

 プレイするマシン番号を指定して、そのマシンの確率を取り出します。

# マシン番号を指定
arm = 0

# arm番目のマシンの確率を抽出
rate = rates[arm]
print(rate)
0.7925828509113335

 マシン番号をarmとして整数を指定します。マシン番号は、各マシンの確率ratesの要素番号(インデックス)と対応します。
 ratesからarm番目の要素を取り出して、指定したマシンの当たりの確率rateとします。

 乱数を生成して、当たり外れを判定します。

# 乱数を生成
r = np.random.rand()
print(r)

# 当たり外れを判定
if rate > r:
    # 当たりの場合の報酬を出力
    print(1)
else:
    # 外れの場合の報酬を出力
    print(0)
0.15586509410556726
1

 この例では、当たりであれば報酬が1、外れであれば報酬が0のマシンとします。
 0から1の一様乱数を生成して、当たりの確率rateと比較します。rateの方が大きければ1を、rate未満であれば0を返します。

 以上が、Banditクラスで行う処理です。

 当たり外れの結果が、設定した確率に対応しているのか確認します。

# データ数を指定
N = 10000

# 一様乱数を生成
r_vals = np.random.rand(N)
print(np.round(r_vals[:10], 3))

# 確率と乱数の値を比較
result = rate > r_vals
print(result[:10])

# 当たりの数をカウント
true_num = np.sum(result)
print(true_num)
[0.1   0.167 0.666 0.333 0.712 0.216 0.437 0.806 0.156 0.918]
[ True  True  True  True  True  True  True False  True False]
7948

 複数の乱数を生成してrateと比較します。当たり(rateの方が値が大きい)であればTrue、外れ(rateが乱数未満)であればFalseになります。
 True1False0として扱われるので、np.sum()で当たりの数をカウントできます。

 当たりと外れの比率を棒グラフで確認します。

# 当たりと外れの比率の棒グラフを作成
plt.figure(figsize=(8, 6))
plt.bar([0, 1], np.round([N-true_num, true_num]) / N) # 当たり外れの比率
plt.xlabel('r')
plt.ylabel('count')
plt.xticks(ticks=[0, 1], labels=['false', 'true']) # 当たり外れのラベル
plt.title('N='+str(N) + ', rate='+str(np.round(rate, 3)), loc='left') # データ数と確率
plt.suptitle('rate > np.random.rand(N)', fontsize=20)
plt.grid()
plt.ylim(0, 1)
plt.show()

当たり外れの比率

 マシンの確率rateの値に従う結果が得られたのが分かります。

・実装

 Banditクラスは、次のページを参照してください。

github.com


 実装したクラスを試してみます。

 バンディットクラスのインスタンスを作成します。

# マシンの数を指定
arms = 10

# インスタンスを作成
bandit = Bandit(arms)

# 各マシンの確率を確認
print(np.round(bandit.rates, 3))
[0.132 0.619 0.839 0.2   0.547 0.993 0.259 0.797 0.059 0.06 ]

 インスタンスを作成する度に、ランダムに値が設定されるので、各マシンの確率が変わります。

 マシンのプレイ回数と、プレイするマシン番号を指定して、マシンをプレイします。

# 試行回数を指定
N = 5

# マシン番号を指定
arm = 0

# 繰り返し試行
for n in range(N):
    # arm番目のマシンをプレイ
    reward = bandit.play(arm)
    print(reward)
0
1
0
0
0

 当たりなら報酬が1、外れなら報酬が0のスロットマシンの実行結果が得られました。

 以上で、スロットマシンの機能を持つクラスを実装できました。次は、エージェントの機能を持つクラスを実装します。

参考文献


おわりに

 手元のノートでは1.3節についても書いたのですが、追加要素が少なくて単なる写経の転載感が強くてブログに上げるのが躊躇われます。日記として上げてもいいものか。

【次節の内容】

www.anarchive-beta.com