Graduate Descent

Reversing a sequence with sublinear space

algorithms automatic-differentiation data-structures

Suppose we have a computation which generates sequence of states \(s_1 \ldots s_n\) according to \(s_{t} = f(s_{t-1})\) where \(s_0\) is given.

We'd like to devise an algorithm, which can reconstruct each point in the sequence efficiently as we traverse it backwards. You can think of this as "hitting undo" from the end of the sequence or reversing a singly linked list.

Obviously, we could just record the entire sequence, but if \(n\) is large or the size of each state is large, this will be infeasible.

Idea 0: Rerun the forward pass \(n\) times. Runtime \(\mathcal{O}(n^2)\), space \(\mathcal{O}(1)\).

Idea 1: Suppose we save \(0 < k \le n\) evenly spaced "checkpoint" states. Clearly, this gives us \(\mathcal{O}(k)\) space, but what does it do to the runtime? Well, if we are at time \(t\) the we have to "replay" computation from the last recorded checkpoint to get \(s_t\), which takes \(O(n/k)\) time. Thus, the overall runtimes becomes \(O(n^2/k)\). This runtime is not ideal.

Idea 2: Idea 1 did something kind of silly, within a chunk of size \(n/k\), it does each computation multiple times! Suppose we increase the memory requirement just a little bit to remember the current chunk we're working on, making it now \(\mathcal{O}(k + n/k)\). Now, we compute each state at most \(2\) times: once in the initial sequence and once in the reverse. This implies a linear runtime. Now, the question: how should we set \(k\) so that we minimize extra space? Easy! Solve the following little optimization problem:

$$ \underset{k}{\textrm{argmin}}\ k+n/k = \sqrt{n} $$

That's nuts! We get away with sublinear space \(\mathcal{O}(\sqrt{n})\) and we only blow up our runtime by a factor of 2. Also, I really love the "introduce a parameter then optimize it out" trick.

Idea 3: What if we apply "the remember \(k\) states" trick recursively? I'm going to work this out for \(k=2\) (and then claim that the value of \(k\) doesn't matter).

Run forward to get the midpoint at \(s_{m}\), where \(m=b + \lfloor n/2 \rfloor\). Next, recurse on the left and right chunks \([b,m)\) and \([m,e)\). We hit the base case when the width of the interval is one.

Note that we implicitly store midpoints as we recurse (thanks to the stack frame). The max depth of the recursion is \(\mathcal{O}(\log n)\), which gives us a \(\mathcal{O}(\log n)\) space bound.

We can characterize runtime with the following recurrence relation, \(T(n) = 2 \cdot T(n/2) + \mathcal{O}(n)\). Since we recognize this as the recurrence for mergesort, we know that it flattens to \(\mathcal{O}(n \log n)\) time. Also, just like in the case of sorting, the branching factor doesn't matter so we're happy with or initial assumption that \(k=2\).

Remarks

The algorithms describe in this post are generic algorithmic tricks, which has been used in a number of place, including

  • The classic computer science interview problem of reversing a singly linked list under a tight budget on additional memory.

  • Backpropagation for computing gradients in sequence models, including HMMs (Zweig & Padmanabhan, 2000) and RNNs (Chen et al., 2016). I have sample code that illustrates the basic idea below.

  • Memory-efficient omniscient debugging, which allows a user to inspect program state while moving forward and backward in time.

Sample code

Comments