Variational Inference for Monte Carlo Objectives
WHY?
Recent variational training requires sampling of the variational posterior to estimate gradient. NVIL estimator suggest a method to estimate the gradient of the loss function wrt parameters. Since score function estimator is known to have high variance, baseline is used as variance reduction technique. However, this technique is insufficient to reduce variance in multi-sample setting as in IWAE.
WHAT?
We want to fit intractable model P(x,h) to data. The simplest way to estimate this (\hat{I}(h^{1:K})
) is sampling h from prior and averaging.
\hat{I}(h^{1:K}) = \frac{1}{K}\sum^K_{i=1}P(x|h^i), h^i \sim P(h)
However, this estimate has large variance since a small area of P(h) can explain data. Instead, we introduce proposal distribution(Q(h^i|x)
) conditonal on the observation and perform importance sampling.
\hat{I}(h^{1:K}) = \frac{1}{K}\sum^K_{i=1}\frac{P(x, h^i)}{Q(h^i|x)}, h^{1:K} \sim Q(h^{1:K}|x) \equiv \prod^K_{i=1}Q(h^i|x)
As we introduced proposal distribution, stochastic lowerbound can be found. E_{Q(h^{1:K}|x)}[\log \hat{I}(h^{1:K})] \leq \log E_{Q(h^{1:K}|x)}[\hat{I}(h^{1:K})] = \log P(x)\\ \hat{L}(h^{1:K}) = \log \hat{I}(h^{1:K})
Single sample version and multi-sample version estimator have different forms as follows. \mathcal{L}(x) = E_{Q(h|x)}[\log\frac{P(x, h)}{Q(h|x)}]\\ \mathcal{L}^K(x) = E_{Q(h^{1:K}|x)}[\log\frac{1}{K}\sum^K_{i=1}f(x, h^i)]\\
Gradient of multi-sample version estimator can be divided into two terms. \nabla_{\theta}\mathcal{L}^K(x) = E_{Q(h^{1:K}|x)}[\sum_j \hat{L}(h^{1:K})\nabla_{\theta}\log Q(h^j|x)] + E_{Q(h^{1:K}|x)}[\sum_j \tilde{w}^j \nabla_{\theta}\log f(x, h^j)\\ \tilde{w}^j \equiv \frac{f(x, h^j)}{\sum_{i=1}^K f(x,h^i)}
Two terms describe the effect of \theta
on L through proposal distribution and stochastic lower bound. The second term is relatively stable than the first term since the second terms are normalized with respective responsibility. The first term is problematic for two reasons: learning signal (\hat{L}(h^{1:K})
) is the same within samples thus does not properly assign credits. Also, the magnitude of the learning signal is unbounded overwhelming the second term.
While estimating gradient wrt \psi
is relatively simple, estimating the gradient wrt \theta
is the issue. The simplest choice would be naive Monte Carlo, and more elaborate choice would be NVIL. However, even NVIL cannot solve variation within samples. \nabla_{\psi}\mathcal{L}^K(x) \simeq \sum_j \tilde{w}^j\nabla_{\psi}\log f(x,h^j)\\ \nabla_{\theta}\mathcal{L}^K(x) \simeq \sum_j \hat{L}(h^{1:K})\nabla_{\theta}\log Q(h^j|x) + \sum_j \tilde{w}^j\nabla_{\theta}\log f(x,h^j)\\ \nabla_{\theta}\mathcal{L}^K(x) \simeq \sum_j (\hat{L}(h^{1:K}) - b(x) - b)\nabla_{\theta}\log Q(h^j|x) + \sum_j \tilde{w}^j\nabla_{\theta}\log f(x,h^j)
This paper suggest the reduction of variance by introducing local learning signals instead of global learning signal. This can be possible by sustituting learning signal for j with other term. This can be another mapping f(x) trained to predict f(x, h^i
), or mean of the other samples. This paper found geometric mean worked slightly better than arithmetric mean. E_{Q(h^{1:K}|x)}[\hat{l}(h^{1:K})\nabla_{theta}\log Q(h^j|x)] = E_{Q(h^{-j}|x)}[E_{Q(h^{j}|x)}[\hat{l}(h^{1:K})\nabla_{theta}\log Q(h^j|x)|h^{-j}]]\\ \hat{L}(h^j|h^{-j}) = \hat{L}(h^{1:K}) - \log \frac{1}{K}(\sum_{i\neq j}f(x, h^i) + f(x))\\ \hat{L}(h^j|h^{-j}) = \hat{L}(h^{1:K}) - \log \frac{1}{K}(\sum_{i\neq j}f(x, h^i) + \hat{f}(x, h^{-j}))\\ \hat{f}(x,h^{-j}) = \frac{1}{K-1}\sum_{i\neq j}f(x, h^i) or \exp(\frac{1}{K-1}\sum_{i\neq j}\log f(x, h^i))\\ \nabla_{\theta}\mathcal{L}^K(x) \simeq \sum_j \hat{L}(h^j|h^{-j})\nabla_{\theta}\log Q(h^j|x) + \sum_j \tilde{w}^j\nabla_{\theta}\log f(x,h^j)
The final estimator is called Variational Inference for Monte Carlo Ojectives (VIMCO).
So?
Compared to NVIL and RWS(Reweight Wake-Sleep), VIMCO tends to lower the bound with more samples used in SBN. Also VIMCO showed better performance in structured output prediction of MNIST.
Critic
Seems like ultimate version of estimating gradient!