からっぽのしょこ

読んだら書く!書いたら読む!読書読読書読書♪同じ事は二度調べ(たく)ない

4.3.3:ポアソン混合モデルの変分推論【緑ベイズ入門のノート】

はじめに

 『ベイズ推論による機械学習入門』の学習時のノートです。基本的な内容は「数式の行間を読んでみた」とそれを「Rで組んでみた」になります。「数式」と「プログラム」から理解するのが目標です。

 この記事は4.3.3項の内容になります。ポアソン混合モデルにおける変分推論を導出し、Rで実装します。

 省略してある内容等ありますので、本と併せて読んでください。初学者な自分が理解できるレベルまで落として書き下していますので、分かる人にはかなりくどくなっています。同じような立場の人のお役に立てれば幸いです。

【前節の内容】

www.anarchive-beta.com

【他の節一覧】

www.anarchive-beta.com

【この節の内容】

4.3.3 変分推論

 ポアソン分布に対する変分推論アルゴリズムを導出する。

・潜在変数の近似事後分布の更新式の導出

 潜在変数$\boldsymbol{S}$とパラメータ$\boldsymbol{\lambda,\ \boldsymbol{\pi}}$の事後分布について

$$ p(\boldsymbol{S}, \boldsymbol{\lambda, \boldsymbol{\pi}} | \boldsymbol{X}) \approx q(\boldsymbol{S}) q(\boldsymbol{\lambda, \boldsymbol{\pi}}) \tag{4.46} $$

分解近似の仮定をおき、変分推論の公式(4.25)を用いると

$$ \begin{align*} \ln q(\boldsymbol{S}) &= \mathbb{E}_{q(\boldsymbol{\lambda}, \boldsymbol{\pi})} [ \ln p(\boldsymbol{S}, \boldsymbol{\lambda, \boldsymbol{\pi}} | \boldsymbol{X}) ] + \mathrm{const.} \\ &= \mathbb{E}_{q(\boldsymbol{\lambda}, \boldsymbol{\pi})} \left[ \ln \frac{ p(\boldsymbol{X}, \boldsymbol{S}, \boldsymbol{\lambda}, \boldsymbol{\pi}) }{ p(\boldsymbol{X}) } \right] + \mathrm{const.} \\ &= \mathbb{E}_{q(\boldsymbol{\lambda}, \boldsymbol{\pi})} [ \ln p(\boldsymbol{X}, \boldsymbol{S}, \boldsymbol{\lambda}, \boldsymbol{\pi}) ] + \mathrm{const.} \\ &= \mathbb{E}_{q(\boldsymbol{\lambda}, \boldsymbol{\pi})} \Bigl[ \ln p(\boldsymbol{X} | \boldsymbol{S}, \boldsymbol{\lambda}) + \ln p(\boldsymbol{S} | \boldsymbol{\pi}) + \ln p(\boldsymbol{\lambda}) + \ln p(\boldsymbol{\pi}) \Bigr] + \mathrm{const.} \\ &= \mathbb{E}_{q(\boldsymbol{\lambda})} [ \ln p(\boldsymbol{X} | \boldsymbol{S}, \boldsymbol{\lambda}) ] + \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln p(\boldsymbol{S} | \boldsymbol{\pi}) ] \\ &\qquad + \mathbb{E}_{q(\boldsymbol{\lambda})} [ \ln p(\boldsymbol{\lambda}) ] + \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln p(\boldsymbol{\pi}) ] + \mathrm{const.} \\ &= \sum_{n=1}^N \Bigl\{ \mathbb{E}_{q(\boldsymbol{\lambda})} [ \ln p(x_n | \boldsymbol{s}_n, \boldsymbol{\lambda}) ] + \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln p(\boldsymbol{s}_n | \boldsymbol{\pi}) ] \Bigr\} + \mathrm{const.} \tag{4.47} \end{align*} $$

となる。$\boldsymbol{S}$と無関係な項は適宜$\mathrm{const.}$に含めている。

 この式の具体的な計算を行っていく。前の項は

$$ \begin{align*} \mathbb{E}_{q(\boldsymbol{\lambda})} [ \ln p(x_n | \boldsymbol{s}_n, \boldsymbol{\lambda}) ] &= \mathbb{E}_{q(\lambda_k)} \left[ \ln \prod_{k=1}^K \mathrm{Poi}(x_n | \lambda_k)^{s_{n,k}} \right] \\ &= \sum_{k=1}^K \mathbb{E}_{q(\lambda_k)} \Bigl[ s_{n,k} \ln \frac{\lambda_k^{x_n}}{x_n!} \exp(- \lambda_k) \Bigr] \\ &= \sum_{k=1}^K s_{n,k} \mathbb{E}_{q(\lambda_k)} \Bigl[ x_n \ln \lambda_k - \lambda_k \Bigr] + \mathrm{const.} \\ &= \sum_{k=1}^K s_{n,k} \Bigl( x_n \mathbb{E}_{q(\lambda_k)} [ \ln \lambda_k ] - \mathbb{E}_{q(\lambda_k)} [ \lambda_k ] \Bigr) + \mathrm{const.} \tag{4.48} \end{align*} $$

となり、後の項は

$$ \begin{align*} \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln p(\boldsymbol{s}_n | \boldsymbol{\pi}) ] &= \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln \mathrm{Cat}(\boldsymbol{s}_n | \boldsymbol{\pi}) ] \\ &= \mathbb{E}_{q(\boldsymbol{\pi})} \left[ \ln \prod_{k=1}^K \pi_k^{s_{n,k}} \right] \\ &= \sum_{k=1}^K s_{n,k} \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln \pi_k ] \tag{4.49} \end{align*} $$

となる。

 よって2つの式を式(4.47)に代入すると

$$ \ln q(\boldsymbol{S}) = \sum_{n=1}^N \sum_{k=1}^K s_{n,k} \Bigl( x_n \mathbb{E}_{q(\lambda_k)} [ \ln \lambda_k ] - \mathbb{E}_{q(\lambda_k)} [ \lambda_k ] + \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln \pi_k ] \Bigr) + \mathrm{const.} $$

となる。

 この式について

$$ \eta_{n,k} \propto x_n \mathbb{E}_{q(\lambda_k)} [ \ln \lambda_k ] - \mathbb{E}_{q(\lambda_k)} [ \lambda_k ] + \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln \pi_k ] ,\ \sum_{k=1}^K \eta_{n,k} = 1 \tag{4.51} $$

とおくと、$\boldsymbol{s}_n$の近似事後分布$q(\boldsymbol{s}_n)$は式の形から

$$ \ln q(\boldsymbol{S}) = \sum_{n=1}^N \sum_{k=1}^K s_{n,k} \ln \eta_{n,k} = \ln \prod_{n=1}^N \mathrm{Cat}(\boldsymbol{s}_n | \boldsymbol{\eta}_n) \tag{4.50} $$

パラメータ$\boldsymbol{\eta}_n$を持つカテゴリ分布になることが分かる。ここで$\boldsymbol{\eta}_n = (\eta_{n,1}, \cdots, \eta_{n,K})$とする。
 $\eta_{n,k}$の計算式(4.51)については、他の分布が明らかになってから詳しくみる。

・パラメータの近似事後分布の更新式の導出

 続いて、パラメータ$\boldsymbol{\lambda},\ \boldsymbol{\pi}$についても同様に変分推論の公式(4.25)を用いて

$$ \begin{align*} \ln q(\boldsymbol{\lambda}, \boldsymbol{\pi}) &= \mathbb{E}_{q(\boldsymbol{S})} \Bigl[ \ln p(\boldsymbol{S}, \boldsymbol{\lambda}, \boldsymbol{\pi} | \boldsymbol{X}) \Bigr] + \mathrm{const.} \\ &= \mathbb{E}_{q(\boldsymbol{S})} \left[ \ln \frac{ p(\boldsymbol{X}, \boldsymbol{S}, \boldsymbol{\lambda}, \boldsymbol{\pi}) }{ p(\boldsymbol{X}) } \right] + \mathrm{const.} \\ &= \mathbb{E}_{q(\boldsymbol{S})} \Bigl[ \ln p(\boldsymbol{X}, \boldsymbol{S}, \boldsymbol{\lambda}, \boldsymbol{\pi}) \Bigr] + \mathrm{const.} \\ &= \mathbb{E}_{q(\boldsymbol{S})} \Bigl[ \ln p(\boldsymbol{X} | \boldsymbol{S}, \boldsymbol{\lambda}) + \ln p(\boldsymbol{\lambda}) + \ln p(\boldsymbol{S} | \boldsymbol{\pi}) + \ln p(\boldsymbol{\pi}) \Bigr] + \mathrm{const.} \\ &= \mathbb{E}_{q(\boldsymbol{S})} \Bigl[ \ln p(\boldsymbol{X} | \boldsymbol{S}, \boldsymbol{\lambda}) \Bigr] + \ln p(\boldsymbol{\lambda}) \\ &\qquad + \mathbb{E}_{q(\boldsymbol{S})} \Bigl[ \ln p(\boldsymbol{S} | \boldsymbol{\pi}) \Bigr] + \ln p(\boldsymbol{\pi}) + \mathrm{const.} \tag{4.52} \end{align*} $$

となる。

 ここから$\boldsymbol{\lambda}$に関係する項を取り出して、$q(\boldsymbol{\lambda})$の具体的な計算を行っていく。

$$ \begin{align*} \ln q(\boldsymbol{\lambda}) &= \mathbb{E}_{q(\boldsymbol{s}_n)} \left[ \ln \prod_{n=1}^N p(x_n | \boldsymbol{s}_n, \boldsymbol{\lambda}) \right] + \ln \prod_{k=1}^K p(\lambda_k) + \mathrm{const.} \\ &= \sum_{n=1}^N \mathbb{E}_{q(\boldsymbol{s}_n)} \left[ \ln \prod_{k=1}^K \mathrm{Poi}(x_n | \lambda_k)^{s_{n,k}} \right] + \sum_{k=1}^K \ln \mathrm{Gam}(\lambda_k | a, b) + \mathrm{const.} \\ &= \sum_{k=1}^K \left\{ \sum_{n=1}^N \mathbb{E}_{q(\boldsymbol{s}_n)} \left[ s_{n,k} \ln \frac{\lambda_k^{x_n}}{x_n!} \exp(- \lambda_k) \right] + \ln C_G(a, b) \lambda_k^{a-1} \exp(- b \lambda_k) \right\} + \mathrm{const.} \\ &= \sum_{k=1}^K \left\{ \sum_{n=1}^N \mathbb{E}_{q(\boldsymbol{s}_n)} \Bigl[ s_{n,k} ( x_n \ln \lambda_k - \lambda_k ) \Bigr] + (a - 1) \ln \lambda_k - b \lambda_k \right\} + \mathrm{const.} \\ &= \sum_{k=1}^K \left\{ \Bigl( \sum_{n=1}^N \mathbb{E}_{q(\boldsymbol{s}_n)} [s_{n,k}] x_n + a - 1 \Bigr) \ln \lambda_k - \Bigl( \sum_{n=1}^N \mathbb{E}_{q(\boldsymbol{s}_n)} [s_{n,k}] + b \Bigr) \lambda_k \right\} + \mathrm{const.} \tag{4.53} \end{align*} $$

 この式について

$$ \begin{align*} \hat{a}_k &= \sum_{n=1}^N \mathbb{E}_{q(\boldsymbol{s}_n)} [s_{n,k}] x_n + a \\ \hat{b}_k &= \sum_{n=1}^N \mathbb{E}_{q(\boldsymbol{s}_n)} [s_{n,k}] + b \tag{4.55} \end{align*} $$

とおくと、$\lambda_k$の近似事後分布$q(\lambda_k)$は

$$ \ln q(\boldsymbol{\lambda}) = \sum_{k=1}^K \Bigl\{ \hat{a}_k \ln \lambda_k - \hat{b}_k \lambda_k \Bigr\} + \mathrm{const.} = \ln \prod_{k=1}^K \mathrm{Gam}(\lambda_k | \hat{a}_k, \hat{b}_k) \tag{4.54} $$

パラメータ$\hat{a}_k,\ \hat{b}_k$を持つガンマ分布になることが分かる。

 $q(\boldsymbol{\pi})$についても同様に、式(4.52)から$\boldsymbol{\pi}$に関係する項を取り出すと

$$ \begin{align*} q(\boldsymbol{\pi}) &= \mathbb{E}_{q(\boldsymbol{s}_n)} \left[ \ln \prod_{n=1}^N p(\boldsymbol{s}_n | \boldsymbol{\pi}) \right] + \ln p(\boldsymbol{\pi}) \\ &= \sum_{n=1}^N \mathbb{E}_{q(\boldsymbol{s}_n)} \left[ \ln \mathrm{Cat}(\boldsymbol{s}_n | \boldsymbol{\pi}) \right] + \ln \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) \\ &= \sum_{n=1}^N \mathbb{E}_{q(\boldsymbol{s}_n)} \left[ \ln \prod_{k=1}^K \pi_k^{s_{n,k}} \right] + \ln C_D(\boldsymbol{\alpha}) \prod_{k=1}^K \pi_k^{\alpha_k-1} \\ &= \sum_{k=1}^K \left\{ \sum_{n=1}^N \mathbb{E}_{q(\boldsymbol{s}_n)} [ s_{n,k} \ln \pi_k ] + (\alpha_k - 1) \ln \pi_k \right\} + \mathrm{const.} \\ &= \sum_{k=1}^K \left( \sum_{n=1}^N \mathbb{E}_{q(\boldsymbol{s}_n)} [s_{n,k}] + \alpha_k - 1 \right) \ln \pi_k + \mathrm{const.} \tag{4.56} \end{align*} $$

となる。

 この式について

$$ \hat{\alpha}_k = \sum_{n=1}^N \mathbb{E}_{q(\boldsymbol{s}_n)} [s_{n,k}] + \alpha_k \tag{4.58} $$

とおくと、$\boldsymbol{\pi}$の近似事後分布$q(\boldsymbol{\pi})$は

$$ q(\boldsymbol{\pi}) = \sum_{k=1}^K \hat{\alpha}_k \ln \pi_k + \mathrm{const.} = \mathrm{Dir}(\boldsymbol{\pi} | \hat{\boldsymbol{\alpha}}) \tag{4.57} $$

パラメータ$\hat{\boldsymbol{\alpha}}$を持つディリクレ分布になることが分かる。ここで$\hat{\boldsymbol{\alpha}} = (\hat{\alpha}_1, \cdots, \hat{\alpha}_K)$とする。
 $\hat{\alpha}_k$の計算式(4.58)について、$q(\boldsymbol{s}_n) = \mathrm{Cat}(\boldsymbol{s}_n | \boldsymbol{\eta}_n)$より、カテゴリ分布の期待値(2.31)を用いて

$$ \mathbb{E}_{q(\boldsymbol{s}_n)} [s_{n,k}] = \eta_{n,k} \tag{4.59} $$

となる。

 最後に$\eta_{n,k}$の計算式(4.51)の各項について確認する。$q(\lambda_k) = \mathrm{Gam}(\lambda_k | \hat{a}_k, \hat{b}_k)$、$q(\boldsymbol{\pi}) = \mathrm{Dir}(\boldsymbol{\pi} | \hat{\boldsymbol{\alpha}})$なので、それぞれガンマ分布の期待値(2.59)、(2.60)、ディリクレ分布の期待値(2.52)より

$$ \begin{align*} \mathbb{E}_{q(\lambda_k)} [\lambda_k] &= \frac{\hat{a}_k}{\hat{b}_k} \tag{4.60} \\ \mathbb{E}_{q(\lambda_k)} [\ln \lambda_k] &= \psi(\hat{a}_k) - \ln \hat{b}_k \tag{4.61} \\ \mathbb{E}_{q(\boldsymbol{\pi})} [\ln \pi_k] &= \psi(\hat{\alpha}_k) - \psi \left(\sum_{k'=1}^K \hat{\alpha}_{k'} \right) \tag{4.62} \end{align*} $$

で計算できる。

・Rでやってみる

 では、手を動かしてプログラムからアルゴリズムの理解を深めましょう。アルゴリズム4.2を参考に実装します。

# 利用パッケージ
library(tidyverse)

 必要なパッケージを読み込みます。

・真の観測モデルの設定
# (観測)データ数を指定
N <- 100

# 真のパラメータを指定
lambda_truth <- c(5, 25)
pi_truth <- c(0.3, 0.7)

# クラスタ数
K <- length(lambda_truth)

# クラスタ(潜在変数)を生成
s_nk <- rmultinom(n =  N, size = 1, prob = pi_truth) %>% 
  t()

# (観測)データXを生成
x_n <- rpois(n = N, lambda = apply(lambda_truth^t(s_nk), 2, prod))

 詳細は4.3.2項を参照ください。生成したデータを用いてlambda_truthpi_truthに設定した値を推定します。

 観測データを確認してみましょう。

# 観測データを確認
summary(x_n)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##    2.00    8.00   23.00   18.93   26.25   39.00

 ヒストグラムにすると次のようになります。

tibble(x = x_n) %>% 
  ggplot(aes(x = x)) + 
    geom_bar(fill = "#56256E") + 
    labs(title = "Histogram")

f:id:anemptyarchive:20200423155922p:plain
観測データ$\boldsymbol{X}$のヒストグラム

・パラメータの設定
# 試行回数
MaxIter <- 50

 繰り返し回数を指定します。

# ハイパーパラメータa,bの初期値を指定
a <- 1
b <- 1

 $\boldsymbol{\lambda}$のパラメータ$a,\ b$の値を指定します。

# lambda(の期待値)の初期値をランダムに設定
tmp_lambda <- seq(0, 1, by = 0.01) %>% 
  sample(size = K, replace = TRUE)
E_lambda_k    <- tmp_lambda / sum(tmp_lambda) # 正規化
E_ln_lambda_k <- log(E_lambda_k) # 対数をとる

 $\boldsymbol{\lambda}$の値をランダムに設定します。ギブスサンプリングではa, bを用いて乱数を生成しましたが、こちらは分布に関係なくランダムに値を決めます。

 ここでは期待値ではありませんが、対数をとったものとそれぞれ変数名をE_lambda_kE_ln_lambda_kとしておきます。

# ハイパーパラメータalphaの初期値を指定
alpha_k <- rep(2, K)

 $\boldsymbol{\pi}$のパラメータ$\boldsymbol{\alpha}$の値を指定します。

# pi(の対数をとった期待値)の初期値をランダムに設定
tmp_pi <- seq(0, 1, by = 0.01) %>% 
  sample(size = K, replace = TRUE)
E_ln_pi_k <- tmp_pi / sum(tmp_pi) %>% # 正規化
  log() # 対数をとる

 $\boldsymbol{\lambda}$と同様に$\boldsymbol{\pi}$もランダムに設定します。

・変分推論

・コード全体(クリックで展開)

# 受け皿を用意
eta_nk <- matrix(0, nrow = N, ncol = K)
E_s_nk <- matrix(0, nrow = N, ncol = K)
hat_a_k <- rep(0, K)
hat_b_k <- rep(0, K)

# ハイパーパラメータの推定値の推移の確認用
trace_a <- matrix(0, nrow = MaxIter + 1, ncol = K)
trace_b <- matrix(0, nrow = MaxIter + 1, ncol = K)
trace_alpha <- matrix(0, nrow = MaxIter + 1, ncol = K)
# 初期値を代入
trace_a[1, ] <- a
trace_b[1, ] <- b
trace_alpha[1, ] <- alpha_k

for(i in 1:MaxIter) {
  
  for(n in 1:N) {
    
    # パラメータeta_nを計算:式(4.51)
    tmp_eta <- exp(x_n[n] * E_ln_lambda_k - E_lambda_k + E_ln_pi_k)
    eta_nk[n, ] <- tmp_eta / sum(tmp_eta) # 正規化
    
    # s_nkの期待値を計算:式(4.59)
    E_s_nk[n, ] <- eta_nk[n, ]
  }
  
  for(k in 1:K) {
    
    # ハイパーパラメータa_hat_k,b_hat_kを計算:式(4.55)
    hat_a_k[k] <- sum(E_s_nk[, k] * x_n) + a
    hat_b_k[k] <- sum(E_s_nk[, k]) + b
    
    # (対数をとった)lambda_kの期待値を計算:式(4.60),(4.61)
    E_lambda_k[k] <- hat_a_k[k] / hat_b_k[k]
    E_ln_lambda_k[k] <- digamma(hat_a_k[k]) - log(hat_b_k[k])
  }
  
  # ハイパーパラメータalpha_hatを計算:式(4.58)
  hat_alpha_k <- apply(E_s_nk, 2, sum) + alpha_k
  
  # 対数をとったpiの期待値を計算:式(4.62)
  E_ln_pi_k <- digamma(hat_alpha_k) - digamma(sum(hat_alpha_k))
  
  # 推移の確認用に推定結果を保存
  trace_a[i + 1, ] <- hat_a_k
  trace_b[i + 1, ] <- hat_b_k
  trace_alpha[i + 1, ] <- hat_alpha_k
  
}

 添字を使って代入するために、予め変数を用意しておく必要があります。
 また各パラメータの更新値の推移を確認するため、学習する度に値を保存しておきます。


・$q(\boldsymbol{S})$の更新

for(n in 1:N) {
  
  # パラメータeta_nを計算:式(4.51)
  tmp_eta <- exp(x_n[n] * E_ln_lambda_k - E_lambda_k + E_ln_pi_k)
  eta_nk[n, ] <- tmp_eta / sum(tmp_eta) # 正規化
  
  # s_nkの期待値を計算:式(4.59)
  E_s_nk[n, ] <- eta_nk[n, ]
}

 式(4.51)の計算を行い、$\boldsymbol{\eta}_n$を求めます。

 eta_nk[n, ]の値が$\mathbb{E}[\boldsymbol{s}_n]$となります(式(4.59))。

 この処理を1から$N$まで繰り返すことで、近似事後分布$q(\boldsymbol{S})$を更新できます。

・$q(\boldsymbol{\lambda})$の更新

for(k in 1:K) {
  
  # ハイパーパラメータa_hat_k,b_hat_kを計算:式(4.55)
  hat_a_k[k] <- sum(E_s_nk[, k] * x_n) + a
  hat_b_k[k] <- sum(E_s_nk[, k]) + b
  
  # (対数をとった)lambda_kの期待値を計算:式(4.60),(4.61)
  E_lambda_k[k] <- hat_a_k[k] / hat_b_k[k]
  E_ln_lambda_k[k] <- digamma(hat_a_k[k]) - log(hat_b_k[k])
}

 式(4.55)の計算を行い、$\hat{a}_k,\ \hat{b}_k$を求めます。

 hat_a_k[k], hat_b_k[k]を使って式(4.60)、(4.61)の計算を行い、$\mathbb{E}_{q(\lambda_k)}[\lambda_k]$と$\mathbb{E}_{q(\lambda_k)}[\ln \lambda_k]$を求めます。

 この処理を1から$K$まで繰り返すことで、近似事後分布$q(\boldsymbol{\lambda})$を更新できます。

・$q(\boldsymbol{\pi})$の更新

# ハイパーパラメータalpha_hatを計算:式(4.58)
hat_alpha_k <- apply(E_s_nk, 2, sum) + alpha_k

# 対数をとったpiの期待値を計算:式(4.62)
E_ln_pi_k <- digamma(hat_alpha_k) - digamma(sum(hat_alpha_k))

 式(4.58)の計算を行い、$\hat{\boldsymbol{\alpha}}$を求めます。

 hat_alpha_kを使って式(4.62)の計算を行い、$\mathbb{E}_{q(\boldsymbol{\pi})}[\ln \pi_k]$を求めます。

 近似事後分布$q(\boldsymbol{\pi})$については1度に全ての要素を更新するため、繰り返し処理は行いません。

・推定結果の確認

 ggplot2パッケージを利用して、パラメータ$\boldsymbol{\lambda},\ \boldsymbol{\pi}$の事後分布を可視化します。

・コード(クリックで展開)

## lambdaの近似事後分布
# 作図用のデータフレームを作成
lambda_df <- tibble()
for(k in 1:K) {
  # データフレームに変換
  tmp_lambda_df <- tibble(
    lambda = seq(0, max(x_n), by = 0.01), 
    density = dgamma(lambda, shape = hat_a_k[k], rate = hat_b_k[k]), 
    cluster = as.factor(k)
  )
  # 結合
  lambda_df <- rbind(lambda_df, tmp_lambda_df)
}

# 作図
ggplot(lambda_df, aes(lambda, density, color = cluster)) + 
  geom_line() + # 折れ線グラフ
  scale_color_manual(values = c("#00A968", "orange")) + # グラフの色(不必要)
  geom_vline(xintercept = lambda_truth, color = "pink", linetype = "dashed") + # 垂直線
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = paste0("a_hat=(", paste0(round(hat_a_k, 1), collapse = ", "), 
                         "), b_hat=(", paste0(round(hat_b_k, 1), collapse = ", "), ")")) # ラベル

 scale_color_manual()は使わない方が無難です(私の趣味の色です)。色を指定する場合は、クラスタ数$K$と同じ数の色を指定する必要があります。

f:id:anemptyarchive:20200423162200p:plain
$q(\boldsymbol{\lambda})$:ガンマ分布

・コード(クリックで展開)

## piの近似事後分布(K=2のときのみ可能)
# 作図用のデータフレームを作成
pi_df <- tibble()
for(k in 1:K) {
  # データフレームに変換
  tmp_pi_df <- tibble(
    pi = seq(0, 1, by = 0.001), 
    density = dbeta(pi, shape1 = hat_alpha_k[k], shape2 = hat_alpha_k[2 / k]), 
    cluster = as.factor(k)
  )
  # 結合
  pi_df <- rbind(pi_df, tmp_pi_df)
}

# 作図
ggplot(pi_df, aes(pi, density, color = cluster)) + 
  geom_line() + # 折れ線グラフ
  scale_color_manual(values = c("#00A968", "orange")) + # グラフの色(不必要)
  geom_vline(xintercept = pi_truth, color = "pink", linetype = "dashed") + # 垂直線
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = paste0("alpha_hat=(", paste0(round(hat_alpha_k, 1), collapse = ", "), ")")) # ラベル

 $\boldsymbol{\pi}$については、このコードでグラフ化できるのは$K = 2$のときだけになります。(詳しくは3.2.2:カテゴリ分布の学習と予測【緑ベイズ入門のノート】 - からっぽのしょこをご参照ください。)

f:id:anemptyarchive:20200423162338p:plain
$q(\boldsymbol{\pi})$:ディリクレ分布

・推移の確認

・事後分布の推移

 gganimateパッケージを利用して、分布の推移をgif画像として出力するためのコードです。

・コード(クリックで展開)

# 利用パッケージ
library(gganimate)

## lambdaの近似事後分布
# 作図用のデータフレームを作成
trace_lambda_long <- tibble()
for(i in 1:(MaxIter + 1)) {
  for(k in 1:K) {
    # データフレームに変換
    tmp_lambda_df <- tibble(
      lambda = seq(0, max(x_n), by = 0.01), 
      density = dgamma(lambda, shape = trace_a[i, k], rate = trace_b[i, k]), 
      cluster = as.factor(k), 
      Iteration = i - 1
    )
    # 結合
    trace_lambda_long <- rbind(trace_lambda_long, tmp_lambda_df)
  }
}

# 作図
graph_lambda <- ggplot(trace_lambda_long, aes(lambda, density, color = cluster)) + 
  geom_line() + # 折れ線グラフ
  scale_color_manual(values = c("#00A968", "orange")) + # グラフの色(不必要)
  geom_vline(xintercept = lambda_truth, color = "pink", linetype = "dashed") + # 垂直線
  transition_manual(Iteration) + # フレーム
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = "i={current_frame}") # ラベル
  
# gif画像を作成
animate(graph_lambda, nframes = MaxIter + 1, fps = 5)

 gganimate::transition_manual()に時系列を指定して、gganimate::animate()でgifファイルに変換します。フレーム数は初期値の0を含むためイタレーション数+1となります。

f:id:anemptyarchive:20200423161512g:plain
$q(\boldsymbol{\lambda})$の推移

・コード(クリックで展開)

## piの近似事後分布(K=2のときのみ可)
# 作図用のデータフレームを作成
trace_pi_long <- tibble()
for(i in 1:(MaxIter + 1)) {
  for(k in 1:K) {
    # データフレームに変換
    tmp_pi_df <- tibble(
      pi = seq(0, 1, by = 0.001), 
      density = dbeta(pi, shape1 = trace_alpha[i, k], shape2 = trace_alpha[i, 2 / k]), 
      cluster = as.factor(k), 
      Iteration = i - 1
    )
    # 結合
    trace_pi_long <- rbind(trace_pi_long, tmp_pi_df)
  }
}

# 作図
graph_pi <- ggplot(trace_pi_long, aes(pi, density, color = cluster)) + 
  geom_line() + # 折れ線グラフ
  scale_color_manual(values = c("#00A968", "orange")) + # グラフの色(不必要)
  geom_vline(xintercept = pi_truth, color = "pink", linetype = "dashed") + # 垂直線
  transition_manual(Iteration) + # フレーム
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = "i={current_frame}") # ラベル

# gif画像を作成
animate(graph_pi, nframes = MaxIter + 1, fps = 5)

f:id:anemptyarchive:20200423161636g:plain
$q(\boldsymbol{\pi})$の推移

・ハイパーパラメータの推移

・コード(クリックで展開)

## lambdaのパラメータa
# データフレームに変換
trace_a_wide <- cbind(
  as.data.frame(trace_a), 
  Iteration = 1:(MaxIter + 1)
)

# long型に変換
trace_a_long <- pivot_longer(
  trace_a_wide, 
  cols = -Iteration, 
  names_to = "cluster", 
  names_prefix = "V", 
  names_ptypes = list(cluster = factor()), 
  values_to = "value"
)

# 作図
ggplot(trace_a_long, aes(Iteration, value, color = cluster)) + 
  geom_line() + 
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = expression(hat(a)))


## lambdaのパラメータb
# データフレームに変換
trace_b_wide <- cbind(
  as.data.frame(trace_a), 
  Iteration = 1:(MaxIter + 1)
)

# long型に変換
trace_b_long <- pivot_longer(
  trace_b_wide, 
  cols = -Iteration, 
  names_to = "cluster", 
  names_prefix = "V", 
  names_ptypes = list(cluster = factor()), 
  values_to = "value"
)

# 作図
ggplot(trace_b_long, aes(Iteration, value, color = cluster)) + 
  geom_line() + 
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = expression(hat(b)))

f:id:anemptyarchive:20200423162959p:plainf:id:anemptyarchive:20200423163017p:plain
$\hat{a},\ \hat{b}$の推移

・コード(クリックで展開)

## piのパラメータ
# データフレームに変換
trace_alpha_wide <- cbind(
  as.data.frame(trace_alpha), 
  Iteration = 1:(MaxIter + 1)
)

# long型に変換
trace_alpha_long <- pivot_longer(
  trace_alpha_wide, 
  cols = -Iteration, 
  names_to = "cluster", 
  names_prefix = "V", 
  names_ptypes = list(cluster = factor()), 
  values_to = "value"
)

# 作図
ggplot(trace_alpha_long, aes(Iteration, value, color = cluster)) + 
  geom_line() + 
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = expression(hat(alpha)))

f:id:anemptyarchive:20200423163205p:plain
$\hat{\alpha}$の推移

参考文献

  • 須山敦志『ベイズ推論による機械学習入門』(機械学習スタートアップシリーズ)杉山将監修,講談社,2017年.

おわりに

 前回に引き続きこの記事もOsaka.Rのリモート朝もく会中に書いたものになります。

Osaka.Rが主催する毎平日8時から10時のもくもく会です。 参加するには下記URLからOsaka.Rのslackに登録してください。 #mokumoku チャンネルのトピックを見るともくもく会場への案内をご覧頂けます。

https://join.slack.com/t/osakar/shared_invite/zt-dgjyfztf-AVYDIx~P8Ncl6deigOOarA

イベントとしても公開していますので、そちらに登録いただいてもかまいません。

朝のちょっとした時間を活かしませんか? 途中参加途中離脱OK。 チャットでその日の課題を宣言してもくもく開始。困ったことがあればチャットや画面共有で協力し合いましょう。 大阪府外の方も大歓迎です!

osaka-r.connpass.com

暫く開催されるとのことですので、(Rユーザーの)皆様ぜひぜひ一緒に参加しましょー。私は毎日参加するつもりです!

【次節の内容】

www.anarchive-beta.com