からっぽのしょこ

読んだら書く!書いたら読む!読書読読書読書♪同じ事は二度調べ(たく)ない

3.4.3:多次元ガウス分布の学習と予測:平均・精度が未知の場合【緑ベイズ入門のノート】

はじめに

 『ベイズ推論による機械学習入門』の学習時のノートです。基本的な内容は「数式の行間を読んでみた」とそれを「Rで組んでみた」になります。「数式」と「プログラム」から理解するのが目標です。

 この記事は3.4.3項の内容です。尤度関数を多次元ガウス分布(多変量正規分布)、事前分布をガウス・ウィシャート分布とした場合の平均パラメータと精度パラメータの事後分布を導出し、また学習した事後分布を用いた予測分布を導出します。またその推論過程をR言語で実装します。

 省略してある内容等ありますので、本と併せて読んでください。初学者な自分が理解できるレベルまで落として書き下していますので、分かる人にはかなりくどくなっています。同じような立場の人のお役に立てれば幸いです。

【前節の内容】

www.anarchive-beta.com

【他の節一覧】

www.anarchive-beta.com

【この節の内容】

3.4.3 平均・精度が未知の場合

・観測モデルの設定

 多次元ガウス分布に従うと仮定する$N$個の観測データ$\mathbf{X} = \{\mathbf{x}_1, \mathbf{x}_2, \cdots, \mathbf{x}_N\}$を用いて、平均パラメータ$\boldsymbol{\mu}$と精度パラメータ$\boldsymbol{\Lambda}$の事後分布を求めていく。

 観測データ$\mathbf{X}$、未知の平均パラメータ$\boldsymbol{\mu}$、未知の共分散行列$\boldsymbol{\Sigma}$、精度行列(パラメータ)$\boldsymbol{\Lambda}$をそれぞれ

$$ \mathbf{X} = \begin{bmatrix} x_{11} & x_{12} & \cdots & x_{1D} \\ x_{21} & x_{22} & \cdots & x_{2D} \\ \vdots & \vdots & \ddots & \vdots \\ x_{N1} & x_{N2} & \cdots & x_{ND} \end{bmatrix} ,\ \boldsymbol{\mu} = \begin{bmatrix} \mu_1 \\ \mu_2 \\ \vdots \\ \mu_D \end{bmatrix} ,\ \boldsymbol{\Sigma} = \begin{bmatrix} \sigma_{11}^2 & \sigma_{12}^2 & \cdots & \sigma_{1D}^2 \\ \sigma_{21}^2 & \sigma_{22}^2 & \cdots & \sigma_{2D}^2 \\ \vdots & \vdots & \ddots & \vdots \\ \sigma_{D1}^2 & \sigma_{D2}^2 & \cdots & \sigma_{DD}^2 \end{bmatrix} ,\ \boldsymbol{\Lambda} = \boldsymbol{\Sigma}^{-1} $$

としたとき、観測モデルは次のようになる。

$$ p(\mathbf{x} | \boldsymbol{\mu}, \boldsymbol{\Lambda}) = \mathcal{N}(\mathbf{x} | \boldsymbol{\mu}, \boldsymbol{\Lambda}^{-1}) $$

 また$\boldsymbol{\mu},\ \boldsymbol{\Lambda}$の事前分布$p(\boldsymbol{\mu}, \boldsymbol{\Lambda})$を、ガウス・ウィシャート分布

$$ \begin{align} p(\boldsymbol{\mu}, \boldsymbol{\Lambda}) &= \mathrm{NW}(\boldsymbol{\mu}, \boldsymbol{\Lambda} | \mathbf{m}, \beta, \nu, \mathbf{W}) \\ &= \mathcal{N}(\boldsymbol{\mu} | \mathbf{m}, (\beta \boldsymbol{\Lambda})^{-1}) \mathcal{W}(\boldsymbol{\Lambda} | \nu, \mathbf{W}) \tag{3.125} \end{align} $$

とする。$\mathbf{m}$はガウス分布の平均、$\nu$はウィシャート分布の自由度であり$\nu > D - 1$の値をとる。

・事後分布の導出

 観測データ$\mathbf{X}$によって学習した$\boldsymbol{\mu},\ \boldsymbol{\Lambda}$の事後分布$p(\boldsymbol{\mu}, \boldsymbol{\Lambda} | \mathbf{X})$は、観測モデルに対してベイズの定理を用いて

$$ \begin{align} p(\boldsymbol{\mu}, \boldsymbol{\Lambda} | \mathbf{X}) &= \frac{ p(\mathbf{X} | \boldsymbol{\mu}, \boldsymbol{\Lambda}) p(\boldsymbol{\boldsymbol{\mu}, \Lambda}) }{ p(\mathbf{X}) } \tag{3.126}\\ &\propto p(\mathbf{X} | \boldsymbol{\mu}, \boldsymbol{\Lambda}) p(\boldsymbol{\mu}, \boldsymbol{\Lambda}) \\ &= \left( \prod_{n=1}^N p(\mathbf{x}_n | \boldsymbol{\mu}, \boldsymbol{\Lambda}) \right) p(\boldsymbol{\boldsymbol{\mu}, \Lambda}) \\ &= \left( \prod_{n=1}^N \mathcal{N}(\mathbf{x}_n | \boldsymbol{\mu}, \boldsymbol{\Lambda}^{-1}) \right) \mathcal{N}(\boldsymbol{\mu} | \mathbf{m}, (\beta \boldsymbol{\Lambda})^{-1}) \mathcal{W}(\boldsymbol{\Lambda} | \nu, \mathbf{W}) \end{align} $$

となる。分母の$p(\mathbf{X})$は$\boldsymbol{\Lambda}$に影響しないため省略して、比例関係にのみ注目する。省略した部分については、最後に正規化することで対応できる。
 また左辺の事後分布は

$$ p(\boldsymbol{\mu}, \boldsymbol{\Lambda} | \mathbf{X}) = p(\boldsymbol{\mu} | \boldsymbol{\Lambda}, \mathbf{X}) p(\boldsymbol{\Lambda} | \mathbf{X}) \tag{3.127} $$

と分解できる。

 次にこの分布の具体的な形状を明らかにしていく。式(3.127)を式(3.126)に代入して対数をとり指数部分の計算を分かりやすくして、$\boldsymbol{\mu}$に関して整理すると$\boldsymbol{\mu}$の事後分布は

$$ \begin{aligned} \ln p(\boldsymbol{\mu} | \boldsymbol{\Lambda}, \mathbf{X}) &= \sum_{n=1}^N \ln \mathcal{N}(\mathbf{x}_n | \boldsymbol{\mu}, \boldsymbol{\Lambda}^{-1}) + \ln \mathcal{N}(\boldsymbol{\mu} | \mathbf{m}, (\beta \boldsymbol{\Lambda})^{-1}) + \mathrm{const.} \end{aligned} $$

で求められる。$\ln \mathcal{W}(\boldsymbol{\Lambda} | \nu, \mathbf{W})$や(右辺に移項した)$\ln p(\boldsymbol{\Lambda} | \mathbf{X})$は、$\boldsymbol{\mu}$に影響しないため$\mathrm{const.}$に含めている。

 この式は、平均未知(3.4.1項)のときの精度パラメータ$\boldsymbol{\Lambda}_{\boldsymbol{\mu}}$を$\beta \boldsymbol{\Lambda}$に置き換えたものと言える。つまり3.4.1項と同じ手順で求められる。そこで$\boldsymbol{\mu}$の事後分布の精度パラメータを$\hat{\beta} \boldsymbol{\Lambda}$とおき、3.4.1項における事後分布の精度パラメータ$\hat{\boldsymbol{\Lambda}}_{\boldsymbol{\mu}}$の計算式(3.102)の$\boldsymbol{\Lambda}_{\boldsymbol{\mu}}$を$\beta \boldsymbol{\Lambda}$に置き換えると

$$ \begin{align} \hat{\beta} \boldsymbol{\Lambda} &= N \boldsymbol{\Lambda} + \beta \boldsymbol{\Lambda} \tag{3.102'}\\ \hat{\beta} &= N + \beta \tag{3.129.a} \end{align} $$

が得られる。事後分布の平均パラメータ$\hat{\mathbf{m}}$の計算式(3.103)についても同様に

$$ \begin{align} \hat{\mathbf{m}} &= (\hat{\beta} \boldsymbol{\Lambda})^{-1} \left( \boldsymbol{\Lambda} \sum_{n=1}^N \mathbf{x}_n + \beta \boldsymbol{\Lambda} \mathbf{m} \right) \tag{3.103'}\\ &= \frac{1}{\hat{\beta}} \boldsymbol{\Lambda}^{-1} \left( \boldsymbol{\Lambda} \sum_{n=1}^N \mathbf{x}_n + \beta \boldsymbol{\Lambda} \mathbf{m} \right) \\ &= \frac{1}{\hat{\beta}} \left( \sum_{n=1}^N \mathbf{x}_n + \beta \mathbf{m} \right) \tag{3.129.b} \end{align} $$

と求められる。

 従って$\boldsymbol{\mu}$の事後分布は

$$ p(\boldsymbol{\mu} | \mathbf{X}) = \mathcal{N}(\boldsymbol{\mu} | \hat{\mathbf{m}}, (\hat{\beta} \boldsymbol{\Lambda})^{-1}) $$

となる。

 $\boldsymbol{\Lambda}$の事後分布についても同様に、式(3.126)の対数をとり具体的な式に置き換え

$$ \begin{aligned} \ln p(\boldsymbol{\Lambda} | \mathbf{X}) &= \sum_{n=1}^N \ln \mathcal{N}(\mathbf{x}_n | \boldsymbol{\mu}, \boldsymbol{\Lambda}^{-1}) + \ln \mathcal{N}(\boldsymbol{\mu} | \mathbf{m}, (\beta \boldsymbol{\Lambda})^{-1}) - \ln \mathcal{N}(\boldsymbol{\mu} | \hat{\mathbf{m}}, (\hat{\beta} \boldsymbol{\Lambda})^{-1}) + \ln \mathcal{W}(\boldsymbol{\Lambda} | \nu, \mathbf{W}) + \mathrm{const.} \\ &= \sum_{n=1}^N - \frac{1}{2} \Bigl\{ (\mathbf{x}_n - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} (\mathbf{x}_n - \boldsymbol{\mu}) + \ln |\boldsymbol{\Lambda}^{-1}| + D \ln 2 \pi \Bigr\} \\ &\qquad - \frac{1}{2} \Bigl\{ (\boldsymbol{\mu} - \mathbf{m})^{\top} \beta \boldsymbol{\Lambda} (\boldsymbol{\mu} - \mathbf{m}) + \ln |(\beta \boldsymbol{\Lambda})^{-1}| + D \ln 2 \pi \Bigr\} \\ &\qquad + \frac{1}{2} \Bigl\{ (\boldsymbol{\mu} - \hat{\mathbf{m}})^{\top} \hat{\beta} \boldsymbol{\Lambda} (\boldsymbol{\mu} - \hat{\mathbf{m}}) + \ln |(\hat{\beta} \boldsymbol{\Lambda})^{-1}| + D \ln 2 \pi \Bigr\} \\ &\qquad + \frac{\nu - D - 1}{2} \ln |\boldsymbol{\Lambda}| - \frac{1}{2} \mathrm{Tr}(\mathbf{W}^{-1} \Lambda) + \ln C_{\mathcal{W}}(\nu, \mathbf{W}) + \mathrm{const.} \end{aligned} $$

括弧を展開して$\boldsymbol{\Lambda}$に関して整理すると

$$ \begin{align} \ln p(\boldsymbol{\Lambda} | \mathbf{X}) &= - \frac{1}{2} \left\{ \sum_{n=1}^N \mathbf{x}_n^{\top} \boldsymbol{\Lambda} \mathbf{x}_n - 2 \sum_{n=1}^N \mathbf{x}_n^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} + N \boldsymbol{\mu}^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} + N \ln |\boldsymbol{\Lambda}|^{-1} \right\} \\ &\qquad - \frac{1}{2} \Bigl\{ \beta \boldsymbol{\mu}^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} - 2 \beta \mathbf{m}^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} + \beta \mathbf{m}^{\top} \boldsymbol{\Lambda} \mathbf{m} + \ln |\beta \boldsymbol{\Lambda}|^{-1} \Bigr\} \\ &\qquad + \frac{1}{2} \Bigl\{ \hat{\beta} \boldsymbol{\mu}^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} - 2 \hat{\beta} \hat{\mathbf{m}}^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} + \hat{\beta} \hat{\mathbf{m}}^{\top} \boldsymbol{\Lambda} \hat{\mathbf{m}} + \ln |\hat{\beta} \boldsymbol{\Lambda}|^{-1} \Bigr\} \\ &\qquad + \frac{\nu - D - 1}{2} \ln |\boldsymbol{\Lambda}| - \frac{1}{2} \mathrm{Tr}(\mathbf{W}^{-1} \Lambda) + \mathrm{const.} \\ &= - \frac{1}{2} \left\{ \sum_{n=1}^N \mathbf{x}_n^{\top} \boldsymbol{\Lambda} \mathbf{x}_n - 2 \sum_{n=1}^N \mathbf{x}_n^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} + N \boldsymbol{\mu}^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} - N \ln |\boldsymbol{\Lambda}| \right\} \\ &\qquad - \frac{1}{2} \Bigl\{ \beta \boldsymbol{\mu}^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} - 2 \beta \mathbf{m}^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} + \beta \mathbf{m}^{\top} \boldsymbol{\Lambda} \mathbf{m} - \ln \beta^D |\boldsymbol{\Lambda}| \Bigr\} \\ &\qquad + \frac{1}{2} \Bigl\{ (N + \beta) \boldsymbol{\mu}^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} - 2 \Bigl( \sum_{n=1}^N \mathbf{x}_n + \beta \mathbf{m} \Bigr)^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} + \hat{\beta} \hat{\mathbf{m}}^{\top} \boldsymbol{\Lambda} \hat{\mathbf{m}} - \ln \hat{\beta}^D |\boldsymbol{\Lambda}| \Bigr\} \\ &\qquad + \frac{\nu - D - 1}{2} \ln |\boldsymbol{\Lambda}| - \frac{1}{2} \mathrm{Tr}(\mathbf{W}^{-1} \Lambda) + \mathrm{const.} \\ &= \frac{N}{2} \ln |\boldsymbol{\Lambda}| - \frac{1}{2} \left\{ \sum_{n=1}^N \mathbf{x}_n^{\top} \boldsymbol{\Lambda} \mathbf{x}_n + \beta \mathbf{m}^{\top} \boldsymbol{\Lambda} \mathbf{m} - \hat{\beta} \hat{\mathbf{m}}^{\top} \boldsymbol{\Lambda} \hat{\mathbf{m}} - \ln \beta^D - \ln |\boldsymbol{\Lambda}| + \ln \hat{\beta}^D + \ln |\boldsymbol{\Lambda}| \right\} \\ &\qquad + \frac{\nu - D - 1}{2} \ln |\boldsymbol{\Lambda}| - \frac{1}{2} \mathrm{Tr}(\mathbf{W}^{-1} \Lambda) + \mathrm{const.} \\ &= \frac{N + \nu - D - 1}{2} \ln |\boldsymbol{\Lambda}| - \frac{1}{2} \mathrm{Tr} \left[ \left\{ \sum_{n=1}^N \mathbf{x}_n \mathbf{x}_n^{\top} + \beta \mathbf{m} \mathbf{m}^{\top} + \hat{\beta} \hat{\mathbf{m}} \hat{\mathbf{m}}^{\top} + \mathbf{W}^{^-1} \right\} \boldsymbol{\Lambda} \right] + \mathrm{const.} \tag{3.131} \end{align} $$

【途中式の途中式】(クリックで展開)

  1. 式を整理する。
    • 3行目の1つ目の項について、$\hat{\beta}$に式(3.129.a)を代入する。
    • 3行目の2つ目の項について、$\hat{\mathbf{m}}$に式(3.129.b)を代入する。
    • 行列式の性質より、$|\mathbf{A}^{-1}| = |\mathbf{A}|^{-1}$、$|c \mathbf{A}| = c^D |\mathbf{A}|$である。
  2. 対数の性質より、$\ln a^{-1} = - \ln a$、$\ln a b = \ln ab = \ln a + \ln b$である。
  3. $\mathrm{Tr}(\cdot)$に置き換えて式を整理する。

 精度行列を

$$ \boldsymbol{\Lambda} = \begin{bmatrix} \lambda_{11} & \cdots & \lambda_{1D} \\ \vdots & \ddots & \vdots \\ \lambda_{D1} & \cdots & \lambda_{DD} \end{bmatrix} $$

とすると、1行目の括弧内の1つ目の項は

$$ \begin{aligned} \sum_{n=1}^N \mathbf{x}_n^{\top} \boldsymbol{\Lambda} \mathbf{x}_n &= \sum_{n=1}^N \sum_{d=1}^D \sum_{d'=1}^D x_{nd} \lambda_{dd'} x_{nd'} \\ &= \mathrm{Tr} \left( \sum_{n=1}^N \begin{bmatrix} \sum_{d=1}^D x_{n1} x_{nd} \lambda_{d1} & \cdots & \sum_{d=1}^D x_{n1} x_{nd} \lambda_{dD} \\ \vdots & \ddots & \vdots \\ \sum_{d=1}^D x_{nD} x_{nd} \lambda_{d1} & \cdots & \sum_{d=1}^D x_{nD} x_{nd} \lambda_{dD} \end{bmatrix} \right) \\ &= \mathrm{Tr} \left( \sum_{n=1}^N \begin{bmatrix} x_{n1} x_{n1} & \cdots & x_{n1} x_{nD} \\ \vdots & \ddots & \vdots \\ x_{nD} x_{n1} & \cdots & x_{nD} x_{nD} \end{bmatrix} \begin{bmatrix} \lambda_{11} & \cdots & \lambda_{1D} \\ \vdots & \ddots & \vdots \\ \lambda_{D1} & \cdots & \lambda_{DD} \end{bmatrix} \right) \\ &= \mathrm{Tr} \left( \sum_{n=1}^N \begin{bmatrix} x_{n1} \\ \vdots \\ x_{nD} \end{bmatrix} \begin{bmatrix} x_{n1} & \cdots & x_{nD} \end{bmatrix} \begin{bmatrix} \lambda_{11} & \cdots & \lambda_{1D} \\ \vdots & \ddots & \vdots \\ \lambda_{D1} & \cdots & \lambda_{DD} \end{bmatrix} \right) \\ &= \mathrm{Tr} \left( \sum_{n=1}^N \mathbf{x}_n \mathbf{x}_n^{\top} \boldsymbol{\Lambda} \right) \end{aligned} $$

となる。同様に2つ目の項は

$$ \begin{aligned} \beta \mathbf{m}^{\top} \boldsymbol{\Lambda} \mathbf{m} &= \beta \sum_{d=1}^D \sum_{d'=1}^D m_d \lambda_{dd'} m_{d'} \\ &= \mathrm{Tr} \left( \beta \begin{bmatrix} \sum_{d=1}^D m_1 m_d \lambda_{d1} & \cdots & \sum_{d=1}^D m_1 m_d \lambda_{dD} \\ \vdots & \ddots & \vdots \\ \sum_{d=1}^D m_D m_d \lambda_{d1} & \cdots & \sum_{d=1}^D m_D m_d \lambda_{dD} \end{bmatrix} \right) \\ &= \mathrm{Tr}( \beta \mathbf{m} \mathbf{m}^{\top} \boldsymbol{\Lambda} ) \end{aligned} $$

となり、3つ目の項も$\hat{\beta} \hat{\mathbf{m}}^{\top} \boldsymbol{\Lambda} \hat{\mathbf{m}} = \mathrm{Tr}(\hat{\beta} \hat{\mathbf{m}} \hat{\mathbf{m}}^{\top} \boldsymbol{\Lambda})$となる。よって式(A.12)より

$$ - \frac{1}{2} \left( \sum_{n=1}^N \mathbf{x}_n^{\top} \boldsymbol{\Lambda} \mathbf{x}_n + \beta \mathbf{m}^{\top} \boldsymbol{\Lambda} \mathbf{m} - \hat{\beta} \hat{\mathbf{m}}^{\top} \boldsymbol{\Lambda} \hat{\mathbf{m}} \right) - \frac{1}{2} \mathrm{Tr}(\mathbf{W}^{-1} \Lambda) = - \frac{1}{2} \mathrm{Tr} \left( \sum_{n=1}^N \mathbf{x}_n \mathbf{x}_n^{\top} \boldsymbol{\Lambda} + \beta \mathbf{m} \mathbf{m}^{\top} \boldsymbol{\Lambda} + \hat{\beta} \hat{\mathbf{m}} \hat{\mathbf{m}}^{\top} \boldsymbol{\Lambda} + \mathbf{W}^{^-1} \boldsymbol{\Lambda} \right) $$

で置き換えられる。



となる。適宜$\boldsymbol{\Lambda}$に影響しない項を$\mathrm{const.}$にまとめている。

 式の形から$\boldsymbol{\Lambda}$の事後分布はウィシャート分布となることが分かる。そこで事後分布を次のようにおき

$$ p(\boldsymbol{\Lambda} | \mathbf{X}) = \mathcal{W}(\boldsymbol{\Lambda} | \hat{\nu}, \hat{\mathbf{W}}) \tag{3.132} $$

この式の対数をとり、$\boldsymbol{\Lambda}$に関して整理すると

$$ \ln p(\boldsymbol{\Lambda} | \mathbf{X}) = \frac{\hat{\nu} - D - 1}{2} \ln |\boldsymbol{\Lambda}| - \frac{1}{2} \mathrm{Tr}(\hat{\mathbf{W}}^{-1} \boldsymbol{\Lambda}) + \mathrm{const.} $$

となる。

 従って式(3.131)との対応関係から、事後分布のパラメータは

$$ \begin{align} \hat{\mathbf{W}}^{-1} &= \sum_{n=1}^N \mathbf{x}_n \mathbf{x}_n^{\top} + \beta \mathbf{m} \mathbf{m}^{\top} - \hat{\beta} \hat{\mathbf{m}} \hat{\mathbf{m}}^{\top} + \mathbf{W}^{-1} \\ \hat{\nu} &= N + \nu \tag{3.133} \end{align} $$

と求められる。

・予測分布の導出

 続いて未観測のデータ$\mathbf{x}_{*}$に対する予測分布$p(\mathbf{x}_{*})$を求めていく。

 1次元ガウス分布のときと同様に、パラメータ$\boldsymbol{\mu},\ \boldsymbol{\Lambda}$を周辺化(積分計算)して予測分布

$$ p(\mathbf{x}_{*}) = \int \int p(\mathbf{x}_{*} | \boldsymbol{\mu}, \boldsymbol{\Lambda}) p(\boldsymbol{\mu}, \boldsymbol{\Lambda}) d\boldsymbol{\mu} d\boldsymbol{\Lambda} \tag{3.134} $$

を求めることは避け、ベイズの定理を用いて

$$ p(\boldsymbol{\mu}, \boldsymbol{\Lambda} | \mathbf{x}_{*}) = \frac{ p(\mathbf{x}_{*} | \boldsymbol{\mu}, \boldsymbol{\Lambda}) p(\boldsymbol{\mu}, \boldsymbol{\Lambda}) }{ p(\mathbf{x}_{*}) } $$

予測分布を求める。

 この式の両辺の対数をとり、$p(\mathbf{x}_{*})$に関して式を整理すると

$$ \begin{align} \ln p(\boldsymbol{\mu}, \boldsymbol{\Lambda} | \mathbf{x}_{*}) &= \ln p(\mathbf{x}_{*} | \boldsymbol{\mu}, \boldsymbol{\Lambda}) + \ln p(\boldsymbol{\mu}, \boldsymbol{\Lambda}) - \ln p(\mathbf{x}_{*}) \\ \ln p(\mathbf{x}_{*}) &= \ln p(\mathbf{x}_{*} | \boldsymbol{\mu}, \boldsymbol{\Lambda}) - \ln p(\boldsymbol{\mu}, \boldsymbol{\Lambda} | \mathbf{x}_{*}) + \mathrm{const.} \tag{3.135} \end{align} $$

で計算できることが分かる。

 2つ目の項は、データが1つ($N = 1$)の事後分布(3.132)と捉えられることから、式(3.133)よりパラメータを

$$ \begin{align} \mathbf{m}_{*} &= \frac{1}{1 + \beta} (\mathbf{x}_{*} + \beta \mathbf{m}) \\ \mathbf{W}_{*}^{-1} &= \mathbf{x}_{*} \mathbf{x}_{*}^{\top} + \beta \mathbf{m} \mathbf{m}^{\top} - \hat{\beta} \hat{\mathbf{m}} \hat{\mathbf{m}}^{\top} + \mathbf{W}^{-1} \\ &= \mathbf{x}_{*} \mathbf{x}_{*}^{\top} + \beta \mathbf{m} \mathbf{m}^{\top} - \frac{1}{1 + \beta} (\mathbf{x}_{*} + \beta \mathbf{m}) (\mathbf{x}_{*} + \beta \mathbf{m})^{\top} + \mathbf{W}^{-1} \\ &= \mathbf{x}_{*} \mathbf{x}_{*}^{\top} + \beta \mathbf{m} \mathbf{m}^{\top} - \frac{1}{1 + \beta} \mathbf{x}_{*} \mathbf{x}_{*}^{\top} - \frac{2 \beta}{1 + \beta} \mathbf{x}_{*} \mathbf{m}^{\top} - \frac{\beta^2}{1 + \beta} \mathbf{m} \mathbf{m}^{\top} + \mathbf{W}^{-1} \\ &= \frac{\beta}{1 + \beta} \mathbf{x}_{*} \mathbf{x}_{*}^{\top} - \frac{2 \beta}{1 + \beta} \mathbf{x}_{*} \mathbf{m}^{\top} + \frac{\beta}{1 + \beta} \mathbf{m} \mathbf{m}^{\top} + \mathbf{W}^{-1} \\ &= \frac{\beta}{1 + \beta} (\mathbf{x}_{*} - \mathbf{m}) (\mathbf{x}_{*} - \mathbf{m})^{\top} + \mathbf{W}^{-1} \tag{3.137} \end{align} $$

とおくと

$$ p(\boldsymbol{\mu}, \boldsymbol{\Lambda} | \mathbf{m}_{*}) = \mathcal{N}(\boldsymbol{\mu} | \mathbf{m}_{*}, \{(1 + \beta) \boldsymbol{\Lambda}\}^{-1}) \mathcal{W}(\boldsymbol{\lambda} | 1 + \nu, \mathbf{W}_{*}) \tag{1.136} $$

となる。分かりにくければ更に$\beta_{*} = 1 + \beta$、$\nu_{*} = 1 + \nu$とおいてもよい。

 この式を式(3.135)に代入して、対数をとり具体的な式に置き換え

$$ \begin{aligned} \ln p(\mathbf{x}_{*}) &= \ln \mathcal{N}(\mathbf{x}_{*} | \boldsymbol{\mu}, \boldsymbol{\Lambda}) - \ln \mathcal{N}(\boldsymbol{\mu} | \mathbf{m}_{*}, \{(1 + \beta) \boldsymbol{\Lambda}\}^{-1}) - \ln \mathcal{W}(\boldsymbol{\Lambda} | 1 + \nu, \mathbf{W}_{*}) + \mathrm{const.} \\ &= - \frac{1}{2} \Bigl\{ (\mathbf{x}_{*} - \boldsymbol{\mu})^{\top} \boldsymbol{\Lambda} (\mathbf{x}_{*} - \boldsymbol{\mu}) + \ln |\boldsymbol{\Lambda}^{-1}| + D \ln 2 \pi \Bigr\} \\ &\qquad + \frac{1}{2} \Bigl\{ (\boldsymbol{\mu} - \mathbf{m}_{*})^{\top} (1 + \beta) \boldsymbol{\Lambda} (\boldsymbol{\mu} - \mathbf{m}_{*}) + \ln |\{(1 + \beta) \boldsymbol{\Lambda}\}^{-1}| + D \ln 2 \pi \Bigr\} \\ &\qquad - \frac{1 + \nu - D - 1}{2} \ln |\boldsymbol{\Lambda}| + \frac{1}{2} \mathrm{Tr}(\mathbf{W}_{*}^{-1} \boldsymbol{\Lambda}) \\ &\qquad + \frac{1 + \nu}{2} \ln |\mathbf{W}_{*}| + \frac{(1 + \nu) D}{2} \ln 2 + \frac{D (D - 1)}{4} \ln \pi + \sum_{d=1}^D \ln \Gamma \Bigl(\frac{1 + \nu + 1 - d}{2} \Bigr) + \mathrm{const.} \end{aligned} $$

括弧を展開して$\mathbf{x}_{*}$に関して式を整理すると

$$ \begin{align} \ln p(\mathbf{x}_{*}) &= - \frac{1}{2} \Bigl\{ \mathbf{x}_{*}^{\top} \boldsymbol{\Lambda} \mathbf{x}_{*} - 2 \mathbf{x}_{*}^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} + \boldsymbol{\mu}^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} \Bigr\} \\ &\qquad + \frac{1}{2} \Bigl\{ (1 + \beta) \boldsymbol{\mu}^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} - 2 (1 + \beta) \boldsymbol{\mu}^{\top} \boldsymbol{\Lambda} \mathbf{m}_{*} + (1 + \beta) \mathbf{m}_{*}^{\top} \boldsymbol{\Lambda} \mathbf{m}_{*} \Bigr\} \\ &\qquad + \frac{1}{2} \mathrm{Tr} \left( \frac{\beta}{1 + \beta} (\mathbf{x}_{*} - \mathbf{m}) (\mathbf{x}_{*} - \mathbf{m})^{\top} \boldsymbol{\Lambda} + \mathbf{W}^{-1} \boldsymbol{\Lambda} \right) \\ &\qquad + \frac{1 + \nu}{2} \ln \Bigl| \mathbf{W}^{-1} + \frac{\beta}{1 + \beta} (\mathbf{x}_{*} - \mathbf{m}) (\mathbf{x}_{*} - \mathbf{m})^{\top} \Bigr|^{-1} + \mathrm{const.} \\ &= - \frac{1}{2} \Bigl\{ \mathbf{x}_{*}^{\top} \boldsymbol{\Lambda} \mathbf{x}_{*} - 2 \mathbf{x}_{*}^{\top} \boldsymbol{\Lambda} \boldsymbol{\mu} \Bigr. \\ &\qquad \Bigl. + 2 \boldsymbol{\mu}^{\top} \boldsymbol{\Lambda} (\mathbf{x}_{*} + \beta \mathbf{m}) - \frac{1}{1 + \beta} (\mathbf{x}_{*} + \beta \mathbf{m})^{\top} \boldsymbol{\Lambda} (\mathbf{x}_{*} + \beta \mathbf{m}) \Bigr\} \\ &\qquad + \frac{1}{2} \frac{\beta}{1 + \beta} (\mathbf{x}_{*} - \mathbf{m})^{\top} \boldsymbol{\Lambda} (\mathbf{x}_{*} - \mathbf{m}) + \frac{1}{2} \mathrm{Tr}(\mathbf{W}^{-1} \boldsymbol{\Lambda}) \\ &\qquad - \frac{1 + \nu}{2} \ln |\mathbf{W}^{-1}| \left| \mathbf{I}_D + \frac{\beta}{1 + \beta} \mathbf{W} (\mathbf{x}_{*} - \mathbf{m}) (\mathbf{x}_{*} - \mathbf{m})^{\top} \right| + \mathrm{const.} \\ &= - \frac{1}{2} \Bigl\{ \mathbf{x}_{*}^{\top} \boldsymbol{\Lambda} \mathbf{x}_{*} - \frac{1}{1 + \beta} \mathbf{x}_{*}^{\top} \boldsymbol{\Lambda} \mathbf{x}_{*} - \frac{2 \beta}{1 + \beta} \mathbf{x}_{*}^{\top} \boldsymbol{\Lambda} \mathbf{m} - \frac{\beta^2}{1 + \beta} \mathbf{m}^{\top} \boldsymbol{\Lambda} \mathbf{m} \Bigr\} \\ &\qquad + \frac{1}{2} \left\{ \frac{\beta}{1 + \beta} \mathbf{x}_{*}^{\top} \boldsymbol{\Lambda} \mathbf{x}_{*} - \frac{2 \beta}{1 + \beta} \mathbf{x}_{*}^{\top} \boldsymbol{\Lambda} \mathbf{m} + \frac{\beta}{1 + \beta} \mathbf{m}^{\top} \boldsymbol{\Lambda} \mathbf{m} \right\} \\ &\qquad - \frac{1 + \nu}{2} \ln \left| \mathbf{I}_1 + \frac{\beta}{1 + \beta} (\mathbf{x}_{*} - \mathbf{m})^{\top} \mathbf{W}^{\top} (\mathbf{x}_{*} - \mathbf{m}) \right| - \frac{1 + \nu}{2} \ln |\mathbf{W}^{-1}| + \mathrm{const.} \\ &= - \frac{1 + \nu}{2} \ln \left\{ 1 + \frac{\beta}{1 + \beta} (\mathbf{x}_{*} - \mathbf{m})^{\top} \mathbf{W} (\mathbf{x}_{*} - \mathbf{m}) \right\} + \mathrm{const.} \tag{3.138} \end{align} $$

【途中式の途中式】

  1. 式(A.18)より$|\mathbf{W}_{*}| = |\mathbf{W}_{*}^{-1}|^{-1}$として、$\mathbf{W}_{*}$に式(3.137)を代入する。
  2. 式を整理する。
    • $\mathbf{m}_{*}$に式(3.137)を代入する。
    • 3行目について式(A.12)より分解し、また3.4.2項で確認した$\mathrm{Tr}(\mathbf{a} \mathbf{a}^{\top} \mathbf{B}) = \mathbf{a}^{\top} \mathbf{B} \mathbf{a}$より変形する。
    • 4行目について式(A.17)より分解する。
  3. $\mathbf{W} (\mathbf{x}_{*} - \mathbf{m})$を1つの行列とみると、$D \times 1$、$1 \times D$の行列の積なので、式(A.19)の変形を行う。
  4. 式を整理する。
    • 行列式の定義より、スカラの行列式はスカラ$|a| = a$である。
    • $(\mathbf{x}_{*} - \boldsymbol{\mu})^{\top} \mathbf{W}^{\top} (\mathbf{x}_{*} - \boldsymbol{\mu})$はスカラなので転置しても影響を受けないため、$(\mathbf{x}_{*} - \boldsymbol{\mu})^{\top} \mathbf{W}^{\top} (\mathbf{x}_{*} - \boldsymbol{\mu}) = {(\mathbf{x}_{*} - \boldsymbol{\mu})^{\top} \mathbf{W}^{\top} (\mathbf{x}_{*} - \boldsymbol{\mu})}^{\top}$である。また式(A.2)と同様に、$(A B C)^{\top} = C^{\top} B^{\top} A^{\top}$である。


となる。適宜$\mathbf{x}_{*}$に影響しない項を$\mathrm{const.}$にまとめている。

 式の形から予測分布は多次元のスチューデントのt分布となることが分かる。そこで次のようにおき

$$ p(\mathbf{x}_{*}) = \mathrm{St}(\mathbf{x}_{*} | \boldsymbol{\mu}_s, \boldsymbol{\Lambda}_s, \nu_s) \tag{3.193} $$

この式の対数をとり、$\mathbf{x}_{*}$に関して整理すると

$$ \ln \mathrm{St}(\mathbf{x} | \boldsymbol{\mu}_s, \boldsymbol{\Lambda}_s, \nu_s) = - \frac{\nu_s + D}{2} \ln \left\{ 1 + \frac{1}{\nu_s} (\mathbf{x}_{*} - \boldsymbol{\mu}_s)^{\top} \boldsymbol{\Lambda}_s (\mathbf{x}_{*} - \boldsymbol{\mu}_s) \right\} + \mathrm{const.} \tag{3.122} $$

となる。

 従って式(3.138)との対応関係から、予測分布のパラメータは

$$ \begin{align} \boldsymbol{\mu}_s &= \mathbf{m} \\ \boldsymbol{\Lambda}_s &= \frac{ \nu_s \beta }{ 1 + \beta } \mathbf{W} \\ &= \frac{ (1 - D + \nu) \beta }{ 1 + \beta } \mathbf{W} \\ \nu_s &= 1 - D + \nu \tag{3.140} \end{align} $$

と求まる。

 事前分布のパラメータ$\beta,\ \mathbf{m},\ \mathbf{W},\ \nu$を事後分布のパラメータ$\hat{\beta},\ \hat{\mathbf{m}},\ \hat{\mathbf{W}},\ \hat{\nu}$に置き換えると、観測データによって学習を行った予測分布$p(\mathbf{x}_{*} | \mathbf{X})$のパラメータ

$$ \begin{aligned} \hat{\boldsymbol{\mu}}_s &= \hat{\mathbf{m}} \\ &= \frac{1}{N + \beta} \left( \sum_{n=1}^N \mathbf{x}_n + \beta \mathbf{m} \right) \\ \hat{\boldsymbol{\Lambda}}_s &= \frac{ (1 - D + \hat{\nu}) \hat{\beta} }{ 1 + \hat{\beta} } \hat{\mathbf{W}} \\ &= \frac{ (1 - D + \hat{\nu}) (N + \beta) }{ N + 1 + \beta } \hat{\mathbf{W}} \\ \hat{\nu}_s &= 1 - D + \hat{\nu} \\ &= N + 1 - D + \nu \end{aligned} $$

が得られる。

・Rでやってみよう

 多次元ガウス分布に従いランダムに生成したデータを用いて、パラメータを推定してみましょう。

 利用するパッケージを読み込みます。

# 利用パッケージ
library(tidyverse)
library(mvtnorm)
library(mvnfast)

 mvtnormパッケージは、多次元ガウス分布に関するパッケージです。多次元ガウス分布に従う乱数生成関数rmvnorm()を使います。確率密度の計算には、mvnfastパッケージを利用します。多次元ガウス分布の確率密度はdmvn()、多次元スチューデントのt分布の確率密度関数はdmvt()です。

・観測モデルの設定

 観測モデルのパラメータを指定します。この例では2次元のグラフで表現するため、$D = 2$のときのみ動作します。

# データ数を指定
N <- 100
D <- 2 # (固定)

# 観測モデルのパラメータを指定
mu_truth_d <- c(25, 50)
sigma_truth_dd <- matrix(c(50, 30, 30, 50), nrow = 2, ncol = 2)
lambda_truth_dd <- solve(sigma_truth_dd^2)

 観測モデルの平均パラメータ$\boldsymbol{\mu}$をmu_truth_dとします。
 分散共分散行列$\boldsymbol{\Sigma}$ではなく、分散共分散の平方根$(\sigma_{11}, \cdots, \sigma_{DD})$をsigma_truth_ddとします。これは作図時に標準偏差を利用するためです。matrix()のデフォルトの仕様上、$(\sigma_{11}, \sigma_{12}, \sigma_{21}, \sigma_{22})$の順番に値を指定します。
 sigma_truth_ddの各要素を2乗したものが$\boldsymbol{\Sigma}$です。更にその逆行列が精度パラメータ$\boldsymbol{\Lambda}$でした。逆行列の計算はsolve()で行い、その計算結果をlambda_truth_ddとします。
 この平均と精度行列の値を推論するのがこの項の目的になります。

 次に事前分布のパラメータを指定します。

# 事前分布のパラメータを指定
m_d <- c(0, 0)
beta <- 1
W_dd <- matrix(c(0.00005, 0, 0, 0.00005), nrow = 2, ncol = 2)
nu <- D

 $\boldsymbol{\mu}$の事前分布(多次元ガウス分布)のパラメータ$\mathbf{m},\ \beta$をそれぞれm_dbetaとします。
 $\boldsymbol{\Lambda}$の事前分布(ウィシャート事前分布)のパラメータ$\mathbf{W},\ \nu$をそれぞれW_ddnuとします。$\nu > D - 1$である必要があります。

 作図時に利用する格子状の点(データ)と真の平均の点のデータフレームを作成しておきます。

# 作図用の点を生成
x_vec <- seq(mu_truth_d[1] - 2 * sigma_truth_dd[1, 1], mu_truth_d[1] + 2 * sigma_truth_dd[1, 1], by = 0.5)
y_vec <- seq(mu_truth_d[2] - 2 * sigma_truth_dd[2, 2], mu_truth_d[2] + 2 * sigma_truth_dd[2, 2], by = 0.5)
point_df <- tibble(
  x = rep(x_vec, times = length(y_vec)), 
  y = rep(y_vec, each = length(x_vec))
)
mu_df <- tibble(
  x = mu_truth_d[1], 
  y = mu_truth_d[2]
)

 この例では、平均から標準偏差の2倍の範囲をグラフ化することにします。

 設定した観測モデルのパラメータに従って乱数を生成します。

# 2次元ガウス分布に従うデータを生成
x_nd <- mvtnorm::rmvnorm(n = N, mean = mu_truth_d, sigma = sigma_truth_dd^2)
summary(x_nd)
##        V1                V2        
##  Min.   :-71.964   Min.   :-46.09  
##  1st Qu.: -6.562   1st Qu.: 23.21  
##  Median : 28.341   Median : 54.85  
##  Mean   : 29.014   Mean   : 55.10  
##  3rd Qu.: 63.876   3rd Qu.: 86.32  
##  Max.   :130.706   Max.   :177.49

 mvtnorm::rmvnorm()で多次元ガウス分布に従う乱数を生成し、観測データ$\mathbf{X}$(x_nd)とします。
 引数nにはデータ数N、引数meanには平均パラメータmu_truth_d、引数sigmaにはsimga_truth_ddの2乗を指定します。精度行列の逆行列solve(lambda_truth_dd)を引数sigmaに指定することもできます。

 観測データを散布図で確認しましょう。

# 観測データのデータフレーム
sample_df <- tibble(
  x = x_nd[, 1], 
  y = x_nd[, 2]
)

# 観測モデルのデータフレーム
model_df <- cbind(
  point_df, 
  density = mvnfast::dmvn(
    X = as.matrix(point_df), mu = mu_truth_d, sigma = sigma_truth_dd^2
  ) # 確率密度
)

# 観測データの散布図を作成
ggplot() + 
  geom_point(data = sample_df, aes(x = x, y = y)) + # 観測データ
  geom_contour(data = model_df, aes(x, y, z = density, color = ..level..)) + # 観測モデル
  geom_point(data = mu_df, aes(x = x, y = y), color = "red", shape = 3, size = 5) + # 平均パラメータ
  labs(title = "Multivariate Gaussian Distribution", 
       subtitle = paste0("N=", N, ", mu=(", paste(round(mu_truth_d, 1), collapse = ", "), ")", 
                         ", sigma=(", paste(round(sigma_truth_dd, 1), collapse = ", "), ")"), 
       x = expression(x[1]), y = expression(x[2]), 
       color = "density") # ラベル

f:id:anemptyarchive:20201031223456p:plain
観測データの散布図と観測モデル:多次元ガウス分布

 観測モデルの分布と重ねて描画します。
 多次元ガウス分布の確率密度は、mvnfast::dmvn()で計算します。引数は、データの値X、平均mu、分散共分散行列sigmaです。第1引数Xには、マトリクス形式の複数データ(x_nd)を渡すことができます。
 等高線グラフgeom_contour()には、格子状の点を渡す必要があります。

 では事後分布を求めていきます。

・事後分布

 観測データを使って、事後分布のパラメータを計算します。

# 事後分布のパラメータを計算
beta_hat <- N + beta
m_hat_d <- (colSums(x_nd) + beta * m_d) / beta_hat

 $\boldsymbol{\mu}$の事後分布のパラメータ$\hat{\beta},\ \hat{\mathbf{m}}$をそれぞれbeta_hatm_hat_dとします。計算式(更新式)は次の式です。

$$ \begin{align} \hat{\beta} &= N + \beta \\ \hat{\mathbf{m}} &= \frac{1}{\hat{\beta}} \left( \sum_{n=1}^N \mathbf{x}_n + \beta \mathbf{m} \right) \tag{3.129} \end{align} $$

 続いて$\boldsymbol{\Lambda}$の事後分布のパラメータを計算します。

# 事後分布のパラメータを計算
tmp_x <- t(x_nd) %*% as.matrix(x_nd)
tmp_m <- beta * as.matrix(m_d) %*% t(m_d)
tmp_m_hat <- beta_hat * as.matrix(m_hat_d) %*% t(m_hat_d)
W_hat_dd <- solve(
  tmp_x + tmp_m - tmp_m_hat + solve(W_dd)
)
nu_hat <- N + nu

 パラメータ$\hat{\mathbf{W}},\ \hat{\nu}$をそれぞれW_hat_ddnu_hatとします。計算式(更新式)は次の式です。

$$ \begin{align} \hat{\mathbf{W}}^{-1} &= \sum_{n=1}^N \mathbf{x}_n \mathbf{x}_n^{\top} + \beta \mathbf{m} \mathbf{m}^{\top} - \hat{\beta} \hat{\mathbf{m}} \hat{\mathbf{m}}^{\top} + \mathbf{W}^{-1} \\ \hat{\nu} &= N + \nu \tag{3.133} \end{align} $$

 ただし上の実装例では、$\sum_{n=1}^N \mathbf{x}_n \mathbf{x}_n^{\top}$の計算を効率よく処理するために転置して計算しています。この式の通りに計算するには、次のような処理になります。

sum_dot_x <- matrix(0, D, D)
for(n in 1:N) {
  dot_x <- matrix(x_nd[n, ] - mu_d) %*% t(x_nd[n, ] - mu_d)
  sum_dot_x <- sum_dot_x + dot_x
}
W_hat_dd <- solve(sum_dot_x + solve(W_dd))

 for()によってN回加算する処理が$\sum_{n=1}^N$の計算に対応します。

 このパラメータを用いて$\boldsymbol{\Lambda}$の期待値を計算します。

# 精度パラメータの期待値を計算
lambda_E_dd <- nu_hat * W_hat_dd

 ウィシャート分布の期待値(2.89)より

$$ \mathbb{E}[\hat{\boldsymbol{\Lambda}}] = \hat{\nu} \hat{\mathbf{W}} $$

で計算できます。

 求めたパラメータを使って、事後分布(の確率密度)を計算します。

# 事後分布を計算
posterior_df <- cbind(
  point_df, 
  density = mvnfast::dmvn(
    X = as.matrix(point_df), mu = m_hat_d, sigma = solve(lambda_E_dd)
  ) # 確率密度
)
head(posterior_df)
##       x   y      density
## 1 -75.0 -50 1.925845e-06
## 2 -74.5 -50 1.958757e-06
## 3 -74.0 -50 1.992010e-06
## 4 -73.5 -50 2.025603e-06
## 5 -73.0 -50 2.059533e-06
## 6 -72.5 -50 2.093800e-06

 作図用にデータフレームとしておきます。

 多次元ガウス分布を作図します。

# 作図
ggplot() + 
  geom_contour(data = posterior_df, aes(x, y, z = density, color = ..level..)) + # 事後分布
  geom_contour(data = model_df, aes(x, y, z = density, color = ..level..), 
               alpha = 0.5, linetype = "dashed") + # 観測モデル
  geom_point(data = mu_df, aes(x = x, y = y), color = "red", shape = 3, size = 5) + # 平均値
  labs(title = "Multivariate Gaussian Distribution", 
       subtitle = paste0("N=", N, ", mu_hat=(", paste(round(m_hat_d, 1), collapse = ", "), ")", 
                         ", E_sigma_hat=(", paste(round(sqrt(solve(lambda_E_dd)), 1), collapse = ", "), ")"), 
       x = expression(x[1]), y = expression(x[2]), 
       color = "density") # ラベル

f:id:anemptyarchive:20201031223549p:plain
事後分布:多次元ガウス分布

 観測データの散布図のときと同様に、観測モデル(真の分布)と重ねて描画します。

 続いて予測分布を求めていきます。

・予測分布

 事後分布のパラメータを使って、予測分布のパラメータを計算します。

# 予測分布のパラメータを計算
mu_s_hat_d <- m_hat_d
lambda_s_hat_dd <- (1 - D + nu_hat) * beta_hat / (1 + beta_hat) * W_hat_dd
nu_s_hat <- 1 - D + nu_hat

 予測分布(多次元スチューデントのt分布)のパラメータ$\hat{\boldsymbol{\mu}}_s,\ \hat{\boldsymbol{\Lambda}}_s,\ \hat{\nu}_s$をそれぞれmu_s_dlambda_s_hat_ddnu_s_hatとします。計算式(更新式)は次の式です。

$$ \begin{aligned} \hat{\boldsymbol{\mu}}_s &= \hat{\mathbf{m}} \\ \hat{\boldsymbol{\Lambda}}_s &= \frac{ (1 - D + \hat{\nu}) \hat{\beta} }{ 1 + \hat{\beta} } \hat{\mathbf{W}} \\ \hat{\nu}_s &= 1 - D + \hat{\nu} \end{aligned} $$

 求めたパラメータを使って、予測分布(の確率密度)を計算します。

# 予測分布を計算
predict_df <- cbind(
  point_df, 
  density = mvnfast::dmvt(
    X = as.matrix(point_df), mu = mu_s_hat_d, sigma = solve(lambda_s_hat_dd), df = nu_s_hat
  ) # 確率密度
)
head(predict_df)
##       x   y      density
## 1 -75.0 -50 2.125815e-06
## 2 -74.5 -50 2.159774e-06
## 3 -74.0 -50 2.194058e-06
## 4 -73.5 -50 2.228666e-06
## 5 -73.0 -50 2.263595e-06
## 6 -72.5 -50 2.298843e-06

 多次元スチューデントのt分布の確率密度は、mvnfast::dmvt()で計算できます。dmvt()も複数のデータを一度に処理できます。また平均引数muにはmu_s_d、スケール引数sigmaにはlambda_s_hat_ddの逆行列、自由度引数dfにはnu_s_hatを指定します。

 予測分布を作図します。

# 作図
ggplot() + 
  geom_contour(data = predict_df, aes(x, y, z = density, color = ..level..)) + # 予測分布
  geom_contour(data = model_df, aes(x, y, z = density, color = ..level..), 
               alpha = 0.5, linetype = "dashed") + # 観測モデル
  geom_point(data = mu_df, aes(x = x, y = y), color = "red", shape = 3, size = 5) + # 平均パラメータ
  labs(title = "Multivariate Student's t Distribution", 
       subtitle = paste0("N=", N, ", mu_s_hat=(", paste(round(mu_s_hat_d, 1), collapse = ", "), ")", 
                         ", lambda_s_hat=(", paste(round(lambda_s_hat_dd, 1), collapse = ", "), ")", 
                         ", df=", nu_s_hat), 
       x = expression(x[1]), y = expression(x[2]), 
       color = "density") # ラベル

f:id:anemptyarchive:20201031223652p:plain
予測分布:多次元スチューデントのt分布

 これまでと同様に作図できます。

・おまけ

 gganimateパッケージを利用して、パラメータの推定値の推移を確認するためのgif画像を作成するコードです。

・コード(クリックで展開)

# 追加パッケージ
library(gganimate)

# データ数を指定
N <- 100
D <- 2 # (固定)

# 観測モデルのパラメータを指定
mu_truth_d <- c(25, 50)
sigma_truth_dd <- matrix(c(50, 30, 30, 50), nrow = 2, ncol = 2)
lambda_truth_dd <- solve(sigma_truth_dd^2)

# 事前分布のパラメータを指定
m_d <- c(0, 0)
beta <- 1
W_dd <- matrix(c(0.00005, 0, 0, 0.00005), nrow = 2, ncol = 2)
nu <- D
lambda_E_dd <- nu * W_dd

# 作図用の点を生成
x_vec <- seq(
   mu_truth_d[1] - 3 * sigma_truth_dd[1, 1], 
   mu_truth_d[1] + 3 * sigma_truth_dd[1, 1], 
   by = 1
)
y_vec <- seq(
   mu_truth_d[2] - 3 * sigma_truth_dd[2, 2], 
   mu_truth_d[2] + 3 * sigma_truth_dd[2, 2], 
   by = 1
)
point_df <- tibble(
  x = rep(x_vec, times = length(y_vec)), 
  y = rep(y_vec, each = length(x_vec))
)
mu_df <- tibble(
  x = mu_truth_d[1], 
  y = mu_truth_d[2]
)

# 観測モデルを計算
model_df <- cbind(
  point_df, 
  density = mvnfast::dmvn(
    X = as.matrix(point_df), mu = mu_truth_d, sigma = sigma_truth_dd^2
  ) # 確率密度
)

# 事前分布を計算
posterior_df <- cbind(
  point_df, 
  density = mvnfast::dmvn(
    X = as.matrix(point_df), mu = m_d, sigma = solve(lambda_E_dd)
  ), # 確率密度
  iteration = 0 # 試行回数
)

# 予測分布のパラメータを計算
mu_s_d <- m_d
lambda_s_dd <- (1 - D + nu) * beta / (1 + beta) * W_dd
nu_s <- 1 - D + nu

# 予測分布を計算
predict_df <- cbind(
  point_df, 
  density = mvnfast::dmvt(
    X = as.matrix(point_df), mu = mu_s_d, sigma = solve(lambda_s_dd), df = nu_s
  ), # 確率密度
  iteration = 0 # 試行回数
)

# ベイズ推論
for(n in 1:N) {
  
  # 2次元ガウス分布に従うデータを生成
  x_nd <- mvtnorm::rmvnorm(n = 1, mean = mu_truth_d, sigma = sigma_truth_dd^2)
  
  # 観測データを記録
  if(n > 1) { # 初回以外
    # オブジェクトを結合
    sample_mat <- rbind(sample_mat, x_nd)
    sample_df <- tibble(
      x = sample_mat[, 1],
      y = sample_mat[, 2], 
      iteration = n
    ) %>% 
      rbind(sample_df, .)
  } else if(n == 1){ # 初回
    # オブジェクトを作成
    sample_mat <- x_nd
    sample_df <- tibble(
      x = sample_mat[, 1],
      y = sample_mat[, 2], 
      iteration = n
    )
  }
  
  # 事後分布のパラメータを更新
  old_beta <- beta
  old_m_d <- m_d
  beta <- 1 + beta
  m_d <- as.vector(
    (x_nd + old_beta * m_d) / beta
  )
  tmp_x <- t(x_nd) %*% as.matrix(x_nd)
  tmp_m <- old_beta * as.matrix(old_m_d) %*% t(old_m_d)
  tmp_m_hat <- beta * as.matrix(m_d) %*% t(m_d)
  W_dd <- solve(
    tmp_x + tmp_m - tmp_m_hat + solve(W_dd)
  )
  nu <- 1 + nu
  lambda_E_dd <- nu * W_dd
  
  # 事後分布を計算
  tmp_posterior_df <- cbind(
    point_df, 
    density = mvnfast::dmvn(
      X = as.matrix(point_df), mu = m_d, sigma = solve(lambda_E_dd)
    ), # 確率密度
    iteration = n # 試行回数
  )
  
  # 予測分布のパラメータを計算
  mu_s_d <- m_d
  lambda_s_dd <- (1 - D + nu) * beta / (1 + beta) * W_dd
  nu_s <- 1 - D + nu
  
  # 予測分布を計算
  tmp_predict_df <- cbind(
    point_df, 
    density = mvnfast::dmvt(
      X = as.matrix(point_df), mu = mu_s_d, sigma = solve(lambda_s_dd), df = nu_s
    ), # 確率密度
    iteration = n # 試行回数
  )
  
  # 推論結果を結合
  posterior_df <- rbind(posterior_df, tmp_posterior_df)
  predict_df <- rbind(predict_df, tmp_predict_df)
  
  # 動作確認
  print(n)
}

# 事後分布の期待値を用いた分布を作図
posterior_graph <- ggplot() + 
  geom_contour(data = posterior_df, aes(x, y, z = density, color = ..level..)) + # 精度の期待値を用いた分布
  geom_point(data = sample_df, aes(x = x, y = y)) + # 観測データ
  geom_contour(data = model_df, aes(x, y, z = density, color = ..level..), 
               alpha = 0.5, linetype = "dashed") + # 観測モデル
  geom_point(data = mu_df, aes(x = x, y = y), color = "red", shape = 3, size = 5) + # 平均パラメータ
  transition_manual(iteration) +  # フレーム
  labs(title = "Multivariate Gaussian Distribution", 
       subtitle = "N={current_frame}", 
       x = expression(x[1]), y = expression(x[2]), 
       color = "density") # ラベル

# gif画像を作成
animate(posterior_graph, nframes = N + 1, fps = 5)

# 予測分布を作図
predict_graph <- ggplot() + 
  geom_contour(data = predict_df, aes(x, y, z = density, color = ..level..)) + # 予測分布
  geom_point(data = sample_df, aes(x = x, y = y)) + # 観測データ
  geom_contour(data = model_df, aes(x, y, z = density, color = ..level..), 
               alpha = 0.5, linetype = "dashed") + # 観測モデル
  geom_point(data = mu_df, aes(x = x, y = y), color = "red", shape = 3, size = 5) +  # 平均パラメータ
  transition_manual(iteration) +  # フレーム
  labs(title = "Multivariate Student's t Distribution", 
       subtitle = "N={current_frame}", 
       x = expression(x[1]), y = expression(x[2]), 
       color = "density") # ラベル

# gif画像を作成
animate(predict_graph, nframes = N + 1, fps = 5)

 異なる点のみを簡単に解説します。

 各データによってどのように学習する(推定値が変化する)のかを確認するため、こちらはfor()で1データずつ処理します。よって観測データ数Nがイタレーション数になります。

 一度の処理で事後分布のパラメータを計算するのではなく、事前分布のパラメータに対して繰り返し観測データの情報を与えることで更新(上書き)していきます。$\sum_{n=1}^N$の計算は、N回繰り返し計算することで実行されます。
 ただし事後分布のパラメータの計算(3.129.b)の計算において、更新前(事前分布)のパラメータ$\beta,\ \mathbf{m}$を使うため、それぞれold_betaold_m_dとして値を一時的に保存しておきます。

f:id:anemptyarchive:20201031223713g:plain
事後分布の推移

f:id:anemptyarchive:20201031223802g:plain
予測分布の推移


参考文献

  • 須山敦志『ベイズ推論による機械学習入門』(機械学習スタートアップシリーズ)杉山将監修,講談社,2017年.

おわりに

 多次元ガウス分布のベイズ推論完了!ただただ愚直に解きました。もっとスマートに導出できたりするのでしょうか?

 そろそろPRMLに移ってもいいもんでしょうかね。次の線形回帰ができたら試しに読んでみる(2度目の正直)。

【次節の内容】

www.anarchive-beta.com