からっぽのしょこ

読んだら書く!書いたら読む!同じ事は二度調べ(たく)ない

【R】3.2.2:カテゴリ分布の学習と予測【緑ベイズ入門のノート】

はじめに

 『ベイズ推論による機械学習入門』の学習時のノートです。基本的な内容は「数式の行間を読んでみた」とそれを「RとPythonで組んでみた」になります。「数式」と「プログラム」から理解するのが目標です。

 この記事は、3.2.2項の内容です。尤度関数をカテゴリ分布、事前分布をディリクレ分布とした場合のパラメータの事後分布と未観測値の予測分布の計算をR言語で実装します。

 省略してある内容等ありますので、本とあわせて読んでください。初学者な自分が理解できるレベルまで落として書き下していますので、分かる人にはかなりくどくなっています。同じような立場の人のお役に立てれば幸いです。

【数式読解編】

www.anarchive-beta.com

【他の節の内容】

www.anarchive-beta.com

【この節の内容】

・Rでやってみよう

 人工的に生成したデータを用いて、ベイズ推論を行ってみましょう。

 利用するパッケージを読み込みます。

# 3.2.2項で利用するパッケージ
library(tidyverse)


・モデルの構築

 まずは、モデルの設定を行います。

 尤度(カテゴリ分布)$p(\mathbf{S} | \boldsymbol{\pi})$のパラメータ$\boldsymbol{\pi} = (\pi_1, \pi_2, \cdots, \pi_K)$を設定します。

# 次元数:(固定)
K <- 3

# 真のパラメータを指定
pi_truth_k <- c(0.3, 0.5, 0.2)

 次元数$K$をKとします。この例では三角図で可視化するため、一部のプログラムは$K = 3$の場合だけ動作します。超パラメータの推定自体は、3以外でも動作します。

 $\pi_k$は、各データ$\mathbf{s}_n = (s_{n,1}, s_{n,2}, \cdots, s_{n,K})$において$s_{n,k} = 1$となる確率です。$\boldsymbol{\pi}$をpi_truth_kとして、$0 \leq \pi_k \leq 1$、$\sum_{k=1}^K \pi_k = 1$の値を指定します。これが真のパラメータであり、この値を求めるのがここでの目的です。

 作図用に、尤度の次元番号と対応する確率を持つデータフレームを作成します。

# 尤度(カテゴリ分布)のデータフレームを作成
model_df <- tibble(
  k = 1:K, # 次元番号
  prob = pi_truth_k # 確率
)

 ここでは簡単に、次元番号とパラメータをデータフレームに格納します。
 次元番号に対応する確率は、カテゴリ分布の定義式

$$ \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\pi}) = \prod_{k=1}^K \pi_k^{s_{n,k}} \tag{2.29} $$

や、多項分布の確率計算関数dmultinom()を使っても計算できます。

# 全てのパターンのデータを作成
s_kk <- diag(K)

# 確率を計算:式(2.29)
apply(pi_truth_k^t(s_kk), 2, prod)
## [1] 0.3 0.5 0.2

 この計算には転置t()は不要ですが、s_nkを使って計算するときには必要です。dmultinom()は、1データs_kk[k, ]ずつ処理する必要があるので(私の実装力上の問題で)省略します。
 ただし、$\mathbf{s}_n$における$s_{n,k} = 1$以外の$s_{n,1}, \cdots, s_{n,k-1}, s_{n,k+1}, \cdots, s_{n,K}$は0であり、$x^0 = 1$なので、計算結果はpi_truth_kになります。よって、pi_truth_kをそのまま格納します。

 作成したデータフレームを確認しましょう。

# 確認
head(model_df)
## # A tibble: 3 x 2
##       k  prob
##   <int> <dbl>
## 1     1   0.3
## 2     2   0.5
## 3     3   0.2

 ggplot2パッケージを利用して作図するには、データフレームを渡す必要があります。

 尤度を作図します。

# 尤度を作図
ggplot(model_df, aes(x = k, y = prob)) + 
  geom_bar(stat = "identity", position = "dodge", fill = "purple") + # 観測モデル
  ylim(c(0, 1)) + # y軸の表示範囲
  labs(title = "Catgorical Distribution", 
       subtitle = paste0("pi=(", paste0(pi_truth_k, collapse = ", "), ")"))

f:id:anemptyarchive:20210222111724p:plain
尤度:カテゴリ分布

 真のパラメータを求めることは、この真の分布を求めることを意味します。

・データの生成

 続いて、構築したモデルに従って観測データ$\mathbf{S} = \{\mathbf{s}_1, \mathbf{s}_2, \cdots, \mathbf{s}_N\}$を生成します。

 カテゴリ分布に従う$N$個のデータをランダムに生成します。

# データ数を指定
N <- 50

# (観測)データを生成
s_nk <- rmultinom(n = N, size = 1, prob = pi_truth_k) %>% 
  t()

 生成するデータ数$N$をNとして、値を指定します。

 カテゴリ分布に従う乱数は、多項分布に従う乱数生成関数rmultinom()size引数を1にすることで生成できます。また、試行回数の引数nN、確率の引数probpi_truth_kを指定します。生成したN個のデータをs_nkとします。
 試行ごとの結果を列としたマトリクスが返ってくるので、t()で転置します。

 観測したデータ$\mathbf{S}$を確認しましょう。

# 観測データを確認
s_nk[1:5, ]
colSums(s_nk)
##      [,1] [,2] [,3]
## [1,]    0    1    0
## [2,]    1    0    0
## [3,]    0    1    0
## [4,]    1    0    0
## [5,]    0    1    0

## [1] 17 24  9

 各データ$\mathbf{s}_n = (s_{n,1}, s_{n,1}, \cdots, s_{n,K})$は、1つの項を1、それ以外を0とする$K$次元ベクトルです。
 $k$番目の次元において1となったデータ数は、$\sum_{n=1}^N s_{n,k}$で得られます。

 $\mathbf{S}$をヒストグラムでも確認します。

# 観測データのヒストグラムを作図
tibble(k = 1:K, count = colSums(s_nk)) %>% 
  ggplot(aes(x = k, y = count)) + 
    geom_bar(stat = "identity", position = "dodge") + # (簡易)ヒストグラム
    labs(title = "Observation Data", 
         subtitle = paste0("N=", N, ", pi=(", paste0(pi_truth_k, collapse = ", "), ")"))

f:id:anemptyarchive:20210222111751p:plain
観測データのヒストグラム:カテゴリ分布

 データ数が十分に大きいと、分布の形状が真の分布に近づきます。

・事前分布の設定

 尤度に対する共役事前分布を設定します。

 事前分布(ディリクレ分布)$p(\boldsymbol{\pi} | \boldsymbol{\alpha})$のパラメータ(超パラメータ)を設定します。

# 事前分布のパラメータ
alpha_k <- c(1, 1, 1)

 ディリクレ分布のパラメータ$\boldsymbol{\alpha} = (\alpha_1, \alpha_2, \cdots, \alpha_K)$をalpha_kとして、$\alpha_k > 0$の値を指定します。

 事前分布の作図用に、$\boldsymbol{\pi}$がとり得る値を用意します。

# 作図用のpiの値を満遍なく生成
point_vec <- seq(0, 1, by = 0.025) # piがとり得る値
n_point <- length(point_vec) # 点の数
pi_point_df <- tibble(
  pi_1 = rep(rep(point_vec, times = n_point), times = n_point), 
  pi_2 = rep(rep(point_vec, each = n_point), times = n_point), 
  pi_3 = rep(point_vec, each = n_point^2)
)
pi_point_df <- pi_point_df / rowSums(pi_point_df) # 正規化

# 点を間引く(ハイスぺ機なら不要…)
pi_point <- pi_point_df[-1, ] %>%  # (0, 0, 0)の行を除去
  round(3) %>% # 値を丸め込み
  dplyr::distinct(pi_1, pi_2, pi_3) %>% # 重複を除去
  as.matrix() # マトリクスに変換

 seq(0, 1)で、$\pi_k$がとり得る0から1までの値を用意してpoint_vecとします。by引数で間隔を指定できるので、グラフが粗かったり処理が重かったりする場合はこの値を調整してください。

 事前分布(と事後分布)を三角図で描画するために、パラメータは3次元$\boldsymbol{\pi} = (\pi_1, \pi_2, \pi_3)$に限ります。point_vecの要素を使って3つの次元で全ての組み合わせ持つようにデータフレームを作成します。

 最初の行が$\boldsymbol{\pi} = (0, 0, 0)$になってしまうので取り除き、値を丸めます。最後に、正規化したことで重複した行を取り除き、マトリクスに変換します。

# 確認
head(pi_point)
##       pi_1  pi_2 pi_3
## [1,] 1.000 0.000    0
## [2,] 0.000 1.000    0
## [3,] 0.500 0.500    0
## [4,] 0.667 0.333    0
## [5,] 0.750 0.250    0
## [6,] 0.800 0.200    0

 簡単な例にすると次のような処理です。

tibble(
  v1 = rep(rep(seq(1, 2), times = 2), times = 2), 
  v2 = rep(rep(seq(1, 2), each = 2), times = 2), 
  v3 = rep(seq(1, 2), each = 4)
)
## # A tibble: 8 x 3
##      v1    v2    v3
##   <int> <int> <int>
## 1     1     1     1
## 2     2     1     1
## 3     1     2     1
## 4     2     2     1
## 5     1     1     2
## 6     2     1     2
## 7     1     2     2
## 8     2     2     2

 面倒でしたら次のようにランダムに生成してしまいます。

# 作図用のpiの値をランダムに生成
pi_point <- seq(0, 1, 0.001) %>% # piがとり得る値
  sample(size = 10000 * K, replace = TRUE) %>% # ランダムに値を生成
  matrix(ncol = K) # マトリクスに変換
pi_point <- pi_point / rowSums(pi_point) # 正規化

 Kの倍数個の値をランダムに生成して、K列のマトリクスを作成します。点の数が少なすぎるとグラフが疎らになり、多いすぎると処理が重くなります。
 こちらの方法でも正規化する必要があります。

 事前分布の確率密度を計算します。

# 事後分布(ディリクレ分布)を計算
prior_df <- tibble(
  x = pi_point[, 2] + (pi_point[, 3] / 2), # 三角座標への変換
  y = sqrt(3) * (pi_point[, 3] / 2), # 三角座標への変換
  ln_C_dir = lgamma(sum(alpha_k)) - sum(lgamma(alpha_k)), # 正規化項(対数)
  density = exp(ln_C_dir) * apply(t(pi_point)^(alpha_k - 1), 2, prod) # 確率密度
)

 3次元の値を2次元の図に落とし込むために、三角図に落とし込みます。

 pi_pointの各行に対して、確率密度を計算します。ディリクレ分布の確率密度は、定義式

$$ \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) = \frac{\Gamma(\sum_{k=1}^K \alpha_k)}{\prod_{k=1}^K \Gamma(\alpha_k)} \pi_k^{\alpha-1} \tag{2.48} $$

で計算します。ここで、$\Gamma(\cdot)$はガンマ関数です。
 ガンマ関数の計算はgamma()で行えますが、値が大きくなると発散してしまします。そこで、対数をとったガンマ関数lgamma()で計算した後にexp()で戻します。
 ディリクレ分布の確率密度は、MCMCpack::dderichlet()でも計算できます(が、第1引数に0を含む場合は計算できないようなので省略しましした)。

 計算結果は次のようになります。

# 確認
head(prior_df)
## # A tibble: 6 x 4
##       x     y ln_C_dir density
##   <dbl> <dbl>    <dbl>   <dbl>
## 1 0         0    0.693       2
## 2 1         0    0.693       2
## 3 0.5       0    0.693       2
## 4 0.333     0    0.693       2
## 5 0.25      0    0.693       2
## 6 0.2       0    0.693       2


 事前分布を作図します。

# 事前分布を作図
ggplot() + 
  geom_point(data = prior_df, aes(x = x, y = y, color = density)) + # 事前分布
  scale_color_gradientn(colors = c("blue", "green", "yellow", "red")) + # 散布図のグラデーション
  scale_x_continuous(breaks = c(0, 1), 
                     labels = c("(1, 0, 0)", "(0, 1, 0)")) + # x軸目盛
  scale_y_continuous(breaks = c(0, 0.87), 
                     labels = c("(1, 0, 0)", "(0, 0, 1)")) + # y軸目盛
  coord_fixed(ratio = 1) + # 縦横比
  labs(title = "Dirichlet Distribution", 
       subtitle = paste0("alpha=(", paste(alpha_k, collapse = ", "), ")"), 
       x = expression(paste(pi[1], ", ", pi[2], sep = "")), 
       y = expression(paste(pi[1], ", ", pi[3], sep = "")))

f:id:anemptyarchive:20210222111835p:plain
事前分布:ディリクレ分布

 alpha_kの値を変更することで、ディリクレ分布におけるパラメータと形状の関係を確認できます。

・事後分布の計算

 観測データ$\mathbf{S}$からパラメータ$\boldsymbol{\pi}$の事後分布を求めます(パラメータ$\boldsymbol{\pi}$を分布推定します)。

 観測データs_nkを用いて、事後分布(ディリクレ分布)のパラメータを計算します。

# 事後分布のパラメータを計算:式(3.28)
alpha_hat_k <- colSums(s_nk) + alpha_k

 事後分布のパラメータは

$$ \hat{\alpha}_k = \sum_{n=1}^N s_{n,k} + \alpha_k \tag{3.28} $$

で計算して、結果をalpha_hat_kとします。

# 確認
alpha_hat_k
## [1] 18 25 10

 事前分布のパラメータ$\alpha_1, \cdots, \alpha_K$に、それぞれs_nkの次元ごとに1となったデータ数を加えています。

 事後分布の確率密度を計算します。

# 事後分布(ディリクレ分布)を計算
posterior_df <- tibble(
  x = pi_point[, 2] + (pi_point[, 3] / 2), # 三角座標への変換
  y = sqrt(3) * (pi_point[, 3] / 2), # 三角座標への変換
  ln_C_dir = lgamma(sum(alpha_hat_k)) - sum(lgamma(alpha_hat_k)), # 正規化項(対数)
  density = exp(ln_C_dir) * apply(t(pi_point)^(alpha_hat_k - 1), 2, prod) # 確率密度
)

 更新した超パラメータalpha_hat_kを用いて、事前分布のときと同様にして計算します。

 計算結果は次のようになります。

# 確認
head(posterior_df)
## # A tibble: 6 x 4
##       x     y ln_C_dir density
##   <dbl> <dbl>    <dbl>   <dbl>
## 1 0         0     55.3       0
## 2 1         0     55.3       0
## 3 0.5       0     55.3       0
## 4 0.333     0     55.3       0
## 5 0.25      0     55.3       0
## 6 0.2       0     55.3       0


 真のパラメータの位置を三角図に表示するために、変換してデータフレームに格納しておきます。

# 真のパラメータのデータフレームを作成
parameter_df <- tibble(
  x = pi_truth_k[2] + (pi_truth_k[3] / 2), # 三角座標への変換
  y = sqrt(3) * (pi_truth_k[3] / 2), # 三角座標への変換
)


 事後分布を作図します。

# 事後分布を作図
ggplot() + 
  geom_point(data = posterior_df, aes(x = x, y = y, color = density)) + # 事後分布
  geom_point(data = parameter_df, aes(x = x, y = y), shape = 4, size = 5) + # 真のパラメータ
  scale_color_gradientn(colors = c("blue", "green", "yellow", "red")) + # 散布図のグラデーション
  scale_x_continuous(breaks = c(0, 1), 
                     labels = c("(1, 0, 0)", "(0, 1, 0)")) + # x軸目盛
  scale_y_continuous(breaks = c(0, 0.87), 
                     labels = c("(1, 0, 0)", "(0, 0, 1)")) + # y軸目盛
  coord_fixed(ratio = 1) + # 縦横比
  labs(title = "Dirichlet Distribution", 
       subtitle = paste0("N=", N, ", alpha_hat=(", paste(alpha_hat_k, collapse = ", "), ")"), 
       x = expression(paste(pi[1], ", ", pi[2], sep = "")), 
       y = expression(paste(pi[1], ", ", pi[3], sep = "")))

f:id:anemptyarchive:20210222111903p:plain
事後分布:ディリクレ分布

 パラメータ$\boldsymbol{\pi}$の真の値付近をピークとする分布を推定できています。

・予測分布の計算

 最後に、$\mathbf{S}$から未観測のデータ$\mathbf{s}_{*}$の予測分布を求めます。

 事後分布のパラメータalpha_hat_k、または観測データs_nkと事前分布のパラメータalpha_kを用いて予測分布(カテゴリ分布)のパラメータを計算します。

# 予測分布のパラメータを計算
pi_hat_star_k <- alpha_hat_k / sum(alpha_hat_k)
pi_hat_star_k <- (colSums(s_nk) + alpha_k) / sum(colSums(s_nk) + alpha_k)

 予測分布のパラメータの計算式

$$ \begin{aligned} \hat{\pi}_{*,k} &= \frac{ \hat{\alpha}_k }{ \sum_{k'=1}^K \hat{\alpha}_{k'} } \\ &= \frac{ \sum_{n=1}^N s_{n,k} + \alpha_k }{ \sum_{k'=1}^K \sum_{n'=1}^N s_{n',k'} + \alpha_{k'} } \end{aligned} $$

の結果をpi_hat_star_kとします。
 上の式だと、事後分布のパラメータalpha_hat_kを使って計算できます。下の式だと、観測データs_nkと事前分布のパラメータalpha_kを使って計算できます。

# 確認
pi_hat_star_k
## [1] 0.3396226 0.4716981 0.1886792

 $\hat{\pi}_{*}$は、$s_{*,k} = 1$となる確率を表し、$\mathbf{S}$から学習しているのが式からも分かります。

 予測分布を計算します。

# 予測分布(カテゴリ分布)のデータフレームを作成
predict_df <- tibble(
  k = 1:K, # 次元番号
  prob = pi_hat_star_k # 確率
)

 尤度のときと同様に処理できます。

 作成したデータフレームは次のようになります。

# 確認
head(predict_df)
## # A tibble: 3 x 2
##       k  prob
##   <int> <dbl>
## 1     1 0.340
## 2     2 0.472
## 3     3 0.189


 予測分布を尤度と重ねて作図します。

# 予測分布を作図
ggplot() + 
  geom_bar(data = predict_df, aes(x = k, y = prob), stat = "identity", position = "dodge", 
           fill = "purple") + # 予測分布
  geom_bar(data = model_df, aes(x = k, y = prob), stat = "identity", position = "dodge", 
           alpha = 0, color = "red", linetype = "dashed") + # 真の分布
  ylim(c(0, 1)) + # y軸の表示範囲
  labs(title = "Categorical Distribution", 
       subtitle = paste0("N=", N, ", pi_hat=(", paste(round(pi_hat_star_k, 2), collapse = ", "), ")"))

f:id:anemptyarchive:20210222111926p:plain
予測分布:カテゴリ分布

 観測データが増えると、予測分布が真の分布に近づきます。(このグラフylim()忘れた、、)

・おまけ:推移の確認

 gganimateパッケージを利用して、パラメータの推定値の推移のアニメーション(gif画像)を作成するためのコードです。

・コード(クリックで展開)

# 利用するパッケージ
library(tidyverse)
library(gganimate)


 異なる点のみを簡単に解説します。

### モデルの設定

# 次元数:(固定)
K <- 3

# 真のパラメータを指定
pi_truth_k <- c(0.3, 0.5, 0.2)

# 事前分布のパラメータを指定
alpha_k <- c(1, 1, 1)

# 作図用のpiの値を満遍なく生成
point_vec <- seq(0, 1, by = 0.025) # piがとり得る値
n_point <- length(point_vec) # 点の数
pi_point <- tibble(
  pi_1 = rep(rep(point_vec, times = n_point), times = n_point), 
  pi_2 = rep(rep(point_vec, each = n_point), times = n_point), 
  pi_3 = rep(point_vec, each = n_point^2)
)
pi_point <- pi_point / rowSums(pi_point) # 正規化

# 点を間引く(ハイスぺ機なら不要…)
pi_point <- pi_point[-1, ] %>%  # (0, 0, 0)の行を除去
  round(3) %>% # 値を丸め込み
  dplyr::as_tibble() %>% # データフレームに変換
  dplyr::distinct(pi_1, pi_2, pi_3) %>% # 重複を除去
  as.matrix() # マトリクスに再変換


# 事前分布(ディリクレ分布)を計算
posterior_df <- tibble(
  x = pi_point[, 2] + (pi_point[, 3] / 2), # 三角座標への変換
  y = sqrt(3) * (pi_point[, 3] / 2), # 三角座標への変換
  ln_C_dir = lgamma(sum(alpha_k)) - sum(lgamma(alpha_k)), # 正規化項(対数)
  density = exp(ln_C_dir) * apply(t(pi_point)^(alpha_k - 1), 2, prod), # 確率密度
  label = as.factor(paste0("N=", 0, ", alpha=(", paste0(alpha_k, collapse = ", "), ")")) # 試行回数とパラメータのラベル
)

# 初期値による予測分布のパラメーターを計算
pi_star_k <- alpha_k / sum(alpha_k)

# 初期値による予測分布のデータフレームを作成
predict_df <- tibble(
  k = 1:K, # 次元番号
  prob = pi_star_k, # 確率
  label = as.factor(paste0("N=", 0, ", pi_star=(", paste0(round(pi_star_k, 2), collapse = ", "), ")")) # 試行回数とパラメータのラベル
)

 試行ごとの結果を同じデータフレームに格納していく必要があります。事後分布をposterior_df、予測分布をpredict_dfとして、初期値の結果を持つように作成しておきます。

### 推論

# データ数(試行回数)を指定
N <- 100

# パラメータを推定
s_nk <- matrix(0, nrow = N, ncol = 3) # 受け皿を初期化
for(n in 1:N){
  
  # カテゴリ分布に従うデータを生成
  s_nk[n, ] <- rmultinom(n = 1, size = 1, prob = pi_truth_k) %>% 
    as.vector()
  
  # 事後分布のパラメータを更新
  alpha_k <- s_nk[n, ] + alpha_k
  
  # 事後分布を計算
  tmp_posterior_df <- tibble(
    x = pi_point[, 2] + (pi_point[, 3] / 2), # 三角座標への変換
    y = sqrt(3) * (pi_point[, 3] / 2), # 三角座標への変換
    ln_C_dir = lgamma(sum(alpha_k)) - sum(lgamma(alpha_k)), # 正規化項(対数)
    density = exp(ln_C_dir) * apply(t(pi_point)^(alpha_k - 1), 2, prod), # 確率密度
    label = as.factor(paste0("N=", n, ", alpha_hat=(", paste0(alpha_k, collapse = ", "), ")")) # 試行回数とパラメータのラベル
  )
  
  # 予測分布のパラメーターを更新
  pi_star_k <- alpha_k / sum(alpha_k)
  
  # 予測分布のデータフレームを作成
  tmp_predict_df <- tibble(
    k = 1:K, # 作図用の値
    prob = pi_star_k, # 確率
    label = as.factor(paste0(
      "N=", n, ", pi_hat_star=(", paste0(round(pi_star_k, 2), collapse = ", "), ")"
    )) # 試行回数とパラメータのラベル
  )
  
  # 結果を結合
  posterior_df <- rbind(posterior_df, tmp_posterior_df)
  predict_df <- rbind(predict_df, tmp_predict_df)
}

 観測された各データによってどのように学習する(分布が変化する)のかを確認するため、forループで1データずつ処理します。よって、データ数Nがイタレーション数になります。

 パラメータの推定値に関して、$\hat{\alpha}$に対応するalpha_hat_kを新たに作るのではなく、alpha_kをイタレーションごとに更新(上書き)していきます。
 それに伴い、事後分布のパラメータの計算式(3.28)の$\sum_{n=1}^N$の計算は、forループによってN回繰り返しs_nk[n, ]を加えることで行います。n回目のループ処理のときには、n-1回分のs_nk[n, ]が既にalpha_kに加えられているわけです。

 結果は次のようになります。

# 確認
head(posterior_df)
head(predict_df)
## # A tibble: 6 x 5
##       x     y ln_C_dir density label               
##   <dbl> <dbl>    <dbl>   <dbl> <fct>               
## 1 0         0    0.693       2 N=0, alpha=(1, 1, 1)
## 2 1         0    0.693       2 N=0, alpha=(1, 1, 1)
## 3 0.5       0    0.693       2 N=0, alpha=(1, 1, 1)
## 4 0.333     0    0.693       2 N=0, alpha=(1, 1, 1)
## 5 0.25      0    0.693       2 N=0, alpha=(1, 1, 1)
## 6 0.2       0    0.693       2 N=0, alpha=(1, 1, 1)

## # A tibble: 6 x 3
##       k  prob label                             
##   <int> <dbl> <fct>                             
## 1     1 0.333 N=0, pi_star=(0.33, 0.33, 0.33)   
## 2     2 0.333 N=0, pi_star=(0.33, 0.33, 0.33)   
## 3     3 0.333 N=0, pi_star=(0.33, 0.33, 0.33)   
## 4     1 0.5   N=1, pi_hat_star=(0.5, 0.25, 0.25)
## 5     2 0.25  N=1, pi_hat_star=(0.5, 0.25, 0.25)
## 6     3 0.25  N=1, pi_hat_star=(0.5, 0.25, 0.25)


・事後分布の推移

### 作図

# 真のパラメータのデータフレームを作成
parameter_df <- tibble(
  x = pi_truth_k[2] + (pi_truth_k[3] / 2), # 三角座標への変換
  y = sqrt(3) * (pi_truth_k[3] / 2), # 三角座標への変換
)

# 事後分布を作図
posterior_graph <- ggplot() + 
  geom_point(data = posterior_df, aes(x = x, y = y, color = density)) + # 事後分布
  geom_point(data = parameter_df, aes(x = x, y = y), shape = 4, size = 5) + # 真のパラメータ
  scale_color_gradientn(colors = c("blue", "green", "yellow", "red")) + # 散布図のグラデーション
  scale_x_continuous(breaks = c(0, 1), 
                     labels = c("(1, 0, 0)", "(0, 1, 0)")) + # x軸目盛
  scale_y_continuous(breaks = c(0, 0.87), 
                     labels = c("(1, 0, 0)", "(0, 0, 1)")) + # y軸目盛
  coord_fixed(ratio = 1) + # 縦横比
  gganimate::transition_manual(label) + # フレーム
  labs(title = "Dirichlet Distribution", 
       subtitle = "{current_frame}", 
       x = expression(paste(pi[1], ", ", pi[2], sep = "")), 
       y = expression(paste(pi[1], ", ", pi[3], sep = "")))

# gif画像を出力
gganimate::animate(posterior_graph, nframes = (N + 1), fps = 10)

 各フレームの順番を示す列をgganimate::transition_manual()に指定します。初期値(事前分布)を含むため、フレーム数はN + 1です。

・予測分布の推移

# Nフレーム分の真のモデルを格納したデータフレームを作成
label_vec <- unique(predict_df[["label"]]) # 各試行のラベルを抽出
model_df <- tibble()
for(n in 1:(N + 1)) {
  # n番目のフレーム用に作成
  tmp_df <- tibble(
    k = 1:K, 
    prob = pi_truth_k, 
    label = label_vec[n]
  )
  
  # 結果を結合
  model_df <- rbind(model_df, tmp_df)
}

# 予測分布を作図
predict_graph <- ggplot() + 
  geom_bar(data = predict_df, aes(x = k, y = prob), stat = "identity", position = "dodge", 
           fill = "purple") + # 予測分布
  geom_bar(data = model_df, aes(x = k, y = prob), stat = "identity", position = "dodge", 
           alpha = 0, color = "red", linetype = "dashed") + # 真の分布
  gganimate::transition_manual(label) + # フレーム
  ylim(c(0, 1)) + # y軸の表示範囲
  labs(title = "Categorical Distribution", 
       subtitle = "{current_frame}")

# gif画像を出力
gganimate::animate(predict_graph, nframes = (N + 1), fps = 10)

 真の分布についても予測分布と同じフレーム数分用意する必要があります(たぶん?)。そのため、ラベルを対応させて値を複製しています。


f:id:anemptyarchive:20210222111956g:plain
事後分布の推移:ディリクレ分布

f:id:anemptyarchive:20210222112035g:plain
予測分布の推移:カテゴリ分布


参考文献

  • 須山敦志『ベイズ推論による機械学習入門』(機械学習スタートアップシリーズ)杉山将監修,講談社,2017年.

おわりに

 加筆修正の際に記事を分割しました。

 {gganimate}の詳しい使い方はいつかまた別の機会に…

【次節の内容】

www.anarchive-beta.com