からっぽのしょこ

読んだら書く!書いたら読む!同じ事は二度調べ(たく)ない

3.2.2:カテゴリ分布の学習と予測【緑ベイズ入門のノート】

はじめに

 『ベイズ推論による機械学習入門』の学習時のノートです。基本的な内容は「数式の行間を読んでみた」とそれを「RとPythonで組んでみた」になります。「数式」と「プログラム」から理解するのが目標です。

 この記事は、3.2.2項の内容です。「尤度関数をカテゴリ分布」、「事前分布をディリクレ分布」とした場合の「パラメータの事後分布」と「未観測値の予測分布」を導出します。

 省略してる内容等ありますので、本とあわせて読んでください。初学者な自分が理解できるレベルまで落として書き下していますので、分かる人にはかなりくどくなっています。同じような立場の人のお役に立てば幸いです。

【実装編】

www.anarchive-beta.com

www.anarchive-beta.com

【他の節の内容】

www.anarchive-beta.com

【この節の内容】

3.2.2 カテゴリ分布の学習と予測

 カテゴリ分布に従うと仮定する$N$個の観測データ$\mathbf{S} = \{\mathbf{s}_1, \mathbf{s}_2, \cdots, \mathbf{s}_N\}$を用いて、パラメータ$\boldsymbol{\pi} = (\pi_1, \pi_2, \cdots, \pi_K)$の事後分布と未観測のデータ$\mathbf{s}_{*}$の予測分布を求めていく。
 各データ$\mathbf{s}_n = (s_{n,1}, s_{n,2}, \cdots, s_{n,K})$の各要素$s_{n,k}$は0か1の2値をとり、$\sum_{k=1}^K s_{n,k} = 1$である(1つの要素が1でその他の要素が0である)。また$\pi_k$は、$s_{n,k} = 1$となる確率を表すパラメータで、$0 \leq \pi_k \leq 1$、$\sum_{k=1}^K \pi_k = 1$である。

 $\mathbf{S}$の各データが独立に発生しているとの仮定の下で、尤度$p(\mathbf{S} | \boldsymbol{\pi})$は

$$ \begin{aligned} p(\mathbf{S} | \boldsymbol{\pi}) &= p(\mathbf{s}_1, \mathbf{s}_2, \cdots, \mathbf{s}_N | \boldsymbol{\pi}) \\ &= p(\mathbf{s}_1 | \boldsymbol{\pi}) p(\mathbf{s}_2 | \boldsymbol{\pi}) \cdots p(\mathbf{s}_N | \boldsymbol{\pi}) \\ &= \prod_{n=1}^N p(\mathbf{s}_n | \boldsymbol{\pi}) \end{aligned} $$

と分解できる。さらに、$p(\mathbf{s}_n | \boldsymbol{\pi})$は

$$ \begin{aligned} p(\mathbf{s}_n | \boldsymbol{\pi}) &= p(s_{n,1}, s_{n,2}, \cdots, s_{n,K} | \boldsymbol{\pi}) \\ &= p(s_{n,1} | \boldsymbol{\pi}) p(s_{n,2} | \boldsymbol{\pi}) \cdots p(s_{n,K} | \boldsymbol{\pi}) \\ &= \prod_{k=1}^K p(s_{n,k} | \boldsymbol{\pi}) \end{aligned} $$

と分解できる。従って、尤度は

$$ p(\mathbf{S} | \boldsymbol{\pi}) = \prod_{n=1}^N \prod_{k=1}^K p(s_{n,k} |\boldsymbol{\pi}) $$

である。

・事後分布の計算

 まずは、パラメータ$\boldsymbol{\pi}$の事後分布$p(\boldsymbol{\pi} | \mathbf{S})$導出する。

 $\mathbf{S}$が与えられた下での$\boldsymbol{\pi}$の事後分布は、ベイズの定理より

$$ \begin{align} p(\boldsymbol{\pi} | \mathbf{S}) &= \frac{ p(\mathbf{S} | \boldsymbol{\pi}) p(\boldsymbol{\pi} | \boldsymbol{\alpha}) }{ p(\mathbf{S}) } \\ &= \frac{ \left\{ \prod_{n=1}^N p(\mathbf{s}_n | \boldsymbol{\pi}) \right\} p(\boldsymbol{\pi} | \boldsymbol{\alpha}) }{ p(\mathbf{S}) } \\ &= \frac{ \left\{\prod_{n=1}^N \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\pi}) \right\} \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) }{ p(\mathbf{S}) } \\ &\propto \left\{\prod_{n=1}^N \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\pi}) \right\} \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) \tag{3.25} \end{align} $$

となる。分母の$p(\mathbf{S})$は$\boldsymbol{\pi}$に影響しないため省略して、比例関係のみに注目する。省略した部分については、最後に正規化することで対応できる。

 この分布の具体的な形状を明らかにしていく。対数をとって指数部分の計算を分かりやすくして、$\boldsymbol{\pi}$に関して整理すると

$$ \begin{align} \ln p(\boldsymbol{\pi} | \mathbf{S}) &= \ln \prod_{n=1}^N \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\pi}) + \ln \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) - \ln p(\mathbf{S}) \\ &= \sum_{n=1}^N \ln \mathrm{Cat}(\mathbf{s}_n | \boldsymbol{\pi}) + \ln \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) + \mathrm{const.} \\ &= \sum_{n=1}^N \ln \left( \prod_{k=1}^K \pi_k^{s_{n,k}} \right) + \ln \left\{ C_D(\boldsymbol{\alpha}) \prod_{k=1}^K \pi_k^{\alpha_k-1} \right\} + \mathrm{const.} \\ &= \sum_{n=1}^N \sum_{k=1}^K s_{n,k} \ln \pi_k + \ln C_D(\boldsymbol{\alpha}) + \sum_{k=1}^K (\alpha_k - 1) \ln \pi_k + \mathrm{const.} \\ &= \sum_{k=1}^K \left( \sum_{n=1}^N s_{n,k} + \alpha_k - 1 \right) \ln \pi_k + \mathrm{const.} \tag{3.26} \end{align} $$

【途中式の途中式】

  1. 式(3.25)の下から2行目の関係を用いている。
  2. 自然対数の性質より、$\ln x y = \ln x + \ln y$である。また、$\boldsymbol{\pi}$に影響しない$- \ln p(\mathbf{S})$を$\mathrm{const.}$とおく。
  3. 尤度と事前分布に、それぞれ具体的な確率分布の式を代入する。
  4. それぞれ対数をとる。$\ln x^a = a \ln x$である。
  5. $\boldsymbol{\pi}$に影響しない$\ln C_D(\boldsymbol{\alpha})$を$\mathrm{const.}$にまとめる。
  6. $\ln \pi_k$の項をまとめて式を整理する。

となる。

 式(3.26)について

$$ \hat{\boldsymbol{\alpha}} = (\hat{\alpha}_1, \hat{\alpha}_2, \cdots, \hat{\alpha}_K) ,\ \hat{\alpha}_k = \sum_{n=1}^N s_{n,k} + \alpha_k \tag{3.28} $$

とおき

$$ \ln p(\boldsymbol{\pi} | \mathbf{S}) = \sum_{k=1}^K (\hat{\alpha}_k - 1) \ln \pi_k + \mathrm{const.} $$

さらに$\ln$を外して、$\mathrm{const.}$を正規化項に置き換える(正規化する)と

$$ p(\boldsymbol{\pi} | \mathbf{S}) = C_D(\hat{\boldsymbol{\alpha}}) \prod_{k=1}^K \pi_k^{\hat{\alpha}_k-1} = \mathrm{Dir}(\boldsymbol{\pi} | \hat{\boldsymbol{\alpha}}) \tag{3.27} $$

事後分布は式の形状から、パラメータ$\hat{\boldsymbol{\alpha}}$を持つディリクレ分布であることが分かる。
 また、式(3.28)が超パラメータ$\hat{\boldsymbol{\alpha}}$の計算式(更新式)である。

・予測分布の計算

 続いて、カテゴリ分布に従う未観測のデータ$\mathbf{s}_{*} = \{s_{*,1}, s_{*,2}, \cdots, s_{*,K}\}$に対する予測分布を導出する。

 事前分布(観測データによる学習を行っていない分布)$p(\boldsymbol{\pi} | \boldsymbol{\alpha})$を用いて、パラメータ$\boldsymbol{\pi}$を周辺化することで予測分布$p(\mathbf{s}_{*})$となる。

$$ \begin{aligned} p(\mathbf{s}_{*}) &= \int p(\mathbf{s}_{*} | \boldsymbol{\pi}) p(\boldsymbol{\pi} | \boldsymbol{\alpha}) d\boldsymbol{\pi} \\ &= \int \mathrm{Cat}(\mathbf{s}_{*} | \boldsymbol{\pi}) \mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha}) d\boldsymbol{\pi} \\ &= \int \prod_{k=1}^K \pi_k^{s_{*,k}} C_D(\boldsymbol{\alpha}) \pi_k^{\alpha_k-1} d\boldsymbol{\pi} \\ &= C_D(\boldsymbol{\alpha}) \int \prod_{k=1}^K \pi_k^{s_{*,k}+\alpha_k-1} d\boldsymbol{\pi} \end{aligned} $$

 積分部分に注目すると、パラメータ$(s_{*,1} + \alpha_1, \cdots, s_{*,K} + \alpha_K)$を持つ正規化項のないディリクレ分布の形をしている。
 これはディリクレ分布の正規化項の逆数に変形できる(そもそもこの部分の逆数が正規化項である。詳しくは「1.2.4:ディリクレ分布【『トピックモデル』の勉強ノート】 - からっぽのしょこ」を参照のこと)。

$$ \int \prod_{k=1}^K \pi_k^{s_{*,k}+\alpha_k-1} d\boldsymbol{\pi} = \frac{1}{C_D((s_{*,k} + \alpha_k)_{k=1}^K)} $$

 これを先ほどの式に代入すると、予測分布は

$$ p(\mathbf{s}_{*}) = \frac{ C_D(\boldsymbol{\alpha}) }{ C_D \Bigl( (s_{*,k} + \alpha_k)_{k=1}^K \Bigr) } \tag{3.29} $$

となる。ここで、$(s_{*,1} + \alpha_1, \cdots, s_{*,K} + \alpha_K)$を$(s_{*,k} + \alpha_k)_{k=1}^K$と表記する。
 さらにディリクレ分布の正規化項(2.49)を用いると

$$ \begin{align} p(\mathbf{s}_{*}) &= C_D(\boldsymbol{\alpha}) \frac{ 1 }{ C_D\Bigl((s_{*,k} + \alpha_k)_{k=1}^K\Bigr) } \\ &= \frac{ \Gamma(\sum_{k=1}^K \alpha_k) }{ \prod_{k=1}^K \Gamma(\alpha_k) } \frac{ \prod_{k=1}^K \Gamma(s_{*,k} + \alpha_k) }{ \Gamma(\sum_{k=1}^K s_{*,k} + \alpha_k) } \tag{3.30} \end{align} $$

となる。

 式(3.30)について、$s_{*,i} = 1$となる($i$番目の項が1となる)場合を考える。このとき$\mathbf{s}_{*}$の$i$番目以外の項($j = 1, \cdots, i - 1, i + 1, \cdots, K$番目の項)は0である。よって、式(3.30)は

$$ \begin{align} p(s_{*,i} = 1) &= \frac{ \Gamma(\sum_{k=1}^K \alpha_k) }{ \prod_{k=1}^K \Gamma(\alpha_k) } \frac{ \Gamma(s_{*,i} + \alpha_i) \prod_{j \neq i} \Gamma(s_{*,j} + \alpha_j) }{ \Gamma(\sum_{k=1}^K s_{*,k} + \alpha_k) } \\ &= \frac{ \Gamma(\sum_{k=1}^K \alpha_k) }{ \prod_{k=1}^K \Gamma(\alpha_k) } \frac{ \Gamma(1 + \alpha_i) \prod_{j \neq i} \Gamma(0 + \alpha_j) }{ \Gamma(1 + \sum_{k=1}^K \alpha_k) } \\ &= \frac{ \Gamma(\sum_{k=1}^K \alpha_k) }{ \prod_{k=1}^K \Gamma(\alpha_k) } \frac{ \alpha_i \Gamma(\alpha_i) \prod_{j \neq i} \Gamma(\alpha_j) }{ (\sum_{k=1}^K \alpha_k) \Gamma(\sum_{k=1}^K \alpha_k) } \\ &= \frac{\alpha_i}{\sum_{k=1}^K \alpha_k} \tag{3.31} \end{align} $$

【途中式の途中式】

  1. $\prod_{k=1}^K$から$i$番目の項のみ取り出して、$\prod_{k=1}^K \Gamma(s_{*,k} + \alpha_k) = \Gamma(s_{*,i} + \alpha_i) \prod_{j \neq i} \Gamma(s_{*,j} + \alpha_j)$とする。
  2. $s_{*,i}$に1、それ以外の$s_{*,1}, \cdots, s_{*,k'-1}, s_{*,k'+1}, \cdots, s_{*,K}$に0を代入する。また、$\sum_{k=1}^K s_{*,k} = 1$である。
  3. ガンマ関数の性質$\Gamma(x + 1) = x \Gamma(x)$を用いて、項を変形する。
  4. $\prod_{k=1}^K \Gamma(\alpha_k) = \Gamma(\alpha_i) \prod_{j \neq i} \Gamma(\alpha_j)$より、約分して式を整理する。

となる。

 $s_{*,i}$以外の項$s_{*,1}, \cdots, s_{*,i-1}, s_{*,i+1}, \cdots, s_{*,K}$が1となる場合も同様に計算できるので、$x^0 = 1$であることを利用して$p(s_{*,1} = 1)$から$p(s_{*,K} = 1)$までをまとめると、式(3.30)は

$$ p(\mathbf{s}_{*}) = \prod_{k=1}^K \Bigl( \frac{\alpha_{k}}{\sum_{k'=1}^K \alpha_{k'}} \Bigr)^{s_{*,k}} $$

と書き換えられる。($s_{*,k} = 1$以外の項は0乗され1となり、式(3.31)が成り立つ。また、分母の$k'$は分子の$k$と区別するためのインデックスである。本ではこちらを$i$と表記することで区別している。)

 この式について

$$ \boldsymbol{\pi}_{*} = (\pi_{*,1}, \pi_{*,2}, \cdots, \pi_{*,K}),\ \pi_{*,k} = \frac{\alpha_k}{\sum_{k'=1}^K \alpha_{k'}} $$

とおくと

$$ p(\mathbf{s}_{*}) = \prod_{k=1}^K \pi_{*,k}^{s_{*,k}} = \mathrm{Cat}(\mathbf{s}_{*} | \boldsymbol{\pi}_{*}) \tag{3.32} $$

予測分布は式の形状から、パラメータ$\boldsymbol{\pi}_{*}$を持つカテゴリ分布であることが分かる。ちなみに、$\frac{\alpha_{k}}{\sum_{k'=1}^K \alpha_{k'}}$はディリクレ分布$\mathrm{Dir}(\boldsymbol{\pi} | \boldsymbol{\alpha})$の期待値$\mathbb{E}[\pi_k]$である。

 予測分布の計算に事前分布$p(\boldsymbol{\pi})$を用いることで、観測データによる学習を行っていない予測分布$p(\mathbf{s}_{*})$(のパラメータ$\boldsymbol{\pi}_{*}$)を求めた。事後分布$p(\boldsymbol{\pi} | \mathbf{S})$を用いると、同様の手順で観測データ$\mathbf{S}$によって学習した予測分布$p(\mathbf{s}_{*} | \mathbf{S})$(のパラメータ)を求められる。

 そこで、$\boldsymbol{\pi}_{*}$を構成する事前分布のパラメータ$\boldsymbol{\alpha}$について、事後分布のパラメータ(3.28)に置き換えたものを$\hat{\boldsymbol{\pi}}_{*}$とおくと

$$ \begin{aligned} \hat{\pi}_{*,k} &= \frac{ \hat{\alpha}_k }{ \sum_{k'=1}^K \hat{\alpha}_{k'} } \\ &= \frac{ \sum_{n=1}^N s_{n,k} + \alpha_k }{ \sum_{k'=1}^K \sum_{n=1}^N s_{n,k'} + \alpha_{k'} } \end{aligned} $$

となり、予測分布

$$ p(\mathbf{s}_{*} | \mathbf{S}) = \prod_{k=1}^K \hat{\pi}_{*,k}^{s_{*,k}} = \mathrm{Cat}(\mathbf{s}_{*} | \hat{\boldsymbol{\pi}}_{*}) $$

が得られる。
 また、上の式が予測分布のパラメータ$\hat{\pi}_{*,k}$の計算式(更新式)である。

参考文献

  • 須山敦志『ベイズ推論による機械学習入門』(機械学習スタートアップシリーズ)杉山将監修,講談社,2017年.

おわりに

 記事を分割したらコメントがなくなっちゃった。

【次節の内容】

www.anarchive-beta.com

2020/03/04:加筆修正しました。
2021/04/04:加筆修正しました。その際にRで実装編と記事を分割しました。途中微妙に間違ってた、、恥ずかしい。