からっぽのしょこ

読んだら書く!書いたら読む!読書読読書読書♪同じ事は二度調べ(たく)ない

3.2.2:カテゴリ分布の学習と予測【緑ベイズ入門のノート】

はじめに

 『ベイズ推論による機械学習入門』の学習時のノートです。基本的な内容は「数式の行間を読んでみた」とそれを「Rで組んでみた」になります。「数式」と「プログラム」から理解するのが目標です。

 この記事は3.2.2項の内容になります。尤度関数をカテゴリ分布、事前分布をディリクレ分布とした場合のパラメータの事後分布を導出し、その学習した事後分布を用いた予測分布を導出します。またその学習過程をR言語で実装します。

 省略してる内容等ありますので、本と併せて読んでください。初学者な自分が理解できるレベルまで落として書き下していますので、分かる人にはかなりくどくなっています。同じような立場の人のお役に立てば幸いです。

【前節の内容】

www.anarchive-beta.com

【他の節一覧】

www.anarchive-beta.com

【この節の内容】

3.2.2 カテゴリ分布の学習と予測

・事後分布の計算

 カテゴリ分布に従うと仮定するN個の観測データ$\boldsymbol{S} = \{\boldsymbol{s}_1, \boldsymbol{s}_2, \cdots, \boldsymbol{s}_N \}$を用いて、パラメータ$\boldsymbol{\pi}$の事後分布を求めていく。

 まずは観測モデル$p(\boldsymbol{S} | \boldsymbol{\pi})$について確認する。離散値データ$\boldsymbol{S} = \{\boldsymbol{s}_1, \boldsymbol{s}_2, \cdots, \boldsymbol{s}_N \}$はそれぞれ独立に生成されているとの仮定の下で、観測モデルは

$$ \begin{aligned} p(\boldsymbol{S} | \boldsymbol{\pi}) &= p(\boldsymbol{s}_1, \boldsymbol{s}_2, \cdots, \boldsymbol{s}_N | \boldsymbol{\pi}) \\ &= p(\boldsymbol{s}_1 | \boldsymbol{\pi}) p(\boldsymbol{s}_2 | \boldsymbol{\pi}) \cdots p(\boldsymbol{s}_N | \boldsymbol{\pi}) \\ &= \prod_{n=1}^N p(\boldsymbol{s}_n | \boldsymbol{\pi}) \end{aligned} $$

と分解できる。更に$\boldsymbol{s}_n,\ \boldsymbol{\pi}$はK次元ベクトルであるため

$$ \begin{aligned} p(\boldsymbol{s}_n | \boldsymbol{\pi}) &= p(s_{n,1}, s_{n,2}, \cdots, s_{n,K} | \boldsymbol{\pi}) \\ &= p(s_{n,1} | \pi_1) p(s_{n,2} | \pi_2) \cdots p(s_{n,K} | \pi_K) \\ &= \prod_{k=1}^K p(s_{n,k} | \pi_k) \end{aligned} $$

と分解できる。従って観測モデルは

$$ p(\boldsymbol{S} | \boldsymbol{\pi}) = \prod_{n=1}^N \prod_{k=1}^K p(s_{n,k} | \pi_k) $$

である。

 これを用いて、$\boldsymbol{\pi}$の事後分布$p(\boldsymbol{\pi} | \boldsymbol{S})$はベイズの定理より

$$ \begin{align} p(\boldsymbol{\pi} | \boldsymbol{S}) &= \frac{ p(\boldsymbol{S} | \boldsymbol{\pi}) p(\boldsymbol{\pi} | \boldsymbol{\alpha}) }{ p(\boldsymbol{S}) } \\ &= \frac{ \left\{ \prod_{n=1}^N p(\boldsymbol{s}_n | \boldsymbol{\pi}) \right\} p(\boldsymbol{\pi} | \boldsymbol{\alpha}) }{ p(\boldsymbol{S}) } \\ &= \frac{ \left\{\prod_{n=1}^N \mathrm{Cat}(\boldsymbol{s}_n | \boldsymbol{\pi}) \right\} \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) }{ p(\boldsymbol{S}) } \\ &\propto \left\{\prod_{n=1}^N \mathrm{Cat}(\boldsymbol{s}_n | \boldsymbol{\pi}) \right\} \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) \tag{3.25} \end{align} $$

となる。分母の$p(\boldsymbol{S})$は$\boldsymbol{\pi}$に影響しないため省略して、比例関係にのみ注目する。省略した部分については、最後に正規化することで対応できる。

 次にこの分布の具体的な形状を明らかにしていく。対数をとって指数部分の計算を分かりやすくして進めると

$$ \begin{align} \ln p(\boldsymbol{\pi} | \boldsymbol{S}) &= \ln \prod_{n=1}^N \mathrm{Cat}(\boldsymbol{s}_n | \boldsymbol{\pi}) + \ln \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) - \ln p(\boldsymbol{S}) \\ &= \sum_{n=1}^N \ln \mathrm{Cat}(\boldsymbol{s}_n | \boldsymbol{\pi}) + \ln \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) + \mathrm{const.} \\ &= \sum_{n=1}^N \ln \left( \prod_{k=1}^K \pi_k^{s_{n,k}} \right) + \ln \left\{ C_D(\boldsymbol{\alpha}) \prod_{k=1}^K \pi_k^{\alpha_k-1} \right\} + \mathrm{const.} \\ &= \sum_{n=1}^N \sum_{k=1}^K s_{n,k} \ln \pi_k + \ln C_D(\boldsymbol{\alpha}) + \sum_{k=1}^K (\alpha_k - 1) \ln \pi_k + \mathrm{const.} \\ &= \sum_{k=1}^K \left( \sum_{n=1}^N s_{n,k} + \alpha_k - 1 \right) \ln \pi_k + \mathrm{const.} \tag{3.26} \end{align} $$

【途中式の途中式】

  1. 式(3.25)の下から2行目の式を用いている。
  2. 対数をとると積が和に変わる。また、$\boldsymbol{\pi}$に影響しない項($- \ln p(\boldsymbol{S})$)を$\mathrm{const.}$とおく。
  3. $\mathrm{Cat}(\boldsymbol{s}_n | \boldsymbol{\pi}),\ \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha})$に、それぞれカテゴリ分布の定義式(2.29)、ディリクレ分布の定義式(2.48)を用いて具体的な確率分布を代入する。
  4. それぞれ対数をとる。$\ln (x^a y^b) = a \ln x + b \ln y$である。
  5. $\boldsymbol{\pi}$に影響しない項($\ln C_D(\boldsymbol{\alpha})$)を$\mathrm{const.}$にまとめる。
  6. $\ln \pi_k$の項をまとめて式を整理する。


となる。

 式(3.26)について

$$ \hat{\alpha}_k = \sum_{n=1}^N s_{n,k} + \alpha_k \tag{3.28} $$

とおき

$$ \ln p(\boldsymbol{\pi} | \boldsymbol{S}) = \sum_{k=1}^K (\hat{\alpha}_k - 1) \ln \pi_k + \mathrm{const.} $$

更に$\ln$を外して、$\mathrm{const.}$を正規化項に置き換える(正規化する)と

$$ p(\boldsymbol{\pi} | \boldsymbol{S}) = C_D(\hat{\boldsymbol{\alpha}}) \prod_{k=1}^K \pi_k^{\hat{\alpha}_k-1} = \mathrm{Dir}(\boldsymbol{\pi} | \hat{\boldsymbol{\alpha}}) \tag{3.27} $$

となる。式の形状から、事後分布$p(\boldsymbol{\pi} | \boldsymbol{S})$がパラメータ$\hat{\boldsymbol{\alpha}}$を持つディリクレ分布であることが確認できる。
 また事後分布のパラメータの計算式(3.28)によって、ハイパーパラメータ$\boldsymbol{\alpha}$が更新される。

・予測分布の計算

 続いて未観測のデータ$\boldsymbol{s}_{*} = {s_{*,1}, s_{*,2}, \cdots, s_{*,K}}$に対する予測分布を求めていく。

 (観測データによる学習を行っていない)事前分布$p(\boldsymbol{\pi} | \boldsymbol{\alpha})$を用いて、パラメータ$\boldsymbol{\pi}$を周辺化することで予測分布$p(\boldsymbol{s}_{*})$となる。

$$ \begin{aligned} p(\boldsymbol{s}_{*}) &= \int p(\boldsymbol{s}_{*} | \boldsymbol{\pi}) p(\boldsymbol{\pi} | \boldsymbol{\alpha}) d\boldsymbol{\pi} \\ &= \int \mathrm{Cat}(\boldsymbol{s}_{*} | \boldsymbol{\pi}) \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) d\boldsymbol{\pi} \\ &= \int \prod_{k=1}^K \pi_k^{s_{*,k}} C_D(\boldsymbol{\alpha}) \pi_k^{\alpha_k-1} d\boldsymbol{\pi} \\ &= C_D(\boldsymbol{\alpha}) \int \prod_{k=1}^K \pi_k^{s_{*,k}+\alpha_k-1} d\boldsymbol{\pi} \end{aligned} $$

 積分部分に注目すると、パラメータ$s_{*,k} + \alpha_k$を持つ正規化項のないディリクレ分布の形をしている。よって、ディリクレ分布の定義式(2.48)を用いて

$$ \int \prod_{k=1}^K \pi_k^{s_{*,k}+\alpha_k-1} d\boldsymbol{\pi} = \frac{1}{C_D((s_{*,k} + \alpha_k)_{k=1}^K)} $$

正規化項の逆数に変形できる(そもそもこの部分の逆数が正規化項である)。これを先ほどの式に代入すると、予測分布$p(\boldsymbol{s}_{*})$は

$$ p(\boldsymbol{s}_{*}) = \frac{ C_D(\boldsymbol{\alpha}) }{ C_D\Bigl((s_{*,k} + \alpha_k)_{k=1}^K\Bigr) } \tag{3.29} $$

となる。更にディリクレ分布の正規化項(2.49)を用いると

$$ \begin{align} p(\boldsymbol{s}_{*}) &= C_D(\boldsymbol{\alpha}) \frac{ 1 }{ C_D\Bigl((s_{*,k} + \alpha_k)_{k=1}^K\Bigr) } \\ &= \frac{ \Gamma(\sum_{k=1}^K \alpha_k) }{ \prod_{k=1}^K \Gamma(\alpha_k) } \frac{ \prod_{k=1}^K \Gamma(s_{*,k} + \alpha_k) }{ \Gamma(\sum_{k=1}^K s_{*,k} + \alpha_k) } \tag{3.30} \end{align} $$

となる。

 ある$k'$に対して$s_{*,k'} = 1$となる場合のみを取り出して考えてみると、$\sum_{k=1}^K s_{*,k} = 1$より

$$ \begin{align} p(s_{*,k'} = 1) &= \frac{ \Gamma(\sum_{k=1}^K \alpha_k) }{ \Gamma(\alpha_{k'}) } \frac{ \Gamma(s_{*,k'} + \alpha_{k'}) }{ \Gamma(\sum_{k=1}^K s_{*,k} + \alpha_k) } \\ &= \frac{ \Gamma(\sum_{k=1}^K \alpha_k) }{ \Gamma(\alpha_{k'}) } \frac{ \Gamma(1 + \alpha_{k'}) }{ \Gamma(1 + \sum_{k=1}^K \alpha_k) } \\ &= \frac{ \Gamma(\sum_{k=1}^K \alpha_k) }{ \Gamma(\alpha_{k'}) } \frac{ \alpha_{k'} \Gamma(\alpha_{k'}) }{ (\sum_{k=1}^K \alpha_k) \Gamma(\sum_{k=1}^K \alpha_k) } \\ &= \frac{\alpha_{k'}}{\sum_{k=1}^K \alpha_k} \tag{3.31} \end{align} $$

となる。

 従って、$s_{*,k} = 1$のときそれ以外の$s_{*,1}, \cdots, s_{*,k'-1}, s_{*,k'+1}, \cdots, s_{*,K}$は全て0であるため$x^0 = 1$であることを利用して、$s_{*,1}$から$s_{*,K}$までをまとめると式(3.30)は

$$ p(s_{*,k}) = \prod_{k=1}^K \Bigl( \frac{\alpha_{k}}{\sum_{k’=1}^K \alpha_{k'}} \Bigr)^{s_{*,k}} $$

と書き換えられる。(分母の$k'$はこれまでのある$k'$とは別物。あくまで分子の$k$と区別するためのもの。本では$i$と表記することで区別している。)

 この式について

$$ \boldsymbol{\pi}_{*} = (\pi_{*,1}, \pi_{*,2}, \cdots, \pi_{*,K}),\ \pi_{*,k} = \frac{\alpha_{k}}{\sum_{k'=1}^K \alpha_{k'}} $$

とおくと、予測分布$p(s_{*,k})$は

$$ p(s_{*,k}) = \prod_{k=1}^K \pi_{*,k}^{s_{*,k}} = \mathrm{Cat}(\boldsymbol{s}_{*} | \boldsymbol{\pi}_{*}) \tag{3.32} $$

となる。式の形状から、$p(x_{*})$はパラメータ$\boldsymbol{\pi}_{*}$を持つカテゴリ分布であることが分かる。
 ちなみに$\frac{\alpha_{k}}{\sum_{k’=1}^K \alpha_{k'}}$は、ディリクレ分布$\mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha})$の期待値$\mathbb{E}[\pi_k]$である。

 予測分布のパラメータ$\boldsymbol{\pi}_{*}$を構成する$\alpha_k$について、事後分布のパラメータ(3.28)を用いると

$$ \begin{aligned} \hat{\pi}_{*,k} &= \frac{ \hat{\alpha}_k }{ \sum_{k'=1}^K \hat{\alpha}_{k'} } \\ &= \frac{ \sum_{n=1}^N s_{n,k} + \alpha_k }{ \sum_{k'=1}^K \sum_{n=1}^N s_{n,k'} + \alpha_{k'} } \end{aligned} $$

となり、(観測データ$\boldsymbol{S}$によって学習した)事後分布$p(\boldsymbol{\pi} | \boldsymbol{S})$を用いた予測分布$p(s_{*,k} | \boldsymbol{S})$

$$ p(s_{*,k} | \boldsymbol{S}) = \prod_{k=1}^K \hat{\pi}_{*,k}^{s_{*,k}} = \mathrm{Cat}(\boldsymbol{s}_{*} | \hat{\boldsymbol{\pi}}_{*}) \tag{3.32} $$

が得られる。

・Rでやってみよう

 各確率分布に従いランダムに生成したデータを用いて、パラメータを推定してみましょう。

# 利用パッケージ
library(tidyverse)

 利用するパッケージを読み込みます。

・事後分布
## パラメーターの初期値を指定
# 観測モデルのパラメータ
pi_k_truth <- c(0.3, 0.5, 0.2)

# 事前分布のパラメータ
alpha_k <- c(2, 2, 2)

# 試行回数
N <- 50

 このプログラムでは三角座標で作図するため、各パラメータの次元数(K)は3で固定です。

 データの発生確率$\boldsymbol{\pi} = (\pi_1, \pi_2, \pi_3)$をpi_k_truthとします。この値を推定するのが目的です。

 事前分布のパラメータ(ハイパーパラメータの初期値)$\boldsymbol{\alpha} = (\alpha_1, \alpha_2, \alpha_3)$をalpha_kとします。

 生成するデータ数$N$をNとします。

# 作図用のpiの値を満遍なく生成
pi <- tibble(
  pi_1 = rep(rep(seq(0, 1, by = 0.02), times = 51), times = 51), 
  pi_2 = rep(rep(seq(0, 1, by = 0.02), each = 51), times = 51), 
  pi_3 = rep(seq(0, 1, by = 0.02), each = 2601)
)

# 正規化
pi <- pi / apply(pi, 1, sum)

# 重複した組み合わせを除去(ハイスぺ機なら不要…)
pi <- pi %>% 
  mutate(pi_1 = round(pi_1, 3), pi_2 = round(pi_2, 3), pi_3 = round(pi_3, 3)) %>% 
  count(pi_1, pi_2, pi_3) %>% 
  select(-n) %>% 
  as.matrix()

 $\pi$が取り得る0から1までの値を3次元分用意します。ただし重複した組み合わせなく値を作るには少し工夫がいります。

 まずは各引数の作用を確認しましょう。seq()byは最小値(第1引数)から最大値(第2引数)までを刻む幅です。rep()の第1引数にベクトル(複数の値)を渡した場合、timesはベクトルを繰り返す回数、eachはベクトルの各要素を複製する個数です。
 またbyに指定した値によって他の引数に指定すべき値が変わります。byに指定した値をnとすると、seq()によって、1/n+1個(+1は0の分)の要素のベクトルが返ってきます。この1/n+1を他の引数に指定します。
 ただし、データフレームでいう3列目(サンプルコードでいうと5行目)のeachは他で指定した値の2乗の値((1/n+1)2)になります。
 入れ子関係や使う引数自体にも注意しましょう。

 簡単な例で確認すると次のようにしたいわけです。

tibble(
  v1 = rep(rep(seq(1, 2), times = 2), times = 2), 
  v2 = rep(rep(seq(1, 2), each = 2), times = 2), 
  v3 = rep(seq(1, 2), each = 4)
)
## # A tibble: 8 x 3
##      v1    v2    v3
##   <int> <int> <int>
## 1     1     1     1
## 2     2     1     1
## 3     1     2     1
## 4     2     2     1
## 5     1     1     2
## 6     2     1     2
## 7     1     2     2
## 8     2     2     2

何言ってんだ??となった場合は、次のランダムに点を打つ方法を使いましょう。

 用意した値は、列ごとの総和が1となるように総和で割って正規化します。

 私の環境だと、データが多くて作図時に固まってしまったので、重複した箇所を間引くことにしました…。
 この処理をしない場合は、正規化の前後どちらかでas.matrix(pi)の処理を加えてマトリクスに変換しておく必要があります。

# 作図用のpiの値をランダムに生成
pi <- matrix(
  sample(seq(0, 1, 0.01), size = 90000, replace = TRUE), 
  nrow = 3
)

# 正規化
pi <- pi / apply(pi, 1, sum)

 満遍なく点で埋めつくす設定は少しややこしいため、ランダムに点を生成してもいいです。

 点の数が少なすぎると疎らになり、多すぎると処理が重くなります。点の数はsize引数に指定する値で調整できます。ただしマトリクスに欠損値が出ないように、値を3の倍数にする必要があるので注意してください。

 こちらのやり方でも、各行の値の和が1となるように正規化します。

# カテゴリ分布に従うデータを生成
s_nk <- rmultinom(n = N, size = 1, prob = pi_k_truth) %>% 
  t()

 多項分布に従う乱数を発生させる関数rmultinom()を使って、ランダムにデータを生成します。これをカテゴリ分布とするためには、size引数に1を指定します。
 試行回数の引数nにはN、確率の引数probにはpi_k_truthを指定します。
 試行ごとの結果を列としたマトリクスが返ってくるので、t()で転置して式と合わせます。

 サンプルを確認してみましょう。

# 観測データを確認
apply(s_nk, 2, sum)
## [1] 13 27 10

このデータを用いて事後分布のパラメータを計算します。

# 事後分布のパラメータを計算
alpha_k_hat <- apply(s_nk, 2, sum) + alpha_k

 事後分布のパラメータの計算式(3.28)の計算を行い、$\hat{\alpha}_k$をalpha_kとします。

# 事後分布を計算
posterior_df <- tibble(
  x = pi[, 2] + (pi[, 3] / 2),  # 三角座標への変換
  y = sqrt(3) * (pi[, 3] / 2),  # 三角座標への変換
  C_D = lgamma(sum(alpha_k_hat)) - sum(lgamma(alpha_k_hat)),  # 正規化項(対数)
  density = exp(C_D + apply((alpha_k_hat - 1) * log(t(pi)), 2, sum)) # 確率密度
)

 最初に用意したpiを使って各値の確率密度を計算します。ただし3次元の情報を2次元の図で表現するために、三角座標に変換します。(よく解ってないゆえ解説なし)\  piの各値に対してディリクレ分布定義式(2.48)の計算を行い、確率密度を求めます。ただし値が大きくなるとgamma()で計算できなくなるため、対数をとって計算することにします。なので最後にexp()で値を戻します。

 計算結果は作図用にデータフレームにまとめておきます。推定結果を確認してみましょう。

head(posterior_df)
## # A tibble: 6 x 4
##       x     y   C_D density
##   <dbl> <dbl> <dbl>   <dbl>
## 1 0.5   0.866  57.7       0
## 2 0.51  0.849  57.7       0
## 3 0.510 0.848  57.7       0
## 4 0.511 0.847  57.7       0
## 5 0.511 0.846  57.7       0
## 6 0.512 0.845  57.7       0

このデータフレームを用いてグラフを描きます。

# piの真の値のプロット用データフレームを作成
pi_truth_df <- tibble(
  x = pi_k_truth[2] + (pi_k_truth[3] / 2),  # 三角座標への変換
  y = sqrt(3) * (pi_k_truth[3] / 2),        # 三角座標への変換
)

 最初に設定した$\boldsymbol{\pi}$の値を推定できているのか確認するために、真の値の位置もプロットしましょう。
 pi_k_truthの値も三角座標に変換します。ggplot()にはデータフレームで渡す必要があるため、データフレームで保存しておきます。

# 描画
ggplot() + 
  geom_point(data = posterior_df, aes(x, y, color = density)) + # 散布図
  geom_point(data = pi_truth_df, aes(x, y), shape = 3, size = 5) + # piの真の値
  scale_color_gradientn(colors = c("blue", "green", "yellow", "red")) + # プロットの色
  scale_x_continuous(breaks = c(0, 1), 
                     labels = c("(1, 0, 0)", "(0, 1, 0)")) + # x軸目盛
  scale_y_continuous(breaks = c(0, 0.87), 
                     labels = c("(1, 0, 0)", "(0, 1, 0)")) + # y軸目盛
  coord_fixed(ratio = 1) + # 縦横比
  labs(title = "Dirichlet Distribution", 
       subtitle = paste0("N=", N, ", alpha=(", paste(alpha_k_hat, collapse = ", "), ")"), 
       x = expression(paste(pi[1], ", ", pi[2], sep = "")), 
       y = expression(paste(pi[1], ", ", pi[3], sep = ""))) # ラベル

 ggplot2パッケージを利用してグラフを作成します。

 推定値は散布図geom_point()を用いて描きます。
 パラメータ$\pi$の真の値もgeom_point()で示します。

f:id:anemptyarchive:20200305015746p:plain
$\pi$の事後分布:ディリクレ分布


・予測分布
# 予測分布のパラメータを計算
pi_k_hat <- alpha_k_hat / sum(alpha_k_hat)

 予測分布のパラメータの計算式の計算を行い、$\hat{\pi}_k$をpi_k_hatとします。

pi_k_hat <- (apply(s_nk, 2, sum) + alpha_k) / sum(apply(s_nk, 2, sum) + alpha_k)

 このように、観測データs_nkと事前分布のパラメータalpha_kを使って計算することもできます。

# 作図用のsの値
s_sk <- matrix(c(1, 0, 0, 0, 1, 0, 0, 0, 1), ncol = 3)

 $s_{*,k}$が取り得る値を用意します。

# 予測分布を計算
predict_df <- tibble(
  k = seq(1, 3),  # 作図用の値
  prob = apply(pi_k_hat^s_sk, 1, prod) # 確率
)

 kの各値となる確率は、カテゴリ分布の定義式(2.29)で計算します。

 推定結果を確認してみましょう。

head(predict_df)
## # A tibble: 3 x 2
##       k  prob
##   <int> <dbl>
## 1     1 0.268
## 2     2 0.518
## 3     3 0.214

 こちらもggplot2パッケージを利用して作図します。

# 作図
ggplot(predict_df, aes(k, prob)) + 
  geom_bar(stat = "identity", position = "dodge", fill = "#56256E") + # 棒グラフ
  labs(title = "Categorical Distribution", 
       subtitle = paste0("N=", N, ", pi_hat=(", paste(round(pi_k_hat, 2), collapse = ", "), ")")) # ラベル

 棒グラフはgeom_bar()を使います。

f:id:anemptyarchive:20200305015852p:plain
$\boldsymbol{s}_{*}$の予測分布:カテゴリ分布


・おまけ

 gganimateパッケージを利用して、パラメータの推定値の推移のgif画像を作成するためのコードです。

・コード(クリックで展開)

# 利用パッケージ
library(tidyverse)
library(gganimate)

## パラメーターの初期値を指定
# 観測モデルのパラメータ
pi_k_truth <- c(0.3, 0.5, 0.2)

# 事前分布のパラメータ
alpha_k <- c(2, 2, 2)

# 試行回数
N <- 50

# 作図用のpiの値
pi <- tibble(
  pi_1 = rep(rep(seq(0, 1, by = 0.025), times = 41), times = 41), 
  pi_2 = rep(rep(seq(0, 1, by = 0.025), each = 41), times = 41), 
  pi_3 = rep(seq(0, 1, by = 0.025), each = 1681)
)

# 正規化
pi <- pi / apply(pi, 1, sum)

# 重複した組み合わせを除去(ハイスぺ機なら不要…)
pi <- pi %>% 
  mutate(pi_1 = round(pi_1, 3), pi_2 = round(pi_2, 3), pi_3 = round(pi_3, 3)) %>% 
  count(pi_1, pi_2, pi_3) %>% 
  select(-n) %>% 
  as.matrix()

# 事前分布を計算
posterior_df <- tibble(
  x = pi[, 2] + (pi[, 3] / 2),  # 三角座標への変換
  y = sqrt(3) * (pi[, 3] / 2),  # 三角座標への変換
  C_D = lgamma(sum(alpha_k)) - sum(lgamma(alpha_k)),  # 正規化項(対数)
  density = exp(C_D + apply((alpha_k - 1) * log(t(pi)), 2, sum)), # 確率密度
  N = 0 # 試行回数
)

# 初期値による予測分布のパラメーターを計算
pi_k_hat <- alpha_k / sum(alpha_k)

# 初期値による予測分布を計算
predict_df <- tibble(
  k = seq(1, 3),  # 作図用の値
  prob = apply(pi_k_hat^s_sk, 1, prod),  # 確率
  N = 0 # 試行回数
)

# パラメーターを推定
s_nk <- matrix(0, nrow = N, ncol = 3) # 受け皿
for(n in 1:N){
  
  # カテゴリ分布に従うデータを生成
  s_nk[n, ] <- rmultinom(n = 1, size = 1, prob = pi_k_truth) %>% 
    as.vector()
  
  # ハイパーパラメータを更新
  alpha_k <- s_nk[n, ] + alpha_k
  
  # 事後分布を計算
  tmp_posterior_df <- tibble(
    x = pi[, 2] + (pi[, 3] / 2),  # 三角座標への変換
    y = sqrt(3) * (pi[, 3] / 2),  # 三角座標への変換
    C_D = lgamma(sum(alpha_k)) - sum(lgamma(alpha_k)),  # 正規化項(対数)
    density = exp(C_D + apply((alpha_k - 1) * log(t(pi)), 2, sum)), # 確率密度
    N = n # 試行回数
  )
  
  # 予測分布のパラメーターを計算
  pi_k_hat <- alpha_k / sum(alpha_k)
  
  # 予測分布を計算
  tmp_predict_df <- tibble(
    k = seq(1, 3),  # 作図用の値
    prob = apply(pi_k_hat^s_sk, 1, prod),  # 確率
    N = n # 試行回数
  )
  
  # 結果を結合
  posterior_df <- rbind(posterior_df, tmp_posterior_df)
  predict_df <- rbind(predict_df, tmp_predict_df)
}

# 観測データを確認
apply(s_nk, 2, sum)

# piの真の値のプロット用データフレームを作成
pi_truth_df <- tibble(
  x = pi_k_truth[2] + (pi_k_truth[3] / 2),  # 三角座標への変換
  y = sqrt(3) * (pi_k_truth[3] / 2),  # 三角座標への変換
  N = seq(0, N)
)

## 事後分布
# 作図
posterior_graph <- ggplot() + 
  geom_point(data = posterior_df, aes(x, y, color = density)) + # 散布図
  geom_point(data = pi_truth_df, aes(x, y), shape = 3, size = 5) + # piの真の値
  scale_color_gradientn(colors = c("blue", "green", "yellow", "red")) + # プロットの色
  scale_x_continuous(breaks = c(0, 1), 
                     labels = c("(1, 0, 0)", "(0, 1, 0)")) + # x軸目盛
  scale_y_continuous(breaks = c(0, 0.87), 
                     labels = c("(1, 0, 0)", "(0, 1, 0)")) + # y軸目盛
  coord_fixed(ratio = 1) + # 縦横比
  transition_manual(N) + # フレーム
  labs(title = "Dirichlet Distribution", 
       subtitle = "N= {current_frame}", 
       x = expression(paste(pi[1], ", ", pi[2], sep = "")), 
       y = expression(paste(pi[1], ", ", pi[3], sep = ""))) # ラベル

# 描画
animate(posterior_graph)

## 予測分布
# 作図
predict_graph <- ggplot(predict_df, aes(k, prob)) + 
  geom_bar(stat = "identity", position = "dodge", fill = "#56256E") + # 棒グラフ
  transition_manual(N) + # フレーム
  labs(title = "Categorical Distribution", 
       subtitle = "N= {current_frame}") # ラベル

# 描画
animate(predict_graph)

 異なる点のみを簡単に解説します。

 各データによってどのように学習する(推定値が変化する)のかを確認するため、こちらはforループで1データずつ処理します。
 よって生成するデータ数として設定したNがイタレーション数になります。

 パラメータの推定値について、$\hat{\alpha}_k$に対応するalpha_k_hatを新たに作るのではなく、alpha_kをイタレーションごとに更新(上書き)していきます。
 それに伴い、事後分布のパラメータの計算式(3.28)の$\sum_{n=1}^N$の計算は、forループによってN回繰り返しs_nk[n, ]を加えることで行います。n回目のループ処理のときには、n-1回分のs_nk[n, ]が既にalpha_kに加えられているわけです。

f:id:anemptyarchive:20200305020119g:plain
$\pi$の事後分布の推移

f:id:anemptyarchive:20200305020238g:plain
$\boldsymbol{s}_{*}$の予測分布の推移


参考文献

  • 須山敦志『ベイズ推論による機械学習入門』(機械学習スタートアップシリーズ)杉山将監修,講談社,2017年.

おわりに

 {gganimate}の詳しい使い方はいつかまた別の機会に…

【次節の内容】

www.anarchive-beta.com

2020/03/04:加筆修正しました。