からっぽのしょこ

読んだら書く!書いたら読む!同じ事は二度調べ(たく)ない

3.2.2:カテゴリ分布の学習と予測の導出【緑ベイズ入門のノート】

はじめに

 『ベイズ推論による機械学習入門』の学習時のノートです。「数式の行間を読んでみた」とそれを「RとPythonで組んでみた」によって、「数式」と「プログラム」から理解するのが目標です。
 省略してある内容等ありますので、本とあわせて読んでください。

 この記事では、「尤度関数をカテゴリ分布」、「事前分布をディリクレ分布」とした場合の「パラメータの事後分布」と「未観測値の予測分布」を導出します。

【実装編】

www.anarchive-beta.com

www.anarchive-beta.com

【前の節の内容】

www.anarchive-beta.com

【他の節の内容】

www.anarchive-beta.com

【この節の内容】

3.2.2 カテゴリ分布の学習と予測

 尤度関数をカテゴリ分布(Categorical Distribution)、事前分布をディリクレ分布(Dirichlet Distribution)とするモデルに対するベイズ推論を導出する。
 カテゴリ分布については「カテゴリ分布の定義式 - からっぽのしょこ」、ディリクレ分布については「ディリクレ分布の定義式 - からっぽのしょこ」を参照のこと。

尤度関数の確認

 カテゴリ分布に従うと仮定する$N$個の観測データ$\mathbf{S} = \{\mathbf{s}_1, \mathbf{s}_2, \cdots, \mathbf{s}_N\}$、$\mathbf{s}_n = (s_{n,1}, s_{n,2}, \cdots, s_{n,K})$を用いて、パラメータ$\boldsymbol{\pi} = (\pi_1, \pi_2, \cdots, \pi_K)$の事後分布と未観測のデータ$\mathbf{s}_{*}$の予測分布を求めていく。
 まずは、尤度関数を確認する。

 $\mathbf{S}$の各データが独立に発生しているとの仮定の下で、尤度$p(\mathbf{S} | \boldsymbol{\pi})$は

$$ \begin{aligned} p(\mathbf{S} | \boldsymbol{\pi}) &= p(\mathbf{s}_1, \mathbf{s}_2, \cdots, \mathbf{s}_N | \boldsymbol{\pi}) \\ &= p(\mathbf{s}_1 | \boldsymbol{\pi}) p(\mathbf{s}_2 | \boldsymbol{\pi}) \cdots p(\mathbf{s}_N | \boldsymbol{\pi}) \\ &= \prod_{n=1}^N p(\mathbf{s}_n | \boldsymbol{\pi}) \end{aligned} $$

と分解できる。さらに、$p(\mathbf{s}_n | \boldsymbol{\pi})$は

$$ \begin{aligned} p(\mathbf{s}_n | \boldsymbol{\pi}) &= p(s_{n,1}, s_{n,2}, \cdots, s_{n,K} | \boldsymbol{\pi}) \\ &= p(s_{n,1} | \boldsymbol{\pi}) p(s_{n,2} | \boldsymbol{\pi}) \cdots p(s_{n,K} | \boldsymbol{\pi}) \\ &= \prod_{k=1}^K p(s_{n,k} | \boldsymbol{\pi}) \end{aligned} $$

と分解できる。従って、尤度は

$$ p(\mathbf{S} | \boldsymbol{\pi}) = \prod_{n=1}^N \prod_{k=1}^K p(s_{n,k} |\boldsymbol{\pi}) $$

である。
 各データ$\mathbf{s}_n$の各要素$s_{n,k}$は0か1の2値をとり、$\sum_{k=1}^K s_{n,k} = 1$である(1つの要素が1でその他の要素が0である)。また$\pi_k$は、$s_{n,k} = 1$となる確率を表すパラメータで、$0 \leq \pi_k \leq 1$、$\sum_{k=1}^K \pi_k = 1$である。

事後分布の計算

 次に、パラメータ$\boldsymbol{\pi}$の事後分布$p(\boldsymbol{\pi} | \mathbf{S})$導出する。

 $\mathbf{S}$が与えられた下での$\boldsymbol{\pi}$の事後分布は、ベイズの定理より

$$ \begin{align} p(\boldsymbol{\pi} | \mathbf{S}) &= \frac{ p(\mathbf{S} | \boldsymbol{\pi}) p(\boldsymbol{\pi} | \boldsymbol{\alpha}) }{ p(\mathbf{S}) } \\ &= \frac{ \left\{ \prod_{n=1}^N p(\mathbf{s}_n | \boldsymbol{\pi}) \right\} p(\boldsymbol{\pi} | \boldsymbol{\alpha}) }{ p(\mathbf{S}) } \\ &= \frac{ \left\{\prod_{n=1}^N \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\pi}) \right\} \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) }{ p(\mathbf{S}) } \\ &\propto \left\{\prod_{n=1}^N \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\pi}) \right\} \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) \tag{3.25} \end{align} $$

となる。分母の$p(\mathbf{S})$は$\boldsymbol{\pi}$に影響しないため省略して、比例関係のみに注目する。省略した部分については、最後に正規化することで対応できる。

 この分布の具体的な形状を明らかにしていく。対数をとって指数部分の計算を分かりやすくして、$\boldsymbol{\pi}$に関して整理すると

$$ \begin{align} \ln p(\boldsymbol{\pi} | \mathbf{S}) &= \ln \prod_{n=1}^N \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\pi}) + \ln \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) - \ln p(\mathbf{S}) \\ &= \sum_{n=1}^N \ln \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\pi}) + \ln \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) + \mathrm{const.} \\ &= \sum_{n=1}^N \ln \left( \prod_{k=1}^K \pi_k^{s_{n,k}} \right) + \ln \left\{ C_D(\boldsymbol{\alpha}) \prod_{k=1}^K \pi_k^{\alpha_k-1} \right\} + \mathrm{const.} \\ &= \sum_{n=1}^N \sum_{k=1}^K s_{n,k} \ln \pi_k + \ln C_D(\boldsymbol{\alpha}) + \sum_{k=1}^K (\alpha_k - 1) \ln \pi_k + \mathrm{const.} \\ &= \sum_{k=1}^K \left( \sum_{n=1}^N s_{n,k} + \alpha_k - 1 \right) \ln \pi_k + \mathrm{const.} \tag{3.26} \end{align} $$

【途中式の途中式】

  1. 式(3.25)の下から2行目の関係を用いている。
  2. 自然対数の性質より、$\ln x y = \ln x + \ln y$である。また、$\boldsymbol{\pi}$に影響しない$- \ln p(\mathbf{S})$を$\mathrm{const.}$とおく。
  3. 尤度と事前分布に、それぞれ具体的な確率分布の式を代入する。
  4. それぞれ対数をとる。$\ln x^a = a \ln x$である。
  5. $\boldsymbol{\pi}$に影響しない$\ln C_D(\boldsymbol{\alpha})$を$\mathrm{const.}$にまとめる。
  6. $\ln \pi_k$の項をまとめて式を整理する。

となる。

 式(3.26)について

$$ \hat{\boldsymbol{\alpha}} = (\hat{\alpha}_1, \hat{\alpha}_2, \cdots, \hat{\alpha}_K) ,\ \hat{\alpha}_k = \sum_{n=1}^N s_{n,k} + \alpha_k \tag{3.28} $$

とおき

$$ \ln p(\boldsymbol{\pi} | \mathbf{S}) = \sum_{k=1}^K (\hat{\alpha}_k - 1) \ln \pi_k + \mathrm{const.} $$

さらに$\ln$を外して、$\mathrm{const.}$を正規化項に置き換える(正規化する)と

$$ p(\boldsymbol{\pi} | \mathbf{S}) = C_D(\hat{\boldsymbol{\alpha}}) \prod_{k=1}^K \pi_k^{\hat{\alpha}_k-1} = \mathrm{Dir}(\boldsymbol{\pi} | \hat{\boldsymbol{\alpha}}) \tag{3.27} $$

事後分布は式の形状から、パラメータ$\hat{\boldsymbol{\alpha}}$を持つディリクレ分布であることが分かる。
 また、式(3.28)が超パラメータ$\hat{\boldsymbol{\alpha}}$の計算式(更新式)である。

予測分布の計算

 続いて、カテゴリ分布に従う未観測のデータ$\mathbf{s}_{*} = \{s_{*,1}, s_{*,2}, \cdots, s_{*,K}\}$に対する予測分布を導出する。

事前分布による予測分布

 事前分布(観測データによる学習を行っていない分布)$p(\boldsymbol{\pi} | \boldsymbol{\alpha})$を用いて、パラメータ$\boldsymbol{\pi}$を周辺化することで予測分布$p(\mathbf{s}_{*})$となる。

$$ \begin{aligned} p(\mathbf{s}_{*}) &= \int p(\mathbf{s}_{*} | \boldsymbol{\pi}) p(\boldsymbol{\pi} | \boldsymbol{\alpha}) d\boldsymbol{\pi} \\ &= \int \mathrm{Cat}(\mathbf{s}_{*} | \boldsymbol{\pi}) \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) d\boldsymbol{\pi} \\ &= \int \prod_{k=1}^K \pi_k^{s_{*,k}} C_D(\boldsymbol{\alpha}) \pi_k^{\alpha_k-1} d\boldsymbol{\pi} \\ &= C_D(\boldsymbol{\alpha}) \int \prod_{k=1}^K \pi_k^{s_{*,k}+\alpha_k-1} d\boldsymbol{\pi} \end{aligned} $$

 積分部分に注目すると、パラメータ$(s_{*,1} + \alpha_1, \cdots, s_{*,K} + \alpha_K)$を持つ正規化項のないディリクレ分布の形をしている。
 これはディリクレ分布の正規化項の逆数に変形できる(そもそもこの部分の逆数が正規化項である。詳しくは「ディリクレ分布の統計量の導出 - からっぽのしょこ」を参照のこと)。

$$ \int \prod_{k=1}^K \pi_k^{s_{*,k}+\alpha_k-1} d\boldsymbol{\pi} = \frac{ 1 }{ C_D\Bigl( (s_{*,k} + \alpha_k)_{k=1}^K \Bigr) } $$

 ここで、$(s_{*,1} + \alpha_1, \cdots, s_{*,K} + \alpha_K)$を$(s_{*,k} + \alpha_k)_{k=1}^K$と表記する。
 これを先ほどの式に代入すると、予測分布は

$$ p(\mathbf{s}_{*}) = \frac{ C_D(\boldsymbol{\alpha}) }{ C_D \Bigl( (s_{*,k} + \alpha_k)_{k=1}^K \Bigr) } \tag{3.29} $$

となる。
 さらにディリクレ分布の正規化項(2.49)を用いると

$$ \begin{align} p(\mathbf{s}_{*}) &= C_D(\boldsymbol{\alpha}) \frac{ 1 }{ C_D\Bigl((s_{*,k} + \alpha_k)_{k=1}^K\Bigr) } \\ &= \frac{ \Gamma(\sum_{k=1}^K \alpha_k) }{ \prod_{k=1}^K \Gamma(\alpha_k) } \frac{ \prod_{k=1}^K \Gamma(s_{*,k} + \alpha_k) }{ \Gamma(\sum_{k=1}^K s_{*,k} + \alpha_k) } \tag{3.30} \end{align} $$

となる。

 式(3.30)について、$s_{*,i} = 1$となる($i$番目の項が1となる)場合を考える。このとき$\mathbf{s}_{*}$の$i$番目以外の項($j = 1, \cdots, i - 1, i + 1, \cdots, K$番目の項)は0である。よって、式(3.30)は

$$ \begin{align} p(s_{*,i} = 1) &= \frac{ \Gamma(\sum_{k=1}^K \alpha_k) }{ \prod_{k=1}^K \Gamma(\alpha_k) } \frac{ \Gamma(s_{*,i} + \alpha_i) \prod_{j \neq i} \Gamma(s_{*,j} + \alpha_j) }{ \Gamma(\sum_{k=1}^K s_{*,k} + \alpha_k) } \\ &= \frac{ \Gamma(\sum_{k=1}^K \alpha_k) }{ \prod_{k=1}^K \Gamma(\alpha_k) } \frac{ \Gamma(1 + \alpha_i) \prod_{j \neq i} \Gamma(0 + \alpha_j) }{ \Gamma(1 + \sum_{k=1}^K \alpha_k) } \\ &= \frac{ \Gamma(\sum_{k=1}^K \alpha_k) }{ \prod_{k=1}^K \Gamma(\alpha_k) } \frac{ \alpha_i \Gamma(\alpha_i) \prod_{j \neq i} \Gamma(\alpha_j) }{ (\sum_{k=1}^K \alpha_k) \Gamma(\sum_{k=1}^K \alpha_k) } \\ &= \frac{\alpha_i}{\sum_{k=1}^K \alpha_k} \tag{3.31} \end{align} $$

【途中式の途中式】

  1. $\prod_{k=1}^K$から$i$番目の項のみ取り出して、$\prod_{k=1}^K \Gamma(s_{*,k} + \alpha_k) = \Gamma(s_{*,i} + \alpha_i) \prod_{j \neq i} \Gamma(s_{*,j} + \alpha_j)$とする。
  2. $s_{*,i}$に1、それ以外の$s_{*,1}, \cdots, s_{*,k'-1}, s_{*,k'+1}, \cdots, s_{*,K}$に0を代入する。また、$\sum_{k=1}^K s_{*,k} = 1$である。
  3. ガンマ関数の性質$\Gamma(x + 1) = x \Gamma(x)$を用いて、項を変形する。
  4. $\prod_{k=1}^K \Gamma(\alpha_k) = \Gamma(\alpha_i) \prod_{j \neq i} \Gamma(\alpha_j)$より、約分して式を整理する。

となる。

 $s_{*,i}$以外の項$s_{*,1}, \cdots, s_{*,i-1}, s_{*,i+1}, \cdots, s_{*,K}$が1となる場合も同様に計算できるので、$x^0 = 1$であることを利用して$p(s_{*,1} = 1)$から$p(s_{*,K} = 1)$までをまとめると、式(3.30)は

$$ p(\mathbf{s}_{*}) = \prod_{k=1}^K \Bigl( \frac{\alpha_{k}}{\sum_{k'=1}^K \alpha_{k'}} \Bigr)^{s_{*,k}} $$

と書き換えられる。($s_{*,k} = 1$以外の項は0乗され1となり、式(3.31)が成り立つ。また、分母の$k'$は分子の$k$と区別するためのインデックスである。本ではこちらを$i$と表記することで区別している。)

 この式について

$$ \boldsymbol{\pi}_{*} = (\pi_{*,1}, \pi_{*,2}, \cdots, \pi_{*,K}),\ \pi_{*,k} = \frac{\alpha_k}{\sum_{k'=1}^K \alpha_{k'}} $$

とおくと

$$ p(\mathbf{s}_{*}) = \prod_{k=1}^K \pi_{*,k}^{s_{*,k}} = \mathrm{Cat}(\mathbf{s}_{*} | \boldsymbol{\pi}_{*}) \tag{3.32} $$

予測分布は式の形状から、パラメータ$\boldsymbol{\pi}_{*}$を持つカテゴリ分布であることが分かる。ちなみに、$\frac{\alpha_{k}}{\sum_{k'=1}^K \alpha_{k'}}$はディリクレ分布$\mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha})$の期待値$\mathbb{E}[\pi_k]$である。また、ディリクレ分布のパラメータは$\alpha_k > 0$の条件を持つので、式(3.31)の計算により、カテゴリ分布のパラメータの条件$0 \leq \pi_{*,k} \leq 1, \sum_{k=1}^K \pi_{*,k} = 1$を満たす。

事後分布による予測分布

 予測分布の計算に事前分布$p(\boldsymbol{\pi})$を用いることで、観測データによる学習を行っていない予測分布$p(\mathbf{s}_{*})$(のパラメータ$\boldsymbol{\pi}_{*}$)を求めた。事後分布$p(\boldsymbol{\pi} | \mathbf{S})$を用いると、同様の手順で観測データ$\mathbf{S}$によって学習した予測分布$p(\mathbf{s}_{*} | \mathbf{S})$(のパラメータ)を求められる。

 そこで、$\boldsymbol{\pi}_{*}$を構成する事前分布のパラメータ$\boldsymbol{\alpha}$について、事後分布のパラメータ(3.28)に置き換えたものを$\hat{\boldsymbol{\pi}}_{*}$とおくと

$$ \begin{aligned} \hat{\pi}_{*,k} &= \frac{ \hat{\alpha}_k }{ \sum_{k'=1}^K \hat{\alpha}_{k'} } \\ &= \frac{ \sum_{n=1}^N s_{n,k} + \alpha_k }{ \sum_{k'=1}^K \left\{ \sum_{n=1}^N s_{n,k'} + \alpha_{k'} \right\} } \end{aligned} $$

となり、予測分布

$$ p(\mathbf{s}_{*} | \mathbf{S}) = \prod_{k=1}^K \hat{\pi}_{*,k}^{s_{*,k}} = \mathrm{Cat}(\mathbf{s}_{*} | \hat{\boldsymbol{\pi}}_{*}) $$

が得られる。
 また、上の式が予測分布のパラメータ$\hat{\pi}_{*,k}$の計算式(更新式)である。

参考文献

  • 須山敦志『ベイズ推論による機械学習入門』(機械学習スタートアップシリーズ)杉山将監修,講談社,2017年.

おわりに

 記事を分割したらコメントがなくなっちゃった。

2020/03/04:加筆修正しました。
2021/04/04:加筆修正しました。その際にRで実装編と記事を分割しました。途中微妙に間違ってた、、恥ずかしい。

【次の節の内容】

www.anarchive-beta.com