はじめに
機械学習で登場する確率分布について色々な角度から理解したいシリーズです。
この記事では、Pythonでディリクレ分布のグラフを作成します。
【前の内容】
【他の記事一覧】
【この記事の内容】
ディリクレ分布の作図
ディリクレ分布(Dirichlet Distribution)のグラフを作成します。ディリクレ分布については「ディリクレ分布の定義式 - からっぽのしょこ」を参照してください。
利用するライブラリを読み込みます。
# 利用ライブラリ import numpy as np from scipy.stats import dirichlet import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation
定義式の確認
まずは、ディリクレ分布の定義式を確認します。
ディリクレ分布は、次の式で定義されます。
ここで、$V$は次元数で、$\boldsymbol{\beta} = (\beta_1, \beta_2, \cdots, \beta_V)$はパラメータです。$\beta_v > 0$を満たす必要があります。確率変数の実現値$\boldsymbol{\phi} = (\phi_1, \phi_2, \cdots, \phi_V)$は、$0 \leq \phi_v \leq 1$、$\sum_{v=1}^V \phi_v = 1$となります。
この計算を行いグラフを作成します。
三角座標の準備
ディリクレ分布を三角図により可視化するために、三角座標を描画するための準備をします。詳しくは「Matplotlibで三角グラフを作図したい - からっぽのしょこ」・「Matplotlibで三角グラフの等高線を作図したい - からっぽのしょこ」・Matplotlibで3D三角グラフを作図したい - からっぽのしょこを参照してください。
軸目盛の間隔を設定して、三角座標を描画するための配列を作成します。
# 軸目盛の位置を指定 axis_vals = np.arange(start=0.0, stop=1.1, step=0.1) # 軸線用の値を作成 axis_x = np.array([0.5, 0.0, 1.0]) axis_y = np.array([0.5*np.sqrt(3.0), 0.0, 0.0]) axis_u = np.array([-0.5, 1.0, -0.5]) axis_v = np.array([-0.5*np.sqrt(3.0), 0.0, 0.5*np.sqrt(3.0)]) # グリッド線用の値を作成 grid_x = np.hstack([ 0.5 * axis_vals, axis_vals, 0.5 * axis_vals + 0.5 ]) grid_y = np.hstack([ 0.5 * axis_vals * np.sqrt(3.0), np.zeros_like(axis_vals), 0.5 * (1.0 - axis_vals) * np.sqrt(3.0) ]) grid_u = np.hstack([ 0.5 * axis_vals, 0.5 * (1.0 - axis_vals), -axis_vals ]) grid_v = np.hstack([ -0.5 * axis_vals * np.sqrt(3.0), 0.5 * (1.0 - axis_vals) * np.sqrt(3.0), np.zeros_like(axis_vals) ])
2つの配列を使って以降の作図を行います。
グラフの作成
Matplotlib
ライブラリのPyplot
モジュールを利用して、ディリクレ分布のグラフを作成します。ディリクレ分布の確率密度の計算については「分布の計算」を参照してください。
パラメータの設定
ディリクレ分布のパラメータ$\boldsymbol{\beta}$を設定します。この例では、三角図で描画するため、次元数を$V = 3$とします。
# パラメータを指定 beta_v = np.array([4.0, 2.0, 3.0])
$V$次元ベクトル$\boldsymbol{\beta} = (\beta_1, \beta_2, \beta_3)$、$\beta_v > 0$の値を指定します。
2つの処理方法でグラフを作成します。
散布図によるヒートマップ
1つ目の方法は、散布図によって簡易的にヒートマップを作成します。こちらの方が直感的に処理できます。
ディリクレ分布の確率変数がとり得る値$\boldsymbol{\phi}$の各要素$\phi_v$の値を作成します。
# Φがとり得る値を作成 phi_vals = np.linspace(start=0.0, stop=1.0, num=51) print(phi_vals[:10])
[0. 0.02 0.04 0.06 0.08 0.1 0.12 0.14 0.16 0.18]
$0 \leq \phi_v \leq 1$の値をphi_vals
とします。グラフが粗い場合や処理が重い場合は、phi_vals
の間隔(np.arange()
のstep
(第3)引数)や要素数(np.linspace()
のnum
(第3)引数)を調整してください。
$\boldsymbol{\phi}$の値を作成します。
# 格子点を作成 phi_0_grid, phi_1_grid, phi_2_grid = np.meshgrid(phi_vals, phi_vals, phi_vals) # Φがとり得る点を作成 phi_points = np.stack([phi_0_grid.flatten(), phi_1_grid.flatten(), phi_2_grid.flatten()], axis=1) # 配列に格納 phi_points = phi_points[1:, :] # (0, 0, 0)の行を除去 phi_points /= np.sum(phi_points, axis=1, keepdims=True) # 正規化 phi_points = np.unique(phi_points, axis=0) # 重複を除去 print(np.round(phi_points[:5], 2)) print(phi_points.shape)
[[0. 0. 1. ]
[0. 0.02 0.98]
[0. 0.02 0.98]
[0. 0.02 0.98]
[0. 0.02 0.98]]
(113222, 3)
3つの要素分のphi_vals
の全ての組み合わせ(格子状の点)をnp.meshgrid()
で作成します。出力される配列をそれぞれ列とする配列を作成してphi_points
とします。phi_points
の各行が点$\boldsymbol{\phi} = (\phi_1, \phi_2, \phi_3)$に対応します。
ただし、$\sum_{v=1}^V \phi_v = 1$を満たす必要があるため、行ごとに総和で割って正規化します。全ての要素が0
の行はゼロ除算になるため取り除きます。正規化によって重複する組み合わせ(行)ができるので、np.unique()
で取り除きます。
phi_points
の値(点)を三角座標に変換します。
# 三角座標に変換 y_0_vals = phi_points[:, 1] + 0.5 * phi_points[:, 2] y_1_vals = 0.5 * phi_points[:, 2] * np.sqrt(3.0) print(np.round(y_0_vals[:10], 2)) print(np.round(y_1_vals[:10], 2))
[0.5 0.51 0.51 0.51 0.51 0.51 0.51 0.51 0.51 0.51]
[0.87 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85]
$y_0 = \phi_1 + \frac{\phi_2}{2}$、$y_1 = \frac{\sqrt{3} \phi_2}{2}$で変換後の座標を計算します。
$\boldsymbol{\phi}$の点ごとの確率密度を計算します。
# ディリクレ分布の確率密度を計算 dens_vals = np.array( [dirichlet.pdf(x=phi_v, alpha=beta_v) for phi_v in phi_points] ) print(np.round(dens_vals, 2))
[0. 0. 0. ... 0. 0. 0.]
ディリクレ分布の確率密度は、SciPy
ライブラリのstats
モジュールのdirichlet
で計算できます。確率変数の引数x
にphi_points
の各行、パラメータの引数alpha
にbeta_v
を指定します。
リスト内包表記を使って、phi_points
の行ごとに確率密度を計算します。
ただし、全てのパラメータが1未満で確率変数に0を含むとエラーになります。その場合は、次のようにして0.0
や欠損値np.nan
に置き換えます。
# ディリクレ分布の確率密度を計算 dens_vals = np.array( [dirichlet.pdf(x=phi_v, alpha=beta_v) if all(phi_v != 0.0) else np.nan for phi_v in phi_points] )
散布図によって、ディリクレ分布のヒートマップを作成します。
# ディリクレ分布の散布図によるヒートマップを作成 plt.figure(figsize=(12, 10), facecolor='white') # 図の設定 sct = plt.scatter(x=y_0_vals, y=y_1_vals, c=dens_vals, alpha=0.8) # 確率密度のヒートマップ plt.quiver(grid_x, grid_y, grid_u, grid_v, scale_units='xy', scale=1, units='dots', width=0.1, headwidth=0.1, fc='none', ec='gray', linewidth=1.5, linestyle=':') # 三角座標のグリッド線 plt.quiver(axis_x, axis_y, axis_u, axis_v, scale_units='xy', scale=1, units='dots', width=1.5, headwidth=1.5, fc='gray', linestyle='-') # 三角座標の枠線 for val in axis_vals: plt.text(x=0.5*val, y=0.5*val*np.sqrt(3.0), s=str(np.round(1.0-val, 1))+' '*2, ha='right', va='bottom', rotation=-60) # 三角座標のx軸目盛 plt.text(x=val, y=0.0, s=str(np.round(val, 1))+' '*10, ha='center', va='center', rotation=60) # 三角座標のy軸目盛 plt.text(x=0.5*val+0.5, y=0.5*(1.0-val)*np.sqrt(3.0), s=' '*3+str(np.round(1.0-val, 1)), ha='left', va='center') # 三角座標のz軸目盛 plt.text(x=0.25, y=0.25*np.sqrt(3.0), s='$\phi_1$'+' '*5, ha='right', va='center', size=25) # 三角座標のx軸ラベル plt.text(x=0.5, y=0.0, s='\n'+'$\phi_2$', ha='center', va='top', size=25) # 三角座標のy軸ラベル plt.text(x=0.75, y=0.25*np.sqrt(3.0), s=' '*4+'$\phi_3$', ha='left', va='center', size=25) # 三角図のz軸ラベル plt.xticks(ticks=[0.0, 0.5, 1.0], labels='') # 2次元座標のx軸目盛 plt.yticks(ticks=[0.0, 0.25*np.sqrt(3.0), 0.5*np.sqrt(3.0)], labels='') # 2次元座標のy軸目盛 plt.grid() # 2次元座標のグリッド線 plt.axis('equal') # アスペクト比 plt.suptitle(t='Dirichlet Distribution', fontsize=20) # 全体のタイトル plt.title(label='$\\beta=('+', '.join([str(beta) for beta in beta_v])+')$', loc='left') # パラメータラベル plt.colorbar(sct, label='density') # カラーバー plt.show() # 描画
phi_vals
の要素数を増やす(間隔を狭くする)と隅の空白が埋まりますが、処理が重くなります。
ヒートマップ・等高線図・曲面図
2つ目の方法は、三角座標を含めた2次元座標上の格子点を作成し、元の3次元座標に戻して確率密度を計算します。こちらの方が綺麗なグラフを作成できます。
作図用と計算用の$\boldsymbol{\phi}$の値を作成します。
# 2次元座標の値を作成 y_0_vals = np.linspace(start=0.0, stop=1.0, num=201) y_1_vals = np.linspace(start=0.0, stop=0.5*np.sqrt(3.0), num=201) # 2次元座標の格子点を作成 y_0_grid, y_1_grid = np.meshgrid(y_0_vals, y_1_vals) # 格子点の形状を保存 y_shape = y_0_grid.shape # 3次元座標の値に変換 phi_1_vals = y_0_grid.flatten() - y_1_grid.flatten() / np.sqrt(3.0) phi_2_vals = 2.0 * y_1_grid.flatten() / np.sqrt(3.0) # 範囲外の点を欠損値に置換 phi_1_vals = np.where( (phi_1_vals >= 0.0) & (phi_1_vals <= 1.0), phi_1_vals, np.nan ) phi_2_vals = np.where( (phi_2_vals >= 0.0) & (phi_2_vals <= 1.0), phi_2_vals, np.nan ) # 3次元座標の値に変換 phi_0_vals = 1.0 - phi_1_vals - phi_2_vals # 範囲外の点を欠損値に置換 phi_0_vals = np.where( (phi_0_vals >= 0.0) & (phi_0_vals <= 1.0), phi_0_vals, np.nan ) # 計算用の3次元座標の点を作成 phi_points = np.stack([phi_0_vals, phi_1_vals, phi_2_vals], axis=1)
三角座標を含めた2次元座標上の格子点を作成してy_0_grid, y_1_grid
とします。この2つの配列は作図に使います。
y_0_grid, y_1_grid
を、$\phi_0 = 1 - \phi_1 - \phi_2$、$\phi_1 = y_0 - \frac{y_1}{\sqrt{3}}$、$\phi_2 = \frac{2 y_1}{\sqrt{3}}$で、3次元座標上の点($\boldsymbol{\phi}$の点)に変換してphi_points
とします。ただし、三角座標外の点については、総和が1の値($\boldsymbol{\phi}$を満たす値)にならないので欠損値np.nan
に置き換えます。この配列は計算に使います。
ディリクレ分布を計算します。
# ディリクレ分布の確率密度を計算 dens_vals = np.array( [dirichlet.pdf(x=phi_v, alpha=beta_v) if all(phi_v != np.nan) else np.nan for phi_v in phi_points] ) print(np.round(dens_vals, 2))
[ 0. 0. 0. ... nan nan nan]
先ほどと同様に計算します。ただし、三角座標外の要素(phi_points
の欠損値を含む行)については、リスト内包表記の内部でif
文を使って、欠損値np.nan
を格納します。
また、全てのパラメータが1未満で確率変数に0を含む場合は、0.0
を含む行も除きます。
# ディリクレ分布の確率密度を計算 dens_vals = np.array( [dirichlet.pdf(x=phi_v, alpha=beta_v) if all(phi_v != np.nan) & all(phi_v != 0.0) else np.nan for phi_v in phi_points] )
ディリクレ分布のヒートマップを作成します。
# ディリクレ分布のヒートマップを作成 plt.figure(figsize=(12, 10), facecolor='white') # 図の設定 plt.quiver(grid_x, grid_y, grid_u, grid_v, scale_units='xy', scale=1, units='dots', width=0.1, headwidth=0.1, fc='none', ec='gray', linewidth=1.5, linestyle=':') # 三角座標のグリッド線 plt.quiver(axis_x, axis_y, axis_u, axis_v, scale_units='xy', scale=1, units='dots', width=1.5, headwidth=1.5, fc='black', linestyle='-') # 三角座標の枠線 for val in axis_vals: plt.text(x=0.5*val, y=0.5*val*np.sqrt(3.0), s=str(np.round(1.0-val, 1))+' '*2, ha='right', va='bottom', rotation=-60) # 三角座標のx軸目盛 plt.text(x=val, y=0.0, s=str(np.round(val, 1))+' '*10, ha='center', va='center', rotation=60) # 三角座標のy軸目盛 plt.text(x=0.5*val+0.5, y=0.5*(1.0-val)*np.sqrt(3.0), s=' '*3+str(np.round(1.0-val, 1)), ha='left', va='center') # 三角座標のz軸目盛 plt.text(x=0.25, y=0.25*np.sqrt(3.0), s='$\phi_1$'+' '*5, ha='right', va='center', size=25) # 三角座標のx軸ラベル plt.text(x=0.5, y=0.0, s='\n'+'$\phi_2$', ha='center', va='top', size=25) # 三角座標のy軸ラベル plt.text(x=0.75, y=0.25*np.sqrt(3.0), s=' '*4+'$\phi_3$', ha='left', va='center', size=25) # 三角図のz軸ラベル pcl = plt.pcolor(y_0_grid, y_1_grid, dens_vals.reshape(y_shape), alpha = 0.8) # 確率密度のヒートマップ plt.xticks(ticks=[0.0, 0.5, 1.0], labels='') # 2次元座標のx軸目盛 plt.yticks(ticks=[0.0, 0.25*np.sqrt(3.0), 0.5*np.sqrt(3.0)], labels='') # 2次元座標のy軸目盛 plt.grid() # 2次元座標のグリッド線 plt.axis('equal') # アスペクト比 plt.colorbar(pcl, label='density') # カラーバー plt.suptitle(t='Dirichlet Distribution', fontsize=20) # タイトル plt.title(label='$\\beta=('+', '.join([str(val) for val in beta_v])+')$', loc='left') # パラメータラベル plt.show() # 描画
plt.pcolor()
でヒートマップを描画します。
続いて、等高線図を作成します。
# ディリクレ分布の等高線図を作成 plt.figure(figsize=(12, 10), facecolor='white') # 図の設定 plt.quiver(grid_x, grid_y, grid_u, grid_v, scale_units='xy', scale=1, units='dots', width=0.1, headwidth=0.1, fc='none', ec='gray', linewidth=1.5, linestyle=':') # 三角座標のグリッド線 plt.quiver(axis_x, axis_y, axis_u, axis_v, scale_units='xy', scale=1, units='dots', width=1.5, headwidth=1.5, fc='black', linestyle='-') # 三角座標の枠線 for val in axis_vals: plt.text(x=0.5*val, y=0.5*val*np.sqrt(3.0), s=str(np.round(1.0-val, 1))+' '*2, ha='right', va='bottom', rotation=-60) # 三角座標のx軸目盛 plt.text(x=val, y=0.0, s=str(np.round(val, 1))+' '*10, ha='center', va='center', rotation=60) # 三角座標のy軸目盛 plt.text(x=0.5*val+0.5, y=0.5*(1.0-val)*np.sqrt(3.0), s=' '*3+str(np.round(1.0-val, 1)), ha='left', va='center') # 三角座標のz軸目盛 plt.text(x=0.25, y=0.25*np.sqrt(3.0), s='$\phi_1$'+' '*5, ha='right', va='center', size=25) # 三角座標のx軸ラベル plt.text(x=0.5, y=0.0, s='\n'+'$\phi_2$', ha='center', va='top', size=25) # 三角座標のy軸ラベル plt.text(x=0.75, y=0.25*np.sqrt(3.0), s=' '*4+'$\phi_3$', ha='left', va='center', size=25) # 三角図のz軸ラベル cnf = plt.contourf(y_0_grid, y_1_grid, dens_vals.reshape(y_shape), alpha = 0.8) # 確率密度の等高線 plt.xticks(ticks=[0.0, 0.5, 1.0], labels='') # 2次元座標のx軸目盛 plt.yticks(ticks=[0.0, 0.25*np.sqrt(3.0), 0.5*np.sqrt(3.0)], labels='') # 2次元座標のy軸目盛 plt.grid() # 2次元座標のグリッド線 plt.axis('equal') # アスペクト比 plt.colorbar(cnf, label='density') # カラーバー plt.suptitle(t='Dirichlet Distribution', fontsize=20) # タイトル plt.title(label='$\\beta=('+', '.join([str(val) for val in beta_v])+')$', loc='left') # パラメータラベル plt.show() # 描画
plt.contour()
またはcontourf()
で等高線図を描画します。
曲面図を作成します。
# ディリクレ分布の曲面図を作成 fig = plt.figure(figsize=(12, 10), facecolor='white') # 図の設定 ax = fig.add_subplot(projection='3d') # 3D用の設定 ax.quiver(grid_x, grid_y, np.zeros_like(grid_x), grid_u, grid_v, np.zeros_like(grid_x), arrow_length_ratio=0.0, ec='gray', linewidth=1.5, linestyle=':') # 三角座標のグリッド線 ax.quiver(axis_x, axis_y, np.zeros_like(axis_x), axis_u, axis_v, np.zeros_like(axis_x), arrow_length_ratio=0.0, ec='black', linestyle='-') # 三角座標の枠線 for val in axis_vals: ax.text(x=0.5*val-0.05, y=0.5*val*np.sqrt(3.0), z=0.0, s=str(np.round(1.0-val, 1)), ha='center', va='center') # 三角座標のx軸目盛 ax.text(x=val, y=0.0-0.05, z=0.0, s=str(np.round(val, 1)), ha='center', va='center') # 三角座標のy軸目盛 ax.text(x=0.5*val+0.5+0.05, y=0.5*(1.0-val)*np.sqrt(3.0), z=0.0, s=str(np.round(1.0-val, 1)), ha='center', va='center') # 三角座標のz軸目盛 ax.text(x=0.25-0.1, y=0.25*np.sqrt(3.0), z=0.0, s='$\phi_1$', ha='right', va='center', size=25) # 三角座標のx軸ラベル ax.text(x=0.5, y=0.0-0.1, z=0.0-0.1, s='$\phi_2$', ha='center', va='top', size=25) # 三角座標のy軸ラベル ax.text(x=0.75+0.1, y=0.25*np.sqrt(3.0), z=0.0, s='$\phi_3$', ha='left', va='center', size=25) # 三角図のz軸ラベル ax.contour(y_0_grid, y_1_grid, dens_vals.reshape(y_shape), offset=0.0) # 確率密度の等高線 ax.plot_surface(y_0_grid, y_1_grid, dens_vals.reshape(y_shape), cmap='viridis', alpha=0.8) # 確率密度の曲面 ax.set_xticks(ticks=[0.0, 0.5, 1.0], labels='') # 2次元座標のx軸目盛 ax.set_yticks(ticks=[0.0, 0.25*np.sqrt(3.0), 0.5*np.sqrt(3.0)], labels='') # 2次元座標のy軸目盛 ax.set_zlabel(zlabel='density') # z軸ラベル ax.set_box_aspect(aspect=(1, 1, 1)) # アスペクト比 fig.suptitle(t='Dirichlet Distribution', fontsize=20) # タイトル ax.set_title(label='$\\beta=('+', '.join([str(beta) for beta in beta_v])+')$', loc='left') # パラメータラベル #ax.view_init(elev=90, azim=-90) # 表示角度 plt.show() # 描画
ax.plot_surface()
で曲面図を描画します。
ここまでで、ディリクレ分布のグラフを描画できました。以降は、ここまでの作図処理を用いて、パラメータの影響を確認していきます。
パラメータと分布の形状の関係をアニメーションで可視化
パラメータの値を少しずつ変化させて、分布の形状の変化をアニメーションで確認します。
第1成分の影響
まずは、$\beta_1$の値を変化させ、$\beta_2, \beta_3$を固定します。
# パラメータとして利用する値を指定 beta_1_vals = np.arange(start=1.0, stop=10.1, step=0.1).round(decimals=1) # 固定するパラメータを指定 beta_2 = 2.0 beta_3 = 3.0 # フレーム数を設定 frame_num = len(beta_1_vals) print(frame_num)
91
値の間隔が一定になるように$\beta_1$の値をbeta_1_vals
として作成します。パラメータごとにフレームを切り替えるので、beta_1_vals
の要素数がアニメーションのフレーム数になります。
また$\beta_2, \beta_3$をbeta_2, beta_3
として値を指定します。
・作図コード(クリックで展開)
全てのフレームで共通のグラデーションと等高線を引くための値を設定します。
# z軸の最小値と最大値を設定 dens_min = 0.0 dens_max = 25.0 # 等高線を引く値を設定 dens_levels = np.linspace(dens_min, dens_max, num=11) print(dens_levels)
[ 0. 2.5 5. 7.5 10. 12.5 15. 17.5 20. 22.5 25. ]
ディリクレ分布の等高線図のアニメーションをします。
# 図を初期化 fig = plt.figure(figsize=(12, 10), facecolor='white') # 図の設定 fig.suptitle(t='Dirichlet Distribution', fontsize=20) # タイトル tmp = plt.contourf(y_0_grid, y_1_grid, np.zeros(y_shape), vmin=dens_min, vmax=dens_max, levels=dens_levels, alpha = 0.8) # カラーバー用のダミー fig.colorbar(tmp, label='density') # カラーバー # 作図処理を関数として定義 def update(i): # 前フレームのグラフを初期化 plt.cla() # i番目のパラメータを取得 beta_1 = beta_1_vals[i] #beta_2 = beta_2_vals[i] #beta_3 = beta_3_vals[i] # パラメータを設定 beta_v = np.array([beta_1, beta_2, beta_3]) # ディリクレ分布の確率密度を計算 dens_vals = np.array( [dirichlet.pdf(x=phi_v, alpha=beta_v) if all(phi_v != np.nan) else np.nan for phi_v in phi_points] ) # パラメータラベル用の文字列を作成 param_text = '$\\beta=('+', '.join([str(beta) for beta in beta_v])+')$' # 三角座標上の等高線図を作成 plt.quiver(grid_x, grid_y, grid_u, grid_v, scale_units='xy', scale=1, units='dots', width=0.1, headwidth=0.1, fc='none', ec='gray', linewidth=1.5, linestyle=':') # 三角座標のグリッド線 plt.quiver(axis_x, axis_y, axis_u, axis_v, scale_units='xy', scale=1, units='dots', width=1.5, headwidth=1.5, fc='black', linestyle='-') # 三角座標の枠線 for val in axis_vals: plt.text(x=0.5*val, y=0.5*val*np.sqrt(3.0), s=str(np.round(1.0-val, 1))+' '*2, ha='right', va='bottom', rotation=-60) # 三角座標のx軸目盛 plt.text(x=val, y=0.0, s=str(np.round(val, 1))+' '*10, ha='center', va='center', rotation=60) # 三角座標のy軸目盛 plt.text(x=0.5*val+0.5, y=0.5*(1.0-val)*np.sqrt(3.0), s=' '*3+str(np.round(1.0-val, 1)), ha='left', va='center') # 三角座標のz軸目盛 plt.text(x=0.25, y=0.25*np.sqrt(3.0), s='$\phi_1$'+' '*5, ha='right', va='center', size=25) # 三角座標のx軸ラベル plt.text(x=0.5, y=0.0, s='\n'+'$\phi_2$', ha='center', va='top', size=25) # 三角座標のy軸ラベル plt.text(x=0.75, y=0.25*np.sqrt(3.0), s=' '*4+'$\phi_3$', ha='left', va='center', size=25) # 三角図のz軸ラベル plt.contourf(y_0_grid, y_1_grid, dens_vals.reshape(y_shape), vmin=dens_min, vmax=dens_max, levels=dens_levels, alpha = 0.8) # 確率密度の等高線 plt.xticks(ticks=[0.0, 0.5, 1.0], labels='') # 2次元座標のx軸目盛 plt.yticks(ticks=[0.0, 0.25*np.sqrt(3.0), 0.5*np.sqrt(3.0)], labels='') # 2次元座標のy軸目盛 plt.grid() # 2次元座標のグリッド線 plt.axis('equal') # アスペクト比 plt.title(label=param_text, loc='left') # パラメータラベル # gif画像を作成 ani = FuncAnimation(fig=fig, func=update, frames=frame_num, interval=100) # gif画像を保存 ani.save('../../figure/Python/dirichlet_cnf.gif')
各フレームの確率密度の計算と作図処理を関数として定義して、FuncAnimation()
を使ってアニメーション(gif画像)を作成する。
同様に、曲面図のアニメーションを作成します。
# ディリクレ分布の曲面図のアニメーションを作成 fig = plt.figure(figsize=(12, 10), facecolor='white') # 図の設定 ax = fig.add_subplot(projection='3d') # 3D用の設定 fig.suptitle(t='Dirichlet Distribution', fontsize=20) # タイトル # 作図処理を関数として定義 def update(i): # 前フレームのグラフを初期化 plt.cla() # i番目のパラメータを取得 beta_1 = beta_1_vals[i] #beta_2 = beta_2_vals[i] #beta_3 = beta_3_vals[i] # パラメータを設定 beta_v = np.array([beta_1, beta_2, beta_3]) # ディリクレ分布の確率密度を計算 dens_vals = np.array( [dirichlet.pdf(x=phi_v, alpha=beta_v) if all(phi_v != np.nan) else np.nan for phi_v in phi_points] ) # パラメータラベル用の文字列を作成 param_text = '$\\beta=('+', '.join([str(beta) for beta in beta_v])+')$' # 三角座標上の曲面図を作成 ax.quiver(grid_x, grid_y, np.zeros_like(grid_x), grid_u, grid_v, np.zeros_like(grid_x), arrow_length_ratio=0.0, ec='gray', linewidth=1.5, linestyle=':') # 三角座標のグリッド線 ax.quiver(axis_x, axis_y, np.zeros_like(axis_x), axis_u, axis_v, np.zeros_like(axis_x), arrow_length_ratio=0.0, ec='black', linestyle='-') # 三角座標の枠線 for val in axis_vals: ax.text(x=0.5*val-0.05, y=0.5*val*np.sqrt(3.0), z=0.0, s=str(np.round(1.0-val, 1)), ha='center', va='center') # 三角座標のx軸目盛 ax.text(x=val, y=0.0-0.05, z=0.0, s=str(np.round(val, 1)), ha='center', va='center') # 三角座標のy軸目盛 ax.text(x=0.5*val+0.5+0.05, y=0.5*(1.0-val)*np.sqrt(3.0), z=0.0, s=str(np.round(1.0-val, 1)), ha='center', va='center') # 三角座標のz軸目盛 ax.text(x=0.25-0.1, y=0.25*np.sqrt(3.0), z=0.0, s='$\phi_1$', ha='right', va='center', size=25) # 三角座標のx軸ラベル ax.text(x=0.5, y=0.0-0.1, z=0.0-0.1, s='$\phi_2$', ha='center', va='top', size=25) # 三角座標のy軸ラベル ax.text(x=0.75+0.1, y=0.25*np.sqrt(3.0), z=0.0, s='$\phi_3$', ha='left', va='center', size=25) # 三角図のz軸ラベル ax.contour(y_0_grid, y_1_grid, dens_vals.reshape(y_shape), vmin=dens_min, vmax=dens_max, levels=dens_levels, offset=0.0) # 確率密度の等高線 ax.plot_surface(y_0_grid, y_1_grid, dens_vals.reshape(y_shape), cmap='viridis', alpha=0.8) # 確率密度の曲面 ax.set_xticks(ticks=[0.0, 0.5, 1.0], labels='') # 2次元座標のx軸目盛 ax.set_yticks(ticks=[0.0, 0.25*np.sqrt(3.0), 0.5*np.sqrt(3.0)], labels='') # 2次元座標のy軸目盛 ax.set_zlabel(zlabel='density') # z軸ラベル ax.set_zlim(bottom=dens_min, top=dens_max) # z軸の表示範囲 ax.set_box_aspect(aspect=(1, 1, 1)) # アスペクト比 ax.set_title(label=param_text, loc='left') # パラメータラベル #ax.view_init(elev=90, azim=-90) # 表示角度 # gif画像を作成 ani = FuncAnimation(fig=fig, func=update, frames=frame_num, interval=100) # gif画像を保存 ani.save('../../figure/Python/dirichlet_srf.gif')
第2成分の影響
続いて、$\beta_2$の値を変化させ、$\beta_1, \beta_3$を固定します。
# パラメータとして利用する値を指定 beta_2_vals = np.arange(start=1.0, stop=10.1, step=0.1).round(decimals=1) # 固定するパラメータを指定 beta_1 = 4.0 beta_3 = 3.0 # フレーム数を設定 frame_num = len(beta_2_vals) print(frame_num)
91
値の間隔が一定になるように$\beta_2$の値をbeta_2_vals
として作成し、$\beta_1, \beta_3$をbeta_1, beta_3
として値を指定します。
「第1成分の影響」のコードで作図できます。
第3成分の影響
$\beta_3$の値を変化させ、$\beta_1, \beta_2$を固定します。
# パラメータとして利用する値を指定 beta_3_vals = np.arange(start=1.0, stop=10.1, step=0.1).round(decimals=1) # 固定するパラメータを指定 beta_1 = 4.0 beta_2 = 2.0 # フレーム数を設定 frame_num = len(beta_3_vals) print(frame_num)
91
同様に指定します。
「第1成分の影響」のコードで作図できます。
3つの成分の影響
最後に、$\boldsymbol{\beta}$の値を変化させます。
# パラメータとして利用する値を指定 beta_1_vals = np.arange(start=1.0, stop=10.1, step=0.1).round(decimals=1) beta_2_vals = np.arange(start=1.0, stop=10.1, step=0.1).round(decimals=1) beta_3_vals = np.arange(start=1.0, stop=10.1, step=0.1).round(decimals=1) # フレーム数を設定 frame_num = len(beta_1_vals) print(frame_num)
91
3つのbeta_*_vals
の要素数が同じになるように値を指定します。
全てのフレームで共通のグラデーションと等高線を引くための値を設定します。
# z軸の最小値と最大値を設定 dens_min = 0.0 beta_max_v = np.array([beta_1_vals.max(), beta_2_vals.max(), beta_3_vals.max()]) dens_max = np.ceil( dirichlet.pdf(x=(beta_max_v-1.0)/(np.sum(beta_max_v)-3.0), alpha=beta_max_v) ) # 等高線を引く値を設定 dens_levels = np.linspace(dens_min, dens_max, num=11) print(dens_levels)
[ 0. 2.5 5. 7.5 10. 12.5 15. 17.5 20. 22.5 25. ]
今回は最大値を直接指定せず、上手いこと最頻値の確率密度を求めます。
「第1成分の影響」のコードで作図できます。
この記事では、ディリクレ分布のグラフを作成しました。
参考文献
- 岩田具治『トピックモデル』(機械学習プロフェッショナルシリーズ)講談社,2015年.
おわりに
$\phi_0$なのか$\phi_1$なのか三角グラフの記事との兼ね合いで色々アレなんですが頑張って読んでください。
Pythonでも乱数生成とかやりたいのですが他にもやりたいことががが。
【次の内容】
つづくはず