Graduate Descent

Counterfactual reasoning and learning from logged data

Counterfactual reasoning is reasoning about data that we did not observe. For example, reasoning about the expected reward of new policy given data collected from a older one.

In this post, I'll discuss some basic techniques for learning from logged data. For the large part, this post is based on things I learned from Peshkin & Shelton (2002) and Bottou et al. (2013) (two of all-time favorite papers).

After reading, have a look at the Jupyter notebook accompanying this post!

Setup (off-line off-policy optimization): We're trying to optimize a function of the form,

$$ J(\theta) = \underset{p_\theta}{\mathbb{E}} \left[ r(x) \right] = \sum_{x \in \mathcal{X}} p_\theta(x) r(x). $$


But! We only have a fixed sample of size \(m\) from a data collection policy \(q\), \(\{ (r^{(j)}, x^{(j)} ) \}_{j=1}^m \overset{\text{i.i.d.}} \sim q.\)

  • Although, it's not necessarily the case, you can think of \(q = p_{\theta'}\) for a fixed value \(\theta'.\)

  • \(\mathcal{X}\) is an arbitrary multivariate space, which permits a mix of continuous and discrete components, with appropriate densities, \(p_{\theta}\) and \(q\) defined over it.

  • \(r: \mathcal{X} \mapsto \mathbb{R}\) is a black box that outputs a scalar score.

  • I've used the notation \(r^{(j)}\) instead of \(r(x^{(j)})\) to emphasize that we can't evaluate \(r\) at \(x\) values other than those in the sample.

  • We'll assume that \(q\) assigns positive probability everywhere, \(q(x) > 0\) for all \(x \in \mathcal{X}\). This means is that the data collection process must be randomized and eventually sample all possible configurations. Later, I discuss relaxing this assumption.

Each distribution is a product of one or more factors of the following types: policy factors (at least one), which directly depend on \(\theta\), and environment factors (possibly none), which do not depend directly on \(\theta\). Note that environment factors are only accessible via sampling (i.e., we don't know the value they assign to a sample). For example, a contextual bandit problem, where \(x\) is a context-action pair, \(x = (s,a)\). Here \(q(x) = q(a|s) p(s)\) and \(p_{\theta}(x) = p_{\theta}(a|s) p(s)\). Note that \(p_{\theta}\) and \(q\) share the environment factor \(p(s)\), the distribution over contexts, and only differ in the action-given-context factor. For now, assume that we can evaluate all environment factors; later, I'll discuss how we cleverly work around it.

The main challenge of this setting is that we don't have controlled experiments to learn from because \(q\) is not (completely) in our control. This manifests itself as high variance ("noise") in estimating \(J(\theta)\). Consider the contextual bandit setting, we receive a context \(s\) and execute single action; we never get to rollback to that precise context and try an alternative action (to get a paired sample #yolo) because we do not control \(p(s)\). This is an important paradigm for many 'real world' problems, e.g., predicting medical treatments or ad selection.

Estimating \(J(\theta)\) [V1]: We obtain an unbiased estimator of \(J(\theta)\) with importance sampling,

$$ J(\theta) \approx \hat{J}_{\!\text{IS}}(\theta) = \frac{1}{m} \sum_{j=1}^m r^{(j)} \!\cdot\! w^{(j)}_{\theta} \quad \text{ where } w^{(j)}_{\theta} = \frac{p_{\theta}(x^{(j)}) }{ q(x^{(j)}) }. $$


This estimator is remarkable: it uses importance sampling as a function approximator! We have an unbiased estimate of \(J(\theta)\) for any value of \(\theta\) that we like. The catch is that we have to pick \(\theta\) a priori, i.e., with no knowledge of the sample.

After we've collected a (large) sample it's possible to optimize \(\hat{J}_{\!\text{IS}}\) using any optimization algorithm (e.g., L-BFGS). Of course, we risk overfitting to the sample if we evaluate \(\hat{J}_{\!\text{IS}}\). Actually, it's a bit worse: this objective tends to favor regions of \(\theta\), which are not well represented in the sample because the importance sampling estimator has high variance in these regions resulting from large importance weights (when \(q(x)\) is small and \(p_{\theta}(x)\) is large, \(w(x)\) is large and consequently so is \(\hat{J}_{\!\text{IS}}\) regardless of whether \(r(x)\) is high!). Thus, we want some type of "regularization" to keep the optimizer in regions which are sufficiently well-represented by the sample.

Better surrogate [V2]: There are many ways to improve the variance of the estimator and confidently obtain improvements to the system. One of my favorites is Bottou et al.'s lower bound on \(J(\theta)\), which we get by clipping importance weights, replace \(w^{(j)}_{\theta}\) with \(\min(R, w^{(j)}_{\theta})\).

Confidence intervals [V3]: We can augment the V2 lower bound with confidence intervals derived from the empirical Bernstein bound (EBB). We'll require that \(r\) is bounded and that we know its max/min values. The EBB penalizes hypotheses (values of \(\theta\)) which have higher sample variance. (Note: a Hoeffding bound wouldn't change the shape of the objective, but EBB does thanks to the sample variance penalty. EBB tends to be tighter.). The EBB introduces an additional "confidence" hyperparameter, \((1-\delta)\). Bottou et al. recommend maximizing the lower bound as it provides safer improvements. See the original paper for the derivation.

Both V2 and V3 are biased (as they are lower bounds), but we can mitigate the bias by tuning the hyperparameter \(R\) on a heldout sample (we can even tune \(\delta\), if desired). Additionally, V2 and V3 are 'valid' when \(q\) has limited support since they prevent the importance weights from exploding (of course, the bias can be arbitrarily bad, but probably unavoidable given the learning-from-only-logged data setup).

Extensions

Unknown environment factors: Consider the contextual bandit setting (mentioned above). Here \(p\) and \(q\) share an unknown environment factor: the distribution of contexts. Luckily, we do not need to know the value of this factor in order to apply any of our estimators because they are all based on likelihood ratios, thus the shared unknown factors cancel out! Some specific examples are given below. Of course, these factors do influence the estimators because they are crucial in sampling, they just aren't necessary in evaluation.

  • In contextual bandit example, \(x\) is a state-action pair, \(w_{\theta}(x) = \frac{p_{\theta}(x)}{q(x)} = \frac{ p_{\theta}(s,a) }{ q(s,a) } = \frac{p_{\theta}(a|s) p(s)}{q(a|s) p(s)} = \frac{p_{\theta}(a|s) }{ q(a|s) }\).

  • In a Markov decision process, \(x\) is a sequence of state-action pairs, \(w_{\theta}(x) = \frac{p_{\theta}(x)}{q(x)} = \frac{ p(s_0) \prod_{t=0}^T p(s_{t+1}|s_t,a_t) p_\theta(a_t|s_t) } { p(s_0) \prod_{t=0}^T p(s_{t+1}|s_t,a_t) q(a_t|s_t) } = \frac{\prod_{t=0}^T \pi_\theta(a_t|s_t)} {\prod_{t=0}^T q(a_t|s_t)}.\)

Variance reduction: These estimators can all be improved with variance reduction techniques. Probably the most effective technique is using control variates (of which baseline functions are a special case). These are random variables correlated with \(r(x)\) for which we know their expectations (or at least they are estimated separately). A great example is how ad clicks depend strongly on time-of-day (fewer people are online late at night so we get fewer clicks), thus the time-of-day covariate explains a large part of the variation in \(r(x)\).

Estimation instead of optimization: You can use this general setup for estimation instead of optimization, in which case it's fine to let \(r\) have real-valued multivariate output. The confidence intervals are probably useful in that setting too.

Unknown \(q\): Often \(q\) is an existing complex system, which does not record its probabilities. It is possible to use regression to estimate \(q\) from the samples, which is called the propensity score (PS). PS attempts to account for confounding variables, which are hidden causes that control variation in the data. Failing to account for confounding variables may lead to incorrect conclusions. Unfortunately, PS results in a biased estimator because we're using a 'ratio of expectations' (we'll divide by the PS estimate) instead of an 'expectation of ratios'. PS is only statistically consistent in the (unlikely) event that the density estimate is correctly specified (i.e., we can eventually get \(q\) correct). In the unknown \(q\) setting, it's often better to use the doubly-robust estimator (DR) which combines two estimators: a density estimator for \(q\) and a function approximation for \(r\). A great explanation for the bandit case is in Dudík et al. (2011). The DR estimator is also biased, but it has a better bias-variance tradeoff than PS.

What if \(q\) doesn't have support everywhere? This is an especially important setting because it is often the case that data collection policies abide by some safety regulations, which prevent known bad configurations. In many situations, evaluating \(r(x)\) corresponds to executing an action \(x\) in the real world so terrible outcomes could occur, such as, breaking a system, giving a patient a bad treatment, or losing money. V1 is ok to use as long as we satisfy the importance sampling support conditions, which might mean rejecting certain values for \(\theta\) (might be non-trivial to enforce) and consequently finding a less-optimal policy. V2 and V3 are ok to use without an explicit constraint, but additional care may be needed to ensure specific safety constraints are satisfied by the learned policy.

What if \(q\) is deterministic? This is related to the point above. This is a hard problem. Essentially, this trying to learn without any exploration / experimentation! In general, we need exploration to learn. Randomization isn't the only way to perform exploration, there are many systematic types of experimentation.

  • There are some cases of systematic experimentation that are ok. For example, enumerating all elements of \(\mathcal{X}\) (almost certainly infeasible). Another example is a contextual bandit where \(q\) assigns actions to contexts deterministically via a hash function (this setting is fine because \(q\) is essentially a uniform distribution over actions, which is independent of the state). In other special cases, we may be able to characterize systematic exploration as stratified sampling.

  • A generic solution might be to apply the doubly-robust estimator, which "smooths out" deterministic components (by pretending they are random) and accounting for confounds (by explicitly modeling them in the propensity score).

What if we control data collection (\(q\))? This is an interesting setting. Essentially, it asks "how do we explore/experiment optimally (and safely)?". In general, this is an open question and depends on many considerations, such as, how much control, exploration cost (safety constraints) and prior knowledge (of \(r\) and unknown factors in the environment). I've seen some papers cleverly design \(q\). The first that comes to mind is Levine & Koltun (2013). Another setting is online contextual bandits, in which algorithms like EXP4 and Thompson sampling, prescribe certain types of exploration and work interactively (i.e., they don't have a fixed training sample). Lastly, I'll mention that there are many techniques for variance reduction by importance sampling, which may apply.

Further reading

Léon Bottou, Jonas Peters, Joaquin Quiñonero-Candela, Denis X. Charles, D. Max Chickering, Elon Portugaly, Dipankar Ray, Patrice Simard, Ed Snelson. Counterfactual reasoning in learning systems. JMLR 2013.

The source for the majority of this post. It includes many other interesting ideas and goes more in depth into some of the details.

Miroslav Dudík, John Langford, Lihong Li. Doubly robust policy evaluation and learning. ICML 2011.

Discussed in extensions section.

Philip S. Thomas. Safe reinforcement learning. PhD Thesis 2015.

Covers confidence intervals for policy evaluation similar to Bottou et al., as well as learning algorithms for RL with safety guarantees (e.g., so we don't break the robot).

Peshkin and Shelton. Learning from scarce experience. ICML 2002.

An older RL paper, which covers learning from logged data. This is one of the earliest papers on learning from logged data that I could find.

Levine and Koltun. Guided policy search. ICML 2013.

Discusses clever choices for \(q\) to better-guide learning in the RL setting.

Corinna Cortes, Yishay Mansour, Mehryar Mohri. Learning bounds for importance weighting. NIPS 2010.

Discusses generalization bounds for the counterfactual objective. Includes an alternative weighting scheme to keep importance weights from exploding.

The optimal proposal distribution is not p

The following is a quick rant about importance sampling (see that post for notation).

I've heard the following incorrect statement one too many times,

We chose \(q \approx p\) because \(q=p\) is the "optimal" proposal distribution.

While it is certainly a good idea to pick \(q\) to be as similar as possible to \(p\), it is by no means optimal because it is oblivious to \(f\)!

With importance sampling, it is possible to achieve a variance reduction over Monte Carlo estimation. The optimal proposal distribution, assuming \(f(x) \ge 0\) for all \(x\), is \(q(x) \propto p(x) f(x).\) This choice of \(q\) gives us a zero variance estimate with a single sample!

Of course, this is an unreasonable distribution to use because the normalizing constant is the thing you are trying to estimate, but it is proof that better proposal distributions exist.

The key to doing better than \(q=p\) is to take \(f\) into account. Look up "importance sampling for variance reduction" to learn more.

Importance Sampling

Importance sampling is a powerful and pervasive technique in statistics, machine learning and randomized algorithms.

Basics

Importance sampling is a technique for estimating the expectation \(\mu\) of a random variable \(f(x)\) under distribution \(p\) from samples of a different distribution \(q.\)

The key observation is that \(\mu\) is can expressed as the expectation of a different random variable \(f^*(x)=\frac{p(x)}{q(x)}\! \cdot\! f(x)\) under \(q.\)

$$ \mathbb{E}_{q}\! \left[ f^*(x) \right] = \mathbb{E}_{q}\! \left[ \frac{p(x)}{q(x)} f(x) \right] = \sum_{x} q(x) \frac{p(x)}{q(x)} f(x) = \sum_{x} p(x) f(x) = \mathbb{E}_{p}\! \left[ f(x) \right] = \mu $$


Technical condition: \(q\) must have support everywhere \(p\) does, \(p(x) > 0 \Rightarrow q(x) > 0.\) Without this condition, the equation is biased! Note: \(q\) can support things that \(p\) doesn't.

Terminology: The quantity \(w(x) = \frac{p(x)}{q(x)}\) is often referred to as the "importance weight" or "importance correction". We often refer to \(p\) as the target density and \(q\) the proposal density.

Now, given samples \(\{ x^{(i)} \}_{i=1}^{n}\) from \(q,\) we can use the Monte Carlo estimate, \(\hat{\mu} \approx \frac{1}{n} \sum_{i=1}^n f^{*}(x^{(i)}),\) as an unbiased estimator of \(\mu.\)

Remarks

There are a few reasons we might want use importance sampling:

  1. Convenience: It might be trickier to sample directly from \(p.\)

  2. Bias-correction: Suppose, we're developing an algorithm which requires samples to satisfy some "safety" condition (e.g., a minimum support threshold) and be unbiased. Importance sampling can be used to remove bias, while satisfying the condition.

  3. Variance reduction: It might be the case that sampling directly from \(p\) would require more samples to estimate \(\mu.\) Check out these great notes for more.

  4. Off-policy evaluation and learning: We might want to collect some "exploratory data" from \(q\) and evaluate different "policies", \(p\) (e.g., to pick the best one). Here's a link to a future post on off-policy evaluation and counterfactual reasoning and some cool papers: counterfactual reasoning, reinforcement learning, contextual bandits, domain adaptation.

There are a few common cases for \(q\) worth separate consideration:

  1. Control over \(q\): This is the case in experimental design, variance reduction, active learning and reinforcement learning. It's often difficult to design \(q,\) which results in an estimator with "reasonable" variance. A very difficult case is in off-policy evaluation because it (essentially) requires a good exploratory distribution for every possible policy. (I have much more to say on this topic.)

  2. Little to no control over \(q\): For example, you're given some dataset (e.g., new articles) and you want to estimate performance on a different dataset (e.g., Twitter).

  3. Unknown \(q\): In this case, we want to estimate \(q\) (typically referred to as the propensity score) and use it in the importance sampling estimator. This technique, as far as I can tell, is widely used to remove selection bias when estimating effects of different treatments.

Drawbacks: The main drawback of importance sampling is variance. A few bad samples with large weights can drastically throw off the estimator. Thus, it's often the case that a biased estimator is preferred, e.g., estimating the partition function, clipping weights, indirect importance sampling. A secondary drawback is that both densities must be normalized, which is often intractable.

What's next? I plan to cover "variance reduction" and off-policy evaluation in more detail in future posts.