We are interested in designing an efficient algorithm for sampling from a categorical distribution over \(n\) items with weights \(w_i > 0\). Define target sampling distribution \(p\) as
The following is a very simple and relatively famous algorithm due to Efraimidis and Spirakis (2006). It has several useful properties (e.g., it is a one-pass "streaming" algorithm, separates data from noise, can be easily extended for streaming sampling without replacement). It is also very closely related to the Gumbel-max trick (Vieira, 2014).
def weighted_reservoir_sampling(stream):
return np.argmin([Exponential.sample(w) for w in stream])
Some cosmetic differences from E&S'06: We use exponential random variates and \(\min\) instead of \(\max\). E&S'06 use a less-elegant and rather-mysterious (IMO) random key \(u_i^{1/w_i}\).
Why does it work? The weighted-reservoir sampling algorithm exploits the following well-known properties of exponential random variates: When \(X_i \sim \mathrm{Exponential}(w_i)\), \(R = {\mathrm{argmin}}_i X_i\), and \(T = \min_i X_i\) then \(R \sim p\) and \(T \sim \mathrm{Exponential}\left( \sum_i w_i \right)\).
Fewer random variates by waiting
One down-side of this one-pass algorithm is that it requires \(\mathcal{O}(n)\) uniform random variates. Contrast that with the usual, two-pass methods for sampling from a categorical distribution, which only need \(\mathcal{O}(1)\) samples. E&S'06 also present a much less well-known algorithm, called the "Exponential jumps" algorithm, which is a one-pass algorithm that only requires \(\mathcal{O}(\log(n))\) random variates (in expectation). That's way fewer random variates and a small price to pay if you are trying to avoid paging-in data from disk a second time.
Here is my take on their algorithm. There is no substantive difference, but I believe my version is more instructive since it makes the connection to exponential variates and truncated generation explicit (i.e., no mysterious random keys).
def jump(stream):
"Weighted-reservoir sampling by jumping"
R = None
T = np.inf
J = 0.0
for i, w in enumerate(stream):
J -= w
if J <= 0:
# Sample the key for item i, given that it is smaller than the current threshold
T = Exponential.sample_truncated(w, 0, T)
# i enters the reservoir
R = i
# sample the waiting time (size of the jump)
J = Exponential.sample(T)
return R
Why does exponential jumps work?
Let me first write the weighted_reservoir_sampling
algorithm to be much more similar to the jump
algorithm. For fun, I'm going to refer to it as the walk
algorithm.
def walk(stream):
"Weighted-reservoir sampling by walking"
R = None
T = np.inf
J = 0.0
for i, w in enumerate(stream):
X = Exponential.sample(w)
if X < T:
R = i # i enters the reservoir
T = X # threshold to enter the reservoir
return R
The key idea of the exponential jumps algorithm is to sample waiting times between new minimum events. In particular, if the algorithm is at step \(i\) the probability that it sees its next minimum at steps \(j \in \{ i+1, \ldots \}\) can be reasoned about without needing to actually sample the various \(X_j\) variables.
Rather than going into a full-blown tutorial on waiting times of exponential variates, I will get to the point and show that the jump
algorithm simulates the walk
algorithm. The key to doing this is showing that the probability of jumping from \(i\) to \(k\) is the same as "walking" from \(i\) to \(k\). Let \(W_{i,k} = \sum_{j=i}^k w_j\).
This proof is adapted from the original proof in E&S'06.
Given that the waiting time correctly matches the walking algorithm, the remaining detail is to check that \(X_k\) is equivalent under the condition that it goes into the reservoir. This conditioning is why the jumping algorithm must generate a truncated random variate: a random variate that is guaranteed to less than the previous minimum. In the Gumbel-max world, this is used in the top-down generative story.
Closing thoughts
Pros:
- The jump algorithm saves a ton of random variates and gives practical savings (at least, in my limited experiments).
Cons:
-
The jump algorithm is harder to parallelize or vectorize, but it seems possible.
-
If you aren't in a setting that requires a one-pass algorithm or some other special properties, you are probably better served by the two-pass algorithms since they have lower overhead because it doesn't call expensive functions like \(\log\) and it uses a single random variate per sample.
Further reading:
-
I have several posts on the topic of fast sampling algorithms (1, 2, 3).
-
Jake VanderPlas (2018) The Waiting Time Paradox, or, Why Is My Bus Always Late?.