Graduate Descent

Fast sigmoid sampling

In this notebook, we describe a simple trick for efficiently sampling a Bernoulli random variable $Y$ from a sigmoid-defined distribution, $p(Y = 1) = (1 + \exp(-x))^{-1}$, where $x \in \mathbb{R}$ is the only parameter of the distribution ($x$ is often defined as the dot product of features and weights).

The "slow" method for sampling from a sigmoid,

$$ u \sim \textrm{Uniform}(0,1) $$$$ Y = sigmoid(x) > u $$

This method is slow because it calls the sigmoid function for every value of $x$. It is slow because $\exp$ is 2-3x slower than basic arithmetic operations.

In this post, I'll describe a simple trick, which is well-suited to vectorized computations (e.g., numpy, matlab). The way it works is by precomputing the expensive stuff (i.e., calls to expensive functions like $\exp$).

$$ sigmoid(x) > u \Leftrightarrow logit(sigmoid(x)) > logit(u) \Leftrightarrow x > logit(u). $$

Some details worth mentioning: (a) logit is the inverse of sigmoid and (b) logit is strictly monotonic increasing you can apply it both sides of the greater than and preserves ordering (there's a plot in the appendix).

The "fast" method derives it's advantage by leveraging the fact that expensive computation can be done independently of the data (i.e., specific values of $x$). The fast method is also interesting as just cute math. In the bonus section of this post, we'll make a connection to the Gumbel max trick.

How fast is it in practice? Below, we run a quick experiment to test that the method is correct and how fast it is.

In [1]:
%matplotlib inline
import numpy as np
import pylab as pl
from numpy.random import uniform
from numpy import exp
from scipy.special import expit as sigmoid, logit
from arsenal.timer import timers    #
In [3]:
T = timers()

# These are the sigmoid parameters we're going to sample from.
n = 10000
X = np.linspace(-5,5,n)

# number of runs to average over.
R = 1000

# Used for plotting average p(Y=1)
F = np.zeros_like(X)

# Temporary array for saving on memory allocation, cf. method slow-2.
tmp = np.empty(n)                     

for _ in range(R):

    # Let's use the same random variables for all methods. This allows 
    # for a lower variance comparsion and equivalence testing.
    u = uniform(0,1,size=n)
    z = logit(u)       # used in fast method: precompute expensive stuff.

    # Requires computing sigmoid for each x.
    with T['slow1']:
        s1 = sigmoid(X) > u           
    # Avoid memory allocation in slow-1 by using the out option to sigmoid
    # function. It's a little bit faster than slow-1.
    with T['slow2']:
        sigmoid(X, out=tmp)           
        s2 = tmp > u

    # Rolling our sigmoid is a bit slower than using the library function.
    # Not to mention this implementation isn't as numerically stable.
    with T['slow3']:
        s3 = 1/(1+exp(-X)) > u
    # The fast method.
    with T['fast']:
        f = X > z
    F += f / R    
    assert (s1 == f).all()
    assert (s2 == f).all()
    assert (s3 == f).all()

pl.plot(X, F)
pl.plot(X, sigmoid(X), c='r', lw=2)
fast is 28.4239x faster than slow1 (avg: slow1: 0.00114061 fast: 4.01285e-05)
slow2 is 1.0037x faster than slow1 (avg: slow1: 0.00114061 slow2: 0.0011364)
slow1 is 1.0840x faster than slow3 (avg: slow3: 0.00123637 slow1: 0.00114061)

It looks like our trick is about $28$x faster than the fastest competing slow method!

We also see that the assert statements passed, which means that the methods tested produce precisely the same samples.

The final plot demonstrates that we get the right expected value (red curve) as we sweep the distributions parameter (x-axis).


We could alternatively use the Gumbel max trick to derive a similar algorithm. If we ground out the trick for a sigmoid instead of a general mutlinomal distributions, we end up with

$$ Z_0 \sim \textrm{Gumbel}(0,1) $$$$ Z_1 \sim \textrm{Gumbel}(0,1) $$$$ Y = x > Z_0 - Z_1 $$

Much like our new trick, this one benefits from the fact that all expensive stuff is done independent of the data (i.e., the value of $x$). However, it seems silly that we "need" to generate two Gumbel RVs to get one sample from the sigmoid. With a little bit of Googling, we discover that the difference of $\textrm{Gumbel}(0,1)$ RVs is a logistic RV (specifically $\textrm{Logistic}(0,1)$).

It turns out that $\textrm{logit}(\textrm{Uniform}(0,1))$ is a $\textrm{Logistic}(0,1)$ RV.

Voila! Our fast sampling trick and the Gumbel max trick are connected!

Another trick is Justin Domke's trick to reduce calls to $\exp$ by $\approx 88\%$. The disadvantage of this approach is that it's harder to implement with vectorization. The advantage is that we don't need to precompute any expensive things.


Logit plot

In [3]:
xs = np.linspace(0,1,100)
ys = logit(xs)
pl.plot(xs, ys);

Logistic random variable

Check that our sampling method is equivalent to sampling from a logistic distribution.

In [5]:
from scipy.stats import logistic
u = uniform(0,1,size=10000)
z = logit(u)
pl.hist(z, bins=100, normed=1)
xs = np.linspace(-6,6,100)
ys = logistic.pdf(xs)
pl.plot(xs, ys, c='r', lw=2);

Gumbel max trick and weighted reservoir sampling

A while back, Hanna and I stumbled upon the following blog post: Algorithms Every Data Scientist Should Know: Reservoir Sampling, which got us excited about reservior sampling.

Around the same time, I attended a talk by Tamir Hazan about some of his work on perturb-and-MAP (Hazan & Jaakkola, 2012), which is inspired by the Gumbel-max-trick (see previous post). The apparent similarity between weighted reservior sampling and the Gumbel max trick lead us to make some cute connections, which I'll describe in this post.

The problem: We're given a stream of unnormalized probabilities, \(x_1, x_2, \cdots\). At any point in time \(t\) we'd like to have a sampled index \(i\) available, where the probability of \(i\) is given by \(\pi_t(i) = \frac{x_i}{ \sum_{j=1}^t x_j}\).

Assume, without loss of generality, that \(x_i > 0\) for all \(i\). (If any element has a zero weight we can safely ignore it since it should never be sampled.)

Streaming Gumbel-max sampler: I came up with the following algorithm, which is a simple "modification" of the Gumbel-max-trick for handling streaming data:

\(a = -\infty; b = \text{null} \ \ \text{# maximum value and index}\)
for \(i=1,2,\cdots;\) do:
# Compute log-unnormalized probabilities
\(w_i = \log(x_i)\)
# Additively perturb each weight by a Gumbel random variate
\(z_i \sim \text{Gumbel}(0,1)\)
\(k_i = w_i + z_i\)
# Keep around the largest \(k_i\) (i.e. the argmax)
if \(k_i > a\):
\(\ \ \ \ a = k_i\)
\(\ \ \ \ b = i\)

If we interrupt this algorithm at any point, we have a sample \(b\).

After convincing myself this algorithm was correct, I sat down to try to understand the algorithm in the blog post, which is due to Efraimidis and Spirakis (2005) (paywall, free summary). They looked similar in many ways but used different sorting keys / perturbations.

Efraimidis and Spirakis (2005): Here is the ES algorithm for weighted reservior sampling

\(a = -\infty; b = \text{null}\)
for \(i=1,2,\cdots;\) do:
# Additively perturb each weight by a Gumbel random variate,
\(u_i \sim \text{Uniform}(0,1)\)
\(e_i = u_i^{(\frac{1}{x_i})}\)
# Keep around the largest \(e_i\)
if \(k_i > a\):
\(\ \ \ \ a = k_i\)
\(\ \ \ \ b = i\)

Again, if interrupt this algorithm at any point, we have our sample \(b\).

Relationship: Let's try to relate these algorithms. At a high level, both algorithms compute a randomized key and take an argmax. What's the relationship between the keys?

First, note that a \(\text{Gumbel}(0,1)\) variate can be generated via \(-\log(-\log(\text{Uniform}(0,1)))\). This is a straightforward application of the inverse transform sampling method for random number generation. This means that if we use the same sequence of uniform random variates then, \(z_i = -\log(-\log(u_i))\).

However, this does not give use equality between \(k_i\) and \(e_i\), but it does turn out that \(k_i = -\log(-\log(e_i))\), which is useful because this is a monotonic transformation on the interval \((0,1)\). Since monotonic transformations preserve ordering, the sequences \(k\) and \(e\) result in the same comparison decisions, as well as, the same argmax. In summary, the algorithms are the same!

Extensions: After reading a little further along in the ES paper, we see that the same algorithm can be used to perform sampling without replacement by sorting and taking the elements with the highest keys. This same modification is applicable to the Gumbel-max-trick because the keys have exactly the same ordering as ES. In practice we don't sort the key, but instead use a bounded priority queue.

Closing: To the best of my knowledge, the connection between the Gumbel-max-trick and ES is undocumented. Furthermore, the Gumbel-max-trick is not known as a streaming algorithm, much less known to perform sampling without replacement! If you know of a reference let me know. Perhaps, we'll finish turning these connections into a short tech report.

Gumbel max trick

Goal: Sampling from a discrete distribution parametrized by unnormalized log-probabilities:

$$ \pi_k = \frac{1}{z} \exp(x_k) \ \ \ \text{where } z = \sum_{j=1}^K \exp(x_j) $$

The usual way: Exponentiate and normalize (using the exp-normalize trick), then use the an algorithm for sampling from a discrete distribution (aka categorical):

def usual(x):
    cdf = exp(x - x.max()).cumsum()     # the exp-normalize trick
    z = cdf[-1]
    u = uniform(0,1)
    return cdf.searchsorted(u * z)

The Gumbel max trick:

$$ y = \underset{ i \in \{1,\cdots,K\} }{\operatorname{argmax}} x_i + z_i $$

where \(z_1 \cdots z_K\) are i.i.d. \(\text{Gumbel}(0,1)\) random variates. It turns out that \(y\) is distributed according to \(\pi\). (See the short derivations in this blog post.)

Implementing the Gumbel max trick is remarkable easy:

def gumbel_max_sample(x):
    z = gumbel(loc=0, scale=1, size=x.shape)
    return (x + g).argmax(axis=1)

If you don't have access to a Gumbel random variate generator, you can use \(-\log(-\log(\text{Uniform}(0,1))\)


  1. Gumbel max requires \(K\) samples from a uniform. Usual requires only \(1\).

  2. Gumbel is one-pass because it does not require normalization (or a pass to compute the max for use in the exp-normalize trick). More on this in a later post!

  3. The Gumbel max trick requires \(2K\) calls to \(\log\), whereas ordinary requires \(K\) calls to \(\exp\). Since \(\exp\) and \(\log\) are expensive function, we'd like to avoid calling them. What gives? Well, Gumbel's calls to \(\log\) do not depend on the data so they can be precomputed; this is handy for implementations which rely on vectorization for efficiency, e.g. python+numpy.

Further reading: I have a few posts relating to the Gumbel-max trick. Have a look at posts tagged with Gumbel.