Graduate Descent

Reversing a sequence with sublinear space

Suppose we have a computation which generates sequence of states \(s_1 \ldots s_n\) according to \(s_{t} = f(s_{t-1})\) where \(s_0\) is given.

We'd like to devise an algorithm, which can reconstruct each point in the sequence efficiently as we traverse it backwards. You can think of this as "hitting undo" from the end of the sequence or reversing a singly-liked list.

Obviously, we could just record the entire sequence, but if \(n\) is large or the size of each state is large, this will be infeasible.

Idea 0: Rerun the forward pass \(n\) times. Runtime \(\mathcal{O}(n^2)\), space \(\mathcal{O}(1)\).

Idea 1: Suppose we save \(0 < k \le n\) evenly spaced "checkpoint" states. Clearly, this gives us \(\mathcal{O}(k)\) space, but what does it do to the runtime? Well, if we are at time \(t\) the we have to "replay" computation from the last recorded checkpoint to get \(s_t\), which takes \(O(n/k)\) time. Thus, the overall runtimes becomes \(O(n^2/k)\). This runtime is not ideal.

Idea 2: Idea 1 did something kind of silly, within a chunk of size \(n/k\), it does each computation multiple times! Suppose we increase the memory requirement just a little bit to remember the current chunk we're working on, making it now \(\mathcal{O}(k + n/k)\). Now, we compute each state at most \(2\) times: once in the initial sequence and once in the reverse. This implies a linear runtime. Now, the question: how should we set \(k\) so that we minimize extra space? Easy! Solve the following little optimization problem:

$$ \underset{k}{\textrm{argmin}}\ k+n/k = \sqrt{n} $$

That's nuts! We get away with sublinear space \(\mathcal{O}(\sqrt{n})\) and we only blow up our runtime by a factor of 2. Also, I really love the "introduce a parameter then optimize it out" trick.

Idea 3: What if we apply "the remember \(k\) states" trick recursively? I'm going to work this out for \(k=2\) (and then claim that the value of \(k\) doesn't matter).

Run forward to get the midpoint at \(s_{m}\), where \(m=b + \lfloor n/2 \rfloor\). Next, recurse on the left and right chunks \([b,m)\) and \([m,e)\). We hit the base case when the width of the interval is one.

Note that we implicitly store midpoints as we recurse (thanks to the stack frame). The max depth of the recursion is \(\mathcal{O}(\log n)\), which gives us a \(\mathcal{O}(\log n)\) space bound.

We can characterize runtime with the following recurrence relation, \(T(n) = 2 \cdot T(n/2) + \mathcal{O}(n)\). Since we recognize this as the recurrence for mergesort, we know that it flattens to \(\mathcal{O}(n \log n)\) time. Also, just like in the case of sorting, the branching factor doesn't matter so we're happy with or initial assumption that \(k=2\).

Remarks

The algorithms describe in this post are generic algorithmic tricks, which has been used in a number of place, including

  • The classic computer science interview problem of reversing a singly-linked list under a tight budget on additional memory.

  • Backpropagation for computing gradients in sequence models, including HMMs (Zweig & Padmanabhan, 2000) and RNNs (Chen et al., 2016).

  • Memory-efficient omniscient debugging, which allows a user to inspect program state while moving forward and backward in time.

Here's the code gist, complete with test cases.

Evaluating ∇f(x) is as fast as f(x)

Automatic differentiation ('autodiff' or 'backprop') is great—not just because it makes it easy to rapidly prototype deep networks with plenty of doodads and geegaws, but because it means that evaluating the gradient \(\nabla f(x)\) is as fast of computing \(f(x)\). In fact, the gradient provably requires at most a small constant factor more arithmetic operations than the function itself. Furthermore, autodiff tells us how to derive and implement the gradient efficiently. This is a fascinating result that is perhaps not emphasized enough in machine learning.

The gradient should never be asymptotically slower than the function. In my recent EMNLP'16 paper, my coauthors and I found a line of work on variable-order CRFs (Ye+'09; Cuong+'14), which had an unnecessarily slow and complicated algorithm for computing gradients, which was asymptotically (and practically) slower than their forward algorithm. Without breaking a sweat, we derived a simpler and more efficient gradient algorithm by simply applying backprop to the forward algorithm (and made some other contributions).

Many algorithms are just backprop. For example, forward-backward and inside-outside, are actually just instances of automatic differentiation (Eisner,'16) (i.e., outside is just backprop on inside). This shouldn't be a surprise because these algorithms are used to compute gradients. Basically, if you know backprop and the inside algorithm, then you can derive the outside algorithm by applying the backprop transform manually. I find it easier to understand the outside algorithm via its connection to backprop, then via the usual presentation. Note that inside-outside and forward-backward pre-date backpropagation and have additional uses beyond computing gradients.

Once you've grokked backprop, the world is your oyster! You can backprop through many approximate inference algorithms, e.g., Stoyanov+'11 and many of Justin Domke's papers, to avoid issues I've mentioned before. You can even backprop through optimization algorithms to get gradients of dev loss wrt hyperparameters, e.g., Domke'12 and Maclaurin+'15.

There's at least one catch! Although the time complexity of computing the gradient is as good as the function, the space complexity may be much larger because the autodiff recipe (at least the default reverse-mode one) requires memoizing all intermediate quantities (e.g., the quantities you overwrite in a loop). There are generic methods for balancing the time-space tradeoff in autodiff, since you can (at least in theory) reconstruct the intermediate quantities by playing the forward computation again from intermediate checkpoints (at a cost to runtime, of course). A recent example is Gruslys+'16.

A final remark. Despite the name "automatic" differentiation, there is no need to rely on software to "automatically" give you gradient routines. Applying the backprop transformation is generally easy to do manually and sometimes more efficient than using a library. Many autodiff libraries lack good support for dynamic computation graph, i.e., when the structure depends on quantities that vary with the input (e.g., sentence length).

Fast sigmoid sampling

In this notebook, we describe a simple trick for efficiently sampling a Bernoulli random variable $Y$ from a sigmoid-defined distribution, $p(Y = 1) = (1 + \exp(-x))^{-1}$, where $x \in \mathbb{R}$ is the only parameter of the distribution ($x$ is often defined as the dot product of features and weights).

The "slow" method for sampling from a sigmoid,

$$ u \sim \textrm{Uniform}(0,1) $$$$ Y = sigmoid(x) > u $$

This method is slow because it calls the sigmoid function for every value of $x$. It is slow because $\exp$ is 2-3x slower than basic arithmetic operations.

In this post, I'll describe a simple trick, which is well-suited to vectorized computations (e.g., numpy, matlab). The way it works is by precomputing the expensive stuff (i.e., calls to expensive functions like $\exp$).

$$ sigmoid(x) > u \Leftrightarrow logit(sigmoid(x)) > logit(u) \Leftrightarrow x > logit(u). $$

Some details worth mentioning: (a) logit is the inverse of sigmoid and (b) logit is strictly monotonic increasing you can apply it both sides of the greater than and preserves ordering (there's a plot in the appendix).

The "fast" method derives it's advantage by leveraging the fact that expensive computation can be done independently of the data (i.e., specific values of $x$). The fast method is also interesting as just cute math. In the bonus section of this post, we'll make a connection to the Gumbel max trick.

How fast is it in practice? Below, we run a quick experiment to test that the method is correct and how fast it is.

In [1]:
%matplotlib inline
import numpy as np
import pylab as pl
from numpy.random import uniform
from numpy import exp
from scipy.special import expit as sigmoid, logit
from arsenal.timer import timers    # https://github.com/timvieira/arsenal
In [3]:
T = timers()

# These are the sigmoid parameters we're going to sample from.
n = 10000
X = np.linspace(-5,5,n)

# number of runs to average over.
R = 1000

# Used for plotting average p(Y=1)
F = np.zeros_like(X)

# Temporary array for saving on memory allocation, cf. method slow-2.
tmp = np.empty(n)                     

for _ in range(R):

    # Let's use the same random variables for all methods. This allows 
    # for a lower variance comparsion and equivalence testing.
    u = uniform(0,1,size=n)
    z = logit(u)       # used in fast method: precompute expensive stuff.

    # Requires computing sigmoid for each x.
    with T['slow1']:
        s1 = sigmoid(X) > u           
        
    # Avoid memory allocation in slow-1 by using the out option to sigmoid
    # function. It's a little bit faster than slow-1.
    with T['slow2']:
        sigmoid(X, out=tmp)           
        s2 = tmp > u

    # Rolling our sigmoid is a bit slower than using the library function.
    # Not to mention this implementation isn't as numerically stable.
    with T['slow3']:
        s3 = 1/(1+exp(-X)) > u
        
    # The fast method.
    with T['fast']:
        f = X > z
    
    F += f / R    
    assert (s1 == f).all()
    assert (s2 == f).all()
    assert (s3 == f).all()

pl.plot(X, F)
pl.plot(X, sigmoid(X), c='r', lw=2)
T.compare()
fast is 28.4239x faster than slow1 (avg: slow1: 0.00114061 fast: 4.01285e-05)
slow2 is 1.0037x faster than slow1 (avg: slow1: 0.00114061 slow2: 0.0011364)
slow1 is 1.0840x faster than slow3 (avg: slow3: 0.00123637 slow1: 0.00114061)

It looks like our trick is about $28$x faster than the fastest competing slow method!

We also see that the assert statements passed, which means that the methods tested produce precisely the same samples.

The final plot demonstrates that we get the right expected value (red curve) as we sweep the distributions parameter (x-axis).

Bonus

We could alternatively use the Gumbel max trick to derive a similar algorithm. If we ground out the trick for a sigmoid instead of a general mutlinomal distributions, we end up with

$$ Z_0 \sim \textrm{Gumbel}(0,1) $$$$ Z_1 \sim \textrm{Gumbel}(0,1) $$$$ Y = x > Z_0 - Z_1 $$

Much like our new trick, this one benefits from the fact that all expensive stuff is done independent of the data (i.e., the value of $x$). However, it seems silly that we "need" to generate two Gumbel RVs to get one sample from the sigmoid. With a little bit of Googling, we discover that the difference of $\textrm{Gumbel}(0,1)$ RVs is a logistic RV (specifically $\textrm{Logistic}(0,1)$).

It turns out that $\textrm{logit}(\textrm{Uniform}(0,1))$ is a $\textrm{Logistic}(0,1)$ RV.

Voila! Our fast sampling trick and the Gumbel max trick are connected!

Another trick is Justin Domke's trick to reduce calls to $\exp$ by $\approx 88\%$. The disadvantage of this approach is that it's harder to implement with vectorization. The advantage is that we don't need to precompute any expensive things.

Appendix

Logit plot

In [3]:
xs = np.linspace(0,1,100)
ys = logit(xs)
pl.plot(xs, ys);

Logistic random variable

Check that our sampling method is equivalent to sampling from a logistic distribution.

In [5]:
from scipy.stats import logistic
u = uniform(0,1,size=10000)
z = logit(u)
pl.hist(z, bins=100, normed=1)
xs = np.linspace(-6,6,100)
ys = logistic.pdf(xs)
pl.plot(xs, ys, c='r', lw=2);