からっぽのしょこ

読んだら書く!書いたら読む!同じ事は二度調べ(たく)ない

4.3.3:ポアソン混合モデルにおける推論:変分推論【緑ベイズ入門のノート】

はじめに

 『ベイズ推論による機械学習入門』の学習時のノートです。基本的な内容は「数式の行間を読んでみた」とそれを「RとPythonで組んでみた」になります。「数式」と「プログラム」から理解するのが目標です。

 この記事は、4.3.3項の内容です。「観測モデルをポアソン混合モデル」、「事前分布をガンマ分布」とする混合モデルを変分推論を用いて推論します。

 省略してある内容等ありますので、本とあわせて読んでください。初学者な自分が理解できるレベルまで落として書き下していますので、分かる人にはかなりくどくなっています。同じような立場の人のお役に立てれば幸いです。

【前節の内容】

www.anarchive-beta.com

【他の節一覧】

www.anarchive-beta.com

【この節の内容】

4.3.3 変分推論

 変分推論を用いて、ポアソン混合モデルの事後分布$p(\mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi} | \mathbf{X})$の近似分布$q(\mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi})$を導出する。

 観測データ$\mathbf{X}$が与えられた下での、潜在変数$\mathbf{S}$、観測モデルのパラメータ$\boldsymbol{\lambda}$、混合比率パラメータ$\boldsymbol{\pi}$の事後分布$p(\mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi} | \mathbf{X})$に対して、分解近似の仮定をおいた

$$ p(\mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi} | \mathbf{X}) \approx q(\mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi}) = q(\mathbf{S}) q(\boldsymbol{\lambda}) q(\boldsymbol{\pi}) \tag{4.46} $$

で近似する。$q(\mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi})$を近似事後分布、または変分事後分布と呼ぶ。

・潜在変数の近似事後分布の導出

 始めに、潜在変数$\mathbf{S}$の近似事後分布$q(\mathbf{S})$を求めていく。

 $\mathbf{S}$の近似事後分布は、事後分布$p(\mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi} | \mathbf{X})$と$q(\boldsymbol{\lambda}, \boldsymbol{\pi})$を固定した近似分布$q(\mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi})$に対して、4.2.2項「変分推論」で求めた変分推論の公式(4.25)を用いて

$$ \begin{align} \ln q(\mathbf{S}) &= \mathbb{E}_{q(\boldsymbol{\lambda}, \boldsymbol{\pi})} \Bigl[ \ln p(\mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi} | \mathbf{X}) \Bigr] + \mathrm{const.} \\ &= \mathbb{E}_{q(\boldsymbol{\lambda}, \boldsymbol{\pi})} \left[ \ln \frac{ p(\mathbf{X}, \mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi}) }{ p(\mathbf{X}) } \right] + \mathrm{const.} \\ &= \mathbb{E}_{q(\boldsymbol{\lambda}, \boldsymbol{\pi})} \Bigl[ \ln p(\mathbf{X} | \mathbf{S}, \boldsymbol{\lambda}) + \ln p(\mathbf{S} | \boldsymbol{\pi}) + \ln p(\boldsymbol{\lambda}) + \ln p(\boldsymbol{\pi}) - \ln p(\mathbf{X}) \Bigr] + \mathrm{const.} \\ &= \mathbb{E}_{q(\boldsymbol{\lambda}, \boldsymbol{\pi})} \Bigl[ \ln p(\mathbf{X} | \mathbf{S}, \boldsymbol{\lambda}) \Bigr] + \mathbb{E}_{q(\boldsymbol{\lambda}, \boldsymbol{\pi})} \Bigl[ \ln p(\mathbf{S} | \boldsymbol{\pi}) \Bigr] + \mathbb{E}_{q(\boldsymbol{\lambda}, \boldsymbol{\pi})} \Bigl[ \ln p(\boldsymbol{\lambda}) \Bigr] + \mathbb{E}_{q(\boldsymbol{\lambda}, \boldsymbol{\pi})} \Bigl[ \ln p(\boldsymbol{\pi}) \Bigr] - \mathbb{E}_{q(\boldsymbol{\lambda}, \boldsymbol{\pi})} \Bigl[ \ln p(\mathbf{X}) \Bigr] + \mathrm{const.} \\ &= \mathbb{E}_{q(\boldsymbol{\lambda})} \left[ \sum_{n=1}^N \ln p(\mathbf{x}_n | \mathbf{s}_n, \boldsymbol{\lambda}) \right] + \mathbb{E}_{q(\boldsymbol{\pi})} \left[ \sum_{n=1}^N \ln p(\mathbf{s}_n | \boldsymbol{\pi}) \right] + \mathrm{const.} \\ &= \sum_{n=1}^N \Bigl\{ \mathbb{E}_{q(\boldsymbol{\lambda})} \Bigl[ \ln p(x_n | \mathbf{s}_n, \boldsymbol{\lambda}) \Bigr] + \mathbb{E}_{q(\boldsymbol{\pi})} \Bigl[ \ln p(\mathbf{s}_n | \boldsymbol{\pi}) \Bigr] \Bigr\} + \mathrm{const.} \tag{4.47}\\ &= \sum_{n=1}^N \left\{ \mathbb{E}_{q(\boldsymbol{\lambda})} \left[ \sum_{k=1}^K \ln \mathrm{Poi}(x_n | \lambda_k)^{s_{n,k}} \right] + \mathbb{E}_{q(\boldsymbol{\pi})} \Bigl[ \ln \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\pi}) \Bigr] \right\} + \mathrm{const.} \end{align} $$

で求められる。4.3.1項「ポアソン混合モデル」で確認した各変数の生成過程(依存関係)に従い項を分解している。また、適宜$\mathbf{S}$に影響しない項を$\mathrm{const.}$にまとめて比例関係に注目する。省略した部分については、最後に正規化することで対応できる。
 連続値の期待値の定義$\mathbb{E}_{q(\boldsymbol{\lambda})}[\boldsymbol{\lambda}] = \int q(\boldsymbol{\lambda}) \boldsymbol{\lambda} d\boldsymbol{\lambda}$より、$\boldsymbol{\lambda}$に影響しない項は$\mathbb{E}_{q(\boldsymbol{\lambda})}[\cdot]$の外に出せる。また、連続値の確率分布の定義より$\int q(\boldsymbol{\lambda}) d\boldsymbol{\lambda} = 1$なので、期待値の括弧内($\int$の中)の項がなくなると1となり消える。$\boldsymbol{\pi}$についても同様である。

 $n$番目の潜在変数(ある1つのデータのクラスタ)$\mathbf{s}_n$の近似事後分布の具体的な形状を明らかにしていく。前の項は

$$ \begin{align} \mathbb{E}_{q(\boldsymbol{\lambda})} \Bigl[ \ln p(x_n | \mathbf{s}_n, \boldsymbol{\lambda}) \Bigr] &= \mathbb{E}_{q(\lambda_k)} \left[ \sum_{k=1}^K s_{n,k} \ln \mathrm{Poi}(x_n | \lambda_k) \right] \\ &= \sum_{k=1}^K s_{n,k} \mathbb{E}_{q(\lambda_k)} \left[ \ln \frac{\lambda_k^{x_n}}{x_n!} \exp(- \lambda_k) \right] \\ &= \sum_{k=1}^K s_{n,k} \mathbb{E}_{q(\lambda_k)} \Bigl[ x_n \ln \lambda_k - \ln x_n! - \lambda_k \Bigr] \\ &= \sum_{k=1}^K s_{n,k} \Bigl( x_n \mathbb{E}_{q(\lambda_k)} [ \ln \lambda_k ] - \mathbb{E}_{q(\lambda_k)} [ \lambda_k ] \Bigr) + \mathrm{const.} \tag{4.48} \end{align} $$

となる。$\sum_{k=1}^K s_{n,k} = 1$なので、$\sum_{k=1}^K - s_{n,k} \ln x_n! = - \ln x_n!$となり$\mathbf{s}_n$の影響を受けなくなるので$\mathrm{const.}$に含める。

 後の項は

$$ \begin{align} \mathbb{E}_{q(\boldsymbol{\pi})} \Bigl[ \ln p(\mathbf{s}_n | \boldsymbol{\pi}) \Bigr] &= \mathbb{E}_{q(\boldsymbol{\pi})} \Bigl[ \ln \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\pi}) \Bigr] \\ &= \mathbb{E}_{q(\boldsymbol{\pi})} \left[ \ln \prod_{k=1}^K \pi_k^{s_{n,k}} \right] \\ &= \mathbb{E}_{q(\boldsymbol{\pi})} \left[ \sum_{k=1}^K s_{n,k} \ln \pi_k \right] \\ &= \sum_{k=1}^K s_{n,k} \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln \pi_k ] \tag{4.49} \end{align} $$

となる。

 よって、式(4.48)と式(4.49)を$n$番目のデータに関する項を取り出した式(4.47)に代入すると

$$ \begin{align} \ln q(\mathbf{s}_n) &= \mathbb{E}_{q(\boldsymbol{\lambda})} \Bigl[ \ln p(x_n | \mathbf{s}_n, \boldsymbol{\lambda}) \Bigr] + \mathbb{E}_{q(\boldsymbol{\pi})} \Bigl[ \ln p(\mathbf{s}_n | \boldsymbol{\pi}) \Bigr] + \mathrm{const.} \tag{4.47'}\\ &= \sum_{k=1}^K s_{n,k} \Bigl( x_n \mathbb{E}_{q(\lambda_k)} [ \ln \lambda_k ] - \mathbb{E}_{q(\lambda_k)} [ \lambda_k ] \Bigr) + \sum_{k=1}^K s_{n,k} \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln \pi_k ] + \mathrm{const.} \\ &= \sum_{k=1}^K s_{n,k} \Bigl( x_n \mathbb{E}_{q(\lambda_k)} [ \ln \lambda_k ] - \mathbb{E}_{q(\lambda_k)} [ \lambda_k ] + \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln \pi_k ] \Bigr) + \mathrm{const.} \end{align} $$

となる。

 この式について

$$ \eta_{n,k} \propto \exp \Bigl\{ x_n \mathbb{E}_{q(\lambda_k)} [ \ln \lambda_k ] - \mathbb{E}_{q(\lambda_k)} [ \lambda_k ] + \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln \pi_k ] \Bigr\} \tag{4.51} $$

とおき

$$ \ln q(\mathbf{s}_n) = \sum_{k=1}^K s_{n,k} \ln \eta_{n,k} + \mathrm{const.} $$

さらに$\ln$を外し、$\sum_{k=1}^K \eta_{n,k} = 1$となるように正規化する($\mathrm{const.}$を正規化項に置き換える)と

$$ q(\mathbf{s}_n) = \prod_{k=1}^K \eta_{n,k}^{s_{n,k}} = \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\eta}_n) \tag{4.50} $$

$\mathbf{s}_n$の近似事後分布は、パラメータ$\boldsymbol{\eta}_n = (\eta_{n,1}, \eta_{n,2}, \cdots, \eta_{n,K})$を持つカテゴリ分布になることが分かる。

 $\eta_{n,k}$の計算式(更新式)(4.51)については、$q(\lambda_k),\ q(\boldsymbol{\pi})$の形状を明らかにしてから確認する。

・パラメータの近似事後分布の導出

 次に、パラメータ$\boldsymbol{\lambda},\ \boldsymbol{\pi}$の(同時)近似事後分布$q(\boldsymbol{\lambda}, \boldsymbol{\pi})$から、各パラメータの近似事後分布$q(\boldsymbol{\lambda}),\ q(\boldsymbol{\pi})$を求めていく。

 $\boldsymbol{\lambda},\ \boldsymbol{\pi}$の近似事後分布は、事後分布$p(\mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi} | \mathbf{X})$と$q(\mathbf{S})$を固定した近似分布$q(\mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi})$に対して、変分推論の公式(4.25)を用いて

$$ \begin{align} \ln q(\boldsymbol{\lambda}, \boldsymbol{\pi}) &= \mathbb{E}_{q(\mathbf{S})} \Bigl[ \ln p(\mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi} | \mathbf{X}) \Bigr] + \mathrm{const.} \\ &= \mathbb{E}_{q(\mathbf{S})} \left[ \ln \frac{ p(\mathbf{X}, \mathbf{S}, \boldsymbol{\lambda}, \boldsymbol{\pi}) }{ p(\mathbf{X}) } \right] + \mathrm{const.} \\ &= \mathbb{E}_{q(\mathbf{S})} \Bigl[ \ln p(\mathbf{X} | \mathbf{S}, \boldsymbol{\lambda}) + \ln p(\boldsymbol{\lambda}) + \ln p(\mathbf{S} | \boldsymbol{\pi}) + \ln p(\boldsymbol{\pi}) - \ln p(\mathbf{X}) \Bigr] + \mathrm{const.} \\ &= \mathbb{E}_{q(\mathbf{S})} \Bigl[ \ln p(\mathbf{X} | \mathbf{S}, \boldsymbol{\lambda}) \Bigr] + \mathbb{E}_{q(\mathbf{S})} \Bigl[ \ln p(\boldsymbol{\lambda}) \Bigr] + \mathbb{E}_{q(\mathbf{S})} \Bigl[ \ln p(\mathbf{S} | \boldsymbol{\pi}) \Bigr] + \mathbb{E}_{q(\mathbf{S})} \Bigl[ \ln p(\boldsymbol{\pi}) \Bigr] - \mathbb{E}_{q(\mathbf{S})} \Bigl[ \ln p(\mathbf{X}) \Bigr] + \mathrm{const.} \\ &= \mathbb{E}_{q(\mathbf{S})} \Bigl[ \ln p(\mathbf{X} | \mathbf{S}, \boldsymbol{\lambda}) \Bigr] + \ln p(\boldsymbol{\lambda}) + \mathbb{E}_{q(\mathbf{S})} \Bigl[ \ln p(\mathbf{S} | \boldsymbol{\pi}) \Bigr] + \ln p(\boldsymbol{\pi}) + \mathrm{const.} \tag{4.52}\\ &= \mathbb{E}_{q(\mathbf{S})} \left[ \sum_{n=1}^N \ln p(\mathbf{x}_n | \mathbf{s}_n, \boldsymbol{\lambda}) \right] + \sum_{k=1}^K \ln p(\lambda_k) + \mathbb{E}_{q(\mathbf{S})} \left[ \sum_{n=1}^N \ln p(\mathbf{s}_n | \boldsymbol{\pi}) \right] + \ln p(\boldsymbol{\pi}) + \mathrm{const.} \\ &= \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} \left[ \sum_{k=1}^K \ln \mathrm{Poi}(x_n | \lambda_k)^{s_{n,k}} \right] + \sum_{k=1}^K \ln \mathrm{Gam}(\lambda_k | a, b) + \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} \Bigl[ \ln \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\pi}) \Bigr] + \ln \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) + \mathrm{const.} \end{align} $$

で求められる。こちらも生成過程に従い項を分解して、$\boldsymbol{\lambda},\ \boldsymbol{\pi}$に影響しない項を省く。$\mathbf{S}$と無関係な項は$\mathbb{E}_{q(\mathbf{S})} [\cdot]$の外に出せる。

 また、左辺の(対数をとった同時)近似事後分布は

$$ \ln q(\boldsymbol{\lambda}, \boldsymbol{\pi}) = \ln q(\boldsymbol{\lambda}) + \ln q(\boldsymbol{\pi}) $$

と分解できる。

 この式を用いて、$\boldsymbol{\lambda},\ \boldsymbol{\pi}$それぞれの近似事後分布の具体的な形状を明らかにしていく。

・観測モデルのパラメータの近似事後分布

 式(4.52)を$\boldsymbol{\lambda}$に関して整理する($\boldsymbol{\lambda}$に影響しない項を$\mathrm{const.}$にまとめる)と

$$ \begin{align} \ln q(\boldsymbol{\lambda}) &= \mathbb{E}_{q(\mathbf{S})} \Bigl[ \ln p(\mathbf{X} | \mathbf{S}, \boldsymbol{\lambda}) \Bigr] + \ln p(\boldsymbol{\lambda}) - \ln q(\boldsymbol{\pi}) + \mathrm{const.} \\ &= \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} \left[ \sum_{k=1}^K s_{n,k} \ln \mathrm{Poi}(x_n | \lambda_k) \right] + \sum_{k=1}^K \ln \mathrm{Gam}(\lambda_k | a, b) + \mathrm{const.} \\ &= \sum_{n=1}^N \sum_{k=1}^K \mathbb{E}_{q(\mathbf{s}_n)} \left[ s_{n,k} \ln \frac{\lambda_k^{x_n}}{x_n!} \exp(- \lambda_k) \right] + \sum_{k=1}^K \ln C_G(a, b) \lambda_k^{a-1} \exp(- b \lambda_k) + \mathrm{const.} \\ &= \sum_{k=1}^K \left\{ \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} [s_{n,k}] (x_n \ln \lambda_k - \ln x_n! - \lambda_k) + \ln C_G(a, b) + (a - 1) \ln \lambda_k - b \lambda_k \right\} + \mathrm{const.} \\ &= \sum_{k=1}^K \left\{ \left( \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} [s_{n,k}] x_n + a - 1 \right) \ln \lambda_k - \left( \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} [s_{n,k}] + b \right) \lambda_k \right\} + \mathrm{const.} \tag{4.53} \end{align} $$

となる。$\ln q(\boldsymbol{\pi})$は左辺から移項したものである。

 式(4.53)について

$$ \begin{aligned} \hat{a}_k &= \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} [s_{n,k}] x_n + a \\ \hat{b}_k &= \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} [s_{n,k}] + b \end{aligned} \tag{4.55} $$

とおき

$$ \ln q(\boldsymbol{\lambda}) = \sum_{k=1}^K \Bigl\{ (\hat{a}_k - 1) \ln \lambda_k - \hat{b}_k \lambda_k \Bigr\} + \mathrm{const.} $$

さらに$\ln$を外し、$\mathrm{const.}$を正規化項に置き換える(正規化する)と

$$ q(\boldsymbol{\lambda}) = \prod_{k=1}^K C_G(\hat{a}_k, \hat{b}_k) \lambda_k^{\hat{a}_k-1} \exp(- \hat{b}_k \lambda_k) = \prod_{k=1}^K \mathrm{Gam}(\lambda_k | \hat{a}_k, \hat{b}_k) \tag{4.54} $$

$\lambda_k$の近似事後分布は、パラメータ$\hat{a}_k,\ \hat{b}_k$を持つガンマ分布になることが分かる。

・混合比率の近似事後分布

 同様に、式(4.52)を$\boldsymbol{\pi}$に関して整理する($\boldsymbol{\pi}$に影響しない項を$\mathrm{const.}$にまとめる)と

$$ \begin{align} \ln q(\boldsymbol{\pi}) &= \mathbb{E}_{q(\mathbf{S})} \Bigl[ \ln p(\mathbf{S} | \boldsymbol{\pi}) \Bigr] + \ln p(\boldsymbol{\pi}) - \ln q(\boldsymbol{\lambda}) + \mathrm{const.} \\ &= \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} \Bigl[ \ln \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\pi}) \Bigr] + \ln \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) + \mathrm{const.} \\ &= \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} \left[ \ln \prod_{k=1}^K \pi_k^{s_{n,k}} \right] + \ln C_D(\boldsymbol{\alpha}) \prod_{k=1}^K \pi_k^{\alpha_k-1} + \mathrm{const.} \\ &= \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} \left[ \sum_{k=1}^K s_{n,k} \ln \pi_k \right] + \ln C_D(\boldsymbol{\alpha}) + \sum_{k=1}^K (\alpha_k - 1) \ln \pi_k + \mathrm{const.} \\ &= \sum_{k=1}^K \left\{ \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} [s_{n,k}] \ln \pi_k + (\alpha_k - 1) \ln \pi_k \right\} + \mathrm{const.} \\ &= \sum_{k=1}^K \left( \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} [s_{n,k}] + \alpha_k - 1 \right) \ln \pi_k + \mathrm{const.} \tag{4.56} \end{align} $$

となる。$\ln q(\boldsymbol{\lambda})$は左辺から移項したものである。

 式(4.56)について

$$ \hat{\alpha}_k = \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} [s_{n,k}] + \alpha_k \tag{4.58} $$

とおき

$$ \ln q(\boldsymbol{\pi}) = \sum_{k=1}^K (\hat{\alpha}_k - 1) \ln \pi_k + \mathrm{const.} $$

さらに$\ln$を外し、$\mathrm{const.}$を正規化項に置き換える(正規化する)と

$$ q(\boldsymbol{\pi}) = C_D(\boldsymbol{\alpha}) \prod_{k=1}^K \pi_k^{\hat{\alpha}_k-1} = \mathrm{Dir}(\boldsymbol{\pi} | \hat{\boldsymbol{\alpha}}) \tag{4.57} $$

$\boldsymbol{\pi}$の近似事後分布は、パラメータ$\hat{\boldsymbol{\alpha}} = (\hat{\alpha}_1, \cdots, \hat{\alpha}_K)$を持つディリクレ分布になることが分かる。

 $\hat{\alpha}_k$の計算式(更新式)(4.58)について、$q(\mathbf{s}_n) = \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\eta}_n)$なので、カテゴリ分布の期待値(2.31)より

$$ \mathbb{E}_{q(\mathbf{s}_n)} [s_{n,k}] = \eta_{n,k} \tag{4.59} $$

で計算できる。

・潜在変数の近似事後分布のパラメータの計算

 各分布が明らかになったので、最後に$\eta_{n,k}$の計算式(更新式)(4.51)の各項について確認する。

 $q(\lambda_k) = \mathrm{Gam}(\lambda_k | \hat{a}_k, \hat{b}_k)$、$q(\boldsymbol{\pi}) = \mathrm{Dir}(\boldsymbol{\pi} | \hat{\boldsymbol{\alpha}})$なので、ガンマ分布の期待値(2.59)、ガンマ分布の対数の期待値(2.60)、ディリクレ分布の期待値(2.52)より

$$ \begin{align} \mathbb{E}_{q(\lambda_k)} [\lambda_k] &= \frac{\hat{a}_k}{\hat{b}_k} \tag{4.60} \\ \mathbb{E}_{q(\lambda_k)} [\ln \lambda_k] &= \psi(\hat{a}_k) - \ln \hat{b}_k \tag{4.61} \\ \mathbb{E}_{q(\boldsymbol{\pi})} [\ln \pi_k] &= \psi(\hat{\alpha}_k) - \psi \left(\sum_{k'=1}^K \hat{\alpha}_{k'} \right) \tag{4.62} \end{align} $$

で計算できる。

・Rでやってみる

 では、手を動かしてプログラムからアルゴリズムの理解を深めましょう。アルゴリズム4.2を参考に実装します。

# 利用パッケージ
library(tidyverse)

 必要なパッケージを読み込みます。

・真の観測モデルの設定
# (観測)データ数を指定
N <- 100

# 真のパラメータを指定
lambda_truth <- c(5, 25)
pi_truth <- c(0.3, 0.7)

# クラスタ数
K <- length(lambda_truth)

# クラスタ(潜在変数)を生成
s_nk <- rmultinom(n =  N, size = 1, prob = pi_truth) %>% 
  t()

# (観測)データXを生成
x_n <- rpois(n = N, lambda = apply(lambda_truth^t(s_nk), 2, prod))

 詳細は4.3.2項を参照ください。生成したデータを用いてlambda_truthpi_truthに設定した値を推定します。

 観測データを確認してみましょう。

# 観測データを確認
summary(x_n)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##    2.00    8.00   23.00   18.93   26.25   39.00

 ヒストグラムにすると次のようになります。

tibble(x = x_n) %>% 
  ggplot(aes(x = x)) + 
    geom_bar(fill = "#56256E") + 
    labs(title = "Histogram")

f:id:anemptyarchive:20200423155922p:plain
観測データ$\boldsymbol{X}$のヒストグラム

・パラメータの設定
# 試行回数
MaxIter <- 50

 繰り返し回数を指定します。

# ハイパーパラメータa,bの初期値を指定
a <- 1
b <- 1

 $\boldsymbol{\lambda}$のパラメータ$a,\ b$の値を指定します。

# lambda(の期待値)の初期値をランダムに設定
tmp_lambda <- seq(0, 1, by = 0.01) %>% 
  sample(size = K, replace = TRUE)
E_lambda_k    <- tmp_lambda / sum(tmp_lambda) # 正規化
E_ln_lambda_k <- log(E_lambda_k) # 対数をとる

 $\boldsymbol{\lambda}$の値をランダムに設定します。ギブスサンプリングではa, bを用いて乱数を生成しましたが、こちらは分布に関係なくランダムに値を決めます。

 ここでは期待値ではありませんが、対数をとったものとそれぞれ変数名をE_lambda_kE_ln_lambda_kとしておきます。

# ハイパーパラメータalphaの初期値を指定
alpha_k <- rep(2, K)

 $\boldsymbol{\pi}$のパラメータ$\boldsymbol{\alpha}$の値を指定します。

# pi(の対数をとった期待値)の初期値をランダムに設定
tmp_pi <- seq(0, 1, by = 0.01) %>% 
  sample(size = K, replace = TRUE)
E_ln_pi_k <- tmp_pi / sum(tmp_pi) %>% # 正規化
  log() # 対数をとる

 $\boldsymbol{\lambda}$と同様に$\boldsymbol{\pi}$もランダムに設定します。

・変分推論

・コード全体(クリックで展開)

# 受け皿を用意
eta_nk <- matrix(0, nrow = N, ncol = K)
E_s_nk <- matrix(0, nrow = N, ncol = K)
hat_a_k <- rep(0, K)
hat_b_k <- rep(0, K)

# ハイパーパラメータの推定値の推移の確認用
trace_a <- matrix(0, nrow = MaxIter + 1, ncol = K)
trace_b <- matrix(0, nrow = MaxIter + 1, ncol = K)
trace_alpha <- matrix(0, nrow = MaxIter + 1, ncol = K)
# 初期値を代入
trace_a[1, ] <- a
trace_b[1, ] <- b
trace_alpha[1, ] <- alpha_k

for(i in 1:MaxIter) {
  
  for(n in 1:N) {
    
    # パラメータeta_nを計算:式(4.51)
    tmp_eta <- exp(x_n[n] * E_ln_lambda_k - E_lambda_k + E_ln_pi_k)
    eta_nk[n, ] <- tmp_eta / sum(tmp_eta) # 正規化
    
    # s_nkの期待値を計算:式(4.59)
    E_s_nk[n, ] <- eta_nk[n, ]
  }
  
  for(k in 1:K) {
    
    # ハイパーパラメータa_hat_k,b_hat_kを計算:式(4.55)
    hat_a_k[k] <- sum(E_s_nk[, k] * x_n) + a
    hat_b_k[k] <- sum(E_s_nk[, k]) + b
    
    # (対数をとった)lambda_kの期待値を計算:式(4.60),(4.61)
    E_lambda_k[k] <- hat_a_k[k] / hat_b_k[k]
    E_ln_lambda_k[k] <- digamma(hat_a_k[k]) - log(hat_b_k[k])
  }
  
  # ハイパーパラメータalpha_hatを計算:式(4.58)
  hat_alpha_k <- apply(E_s_nk, 2, sum) + alpha_k
  
  # 対数をとったpiの期待値を計算:式(4.62)
  E_ln_pi_k <- digamma(hat_alpha_k) - digamma(sum(hat_alpha_k))
  
  # 推移の確認用に推定結果を保存
  trace_a[i + 1, ] <- hat_a_k
  trace_b[i + 1, ] <- hat_b_k
  trace_alpha[i + 1, ] <- hat_alpha_k
  
}

 添字を使って代入するために、予め変数を用意しておく必要があります。
 また各パラメータの更新値の推移を確認するため、学習する度に値を保存しておきます。


・$q(\boldsymbol{S})$の更新

for(n in 1:N) {
  
  # パラメータeta_nを計算:式(4.51)
  tmp_eta <- exp(x_n[n] * E_ln_lambda_k - E_lambda_k + E_ln_pi_k)
  eta_nk[n, ] <- tmp_eta / sum(tmp_eta) # 正規化
  
  # s_nkの期待値を計算:式(4.59)
  E_s_nk[n, ] <- eta_nk[n, ]
}

 式(4.51)の計算を行い、$\boldsymbol{\eta}_n$を求めます。

 eta_nk[n, ]の値が$\mathbb{E}[\boldsymbol{s}_n]$となります(式(4.59))。

 この処理を1から$N$まで繰り返すことで、近似事後分布$q(\boldsymbol{S})$を更新できます。

・$q(\boldsymbol{\lambda})$の更新

for(k in 1:K) {
  
  # ハイパーパラメータa_hat_k,b_hat_kを計算:式(4.55)
  hat_a_k[k] <- sum(E_s_nk[, k] * x_n) + a
  hat_b_k[k] <- sum(E_s_nk[, k]) + b
  
  # (対数をとった)lambda_kの期待値を計算:式(4.60),(4.61)
  E_lambda_k[k] <- hat_a_k[k] / hat_b_k[k]
  E_ln_lambda_k[k] <- digamma(hat_a_k[k]) - log(hat_b_k[k])
}

 式(4.55)の計算を行い、$\hat{a}_k,\ \hat{b}_k$を求めます。

 hat_a_k[k], hat_b_k[k]を使って式(4.60)、(4.61)の計算を行い、$\mathbb{E}_{q(\lambda_k)}[\lambda_k]$と$\mathbb{E}_{q(\lambda_k)}[\ln \lambda_k]$を求めます。

 この処理を1から$K$まで繰り返すことで、近似事後分布$q(\boldsymbol{\lambda})$を更新できます。

・$q(\boldsymbol{\pi})$の更新

# ハイパーパラメータalpha_hatを計算:式(4.58)
hat_alpha_k <- apply(E_s_nk, 2, sum) + alpha_k

# 対数をとったpiの期待値を計算:式(4.62)
E_ln_pi_k <- digamma(hat_alpha_k) - digamma(sum(hat_alpha_k))

 式(4.58)の計算を行い、$\hat{\boldsymbol{\alpha}}$を求めます。

 hat_alpha_kを使って式(4.62)の計算を行い、$\mathbb{E}_{q(\boldsymbol{\pi})}[\ln \pi_k]$を求めます。

 近似事後分布$q(\boldsymbol{\pi})$については1度に全ての要素を更新するため、繰り返し処理は行いません。

・推定結果の確認

 ggplot2パッケージを利用して、パラメータ$\boldsymbol{\lambda},\ \boldsymbol{\pi}$の事後分布を可視化します。

・コード(クリックで展開)

## lambdaの近似事後分布
# 作図用のデータフレームを作成
lambda_df <- tibble()
for(k in 1:K) {
  # データフレームに変換
  tmp_lambda_df <- tibble(
    lambda = seq(0, max(x_n), by = 0.01), 
    density = dgamma(lambda, shape = hat_a_k[k], rate = hat_b_k[k]), 
    cluster = as.factor(k)
  )
  # 結合
  lambda_df <- rbind(lambda_df, tmp_lambda_df)
}

# 作図
ggplot(lambda_df, aes(lambda, density, color = cluster)) + 
  geom_line() + # 折れ線グラフ
  scale_color_manual(values = c("#00A968", "orange")) + # グラフの色(不必要)
  geom_vline(xintercept = lambda_truth, color = "pink", linetype = "dashed") + # 垂直線
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = paste0("a_hat=(", paste0(round(hat_a_k, 1), collapse = ", "), 
                         "), b_hat=(", paste0(round(hat_b_k, 1), collapse = ", "), ")")) # ラベル

 scale_color_manual()は使わない方が無難です(私の趣味の色です)。色を指定する場合は、クラスタ数$K$と同じ数の色を指定する必要があります。

f:id:anemptyarchive:20200423162200p:plain
$q(\boldsymbol{\lambda})$:ガンマ分布

・コード(クリックで展開)

## piの近似事後分布(K=2のときのみ可能)
# 作図用のデータフレームを作成
pi_df <- tibble()
for(k in 1:K) {
  # データフレームに変換
  tmp_pi_df <- tibble(
    pi = seq(0, 1, by = 0.001), 
    density = dbeta(pi, shape1 = hat_alpha_k[k], shape2 = hat_alpha_k[2 / k]), 
    cluster = as.factor(k)
  )
  # 結合
  pi_df <- rbind(pi_df, tmp_pi_df)
}

# 作図
ggplot(pi_df, aes(pi, density, color = cluster)) + 
  geom_line() + # 折れ線グラフ
  scale_color_manual(values = c("#00A968", "orange")) + # グラフの色(不必要)
  geom_vline(xintercept = pi_truth, color = "pink", linetype = "dashed") + # 垂直線
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = paste0("alpha_hat=(", paste0(round(hat_alpha_k, 1), collapse = ", "), ")")) # ラベル

 $\boldsymbol{\pi}$については、このコードでグラフ化できるのは$K = 2$のときだけになります。(詳しくは3.2.2:カテゴリ分布の学習と予測【緑ベイズ入門のノート】 - からっぽのしょこをご参照ください。)

f:id:anemptyarchive:20200423162338p:plain
$q(\boldsymbol{\pi})$:ディリクレ分布

・推移の確認

・事後分布の推移

 gganimateパッケージを利用して、分布の推移をgif画像として出力するためのコードです。

・コード(クリックで展開)

# 利用パッケージ
library(gganimate)

## lambdaの近似事後分布
# 作図用のデータフレームを作成
trace_lambda_long <- tibble()
for(i in 1:(MaxIter + 1)) {
  for(k in 1:K) {
    # データフレームに変換
    tmp_lambda_df <- tibble(
      lambda = seq(0, max(x_n), by = 0.01), 
      density = dgamma(lambda, shape = trace_a[i, k], rate = trace_b[i, k]), 
      cluster = as.factor(k), 
      Iteration = i - 1
    )
    # 結合
    trace_lambda_long <- rbind(trace_lambda_long, tmp_lambda_df)
  }
}

# 作図
graph_lambda <- ggplot(trace_lambda_long, aes(lambda, density, color = cluster)) + 
  geom_line() + # 折れ線グラフ
  scale_color_manual(values = c("#00A968", "orange")) + # グラフの色(不必要)
  geom_vline(xintercept = lambda_truth, color = "pink", linetype = "dashed") + # 垂直線
  transition_manual(Iteration) + # フレーム
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = "i={current_frame}") # ラベル
  
# gif画像を作成
animate(graph_lambda, nframes = MaxIter + 1, fps = 5)

 gganimate::transition_manual()に時系列を指定して、gganimate::animate()でgifファイルに変換します。フレーム数は初期値の0を含むためイタレーション数+1となります。

f:id:anemptyarchive:20200423161512g:plain
$q(\boldsymbol{\lambda})$の推移

・コード(クリックで展開)

## piの近似事後分布(K=2のときのみ可)
# 作図用のデータフレームを作成
trace_pi_long <- tibble()
for(i in 1:(MaxIter + 1)) {
  for(k in 1:K) {
    # データフレームに変換
    tmp_pi_df <- tibble(
      pi = seq(0, 1, by = 0.001), 
      density = dbeta(pi, shape1 = trace_alpha[i, k], shape2 = trace_alpha[i, 2 / k]), 
      cluster = as.factor(k), 
      Iteration = i - 1
    )
    # 結合
    trace_pi_long <- rbind(trace_pi_long, tmp_pi_df)
  }
}

# 作図
graph_pi <- ggplot(trace_pi_long, aes(pi, density, color = cluster)) + 
  geom_line() + # 折れ線グラフ
  scale_color_manual(values = c("#00A968", "orange")) + # グラフの色(不必要)
  geom_vline(xintercept = pi_truth, color = "pink", linetype = "dashed") + # 垂直線
  transition_manual(Iteration) + # フレーム
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = "i={current_frame}") # ラベル

# gif画像を作成
animate(graph_pi, nframes = MaxIter + 1, fps = 5)

f:id:anemptyarchive:20200423161636g:plain
$q(\boldsymbol{\pi})$の推移

・ハイパーパラメータの推移

・コード(クリックで展開)

## lambdaのパラメータa
# データフレームに変換
trace_a_wide <- cbind(
  as.data.frame(trace_a), 
  Iteration = 1:(MaxIter + 1)
)

# long型に変換
trace_a_long <- pivot_longer(
  trace_a_wide, 
  cols = -Iteration, 
  names_to = "cluster", 
  names_prefix = "V", 
  names_ptypes = list(cluster = factor()), 
  values_to = "value"
)

# 作図
ggplot(trace_a_long, aes(Iteration, value, color = cluster)) + 
  geom_line() + 
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = expression(hat(a)))


## lambdaのパラメータb
# データフレームに変換
trace_b_wide <- cbind(
  as.data.frame(trace_a), 
  Iteration = 1:(MaxIter + 1)
)

# long型に変換
trace_b_long <- pivot_longer(
  trace_b_wide, 
  cols = -Iteration, 
  names_to = "cluster", 
  names_prefix = "V", 
  names_ptypes = list(cluster = factor()), 
  values_to = "value"
)

# 作図
ggplot(trace_b_long, aes(Iteration, value, color = cluster)) + 
  geom_line() + 
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = expression(hat(b)))

f:id:anemptyarchive:20200423162959p:plainf:id:anemptyarchive:20200423163017p:plain
$\hat{a},\ \hat{b}$の推移

・コード(クリックで展開)

## piのパラメータ
# データフレームに変換
trace_alpha_wide <- cbind(
  as.data.frame(trace_alpha), 
  Iteration = 1:(MaxIter + 1)
)

# long型に変換
trace_alpha_long <- pivot_longer(
  trace_alpha_wide, 
  cols = -Iteration, 
  names_to = "cluster", 
  names_prefix = "V", 
  names_ptypes = list(cluster = factor()), 
  values_to = "value"
)

# 作図
ggplot(trace_alpha_long, aes(Iteration, value, color = cluster)) + 
  geom_line() + 
  labs(title = "Poisson Mixture Model:Variational Inference", 
       subtitle = expression(hat(alpha)))

f:id:anemptyarchive:20200423163205p:plain
$\hat{\alpha}$の推移

参考文献

  • 須山敦志『ベイズ推論による機械学習入門』(機械学習スタートアップシリーズ)杉山将監修,講談社,2017年.

おわりに

 前回に引き続きこの記事もOsaka.Rのリモート朝もく会中に書いたものになります。

Osaka.Rが主催する毎平日8時から10時のもくもく会です。 参加するには下記URLからOsaka.Rのslackに登録してください。 #mokumoku チャンネルのトピックを見るともくもく会場への案内をご覧頂けます。

https://join.slack.com/t/osakar/shared_invite/zt-dgjyfztf-AVYDIx~P8Ncl6deigOOarA

イベントとしても公開していますので、そちらに登録いただいてもかまいません。

朝のちょっとした時間を活かしませんか? 途中参加途中離脱OK。 チャットでその日の課題を宣言してもくもく開始。困ったことがあればチャットや画面共有で協力し合いましょう。 大阪府外の方も大歓迎です!

osaka-r.connpass.com

暫く開催されるとのことですので、(Rユーザーの)皆様ぜひぜひ一緒に参加しましょー。私は毎日参加するつもりです!

【次節の内容】

www.anarchive-beta.com


  • 2021/04/27:加筆修正しました。

 今でもほとんど毎日朝もくに参加してますー。この修正も朝もくを中心に進めました!