Graduate Descent

The optimal proposal distribution is not p

The following is a quick rant about importance sampling (see that post for notation).

I've heard the following incorrect statement one too many times,

We chose \(q \approx p\) because \(q=p\) is the "optimal" proposal distribution.

While it is certainly a good idea to pick \(q\) to be as similar as possible to \(p\), it is by no means optimal because it is oblivious to \(f\)!

With importance sampling, it is possible to achieve a variance reduction over Monte Carlo estimation. The optimal proposal distribution, assuming \(f(x) \ge 0\) for all \(x\), is \(q(x) \propto p(x) f(x).\) This choice of \(q\) gives us a zero variance estimate with a single sample!

Of course, this is an unreasonable distribution to use because the normalizing constant is the thing you are trying to estimate, but it is proof that better proposal distributions exist.

The key to doing better than \(q=p\) is to take \(f\) into account. Look up "importance sampling for variance reduction" to learn more.

Importance Sampling

Importance sampling is a powerful and pervasive technique in statistics, machine learning and randomized algorithms.

Basics

Importance sampling is a technique for estimating the expectation \(\mu\) of a random variable \(f(x)\) under distribution \(p\) from samples of a different distribution \(q.\)

The key observation is that \(\mu\) is can expressed as the expectation of a different random variable \(f^*(x)=\frac{p(x)}{q(x)}\! \cdot\! f(x)\) under \(q.\)

$$ \mathbb{E}_{q}\! \left[ f^*(x) \right] = \mathbb{E}_{q}\! \left[ \frac{p(x)}{q(x)} f(x) \right] = \sum_{x} q(x) \frac{p(x)}{q(x)} f(x) = \sum_{x} p(x) f(x) = \mathbb{E}_{p}\! \left[ f(x) \right] = \mu $$


Technical condition: \(q\) must have support everywhere \(p\) does, \(p(x) > 0 \Rightarrow q(x) > 0.\) Without this condition, the equation is biased! Note: \(q\) can support things that \(p\) doesn't.

Terminology: The quantity \(w(x) = \frac{p(x)}{q(x)}\) is often referred to as the "importance weight" or "importance correction". We often refer to \(p\) as the target density and \(q\) the proposal density.

Now, given samples \(\{ x^{(i)} \}_{i=1}^{n}\) from \(q,\) we can use the Monte Carlo estimate, \(\hat{\mu} \approx \frac{1}{n} \sum_{i=1}^n f^{*}(x^{(i)}),\) as an unbiased estimator of \(\mu.\)

Remarks

There are a few reasons we might want use importance sampling:

  1. Convenience: It might be trickier to sample directly from \(p.\)

  2. Bias-correction: Suppose, we're developing an algorithm which requires samples to satisfy some "safety" condition (e.g., a minimum support threshold) and be unbiased. Importance sampling can be used to remove bias, while satisfying the condition.

  3. Variance reduction: It might be the case that sampling directly from \(p\) would require more samples to estimate \(\mu.\) Check out these great notes for more.

  4. Off-policy evaluation and learning: We might want to collect some "exploratory data" from \(q\) and evaluate different "policies", \(p\) (e.g., to pick the best one). Here's a link to a future post on off-policy evaluation and counterfactual reasoning and some cool papers: counterfactual reasoning, reinforcement learning, contextual bandits, domain adaptation.

There are a few common cases for \(q\) worth separate consideration:

  1. Control over \(q\): This is the case in experimental design, variance reduction, active learning and reinforcement learning. It's often difficult to design \(q,\) which results in an estimator with "reasonable" variance. A very difficult case is in off-policy evaluation because it (essentially) requires a good exploratory distribution for every possible policy. (I have much more to say on this topic.)

  2. Little to no control over \(q\): For example, you're given some dataset (e.g., new articles) and you want to estimate performance on a different dataset (e.g., Twitter).

  3. Unknown \(q\): In this case, we want to estimate \(q\) (typically referred to as the propensity score) and use it in the importance sampling estimator. This technique, as far as I can tell, is widely used to remove selection bias when estimating effects of different treatments.

Drawbacks: The main drawback of importance sampling is variance. A few bad samples with large weights can drastically throw off the estimator. Thus, it's often the case that a biased estimator is preferred, e.g., estimating the partition function, clipping weights, indirect importance sampling. A secondary drawback is that both densities must be normalized, which is often intractable.

What's next? I plan to cover "variance reduction" and off-policy evaluation in more detail in future posts.

KL-divergence as an objective function

It's well-known that KL-divergence is not symmetric, but which direction is right for fitting your model?

First, which one is which?

Cheat sheet: If we're fitting \(q\) to \(p\) using

\(\textbf{KL}(q || p)\) - mode-seeking, exclusive - no normalization wrt \(p\)

\(\textbf{KL}(p || q)\) - mean-seeking, inclusive - results in a moment-matching problem when \(q\) is in the exponential family - requires normalization wrt \(p\)

How I remember which is which: The KL notation is a little strange with the "\(||\)". I pretend it's a division symbol. This happens to correspond nicely to a division symbol in the corresponding equation (I'm not sure it's intentional). It helps to remember that you always normalize wrt the "numerator" (and \(q\)).

Computational perspecive

Let's look at what's involved in fitting a model \(q_\theta\) in each direction. In this section, I'll describe the gradient and pay special attention to the issue of normalization.

Notation: \(p,q_\theta\) are probabilty distributions. \(p = \bar{p} / Z_p\), where \(Z_p\) is the normalization constant. Similarly for \(q\).

The easy direction \(\textbf{KL}(q_\theta || p)\)

\begin{align*} \textbf{KL}(q_\theta || p) &= \sum_d q(d) \log \left( \frac{q(d)}{p(d)} \right) \\ &= \sum_d q(d) \left( \log q(d) - \log p(d) \right) \\ &= \underbrace{\sum_d q(d) \log q(d)}_{-\text{entropy}} - \underbrace{\sum_d q(d) \log p(d)}_{\text{cross-entropy}} \\ \end{align*}

Let's look at normalization of \(p\), the entropy term is easy because there is no \(p\) in it.

\begin{align*} \sum_d q(d) \log p(d) &= \sum_d q(d) \log (\bar{p}(d) / Z_p) \\ &= \sum_d q(d) \left( \log \bar{p}(d) - \log Z_p) \right) \\ &= \sum_d q(d) \log \bar{p}(d) - \sum_d q(d) \log Z_p \\ &= \sum_d q(d) \log \bar{p}(d) - \log Z_p \end{align*}

In this case, \(-\log Z_p\) is an additive constant, which can be dropped because we're optimizing.

This leaves us with the following optimization problem:

\begin{align*} & \text{argmin}_\theta \textbf{KL}(q_\theta || p) \\ &\qquad = \text{argmin}_\theta \sum_d q_\theta(d) \log q_\theta(d) - \sum_d q_\theta(d) \log \bar{p}(d) \end{align*}

Let's work out the gradient

\begin{align*} & \nabla\left[ \sum_d q_\theta(d) \log q_\theta(d) - \sum_d q_\theta(d) \log \bar{p}(d) \right] \\ &\qquad = \sum_d \nabla \left[ q_\theta(d) \log q_\theta(d) \right] - \sum_d \nabla\left[ q_\theta(d) \right] \log \bar{p}(d) \\ &\qquad = \sum_d \nabla \left[ q_\theta(d) \right] \left( 1 + \log q_\theta(d) \right) - \sum_d \nabla\left[ q_\theta(d) \right] \log \bar{p}(d) \\ &\qquad = \sum_d \nabla \left[ q_\theta(d) \right] \left( 1 + \log q_\theta(d) - \log \bar{p}(d) \right) \\ &\qquad = \sum_d \nabla \left[ q_\theta(d) \right] \left( \log q_\theta(d) - \log \bar{p}(d) \right) \\ \end{align*}

We killed the one in the last equality because \(\sum_d \nabla \left[ q(d) \right] = \nabla \left[ \sum_d q(d) \right] = \nabla \left[ 1 \right] = 0\), for any \(q\) which is a probability distribution.

This direction is convenient because we don't need to normalize \(p\). Unfortunately, the "easy" direction is nonconvex in general --- unlike the "hard" direction, which is convex.

Harder direction \(\textbf{KL}(p || q_\theta)\)

\begin{align*} \textbf{KL}(p || q_\theta) &= \sum_d p(d) \log \left( \frac{p(d)}{q(d)} \right) \\ &= \sum_d p(d) \left( \log p(d) - \log q(d) \right) \\ &= \sum_d p(d) \log p(d) - \sum_d p(d) \log q(d) \\ \end{align*}

Clearly the first term (entropy) won't matter if we're just trying optimize wrt \(\theta\). So, let's focus on the second term (cross-entropy).

\begin{align*} \sum_d p(d) \log q(d) &= \frac{1}{Z_p} \sum_d \bar{p}(d) \log \left( \bar{q}(d)/Z_q \right) \\ &= \frac{1}{Z_p} \sum_d \bar{p}(d) \log \left( \bar{q}(d)/Z_q \right) \\ &= \frac{1}{Z_p} \sum_d \bar{p}(d) \left( \log \bar{q}(d) - \log Z_q \right) \\ &= \frac{1}{Z_p} \sum_d \bar{p}(d) \log \bar{q}(d) - \frac{1}{Z_p} \sum_d \bar{p}(d) \log Z_q \\ &= \frac{1}{Z_p} \sum_d \bar{p}(d) \log \bar{q}(d) - \log Z_q \frac{1}{Z_p} \sum_d \bar{p}(d) \\ &= \frac{1}{Z_p} \sum_d \bar{p}(d) \log \bar{q}(d) - \log Z_q \end{align*}

Unfortunately the \(Z_p\) factor is unavoidable. The usual "approximate inference story" is that \(Z_p\) is hard to compute, while the approximating distributions normalization constant \(Z_q\) is easy.

Nonetheless, optimizing KL in this direction is still useful. Examples include, expectation propagation, variational decoding and maximum likelihood estimation. In the case of maximum likelihood estimation, \(p\) is the empirical distribution, so technically you don't have to compute its normalizing constant, but you do need to sample from it (which can be just as hard).

The gradient, when \(q\) is in the exponential family, is intuitive:

\begin{align*} \nabla \left[ \frac{1}{Z_p} \sum_d \bar{p}(d) \log \bar{q}(d) - \log Z_q \right] &= \frac{1}{Z_p} \sum_d \bar{p}(d) \nabla \left[ \log \bar{q}(d) \right] - \nabla \log Z_q \\ &= \frac{1}{Z_p} \sum_d \bar{p}(d) \phi_q(d) - \mathbb{E}_q \left[ \phi_q \right] \\ &= \mathbb{E}_p \left[ \phi_q \right] - \mathbb{E}_q \left[ \phi_q \right] \end{align*}

Optimization problem is convex when \(q_\theta\) is an exponential families --- i.e., \(p\) can be arbitrary. You can think of maximum likelihood estimation as a method which minimizes KL divergence from samples of \(p\). In this case, \(p\) is the true data distribution! The first term in the gradient is based on a sample instead of an exact estimate (often called "observed feature counts").

Downside: computing \(\mathbb{E}_p \left[ \phi_q \right]\) might not be tractable.

Remarks

  • Both directions of KL are special cases of \(\alpha\)-divergence. For a unified account of both directions consider looking into \(\alpha\)-divergence.

  • Inclusive divergences require \(q > 0\) whenever \(p > 0\). No "false negatives".

  • Exclusive divergences will often favor a single mode.

  • Computing the value of KL (not just the gradient) in either direction requires normalization. However, in the "easy" direction, using unnormalized \(p\) results in only an additive constant difference. So, it's still just as useful, if all you care about is optimization (fitting the model).