Graduate Descent

Counterfactual reasoning and learning from logged data

Counterfactual reasoning is reasoning about data that we did not observe. For example, reasoning about the expected reward of new policy given data collected from a older one.

In this post, I'll discuss some basic techniques for learning from logged data. For the large part, this post is based on things I learned from Peshkin & Shelton (2002) and Bottou et al. (2013) (two of all-time favorite papers).

After reading, have a look at the Jupyter notebook accompanying this post!

Setup (off-line off-policy optimization): We're trying to optimize a function of the form,

$$ J(\theta) = \underset{p_\theta}{\mathbb{E}} \left[ r(x) \right] = \sum_{x \in \mathcal{X}} p_\theta(x) r(x). $$


But! We only have a fixed sample of size \(m\) from a data collection policy \(q\), \(\{ (r^{(j)}, x^{(j)} ) \}_{j=1}^m \overset{\text{i.i.d.}} \sim q.\)

  • Although, it's not necessarily the case, you can think of \(q = p_{\theta'}\) for a fixed value \(\theta'.\)

  • \(\mathcal{X}\) is an arbitrary multivariate space, which permits a mix of continuous and discrete components, with appropriate densities, \(p_{\theta}\) and \(q\) defined over it.

  • \(r: \mathcal{X} \mapsto \mathbb{R}\) is a black box that outputs a scalar score.

  • I've used the notation \(r^{(j)}\) instead of \(r(x^{(j)})\) to emphasize that we can't evaluate \(r\) at \(x\) values other than those in the sample.

  • We'll assume that \(q\) assigns positive probability everywhere, \(q(x) > 0\) for all \(x \in \mathcal{X}\). This means is that the data collection process must be randomized and eventually sample all possible configurations. Later, I discuss relaxing this assumption.

Each distribution is a product of one or more factors of the following types: policy factors (at least one), which directly depend on \(\theta\), and environment factors (possibly none), which do not depend directly on \(\theta\). Note that environment factors are only accessible via sampling (i.e., we don't know the value they assign to a sample). For example, a contextual bandit problem, where \(x\) is a context-action pair, \(x = (s,a)\). Here \(q(x) = q(a|s) p(s)\) and \(p_{\theta}(x) = p_{\theta}(a|s) p(s)\). Note that \(p_{\theta}\) and \(q\) share the environment factor \(p(s)\), the distribution over contexts, and only differ in the action-given-context factor. For now, assume that we can evaluate all environment factors; later, I'll discuss how we cleverly work around it.

The main challenge of this setting is that we don't have controlled experiments to learn from because \(q\) is not (completely) in our control. This manifests itself as high variance ("noise") in estimating \(J(\theta)\). Consider the contextual bandit setting, we receive a context \(s\) and execute single action; we never get to rollback to that precise context and try an alternative action (to get a paired sample #yolo) because we do not control \(p(s)\). This is an important paradigm for many 'real world' problems, e.g., predicting medical treatments or ad selection.

Estimating \(J(\theta)\) [V1]: We obtain an unbiased estimator of \(J(\theta)\) with importance sampling,

$$ J(\theta) \approx \hat{J}_{\!\text{IS}}(\theta) = \frac{1}{m} \sum_{j=1}^m r^{(j)} \!\cdot\! w^{(j)}_{\theta} \quad \text{ where } w^{(j)}_{\theta} = \frac{p_{\theta}(x^{(j)}) }{ q(x^{(j)}) }. $$


This estimator is remarkable: it uses importance sampling as a function approximator! We have an unbiased estimate of \(J(\theta)\) for any value of \(\theta\) that we like. The catch is that we have to pick \(\theta\) a priori, i.e., with no knowledge of the sample.

After we've collected a (large) sample it's possible to optimize \(\hat{J}_{\!\text{IS}}\) using any optimization algorithm (e.g., L-BFGS). Of course, we risk overfitting to the sample if we evaluate \(\hat{J}_{\!\text{IS}}\). Actually, it's a bit worse: this objective tends to favor regions of \(\theta\), which are not well represented in the sample because the importance sampling estimator has high variance in these regions resulting from large importance weights (when \(q(x)\) is small and \(p_{\theta}(x)\) is large, \(w(x)\) is large and consequently so is \(\hat{J}_{\!\text{IS}}\) regardless of whether \(r(x)\) is high!). Thus, we want some type of "regularization" to keep the optimizer in regions which are sufficiently well-represented by the sample.

Better surrogate [V2]: There are many ways to improve the variance of the estimator and confidently obtain improvements to the system. One of my favorites is Bottou et al.'s lower bound on \(J(\theta)\), which we get by clipping importance weights, replace \(w^{(j)}_{\theta}\) with \(\min(R, w^{(j)}_{\theta})\).

Confidence intervals [V3]: We can augment the V2 lower bound with confidence intervals derived from the empirical Bernstein bound (EBB). We'll require that \(r\) is bounded and that we know its max/min values. The EBB penalizes hypotheses (values of \(\theta\)) which have higher sample variance. (Note: a Hoeffding bound wouldn't change the shape of the objective, but EBB does thanks to the sample variance penalty. EBB tends to be tighter.). The EBB introduces an additional "confidence" hyperparameter, \((1-\delta)\). Bottou et al. recommend maximizing the lower bound as it provides safer improvements. See the original paper for the derivation.

Both V2 and V3 are biased (as they are lower bounds), but we can mitigate the bias by tuning the hyperparameter \(R\) on a heldout sample (we can even tune \(\delta\), if desired). Additionally, V2 and V3 are 'valid' when \(q\) has limited support since they prevent the importance weights from exploding (of course, the bias can be arbitrarily bad, but probably unavoidable given the learning-from-only-logged data setup).

Extensions

Unknown environment factors: Consider the contextual bandit setting (mentioned above). Here \(p\) and \(q\) share an unknown environment factor: the distribution of contexts. Luckily, we do not need to know the value of this factor in order to apply any of our estimators because they are all based on likelihood ratios, thus the shared unknown factors cancel out! Some specific examples are given below. Of course, these factors do influence the estimators because they are crucial in sampling, they just aren't necessary in evaluation.

  • In contextual bandit example, \(x\) is a state-action pair, \(w_{\theta}(x) = \frac{p_{\theta}(x)}{q(x)} = \frac{ p_{\theta}(s,a) }{ q(s,a) } = \frac{p_{\theta}(a|s) p(s)}{q(a|s) p(s)} = \frac{p_{\theta}(a|s) }{ q(a|s) }\).

  • In a Markov decision process, \(x\) is a sequence of state-action pairs, \(w_{\theta}(x) = \frac{p_{\theta}(x)}{q(x)} = \frac{ p(s_0) \prod_{t=0}^T p(s_{t+1}|s_t,a_t) p_\theta(a_t|s_t) } { p(s_0) \prod_{t=0}^T p(s_{t+1}|s_t,a_t) q(a_t|s_t) } = \frac{\prod_{t=0}^T \pi_\theta(a_t|s_t)} {\prod_{t=0}^T q(a_t|s_t)}.\)

Variance reduction: These estimators can all be improved with variance reduction techniques. Probably the most effective technique is using control variates (of which baseline functions are a special case). These are random variables correlated with \(r(x)\) for which we know their expectations (or at least they are estimated separately). A great example is how ad clicks depend strongly on time-of-day (fewer people are online late at night so we get fewer clicks), thus the time-of-day covariate explains a large part of the variation in \(r(x)\).

Estimation instead of optimization: You can use this general setup for estimation instead of optimization, in which case it's fine to let \(r\) have real-valued multivariate output. The confidence intervals are probably useful in that setting too.

Unknown \(q\): Often \(q\) is an existing complex system, which does not record its probabilities. It is possible to use regression to estimate \(q\) from the samples, which is called the propensity score (PS). PS attempts to account for confounding variables, which are hidden causes that control variation in the data. Failing to account for confounding variables may lead to incorrect conclusions. Unfortunately, PS results in a biased estimator because we're using a 'ratio of expectations' (we'll divide by the PS estimate) instead of an 'expectation of ratios'. PS is only statistically consistent in the (unlikely) event that the density estimate is correctly specified (i.e., we can eventually get \(q\) correct). In the unknown \(q\) setting, it's often better to use the doubly-robust estimator (DR) which combines two estimators: a density estimator for \(q\) and a function approximation for \(r\). A great explanation for the bandit case is in Dudík et al. (2011). The DR estimator is also biased, but it has a better bias-variance tradeoff than PS.

What if \(q\) doesn't have support everywhere? This is an especially important setting because it is often the case that data collection policies abide by some safety regulations, which prevent known bad configurations. In many situations, evaluating \(r(x)\) corresponds to executing an action \(x\) in the real world so terrible outcomes could occur, such as, breaking a system, giving a patient a bad treatment, or losing money. V1 is ok to use as long as we satisfy the importance sampling support conditions, which might mean rejecting certain values for \(\theta\) (might be non-trivial to enforce) and consequently finding a less-optimal policy. V2 and V3 are ok to use without an explicit constraint, but additional care may be needed to ensure specific safety constraints are satisfied by the learned policy.

What if \(q\) is deterministic? This is related to the point above. This is a hard problem. Essentially, this trying to learn without any exploration / experimentation! In general, we need exploration to learn. Randomization isn't the only way to perform exploration, there are many systematic types of experimentation.

  • There are some cases of systematic experimentation that are ok. For example, enumerating all elements of \(\mathcal{X}\) (almost certainly infeasible). Another example is a contextual bandit where \(q\) assigns actions to contexts deterministically via a hash function (this setting is fine because \(q\) is essentially a uniform distribution over actions, which is independent of the state). In other special cases, we may be able to characterize systematic exploration as stratified sampling.

  • A generic solution might be to apply the doubly-robust estimator, which "smooths out" deterministic components (by pretending they are random) and accounting for confounds (by explicitly modeling them in the propensity score).

What if we control data collection (\(q\))? This is an interesting setting. Essentially, it asks "how do we explore/experiment optimally (and safely)?". In general, this is an open question and depends on many considerations, such as, how much control, exploration cost (safety constraints) and prior knowledge (of \(r\) and unknown factors in the environment). I've seen some papers cleverly design \(q\). The first that comes to mind is Levine & Koltun (2013). Another setting is online contextual bandits, in which algorithms like EXP4 and Thompson sampling, prescribe certain types of exploration and work interactively (i.e., they don't have a fixed training sample). Lastly, I'll mention that there are many techniques for variance reduction by importance sampling, which may apply.

Further reading

Léon Bottou, Jonas Peters, Joaquin Quiñonero-Candela, Denis X. Charles, D. Max Chickering, Elon Portugaly, Dipankar Ray, Patrice Simard, Ed Snelson. Counterfactual reasoning in learning systems. JMLR 2013.

The source for the majority of this post. It includes many other interesting ideas and goes more in depth into some of the details.

Miroslav Dudík, John Langford, Lihong Li. Doubly robust policy evaluation and learning. ICML 2011.

Discussed in extensions section.

Philip S. Thomas. Safe reinforcement learning. PhD Thesis 2015.

Covers confidence intervals for policy evaluation similar to Bottou et al., as well as learning algorithms for RL with safety guarantees (e.g., so we don't break the robot).

Peshkin and Shelton. Learning from scarce experience. ICML 2002.

An older RL paper, which covers learning from logged data. This is one of the earliest papers on learning from logged data that I could find.

Levine and Koltun. Guided policy search. ICML 2013.

Discusses clever choices for \(q\) to better-guide learning in the RL setting.

Corinna Cortes, Yishay Mansour, Mehryar Mohri. Learning bounds for importance weighting. NIPS 2010.

Discusses generalization bounds for the counterfactual objective. Includes an alternative weighting scheme to keep importance weights from exploding.

Heaps for incremental computation

In this post, I'll describe a neat trick for maintaining a summary quantity (e.g., sum, product, max, log-sum-exp, concatenation, cross-product) under changes to its inputs. The trick and it's implementation are inspired by the well-known max-heap datastructure. I'll also describe a really elegant application to fast sampling under an evolving categorical distribution.

Setup: Suppose we'd like to efficiently compute a summary quantity under changes to its \(n\)-dimensional input vector \(\boldsymbol{w}\). The particular form of the quantity we're going to compute is \(z = \bigoplus_{i=1}^n w_i\), where \(\oplus\) is some associative binary operator with identity element \(\boldsymbol{0}\).

The trick: Essentially, the trick boils down to parenthesis placement in the expression which computes \(z\). A freedom we assumed via the associative property.

I'll demonstrate by example with \(n=8\).

Linear structure: We generally compute something like \(z\) with a simple loop. This looks like a right-branching binary tree when we think about the order of operations,

$$ z = (((((((w_1 \oplus w_2) \oplus w_3) \oplus w_4) \oplus w_5) \oplus w_6) \oplus w_7) \oplus w_8). $$


Heap structure: Here the parentheses form a balanced tree, which looks much more like a recursive implementation that computes the left and right halves and \(\oplus\)s the results (divide-and-conquer style),

$$ z = (((w_1 \oplus w_2) \oplus (w_3 \oplus w_4)) \oplus ((w_5 \oplus w_6) \oplus (w_7 \oplus w_8))). $$


The benefit of the heap structure is that there are \(\mathcal{O}(\log n)\) intermediate quantities that depend on any input, whereas the linear structure has \(\mathcal{O}(n)\). The intermediate quantities correspond to the values of each of the parenthesized expressions.

Since fewer intermediate quantities depend on a given input, fewer intermediates need to be adjusted upon a change to the input. Therefore, we get faster algorithms for maintaining the output quantity \(z\) as the inputs change.

Heap datastructure (aka binary index tree or Fenwick tree): We're going to store the values of the intermediates quantities and inputs in a heap datastructure, which is a complete binary tree. In our case, the tree has depth \(1 + \lceil \log_2 n \rceil\), with the values of \(\boldsymbol{w}\) at it's leaves (aligned left) and padding with \(\boldsymbol{0}\) for remaining leaves. Thus, the array's length is \(< 4 n\).

This structure makes our implementation really nice and efficient because we don't need pointers to find the parent or children of a node (i.e., no need to wrap elements into a "node" class like in a general tree data structure). So, we can pack everything into an array, which means our implementation has great memory/cache locality and low storage overhead.

Traversing the tree is pretty simple: Let \(d\) be the number of internal nodes, nodes \(1 \le i \le d\) are interal. For node \(i\), left child \(\rightarrow {2 \cdot i},\) right child \(\rightarrow {2 \cdot i + 1},\) parent \(\rightarrow \lfloor i / 2 \rfloor.\) (Note that these operations assume the array's indices start at \(1\). We generally fake this by adding a dummy node at position \(0\), which makes implementation simpler.)

Initializing the heap: Here's code that initializes the heap structure we just described.

def sumheap(w):
    "Create sumheap from weights `w` in O(n) time."
    n = w.shape[0]
    d = int(2**np.ceil(np.log2(n)))  # number of intermediates
    S = np.zeros(2*d)                # intermediates + leaves
    S[d:d+n] = w                     # store `w` at leaves.
    for i in reversed(range(1, d)):
        S[i] = S[2*i] + S[2*i + 1]
    return S

Updating \(w_k\) boils down to fixing intermediate sums that (transitively) depend on \(w_k.\) I won't go into all of the details here, instead I'll give code (below). I'd like to quickly point out that the term "parents" is not great for our purposes because they are actually the dependents: when an input changes the value the parents, grand parents, great grand parents, etc, become stale and need to be recomputed bottom up (from the leaves). The code below implements the update method for changing the value of \(w_k\) and runs in \(\mathcal{O}(\log n)\) time.

def update(S, k, v):
    "Update w[k] = v` in time O(log n)."
    d = S.shape[0]
    i = d//2 + k
    S[i] = v
    while i > 0:   # fix parents in the tree.
        i //= 2
        S[i] = S[2*i] + S[2*i + 1]

Remarks

  • Numerical stability: If the operations are noisy (e.g., floating point operator), then the heap version may be better behaved. For example, if operations have an independent, additive noise rate \(\varepsilon\) then noise of \(z_{\text{heap}}\) is \(\mathcal{O}(\varepsilon \cdot \log n)\), whereas \(z_{\text{linear}}\) is \(\mathcal{O}(\varepsilon \cdot n)\). (Without further assumptions about the underlying operator, I don't believe you can do better than that.)

  • Relationship to max-heap: In the case of a max or min heap, we can avoid allocating extra space for intermediate quantities because all intermediates values are equal to exactly one element of \(\boldsymbol{w}\).

  • Change propagation: The general idea of adjusting cached intermediate quantities is a neat idea. In fact, we encounter it each time we type make at the command line! The general technique goes by many names—including change propagation, incremental maintenance, and functional reactive programming—and applies to basically any side-effect-free computation. However, it's most effective when the dependency structure of the computation is sparse and requires little overhead to find and refresh stale values. In our example of computing \(z\), these considerations manifest themselves as the heap vs linear structures and our fast array implementation instead of a generic tree datastructure.

Generalizations

  • No zero? No problem. We don't actually require a zero element. So, it's fair to augment \(\boldsymbol{K} \cup \{ \textsf{null} \}\) where \(\textsf{null}\) is distinguished value (i.e., \(\textsf{null} \notin \boldsymbol{K}\)) that acts just like a zero after we overload \(\oplus\) to satisfy the definition of a zero (e.g., by adding an if-statement).

  • Generalization to an arbitrary maps instead of fixed vectors is possible with a "locator" map, which a bijective map from elements to indices in a dense array.

  • Support for growing and shrinking: We support growing by maintaining an underlying array that is always slightly larger than we need—which we're already doing in the heap datastructure. Doubling the size of the underlying array (i.e., rounding up to the next power of two) has the added benefit of allowing us to grow \(\boldsymbol{w}\) at no asymptotic cost! This is because the resize operation, which requires an \(\mathcal{O}(n)\) time to allocate a new array and copying old values, happens so infrequently that they can be completely amortized. We get of effect of shrinking by replacing the old value with \(\textsf{null}\) (or \(\boldsymbol{0}\)). We can shrink the underlying array when the fraction of nonzeros dips below \(25\%\). This prevents "thrashing" between shrinking and growing.

Application

Sampling from an evolving distribution: Suppose that \(\boldsymbol{w}\) corresponds to a categorical distributions over \(\{1, \ldots, n\}\) and that we'd like to sample elements from in proportion to this (unnormalized) distribution.

Other methods like the alias or inverse CDF methods are efficient after a somewhat costly initialization step. But! they are not as efficient as the heap sampler when the distribution is being updated. (I'm not sure about whether variants of alias that support updates exist.)

Method Sample Update Init
alias O(1) O(n)? O(n)
i-CDF O(log n) O(n) O(n)
heap O(log n) O(log n) O(n)

Use cases include

  • Gibbs sampling, where distributions are constantly modified and sampled from (changes may not be sparse so YMMV). The heap sampler is used in this paper.

  • EXP3 (mutli-armed bandit algorithm) is an excellent example of an algorithm that samples and modifies a single weight in the distribution.

  • Stochastic priority queues where we sample proportional to priority and the weights on items in the queue may change, elements are possibly removed after they are sampled (i.e., sampling without replacement), and elements are added.

Again, I won't spell out all of the details of these algorithms. Instead, I'll just give the code.

Inverse CDF sampling

def sample(w):
    "Ordinary sampling method, O(n) init, O(log n) per sample."
    c = w.cumsum()            # build cdf, O(n)
    p = uniform() * c[-1]     # random probe, p ~ Uniform(0, z)
    return c.searchsorted(p)  # binary search, O(log n)

Heap sampling is essentially the same, except the cdf is stored as heap, which is perfect for binary search!

def hsample(S):
    "Sample from sumheap, O(log n) per sample."
    d = S.shape[0]//2     # number of internal nodes.
    p = uniform() * S[1]  # random probe, p ~ Uniform(0, z)
    # Use binary search to find the index of the largest CDF (represented as a
    # heap) value that is less than a random probe.
    i = 1
    while i < d:
        # Determine if the value is in the left or right subtree.
        i *= 2         # Point at left child
        left = S[i]    # Probability mass under left subtree.
        if p > left:   # Value is in right subtree.
            p -= left  # Subtract mass from left subtree
            i += 1     # Point at right child
    return i - d

Code: Complete code and test cases for heap sampling are available in this gist.

Reversing a sequence with sublinear space

Suppose we have a computation which generates sequence of states \(s_1 \ldots s_n\) according to \(s_{t} = f(s_{t-1})\) where \(s_0\) is given.

We'd like to devise an algorithm, which can reconstruct each point in the sequence efficiently as we traverse it backwards. You can think of this as "hitting undo" from the end of the sequence or reversing a singly-liked list.

Obviously, we could just record the entire sequence, but if \(n\) is large or the size of each state is large, this will be infeasible.

Idea 0: Rerun the forward pass \(n\) times. Runtime \(\mathcal{O}(n^2)\), space \(\mathcal{O}(1)\).

Idea 1: Suppose we save \(0 < k \le n\) evenly spaced "checkpoint" states. Clearly, this gives us \(\mathcal{O}(k)\) space, but what does it do to the runtime? Well, if we are at time \(t\) the we have to "replay" computation from the last recorded checkpoint to get \(s_t\), which takes \(O(n/k)\) time. Thus, the overall runtimes becomes \(O(n^2/k)\). This runtime is not ideal.

Idea 2: Idea 1 did something kind of silly, within a chunk of size \(n/k\), it does each computation multiple times! Suppose we increase the memory requirement just a little bit to remember the current chunk we're working on, making it now \(\mathcal{O}(k + n/k)\). Now, we compute each state at most \(2\) times: once in the initial sequence and once in the reverse. This implies a linear runtime. Now, the question: how should we set \(k\) so that we minimize extra space? Easy! Solve the following little optimization problem:

$$ \underset{k}{\textrm{argmin}}\ k+n/k = \sqrt{n} $$

That's nuts! We get away with sublinear space \(\mathcal{O}(\sqrt{n})\) and we only blow up our runtime by a factor of 2. Also, I really love the "introduce a parameter then optimize it out" trick.

Idea 3: What if we apply "the remember \(k\) states" trick recursively? I'm going to work this out for \(k=2\) (and then claim that the value of \(k\) doesn't matter).

Run forward to get the midpoint at \(s_{m}\), where \(m=b + \lfloor n/2 \rfloor\). Next, recurse on the left and right chunks \([b,m)\) and \([m,e)\). We hit the base case when the width of the interval is one.

Note that we implicitly store midpoints as we recurse (thanks to the stack frame). The max depth of the recursion is \(\mathcal{O}(\log n)\), which gives us a \(\mathcal{O}(\log n)\) space bound.

We can characterize runtime with the following recurrence relation, \(T(n) = 2 \cdot T(n/2) + \mathcal{O}(n)\). Since we recognize this as the recurrence for mergesort, we know that it flattens to \(\mathcal{O}(n \log n)\) time. Also, just like in the case of sorting, the branching factor doesn't matter so we're happy with or initial assumption that \(k=2\).

Remarks

The algorithms describe in this post are generic algorithmic tricks, which has been used in a number of place, including

  • The classic computer science interview problem of reversing a singly-linked list under a tight budget on additional memory.

  • Backpropagation for computing gradients in sequence models, including HMMs (Zweig & Padmanabhan, 2000) and RNNs (Chen et al., 2016).

  • Memory-efficient omniscient debugging, which allows a user to inspect program state while moving forward and backward in time.

Here's the code gist, complete with test cases.