からっぽのしょこ

読んだら書く!書いたら読む!同じ事は二度調べ(たく)ない

【R】3.2.2:カテゴリ分布の学習と予測の実装【緑ベイズ入門のノート】

はじめに

 『ベイズ推論による機械学習入門』の学習時のノートです。「数式の行間を読んでみた」とそれを「RとPythonで組んでみた」によって、「数式」と「プログラム」から理解するのが目標です。
 省略してある内容等ありますので、本とあわせて読んでください。

 この記事では、「尤度関数をカテゴリ分布」、「事前分布をディリクレ分布」とした場合の「パラメータの事後分布」と「未観測値の予測分布」の計算をR言語で実装します。

【数式読解編】

www.anarchive-beta.com

【前の節の内容】

www.anarchive-beta.com

【他の節の内容】

www.anarchive-beta.com

【この節の内容】

3.2.2 カテゴリ分布の学習と予測

 尤度関数をカテゴリ分布(Categorical Distribution)、事前分布をディリクレ分布(Dirichlet Distribution)とするモデルに対するベイズ推論を実装します。人工的に生成したデータを用いて、カテゴリ分布のパラメータを推定し、また未観測データに対する予測分布を求めます。
 カテゴリ分布については「カテゴリ分布の定義式 - からっぽのしょこ」、ディリクレ分布については「ディリクレ分布の定義式 - からっぽのしょこ」を参照してください。

 利用するパッケージを読み込みます。

# 利用パッケージ
library(tidyverse)
library(MCMCpack)
library(gganimate)

 この記事では、基本的にパッケージ名::関数名()の記法を使うので、パッケージを読み込む必要はありません。ただし、作図コードがごちゃごちゃしないようにパッケージ名を省略しているため、ggplot2は読み込む必要があります。
 magrittrパッケージのパイプ演算子%>%ではなく、ベースパイプ(ネイティブパイプ)演算子|>を使っています。%>%に置き換えても処理できます。
 分布の変化をアニメーション(gif画像)で確認するのにgganimateパッケージを利用します。不要であれば省略してください。

三角座標の準備

 ディリクレ分布を三角図により可視化するために、三角座標を描画するための準備をします。詳しくは「ggplot2で三角グラフを作図したい - からっぽのしょこ」と「ggplot2で三角グラフの等高線を作図したい - からっぽのしょこ」を参照してください。

・コード(クリックで展開)

 軸目盛の間隔を設定して、三角座標を描画するためのデータフレームを作成します。

# 軸目盛の位置を指定
axis_vals <- seq(from = 0, to = 1, by = 0.1)

# 枠線用の値を作成
ternary_axis_df <- tibble::tibble(
  y_1_start = c(0.5, 0, 1),         # 始点のx軸の値
  y_2_start = c(0.5*sqrt(3), 0, 0), # 始点のy軸の値
  y_1_end = c(0, 1, 0.5),           # 終点のx軸の値
  y_2_end = c(0, 0, 0.5*sqrt(3)),   # 終点のy軸の値
  axis = c("x_1", "x_2", "x_3")     # 元の軸
)

# グリッド線用の値を作成
ternary_grid_df <- tibble::tibble(
  y_1_start = c(
    0.5 * axis_vals, 
    axis_vals, 
    0.5 * axis_vals + 0.5
  ), # 始点のx軸の値
  y_2_start = c(
    sqrt(3) * 0.5 * axis_vals, 
    rep(0, times = length(axis_vals)), 
    sqrt(3) * 0.5 * (1 - axis_vals)
  ), # 始点のy軸の値
  y_1_end = c(
    axis_vals, 
    0.5 * axis_vals + 0.5, 
    0.5 * rev(axis_vals)
  ), # 終点のx軸の値
  y_2_end = c(
    rep(0, times = length(axis_vals)), 
    sqrt(3) * 0.5 * (1 - axis_vals), 
    sqrt(3) * 0.5 * rev(axis_vals)
  ), # 終点のy軸の値
  axis = c("x_1", "x_2", "x_3") |> 
    rep(each = length(axis_vals)) # 元の軸
)

# 軸ラベル用の値を作成
ternary_axislabel_df <- tibble::tibble(
  y_1 = c(0.25, 0.5, 0.75),               # x軸の値
  y_2 = c(0.25*sqrt(3), 0, 0.25*sqrt(3)), # y軸の値
  label = c("phi[1]", "phi[2]", "phi[3]"),      # 軸ラベル
  h = c(3, 0.5, -2),  # 水平方向の調整用の値
  v = c(0.5, 3, 0.5), # 垂直方向の調整用の値
  axis = c("x_1", "x_2", "x_3") # 元の軸
)

# 軸目盛ラベル用の値を作成
ternary_ticklabel_df <- tibble::tibble(
  y_1 = c(
    0.5 * axis_vals, 
    axis_vals, 
    0.5 * axis_vals + 0.5
  ), # x軸の値
  y_2 = c(
    sqrt(3) * 0.5 * axis_vals, 
    rep(0, times = length(axis_vals)), 
    sqrt(3) * 0.5 * (1 - axis_vals)
  ), # y軸の値
  label = c(
    rev(axis_vals), 
    axis_vals, 
    rev(axis_vals)
  ), # 軸目盛ラベル
  h = c(
    rep(1.5, times = length(axis_vals)), 
    rep(1.5, times = length(axis_vals)), 
    rep(-0.5, times = length(axis_vals))
  ), # 水平方向の調整用の値
  v = c(
    rep(0.5, times = length(axis_vals)), 
    rep(0.5, times = length(axis_vals)), 
    rep(0.5, times = length(axis_vals))
  ), # 垂直方向の調整用の値
  angle = c(
    rep(-60, times = length(axis_vals)), 
    rep(60, times = length(axis_vals)), 
    rep(0, times = length(axis_vals))
  ), # ラベルの表示角度
  axis = c("x_1", "x_2", "x_3") |> 
    rep(each = length(axis_vals)) # 元の軸
)


 格子点の数を指定して、確率密度を計算するためのマトリクスを作成します。

# 三角座標の値を作成
y_1_vals <- seq(from = 0, to = 1, length.out = 151)
y_2_vals <- seq(from = 0, to = 0.5*sqrt(3), length.out = 150)

# 格子点を作成
y_mat <- tidyr::expand_grid(
  y_1 = y_1_vals, 
  y_2 = y_2_vals
) |> # 格子点を作成
  as.matrix() # マトリクスに変換

# 3次元変数に変換
phi_mat <- tibble::tibble(
  phi_2 = y_mat[, 1] - y_mat[, 2] / sqrt(3), 
  phi_3 = 2 * y_mat[, 2] / sqrt(3)
) |> # 元の座標に変換
  dplyr::mutate(
    phi_2 = dplyr::if_else(phi_2 >= 0 & phi_2 <= 1, true = phi_2, false = as.numeric(NA)), 
    phi_3 = dplyr::if_else(phi_3 >= 0 & phi_3 <= 1 & !is.na(phi_2), true = phi_3, false = as.numeric(NA)), 
    phi_1 = 1 - phi_2 - phi_3, 
    phi_1 = dplyr::if_else(phi_1 >= 0 & phi_1 <= 1, true = phi_1, false = as.numeric(NA))
  ) |> # 範囲外の値をNAに置換
  dplyr::select(phi_1, phi_2, phi_3) |> # 順番を変更
  as.matrix() # マトリクスに変換


 4つのデータフレームと2つのマトリクスを使って以降の作図を行います。

ベイズ推論の実装

 まずは、モデルを設定して、データを生成します。生成したデータを用いて、事後分布のパラメータを計算します。さらに、事後分布のパラメータを用いて、予測分布のパラメータを計算します。

生成分布の設定

 データ生成分布(真の分布)として、カテゴリ分布$\mathrm{Cat}(\mathbf{s} | \boldsymbol{\pi})$を設定します。

 真の分布(カテゴリ分布)のパラメータ$\boldsymbol{\pi}$を設定します。この例では三角図で表現するため、次元数を$K = 3$とします。パラメータの計算自体は次元数に関わらず行えます。

# 真のパラメータを指定
pi_truth_k <- c(0.3, 0.5, 0.2)

# 次元数を設定:(固定)
K <- length(pi_truth_k)
K
## [1] 3

 各クラスタが割り当てられる($s_{n,k} = 1$となる)確率$\boldsymbol{\pi} = (\pi_1, \pi_2, \cdots, \pi_K)$をpi_truth_kとして、$0 \leq \pi_k \leq 1$、$\sum_{k=1}^K \pi_k = 1$の値を指定します。pi_truth_kが真のパラメータであり、この値を求めるのがここでの目的です。

 真の分布を計算して、作図用のデータフレームを作成します。

# 真の分布を格納
model_df <- tibble::tibble(
  k = 1:K, # 次元番号
  prob = pi_truth_k # 確率
)
model_df
## # A tibble: 3 × 2
##       k  prob
##   <int> <dbl>
## 1     1   0.3
## 2     2   0.5
## 3     3   0.2

 次元番号(1からKの整数)と対応するパラメータをデータフレームに格納します。

 真の分布のグラフを作成します。

# パラメータラベル用の文字列を作成
model_param_text <- paste0("pi==(list(", paste0(pi_truth_k, collapse = ", "), "))")

# 真の分布を作図
ggplot() + 
  geom_bar(data = model_df, mapping = aes(x = k, y = prob, fill = "model"), 
           stat = "identity") + # 真の分布
  scale_fill_manual(breaks = "model", values = "purple", labels = "true model", name = "") + # バーの色:(凡例表示用)
  scale_x_continuous(breaks = 1:K, minor_breaks = FALSE) + # x軸目盛
  coord_cartesian(ylim = c(0, 1)) + # 表示範囲
  labs(title = "Categorical Distribution", 
       subtitle = parse(text = model_param_text), 
       x = "k", y = "probability")

真の分布(カテゴリ分布)のグラフ

 真のパラメータを求めることは、真の分布を求めることを意味します。

データの生成

 続いて、設定した生成分布から観測データ$\mathbf{S} = \{\mathbf{s}_1, \mathbf{s}_2, \cdots, \mathbf{s}_N\}$、$\mathbf{s}_n = (s_{n,1}, s_{n,2}, \cdots, s_{n,K})$を生成します。

 カテゴリ分布に従う$N$個のデータをランダムに生成します。

# データ数を指定
N <- 100

# カテゴリ分布に従うデータを生成
s_nk <- rmultinom(n = N, size = 1, prob = pi_truth_k) |> 
  t()
head(s_nk)
##      [,1] [,2] [,3]
## [1,]    1    0    0
## [2,]    1    0    0
## [3,]    0    1    0
## [4,]    0    0    1
## [5,]    0    1    0
## [6,]    0    1    0

 生成するデータ数$N$を指定します。
 カテゴリ分布に従う乱数は、多項分布に従う乱数生成関数rmultinom()の試行回数の引数size1にすることで生成できます。データ数の引数nN、確率の引数probpi_truth_kを指定します。生成したN個のデータをs_nkとします。各試行の結果を列としたマトリクスが返ってくるので、t()で転置します。

 観測データ$\mathbf{S}$の度数を集計します。

# 観測データを集計
freq_df <- tibble::tibble(
  k = 1:K, # 次元番号
  freq = colSums(s_nk) # 度数
)
freq_df
## # A tibble: 3 × 2
##       k  freq
##   <int> <dbl>
## 1     1    32
## 2     2    52
## 3     3    16

 カテゴリ分布の乱数$\mathbf{s}_n$は1つの要素(次元)が1でそれ以外は0をとるので、各クラスタとなったデータ数は次元ごとの総和$\sum_{n=1}^N s_{n,k}$で得られます。列ごとの和はcolSums()で計算できます。

 観測データのヒストグラムを真の分布と重ねて確認します。

# パラメータラベル用の文字列を作成
sample_param_text <- paste0(
  "list(", 
  "N==", N, "~(list(", paste0(freq_df[["freq"]], collapse = ", "), "))", 
  ", pi==(list(", paste0(pi_truth_k, collapse = ", "), "))", 
  ")"
)

# 観測データのヒストグラムを作成
ggplot() + 
  geom_bar(data = freq_df, mapping = aes(x = k, y = freq/N, fill = "data"), 
           stat = "identity") + # 観測データ:(相対度数)
  geom_bar(data = model_df, mapping = aes(x = k, y = prob, fill = "model", color = "model"), 
           stat = "identity", size = 1, linetype = "dashed") + # 真の分布
  scale_fill_manual(values = c(model = NA, data = "pink"), na.value = NA, 
                    labels = c(model = "true model", data = "observation data"), name = "") + # バーの色:(凡例表示用)
  scale_color_manual(values = c(model = "red", data = "pink"), 
                     labels = c(model = "true model", data = "observation data"), name = "") + # 線の色:(凡例表示用)
  scale_x_continuous(breaks = 1:K, minor_breaks = FALSE) + # x軸目盛
  coord_cartesian(ylim = c(0, 1)) + # 表示範囲
  labs(title = "Categorical Distribution", 
       subtitle = parse(text = sample_param_text), 
       x = "k", y = "relative frequency")

観測データ(カテゴリ分布の乱数)のヒストグラム

 各値の度数(freq列)をデータ数Nで割って、相対度数をバーの高さとします。

 データ数が十分に大きいと、ヒストグラムの形状が真の分布に近付きます。

事前分布の設定

 次は、尤度に対する共役事前分布を設定します。カテゴリ分布のパラメータ$\boldsymbol{\pi}$の事前分布としてディリクレ分布$\mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha})$を設定します。

 事前分布(ディリクレ分布)のパラメータ(超パラメータ)$\boldsymbol{\alpha}$を設定します。

# 事前分布のパラメータを指定
alpha_k <- rep(1, times = K)
alpha_k
## [1] 1 1 1

 $\boldsymbol{\alpha} = (\alpha_1, \alpha_2, \cdots, \alpha_K)$をalpha_kとして、$\alpha_k > 0$の値を指定します。

 事前分布を計算します。

# 事前分布を計算:式(2.48)
prior_df <- tibble::tibble(
  y_1 = y_mat[, 1], # x軸の値
  y_2 = y_mat[, 2], # y軸の値
  density = MCMCpack::ddirichlet(x = phi_mat, alpha = alpha_k) # 確率密度
) |> 
  dplyr::mutate(
    fill_flg = !is.na(rowSums(phi_mat)), 
    density = dplyr::if_else(fill_flg, true = density, false = as.numeric(NA))
  ) # 範囲外の値をNAに置換
prior_df
## # A tibble: 22,650 × 4
##      y_1     y_2 density fill_flg
##    <dbl>   <dbl>   <dbl> <lgl>   
##  1     0 0           NaN TRUE    
##  2     0 0.00581      NA FALSE   
##  3     0 0.0116       NA FALSE   
##  4     0 0.0174       NA FALSE   
##  5     0 0.0232       NA FALSE   
##  6     0 0.0291       NA FALSE   
##  7     0 0.0349       NA FALSE   
##  8     0 0.0407       NA FALSE   
##  9     0 0.0465       NA FALSE   
## 10     0 0.0523       NA FALSE   
## # … with 22,640 more rows

 グラフ用の値y_matと、計算用の値x_matにより求めた確率密度をデータフレームに格納します。ただし、三角座標外の要素(phi_matの欠損値を含む行)については欠損値NAに置き換えます。
 ディリクレ分布の確率密度は、MCMCpack::dderichlet()で計算できます。確率変数の引数xphi_mat、パラメータの引数alphaalpha_kを指定します。

 真のパラメータの位置を三角図に表示するために、変換してデータフレームに格納しておきます。

# 真のパラメータを格納
parameter_df <- tibble::tibble(
  x = pi_truth_k[2] + 0.5 * pi_truth_k[3], # 三角座標への変換
  y = sqrt(3) * 0.5 * pi_truth_k[3], # 三角座標への変換
)
parameter_df
## # A tibble: 1 × 2
##       x     y
##   <dbl> <dbl>
## 1   0.6 0.173

 2次元座標におけるx軸の値は$x = \phi_2 + \frac{\phi_3}{2}$、y軸の値は$y = \frac{\sqrt{3} \phi_3}{2}$で計算できます。

 事前分布の等高線グラフを作成します。

# 真のパラメータを格納
parameter_df <- tibble::tibble(
  x = pi_truth_k[2] + 0.5 * pi_truth_k[3], # 三角座標への変換
  y = sqrt(3) * 0.5 * pi_truth_k[3], # 三角座標への変換
)

# パラメータラベル用の文字列を作成
prior_param_text <- paste0("alpha==(list(", paste0(alpha_k, collapse = ", "), "))")

# 事前分布を作図
ggplot() + 
  geom_segment(data = ternary_grid_df, 
               mapping = aes(x = y_1_start, y = y_2_start, xend = y_1_end, yend = y_2_end), 
               color = "gray50", linetype = "dashed") + # 三角図のグリッド線
  geom_segment(data = ternary_axis_df, 
               mapping = aes(x = y_1_start, y = y_2_start, xend = y_1_end, yend = y_2_end), 
               color = "gray50") + # 三角図の枠線
  geom_text(data = ternary_ticklabel_df, 
            mapping = aes(x = y_1, y = y_2, label = label, hjust = h, vjust = v, angle = angle)) + # 三角図の軸目盛ラベル
  geom_text(data = ternary_axislabel_df, 
            mapping = aes(x = y_1, y = y_2, label = label, hjust = h, vjust = v), 
            parse = TRUE, size = 6) + # 三角図の軸ラベル
  geom_contour_filled(data = prior_df, 
                      mapping = aes(x = y_1, y = y_2, z = density, fill = ..level..), 
                      alpha = 0.8) + # 事前分布
  geom_point(data = parameter_df, mapping = aes(x = x, y = y, color = "param"), 
             shape = 4, size = 6) + # 真のパラメータ
  scale_x_continuous(breaks = c(0, 0.5, 1), labels = NULL) + # x軸
  scale_y_continuous(breaks = c(0, 0.25*sqrt(3), 0.5*sqrt(3)), labels = NULL) + # y軸
  scale_color_manual(breaks = "param", values = "red", labels = "true parameter", name = "") + # 線の色:(凡例表示用)
  coord_fixed(ratio = 1, clip = "off") + # アスペクト比
  theme(axis.ticks = element_blank(), 
        panel.grid.minor = element_blank()) + # 図の体裁
  labs(title = "Dirichlet Distribution", 
       subtitle = parse(text = prior_param_text), 
       fill = "density", 
       x = "", y = "")

事前分布(ディリクレ分布)のグラフ

 真のパラメータの値を赤色のバツ印で示します。
 無情報事前分布として、全ての確率密度が等しい分布を設定しました。

事後分布の計算

 観測データ$\mathbf{S}$からパラメータ$\boldsymbol{\pi}$の事後分布を求めます(パラメータ$\boldsymbol{\pi}$を分布推定します)。

 観測データ$\mathbf{X}$を用いて、事後分布(ディリクレ分布)のパラメータ$\hat{\boldsymbol{\alpha}}$を計算します。

# 事後分布のパラメータを計算:式(3.28)
alpha_hat_k <- colSums(s_nk) + alpha_k
alpha_hat_k
## [1] 33 53 17

 事後分布のパラメータ$\hat{\boldsymbol{\alpha}} = (\hat{\alpha}_1, \hat{\alpha}_2, \cdots, \hat{\alpha}_K)$は、次の式で計算できます。

$$ \hat{\alpha}_k = \sum_{n=1}^N s_{n,k} + \alpha_k \tag{3.28} $$

 事前分布のパラメータ$\alpha_k$に、それぞれ$s_{n,k} = 1$となった数を加えています。

 事後分布を計算します。

# 事後分布を計算:式(2.48)
posterior_df <- tibble::tibble(
  y_1 = y_mat[, 1], # x軸の値
  y_2 = y_mat[, 2], # y軸の値
  density = MCMCpack::ddirichlet(x = phi_mat, alpha = alpha_hat_k) # 確率密度
) |> 
  dplyr::mutate(
    fill_flg = !is.na(rowSums(phi_mat)), 
    density = dplyr::if_else(fill_flg, true = density, false = as.numeric(NA))
  ) # 範囲外の値をNAに置換
posterior_df
## # A tibble: 22,650 × 4
##      y_1     y_2 density fill_flg
##    <dbl>   <dbl>   <dbl> <lgl>   
##  1     0 0             0 TRUE    
##  2     0 0.00581      NA FALSE   
##  3     0 0.0116       NA FALSE   
##  4     0 0.0174       NA FALSE   
##  5     0 0.0232       NA FALSE   
##  6     0 0.0291       NA FALSE   
##  7     0 0.0349       NA FALSE   
##  8     0 0.0407       NA FALSE   
##  9     0 0.0465       NA FALSE   
## 10     0 0.0523       NA FALSE   
## # … with 22,640 more rows

 更新した超パラメータalpha_hat_kを使って、事前分布のときと同様にして処理します。

 事後分布の等高線グラフを作成します。

# パラメータラベル用の文字列を作成
posterior_param_text <- paste0(
  "list(", 
  "N==", N, "~(list(", paste0(freq_df[["freq"]], collapse = ", "), "))", 
  ", hat(alpha)==(list(", paste0(alpha_hat_k, collapse = ", "), "))", 
  ")"
)

# 事後分布を作図
ggplot() + 
  geom_segment(data = ternary_grid_df, 
               mapping = aes(x = y_1_start, y = y_2_start, xend = y_1_end, yend = y_2_end), 
               color = "gray50", linetype = "dashed") + # 三角図のグリッド線
  geom_segment(data = ternary_axis_df, 
               mapping = aes(x = y_1_start, y = y_2_start, xend = y_1_end, yend = y_2_end), 
               color = "gray50") + # 三角図の枠線
  geom_text(data = ternary_ticklabel_df, 
            mapping = aes(x = y_1, y = y_2, label = label, hjust = h, vjust = v, angle = angle)) + # 三角図の軸目盛ラベル
  geom_text(data = ternary_axislabel_df, 
            mapping = aes(x = y_1, y = y_2, label = label, hjust = h, vjust = v), 
            parse = TRUE, size = 6) + # 三角図の軸ラベル
  geom_contour_filled(data = posterior_df, 
                      mapping = aes(x = y_1, y = y_2, z = density, fill = ..level..), 
                      alpha = 0.8) + # 事後分布
  geom_point(data = parameter_df, mapping = aes(x = x, y = y, color = "param"), 
             shape = 4, size = 6) + # 真のパラメータ
  scale_x_continuous(breaks = c(0, 0.5, 1), labels = NULL) + # x軸
  scale_y_continuous(breaks = c(0, 0.25*sqrt(3), 0.5*sqrt(3)), labels = NULL) + # y軸
  scale_color_manual(breaks = "param", values = "red", labels = "true parameter", name = "") + # 線の色:(凡例表示用)
  coord_fixed(ratio = 1, clip = "off") + # アスペクト比
  theme(axis.ticks = element_blank(), 
        panel.grid.minor = element_blank()) + # 図の体裁
  labs(title = "Dirichlet Distribution", 
       subtitle = parse(text = posterior_param_text), 
       fill = "density", 
       x = "", y = "")

事後分布(ディリクレ分布)のグラフ

 $\boldsymbol{\pi}$の真の値付近をピークとする分布になっています。

予測分布の計算

 最後に、観測データ$\mathbf{S}$または(観測データから求めた)事後分布のパラメータ$\hat{\boldsymbol{\alpha}}$から未観測のデータ$\mathbf{s}_{*}$の予測分布を求めます。

 事後分布のパラメータ$\hat{\boldsymbol{\alpha}}$、または観測データ$\mathbf{S}$と事前分布のパラメータ$\boldsymbol{\alpha}$を用いて、予測分布(カテゴリ分布)のパラメータ$\boldsymbol{\pi}_{*}$を計算します。

# 予測分布のパラメータを計算:式(3.31')
pi_star_hat_k <- alpha_hat_k / sum(alpha_hat_k)
#pi_star_hat_k <- (colSums(s_nk) + alpha_k) / sum(colSums(s_nk) + alpha_k)
pi_star_hat_k
## [1] 0.3203883 0.5145631 0.1650485

 予測分布のパラメータは、次の式で計算できます。

$$ \hat{\pi}_{*,k} = \frac{ \hat{\alpha}_k }{ \sum_{k'=1}^K \hat{\alpha}_{k'} } = \frac{ \sum_{n=1}^N s_{n,k} + \alpha_k }{ \sum_{k'=1}^K \Bigl\{ \sum_{n'=1}^N s_{n',k'} + \alpha_{k'} \Bigr\} } \tag{3.31'} $$

 1つ目の式だと、事後分布のパラメータalpha_hat_kを使って計算できます。2つ目の式だと、観測データs_nkと事前分布のパラメータalpha_kを使って計算できます。
 $\hat{\pi}_{*}$は、$s_{*,k} = 1$となる確率を表し、$s_{n,k} = 1$の数が多いほど値が1に近付きます。

 予測分布を計算します。

# 予測分布を格納
predict_df <- tibble::tibble(
  k = 1:K, # 次元番号
  prob = pi_star_hat_k # 確率
)
predict_df
## # A tibble: 3 × 2
##       k  prob
##   <int> <dbl>
## 1     1 0.320
## 2     2 0.515
## 3     3 0.165

 予測分布のパラメータpi_star_hatを使って、真の分布のときと同様にして処理します。

 予測分布のグラフを真の分布と重ねて作成します。

# パラメータラベル用の文字列を作成
predict_param_text <- paste0(
  "list(", 
  "N==", N, "~(list(", paste0(freq_df[["freq"]], collapse = ", "), "))", 
  ", pi[s]==(list(", paste0(round(pi_star_hat_k, digits = 2), collapse = ", "), "))", 
  ")"
)

# 予測分布を作図
ggplot() + 
  geom_bar(data = predict_df, mapping = aes(x = k, y = prob, fill = "predict"), 
           stat = "identity") + # 予測分布
  geom_bar(data = model_df, aes(x = k, y = prob, fill = "model", color = "model"), 
           stat = "identity", size = 1, linetype = "dashed") + # 真の分布
  scale_fill_manual(values = c(model = NA, predict ="purple"), na.value = NA, 
                    labels = c(model = "true model", predict = "predict"), name = "") + # バーの色:(凡例表示用)
  scale_color_manual(values = c(model = "red", predict ="purple"), 
                     labels = c(model = "true model", predict = "predict"), name = "") + # 線の色:(凡例表示用)
  scale_x_continuous(breaks = 1:K, minor_breaks = FALSE) + # x軸目盛
  guides(fill = guide_legend(override.aes = list(fill = c(NA, "purple"))), 
         color = guide_legend(override.aes = list(size = c(0.5, 0.5), linetype = c("dashed", "blank")))) + # 凡例の体裁:(凡例表示用)
  coord_cartesian(ylim = c(0, 1)) + # 表示範囲
  labs(title = "Categorical Distribution", 
       subtitle = parse(text = predict_param_text), 
       x = "k", y = "probability")

予測分布(カテゴリ分布)のグラフ

 観測データが増えると、予測分布が真の分布に近付きます。

 以上で、カテゴリ分布のベイズ推論を実装できました。

学習推移の可視化

 gganimateパッケージを利用して、パラメータの推定値(更新値)の推移(分布の変化)をアニメーション(gif画像)で確認します。

モデルの設定

 gganimateパッケージを利用して、パラメータの推定値の推移のアニメーション(gif画像)を作成するためのコードです。

 真の分布のパラメータ$\boldsymbol{\pi}$と、事前分布のパラメータ(の初期値)$\boldsymbol{\alpha}$を設定します。

# 真のパラメータを指定
pi_truth_k <- c(0.3, 0.5, 0.2)

# 事前分布のパラメータを指定
alpha_k <- rep(1, times = K)

# 次元数を設定:(固定)
K <- length(pi_truth_k)

# データ数(試行回数)を指定
N <- 100

 実装時と同じく、パラメータを指定します。

 事後分布と予測分布のパラメータの計算(更新)処理について、2通りの方法を載せます。目的に応じて使い分けてください。

推論処理:for関数による処理

 1つ目の処理方法では、for()を使って、1データずつ生成してパラメータの更新を繰り返し処理します。こちらの方が、前ステップで求めた事後分布(のパラメータ)を次ステップの事前分布(のパラメータ)として用いる逐次学習をイメージしやすいです。

・コード(クリックで展開)

 超パラメータの初期値を使って、事前分布を計算します。

# 事前分布(ディリクレ分布)を計算:式(2.48)
anime_posterior_df <- tibble::tibble(
  y_1 = y_mat[, 1], # x軸の値
  y_2 = y_mat[, 2], # y軸の値
  density = MCMCpack::ddirichlet(x = phi_mat, alpha = alpha_k) # 確率密度
) |> 
  dplyr::mutate(
    fill_flg = !is.na(rowSums(phi_mat)), 
    density = dplyr::if_else(fill_flg, true = density, false = as.numeric(NA)), # 範囲外の値をNAに置換
    param = paste0(
      "n=", 0, " (", paste0(rep(0, times = K), collapse = ", "), ")", 
      ", α=(", paste0(alpha_k, collapse = ", "), ")"
    ) |> 
      as.factor() # フレーム切替用のラベル
  )
anime_posterior_df
## # A tibble: 22,650 × 5
##      y_1     y_2 density fill_flg param                     
##    <dbl>   <dbl>   <dbl> <lgl>    <fct>                     
##  1     0 0           NaN TRUE     n=0 (0, 0, 0), α=(1, 1, 1)
##  2     0 0.00581      NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  3     0 0.0116       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  4     0 0.0174       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  5     0 0.0232       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  6     0 0.0291       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  7     0 0.0349       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  8     0 0.0407       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  9     0 0.0465       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
## 10     0 0.0523       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
## # … with 22,640 more rows

 現在のパラメータを文字列結合し因子型に変換して、フレーム切替用のラベル列とします。

 また、事前分布のパラメータを使って、予測分布のパラメータと予測分布を計算します。

# 初期値による予測分布のパラメーターを計算:式(3.31)
pi_star_k <- alpha_k / sum(alpha_k)

# 初期値による予測分布を格納
anime_predict_df <- tibble::tibble(
  k = 1:K, # 次元番号
  prob = pi_star_k, # 確率
  param = paste0(
      "N=", 0, " (", paste0(rep(0, times = K), collapse = ", "), ")", 
      ", π=(", paste0(round(pi_star_k, digits = 2), collapse = ", "), ")"
    ) |> 
    as.factor() # フレーム切替用ラベル
)
anime_predict_df
## # A tibble: 3 × 3
##       k  prob param                              
##   <int> <dbl> <fct>                              
## 1     1 0.333 N=0 (0, 0, 0), π=(0.33, 0.33, 0.33)
## 2     2 0.333 N=0 (0, 0, 0), π=(0.33, 0.33, 0.33)
## 3     3 0.333 N=0 (0, 0, 0), π=(0.33, 0.33, 0.33)

 同様に、フレーム切替用のラベル列を作成します。

 パラメータの更新処理をN回繰り返します。

# 観測データの受け皿を作成
s_nk <- matrix(NA, nrow = N, ncol = K)

# ベイズ推論
for(n in 1:N){
  
  # カテゴリ分布に従うデータを生成
  s_nk[n, ] <- rmultinom(n = 1, size = 1, prob = pi_truth_k) |> 
    as.vector()
  
  # 事後分布のパラメータを更新:式(3.28)
  alpha_k <- s_nk[n, ] + alpha_k
  
  # 事後分布を計算:式(2.48)
  tmp_posterior_df <- tibble::tibble(
    y_1 = y_mat[, 1], # x軸の値
    y_2 = y_mat[, 2], # y軸の値
    density = MCMCpack::ddirichlet(x = phi_mat, alpha = alpha_k) # 確率密度
  ) |> 
    dplyr::mutate(
      fill_flg = !is.na(rowSums(phi_mat)), 
      density = dplyr::if_else(fill_flg, true = density, false = as.numeric(NA)), # 範囲外の値をNAに置換
      param = paste0(
        "n=", n, " (", paste0(colSums(s_nk), collapse = ", "), ")", 
        ", α=(", paste0(alpha_k, collapse = ", "), ")"
      ) |> 
        as.factor() # フレーム切替用のラベル
    )
  
  # 予測分布のパラメーターを更新:式(3.31)
  pi_star_k <- alpha_k / sum(alpha_k)
  
  # 予測分布を格納
  tmp_predict_df <- tibble::tibble(
    k = 1:K, # 次元番号
    prob = pi_star_k, # 確率
    param = paste0(
      "N=", n, " (", paste0(colSums(s_nk), collapse = ", "), ")", 
      ", π=(", paste0(round(pi_star_k, digits = 2), collapse = ", "), ")"
    ) |> 
      as.factor() # フレーム切替用ラベル
  )
  
  # n回目の結果を結合
  anime_posterior_df <- dplyr::bind_rows(anime_posterior_df, tmp_posterior_df)
  anime_predict_df   <- dplyr::bind_rows(anime_predict_df, tmp_predict_df)
  
  # 途中経過を表示
  message("\r", n, " / ", N, appendLF = FALSE)
}

 超パラメータに関して、$\hat{\alpha}$に対応するalpha_hat_kを新たに作るのではなく、alpha_kを繰り返し更新(上書き)していきます。これにより、事後分布のパラメータの計算式(3.28)の$\sum_{n=1}^N s_{n,k}$の計算は、ループ処理によってN回繰り返しs_nk[n, ]を加えることで行います。n回目のループ処理のときには、n-1回分のs_nk[n, ]が既にalpha_kに加えられています。
 更新したパラメータを使って事後分布と予測分布を計算して、それぞれ計算結果をanime_***_dfに結合していきます。

 結果を確認します。

# 確認
head(s_nk); anime_posterior_df; anime_predict_df
##      [,1] [,2] [,3]
## [1,]    1    0    0
## [2,]    0    1    0
## [3,]    1    0    0
## [4,]    0    1    0
## [5,]    1    0    0
## [6,]    0    0    1
## # A tibble: 2,287,650 × 5
##      y_1     y_2 density fill_flg param                     
##    <dbl>   <dbl>   <dbl> <lgl>    <fct>                     
##  1     0 0           NaN TRUE     n=0 (0, 0, 0), α=(1, 1, 1)
##  2     0 0.00581      NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  3     0 0.0116       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  4     0 0.0174       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  5     0 0.0232       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  6     0 0.0291       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  7     0 0.0349       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  8     0 0.0407       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
##  9     0 0.0465       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
## 10     0 0.0523       NA FALSE    n=0 (0, 0, 0), α=(1, 1, 1)
## # … with 2,287,640 more rows
## # A tibble: 303 × 3
##        k  prob param                                
##    <int> <dbl> <fct>                                
##  1     1 0.333 N=0 (0, 0, 0), π=(0.33, 0.33, 0.33)  
##  2     2 0.333 N=0 (0, 0, 0), π=(0.33, 0.33, 0.33)  
##  3     3 0.333 N=0 (0, 0, 0), π=(0.33, 0.33, 0.33)  
##  4     1 0.5   N=1 (NA, NA, NA), π=(0.5, 0.25, 0.25)
##  5     2 0.25  N=1 (NA, NA, NA), π=(0.5, 0.25, 0.25)
##  6     3 0.25  N=1 (NA, NA, NA), π=(0.5, 0.25, 0.25)
##  7     1 0.4   N=2 (NA, NA, NA), π=(0.4, 0.4, 0.2)  
##  8     2 0.4   N=2 (NA, NA, NA), π=(0.4, 0.4, 0.2)  
##  9     3 0.2   N=2 (NA, NA, NA), π=(0.4, 0.4, 0.2)  
## 10     1 0.5   N=3 (NA, NA, NA), π=(0.5, 0.33, 0.17)
## # … with 293 more rows

 それぞれ「y_matの行数・K」掛ける「N+1」行のデータフレームになります。行数が増えすぎないように注意してくださいしてください。


推論処理:tidyverseパッケージによる処理

 2つ目の処理方法では、tidyverseパッケージの関数を使って、一度の処理でN+1回分の計算をします。こちらの方が、処理時間が短いです。

・コード(クリックで展開)

 カテゴリ分布に従う$N$個のデータを生成します。

# カテゴリ分布に従うデータを生成
s_nk <- rmultinom(n = N, size = 1, prob = pi_truth_k) |> 
  t()
head(s_nk)
##      [,1] [,2] [,3]
## [1,]    1    0    0
## [2,]    0    1    0
## [3,]    1    0    0
## [4,]    0    1    0
## [5,]    1    0    0
## [6,]    0    0    1


 ラベル用のテキストとして、観測データを集計し超パラメータを計算します。

# レベルを設定用の文字列を作成
freq_vec <- rbind(rep(0, times = K), s_nk) |> 
  apply(MARGIN = 2, FUN = cumsum) |> # 累積和を計算
  apply(MARGIN = 1, FUN = paste0, collapse = ", ") # 試行ごとに結合
param_vec <- rbind(rep(0, times = K), s_nk) |> 
  apply(MARGIN = 2, FUN = cumsum) |> # 累積和を計算
  (\(.) {t(.) + alpha_k})() |> # 事後分布のパラメータを計算
  apply(MARGIN = 2, FUN = paste0, collapse = ", ") # 試行ごとに結合
posterior_level_vec <- paste0("n=", 0:N, " (", freq_vec, "), α=(", param_vec, ")")
head(posterior_level_vec)
## [1] "n=0 (0, 0, 0), α=(1, 1, 1)" "n=1 (1, 0, 0), α=(2, 1, 1)"
## [3] "n=2 (1, 1, 0), α=(2, 2, 1)" "n=3 (2, 1, 0), α=(3, 2, 1)"
## [5] "n=4 (2, 2, 0), α=(3, 3, 1)" "n=5 (3, 2, 0), α=(4, 3, 1)"

 試行ごとに、次元ごとの度数と事後分布のパラメータをそれぞれ文字列にまとめておき、さらに2つを文字列結合します。

 事前分布を含めたN+1回分の事後分布を計算します。

# 試行ごとに事後分布を計算
anime_posterior_df <- tidyr::expand_grid(
  n = 0:N, # 試行回数
  tibble::tibble(
    y_1 = y_mat[, 1], # x軸の値
    y_2 = y_mat[, 2]  # y軸の値
  )
) |> # 試行ごとに格子点を複製
  dplyr::group_by(n) |> # 試行ごとの計算用
  dplyr::mutate(
    # 事後分布のパラメータを計算:式(3.28)
    sum_s_k_lt = colSums(s_nk[0:unique(n), , drop = FALSE]) |> 
      list(), 
    alpha_k_lt = (colSums(s_nk[0:unique(n), , drop = FALSE]) + alpha_k) |> 
      list(), 
    # 事後分布を計算:式(2.48)
    density = MCMCpack::ddirichlet(x = phi_mat, alpha = alpha_k_lt[[1]]), # 確率密度
    fill_flg = !is.na(rowSums(phi_mat)), 
    density = dplyr::if_else(fill_flg, true = density, false = as.numeric(NA)), # 範囲外の値をNAに置換
    param = paste0(
      "n=", unique(n), " (", paste0(sum_s_k_lt[[1]], collapse = ", "), ")", 
      ", α=(", paste0(alpha_k_lt[[1]], collapse = ", "), ")"
    ) |> 
      factor(levels = posterior_level_vec[unique(n)+1]) # フレーム切替用ラベル
  ) |> 
  dplyr::ungroup()
anime_posterior_df
## # A tibble: 2,287,650 × 8
##        n   y_1     y_2 sum_s_k_lt alpha_k_lt density fill_flg param            
##    <int> <dbl>   <dbl> <list>     <list>       <dbl> <lgl>    <fct>            
##  1     0     0 0       <dbl [3]>  <dbl [3]>      NaN TRUE     n=0 (0, 0, 0), α…
##  2     0     0 0.00581 <dbl [3]>  <dbl [3]>       NA FALSE    n=0 (0, 0, 0), α…
##  3     0     0 0.0116  <dbl [3]>  <dbl [3]>       NA FALSE    n=0 (0, 0, 0), α…
##  4     0     0 0.0174  <dbl [3]>  <dbl [3]>       NA FALSE    n=0 (0, 0, 0), α…
##  5     0     0 0.0232  <dbl [3]>  <dbl [3]>       NA FALSE    n=0 (0, 0, 0), α…
##  6     0     0 0.0291  <dbl [3]>  <dbl [3]>       NA FALSE    n=0 (0, 0, 0), α…
##  7     0     0 0.0349  <dbl [3]>  <dbl [3]>       NA FALSE    n=0 (0, 0, 0), α…
##  8     0     0 0.0407  <dbl [3]>  <dbl [3]>       NA FALSE    n=0 (0, 0, 0), α…
##  9     0     0 0.0465  <dbl [3]>  <dbl [3]>       NA FALSE    n=0 (0, 0, 0), α…
## 10     0     0 0.0523  <dbl [3]>  <dbl [3]>       NA FALSE    n=0 (0, 0, 0), α…
## # … with 2,287,640 more rows

 試行回数(観測データ数)を表す0からNまでのN+1個の整数と、x軸の値mu_vecの全ての組み合わせをexpand_grid()で作成します。これにより、mu_vecの各要素をN+1フレーム分に複製できます。
 事前分布とN回分の事後分布のパラメータa, b(のベクトル)を計算(作成)して、n列の値をインデックスとして使って各試行に対応する値を抽出します。
 確率変数とパラメータの組み合わせごとに確率密度を計算して、パラメータごとにフレーム切替用のラベルを作成します。

 同様に、観測データの度数と予測分布のパラメータを結合します。

# レベルを設定用の文字列を作成
freq_vec <- rbind(rep(0, times = K), s_nk) |> 
  apply(MARGIN = 2, FUN = cumsum) |> # 累積和を計算
  apply(MARGIN = 1, FUN = paste0, collapse = ", ") # 試行ごとに結合
param_vec <- rbind(rep(0, times = K), s_nk) |> 
  apply(MARGIN = 2, FUN = cumsum) |> # 累積和を計算
  (\(.) {t(.) + alpha_k})() |> # 事後分布のパラメータを計算
  (\(.) {t(.) / colSums(.)})() |> # 予測分布のパラメータを計算
  round(digits = 2) |> 
  apply(MARGIN = 1, FUN = paste0, collapse = ", ") # 試行ごとに結合
predict_level_vec <- paste0("n=", 0:N, " (", freq_vec, "), π=(", param_vec, ")")
head(predict_level_vec)
## [1] "n=0 (0, 0, 0), π=(0.33, 0.33, 0.33)"
## [2] "n=1 (1, 0, 0), π=(0.5, 0.25, 0.25)" 
## [3] "n=2 (1, 1, 0), π=(0.4, 0.4, 0.2)"   
## [4] "n=3 (2, 1, 0), π=(0.5, 0.33, 0.17)" 
## [5] "n=4 (2, 2, 0), π=(0.43, 0.43, 0.14)"
## [6] "n=5 (3, 2, 0), π=(0.5, 0.38, 0.12)"


 初期値を含めたN+1回分の予測分布を計算します。

# 試行ごとに予測分布を格納
anime_predict_df <- tidyr::expand_grid(
  n = 0:N, # 試行回数
  k = 1:K  # 次元番号
) |> # 試行ごとにx軸の値を複製
  dplyr::group_by(n) |> # 
  dplyr::mutate(
    # 予測分布のパラメータを計算:式(3.31')
    sum_s_k_lt = colSums(s_nk[0:unique(n), , drop = FALSE]) |> 
      list(), 
    alpha_k_lt = (colSums(s_nk[0:unique(n), , drop = FALSE]) + alpha_k) |> 
      list(), 
    pi_k_lt = (alpha_k_lt[[1]] / sum(alpha_k_lt[[1]])) |> 
      list(), 
    # 予測分布を格納
    prob = pi_k_lt[[1]], 
    param = paste0(
      "n=", unique(n), " (", paste0(sum_s_k_lt[[1]], collapse = ", "), ")", 
      ", π=(", paste0(round(pi_k_lt[[1]], digits = 2), collapse = ", "), ")"
    ) |> 
      factor(levels = predict_level_vec[unique(n)+1]) # フレーム切替用ラベル
  ) |> 
  dplyr::ungroup()
anime_predict_df
## # A tibble: 303 × 7
##        n     k sum_s_k_lt alpha_k_lt pi_k_lt    prob param                     
##    <int> <int> <list>     <list>     <list>    <dbl> <fct>                     
##  1     0     1 <dbl [3]>  <dbl [3]>  <dbl [3]> 0.333 n=0 (0, 0, 0), π=(0.33, 0…
##  2     0     2 <dbl [3]>  <dbl [3]>  <dbl [3]> 0.333 n=0 (0, 0, 0), π=(0.33, 0…
##  3     0     3 <dbl [3]>  <dbl [3]>  <dbl [3]> 0.333 n=0 (0, 0, 0), π=(0.33, 0…
##  4     1     1 <dbl [3]>  <dbl [3]>  <dbl [3]> 0.5   n=1 (1, 0, 0), π=(0.5, 0.…
##  5     1     2 <dbl [3]>  <dbl [3]>  <dbl [3]> 0.25  n=1 (1, 0, 0), π=(0.5, 0.…
##  6     1     3 <dbl [3]>  <dbl [3]>  <dbl [3]> 0.25  n=1 (1, 0, 0), π=(0.5, 0.…
##  7     2     1 <dbl [3]>  <dbl [3]>  <dbl [3]> 0.4   n=2 (1, 1, 0), π=(0.4, 0.…
##  8     2     2 <dbl [3]>  <dbl [3]>  <dbl [3]> 0.4   n=2 (1, 1, 0), π=(0.4, 0.…
##  9     2     3 <dbl [3]>  <dbl [3]>  <dbl [3]> 0.2   n=2 (1, 1, 0), π=(0.4, 0.…
## 10     3     1 <dbl [3]>  <dbl [3]>  <dbl [3]> 0.5   n=3 (2, 1, 0), π=(0.5, 0.…
## # … with 293 more rows


作図処理

 事後分布と予測分布の推移をそれぞれアニメーションで可視化します。

・コード(クリックで展開)

 観測データと対応するラベルをデータフレームに格納します。

# 観測データを格納
anime_data_df <- tibble::tibble(
  x = c(NA, s_nk[, 2] + 0.5 * s_nk[, 3]), # 三角座標への変換
  y = c(NA, sqrt(3) * 0.5 * s_nk[, 3]), # 三角座標への変換
  param = unique(anime_posterior_df[["param"]]) # フレーム切替用ラベル
)
anime_data_df
## # A tibble: 101 × 3
##        x      y param                     
##    <dbl>  <dbl> <fct>                     
##  1  NA   NA     n=0 (0, 0, 0), α=(1, 1, 1)
##  2   0    0     n=1 (1, 0, 0), α=(2, 1, 1)
##  3   1    0     n=2 (1, 1, 0), α=(2, 2, 1)
##  4   0    0     n=3 (2, 1, 0), α=(3, 2, 1)
##  5   1    0     n=4 (2, 2, 0), α=(3, 3, 1)
##  6   0    0     n=5 (3, 2, 0), α=(4, 3, 1)
##  7   0.5  0.866 n=6 (3, 2, 1), α=(4, 3, 2)
##  8   1    0     n=7 (3, 3, 1), α=(4, 4, 2)
##  9   0.5  0.866 n=8 (3, 3, 2), α=(4, 4, 3)
## 10   0    0     n=9 (4, 3, 2), α=(5, 4, 3)
## # … with 91 more rows

 事前分布には観測データが影響しないので欠損値NAとして、anime_posterior_dfのラベル列と合わせて格納します。
 参考として、各試行における観測データを表示するのに使います。

 事後分布の推移のアニメーションを作成します。

# 真のパラメータを格納
parameter_df <- tibble::tibble(
  x = pi_truth_k[2] + 0.5 * pi_truth_k[3], # 三角座標への変換
  y = sqrt(3) * 0.5 * pi_truth_k[3], # 三角座標への変換
)

# 事後分布のアニメーションを作図
anime_posterior_graph <- ggplot() + 
  geom_segment(data = ternary_grid_df, 
               mapping = aes(x = y_1_start, y = y_2_start, xend = y_1_end, yend = y_2_end), 
               color = "gray50", linetype = "dashed") + # 三角図のグリッド線
  geom_segment(data = ternary_axis_df, 
               mapping = aes(x = y_1_start, y = y_2_start, xend = y_1_end, yend = y_2_end), 
               color = "gray50") + # 三角図の枠線
  geom_text(data = ternary_ticklabel_df, 
            mapping = aes(x = y_1, y = y_2, label = label, hjust = h, vjust = v, angle = angle)) + # 三角図の軸目盛ラベル
  geom_text(data = ternary_axislabel_df, 
            mapping = aes(x = y_1, y = y_2, label = label, hjust = h, vjust = v), 
            parse = TRUE, size = 6) + # 三角図の軸ラベル
  geom_contour_filled(data = anime_posterior_df, 
                      mapping = aes(x = y_1, y = y_2, z = density, fill = ..level..), 
                      alpha = 0.8) + # 事後分布
  geom_point(data = anime_data_df, mapping = aes(x = x, y = y, color = "data"), 
             size = 6) + # 観測データ
  geom_point(data = parameter_df, mapping = aes(x = x, y = y, color = "param"), 
             shape = 4, size = 6) + # 真のパラメータ
  gganimate::transition_manual(frames = param) + # フレーム
  scale_x_continuous(breaks = c(0, 0.5, 1), labels = NULL) + # x軸
  scale_y_continuous(breaks = c(0, 0.25*sqrt(3), 0.5*sqrt(3)), labels = NULL) + # y軸
  scale_color_manual(breaks = c("param", "data"), 
                     values = c("red", "pink"), 
                     labels = c("true parameter", "observation data"), name = "") + # 線の色:(凡例表示用)
  guides(color = guide_legend(override.aes = list(size = c(5, 5), shape = c(4, 19)))) + # 凡例の体裁:(凡例表示用)
  coord_fixed(ratio = 1, clip = "off") + # アスペクト比
  theme(axis.ticks = element_blank(), 
        panel.grid.minor = element_blank()) + # 図の体裁
  labs(title = "Dirichlet Distribution", 
       subtitle = "{current_frame}", 
       fill = "density", 
       x = "", y = "")

# gif画像を作成
gganimate::animate(
  plot = anime_posterior_graph, nframes = N+1+10, end_pause = 10, fps = 10, 
  width = 800, height = 800
)

 フレームの順番を示す列をtransition_manual()に指定して、animate()でgif画像を作成します。事前分布(初期値)を含むため、フレーム数はN+1です。

 観測データと対応するラベルをデータフレームに格納します。

# 観測データを格納
anime_data_df <- tibble::tibble(
  k = c(NA, which(t(s_nk) == 1, arr.ind = TRUE)[, "row"]), # クラスタ番号
  param = unique(anime_predict_df[["param"]]) # フレーム切替用ラベル
)
anime_data_df
## # A tibble: 101 × 2
##        k param                              
##    <int> <fct>                              
##  1    NA n=0 (0, 0, 0), π=(0.33, 0.33, 0.33)
##  2     1 n=1 (1, 0, 0), π=(0.5, 0.25, 0.25) 
##  3     2 n=2 (1, 1, 0), π=(0.4, 0.4, 0.2)   
##  4     1 n=3 (2, 1, 0), π=(0.5, 0.33, 0.17) 
##  5     2 n=4 (2, 2, 0), π=(0.43, 0.43, 0.14)
##  6     1 n=5 (3, 2, 0), π=(0.5, 0.38, 0.12) 
##  7     3 n=6 (3, 2, 1), π=(0.44, 0.33, 0.22)
##  8     2 n=7 (3, 3, 1), π=(0.4, 0.4, 0.2)   
##  9     3 n=8 (3, 3, 2), π=(0.36, 0.36, 0.27)
## 10     1 n=9 (4, 3, 2), π=(0.42, 0.33, 0.25)
## # … with 91 more rows

 こちらは、anime_predict_dfのラベル列を使います。

 真の分布を複製して、対応するラベルとデータフレームに格納します。

# 真の分布を複製
anime_model_df <- tidyr::expand_grid(
  param = unique(anime_predict_df[["param"]]), # フレーム切替用ラベル
  k = 1:K # 次元番号
) |> # 試行ごとにx軸の値を複製
  dplyr::mutate(
    prob = pi_truth_k[k] # 確率
  )
anime_model_df
## # A tibble: 303 × 3
##    param                                   k  prob
##    <fct>                               <int> <dbl>
##  1 n=0 (0, 0, 0), π=(0.33, 0.33, 0.33)     1   0.3
##  2 n=0 (0, 0, 0), π=(0.33, 0.33, 0.33)     2   0.5
##  3 n=0 (0, 0, 0), π=(0.33, 0.33, 0.33)     3   0.2
##  4 n=1 (1, 0, 0), π=(0.5, 0.25, 0.25)      1   0.3
##  5 n=1 (1, 0, 0), π=(0.5, 0.25, 0.25)      2   0.5
##  6 n=1 (1, 0, 0), π=(0.5, 0.25, 0.25)      3   0.2
##  7 n=2 (1, 1, 0), π=(0.4, 0.4, 0.2)        1   0.3
##  8 n=2 (1, 1, 0), π=(0.4, 0.4, 0.2)        2   0.5
##  9 n=2 (1, 1, 0), π=(0.4, 0.4, 0.2)        3   0.2
## 10 n=3 (2, 1, 0), π=(0.5, 0.33, 0.17)      1   0.3
## # … with 293 more rows

 棒グラフの場合は、フレームごとにデータを用意する必要があります。
 真の分布は全てのフレームで変化しないので、N+1個に複製します。

 予測分布の推移のアニメーションを作成します。

# 予測分布を作図
anime_predict_graph <- ggplot() + 
  geom_bar(data = anime_predict_df, mapping = aes(x = k, y = prob, fill = "predict"), 
           stat = "identity") + # 予測分布
  geom_bar(data = anime_model_df, aes(x = k, y = prob, fill = "model", color = "model"), 
           stat = "identity", size = 1, linetype = "dashed") + # 真の分布
  geom_point(data = anime_data_df, aes(x = k, y = 0, color = "data"), 
             size = 6) + # 観測データ
  gganimate::transition_manual(frames = param) + # フレーム
  scale_fill_manual(values = c(model = NA, predict = "purple", data = NA), na.value = NA, 
                    labels = c(model = "true model", predict = "predict", data = "observation data"), name = "") + # バーの色:(凡例表示用)
  scale_color_manual(values = c(model = "red", predict = "purple", data = "pink"), 
                     labels = c(model = "true model", predict = "predict", data = "observation data"), name = "") + # 線の色:(凡例表示用)
  scale_x_continuous(breaks = 1:K, minor_breaks = FALSE) + # x軸目盛
  guides(fill = guide_legend(override.aes = list(fill = c(NA, "purple", NA))), 
         color = guide_legend(override.aes = list(size = c(0.5, 0.5, 5), linetype = c("dashed", "blank", "blank"), shape = c(NA, NA, 19)))) + # 凡例の体裁:(凡例表示用)
  coord_cartesian(ylim = c(0, 1)) + # 表示範囲
  labs(title = "Categorical Distribution", 
       subtitle = "{current_frame}", 
       x = "k", y = "probability")

# gif画像を作成
gganimate::animate(
  plot = anime_predict_graph, nframes = N+1+10, end_pause = 10, fps = 10, 
  width = 800, height = 600
)


事後分布の推移

 データが増えるにつれて、真のパラメータ付近の確率密度が大きくなっていくのを確認できます。

予測分布の推移

 予測分布が真の分布に近付いていくのを確認できます。

参考文献

  • 須山敦志『ベイズ推論による機械学習入門』(機械学習スタートアップシリーズ)杉山将監修,講談社,2017年.

おわりに

 加筆修正の際に記事を分割しました。

 {gganimate}の詳しい使い方はいつかまた別の機会に…

  • 2023.01.25:加筆修正しました。

 gganimateパッケージの記事も書いているのでR > パッケージ > gganimateのカテゴリを見てみてください。

【次の節の内容】

www.anarchive-beta.com