Goal: Sampling from a discrete distribution parametrized by unnormalized log-probabilities:
$$ \pi_k = \frac{1}{z} \exp(x_k) \ \ \ \text{where } z = \sum_{j=1}^K \exp(x_j) $$
The usual way: Exponentiate and normalize (using the exp-normalize trick), then use the an algorithm for sampling from a discrete distribution (aka categorical):
def usual(x):
cdf = exp(x - x.max()).cumsum() # the exp-normalize trick
z = cdf[-1]
u = uniform(0,1)
return cdf.searchsorted(u * z)
The Gumbel-max trick:
$$ y = \underset{ i \in \{1,\cdots,K\} }{\operatorname{argmax}} x_i + z_i $$
where $z_1 \cdots z_K$ are i.i.d. $\text{Gumbel}(0,1)$ random variates. It turns out that $y$ is distributed according to $\pi$. (See the short derivations in this blog post.)
Implementing the Gumbel-max trick is remarkable easy:
def gumbel_max_sample(x):
z = gumbel(loc=0, scale=1, size=x.shape)
return (x + z).argmax(axis=1)
If you don't have access to a Gumbel random variate generator, you can use $-\log(-\log(\text{Uniform}(0,1))$
Comparison:
-
Number of calls to the random number generator: Gumbel-max requires $K$ samples from a uniform, whereas the usual algorithm only requires $1$.
-
Gumbel is a one-pass algorithm: It does not need to see all of the data (e.g., to normalize) before it can start partially sampling. Thus, Gumbel-max can be used for weighted sampling from a stream.
-
Low-level efficiency: The Gumbel-max trick requires $2K$ calls to $\log$, whereas ordinary requires $K$ calls to $\exp$. Since $\exp$ and $\log$ are expensive function, we'd like to avoid calling them. What gives? Well, Gumbel's calls to $\log$ do not depend on the data so they can be precomputed; this is handy for implementations which rely on vectorization for efficiency, e.g. python+numpy.
Further reading: I have a few posts relating to the Gumbel-max trick. Have a look at posts tagged with Gumbel.