Graduate Descent

Heaps for incremental computation

In this post, I'll describe a neat trick for maintaining a summary quantity (e.g., sum, product, max, log-sum-exp, concatenation, cross-product) under changes to its inputs. The trick and it's implementation are inspired by the well-known max-heap datastructure. I'll also describe a really elegant application to fast sampling under an evolving categorical distribution.

Setup: Suppose we'd like to efficiently compute a summary quantity under changes to its \(n\)-dimensional input vector \(\boldsymbol{w}\). The particular form of the quantity we're going to compute is \(z = \bigoplus_{i=1}^n w_i\), where \(\oplus\) is some associative binary operator with identity element \(\boldsymbol{0}\).

The trick: Essentially, the trick boils down to parenthesis placement in the expression which computes \(z\). A freedom we assumed via the associative property.

I'll demonstrate by example with \(n=8\).

Linear structure: We generally compute something like \(z\) with a simple loop. This looks like a right-branching binary tree when we think about the order of operations,

$$ z = (((((((w_1 \oplus w_2) \oplus w_3) \oplus w_4) \oplus w_5) \oplus w_6) \oplus w_7) \oplus w_8). $$

Heap structure: Here the parentheses form a balanced tree, which looks much more like a recursive implementation that computes the left and right halves and \(\oplus\)s the results (divide-and-conquer style),

$$ z = (((w_1 \oplus w_2) \oplus (w_3 \oplus w_4)) \oplus ((w_5 \oplus w_6) \oplus (w_7 \oplus w_8))). $$

The benefit of the heap structure is that there are \(\mathcal{O}(\log n)\) intermediate quantities that depend on any input, whereas the linear structure has \(\mathcal{O}(n)\). The intermediate quantities correspond to the values of each of the parenthesized expressions.

Since fewer intermediate quantities depend on a given input, fewer intermediates need to be adjusted upon a change to the input. Therefore, we get faster algorithms for maintaining the output quantity \(z\) as the inputs change.

Heap datastructure (aka binary index tree or Fenwick tree): We're going to store the values of the intermediates quantities and inputs in a heap datastructure, which is a complete binary tree. In our case, the tree has depth \(1 + \lceil \log_2 n \rceil\), with the values of \(\boldsymbol{w}\) at it's leaves (aligned left) and padding with \(\boldsymbol{0}\) for remaining leaves. Thus, the array's length is \(< 4 n\).

This structure makes our implementation really nice and efficient because we don't need pointers to find the parent or children of a node (i.e., no need to wrap elements into a "node" class like in a general tree data structure). So, we can pack everything into an array, which means our implementation has great memory/cache locality and low storage overhead.

Traversing the tree is pretty simple: Let \(d\) be the number of internal nodes, nodes \(1 \le i \le d\) are interal. For node \(i\), left child \(\rightarrow {2 \cdot i},\) right child \(\rightarrow {2 \cdot i + 1},\) parent \(\rightarrow \lfloor i / 2 \rfloor.\) (Note that these operations assume the array's indices start at \(1\). We generally fake this by adding a dummy node at position \(0\), which makes implementation simpler.)

Initializing the heap: Here's code that initializes the heap structure we just described.

def sumheap(w):
    "Create sumheap from weights `w` in O(n) time."
    n = w.shape[0]
    d = int(2**np.ceil(np.log2(n)))  # number of intermediates
    S = np.zeros(2*d)                # intermediates + leaves
    S[d:d+n] = w                     # store `w` at leaves.
    for i in reversed(range(1, d)):
        S[i] = S[2*i] + S[2*i + 1]
    return S

Updating \(w_k\) boils down to fixing intermediate sums that (transitively) depend on \(w_k.\) I won't go into all of the details here, instead I'll give code (below). I'd like to quickly point out that the term "parents" is not great for our purposes because they are actually the dependents: when an input changes the value the parents, grand parents, great grand parents, etc, become stale and need to be recomputed bottom up (from the leaves). The code below implements the update method for changing the value of \(w_k\) and runs in \(\mathcal{O}(\log n)\) time.

def update(S, k, v):
    "Update w[k] = v` in time O(log n)."
    d = S.shape[0]
    i = d//2 + k
    S[i] = v
    while i > 0:   # fix parents in the tree.
        i //= 2
        S[i] = S[2*i] + S[2*i + 1]


  • Numerical stability: If the operations are noisy (e.g., floating point operator), then the heap version may be better behaved. For example, if operations have an independent, additive noise rate \(\varepsilon\) then noise of \(z_{\text{heap}}\) is \(\mathcal{O}(\varepsilon \cdot \log n)\), whereas \(z_{\text{linear}}\) is \(\mathcal{O}(\varepsilon \cdot n)\). (Without further assumptions about the underlying operator, I don't believe you can do better than that.)

  • Relationship to max-heap: In the case of a max or min heap, we can avoid allocating extra space for intermediate quantities because all intermediates values are equal to exactly one element of \(\boldsymbol{w}\).

  • Change propagation: The general idea of adjusting cached intermediate quantities is a neat idea. In fact, we encounter it each time we type make at the command line! The general technique goes by many names—including change propagation, incremental maintenance, and functional reactive programming—and applies to basically any side-effect-free computation. However, it's most effective when the dependency structure of the computation is sparse and requires little overhead to find and refresh stale values. In our example of computing \(z\), these considerations manifest themselves as the heap vs linear structures and our fast array implementation instead of a generic tree datastructure.


  • No zero? No problem. We don't actually require a zero element. So, it's fair to augment \(\boldsymbol{K} \cup \{ \textsf{null} \}\) where \(\textsf{null}\) is distinguished value (i.e., \(\textsf{null} \notin \boldsymbol{K}\)) that acts just like a zero after we overload \(\oplus\) to satisfy the definition of a zero (e.g., by adding an if-statement).

  • Generalization to an arbitrary maps instead of fixed vectors is possible with a "locator" map, which a bijective map from elements to indices in a dense array.

  • Support for growing and shrinking: We support growing by maintaining an underlying array that is always slightly larger than we need—which we're already doing in the heap datastructure. Doubling the size of the underlying array (i.e., rounding up to the next power of two) has the added benefit of allowing us to grow \(\boldsymbol{w}\) at no asymptotic cost! This is because the resize operation, which requires an \(\mathcal{O}(n)\) time to allocate a new array and copying old values, happens so infrequently that they can be completely amortized. We get of effect of shrinking by replacing the old value with \(\textsf{null}\) (or \(\boldsymbol{0}\)). We can shrink the underlying array when the fraction of nonzeros dips below \(25\%\). This prevents "thrashing" between shrinking and growing.


Sampling from an evolving distribution: Suppose that \(\boldsymbol{w}\) corresponds to a categorical distributions over \(\{1, \ldots, n\}\) and that we'd like to sample elements from in proportion to this (unnormalized) distribution.

Other methods like the alias or inverse CDF methods are efficient after a somewhat costly initialization step. But! they are not as efficient as the heap sampler when the distribution is being updated. (I'm not sure about whether variants of alias that support updates exist.)

Method Sample Update Init
alias O(1) O(n)? O(n)
i-CDF O(log n) O(n) O(n)
heap O(log n) O(log n) O(n)

Use cases include

  • Gibbs sampling, where distributions are constantly modified and sampled from (changes may not be sparse so YMMV). The heap sampler is used in this paper.

  • EXP3 (mutli-armed bandit algorithm) is an excellent example of an algorithm that samples and modifies a single weight in the distribution.

  • Stochastic priority queues where we sample proportional to priority and the weights on items in the queue may change, elements are possibly removed after they are sampled (i.e., sampling without replacement), and elements are added.

Again, I won't spell out all of the details of these algorithms. Instead, I'll just give the code.

Inverse CDF sampling

def sample(w):
    "Ordinary sampling method, O(n) init, O(log n) per sample."
    c = w.cumsum()            # build cdf, O(n)
    p = uniform() * c[-1]     # random probe, p ~ Uniform(0, z)
    return c.searchsorted(p)  # binary search, O(log n)

Heap sampling is essentially the same, except the cdf is stored as heap, which is perfect for binary search!

def hsample(S):
    "Sample from sumheap, O(log n) per sample."
    d = S.shape[0]//2     # number of internal nodes.
    p = uniform() * S[1]  # random probe, p ~ Uniform(0, z)
    # Use binary search to find the index of the largest CDF (represented as a
    # heap) value that is less than a random probe.
    i = 1
    while i < d:
        # Determine if the value is in the left or right subtree.
        i *= 2         # Point at left child
        left = S[i]    # Probability mass under left subtree.
        if p > left:   # Value is in right subtree.
            p -= left  # Subtract mass from left subtree
            i += 1     # Point at right child
    return i - d

Code: Complete code and test cases for heap sampling are available in this gist.

Fast sigmoid sampling

In this notebook, we describe a simple trick for efficiently sampling a Bernoulli random variable $Y$ from a sigmoid-defined distribution, $p(Y = 1) = (1 + \exp(-x))^{-1}$, where $x \in \mathbb{R}$ is the only parameter of the distribution ($x$ is often defined as the dot product of features and weights).

The "slow" method for sampling from a sigmoid,

$$ u \sim \textrm{Uniform}(0,1) $$$$ Y = sigmoid(x) > u $$

This method is slow because it calls the sigmoid function for every value of $x$. It is slow because $\exp$ is 2-3x slower than basic arithmetic operations.

In this post, I'll describe a simple trick, which is well-suited to vectorized computations (e.g., numpy, matlab). The way it works is by precomputing the expensive stuff (i.e., calls to expensive functions like $\exp$).

$$ sigmoid(x) > u \Leftrightarrow logit(sigmoid(x)) > logit(u) \Leftrightarrow x > logit(u). $$

Some details worth mentioning: (a) logit is the inverse of sigmoid and (b) logit is strictly monotonic increasing you can apply it both sides of the greater than and preserves ordering (there's a plot in the appendix).

The "fast" method derives it's advantage by leveraging the fact that expensive computation can be done independently of the data (i.e., specific values of $x$). The fast method is also interesting as just cute math. In the bonus section of this post, we'll make a connection to the Gumbel max trick.

How fast is it in practice? Below, we run a quick experiment to test that the method is correct and how fast it is.

In [1]:
%matplotlib inline
import numpy as np
import pylab as pl
from numpy.random import uniform
from numpy import exp
from scipy.special import expit as sigmoid, logit
from arsenal.timer import timers    #
In [3]:
T = timers()

# These are the sigmoid parameters we're going to sample from.
n = 10000
X = np.linspace(-5,5,n)

# number of runs to average over.
R = 1000

# Used for plotting average p(Y=1)
F = np.zeros_like(X)

# Temporary array for saving on memory allocation, cf. method slow-2.
tmp = np.empty(n)                     

for _ in range(R):

    # Let's use the same random variables for all methods. This allows 
    # for a lower variance comparsion and equivalence testing.
    u = uniform(0,1,size=n)
    z = logit(u)       # used in fast method: precompute expensive stuff.

    # Requires computing sigmoid for each x.
    with T['slow1']:
        s1 = sigmoid(X) > u           
    # Avoid memory allocation in slow-1 by using the out option to sigmoid
    # function. It's a little bit faster than slow-1.
    with T['slow2']:
        sigmoid(X, out=tmp)           
        s2 = tmp > u

    # Rolling our sigmoid is a bit slower than using the library function.
    # Not to mention this implementation isn't as numerically stable.
    with T['slow3']:
        s3 = 1/(1+exp(-X)) > u
    # The fast method.
    with T['fast']:
        f = X > z
    F += f / R    
    assert (s1 == f).all()
    assert (s2 == f).all()
    assert (s3 == f).all()

pl.plot(X, F)
pl.plot(X, sigmoid(X), c='r', lw=2)
fast is 28.4239x faster than slow1 (avg: slow1: 0.00114061 fast: 4.01285e-05)
slow2 is 1.0037x faster than slow1 (avg: slow1: 0.00114061 slow2: 0.0011364)
slow1 is 1.0840x faster than slow3 (avg: slow3: 0.00123637 slow1: 0.00114061)

It looks like our trick is about $28$x faster than the fastest competing slow method!

We also see that the assert statements passed, which means that the methods tested produce precisely the same samples.

The final plot demonstrates that we get the right expected value (red curve) as we sweep the distributions parameter (x-axis).


We could alternatively use the Gumbel max trick to derive a similar algorithm. If we ground out the trick for a sigmoid instead of a general mutlinomal distributions, we end up with

$$ Z_0 \sim \textrm{Gumbel}(0,1) $$$$ Z_1 \sim \textrm{Gumbel}(0,1) $$$$ Y = x > Z_0 - Z_1 $$

Much like our new trick, this one benefits from the fact that all expensive stuff is done independent of the data (i.e., the value of $x$). However, it seems silly that we "need" to generate two Gumbel RVs to get one sample from the sigmoid. With a little bit of Googling, we discover that the difference of $\textrm{Gumbel}(0,1)$ RVs is a logistic RV (specifically $\textrm{Logistic}(0,1)$).

It turns out that $\textrm{logit}(\textrm{Uniform}(0,1))$ is a $\textrm{Logistic}(0,1)$ RV.

Voila! Our fast sampling trick and the Gumbel max trick are connected!

Another trick is Justin Domke's trick to reduce calls to $\exp$ by $\approx 88\%$. The disadvantage of this approach is that it's harder to implement with vectorization. The advantage is that we don't need to precompute any expensive things.


Logit plot

In [3]:
xs = np.linspace(0,1,100)
ys = logit(xs)
pl.plot(xs, ys);

Logistic random variable

Check that our sampling method is equivalent to sampling from a logistic distribution.

In [5]:
from scipy.stats import logistic
u = uniform(0,1,size=10000)
z = logit(u)
pl.hist(z, bins=100, normed=1)
xs = np.linspace(-6,6,100)
ys = logistic.pdf(xs)
pl.plot(xs, ys, c='r', lw=2);

Sqrt-biased sampling

The following post is about instance of "sampling in proportion to \(p\) is not optimal, but you probably think it is." It's surprising how few people seem to know this trick. Myself included! It was brought to my attention recently by Nikos Karampatziakis. (Thanks, Nikos!)

The paper credited for this trick is Press (2008). I'm borrowing heavily from that paper as well as an email exchange from Nikos.

Setting: Suppose you're an aspiring chef with a severe head injury affecting your long- and short- term memory trying to find a special recipe from a cookbook that you made one time but just can't remember exactly which recipe it was. So, based on the ingredients of each recipe, you come up with a prior probability \(p_i\) that recipe \(i\) is the one you're looking for. In total, the cookbook has \(n\) recipes and \(\sum_{i=1}^n p_i = 1.\)

A good strategy would be to sort recipes by \(p_i\) and cook the most promising ones first. Unfortunately, you're not a great chef so there is some probability that you'll mess-up the recipe. So, it's a good idea to try recipes multiple times. Also, you have no short term memory...

This suggests a sampling with replacement strategy, where we sample a recipe from the cookbook to try independently of whether we've tried it before (called a memoryless strategy). Let's give this strategy the name \(\boldsymbol{q}.\) Note that \(\boldsymbol{q}\) is a probability distribution over the recipes in the cookbook, just like \(\boldsymbol{p}.\)

How many recipes until we find the special one? To start, suppose the special recipe is \(j.\) Then, the expected number of recipes we have to make until we find \(j\) under the strategy \(\boldsymbol{q}\) is

$$ \sum_{t=1}^\infty t \cdot (1 - q_j)^{t-1} q_{j} = 1/q_{j}. $$

The equation says that expected time it takes to sample \(j\) for the first time is the probability we didn't sample for \((t-1)\) steps times the probability we sample it at time \(t.\) We multiply this probability by the time \(t\) to get the expected time.

Note that this equation assumes that we known \(j\) is the special recipe with certainty when we sample it. We'll revisit this assumption later when we consider potential errors in executing the recipe.

Since we don't known which \(j\) is the right one, we take an expectation over it according to the prior distribution, which yields the following equation,

$$ f(\boldsymbol{q}) = \sum_{i=1}^n \frac{p_i}{q_i}. $$

The first surprising thing: Uniform is just as good as \(\boldsymbol{p}\), yikes! \(f(\boldsymbol{p}) = \sum_{i=1}^n \frac{p_i}{p_i} = n\) and \(f(\text{uniform}(n)) = \sum_{i=1}^n \frac{p_i }{ 1/n } = n.\) (Assume, without loss of generality, that \(p_i > 0\) since we can just drop these elements from \(\boldsymbol{p}.\))

What's the optimal \(\boldsymbol{q}\)? We can address this question by solving the following optimization (which will have a nice closed form solution),

$$ \begin{eqnarray*} && \boldsymbol{q}^* = \underset{\boldsymbol{q}}{\operatorname{argmin}} \sum_{i=1}^n \frac{p_i}{q_i} \\ && \ \ \ \ \ \ \ \ \text{ s.t. } \sum_{i=1}^n q_i = 1 \\ && \ \ \ \ \ \ \ \ \ \ \ \ \, q_1 \ldots q_n \ge 0. \end{eqnarray*} $$

The optimization problem says minimize the expected time to find the special recipe. The constraints enforce that \(\boldsymbol{q}\) be a valid probability distribution.

The optimal strategy, which we get via Lagrange multipliers, turns out to be,

$$ q^*_i = \frac{ \sqrt{p_i} }{ \sum_{j=1}^n \sqrt{p_j} }. $$

How much better is \(q^*\)?

$$ f(q^*) = \sum_i \frac{p_i}{q^*_i} = \sum_i \frac{p_i}{ \frac{\sqrt{p_i} }{ \sum_j \sqrt{p_j}} } = \left( \sum_i \frac{p_i}{ \sqrt{p_i} } \right) \left( \sum_j \sqrt{p_j} \right) = \left( \sum_i \sqrt{p_i} \right)^2 $$

which sometimes equals \(n\), e.g., when \(\boldsymbol{p}\) is uniform, but is never bigger than \(n.\)

What's the intuition? The reason why the \(\sqrt{p}\)-scheme is preferred is because we save on additional cooking experiments. For example, if a recipe has \(k\) times higher prior probability than the average recipe, then we will try that recipe \(\sqrt{k}\) times more often; compared to \(k\), which we'd get under \(\boldsymbol{p}.\) Additional cooking experiments are not so advantageous.

Allowing for noise in the cooking process: Suppose that for each recipe we had a prior belief about how hard that recipe is for us to cook. Denote that belief \(s_i\), these belief are between zero (never get it right) and one (perfect every time) and do not sum to one over the cookbook.

Following a similar derivation to before, the time to cook the special recipe \(j\) and cook it correctly is,

$$ \sum_{t=1}^\infty t \cdot (1 - \color{red}{s_j} q_j)^{t-1} q_{j} \color{red}{s_j} = \frac{1}{s_j \cdot q_j} $$

That gives rise to a modified objective,

$$ f'(\boldsymbol{q}) = \sum_{i=1}^n \frac{p_i}{\color{red}{s_i} \cdot q_i} $$

This is exactly the same as the previous objective, except we've replaced \(p_i\) with \(p_i/s_i.\) Thus, we can reuse our previous derivation to get the optimal strategy, \(q^*_i \propto \sqrt{p_i / s_i}.\) If noise is constant, then we recover the original solution, \(q^*_i \propto \sqrt{p_i}.\)

Extension to finding multiple tasty recipes: Suppose we're trying to find several tasty recipes, not just a single special one. Now, \(p_i\) is our prior belief that we'll like the recipe at all. How do we minimize the time until we find a tasty one? It turns out the same trick works without modification because all derivations apply to each recipe independently. The same trick works if \(p_i\) does not sums to one over \(n.\) For example, if \(p_i\) is the independent probability that you'll like recipe \(i\) at all, not the probability that it's the special one.

Beyond memoryless policies: Clearly, our choice of a memoryless policy can be beat by a policy family that balances exploration (trying new recipes) and exploitation (trying our best guess).

  • Overall, the problem we've posed is similar to a multi-armed bandit. In our case, the arms are the recipes, pulling the arm is trying the recipe and the reward is whether or not we liked the recipe (possibly noisy). The key difference between our setup and multi-armed bandits is that we trust our prior distribution \(\boldsymbol{p}\) and noise model \(\boldsymbol{s}.\)

  • If the amount of noise \(s_i\) is known and we trust the prior \(p_i\) then there is an optimal deterministic (without-replacement) strategy that we can get by sorting the recipes by \(p_i\) accounting for the error rates \(s_i.\) This approach is described in the original paper.

A more realistic application: In certain language modeling applications, we avoid computing normalization constants (which require summing over a massive vocabulary) by using importance sampling, negative sampling or noise contrastive estimation techniques (e.g., Ji+,16; Levy+,15). These techniques depend on a proposal distribution, which folks often take to be the unigram distribution. Unfortunately, this gives too many samples of stop words (e.g., "the", "an", "a"), so practitioners "anneal" the unigram distribution (to increase the entropy), that is sample from \(q_i \propto p_{\text{unigram},i}^\alpha.\) Typically, \(\alpha\) is set by grid search and (no surprise) \(\alpha \approx 1/2\) tends to work best! The \(\sqrt{p}\)-sampling trick is possibly a reverse-engineered justification in favor of annealing as "the right thing to do" (e.g., why not do additive smoothing?) and it even tells us how to set the annealing parameter \(\alpha.\) The key assumption is that we want to sample the actual word at a given position as often as possible while still being diverse thanks to the coverage of unigram prior. (Furthermore, memoryless sampling leads to simpler algorithms.)