Graduate Descent

Gumbel-max trick

sampling Gumbel

Goal: Sampling from a discrete distribution parametrized by unnormalized log-probabilities:

$$ \pi_k = \frac{1}{z} \exp(x_k) \ \ \ \text{where } z = \sum_{j=1}^K \exp(x_j) $$

The usual way: Exponentiate and normalize (using the exp-normalize trick), then use the an algorithm for sampling from a discrete distribution (aka categorical):

def usual(x):
    cdf = exp(x - x.max()).cumsum()     # the exp-normalize trick
    z = cdf[-1]
    u = uniform(0,1)
    return cdf.searchsorted(u * z)

The Gumbel-max trick:

$$ y = \underset{ i \in \{1,\cdots,K\} }{\operatorname{argmax}} x_i + z_i $$

where \(z_1 \cdots z_K\) are i.i.d. \(\text{Gumbel}(0,1)\) random variates. It turns out that \(y\) is distributed according to \(\pi\). (See the short derivations in this blog post.)

Implementing the Gumbel-max trick is remarkable easy:

def gumbel_max_sample(x):
    z = gumbel(loc=0, scale=1, size=x.shape)
    return (x + z).argmax(axis=1)

If you don't have access to a Gumbel random variate generator, you can use \(-\log(-\log(\text{Uniform}(0,1))\)


  1. Number of calls to the random number generator: Gumbel-max requires \(K\) samples from a uniform, whereas the usual algorithm only requires \(1\).

  2. Gumbel is a one-pass algorithm: It does not need to see all of the data (e.g., to normalize) before it can start partially sampling. Thus, Gumbel-max can be used for weighted sampling from a stream.

  3. Low-level efficiency: The Gumbel-max trick requires \(2K\) calls to \(\log\), whereas ordinary requires \(K\) calls to \(\exp\). Since \(\exp\) and \(\log\) are expensive function, we'd like to avoid calling them. What gives? Well, Gumbel's calls to \(\log\) do not depend on the data so they can be precomputed; this is handy for implementations which rely on vectorization for efficiency, e.g. python+numpy.

Further reading: I have a few posts relating to the Gumbel-max trick. Have a look at posts tagged with Gumbel.