Counterfactual reasoning and learning from logged data

Counterfactual reasoning is reasoning about data that we did not observe. For example, reasoning about the expected reward of new policy given data collected from a older one.

In this post, I'll discuss some basic techniques for learning from logged data. For the large part, this post is based on things I learned from Peshkin & Shelton (2002) and Bottou et al. (2013) (two of all-time favorite papers).

After reading, have a look at the Jupyter notebook accompanying this post!

Setup (off-line off-policy optimization): We're trying to optimize a function of the form,

$$J(\theta) = \underset{p_\theta}{\mathbb{E}} \left[ r(x) \right] = \sum_{x \in \mathcal{X}} p_\theta(x) r(x).$$

But! We only have a fixed sample of size $m$ from a data collection policy $q$, $\{ (r^{(j)}, x^{(j)} ) \}_{j=1}^m \overset{\text{i.i.d.}} \sim q.$

• Although, it's not necessarily the case, you can think of $q = p_{\theta'}$ for a fixed value $\theta'.$

• $\mathcal{X}$ is an arbitrary multivariate space, which permits a mix of continuous and discrete components, with appropriate densities, $p_{\theta}$ and $q$ defined over it.

• $r: \mathcal{X} \mapsto \mathbb{R}$ is a black box that outputs a scalar score.

• I've used the notation $r^{(j)}$ instead of $r(x^{(j)})$ to emphasize that we can't evaluate $r$ at $x$ values other than those in the sample.

• We'll assume that $q$ assigns positive probability everywhere, $q(x) > 0$ for all $x \in \mathcal{X}$. This means is that the data collection process must be randomized and eventually sample all possible configurations. Later, I discuss relaxing this assumption.

Each distribution is a product of one or more factors of the following types: policy factors (at least one), which directly depend on $\theta$, and environment factors (possibly none), which do not depend directly on $\theta$. Note that environment factors are only accessible via sampling (i.e., we don't know the value they assign to a sample). For example, a contextual bandit problem, where $x$ is a context-action pair, $x = (s,a)$. Here $q(x) = q(a|s) p(s)$ and $p_{\theta}(x) = p_{\theta}(a|s) p(s)$. Note that $p_{\theta}$ and $q$ share the environment factor $p(s)$, the distribution over contexts, and only differ in the action-given-context factor. For now, assume that we can evaluate all environment factors; later, I'll discuss how we cleverly work around it.

The main challenge of this setting is that we don't have controlled experiments to learn from because $q$ is not (completely) in our control. This manifests itself as high variance ("noise") in estimating $J(\theta)$. Consider the contextual bandit setting, we receive a context $s$ and execute single action; we never get to rollback to that precise context and try an alternative action (to get a paired sample #yolo) because we do not control $p(s)$. This is an important paradigm for many 'real world' problems, e.g., predicting medical treatments or ad selection.

Estimating $J(\theta)$ [V1]: We obtain an unbiased estimator of $J(\theta)$ with importance sampling,

$$J(\theta) \approx \hat{J}_{\!\text{IS}}(\theta) = \frac{1}{m} \sum_{j=1}^m r^{(j)} \!\cdot\! w^{(j)}_{\theta} \quad \text{ where } w^{(j)}_{\theta} = \frac{p_{\theta}(x^{(j)}) }{ q(x^{(j)}) }.$$

This estimator is remarkable: it uses importance sampling as a function approximator! We have an unbiased estimate of $J(\theta)$ for any value of $\theta$ that we like. The catch is that we have to pick $\theta$ a priori, i.e., with no knowledge of the sample.

After we've collected a (large) sample it's possible to optimize $\hat{J}_{\!\text{IS}}$ using any optimization algorithm (e.g., L-BFGS). Of course, we risk overfitting to the sample if we evaluate $\hat{J}_{\!\text{IS}}$. Actually, it's a bit worse: this objective tends to favor regions of $\theta$, which are not well represented in the sample because the importance sampling estimator has high variance in these regions resulting from large importance weights (when $q(x)$ is small and $p_{\theta}(x)$ is large, $w(x)$ is large and consequently so is $\hat{J}_{\!\text{IS}}$ regardless of whether $r(x)$ is high!). Thus, we want some type of "regularization" to keep the optimizer in regions which are sufficiently well-represented by the sample.

Better surrogate [V2]: There are many ways to improve the variance of the estimator and confidently obtain improvements to the system. One of my favorites is Bottou et al.'s lower bound on $J(\theta)$, which we get by clipping importance weights, replace $w^{(j)}_{\theta}$ with $\min(R, w^{(j)}_{\theta})$.

Confidence intervals [V3]: We can augment the V2 lower bound with confidence intervals derived from the empirical Bernstein bound (EBB). We'll require that $r$ is bounded and that we know its max/min values. The EBB penalizes hypotheses (values of $\theta$) which have higher sample variance. (Note: a Hoeffding bound wouldn't change the shape of the objective, but EBB does thanks to the sample variance penalty. EBB tends to be tighter.). The EBB introduces an additional "confidence" hyperparameter, $(1-\delta)$. Bottou et al. recommend maximizing the lower bound as it provides safer improvements. See the original paper for the derivation.

Both V2 and V3 are biased (as they are lower bounds), but we can mitigate the bias by tuning the hyperparameter $R$ on a heldout sample (we can even tune $\delta$, if desired). Additionally, V2 and V3 are 'valid' when $q$ has limited support since they prevent the importance weights from exploding (of course, the bias can be arbitrarily bad, but probably unavoidable given the learning-from-only-logged data setup).

Extensions

Unknown environment factors: Consider the contextual bandit setting (mentioned above). Here $p$ and $q$ share an unknown environment factor: the distribution of contexts. Luckily, we do not need to know the value of this factor in order to apply any of our estimators because they are all based on likelihood ratios, thus the shared unknown factors cancel out! Some specific examples are given below. Of course, these factors do influence the estimators because they are crucial in sampling, they just aren't necessary in evaluation.

• In contextual bandit example, $x$ is a state-action pair, $w_{\theta}(x) = \frac{p_{\theta}(x)}{q(x)} = \frac{ p_{\theta}(s,a) }{ q(s,a) } = \frac{p_{\theta}(a|s) p(s)}{q(a|s) p(s)} = \frac{p_{\theta}(a|s) }{ q(a|s) }$.

• In a Markov decision process, $x$ is a sequence of state-action pairs, $w_{\theta}(x) = \frac{p_{\theta}(x)}{q(x)} = \frac{ p(s_0) \prod_{t=0}^T p(s_{t+1}|s_t,a_t) p_\theta(a_t|s_t) } { p(s_0) \prod_{t=0}^T p(s_{t+1}|s_t,a_t) q(a_t|s_t) } = \frac{\prod_{t=0}^T \pi_\theta(a_t|s_t)} {\prod_{t=0}^T q(a_t|s_t)}.$

Variance reduction: These estimators can all be improved with variance reduction techniques. Probably the most effective technique is using control variates (of which baseline functions are a special case). These are random variables correlated with $r(x)$ for which we know their expectations (or at least they are estimated separately). A great example is how ad clicks depend strongly on time-of-day (fewer people are online late at night so we get fewer clicks), thus the time-of-day covariate explains a large part of the variation in $r(x)$.

Estimation instead of optimization: You can use this general setup for estimation instead of optimization, in which case it's fine to let $r$ have real-valued multivariate output. The confidence intervals are probably useful in that setting too.

Unknown $q$: Often $q$ is an existing complex system, which does not record its probabilities. It is possible to use regression to estimate $q$ from the samples, which is called the propensity score (PS). PS attempts to account for confounding variables, which are hidden causes that control variation in the data. Failing to account for confounding variables may lead to incorrect conclusions. Unfortunately, PS results in a biased estimator because we're using a 'ratio of expectations' (we'll divide by the PS estimate) instead of an 'expectation of ratios'. PS is only statistically consistent in the (unlikely) event that the density estimate is correctly specified (i.e., we can eventually get $q$ correct). In the unknown $q$ setting, it's often better to use the doubly-robust estimator (DR) which combines two estimators: a density estimator for $q$ and a function approximation for $r$. A great explanation for the bandit case is in Dudík et al. (2011). The DR estimator is also biased, but it has a better bias-variance tradeoff than PS.

What if $q$ doesn't have support everywhere? This is an especially important setting because it is often the case that data collection policies abide by some safety regulations, which prevent known bad configurations. In many situations, evaluating $r(x)$ corresponds to executing an action $x$ in the real world so terrible outcomes could occur, such as, breaking a system, giving a patient a bad treatment, or losing money. V1 is ok to use as long as we satisfy the importance sampling support conditions, which might mean rejecting certain values for $\theta$ (might be non-trivial to enforce) and consequently finding a less-optimal policy. V2 and V3 are ok to use without an explicit constraint, but additional care may be needed to ensure specific safety constraints are satisfied by the learned policy.

What if $q$ is deterministic? This is related to the point above. This is a hard problem. Essentially, this trying to learn without any exploration / experimentation! In general, we need exploration to learn. Randomization isn't the only way to perform exploration, there are many systematic types of experimentation.

• There are some cases of systematic experimentation that are ok. For example, enumerating all elements of $\mathcal{X}$ (almost certainly infeasible). Another example is a contextual bandit where $q$ assigns actions to contexts deterministically via a hash function (this setting is fine because $q$ is essentially a uniform distribution over actions, which is independent of the state). In other special cases, we may be able to characterize systematic exploration as stratified sampling.

• A generic solution might be to apply the doubly-robust estimator, which "smooths out" deterministic components (by pretending they are random) and accounting for confounds (by explicitly modeling them in the propensity score).

What if we control data collection ($q$)? This is an interesting setting. Essentially, it asks "how do we explore/experiment optimally (and safely)?". In general, this is an open question and depends on many considerations, such as, how much control, exploration cost (safety constraints) and prior knowledge (of $r$ and unknown factors in the environment). I've seen some papers cleverly design $q$. The first that comes to mind is Levine & Koltun (2013). Another setting is online contextual bandits, in which algorithms like EXP4 and Thompson sampling, prescribe certain types of exploration and work interactively (i.e., they don't have a fixed training sample). Lastly, I'll mention that there are many techniques for variance reduction by importance sampling, which may apply.

Léon Bottou, Jonas Peters, Joaquin Quiñonero-Candela, Denis X. Charles, D. Max Chickering, Elon Portugaly, Dipankar Ray, Patrice Simard, Ed Snelson. Counterfactual reasoning in learning systems. JMLR 2013.

The source for the majority of this post. It includes many other interesting ideas and goes more in depth into some of the details.

Miroslav Dudík, John Langford, Lihong Li. Doubly robust policy evaluation and learning. ICML 2011.

Discussed in extensions section.

Philip S. Thomas. Safe reinforcement learning. PhD Thesis 2015.

Covers confidence intervals for policy evaluation similar to Bottou et al., as well as learning algorithms for RL with safety guarantees (e.g., so we don't break the robot).

Peshkin and Shelton. Learning from scarce experience. ICML 2002.

An older RL paper, which covers learning from logged data. This is one of the earliest papers on learning from logged data that I could find.

Levine and Koltun. Guided policy search. ICML 2013.

Discusses clever choices for $q$ to better-guide learning in the RL setting.

Corinna Cortes, Yishay Mansour, Mehryar Mohri. Learning bounds for importance weighting. NIPS 2010.

Discusses generalization bounds for the counterfactual objective. Includes an alternative weighting scheme to keep importance weights from exploding.

Multiclass logistic regression and conditional random fields are the same thing

A short rant: multiclass logistic regression and conditional random fields (CRF) are the same thing. This comes to a surprise to many people because CRFs tend to be surrounded by additional "stuff."

Multiclass logistic regression is simple. The goal is to predict the correct label $y^*$ from handful of labels $\mathcal{Y}$ given the observation $x$ based on features $\phi(x,y)$. Training this model typically requires computing the gradient:

$$\phi(x,y^*) - \sum_{y \in \mathcal{Y}} p(y|x) \phi(x,y)$$

where

$$\begin{eqnarray*} p(y|x) &=& \frac{1}{Z(x)} \exp(\theta^\top \phi(x,y)) & \ \ \ \ \text{and} \ \ \ \ & Z(x) &=& \sum_{y \in \mathcal{Y}} \exp(\theta^\top \phi(x,y)) \end{eqnarray*}$$

At test-time, we often take the highest-scoring label under the model.

$$\hat{y}(x) = \textbf{argmax}_{y \in \mathcal{Y}} \theta^\top \phi(x,y)$$

A conditional random field is exactly multiclass logistic regression. The only difference is that the sum and argmax are inefficient to compute naively (i.e., by enumeration). This point is often lost when people first learn about CRFs. Some people never make this connection.

Here's some stuff you'll see once we start talking about CRFs:

1. Inference algorithms (e.g., Viterbi decoding, forward-backward, Junction tree)

2. Graphical models (factor graphs, Bayes nets, Markov random fields)

3. Model templates (i.e., repeated feature functions)

In the logistic regression case, we'd never use the term "inference" to describe the "sum" and "max" over a handful of categories. Once we move to a structured label space, this term gets throw around. (BTW, this isn't "statistical inference," just algorithms to compute sum and max over $\mathcal{Y}$.)

Graphical models establish a notation and structural properties which allow efficient inference -- things like cycles and treewidth.

Model templating is the only essential trick to move from logistic regression to a CRF. Templating "solves" the problem that not all training examples have the same "size" -- the set of outputs $\mathcal{Y}(x)$ now depends on $x$. A model template specifies how to compute the features for an entire output, by looking at interactions between subsets of variables.

$$\phi(x,\boldsymbol{y}) = \sum_{\alpha \in A(x)} \phi_\alpha(x, \boldsymbol{y}_\alpha)$$

where $\alpha$ is a labeled subset of variables often called a factor and $\boldsymbol{y}_\alpha$ is the subvector containing values of variables $\alpha$. Basically, the feature function $\phi$ gets to look at some subset of the variables being predicted $y$ and the entire input $x$. The ability to look at more of $y$ allows the model to make more coherent predictions.

Anywho, it's often useful to take a step back and think about what you are trying to compute instead of how you're computing it. In this post, this allowed us see the similarity between logistic regression and CRFs even though they seem quite different.

Conditional random fields as Deep learning models?

This post is intended to convince conditional random field (CRF) lovers that deep learning might not be as crazy as it seems. And maybe even convince some deep learning lovers that the graphical models might have interesting things to offer.

In the world of structured prediction, we are plagued by the high-treewidth problem -- models with loopy factors are "bad" because exact inference is intractable. There are three common approaches for dealing with this problem:

1. Limit the expressiveness of the model (i.e., don't use to model you want)

2. Change the training objective

3. Approximate inference

Approximate inference is tricky. Things can easily go awry.

For example, structured perceptron training with loopy max-product BP instead of exact max product can diverge (Kulesza & Pereira, 2007). Another example: using approximate marginals from sum-product loopy BP in place of the true marginals in the gradient of the log-likelihood. This results in a different nonconvex objective function. (Note: sometimes these loopy BP approximations work fine.)

It looks like using approximate inference during training changes the training objective.

So, here's a simple idea: learn a model which makes accurate predictions given the approximate inference algorithm that will be used at test-time. Furthermore, we should minimize empirical risk instead of log-likelihood because it is robust to model miss-specification and approximate inference. In other words, make training conditions as close as possible to test-time conditions.

Now, as long as everything is differentiable, you can apply automatic differentiation (backprop) to train the end-to-end system. This idea appears in a few publications, including a handful of papers by Justin Domke, and a few by Stoyanov & Eisner.

Unsuprisingly, it works pretty well.

I first saw this idea in Stoyanov & Eisner (2011). They use loopy belief propagation as their approximate inference algorithm. At the end of the day, their model is essentially a deep recurrent network, which came from unrolling inference in a graphical model. This idea really struck me because it's clearly right in the middle between graphical models and deep learning.

You can immediately imagine swapping in other approximate inference algorithms in place of loopy BP.

Deep learning approaches get a bad reputation because there are a lot of "tricks" to get nonconvex optimization to work and because model structures are more open ended. Unlike graphical models, deep learning models have more variation in model structures. Maybe being more open minded about model structures is a good thing. We seem to have hit a brick wall with likelihood-based training. At the same time, maybe we can port over some of the good work on approximate inference as deep architectures.