からっぽのしょこ

読んだら書く!書いたら読む!読書読読書読書♪同じ事は二度調べ(たく)ない

3.4.2:多次元ガウス分布の学習と予測:精度が未知の場合【緑ベイズ入門のノート】

はじめに

 『ベイズ推論による機械学習入門』の学習時のノートです。基本的な内容は「数式の行間を読んでみた」とそれを「Rで組んでみた」になります。「数式」と「プログラム」から理解するのが目標です。

 この記事は3.4.2項の内容です。尤度関数を多次元ガウス分布、事前分布をウィシャート分布とした場合の精度パラメータの事後分布を導出し、また学習した事後分布を用いた予測分布を導出します。またその推論過程をR言語で実装します。

 省略してある内容等ありますので、本と併せて読んでください。初学者な自分が理解できるレベルまで落として書き下していますので、分かる人にはかなりくどくなっています。同じような立場の人のお役に立てれば幸いです。

【前節の内容】

www.anarchive-beta.com

【他の節一覧】

www.anarchive-beta.com

【この節の内容】

3.4.2 精度が未知の場合

・観測モデルの設定

 多次元ガウス分布に従うと仮定する$N$個の観測データ$\mathbf{X} = \{\mathbf{x}_1, \mathbf{x}_2, \cdots, \mathbf{x}_N\}$を用いて精度パラメータ$\boldsymbol{\Lambda}$の事後分布を求めていく。

 観測データ$\mathbf{X}$、平均パラメータ$\boldsymbol{\mu}$、未知の共分散行列$\boldsymbol{\Sigma}$、精度行列(パラメータ)$\boldsymbol{\Lambda}$をそれぞれ

$$ \mathbf{X} = \begin{bmatrix} x_{11} & x_{12} & \cdots & x_{1D} \\ x_{21} & x_{22} & \cdots & x_{2D} \\ \vdots & \vdots & \ddots & \vdots \\ x_{N1} & x_{N2} & \cdots & x_{ND} \end{bmatrix} ,\ \boldsymbol{\mu} = \begin{bmatrix} \mu_1 \\ \mu_2 \\ \vdots \\ \mu_D \end{bmatrix} ,\ \boldsymbol{\Sigma} = \begin{bmatrix} \sigma_{11}^2 & \sigma_{12}^2 & \cdots & \sigma_{1D}^2 \\ \sigma_{21}^2 & \sigma_{22}^2 & \cdots & \sigma_{2D}^2 \\ \vdots & \vdots & \ddots & \vdots \\ \sigma_{D1}^2 & \sigma_{D2}^2 & \cdots & \sigma_{DD}^2 \end{bmatrix} ,\ \boldsymbol{\Lambda} = \boldsymbol{\Sigma}^{-1} $$

としたとき、観測モデルは次のようになる。

$$ p(\mathbf{x} | \boldsymbol{\Lambda}) = \mathcal{N}(\mathbf{x} | \boldsymbol{\mu}, \boldsymbol{\Lambda}^{-1}) \tag{3.111} $$

 また$\boldsymbol{\Lambda}$の事前分布$p(\boldsymbol{\Lambda})$を、ウィシャート分布

$$ p(\boldsymbol{\Lambda}) = \mathcal{W}(\boldsymbol{\Lambda} | \nu, \mathbf{W}) \tag{3.112} $$

とする。$\nu$は自由度であり、$\nu > D - 1$の値をとる($\mathbf{W}$は何?)。

・事後分布の導出

 観測データ$\mathbf{X}$によって学習した$\boldsymbol{\Lambda}$の事後分布$p(\boldsymbol{\Lambda} | \mathbf{X})$は、観測モデルに対してベイズの定理を用いて

$$ \begin{aligned} p(\boldsymbol{\Lambda} | \mathbf{X}) &= \frac{ p(\mathbf{X} | \boldsymbol{\Lambda}) p(\boldsymbol{\Lambda}) }{ p(\mathbf{X}) } \\ &\propto p(\mathbf{X} | \boldsymbol{\Lambda}) p(\boldsymbol{\Lambda}) \\ &= \left( \prod_{n=1}^N p(\mathbf{x}_n | \boldsymbol{\Lambda}) \right) p(\boldsymbol{\Lambda}) \\ &= \left( \prod_{n=1}^N \mathcal{N}(\mathbf{x}_n | \boldsymbol{\mu}, \boldsymbol{\Lambda}^{-1}) \right) \mathcal{W}(\boldsymbol{\Lambda} | \nu, \mathbf{W}) \end{aligned} $$

となる。分母の$p(\mathbf{X})$は$\boldsymbol{\Lambda}$に影響しないため省略して、比例関係にのみ注目する。省略した部分については、最後に正規化することで対応できる。

 次にこの分布の具体的な形状を明らかにしていく。対数をとって指数部分の計算を分かりやすくして、$\boldsymbol{\Lambda}$に関して整理すると

$$ \begin{align} \ln p(\boldsymbol{\Lambda} | \mathbf{X}) &= \sum_{n=1}^N \ln \mathcal{N}(\mathbf{x}_n | \boldsymbol{\mu}, \boldsymbol{\Lambda}^{-1}) + \ln \mathcal{W}(\boldsymbol{\Lambda} | \nu, \mathbf{W}) + \mathrm{const.} \tag{3.113}\\ &= \sum_{n=1}^N - \frac{1}{2} \Bigl\{ (\mathbf{x}_n - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} (\mathbf{x}_n - \boldsymbol{\mu}) + \ln |\boldsymbol{\Lambda}^{-1}| + D \ln 2 \pi \Bigr\} \\ &\qquad + \frac{\nu - D - 1}{2} \ln |\boldsymbol{\Lambda}| - \frac{1}{2} \mathrm{Tr}(\mathbf{W}^{-1} \Lambda) + \ln C_{\mathcal{W}}(\nu, \mathbf{W}) + \mathrm{const.} \\ &= - \frac{1}{2} \Biggl\{ \sum_{n=1}^N \sum_{d=1}^D \sum_{d'=1}^D (x_{nd} - \mu_d) \lambda_{dd'} (x_{nd'} - \mu_{d'}) \Biggr\} - \frac{N}{2} \ln |\boldsymbol{\Lambda}|^{-1} \\ &\qquad + \frac{\nu - D - 1}{2} \ln |\boldsymbol{\Lambda}| - \frac{1}{2} \mathrm{Tr}(\mathbf{W}^{-1} \Lambda) + \mathrm{const.} \\ &= - \frac{1}{2} \mathrm{Tr} \Bigl( \sum_{n=1}^N (\mathbf{x}_n - \boldsymbol{\mu}) (\mathbf{x}_n - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} \Bigr) + \frac{N}{2} \ln |\boldsymbol{\Lambda}| \\ &\qquad + \frac{\nu - D - 1}{2} \ln |\boldsymbol{\Lambda}| - \frac{1}{2} \mathrm{Tr}(\mathbf{W}^{-1} \Lambda) + \mathrm{const.} \\ &= \frac{N + \nu - D - 1}{2} \ln |\boldsymbol{\Lambda}| - \frac{1}{2} \mathrm{Tr} \left( \Bigl\{ \sum_{n=1}^N (\mathbf{x}_n - \mu) (\mathbf{x}_n - \mu)^{\top} + \mathbf{W}^{-1} \Bigr\} \boldsymbol{\Lambda} \right) + \mathrm{const.} \tag{3.114} \end{align} $$

【途中式の途中式】

  1. それぞれ具体的な式に置き換える。
  2. 1つ目の因子をトレース$\mathrm{Tr}(\cdot)$を使って置き換える。

 精度パラメータを

$$ \boldsymbol{\Lambda} = \begin{bmatrix} \lambda_{11} & \cdots & \lambda_{1D} \\ \vdots & \ddots & \vdots \\ \lambda_{D1} & \cdots & \lambda_{DD} \end{bmatrix} $$

として、また計算を分かりやすくするため$(\mathbf{x}_n - \mu)^{\top} = (x_{n1} - \mu_1, x_{n2} - \mu_2, \cdots, x_{nD} - \mu_D) = (\tilde{x}_{n1}, \tilde{x}_{n2}, \cdots, \tilde{x}_{nD})$とおくと(他に適切な記号があれば教えて)、$(\mathbf{x}_n - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} (\mathbf{x}_n - \boldsymbol{\mu})$は

$$ \begin{aligned} (\mathbf{x}_n - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} (\mathbf{x}_n - \boldsymbol{\mu}) &= \begin{bmatrix} \tilde{x}_{n1} & \tilde{x}_{n2} & \cdots & \tilde{x}_{nD} \end{bmatrix} \begin{bmatrix} \lambda_{11} & \lambda_{12} & \cdots & \lambda_{1D} \\ \lambda_{21} & \lambda_{22} & \cdots & \lambda_{2D} \\ \vdots & \vdots & \ddots & \vdots \\ \lambda_{D1} & \lambda_{D2} & \cdots & \lambda_{DD} \end{bmatrix} \begin{bmatrix} \tilde{x}_{n1} \\ \tilde{x}_{n2} \\ \vdots \\ \tilde{x}_{nD} \end{bmatrix} \\ &= \begin{bmatrix} \sum_{d=1}^D \tilde{x}_{nd} \lambda_{d1} & \sum_{d=1}^D \tilde{x}_{nd} \lambda_{d2} & \cdots & \sum_{d=1}^D \tilde{x}_{nd} \lambda_{dD} \end{bmatrix} \begin{bmatrix} \tilde{x}_{n1} \\ \tilde{x}_{n2} \\ \vdots \\ \tilde{x}_{nD} \end{bmatrix} \\ &= \sum_{d'=1}^D \sum_{d=1}^D \tilde{x}_{nd} \lambda_{dd'} \tilde{x}_{nd'} \end{aligned} $$

となる。

 この式を整理するために、$(\mathbf{x}_n - \boldsymbol{\mu}) (\mathbf{x}_n - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda}$を考える。

$$ \begin{aligned} (\mathbf{x}_n - \boldsymbol{\mu}) (\mathbf{x}_n - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} &= \begin{bmatrix} \tilde{x}_{n1} \\ \tilde{x}_{n2} \\ \vdots \\ \tilde{x}_{nD} \end{bmatrix} \begin{bmatrix} \tilde{x}_{n1} & \tilde{x}_{n2} & \cdots & \tilde{x}_{nD} \end{bmatrix} \begin{bmatrix} \lambda_{11} & \lambda_{12} & \cdots & \lambda_{1D} \\ \lambda_{21} & \lambda_{22} & \cdots & \lambda_{2D} \\ \vdots & \vdots & \ddots & \vdots \\ \lambda_{D1} & \lambda_{D2} & \cdots & \lambda_{DD} \end{bmatrix} \\ &= \begin{bmatrix} \tilde{x}_{n1}^2 & \tilde{x}_{n1} \tilde{x}_{n2} & \cdots & \tilde{x}_{n1} \tilde{x}_{nD} \\ \tilde{x}_{n2} \tilde{x}_{n1} & \tilde{x}_{n2}^2 & \cdots & \tilde{x}_{n2} \tilde{x}_{nD} \\ \vdots & \vdots & \ddots & \vdots \\ \tilde{x}_{nD} \tilde{x}_{n1} & \tilde{x}_{nD} \tilde{x}_{n2} & \cdots & \tilde{x}_{nD}^2 \end{bmatrix} \begin{bmatrix} \lambda_{11} & \lambda_{12} & \cdots & \lambda_{1D} \\ \lambda_{21} & \lambda_{22} & \cdots & \lambda_{2D} \\ \vdots & \vdots & \ddots & \vdots \\ \lambda_{D1} & \lambda_{D2} & \cdots & \lambda_{DD} \end{bmatrix} \\ &= \begin{bmatrix} \sum_{d=1}^D \tilde{x}_{n1} \tilde{x}_{nd} \lambda_{d1} & \sum_{d=1}^D \tilde{x}_{n1} \tilde{x}_{nd} \lambda_{d2} & \cdots & \sum_{d=1}^D \tilde{x}_{n1} \tilde{x}_{nd} \lambda_{dD} \\ \sum_{d=1}^D \tilde{x}_{n2} \tilde{x}_{nd} \lambda_{d1} & \sum_{d=1}^D \tilde{x}_{n2} \tilde{x}_{nd} \lambda_{d2} & \cdots & \sum_{d=1}^D \tilde{x}_{n2} \tilde{x}_{nd} \lambda_{dD} \\ \vdots & \vdots & \ddots & \vdots \\ \sum_{d=1}^D \tilde{x}_{nD} \tilde{x}_{nd} \lambda_{d1} & \sum_{d=1}^D \tilde{x}_{nD} \tilde{x}_{nd} \lambda_{d2} & \cdots & \sum_{d=1}^D \tilde{x}_{nD} \tilde{x}_{nd} \lambda_{dD} \end{bmatrix} \end{aligned} $$

 $(\mathbf{x}_n - \boldsymbol{\mu}) (\mathbf{x}_n - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda}$は、$D \times D$の正方行列になる。この行列の対角成分の和$\mathrm{Tr} \Bigl((\mathbf{x}_n - \boldsymbol{\mu}) (\mathbf{x}_n - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda}\Bigr)$は

$$ \begin{aligned} \mathrm{Tr} \Bigl( (\mathbf{x}_n - \boldsymbol{\mu}) (\mathbf{x}_n - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} \Bigr) &= \sum_{d=1}^D \tilde{x}_{n1} \tilde{x}_{nd} \lambda_{d1} + \sum_{d=1}^D \tilde{x}_{n2} \tilde{x}_{nd} \lambda_{d2} + \cdots + \sum_{d=1}^D \tilde{x}_{nD} \tilde{x}_{nd} \lambda_{dD} \\ &= \sum_{d'=1}^D \sum_{d=1}^D \tilde{x}_{nd'} \tilde{x}_{nd} \lambda_{dd'} \end{aligned} $$

となる。

 よって

$$ (\mathbf{x}_n - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} (\mathbf{x}_n - \boldsymbol{\mu}) = \mathrm{Tr} \Bigl( (\mathbf{x}_n - \boldsymbol{\mu}) (\mathbf{x}_n - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} \Bigr) $$

であることが分かる。

  1. 式(A.12)より、$\mathrm{Tr} \Bigl(\sum_{n=1}^N (\mathbf{x}_n - \mu) (\mathbf{x}_n - \mu)^{\top} \boldsymbol{\Lambda}\Bigr) + \mathrm{Tr}(\mathbf{W}^{-1} \boldsymbol{\Lambda})= \mathrm{Tr} \Bigl(\sum_{n=1}^N (\mathbf{x}_n - \mu) (\mathbf{x}_n - \mu)^{\top} \boldsymbol{\Lambda} + \mathbf{W}^{-1} \boldsymbol{\Lambda}\Bigr)$であることから、式を整理する。


となる。適宜$\boldsymbol{\Lambda}$に影響しない項を$\mathrm{const.}$にまとめている。

 式の形から事後分布もウィシャート分布となることが分かる。そこで事後分布を次のようにおき

$$ p(\boldsymbol{\Lambda} | \mathbf{X}) = \mathcal{W}(\boldsymbol{\Lambda} | \hat{\nu}, \hat{\mathbf{W}}) \tag{3.115} $$

この式の対数をとり、$\boldsymbol{\Lambda}$に関して整理すると

$$ \ln p(\boldsymbol{\Lambda} | \mathbf{X}) = \frac{\hat{\nu} - D - 1}{2} \ln |\boldsymbol{\Lambda}| - \frac{1}{2} \mathrm{Tr}(\hat{\mathbf{W}}^{-1} \boldsymbol{\Lambda}) + \mathrm{const.} $$

となる。

 従って式(3.114)との対応関係から、事後分布のパラメータは

$$ \begin{align} \hat{\mathbf{W}} &= \sum_{n=1}^N (\mathbf{x}_n - \mu) (\mathbf{x}_n - \mu)^{\top} + \mathbf{W}^{-1} \\ \hat{\nu} &= N + \nu \tag{3.116} \end{align} $$

と求められる。
 ちなみに分散共分散行列$\boldsymbol{\Sigma}$の事後分布の場合、逆ウィシャート分布になる。

・予測分布の導出

 続いて未観測のデータ$\mathbf{x}_{*}$に対する予測分布$p(\mathbf{x}_{*})$を求めていく。

 1次元ガウス分布のときと同様に、パラメータ$\boldsymbol{\mu}$を周辺化(積分計算)して予測分布

$$ p(\mathbf{x}_{*}) = \int p(\mathbf{x}_{*} | \boldsymbol{\Lambda}) p(\boldsymbol{\Lambda}) d\boldsymbol{\Lambda} $$

を求めることは避け、ベイズの定理を用いて

$$ p(\boldsymbol{\Lambda} | \mathbf{x}_{*}) = \frac{ p(\mathbf{x}_{*} | \boldsymbol{\Lambda}) p(\boldsymbol{\Lambda}) }{ p(\mathbf{x}_{*}) } $$

予測分布を求める。

 この式の両辺の対数をとり、$p(\mathbf{x}_{*})$に関して式を整理すると

$$ \begin{align} \ln p(\boldsymbol{\Lambda} | \mathbf{x}_{*}) &= \ln p(\mathbf{x}_{*} | \boldsymbol{\Lambda}) - \ln p(\mathbf{x}_{*}) + \mathrm{const.} \\ \ln p(\mathbf{x}_{*}) &= \ln p(\mathbf{x}_{*} | \boldsymbol{\Lambda}) - \ln p(\boldsymbol{\Lambda} | \mathbf{x}_{*}) + \mathrm{const.} \tag{3.117} \end{align} $$

で計算できることが分かる。

 2つ目の項は、データが1つ($N = 1$)の事後分布(3.115)と捉えられることから、式(3.116)よりパラメータを

$$ \begin{align} \mathbf{W}_{\mathbf{x}_{*}}^{-1} &= (\mathbf{x}_{*} - \mu) (\mathbf{x}_{*} - \mu)^{\top} + \mathbf{W}^{-1} \\ \nu_{\mathbf{x}_{*}} &= 1 + \nu \tag{3.119} \end{align} $$

とおくと

$$ p(\mathbf{x}_{*} | \boldsymbol{\Lambda}) = \mathcal{W}(\boldsymbol{\Lambda} | \nu_{\mathbf{x}_{*}}, \mathbf{W}_{\mathbf{x}_{*}}) \tag{3.118} $$

となる。

 この式を式(3.117)に代入して、$\mathbf{x}_{*}$に関して式を整理すると

$$ \begin{align} \ln p(\mathbf{x}_{*}) &= \mathcal{N}(\mathbf{x}_{*} | \boldsymbol{\Lambda}) - \mathcal{W}(\boldsymbol{\Lambda} | \nu_{\mathbf{x}_{*}}, \mathbf{W}_{\mathbf{x}_{*}}) + \mathrm{const.} \\ &= - \frac{1}{2} \Bigl\{ (\mathbf{x}_{*} - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} (\mathbf{x}_{*} - \boldsymbol{\mu}) + \ln |\boldsymbol{\Lambda}^{-1}| + D \ln 2 \pi \Bigr\} \\ &\qquad - \frac{\nu_{\mathbf{x}_{*}} - D - 1}{2} \ln |\boldsymbol{\Lambda}| + \frac{1}{2} \mathrm{Tr}(\mathbf{W}_{\mathbf{x}_{*}}^{-1} \boldsymbol{\Lambda}) \\ &\qquad + \frac{\nu_{\mathbf{x}_{*}}}{2} \ln |\mathbf{W}_{\mathbf{x}_{*}}| + \frac{\nu_{\mathbf{x}_{*}} D}{2} \ln 2 + \frac{D (D - 1)}{4} \ln \pi + \sum_{d=1}^D \ln \Gamma \Bigl(\frac{\nu_{\mathbf{x}_{*}} + 1 - d}{2} \Bigr) + \mathrm{const.} \\ &= - \frac{1}{2} (\mathbf{x}_{*} - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} (\mathbf{x}_{*} - \boldsymbol{\mu}) \\ &\qquad + \frac{1}{2} \mathrm{Tr} \Bigl( (\mathbf{x}_{*} - \mu) (\mathbf{x}_{*} - \mu)^{\top} \boldsymbol{\Lambda} + \mathbf{W}^{-1} \boldsymbol{\Lambda} \Bigr) \\ &\qquad + \frac{1 + \nu}{2} \ln \Bigl| \mathbf{W}^{-1} + (\mathbf{x}_{*} - \mu) (\mathbf{x}_{*} - \mu)^{\top} \Bigr|^{-1} + \mathrm{const.} \\ &= - \frac{1}{2} (\mathbf{x}_{*} - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} (\mathbf{x}_{*} - \boldsymbol{\mu}) \\ &\qquad + \frac{1}{2} \mathrm{Tr} \Bigl( (\mathbf{x}_{*} - \mu) (\mathbf{x}_{*} - \mu)^{\top} \boldsymbol{\Lambda} \Bigr) + \frac{1}{2} \mathrm{Tr} \Bigl( \mathbf{W}^{-1} \boldsymbol{\Lambda} \Bigr) \\ &\qquad - \frac{1 + \nu}{2} \ln |\mathbf{W}^{-1}| \Bigl| \mathbf{I}_D + \mathbf{W} (\mathbf{x}_{*} - \mu) (\mathbf{x}_{*} - \mu)^{\top} \Bigr| + \mathrm{const.} \\ &= - \frac{1 + \nu}{2} \ln \Bigl| \mathbf{I}_1 + \{ \mathbf{W} (\mathbf{x}_{*} - \boldsymbol{\mu}) \}^{\top} (\mathbf{x}_{*} - \boldsymbol{\mu}) \Bigr| - \frac{1 + \nu}{2} \ln |\mathbf{W}^{-1}| + \mathrm{const.} \\ &= - \frac{1 + \nu}{2} \ln \Bigl\{ 1 + (\mathbf{x}_{*} - \boldsymbol{\mu})^{\top} \mathbf{W}^{\top} (\mathbf{x}_{*} - \boldsymbol{\mu}) \Bigr\} + \mathrm{const.} \\ &= - \frac{1 + \nu}{2} \ln \Bigl\{ 1 + (\mathbf{x}_{*} - \boldsymbol{\mu})^{\top} \mathbf{W} (\mathbf{x}_{*} - \boldsymbol{\mu}) \Bigr\} + \mathrm{const.} \tag{3.120} \end{align} $$

【途中式の途中式】

  1. 具体的な式に置き換える。
  2. $\nu_{\mathbf{x}_{*}},\ \mathbf{W_{\mathbf{x}_{*}}}$にそれぞれ式(3.117)を代入する。このとき式(A.18)より、$|\mathbf{W}_{\mathbf{x}_{*}}| = |\mathbf{W}_{\mathbf{x}_{*}}^{-1}|^{-1}$である。
  3. 2行目の項は式(A.12)より、3行目の項は$\mathbf{W}^{-1} + (\mathbf{x}_{*} - \mu) (\mathbf{x}_{*} - \mu)^{\top} = \mathbf{W}^{-1} \{\mathbf{I}_D + \mathbf{W} (\mathbf{x}_{*} - \mu) (\mathbf{x}_{*} - \mu)^{\top}\}$とできるので式(A.17)より、それぞれ分解する。
  4. 式を整理する。
    • 事後分布の導出時より、$(\mathbf{x}_n - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} (\mathbf{x}_n - \boldsymbol{\mu}) = \mathrm{Tr} ((\mathbf{x}_n - \boldsymbol{\mu}) (\mathbf{x}_n - \boldsymbol{\mu})^{\top}\boldsymbol{\Lambda})$である。
    • $\mathbf{W} (\mathbf{x}_{*} - \boldsymbol{\mu})$を1つの行列とみると、$(D \times 1)$、$(1 \times D)$の行列の積なので、式(A.19)の変形を行う。
  5. 行列式の定義より、スカラの行列式はスカラ$|a| = a$である。また式(A.2)より、転置する。
  6. $(\mathbf{x}_{*} - \boldsymbol{\mu})^{\top} \mathbf{W}^{\top} (\mathbf{x}_{*} - \boldsymbol{\mu})$はスカラなので転置しても影響を受けないため、$(\mathbf{x}_{*} - \boldsymbol{\mu})^{\top} \mathbf{W}^{\top} (\mathbf{x}_{*} - \boldsymbol{\mu}) = \{(\mathbf{x}_{*} - \boldsymbol{\mu})^{\top} \mathbf{W}^{\top} (\mathbf{x}_{*} - \boldsymbol{\mu})\}^{\top}$である。また式(A.2)と同様に、$(A B C)^{\top} = C^{\top} B^{\top} A^{\top}$である。


となる。適宜$\mathbf{x}_{*}$に影響しない項を$\mathrm{const.}$にまとめている。

 式の形から予測分布は多次元のスチューデントのt分布となることが分かる。そこで次のようにおき

$$ p(\mathbf{x}_{*}) = \mathrm{St}(\mathbf{x}_{*} | \boldsymbol{\mu}_s, \boldsymbol{\Lambda}_s, \nu_s) \tag{3.123} $$

この式の対数をとり、$\mathbf{x}_{*}$に関して整理すると

$$ \ln \mathrm{St}(\mathbf{x} | \boldsymbol{\mu}_s, \boldsymbol{\Lambda}_s, \nu_s) = - \frac{\nu_s + D}{2} \ln \left\{ 1 + \frac{1}{\nu_s} (\mathbf{x}_{*} - \boldsymbol{\mu}_s)^{\top} \boldsymbol{\Lambda}_s (\mathbf{x}_{*} - \boldsymbol{\mu}_s) \right\} + \mathrm{const.} \tag{3.122} $$

となる。

 従って式(3.120)との対応関係から、予測分布のパラメータは

$$ \begin{align} \boldsymbol{\mu}_s &= \boldsymbol{\mu} \\ \boldsymbol{\Lambda}_s &= \nu_s \mathbf{W} \\ &= (1 - D + \nu) \mathbf{W} \\ \nu_s &= 1 - D + \nu \tag{3.124} \end{align} $$

と求まる。

 事前分布のパラメータ$\nu,\ \mathbf{W}$を事後分布のパラメータ$\hat{\nu},\ \hat{\mathbf{W}}$に置き換えると、観測データによって学習を行った予測分布$p(\mathbf{x}_{*} | \mathbf{X})$のパラメータ

$$ \begin{aligned} \boldsymbol{\mu}_s &= \boldsymbol{\mu} \\ \hat{\boldsymbol{\Lambda}}_s &= (1 - D + \hat{\nu}) \hat{\mathbf{W}} \\ &= (1 - D + N + \nu) \left\{ \sum_{n=1}^N (\mathbf{x}_n - \mu) (\mathbf{x}_n - \mu)^{\top} + \mathbf{W}^{-1} \right\} \\ \hat{\nu}_s &= 1 - D + \hat{\nu} \\ &= 1 - D + N + \nu \end{aligned} $$

が得られる。

・Rでやってみよう

 多次元ガウス分布に従いランダムに生成したデータを用いて、パラメータを推定してみましょう。

 利用するパッケージを読み込みます。

# 利用パッケージ
library(tidyverse)
library(mvtnorm)
library(mvnfast)

 mvtnormパッケージは、多次元ガウス分布に関するパッケージです。多次元ガウス分布に従う乱数生成関数rmvnorm()と確率密度計算関数dmvnorm()を使います。
 `mvnfastパッケージは、多次元スチューデントのt分布の確率密度関数dmvt()を使います。

・観測モデルの設定

 観測モデルのパラメータを指定します。この例では2次元のグラフで表現するため、$D = 2$のときのみ動作します。

# データ数を指定
N <- 100
D <- 2 # (固定)

# 観測モデルのパラメータを指定
mu_d <- c(25, 50)
sigma_truth_dd <- matrix(c(50, 30, 30, 50), nrow = 2, ncol = 2)
lambda_truth_dd <- solve(sigma_truth_dd^2)

 観測モデルの平均パラメータ$\boldsymbol{\mu}$をmu_dとします。
 分散共分散行列$\boldsymbol{\Sigma}$ではなく、分散共分散の平方根$(\sigma_{11}, \cdots, \sigma_{DD})$をsigma_truth_ddとします。これは作図時に標準偏差を利用するためです。matrix()のデフォルトの仕様上、$(\sigma_{11}, \sigma_{12}, \sigma_{21}, \sigma_{22})$の順番に値を指定します。
 sigma_truth_ddの各要素を2乗したものが$\boldsymbol{\Sigma}$です。更にその逆行列が精度パラメータ$\boldsymbol{\Lambda}$でした。逆行列の計算はsolve()で行い、その計算結果をlambda_truth_ddとします。この値を推論するのがこの項の目的になります。

 次に事前分布のパラメータを指定します。

# 事前分布のパラメータを指定
W_dd <- matrix(c(0.00005, 0, 0, 0.00005), nrow = 2, ncol = 2)
nu <- 2

 ウィシャート事前分布のパラメータ$\mathbf{W},\ \nu$をそれぞれW_ddnuとします。

 作図時に利用する格子状の点(データ)と平均の点のデータフレームを作成しておきます。

# 作図用の点を生成
x_vec <- seq(mu_d[1] - 2 * sigma_truth_dd[1, 1], mu_d[1] + 2 * sigma_truth_dd[1, 1], by = 0.5)
y_vec <- seq(mu_d[2] - 2 * sigma_truth_dd[2, 2], mu_d[2] + 2 * sigma_truth_dd[2, 2], by = 0.5)
point_df <- tibble(
  x = rep(x_vec, times = length(y_vec)), 
  y = rep(y_vec, each = length(x_vec))
)
mu_df <- tibble(
  x = mu_d[1], 
  y = mu_d[2]
)

 この例では、平均から標準偏差の2倍の範囲をグラフ化することにします。

 設定した観測モデルのパラメータに従って乱数を生成します。

# 2次元ガウス分布に従うデータを生成
x_nd <- mvtnorm::rmvnorm(n = N, mean = mu_d, sigma = sigma_truth_dd^2)
summary(x_nd)
##        V1               V2        
##  Min.   :-95.30   Min.   :-49.47  
##  1st Qu.: -9.20   1st Qu.: 22.05  
##  Median : 29.89   Median : 62.35  
##  Mean   : 33.45   Mean   : 61.04  
##  3rd Qu.: 75.12   3rd Qu.: 99.86  
##  Max.   :158.02   Max.   :166.78

 mvtnorm::rmvnorm()で多次元ガウス分布に従う乱数を生成し、観測データ$\mathbf{X}$(x_nd)とします。
 引数nにはデータ数N、引数meanには平均パラメータmu_d、引数sigmaにはsimga_truth_ddの2乗を指定します。精度行列の逆行列solve(lambda_truth_dd)を引数sigmaに指定することもできます。

 観測データを散布図で確認しましょう。

# 観測データのデータフレーム
sample_df <- tibble(
  x = x_nd[, 1], 
  y = x_nd[, 2]
)

# 観測モデルのデータフレーム
model_df <- tibble(
  xy = point_df, 
  density = mvtnorm::dmvnorm(x = xy, mean = mu_d, sigma = sigma_truth_dd^2), # 確率密度
) %>% 
  dplyr::select(density) %>% 
  cbind(point_df, .)

# 観測データの散布図を作成
ggplot() + 
  geom_point(data = sample_df, aes(x = x, y = y)) + # 観測データ
  geom_contour(data = model_df, aes(x, y, z = density, color = ..level..)) + # 観測モデル
  geom_point(data = mu_df, aes(x = x, y = y), color = "red", shape = 3, size = 5) + # 平均パラメータ
  labs(title = "Multivariate Gaussian Distribution", 
       subtitle = paste0("N=", N, ", mu=(", paste(round(mu_d, 1), collapse = ", "), ")", 
                         ", sigma=(", paste(round(sigma_truth_dd, 1), collapse = ", "), ")"), 
       x = expression(x[1]), y = expression(x[2]), 
       color = "density") # ラベル

f:id:anemptyarchive:20201021090753p:plain
観測データの散布図

 観測モデルの分布と重ねて描画します。多次元ガウス分布の確率密度は、mvtnorm::dmvnorm()で計算します。等高線グラフgeom_contour()には、格子状の点を渡す必要があります。

 では事後分布を求めていきます。

・事後分布

 観測データを使って、事後分布のパラメータを計算します。

# 事後分布のパラメータを計算
W_hat_dd <- solve(
  (t(x_nd) - mu_d) %*% t(t(x_nd) - mu_d) + solve(W_dd)
)
nu_hat <- N + nu

 事後分布のパラメータ$\hat{\mathbf{W}},\ \hat{\nu}$をそれぞれW_hat_ddnu_hatとします。計算式(更新式)はそれぞれ次の式です。

$$ \begin{align} \hat{\mathbf{W}} &= \sum_{n=1}^N (\mathbf{x}_n - \mu) (\mathbf{x}_n - \mu)^{\top} + \mathbf{W}^{-1} \\ \hat{\nu} &= N + \nu \tag{3.116} \end{align} $$

 ただし上の実装例では、効率よく処理するために転置して計算しています。この式の通りに計算するには、次のような処理となります。

sum_dot_x <- matrix(0, D, D)
for(n in 1:N) {
  dot_x <- matrix(x_nd[n, ] - mu_d) %*% t(x_nd[n, ] - mu_d)
  sum_dot_x <- sum_dot_x + dot_x
}
W_hat_dd <- solve(sum_dot_x + solve(W_dd))

 for()によってN回加算する処理が$\sum_{n=1}^N$の計算に対応します。

 精度行列$\boldsymbol{\Lambda}$の事後分布$p(\boldsymbol{\Lambda} | \mathbf{X})$のパラメータを計算できました。これまでのように、このパラメータを用いて事後分布(ウィシャート分布)を計算できます。しかし精度パラメータの分布はイメージしにくいため、精度パラメータの期待値$\mathbb{E}[\hat{\boldsymbol{\Lambda}}]$を用いた多次元ガウス分布を可視化することで、観測モデルと比較しましょう。
 ちなみにウィシャート分布の確率密度はMCMCpack::dwish()で計算できます。

# 精度パラメータの期待値を計算
lambda_E_dd <- nu_hat * W_hat_dd

 ウィシャート分布の期待値(2.89)より

$$ \mathbb{E}[\hat{\boldsymbol{\Lambda}}] = \hat{\nu} \hat{\mathbf{W}} $$

で計算できます。

 求めたパラメータを使って、確率密度を計算します。

# 事後分布の期待値を用いた分布を計算
posterior_df <- tibble(
  xy = point_df, 
  density = mvtnorm::dmvnorm(x = xy, mean = mu_d, sigma = solve(lambda_E_dd)) # 確率密度
) %>% 
  dplyr::select(density) %>% 
  cbind(point_df, .)
head(posterior_df)
##       x   y      density
## 1 -75.0 -50 5.682318e-06
## 2 -74.5 -50 5.742626e-06
## 3 -74.0 -50 5.802989e-06
## 4 -73.5 -50 5.863398e-06
## 5 -73.0 -50 5.923840e-06
## 6 -72.5 -50 5.984302e-06

 作図用にデータフレームとしておきます。

 多次元ガウス分布を作図します。

# 作図
ggplot() + 
  geom_contour(data = posterior_df, aes(x, y, z = density, color = ..level..)) + # 精度の期待値を用いた分布
  geom_contour(data = model_df, aes(x, y, z = density, color = ..level..), 
               alpha = 0.5, linetype = "dashed") + # 観測モデル
  geom_point(data = mu_df, aes(x = x, y = y), color = "red", shape = 3, size = 5) + # 平均値
  labs(title = "Multivariate Gaussian Distribution", 
       subtitle = paste0("N=", N, ", mu=(", paste(round(mu_d, 1), collapse = ", "), ")", 
                         ", E_sigma_hat=(", paste(round(sqrt(solve(lambda_E_dd)), 1), collapse = ", "), ")"), 
       x = expression(x[1]), y = expression(x[2]), 
       color = "density") # ラベル

f:id:anemptyarchive:20201021090952p:plain
精度パラメータの事後分布の期待値$\hat{\boldsymbol{\Lambda}}$を用いた分布

 観測データの散布図のときと同様に、観測モデル(真の分布)と重ねて描画します。

 続いて予測分布を求めていきます。

・予測分布

 事後分布のパラメータを使って、予測分布のパラメータを計算します。

# 予測分布のパラメータを計算
mu_s_d <- mu_d
lambda_s_hat_dd <- (1 - D + nu_hat) * W_hat_dd
nu_s_hat <- 1 - D + nu_hat

 予測分布(多次元スチューデントのt分布)のパラメータ$\boldsymbol{\mu}_s,\ \hat{\boldsymbol{\Lambda}}_s,\ \hat{\nu}_s$をそれぞれmu_s_dlambda_s_hat_ddnu_s_hatとします。計算式(更新式)はそれぞれ次の式です。

$$ \begin{aligned} \boldsymbol{\mu}_s &= \boldsymbol{\mu} \\ \hat{\boldsymbol{\Lambda}}_s &= (1 - D + \hat{\nu}) \hat{\mathbf{W}} \\ \hat{\nu}_s &= 1 - D + \hat{\nu} \end{aligned} $$

 求めたパラメータを使って、予測分布(の確率密度)を計算します。

# 予測分布を計算
predict_df <- cbind(
  point_df, 
  density = mvnfast::dmvt(
    X = as.matrix(point_df), mu = mu_s_d, sigma = solve(lambda_s_hat_dd), df = nu_s_hat
  ) # 確率密度
)
head(predict_df)
##       x   y      density
## 1 -75.0 -50 5.801869e-06
## 2 -74.5 -50 5.861283e-06
## 3 -74.0 -50 5.920743e-06
## 4 -73.5 -50 5.980238e-06
## 5 -73.0 -50 6.039755e-06
## 6 -72.5 -50 6.099284e-06

 多次元スチューデントのt分布の確率密度は、mvnfast::dmvt()で計算できます。dmvt()の第1引数には、複数のデータをマトリクスで渡すことができます。また平均引数muにはmu_s_d、スケール引数sigmaにはlambda_s_hat_ddの逆行列、自由度引数dfにはnu_s_hatを指定します。
 mvtnorm::dmvt()でも計算できるはずですが、ちょっとよく分からなかった…。

 予測分布を作図します。

# 作図
ggplot() + 
  geom_contour(data = predict_df, aes(x, y, z = density, color = ..level..)) + # 予測分布
  geom_contour(data = model_df, aes(x, y, z = density, color = ..level..), 
               alpha = 0.5, linetype = "dashed") + # 観測モデル
  geom_point(data = mu_df, aes(x = x, y = y), color = "red", shape = 3, size = 5) + # 平均パラメータ
  labs(title = "Multivariate Student's t Distribution", 
       subtitle = paste0("N=", N, ", mu_s=(", paste(round(mu_s_d, 1), collapse = ", "), ")", 
                         ", lambda_s_hat=(", paste(round(lambda_s_hat_dd, 1), collapse = ", "), ")", 
                         ", df=", nu_s_hat), 
       x = expression(x[1]), y = expression(x[2]), 
       color = "density") # ラベル

f:id:anemptyarchive:20201021091116p:plain
未知データ$\mathbf{x}_{*}$の予測分布

 これまでと同様に作図できます。

・おまけ

 gganimateパッケージを利用して、パラメータの推定値の推移を確認するためのgif画像を作成するコードです。

・コード(クリックで展開)

# 追加パッケージ
library(gganimate)

# データ数を指定
N <- 100
D <- 2 # (固定)

# 観測モデルのパラメータを指定
mu_d <- c(25, 50)
sigma_truth_dd <- matrix(c(50, 30, 30, 50), nrow = 2, ncol = 2)
lambda_truth_dd <- solve(sigma_truth_dd^2)

# 事前分布のパラメータを指定
nu <- 2
W_dd <- matrix(c(0.00005, 0, 0, 0.00005), nrow = 2, ncol = 2)

# 精度パラメータの期待値を計算
lambda_E_dd <- nu * W_dd

# 作図用の点を生成
x_vec <- seq(mu_d[1] - 3 * sigma_truth_dd[1, 1], mu_d[1] + 3 * sigma_truth_dd[1, 1], by = 1)
y_vec <- seq(mu_d[2] - 3 * sigma_truth_dd[2, 2], mu_d[2] + 3 * sigma_truth_dd[2, 2], by = 1)
point_df <- tibble(
  x = rep(x_vec, times = length(y_vec)), 
  y = rep(y_vec, each = length(x_vec))
)
mu_df <- tibble(
  x = mu_d[1], 
  y = mu_d[2]
)

# 観測モデルを計算
model_df <- tibble(
  xy = point_df, 
  density = mvtnorm::dmvnorm(x = xy, mean = mu_d, sigma = sigma_truth_dd^2), # 確率密度
) %>% 
  dplyr::select(density) %>% 
  cbind(point_df, .)


# 事前分布の期待値を用いた分布を計算
posterior_df <- tidyr::tibble(
  xy = point_df, 
  density = mvtnorm::dmvnorm(x = xy, mean = mu_d, sigma = solve(lambda_E_dd)), # 確率密度
  iteration = 0 # 試行回数
) %>% 
  dplyr::select(density, iteration) %>% 
  cbind(point_df, .)

# 予測分布のパラメータを計算
mu_s_d <- mu_d
lambda_s_dd <- (1 - D + nu) * W_dd
nu_s <- 1 - D + nu

# 初期値による予測分布を計算
predict_df <- cbind(
  point_df, 
  density = mvnfast::dmvt(
    X = as.matrix(point_df), mu = mu_s_d, sigma = solve(lambda_s_dd), df = nu_s
  ), # 確率密度
  iteration = 0 # 試行回数
)

# ベイズ推論
for(i in 1:N) {
  
  # 2次元ガウス分布に従うデータを生成
  x_nd <- mvtnorm::rmvnorm(n = 1, mean = mu_d, sigma = sigma_truth_dd^2)
  
  # 観測データを記録
  if(i > 1) { # 初回以外
    # オブジェクトを結合
    x_mat <- rbind(x_mat, x_nd)
    sample_df <- tibble(
      x = x_mat[, 1],
      y = x_mat[, 2], 
      iteration = i
    ) %>% 
      rbind(sample_df, .)
  } else if(i == 1){ # 初回
    # オブジェクトを作成
    x_mat <- x_nd
    sample_df <- tibble(
      x = x_mat[, 1],
      y = x_mat[, 2], 
      iteration = i
    )
  }
  
  # 事後分布のパラメータを更新
  W_dd <- solve(
    (t(x_nd) - mu_d) %*% t(t(x_nd) - mu_d) + solve(W_dd)
  )
  nu <- 1 + nu
  
  # 精度パラメータの期待値を計算
  lambda_E_dd <- nu * W_dd
  
  # 事後分布の期待値を用いた分布を計算
  tmp_posterior_df <- tidyr::tibble(
    xy = point_df, 
    density = mvtnorm::dmvnorm(x = xy, mean = mu_d, sigma = solve(lambda_E_dd)), # 確率密度
    iteration = i # 試行回数
  ) %>% 
    dplyr::select(density, iteration) %>% 
    cbind(point_df, .)
  
  # 予測分布のパラメータを更新
  mu_s_d <- mu_d
  lambda_s_dd <- (1 - D + nu) * W_dd
  nu_s <- 1 - D + nu
  
  # 予測分布を計算
  tmp_predict_df <- cbind(
    point_df, 
    density = mvnfast::dmvt(
      X = as.matrix(point_df), mu = mu_s_d, sigma = solve(lambda_s_dd), df = nu_s
    ), # 確率密度
    iteration = i # 試行回数
  )
  
  # 推論結果を結合
  posterior_df <- rbind(posterior_df, tmp_posterior_df)
  predict_df <- rbind(predict_df, tmp_predict_df)
  
  # 動作確認
  print(i)
}

# 事後分布の期待値を用いた分布を作図
posterior_graph <- ggplot() + 
  geom_contour(data = posterior_df, aes(x, y, z = density, color = ..level..)) +  # 精度の期待値を用いた分布
  geom_point(data = sample_df, aes(x = x, y = y)) + # 観測データ
  geom_contour(data = model_df, aes(x, y, z = density, color = ..level..), 
               alpha = 0.5, linetype = "dashed") + # 観測モデル
  geom_point(data = mu_df, aes(x = x, y = y), color = "red", shape = 3, size = 5) +  # 平均パラメータ
  transition_manual(iteration) +  # フレーム
  labs(title = "Multivariate Gaussian Distribution", 
       subtitle = "N={current_frame}", 
       x = expression(x[1]), y = expression(x[2]), 
       color = "density") # ラベル

# gif画像を作成
animate(posterior_graph, nframes = N + 1, fps = 5)


# 予測分布を作図
predict_graph <- ggplot() + 
  geom_contour(data = predict_df, aes(x, y, z = density, color = ..level..)) +  # 予測分布
  geom_point(data = sample_df, aes(x = x, y = y)) + # 観測データ
  geom_contour(data = model_df, aes(x, y, z = density, color = ..level..), 
               alpha = 0.5, linetype = "dashed") + # 観測モデル
  geom_point(data = mu_df, aes(x = x, y = y), color = "red", shape = 3, size = 5) +  # 平均パラメータ
  transition_manual(iteration) +  # フレーム
  labs(title = "Multivariate Student's t Distribution", 
       subtitle = "N={current_frame}", 
       x = expression(x[1]), y = expression(x[2]), 
       color = "density") # ラベル

# gif画像を作成
animate(predict_graph, nframes = N + 1, fps = 5)

 異なる点のみを簡単に解説します。

 各データによってどのように学習する(推定値が変化する)のかを確認するため、こちらはfor()で1データずつ処理します。よって観測データ数Nがイタレーション数になります。

 一度の処理で事後分布のパラメータを計算するのではなく、事前分布のパラメータに対して繰り返し観測データの情報を与えることで更新(上書き)していきます。$\sum_{n=1}^N$の計算は、N回繰り返し計算することで実行されます。

f:id:anemptyarchive:20201021091224g:plain
事後分布の推移

f:id:anemptyarchive:20201021091308g:plain
予測分布の推移


参考文献

  • 須山敦志『ベイズ推論による機械学習入門』(機械学習スタートアップシリーズ)杉山将監修,講談社,2017年.

おわりに

 やっぱり今はベイズ推論が楽しい!ところでスケール行列って何?

 mvnfastパッケージにも多次元ガウス分布関連の関数がありますが、書き換えるのが面倒だったので止めました。

【次節の内容】続く