からっぽのしょこ

読んだら書く!書いたら読む!同じ事は二度調べ(たく)ない

【R】4.4.3:ガウス混合モデルの変分推論【緑ベイズ入門のノート】

はじめに

 この記事は、R Advent Calendar 2020の10日目の記事です。

 『ベイズ推論による機械学習入門』の学習時のノートです。基本的な内容は「数式の行間を読んでみた」とそれを「RとPythonで組んでみた」になります。「数式」と「プログラム」から理解するのが目標です。

 この記事は4.4.3項の内容です。多次元ガウス混合モデルにおける変分推論による近似推論をRで実装します。

 省略してある内容等ありますので、本と併せて読んでください。初学者な自分が理解できるレベルまで落として書き下していますので、分かる人にはかなりくどくなっています。同じような立場の人のお役に立てれば幸いです。

【数式読解編】

www.anarchive-beta.com

【前節の内容】

www.anarchive-beta.com

【他の節一覧】

www.anarchive-beta.com

【この節の内容】

・Rでやってみる

 ガウス混合モデルに従い生成したデータを用いて、パラメータを推定してみましょう。

 利用するパッケージを読み込みます。

# 利用パッケージ
library(tidyverse)
library(mvnfast)

 mvnfastは、多次元ガウス分布に関するパッケージです。多次元ガウス分布に従う乱数の生成関数rmvn()、確率密度の計算関数dmvn()を使います。

・真の観測モデルの設定

 観測モデル$p(\mathbf{X} | \mathbf{S}, \boldsymbol{\mu}, \boldsymbol{\Lambda})$のパラメータを設定します。作図に関しては、2次元のグラフで表現するため$D = 2$のときのみ動作します。

# 真の観測モデルのパラメータを指定
D <- 2 # (固定)
K <- 3
mu_true_kd <- matrix(
  c(0, 4, 
    -5, -5, 
    5, -2.5), nrow = K, ncol = D, byrow = TRUE
)
sigma2_true_ddk <- array(
  c(8, 0, 0, 8, 
    4, -2.5, -2.5, 4, 
    6.5, 4, 4, 6.5), dim = c(D, D, K)
)

# 確認
mu_true_kd
##      [,1] [,2]
## [1,]    0  4.0
## [2,]   -5 -5.0
## [3,]    5 -2.5
sigma2_true_ddk
## , , 1
## 
##      [,1] [,2]
## [1,]    8    0
## [2,]    0    8
## 
## , , 2
## 
##      [,1] [,2]
## [1,]  4.0 -2.5
## [2,] -2.5  4.0
## 
## , , 3
## 
##      [,1] [,2]
## [1,]  6.5  4.0
## [2,]  4.0  6.5

 真の観測モデルにおけるクラスタごとの平均パラメータ$\boldsymbol{\mu} = \{\boldsymbol{\mu_1}, \cdots, \boldsymbol{\mu}_K\}$、$\boldsymbol{\mu}_k = \{\mu_{k,1}, \cdots, \mu_{k,D}\}$をmu_true_kdとし、分散共分散行列(精度行列の逆行列)パラメータ$\boldsymbol{\Lambda}^{-1} = \{\boldsymbol{\Lambda}_1^{-1}, \cdots, \boldsymbol{\Lambda}_K^{-1}\}$、$\boldsymbol{\Lambda}_k^{-1} = \boldsymbol{\Sigma}_k = \{\sigma_{k,1,1}^2, \cdots, \sigma_{k,D,D}^2\}$をsigma2_true_ddkとして値を指定します。ただし$\boldsymbol{\Sigma}_k$は、正定値行列である必要があります。
 mu_true_kdの作成において、引数にbyrow = TRUEを指定すると、クラスタ1から$K$の順に値を指定できます。sigma2_true_ddkにおいては、デフォルトの仕様上$\sigma_{d,d',k}^2$として$(\sigma_{1,1,1}^2, \sigma_{2,1,1}^2, \sigma_{1,2,1}^2, \sigma_{2,2,1}^2, \sigma_{1,1,2}^2, \sigma_{2,1,2}^2, \cdots)$の順に値を指定します。
 この2つのパラメータの値を観測データから求めることが目的となります。

 混合比率も設定します。

# 真の混合比率を指定
pi_true_k <- c(0.5, 0.2, 0.3)

 混合比率$\boldsymbol{\pi} = \{\pi_1, \cdots, \pi_K\}$をpi_true_kとして、値を指定します。ここで$\pi_k$は各データ$\mathbf{x}_n$のクラスタ(各潜在変数$\mathbf{s}_n$)が$k$となる確率であり、$0 \leq \pi_k \leq 1,\ \sum_{k=1}^K \pi_k = 1$の値をとります。このパラメータの値も求めます。

 以上が観測モデルの設定です。続いて、設定したガウス混合モデルに従ってデータを生成します。先に各データが持つクラスタ(潜在変数$\mathbf{S} = \{s_{1,1}, \cdots, s_{N,K}\}$)を生成します。

# 観測データの真のクラスタを生成
N <- 250
s_true_nk <- rmultinom(n = N, size = 1, prob = pi_true_k) %>% 
  t()

# 確認
s_true_nk[1:5, ]
##      [,1] [,2] [,3]
## [1,]    0    1    0
## [2,]    1    0    0
## [3,]    0    0    1
## [4,]    1    0    0
## [5,]    0    0    1

 $\boldsymbol{\pi}$をパラメータとするカテゴリ分布に従い、潜在変数$\mathbf{S}$に1から$K$の値を割り当てます。カテゴリ分布の乱数は、多項分布の乱数生成関数rmultinom()の試行回数引数size1を指定することで生成できます。また確率値引数probpi_true_k、データ数引数nNを指定します。
 転置した出力は、各列がクラスタ1から$K$に対応していて、各行(各データ)に割り当てられたクラスタが1でそれ以外は0となります。これが(本来は観測できない)真のクラスタであり、s_true_nkとします。

 処理の都合上、割り当てられたクラスタを抽出しておきます。

# 各データのクラスタを抽出
res_s <- which(t(s_true_nk) == 1, arr.ind = TRUE)
s_true_n <- res_s[, "row"]

# 確認
s_true_n[1:5]
## [1] 2 1 3 1 3

 s_true_nkは、行がデータ、列がクラスタに対応しています。which()を使って行(データ)ごとに要素が1のインデックスを取得します。s_true_nの各要素が各データに対応し、値がクラスタ番号に対応します。

 生成したクラスタに従い、観測データ$\mathbf{X} = \{x_{11}, \cdots, x_{ND}\}$を生成します。

# 観測データを生成
x_nd <- matrix(0, nrow = N, ncol = D)
for(n in 1:N) {
  k <- s_true_n[n] # クラスタを取得
  x_nd[n, ] = mvnfast::rmvn(n = 1, mu = mu_true_kd[k, ], sigma = sigma2_true_ddk[, , k])
}

# 確認
x_nd[1:5, ]
##            [,1]      [,2]
## [1,] -4.6972335 -2.999227
## [2,]  1.7615493  8.326199
## [3,]  6.4177679 -1.999288
## [4,]  0.1890716  5.306996
## [5,]  4.3615996 -2.614715

 各データに与えられたクラスタのパラメータ$\boldsymbol{\mu}_k,\ \boldsymbol{\Sigma}_k$を持つ多次元ガウス分布に従い、データを生成します。多次元ガウス分布の乱数は、mvnfast::rmvn()で生成できます。
 データごとにクラスタが異なるので、for()で1データずつs_true_nからクラスタの値を取り出し、そのクラスタに従ってデータ$\mathbf{x}_n$を生成します。平均引数mumu_true_kd[k, ]、分散共分散引数sigmasigma2_true_ddk[, , k]を指定します。また1データずつ生成するので、データ数引数n1です。

 グラフを作成して観測モデルと観測データを確認しましょう。作図用のデータフレームを作成します。

# 作図用の点を生成
x_line <- seq(-10, 10, by = 0.1)
point_df <- tibble(
  x1 = rep(x_line, times = length(x_line)), 
  x2 = rep(x_line, each = length(x_line))
)

# 作図用のデータフレームを作成
model_true_df <- tibble()
sample_df <- tibble()
for(k in 1:K) {
  # 真の観測モデルを計算
  tmp_model_df <- cbind(
    point_df, 
    density = mvnfast::dmvn(
      X = as.matrix(point_df), mu = mu_true_kd[k, ], sigma = sigma2_true_ddk[, , k]
    ), 
    cluster = as.factor(k)
  )
  model_true_df <- rbind(model_true_df, tmp_model_df)
  
  # 観測データのデータフレーム
  k_idx <- which(s_true_n == k)
  tmp_sample_df <- tibble(
    x1 = x_nd[k_idx, 1], 
    x2 = x_nd[k_idx, 2], 
    cluster = as.factor(k)
  )
  sample_df <- rbind(sample_df, tmp_sample_df)
}

# 確認
head(model_true_df)
##      x1  x2      density cluster
## 1 -10.0 -10 1.837732e-10       1
## 2  -9.9 -10 2.081122e-10       1
## 3  -9.8 -10 2.353803e-10       1
## 4  -9.7 -10 2.658886e-10       1
## 5  -9.6 -10 2.999760e-10       1
## 6  -9.5 -10 3.380107e-10       1
head(sample_df)
## # A tibble: 6 x 3
##       x1    x2 cluster
##    <dbl> <dbl> <fct>  
## 1  1.76   8.33 1      
## 2  0.189  5.31 1      
## 3 -4.28   5.21 1      
## 4 -3.22   7.16 1      
## 5  5.58   4.22 1      
## 6 -3.67   4.43 1

 等高線グラフ用に格子状の点を用意する必要があります。seq()で描画範囲と間隔を決めて、x軸の点($x_1$がとり得る値)とy軸の点($x_2$がとり得る値)が直交する点(格子状の点)を作成します。
 各交点$(x_1, x_2)$のマトリクスをmvnfast::dmvn()の第1引数に渡すことで、確率密度を計算します。他の引数についてはmvnfast::rmvn()と同じです。これをクラスタごとに行います。
 観測データについても作図用にデータフレームを作成します。which(s_true_n == k)で各クラスタが割り当てられたデータのインデックスを取得しておき、そのインデックスを添字として用いてx_ndから各クラスタのデータを取り出します。

 観測モデル(多次元ガウス分布の確率密度)の等高線図と、観測データの散布図を重ねて作図します。

# 真の観測モデルを作図
ggplot() + 
  geom_contour(data = model_true_df, aes(x1, x2, z = density, color = cluster)) + # 真の観測モデル
  geom_point(data = sample_df, aes(x1, x2, color = cluster)) + # 真の観測データ
  labs(title = "Gaussian Mixture Model", 
       subtitle = paste0('K=', K, ', N=', N), 
       x = expression(x[1]), y = expression(x[2]))

 等高線グラフはgeom_contour()、散布図はgeom_point()で描画します。

f:id:anemptyarchive:20201207010917p:plain
ガウス混合モデル


 ここまでが観測モデル(ガウス混合モデル)に関する設定です。次に事前分布のパラメータの設定を行います。

・事前分布のパラメータの設定

 観測モデルのパラメータは本来知り得ないものです。そこで事前分布を設定し、事後分布を求めます(分布推定します)。

# 事前分布のパラメータを指定
beta <- 1
m_d <- rep(0, D)
nu <- D
w_dd <- diag(D) * 0.05
alpha_k <- rep(1, K)

 観測モデルの平均パラメータ$\boldsymbol{\mu}_k$の事前分布(多次元ガウス分布)の平均パラメータ$\mathbf{m} = \{m_1, \cdots, m_D\}$、精度行列パラメータの係数$\beta$を、それぞれm_dbetaとして値を指定します。
 観測モデルの精度行列パラメータ$\boldsymbol{\Lambda}_k$の事前分布(ウィシャート分布)の自由度$\nu$、パラメータ$\mathbf{W} = \{w_{1,1}, \cdots, w_{D,D}\}$を、それぞれnuw_ddとして値を指定します。ただし、$\nu > D - 1$の値をとり、$\mathbf{W}$は正定値行列です。
 混合比率$\boldsymbol{\pi}$の事前分布(ディリクレ分布)のパラメータ$\boldsymbol{\alpha} = \{\alpha_1, \cdots, \alpha_K\}$をalpha_kとして、$\alpha_k > 0$の値を指定します。

 近似事後分布のパラメータの初期値をランダムに設定します。

# 近似事後分布の初期値をランダムに設定
beta_hat_k <- seq(0.1, 10, by = 0.1) %>% 
  sample(size = K, replace = TRUE)
m_hat_kd <- seq(min(x_nd), max(x_nd), by = 0.1) %>% 
  sample(size = K * D, replace = TRUE) %>% 
  matrix(nrow = K, ncol = D)
nu_hat_k <- rep(nu, K)
w_hat_ddk <- array(rep(w_dd, times = K), dim = c(D, D, K))
alpha_hat_k <- seq(0.1, 10, by = 0.1) %>% 
  sample(size = K, replace = TRUE)

 $\boldsymbol{\mu}_k$の近似事後分布(多次元ガウス分布)の平均パラメータ$\hat{\mathbf{m}} = \{\hat{m}_{1,1}, \cdots, \hat{m}_{K,D}\}$、精度行列パラメータの係数$\hat{\boldsymbol{\beta}} = \{\hat{\beta}_1, \cdots, \hat{\beta}_K\}$を、それぞれm_hat_kdbeta_hat_kとします。
 $\boldsymbol{\Lambda}_k$の近似事後分布(ウィシャート分布)の自由度$\hat{\boldsymbol{\nu}} = \{\hat{\nu}_1, \cdots, \hat{\nu}_K\}$、パラメータ$\hat{\mathbf{W}} = \{\hat{\mathbf{W}}_1, \cdots, \hat{\mathbf{W}}_K\}$、$\hat{\mathbf{W}}_k = \{w_{k,1,1}, \cdots, w_{k,D,D}\}$を、それぞれnu_hat_kw_hat_ddkとします。
 $\boldsymbol{\pi}$の近似事後分布(ディリクレ分布)のパラメータ$\hat{\boldsymbol{\alpha}} = \{\hat{\alpha}_1, \cdots, \hat{\alpha}_K\}$をalpha_hat_kとします。

 m_hat_kdは、観測データx_ndの最小値から最大値の範囲でランダムに値を決めています。beta_hat_kalpha_hat_kは、範囲を指定してランダムに値を決めています。w_hat_ddkは、(値をランダムに生成するのが面倒なので)nu_hat_kと共に事前分布のパラメータの値を複製しています。

 事前分布と事後分布のパラメータのことを超パラメータ(ハイパーパラメータ)と呼びます。以上で推論に必要なモデルに関する設定は完了です。次は変分推論を実装します。

・変分推論の実装

 各データのクラスタ$\mathbf{S}$の近似事後分布$q(\mathbf{S})$のパラメータの計算(超パラメータの更新)と、パラメータ$\boldsymbol{\mu},\ \boldsymbol{\Lambda},\ \boldsymbol{\pi}$の近似事後分布のパラメータの計算(超パラメータの更新)を交互に行います。

 本では縦ベクトルとしているところを、この例では横ベクトルとして扱うため、転置などの処理が異なっている点に注意してください。

・コード全体

・コード(クリックで展開)

 $q(\mathbf{S})$の計算時に使用する中間変数(オブジェクト)に関して、クラスタごとの計算結果を添字を用いて代入していくため、最初に受け皿としてのオブジェクトを作成しておきます。

# 試行回数を指定
MaxIter <- 100

# 途中計算に用いる項の受け皿を作成
ln_eta_nk <- matrix(0, nrow = N, ncol = K)
tmp_eta_nk <- matrix(0, nrow = N, ncol = K)
E_lmd_ddk <- rep(nu_hat_k, each = D * D) * w_hat_ddk
E_ln_det_lmd_k <- rep(0, K)
E_lmd_mu_kd <- matrix(0, nrow = K, ncol = D)
E_mu_lmd_mu_k <- rep(0, K)
E_ln_pi_k <- rep(0, K)

# 推移の確認用の受け皿を作成
trace_E_mu_ikd <- array(0, dim = c(MaxIter+1, K, D))
trace_E_lambda_iddk <- array(0, dim = c(MaxIter+1, D, D, K))
trace_E_s_ink <- array(0, dim = c(MaxIter, N, K))
trace_E_mu_ikd[1, , ] <- m_hat_kd
trace_E_lambda_iddk[1, , , ] <- rep(nu_hat_k, each = D * D) * w_hat_ddk

# 変分推論
for(i in 1:MaxIter) {
  
  # Sの近似事後分布のパラメータを計算:式(4.109)
  for(k in 1:K) {
    E_lmd_ddk[, , k] <- nu_hat_k[k] * w_hat_ddk[, , k]
    E_ln_det_lmd_k[k] <- sum(digamma(0.5 * (nu_hat_k[k] + 1 - 1:D))) + D * log(2) + log(det(w_hat_ddk[, , k]))
    E_lmd_mu_kd[k, ] <- E_lmd_ddk[, , k] %*% matrix(m_hat_kd[k, ])
    E_mu_lmd_mu_k[k] <- t(m_hat_kd[k, ]) %*% matrix(E_lmd_mu_kd[k, ]) + D / beta_hat_k[k]
    E_ln_pi_k[k] <- digamma(alpha_hat_k[k]) - digamma(sum(alpha_hat_k))
    term_eta1_n <- diag(
      -0.5 * x_nd %*% E_lmd_ddk[, , k] %*% t(x_nd)
    )
    term_eta2_n <- x_nd %*% matrix(E_lmd_mu_kd[k, ]) %>% 
      as.vector()
    ln_eta_nk[, k] <- term_eta1_n + term_eta2_n - 0.5 * E_mu_lmd_mu_k[k] + 0.5 * E_ln_det_lmd_k[k] + E_ln_pi_k[k]
  }
  tmp_eta_nk <- exp(ln_eta_nk)
  eta_nk <- (tmp_eta_nk + 1e-7) / rowSums(tmp_eta_nk + 1e-7) # 正規化
  
  # Sの近似事後分布の期待値を計算:式(4.59)
  E_s_nk <- eta_nk
  
  for(k in 1:K) {
    # muの近似事後分布のパラメータを計算:式(4.114)
    beta_hat_k[k] <- sum(E_s_nk[, k]) + beta
    m_hat_kd[k, ] <- (colSums(E_s_nk[, k] * x_nd) + beta * m_d) / beta_hat_k[k]
    
    # lambdaの近似事後分布のパラメータを計算:式(4.118)
    nu_hat_k[k] <- sum(E_s_nk[, k]) + nu
    tmp_w1_dd <- t(E_s_nk[, k] * x_nd) %*% x_nd
    tmp_w2_dd <- beta * matrix(m_d) %*% t(m_d)
    tmp_w3_dd <- beta_hat_k[k] * matrix(m_hat_kd[k, ]) %*% t(m_hat_kd[k, ])
    w_hat_ddk[, , k] <- solve(
      tmp_w1_dd + tmp_w2_dd - tmp_w3_dd + solve(w_dd)
    )
  }
  
  # piの近似事後分布のパラメータを計算:式(4.58)
  alpha_hat_k <- colSums(E_s_nk) + alpha_k
  
  # 観測モデルのパラメータの期待値を記録
  trace_E_mu_ikd[i+1, , ] <- m_hat_kd
  trace_E_lambda_iddk[i+1, , , ] <- rep(nu_hat_k, each = D * D) * w_hat_ddk
  trace_E_s_ink[i, , ] <- E_s_nk
  
  # 動作確認
  print(paste0(i, ' (', round(i / MaxIter * 100, 1), '%)'))
}

 trace_***は、パラメータや分布の推移を確認するためのオブジェクトです。推論自体には不要です。
 $q(\boldsymbol{\mu} | \boldsymbol{\Lambda})$の期待値$\mathbb{E}_{q(\boldsymbol{\mu} | \boldsymbol{\Lambda})}[\boldsymbol{\mu}] = \hat{\mathbf{m}}$、$q(\boldsymbol{\Lambda})$の期待値$\mathbb{E}_{q(\boldsymbol{\Lambda})}[\boldsymbol{\Lambda}] = \hat{\boldsymbol{\nu}} \hat{\mathbf{W}}$、$q(\mathbf{S})$の期待値$\mathbb{E}_{q(\mathbf{S})}[\mathbf{S}] = \boldsymbol{\eta}$を記録しておきます。


・処理の解説

 まずは、潜在変数$\mathbf{s}_n$の近似事後分布(カテゴリ分布)のパラメータ$\boldsymbol{\eta}_n = \{\eta_{n,1}, \cdots, \eta_{n,K}\}$を計算します。$\boldsymbol{\eta}_n$の計算式(4.109)の期待値に関する項の計算から行います。

# Sの近似事後分布のパラメータを計算:式(4.109)
for(k in 1:K) {
  E_lmd_ddk[, , k] <- nu_hat_k[k] * w_hat_ddk[, , k]
  E_ln_det_lmd_k[k] <- sum(digamma(0.5 * (nu_hat_k[k] + 1 - 1:D))) + D * log(2) + log(det(w_hat_ddk[, , k]))
  E_lmd_mu_kd[k, ] <- E_lmd_ddk[, , k] %*% matrix(m_hat_kd[k, ])
  E_mu_lmd_mu_k[k] <- t(m_hat_kd[k, ]) %*% matrix(E_lmd_mu_kd[k, ]) + D / beta_hat_k[k]
  E_ln_pi_k[k] <- digamma(alpha_hat_k[k]) - digamma(sum(alpha_hat_k))
}

 それぞれ次の式で計算します。

$$ \begin{align} \mathbb{E}_{q(\boldsymbol{\Lambda}_k)} [ \boldsymbol{\Lambda}_k ] &= \hat{\nu} \hat{\mathbf{W}}_k \tag{4.119}\\ \mathbb{E}_{q(\boldsymbol{\Lambda}_k)} [ \ln |\boldsymbol{\Lambda}_k| ] &= \sum_{d=1}^D \psi \Bigl( \frac{\hat{\nu}_k + 1 - d}{2} \Bigr) + D \ln 2 + \ln |\hat{\mathbf{W}}_k| \tag{4.120}\\ \mathbb{E}_{q(\boldsymbol{\mu}_k, \boldsymbol{\Lambda}_k)} [ \boldsymbol{\Lambda}_k \boldsymbol{\mu}_k ] &= \hat{\nu} \hat{\mathbf{W}}_k \hat{\mathbf{m}}_k \tag{4.121}\\ &= \mathbb{E}_{q(\boldsymbol{\Lambda}_k)} [ \boldsymbol{\Lambda}_k ] \hat{\mathbf{m}}_k \\ \mathbb{E}_{q(\boldsymbol{\mu}_k, \boldsymbol{\Lambda}_k)} [ \boldsymbol{\mu}_k^{\top} \boldsymbol{\Lambda}_k \boldsymbol{\mu}_k ] &= \hat{\nu} \hat{\mathbf{m}}_k^{\top} \hat{\mathbf{W}}_k \hat{\mathbf{m}}_k + \frac{D}{\hat{\beta}_k} \tag{4.122}\\ &= \hat{\mathbf{m}}_k^{\top} \mathbb{E}_{q(\boldsymbol{\mu}_k, \boldsymbol{\Lambda}_k)} [ \boldsymbol{\Lambda}_k \boldsymbol{\mu}_k ] + \frac{D}{\hat{\beta}_k} \\ \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln \pi_k ] &= \psi(\hat{\alpha}_k) - \psi \Bigl( \sum_{k=1}^K \hat{\alpha}_k \Bigr) \tag{4.62} \end{align} $$

 ここで$\psi(\cdot)$はディガンマ関数です。

 これらを用いて$\boldsymbol{\eta} = \{\boldsymbol{\eta}_1, \cdots, \boldsymbol{\eta}_N\}$を計算します。

# Sの近似事後分布のパラメータを計算:式(4.109)
for(k in 1:K) {
  term_eta1_n <- diag(
    -0.5 * x_nd %*% E_lmd_ddk[, , k] %*% t(x_nd)
  )
  term_eta2_n <- x_nd %*% matrix(E_lmd_mu_kd[k, ]) %>% 
    as.vector()
  ln_eta_nk[, k] <- term_eta1_n + term_eta2_n - 0.5 * E_mu_lmd_mu_k[k] + 0.5 * E_ln_det_lmd_k[k] + E_ln_pi_k[k]
}
tmp_eta_nk <- exp(ln_eta_nk)

# 確認
round(tmp_eta_nk[1:5, ], 4)
##        [,1]   [,2]   [,3]
## [1,] 0.0006 0.0269 0.0000
## [2,] 0.0173 0.0000 0.0000
## [3,] 0.0001 0.0000 0.0449
## [4,] 0.0603 0.0000 0.0000
## [5,] 0.0002 0.0000 0.0569

 $\eta_{n,k}$は、次の式で計算します。

$$ \begin{align} \eta_{n,k} &\propto \exp \Biggl\{ - \frac{1}{2} \mathbf{x}_n^{\top} \mathbb{E}_{q(\boldsymbol{\Lambda}_k)} [ \boldsymbol{\Lambda}_k ] \mathbf{x}_n - \mathbf{x}_n^{\top} \mathbb{E}_{q(\boldsymbol{\mu}_k, \boldsymbol{\Lambda}_k)} [ \boldsymbol{\Lambda}_k \boldsymbol{\mu}_k ] + \frac{1}{2} \mathbb{E}_{q(\boldsymbol{\mu}_k, \boldsymbol{\Lambda}_k)} [ \boldsymbol{\mu}_k^{\top} \boldsymbol{\Lambda}_k \boldsymbol{\mu}_k ] \Biggr.\\ &\qquad \Biggl. + \frac{1}{2} \mathbb{E}_{q(\boldsymbol{\Lambda}_k)} [ \ln |\boldsymbol{\Lambda}_k| ] + \mathbb{E}_{q(\boldsymbol{\pi})} [ \ln \pi_k ] \Biggr\} \tag{4.109} \end{align} $$

 ただし全てのデータ($n = 1, \cdots, N$)を一度に処理するために、$(\mathbf{x}_n - \boldsymbol{\mu}_k)^{\top} \boldsymbol{\Lambda}_k (\mathbf{x}_n - \boldsymbol{\mu}_k)$の計算を1から$N$の全ての組み合わせで行います(for()内の2行目)。この計算結果は、$N \times N$のマトリクスになります。これは例えば、1行$N$列目の要素は$(\mathbf{x}_1 - \boldsymbol{\mu}_k)^{\top} \boldsymbol{\Lambda}_k (\mathbf{x}_N - \boldsymbol{\mu}_k)$の計算結果です。これは不要なので、diag()で対角成分(同じデータによる計算結果)のみ取り出します。
 またこの計算において$\ln 0$とならないように、微小な値1e-7($10^{-7} = 0.0000001$)を加えています。

 さらに$\sum_{k=1}^K \eta_{n,k} = 1$となるように、1から$K$の和をとったもので割って正規化する必要があります。

# 正規化
eta_nk <- eta_nk / rowSums(eta_nk)

# 確認
round(eta_nk[1:5, ], 4)
##        [,1]   [,2]   [,3]
## [1,] 0.0203 0.9794 0.0002
## [2,] 1.0000 0.0000 0.0000
## [3,] 0.0017 0.0000 0.9983
## [4,] 1.0000 0.0000 0.0000
## [5,] 0.0043 0.0000 0.9957
rowSums(eta_nk[1:5, ])
## [1] 1 1 1 1 1

 正規化の計算を式にすると次のようになります。

$$ \eta_{n,k} \leftarrow \frac{ \eta_{n,k} }{ \sum_{k'=1}^K \eta_{n,k'} } $$


 $q(\mathbf{S})$のパラメータが得られたので、$q(\mathbf{S})$の期待値を計算します。

# Sの近似事後分布の期待値を計算:式(4.59)
E_s_nk <- eta_nk

 カテゴリ分布の期待値(2.31)より、パラメータの値そのままです。

$$ \mathbb{E}_{q(\mathbf{s}_n)}[s_{n,k}] = \eta_{n,k} \tag{4.59} $$


 $\mathbf{S}$の近似事後分布の期待値が求まりました。次はこれを用いて、各パラメータの近似事後分布(のパラメータ)を求めます。ここからの内容は、for()でクラスタごとに行います。

 $\boldsymbol{\mu}_k$の近似事後分布(多次元ガウス分布)の平均パラメータ$\hat{\mathbf{m}}_k$、精度行列パラメータの係数$\hat{\beta}_k$を計算します。

# クラスタを指定
k <- 1

# muの近似事後分布のパラメータを計算:式(4.114)
beta_hat_k[k] <- sum(E_s_nk[, k]) + beta
m_hat_kd[k, ] <- (colSums(E_s_nk[, k] * x_nd) + beta * m_d) / beta_hat_k[k]

# 確認
head(beta_hat_k)
## [1] 119.59985  47.73624  85.66391
head(m_hat_kd)
##            [,1]      [,2]
## [1,] -0.3650018  4.364588
## [2,] -4.4876996 -5.041223
## [3,]  4.7854101 -2.267585

 $\hat{\beta}_k,\ \hat{\mathbf{m}}_k$は、それぞれ次の式で計算します。

$$ \begin{align} \hat{\beta}_k &= \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} [s_{n,k}] + \beta \\ \hat{\mathbf{m}}_k &= \frac{ \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} [s_{n,k}] \mathbf{x}_n + \beta \mathbf{m} }{ \hat{\beta}_k } \tag{4.114} \end{align} $$


 $\boldsymbol{\lambda}_k$の近似事後分布(ウィシャート分布)のパラメータ$\hat{\mathbf{W}}_k$、自由度$\hat{\nu}_k$も計算します。

# lambdaの近似事後分布のパラメータを計算:式(4.118)
nu_hat_k[k] <- sum(E_s_nk[, k]) + nu
tmp_w1_dd <- t(E_s_nk[, k] * x_nd) %*% x_nd
tmp_w2_dd <- beta * matrix(m_d) %*% t(m_d)
tmp_w3_dd <- beta_hat_k[k] * matrix(m_hat_kd[k, ]) %*% t(m_hat_kd[k, ])
w_hat_ddk[, , k] <- solve(
  tmp_w1_dd + tmp_w2_dd - tmp_w3_dd + solve(w_dd)
)

# 確認
head(nu_hat_k)
## [1] 120.59985  48.73624  86.66391
head(w_hat_ddk)
## , , 1
## 
##               [,1]          [,2]
## [1,]  0.0011520354 -0.0001235843
## [2,] -0.0001235843  0.0012023086
## 
## , , 2
## 
##            [,1]        [,2]
## [1,] 0.00558537 0.001625420
## [2,] 0.00162542 0.004915759
## 
## , , 3
## 
##              [,1]         [,2]
## [1,]  0.002495864 -0.001050050
## [2,] -0.001050050  0.002107438

 $\hat{\nu}_k,\ \hat{\mathbf{W}}_k$は、それぞれ次の式で計算します。

$$ \begin{align} \hat{\nu}_k &= \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} [s_{n,k}] + \nu \\ \hat{\mathbf{W}}_k^{-1} &= \sum_{n=1}^N \mathbb{E}_{q(\mathbf{s}_n)} [s_{n,k}] \mathbf{x}_n \mathbf{x}_n^{\top} + \beta \mathbf{m} \mathbf{m}^{\top} - \hat{\beta}_k \hat{\mathbf{m}}_k \hat{\mathbf{m}}^{\top} + \mathbf{W}^{-1} \tag{4.103} \end{align} $$


 ここまでの処理を全てのクラスタ($k = 1, \cdots, K$)で行います。

 最後に、$\boldsymbol{\pi}$の近似事後分布(ディリクレ分布)のパラメータ$\hat{\alpha}_k$を計算します。

# piの近似事後分布のパラメータを計算:式(4.58)
alpha_hat_k <- colSums(E_s_nk) + alpha_k

# 確認
head(alpha_hat_k)
## [1] 119.59985  47.73624  85.66391

 $\hat{\alpha}_k$は、次の式で計算します。

$$ \hat{\alpha}_k = \sum_{n=1}^N \mathbb{E}_{q(\boldsymbol{s}_n)} [s_{n,k}] + \alpha_k \tag{4.58} $$


 以上で各パラメータの近似事後分布(のパラメータ)が求まりました。次の試行ではこれを用いて潜在変数の近似事後分布(のパラメータ)を更新します。

 以上が変分推論による近似推論で行う個々の処理です。これを指定した回数くり返し行うことで、徐々に各パラメータの値を真の値に近付けていきます。

・結果の確認

 最後の試行で求めたパラメータの期待値を用いて、観測モデルの近似事後分布(の確率密度)の期待値を計算します。

# 作図用のデータフレームを作成
model_df <- tibble()
sample_df <- tibble()
max_p_idx <- max.col(E_s_nk) # 確率の最大値のインデックスを取得
for(k in 1:K) {
  # 近似事後分布を計算
  tmp_model_df <- cbind(
    point_df, 
    density = mvnfast::dmvn(
      as.matrix(point_df), 
      mu = m_hat_kd[k, ], sigma = solve(nu_hat_k[k] * w_hat_ddk[, , k])
    ), 
    cluster = as.factor(k)
  )
  model_df <- rbind(model_df, tmp_model_df)
  
  # 観測データのクラスタを抽出
  k_idx <- which(max_p_idx == k)
  tmp_sample_df <- tibble(
    x1 = x_nd[k_idx, 1], 
    x2 = x_nd[k_idx, 2], 
    cluster = as.factor(k)
  )
  sample_df <- rbind(sample_df, tmp_sample_df)
}

# 確認
head(model_df)
##      x1  x2      density cluster
## 1 -10.0 -10 8.907920e-11       1
## 2  -9.9 -10 9.961233e-11       1
## 3  -9.8 -10 1.112363e-10       1
## 4  -9.7 -10 1.240442e-10       1
## 5  -9.6 -10 1.381348e-10       1
## 6  -9.5 -10 1.536125e-10       1

 各データのクラスタ割り当て確率の期待値E_s_nkが最も高いクラスタを、そのデータのクラスタとみなして可視化します。max.col(E_s_nk)で、行(データ)ごとに最大値の列番号を返します。この出力は1から$K$の値となり、各データのクラスタ番号に対応します。これを用いて観測データの散布図のときと同様に作図用のデータフレームを作成します。

 真の観測モデルと重ねて近似事後分布の期待値を描画します。

# 近似事後分布を作図
ggplot() + 
  geom_contour(data = model_df, aes(x1, x2, z = density, color = cluster)) + # 近似事後分布
  geom_contour(data = model_true_df, aes(x1, x2, z = density, color = cluster), 
               linetype = "dotted", alpha = 0.6) + # 真の観測モデル
  geom_point(data = sample_df, aes(x1, x2, color = cluster)) + # 観測データ
  labs(title = "Variational Inference", 
       subtitle = paste0('K=', K, ', N=', N, ', iter:', MaxIter), 
       x = expression(x[1]), y = expression(x[2]))

f:id:anemptyarchive:20201207011003p:plain
変分推論によるガウス混合モデルの近似事後分布

 うまく推定できていることが確認できます。ただしクラスタの順番はランダムに決まるため、真のクラスタの番号(色)とは異なります。

 $\boldsymbol{\mu},\ \boldsymbol{\Lambda}$近似事後分布の期待値の推移を確認しましょう。

・コード(クリックで展開)

 trace_E_mu_ikdtrace_E_lambda_iddkから各クラスタと次元の値をそれぞれ取り出して、データフレームにまとめます。

# 作図用のデータフレームを作成
trace_E_mu_df <- tibble()
trace_E_lambda_df <- tibble()
for(k in 1:K) {
  for(d1 in 1:D) {
    # muの値を取得
    tmp_mu_df <- tibble(
      iteration = seq(0, MaxIter), 
      value = trace_E_mu_ikd[, k, d1], 
      label = as.factor(
        paste0("k=", k, ", d=", d1)
      )
    )
    trace_E_mu_df <- rbind(trace_E_mu_df, tmp_mu_df)
    
    for(d2 in 1:D) {
      # lambdaの値を取得
      tmp_lambda_df <- tibble(
        iteration = seq(0, MaxIter), 
        value = trace_E_lambda_iddk[, d1, d2, k], 
        label = as.factor(
          paste0("k=", k, ", d=", d1, ", d'=", d2)
        )
      )
      trace_E_lambda_df <- rbind(trace_E_lambda_df, tmp_lambda_df)
    }
  }
}

# 確認
head(trace_E_mu_df)
## # A tibble: 6 x 3
##   iteration value label   
##       <int> <dbl> <fct>   
## 1         0  9.98 k=1, d=1
## 2         1  4.48 k=1, d=1
## 3         2  4.10 k=1, d=1
## 4         3  3.72 k=1, d=1
## 5         4  3.35 k=1, d=1
## 6         5  3.01 k=1, d=1
head(trace_E_lambda_df)
## # A tibble: 6 x 3
##   iteration  value label         
##       <int>  <dbl> <fct>         
## 1         0 0.1    k=1, d=1, d'=1
## 2         1 0.143  k=1, d=1, d'=1
## 3         2 0.114  k=1, d=1, d'=1
## 4         3 0.0975 k=1, d=1, d'=1
## 5         4 0.0887 k=1, d=1, d'=1
## 6         5 0.0843 k=1, d=1, d'=1


 それぞれ折れ線グラフで可視化します。

# muの推移を確認
ggplot(trace_E_mu_df, aes(x = iteration, y = value, color = label)) + 
  geom_line() + 
  labs(title = "Variational Inference", 
       subtitle = expression(paste(E, "[", bold(mu), "]", sep = "")))
# lambdaの推移を確認
ggplot(trace_E_lambda_df, aes(x = iteration, y = value, color = label)) + 
  geom_line() + 
  labs(title = "Variational Inference", 
       subtitle = expression(paste(E, "[", bold(Lambda), "]", sep = "")))

f:id:anemptyarchive:20201207011157p:plain
$\boldsymbol{\mu}$の推移

f:id:anemptyarchive:20201207011236p:plain
$\boldsymbol{\Lambda}$の推移

 収束していることが確認できます。

 以上が変分推論による近似推論の処理です。

・おまけ

 gganimateパッケージを利用して、gif画像を作成するコードです。

 更新回数(試行回数)が増えるに従って、近似事後分布(の平均)が真の観測モデルに近づいていく様子を確認します。

・コード(クリックで展開)

# 追加パッケージ
library(gganimate)

# 描画する回数を指定(変更)
#MaxIter <- 150

# 作図用のデータフレームを作成
model_df <- tibble()
sample_df <- tibble()
for(i in 1:(MaxIter + 1)) {
  # 確率の最大値のインデックスを取得
  if(i > 1) {
    max_p_idx <- max.col(trace_E_s_ink[i - 1, , ])
  }
  for(k in 1:K) {
    # 近似事後分布を計算
    tmp_model_df <- cbind(
      point_df, 
      density = mvnfast::dmvn(
        as.matrix(point_df), 
        mu = trace_E_mu_ikd[i, k, ], sigma = solve(trace_E_lambda_iddk[i, , , k])
      ), 
      cluster = as.factor(k), 
      iteration = as.factor(i - 1)
    )
    model_df <- rbind(model_df, tmp_model_df)
    
    # クラスタを抽出
    if(i > 1) { # 初期値以外のとき
      k_idx <- which(max_p_idx == k)
      tmp_sample_df <- tibble(
        x1 = x_nd[k_idx, 1], 
        x2 = x_nd[k_idx, 2], 
        cluster = as.factor(k), 
        iteration = as.factor(i - 1)
      )
      sample_df <- rbind(sample_df, tmp_sample_df)
    }
  }
  
  if(i == 1) { # 初期値のとき
    tmp_sample_df <- tibble(
      x1 = x_nd[, 1], 
      x2 = x_nd[, 2], 
      cluster = NA, 
      iteration = as.factor(i - 1)
    )
    sample_df <- rbind(sample_df, tmp_sample_df)
  }
  
  # 動作確認
  print(paste0(i - 1, ' (', round((i - 1) / MaxIter * 100, 1), '%)'))
}

# 近似事後分布を作図
trace_graphe <- ggplot() + 
  geom_contour(data = model_df, aes(x1, x2, z = density, color = cluster)) + # 近似事後分布
  geom_contour(data = model_true_df, aes(x1, x2, z = density, color = cluster), 
               linetype = "dotted", alpha = 0.6) + # 真の観測モデル
  geom_point(data = sample_df, aes(x1, x2, color = cluster)) + # 観測データ
  transition_manual(iteration) + # フレーム
  labs(title = "Variational Inference", 
       subtitle = paste0('K=', K, ', N=', N, ', iter:{current_frame}'), 
       x = expression(x[1]), y = expression(x[2]))

# gif画像を作成
animate(trace_graphe, nframes = MaxIter, fps = 10)

 先ほど同様に、イタレーションごとに各近似事後分布の期待値を用いて率密度を計算し、データフレームに結合していきます。作図処理では、gganimate::transition_manual()に、フレームの切り替えの値となる列を指定します。全てのフレーム(イタレーション)のグラフがtrace_grapheに格納されるので、gganimate::animate()でgif画像として出力します。
 処理に時間がかかる場合は、描画範囲や点の間隔(x_line)、または表示する試行回数(MaxIter)を調整してください。

f:id:anemptyarchive:20201207011302g:plain
変分推論によるガウス混合モデルの近似事後分布の推移

Enjoy!

参考文献

  • 須山敦志『ベイズ推論による機械学習入門』(機械学習スタートアップシリーズ)杉山将監修,講談社,2017年.

おわりに

 2年ほどROM専だったアドベントカレンダーに初参加しました!

 アドカレ経由で初めてこのブログを読む方が増えることを期待して、簡単に自己紹介します。
 ここ1年半ほどは、ベイズ推論系の本を読んで、数式の行間を埋め、スクラッチで組み、アルゴリズムを理解することを目指しております。この記事もその内の1つです。よければ他の記事も読んでみてください。そして何か間違い等あったらご指摘ください!よろしくお願いします。
 ただ本当は、ハロプロ楽曲の歌詞分析をしたいだけのオタクです。NLPerになるべくトピックモデルを勉強し始めてから何かおかしな方向へ進んでいます。

 そんなこんなで今年はベイズ力が少し上がったので、来年はR力を少し上げたいです。
 この記事でもイメージしたグラフを描けず悔しいです。クラスタ(潜在変数)を推定した散布図について、左のようにデータごとに確率が最大のクラスタで色分けしましたが、右のように確率値に応じてグラデーションで表現できます。

f:id:anemptyarchive:20201210022703g:plainf:id:anemptyarchive:20201210022822g:plain
クラスタの表現

なので本当は、確率が最大のクラスタの色で、確率値に応じてグラデーションにして、確率密度の等高線と重ねたかった。クラスタごとにデータフレームを作って、geom_point()を重ねて、色も個別に指定すればできそうな気もしましたが、あまりにスマートじゃないので止めました。何かいい方法があるでしょうか?
 あと推論時のは疑似コードに従ってるのでいいとして、作図時のfor()を何とかしたい。いい加減{purrr}覚えたい。

 そして2020年12月10日は、ハロープロジェクトのグループJuice=Juiceの宮本佳林さんの卒業の日です。

 卒業悲しい、ホールツアー観たかった。でもソロデビュー楽しみ、早くアルバム聴きたい。おめでとうございます!

 卒コン最高でした!!!卒業おめでとうございます。(追記)

【次節の内容】続く