Graduate Descent

Gumbel-Max Trick


sampling Gumbel probability tricks statistics

Goal: Sampling from a discrete distribution parametrized by unnormalized log-probabilities:

$$ \pi_k = \frac{1}{z} \exp(x_k) \ \ \ \text{where } z = \sum_{j=1}^K \exp(x_j) $$

The usual way: Exponentiate and normalize (using the exp-normalize trick), then use the an algorithm for sampling from a discrete distribution (aka categorical):

def usual(x):
    cdf = exp(x - x.max()).cumsum()     # the exp-normalize trick
    z = cdf[-1]
    u = uniform(0,1)
    return cdf.searchsorted(u * z)

The Gumbel-max trick:

$$ y = \underset{ i \in \{1,\cdots,K\} }{\operatorname{argmax}} x_i + z_i $$

where $z_1 \cdots z_K$ are i.i.d. $\text{Gumbel}(0,1)$ random variates. It turns out that $y$ is distributed according to $\pi$. (See the short derivations in this blog post.)

Implementing the Gumbel-max trick is remarkable easy:

def gumbel_max_sample(x):
    z = gumbel(loc=0, scale=1, size=x.shape)
    return (x + z).argmax(axis=1)

If you don't have access to a Gumbel random variate generator, you can use $-\log(-\log(\text{Uniform}(0,1))$

Comparison:

  1. Number of calls to the random number generator: Gumbel-max requires $K$ samples from a uniform, whereas the usual algorithm only requires $1$.

  2. Gumbel is a one-pass algorithm: It does not need to see all of the data (e.g., to normalize) before it can start partially sampling. Thus, Gumbel-max can be used for weighted sampling from a stream.

  3. Low-level efficiency: The Gumbel-max trick requires $2K$ calls to $\log$, whereas ordinary requires $K$ calls to $\exp$. Since $\exp$ and $\log$ are expensive function, we'd like to avoid calling them. What gives? Well, Gumbel's calls to $\log$ do not depend on the data so they can be precomputed; this is handy for implementations which rely on vectorization for efficiency, e.g. python+numpy.

Further reading: I have a few posts relating to the Gumbel-max trick. Have a look at posts tagged with Gumbel.