Graduate Descent

How to test gradient implementations

Setup: Suppose we have a function, \(f: \mathbb{R}^n \rightarrow \mathbb{R}\), and we want to test code that computes \(\nabla f\). (Note that these techniques also apply when \(f\) has multivariate output.)

Finite-difference approximation

The main way that people test gradient computation is by comparing it against a finite-difference (FD) approximation to the gradient:

$$ \boldsymbol{d}^\top\! \nabla f(\boldsymbol{x}) \approx \frac{1}{2 \varepsilon}(f(\boldsymbol{x} + \varepsilon \cdot \boldsymbol{d}) - f(\boldsymbol{x} - \varepsilon \cdot \boldsymbol{d})) $$

where \(\boldsymbol{d} \in \mathbb{R}^n\) is an arbitrary "direction" in parameter space. We will look at many directions when we test. Generally, people take the \(n\) elementary vectors as the directions, but random directions are just as good (and you can catch bugs in all dimensions with less than \(n\) of them).

Always use the two-sided difference formula. There is a version which doesn't add and subtract, just does one or the other. Do not use it ever.

Make sure you test multiple inputs (values of \(\boldsymbol{x}\)) or any thing else the function depends on (e.g., the minibatch).

What directions to use: When debugging, I tend to use elementary directions because they tell me something about which dimensions that are wrong... this doesn't always help though. The random directions are best when you want the test cases to run really quickly. In that case, you can switch to check a few random directions using a spherical distribution—do not sample them from a multivariate uniform!

Always test your implementation of \(f\)! It's very easy to correctly compute the gradient of the wrong function. The FD approximation is a "self-consistency" test, it does not validate \(f\) only the relationship between \(f\) and \(\nabla\! f\).

Obviously, how you test \(f\) depends strongly on what it's supposed to compute.

  • Example: For a conditional random field (CRF), you can also test that your implementation of a dynamic program for computing \(\log Z_\theta(x)\) is correctly by comparing against brute-force enumeration of \(\mathcal{Y}(x)\) on small examples.

Similarly, you can directly test the gradient code if you know a different way to compute it.

  • Example: In a CRF, we know that the \(\nabla \log Z_\theta(x)\) is a feature expectation, which you can also test against a brute-force enumeration on small examples.

Why not just use the FD approximation as your gradient?

For low-dimensional functions, you can straight-up use the finite-difference approximation instead of rolling code to compute the gradient. (Take \(n\) axis-aligned unit vectors for \(\boldsymbol{d}\).) The FD approximation is very accurate. Of course, specialized code is probably a little more accurate, but that's not really why we bother to do it! The reason why we write specialized gradient code is not improve numerical accuracy, it's to improve efficiency. As I've ranted before, automatic differentiation techniques guarantee that evaluating \(\nabla f(x)\) gradient should be as efficient as computing \(f(x)\) (with the caveat that space complexity may increase substantially - i.e., space-time tradeoffs exists). FD is \(\mathcal{O}(n \cdot \textrm{runtime } f(x))\), where as autodiff is \(\mathcal{O}(\textrm{runtime } f(x))\).

How to compare vectors

Absolute difference is the devil. You should never compare vectors in absolute difference (this is Lecture 1 of any numerical methods course). In this case, the problem is that gradients depend strongly on the scale of \(f\). If \(f\) takes tiny values then it's easy for differences to be lower than a tiny threshold.

Most people use relative error \(= \frac{|\textbf{want} - \textbf{got}|}{|\textbf{want}|}\), to get a scale-free error measure, but unfortunately relative error chokes when \(\textbf{want}\) is zero.

I compute several error measures with a script that you can import from my github arsenal.math.checkgrad.{fdcheck}.

I use two metrics to test gradients:

  1. Relative error (skipping zeros): If relative error hits a zero, I skip it. I'll rely on the other measure.

  2. Pearson correlation: Checks the direction of the gradient, but allows a scale and shift transformation. This measure doesn't have trouble with zeros, but allows scale and shift problems to pass by. Make sure you fix those errors! (e.g. In the CRF example, you might have forgotten to divide by \(Z(x)\), which not really a constant... I've made this exact mistake a few times.)

I also look at some diagnostics, which help me debug stuff:

  • Accuracy of predicting the sign {+,-,0} of each dimension (or dot random product).

  • Absolute error (just as a diagnostic)

  • Scatter plot: When debugging, I like to scatter plot the elements of FD vs. my implementation.

All these measurements (and the scatter plot) can be computed with{compare}, which I find super useful when debugging absolutely anything numerical.

Bonus tests

Testing modules: You can test the different modules of your code as well (assuming you have a composable module-based setup). E.g., I test my DP algorithm independent of how the features and downstream loss are computed. You can also test feature and downstream loss modules independent of one another. Note that autodiff (implicitly) computes Jacobian-vector products because modules are multivariate in general. We can reduce to the scalar case by taking a dot product of the outputs with a (fixed) random vector.

Something like this:

r = spherical(m)  # fixed random vector |output|=|m|
h = lambda x: module.fprop(x).dot(r)   # scalar function for use in fd

module.fprop(x)  # propagate
module.outputs.adjoint = r. # set output adjoint to r, usually we set adjoint of scalar output=1
ad = module.input.adjoint # grab the gradient
fd = fdgrad(h, x)
compare(fd, ad)

Integration tests: Test that running a gradient-based optimization algorithm is successful with your gradient implementation. Use smaller versions of your problem if possible. A related test for machine learning applications is to make sure that your model and learning procedure can (over)fit small datasets.

Test that batch = minibatch (if applicable). It's very easy to get this bit wrong. Broadcasting rules (in numpy, for example) make it easy to hide matrix conformability mishaps. So make sure you get the same results as manual minibatching (Of course, you should only do minibatching if are get a speed-up from vectorization or parallelism. You should probably test that it's actually faster.)

Further reading: I've written about gradient approximations before, you might like these articles: gradient-vector products, complex-step method. I strongly recommend learning how automatic differentiation works, I learned it from Justin Domke's course notes.

Counterfactual reasoning and learning from logged data

Counterfactual reasoning is reasoning about data that we did not observe. For example, reasoning about the expected reward of new policy given data collected from a older one.

In this post, I'll discuss some basic techniques for learning from logged data. For the large part, this post is based on things I learned from Peshkin & Shelton (2002) and Bottou et al. (2013) (two of all-time favorite papers).

After reading, have a look at the Jupyter notebook accompanying this post!

Setup (off-line off-policy optimization): We're trying to optimize a function of the form,

$$ J(\theta) = \underset{p_\theta}{\mathbb{E}} \left[ r(x) \right] = \sum_{x \in \mathcal{X}} p_\theta(x) r(x). $$

But! We only have a fixed sample of size \(m\) from a data collection policy \(q\), \(\{ (r^{(j)}, x^{(j)} ) \}_{j=1}^m \overset{\text{i.i.d.}} \sim q.\)

  • Although, it's not necessarily the case, you can think of \(q = p_{\theta'}\) for a fixed value \(\theta'.\)

  • \(\mathcal{X}\) is an arbitrary multivariate space, which permits a mix of continuous and discrete components, with appropriate densities, \(p_{\theta}\) and \(q\) defined over it.

  • \(r: \mathcal{X} \mapsto \mathbb{R}\) is a black box that outputs a scalar score.

  • I've used the notation \(r^{(j)}\) instead of \(r(x^{(j)})\) to emphasize that we can't evaluate \(r\) at \(x\) values other than those in the sample.

  • We'll assume that \(q\) assigns positive probability everywhere, \(q(x) > 0\) for all \(x \in \mathcal{X}\). This means is that the data collection process must be randomized and eventually sample all possible configurations. Later, I discuss relaxing this assumption.

Each distribution is a product of one or more factors of the following types: policy factors (at least one), which directly depend on \(\theta\), and environment factors (possibly none), which do not depend directly on \(\theta\). Note that environment factors are only accessible via sampling (i.e., we don't know the value they assign to a sample). For example, a contextual bandit problem, where \(x\) is a context-action pair, \(x = (s,a)\). Here \(q(x) = q(a|s) p(s)\) and \(p_{\theta}(x) = p_{\theta}(a|s) p(s)\). Note that \(p_{\theta}\) and \(q\) share the environment factor \(p(s)\), the distribution over contexts, and only differ in the action-given-context factor. For now, assume that we can evaluate all environment factors; later, I'll discuss how we cleverly work around it.

The main challenge of this setting is that we don't have controlled experiments to learn from because \(q\) is not (completely) in our control. This manifests itself as high variance ("noise") in estimating \(J(\theta)\). Consider the contextual bandit setting, we receive a context \(s\) and execute single action; we never get to rollback to that precise context and try an alternative action (to get a paired sample #yolo) because we do not control \(p(s)\). This is an important paradigm for many 'real world' problems, e.g., predicting medical treatments or ad selection.

Estimating \(J(\theta)\) [V1]: We obtain an unbiased estimator of \(J(\theta)\) with importance sampling,

$$ J(\theta) \approx \hat{J}_{\!\text{IS}}(\theta) = \frac{1}{m} \sum_{j=1}^m r^{(j)} \!\cdot\! w^{(j)}_{\theta} \quad \text{ where } w^{(j)}_{\theta} = \frac{p_{\theta}(x^{(j)}) }{ q(x^{(j)}) }. $$

This estimator is remarkable: it uses importance sampling as a function approximator! We have an unbiased estimate of \(J(\theta)\) for any value of \(\theta\) that we like. The catch is that we have to pick \(\theta\) a priori, i.e., with no knowledge of the sample.

After we've collected a (large) sample it's possible to optimize \(\hat{J}_{\!\text{IS}}\) using any optimization algorithm (e.g., L-BFGS). Of course, we risk overfitting to the sample if we evaluate \(\hat{J}_{\!\text{IS}}\). Actually, it's a bit worse: this objective tends to favor regions of \(\theta\), which are not well represented in the sample because the importance sampling estimator has high variance in these regions resulting from large importance weights (when \(q(x)\) is small and \(p_{\theta}(x)\) is large, \(w(x)\) is large and consequently so is \(\hat{J}_{\!\text{IS}}\) regardless of whether \(r(x)\) is high!). Thus, we want some type of "regularization" to keep the optimizer in regions which are sufficiently well-represented by the sample.

Better surrogate [V2]: There are many ways to improve the variance of the estimator and confidently obtain improvements to the system. One of my favorites is Bottou et al.'s lower bound on \(J(\theta)\), which we get by clipping importance weights, replace \(w^{(j)}_{\theta}\) with \(\min(R, w^{(j)}_{\theta})\).

Confidence intervals [V3]: We can augment the V2 lower bound with confidence intervals derived from the empirical Bernstein bound (EBB). We'll require that \(r\) is bounded and that we know its max/min values. The EBB penalizes hypotheses (values of \(\theta\)) which have higher sample variance. (Note: a Hoeffding bound wouldn't change the shape of the objective, but EBB does thanks to the sample variance penalty. EBB tends to be tighter.). The EBB introduces an additional "confidence" hyperparameter, \((1-\delta)\). Bottou et al. recommend maximizing the lower bound as it provides safer improvements. See the original paper for the derivation.

Both V2 and V3 are biased (as they are lower bounds), but we can mitigate the bias by tuning the hyperparameter \(R\) on a heldout sample (we can even tune \(\delta\), if desired). Additionally, V2 and V3 are 'valid' when \(q\) has limited support since they prevent the importance weights from exploding (of course, the bias can be arbitrarily bad, but probably unavoidable given the learning-from-only-logged data setup).


Unknown environment factors: Consider the contextual bandit setting (mentioned above). Here \(p\) and \(q\) share an unknown environment factor: the distribution of contexts. Luckily, we do not need to know the value of this factor in order to apply any of our estimators because they are all based on likelihood ratios, thus the shared unknown factors cancel out! Some specific examples are given below. Of course, these factors do influence the estimators because they are crucial in sampling, they just aren't necessary in evaluation.

  • In contextual bandit example, \(x\) is a state-action pair, \(w_{\theta}(x) = \frac{p_{\theta}(x)}{q(x)} = \frac{ p_{\theta}(s,a) }{ q(s,a) } = \frac{p_{\theta}(a|s) p(s)}{q(a|s) p(s)} = \frac{p_{\theta}(a|s) }{ q(a|s) }\).

  • In a Markov decision process, \(x\) is a sequence of state-action pairs, \(w_{\theta}(x) = \frac{p_{\theta}(x)}{q(x)} = \frac{ p(s_0) \prod_{t=0}^T p(s_{t+1}|s_t,a_t) p_\theta(a_t|s_t) } { p(s_0) \prod_{t=0}^T p(s_{t+1}|s_t,a_t) q(a_t|s_t) } = \frac{\prod_{t=0}^T \pi_\theta(a_t|s_t)} {\prod_{t=0}^T q(a_t|s_t)}.\)

Variance reduction: These estimators can all be improved with variance reduction techniques. Probably the most effective technique is using control variates (of which baseline functions are a special case). These are random variables correlated with \(r(x)\) for which we know their expectations (or at least they are estimated separately). A great example is how ad clicks depend strongly on time-of-day (fewer people are online late at night so we get fewer clicks), thus the time-of-day covariate explains a large part of the variation in \(r(x)\).

Estimation instead of optimization: You can use this general setup for estimation instead of optimization, in which case it's fine to let \(r\) have real-valued multivariate output. The confidence intervals are probably useful in that setting too.

Unknown \(q\): Often \(q\) is an existing complex system, which does not record its probabilities. It is possible to use regression to estimate \(q\) from the samples, which is called the propensity score (PS). PS attempts to account for confounding variables, which are hidden causes that control variation in the data. Failing to account for confounding variables may lead to incorrect conclusions. Unfortunately, PS results in a biased estimator because we're using a 'ratio of expectations' (we'll divide by the PS estimate) instead of an 'expectation of ratios'. PS is only statistically consistent in the (unlikely) event that the density estimate is correctly specified (i.e., we can eventually get \(q\) correct). In the unknown \(q\) setting, it's often better to use the doubly-robust estimator (DR) which combines two estimators: a density estimator for \(q\) and a function approximation for \(r\). A great explanation for the bandit case is in Dudík et al. (2011). The DR estimator is also biased, but it has a better bias-variance tradeoff than PS.

What if \(q\) doesn't have support everywhere? This is an especially important setting because it is often the case that data collection policies abide by some safety regulations, which prevent known bad configurations. In many situations, evaluating \(r(x)\) corresponds to executing an action \(x\) in the real world so terrible outcomes could occur, such as, breaking a system, giving a patient a bad treatment, or losing money. V1 is ok to use as long as we satisfy the importance sampling support conditions, which might mean rejecting certain values for \(\theta\) (might be non-trivial to enforce) and consequently finding a less-optimal policy. V2 and V3 are ok to use without an explicit constraint, but additional care may be needed to ensure specific safety constraints are satisfied by the learned policy.

What if \(q\) is deterministic? This is related to the point above. This is a hard problem. Essentially, this trying to learn without any exploration / experimentation! In general, we need exploration to learn. Randomization isn't the only way to perform exploration, there are many systematic types of experimentation.

  • There are some cases of systematic experimentation that are ok. For example, enumerating all elements of \(\mathcal{X}\) (almost certainly infeasible). Another example is a contextual bandit where \(q\) assigns actions to contexts deterministically via a hash function (this setting is fine because \(q\) is essentially a uniform distribution over actions, which is independent of the state). In other special cases, we may be able to characterize systematic exploration as stratified sampling.

  • A generic solution might be to apply the doubly-robust estimator, which "smooths out" deterministic components (by pretending they are random) and accounting for confounds (by explicitly modeling them in the propensity score).

What if we control data collection (\(q\))? This is an interesting setting. Essentially, it asks "how do we explore/experiment optimally (and safely)?". In general, this is an open question and depends on many considerations, such as, how much control, exploration cost (safety constraints) and prior knowledge (of \(r\) and unknown factors in the environment). I've seen some papers cleverly design \(q\). The first that comes to mind is Levine & Koltun (2013). Another setting is online contextual bandits, in which algorithms like EXP4 and Thompson sampling, prescribe certain types of exploration and work interactively (i.e., they don't have a fixed training sample). Lastly, I'll mention that there are many techniques for variance reduction by importance sampling, which may apply.

Further reading

Léon Bottou, Jonas Peters, Joaquin Quiñonero-Candela, Denis X. Charles, D. Max Chickering, Elon Portugaly, Dipankar Ray, Patrice Simard, Ed Snelson. Counterfactual reasoning in learning systems. JMLR 2013.

The source for the majority of this post. It includes many other interesting ideas and goes more in depth into some of the details.

Miroslav Dudík, John Langford, Lihong Li. Doubly robust policy evaluation and learning. ICML 2011.

Discussed in extensions section.

Philip S. Thomas. Safe reinforcement learning. PhD Thesis 2015.

Covers confidence intervals for policy evaluation similar to Bottou et al., as well as learning algorithms for RL with safety guarantees (e.g., so we don't break the robot).

Peshkin and Shelton. Learning from scarce experience. ICML 2002.

An older RL paper, which covers learning from logged data. This is one of the earliest papers on learning from logged data that I could find.

Levine and Koltun. Guided policy search. ICML 2013.

Discusses clever choices for \(q\) to better-guide learning in the RL setting.

Corinna Cortes, Yishay Mansour, Mehryar Mohri. Learning bounds for importance weighting. NIPS 2010.

Discusses generalization bounds for the counterfactual objective. Includes an alternative weighting scheme to keep importance weights from exploding.

Heaps for incremental computation

In this post, I'll describe a neat trick for maintaining a summary quantity (e.g., sum, product, max, log-sum-exp, concatenation, cross-product) under changes to its inputs. The trick and it's implementation are inspired by the well-known max-heap datastructure. I'll also describe a really elegant application to fast sampling under an evolving categorical distribution.

Setup: Suppose we'd like to efficiently compute a summary quantity under changes to its \(n\)-dimensional input vector \(\boldsymbol{w}\). The particular form of the quantity we're going to compute is \(z = \bigoplus_{i=1}^n w_i\), where \(\oplus\) is some associative binary operator with identity element \(\boldsymbol{0}\).

The trick: Essentially, the trick boils down to parenthesis placement in the expression which computes \(z\). A freedom we assumed via the associative property.

I'll demonstrate by example with \(n=8\).

Linear structure: We generally compute something like \(z\) with a simple loop. This looks like a right-branching binary tree when we think about the order of operations,

$$ z = (((((((w_1 \oplus w_2) \oplus w_3) \oplus w_4) \oplus w_5) \oplus w_6) \oplus w_7) \oplus w_8). $$

Heap structure: Here the parentheses form a balanced tree, which looks much more like a recursive implementation that computes the left and right halves and \(\oplus\)s the results (divide-and-conquer style),

$$ z = (((w_1 \oplus w_2) \oplus (w_3 \oplus w_4)) \oplus ((w_5 \oplus w_6) \oplus (w_7 \oplus w_8))). $$

The benefit of the heap structure is that there are \(\mathcal{O}(\log n)\) intermediate quantities that depend on any input, whereas the linear structure has \(\mathcal{O}(n)\). The intermediate quantities correspond to the values of each of the parenthesized expressions.

Since fewer intermediate quantities depend on a given input, fewer intermediates need to be adjusted upon a change to the input. Therefore, we get faster algorithms for maintaining the output quantity \(z\) as the inputs change.

Heap datastructure (aka binary index tree or Fenwick tree): We're going to store the values of the intermediates quantities and inputs in a heap datastructure, which is a complete binary tree. In our case, the tree has depth \(1 + \lceil \log_2 n \rceil\), with the values of \(\boldsymbol{w}\) at it's leaves (aligned left) and padding with \(\boldsymbol{0}\) for remaining leaves. Thus, the array's length is \(< 4 n\).

This structure makes our implementation really nice and efficient because we don't need pointers to find the parent or children of a node (i.e., no need to wrap elements into a "node" class like in a general tree data structure). So, we can pack everything into an array, which means our implementation has great memory/cache locality and low storage overhead.

Traversing the tree is pretty simple: Let \(d\) be the number of internal nodes, nodes \(1 \le i \le d\) are interal. For node \(i\), left child \(\rightarrow {2 \cdot i},\) right child \(\rightarrow {2 \cdot i + 1},\) parent \(\rightarrow \lfloor i / 2 \rfloor.\) (Note that these operations assume the array's indices start at \(1\). We generally fake this by adding a dummy node at position \(0\), which makes implementation simpler.)

Initializing the heap: Here's code that initializes the heap structure we just described.

def sumheap(w):
    "Create sumheap from weights `w` in O(n) time."
    n = w.shape[0]
    d = int(2**np.ceil(np.log2(n)))  # number of intermediates
    S = np.zeros(2*d)                # intermediates + leaves
    S[d:d+n] = w                     # store `w` at leaves.
    for i in reversed(range(1, d)):
        S[i] = S[2*i] + S[2*i + 1]
    return S

Updating \(w_k\) boils down to fixing intermediate sums that (transitively) depend on \(w_k.\) I won't go into all of the details here, instead I'll give code (below). I'd like to quickly point out that the term "parents" is not great for our purposes because they are actually the dependents: when an input changes the value the parents, grand parents, great grand parents, etc, become stale and need to be recomputed bottom up (from the leaves). The code below implements the update method for changing the value of \(w_k\) and runs in \(\mathcal{O}(\log n)\) time.

def update(S, k, v):
    "Update w[k] = v` in time O(log n)."
    d = S.shape[0]
    i = d//2 + k
    S[i] = v
    while i > 0:   # fix parents in the tree.
        i //= 2
        S[i] = S[2*i] + S[2*i + 1]


  • Numerical stability: If the operations are noisy (e.g., floating point operator), then the heap version may be better behaved. For example, if operations have an independent, additive noise rate \(\varepsilon\) then noise of \(z_{\text{heap}}\) is \(\mathcal{O}(\varepsilon \cdot \log n)\), whereas \(z_{\text{linear}}\) is \(\mathcal{O}(\varepsilon \cdot n)\). (Without further assumptions about the underlying operator, I don't believe you can do better than that.)

  • Relationship to max-heap: In the case of a max or min heap, we can avoid allocating extra space for intermediate quantities because all intermediates values are equal to exactly one element of \(\boldsymbol{w}\).

  • Change propagation: The general idea of adjusting cached intermediate quantities is a neat idea. In fact, we encounter it each time we type make at the command line! The general technique goes by many names—including change propagation, incremental maintenance, and functional reactive programming—and applies to basically any side-effect-free computation. However, it's most effective when the dependency structure of the computation is sparse and requires little overhead to find and refresh stale values. In our example of computing \(z\), these considerations manifest themselves as the heap vs linear structures and our fast array implementation instead of a generic tree datastructure.


  • No zero? No problem. We don't actually require a zero element. So, it's fair to augment \(\boldsymbol{K} \cup \{ \textsf{null} \}\) where \(\textsf{null}\) is distinguished value (i.e., \(\textsf{null} \notin \boldsymbol{K}\)) that acts just like a zero after we overload \(\oplus\) to satisfy the definition of a zero (e.g., by adding an if-statement).

  • Generalization to an arbitrary maps instead of fixed vectors is possible with a "locator" map, which a bijective map from elements to indices in a dense array.

  • Support for growing and shrinking: We support growing by maintaining an underlying array that is always slightly larger than we need—which we're already doing in the heap datastructure. Doubling the size of the underlying array (i.e., rounding up to the next power of two) has the added benefit of allowing us to grow \(\boldsymbol{w}\) at no asymptotic cost! This is because the resize operation, which requires an \(\mathcal{O}(n)\) time to allocate a new array and copying old values, happens so infrequently that they can be completely amortized. We get of effect of shrinking by replacing the old value with \(\textsf{null}\) (or \(\boldsymbol{0}\)). We can shrink the underlying array when the fraction of nonzeros dips below \(25\%\). This prevents "thrashing" between shrinking and growing.


Sampling from an evolving distribution: Suppose that \(\boldsymbol{w}\) corresponds to a categorical distributions over \(\{1, \ldots, n\}\) and that we'd like to sample elements from in proportion to this (unnormalized) distribution.

Other methods like the alias or inverse CDF methods are efficient after a somewhat costly initialization step. But! they are not as efficient as the heap sampler when the distribution is being updated. (I'm not sure about whether variants of alias that support updates exist.)

Method Sample Update Init
alias O(1) O(n)? O(n)
i-CDF O(log n) O(n) O(n)
heap O(log n) O(log n) O(n)

Use cases include

  • Gibbs sampling, where distributions are constantly modified and sampled from (changes may not be sparse so YMMV). The heap sampler is used in this paper.

  • EXP3 (mutli-armed bandit algorithm) is an excellent example of an algorithm that samples and modifies a single weight in the distribution.

  • Stochastic priority queues where we sample proportional to priority and the weights on items in the queue may change, elements are possibly removed after they are sampled (i.e., sampling without replacement), and elements are added.

Again, I won't spell out all of the details of these algorithms. Instead, I'll just give the code.

Inverse CDF sampling

def sample(w):
    "Ordinary sampling method, O(n) init, O(log n) per sample."
    c = w.cumsum()            # build cdf, O(n)
    p = uniform() * c[-1]     # random probe, p ~ Uniform(0, z)
    return c.searchsorted(p)  # binary search, O(log n)

Heap sampling is essentially the same, except the cdf is stored as heap, which is perfect for binary search!

def hsample(S):
    "Sample from sumheap, O(log n) per sample."
    d = S.shape[0]//2     # number of internal nodes.
    p = uniform() * S[1]  # random probe, p ~ Uniform(0, z)
    # Use binary search to find the index of the largest CDF (represented as a
    # heap) value that is less than a random probe.
    i = 1
    while i < d:
        # Determine if the value is in the left or right subtree.
        i *= 2         # Point at left child
        left = S[i]    # Probability mass under left subtree.
        if p > left:   # Value is in right subtree.
            p -= left  # Subtract mass from left subtree
            i += 1     # Point at right child
    return i - d

Code: Complete code and test cases for heap sampling are available in this gist.