Reading Notes: MLAPP Chapter 21: Variational Inference
在Bayesian inference中,我们需要计算形如$p(\bm{\theta}, \bm{z} \mid \bm{x})$的posterior,其中$\bm{x}$代表observed variables,$\bm{z}$代表latent variables,$\bm{\theta}$为模型参数。在一定条件下,posterior有closed-form representation,如用Gaussian作为prior和likelihood。然而在一般情况下,我们只能根据Bayes theorem暴力求解:
\[p(\bm{\theta}, \bm{z} \mid \bm{x}) = \frac{p(\bm{\theta}, \bm{z}, \bm{x})}{\int_{\bm{\theta}' \bm{z}'}{p(\bm{\theta}', \bm{z}', \bm{x})}}\]其中分母marginal probability的计算复杂度会随$\bm{\theta}$和$\bm{z}$的维度指数增长,高维情况下exact inference是不可能的,我们只能寻求approximate inference,其中常见的两类算法分别是Variational Inference (VI)和Markov Chain Monte Carlo (MCMC) Inference。在本篇notes中我们介绍VI,并讨论VI和MCMC的优缺点及适用范围。
Variational Inference (VI)
VI用一系列简单易求解的distribution来近似一个困难的目标distribution(即我们这里的posterior),通过迭代法来逼近这个目标distribution。换句话说,VI用optimization代替了inference。Jordan大神曾这样定义VI:
Any procedure which uses optimization to approximate a density can be termed ‘‘variational inference’’.
KL Divergence
记$p^{\star}(x)$为困难的目标distribution,$q(x)$为用来近似的distribution,我们用KL Divergence来量化这两个distribution的近似程度:
\[\mathbb{KL}(q \| p^{\star}) = \int_x{q(x) \log{\frac{q(x)}{p^{\star}(x)}}}\]注意到这里使用的是reverse KL,它用对$q(x)$求expectation简化了计算。由于KL Divergence的非对称性,还有一种forwards KL,关于这两种KL我们做如下总结:
Forwards KL | Reverse KL | |
---|---|---|
Formula | $ \mathbb{KL}(p | q) = \int{p(x) \log{\frac{p(x)}{q(x)}}} $ | $ \mathbb{KL}(q | p) = \int{q(x) \log{\frac{q(x)}{p(x)}}} $ |
Other Name | moment projection (M-projection) | information projection (I-projection) |
Other Name | mode-seeking | mean-seeking |
Other Name | exclusive | inclusive |
Property | zero-avoiding | zero-forcing |
Support Estimation | overestimate | underestimate |
然而,计算$p^{\star}(x) = p(x \mid D)$依然是困难的,因为需要计算marginal distribution $Z = p(D)$。我们用joint distribution $\tilde{p}(x) = p(x, D) = p^{\star}(x) Z$替代$p^{\star}(x)$,重新定义目标函数
\[L(q) = -\mathbb{KL}(q \| \tilde{p}) = -\mathbb{KL}(q \| p^{\star}) + \log{Z} \le \log{Z} = \log{p(D)}\]这里$L(q)$的数学含义是evidence lower bound (ELBO),当$q$完全等价于$p^{\star}$时它取到log marginal probability,所以优化它即可使$q$逼近$p^{\star}$。
The mean field method
Mean field将用来近似的posterior $q$限定在fully factorized的形式:
\[q(\bm{x}) = \prod{q_i(x_i)}\]则可以推导出一个coordinated descent算法,每步中更新一个子分布
\[\log{q_j(x_j)} = \mathbb{E}_{-q_j}{[\log{\tilde{p}(\bm{x})}]} + C\]其实在更新$q_j$时,我们只需考虑在$j$的Markov blanket中的variables。由于在每步中我们把每个node的distribution更新为它邻居们的“平均值”,因而这个算法得名mean field method。书中举例将此算法用在Ising model上用于image denoising。
当然,fully factorized是个很强的假设。Structured mean field可以用来处理variable之间有一定dependency的情况,如factorized HMM。
Variational Bayes (VB)
当我们用mean field来对$\theta$做inference时,便得到了VB。书中的例子将VB应用在univariate Gaussian和linear regression的posterior inference。
Variational Bayes EM (VBEM)
当我们用mean field来同时对$\theta$和$z$做inference时,便得到了VBEM。对于经典的EM,我们在E-step中计算latent variable的posterior $p(z_i \mid x_i, \theta)$,在M-step中计算parameter $\theta$的MAP estimation。对于latent variable用Bayesian inference,而对于parameter用point estimation的理由是,$z_i$只能从每个$x_i$估计,而$\theta$可以从所有数据估计,相对而言方差较小。VBEM更进一步,对parameter也用posterior inference,相应的E-step和M-step变为variational E-step和variational M-step。在variational E-step中,我们需要代入parameter的posterior mean;在variational M-step中,我们需要计算parameter的posterior而不是MAP estimation。
VBEM的优点是,我们可以得到marginal likelihood的下界,用来做模型选择。VBEM还将latent variables和parameters一视同仁地处理,避免了EM将两者特意区分的做法。
Variational Auto-Encoder (VAE)
TODO
VI vs. MCMC
一般来讲,VI适合大数据集,MCMC适合小数据集。MCMC对计算的消耗更大,但可以保证产生和target distribution相同的sample。对于multimodal的distribution,MCMC可能过于注重某一个峰值而不能很好描述其它峰值。