Almost everyone I know says that "backprop is just the chain rule." Although that's basically true, there are some subtle and beautiful things about automatic differentiation techniques (including backprop) that will not be appreciated with this dismissive attitude.
This leads to a poor understanding. As I have ranted before: people do not understand basic facts about autodiff.
- Evaluating ∇f(x) is provably as fast as evaluating f(x).
- Code for ∇f(x) can be derived by a rote program transformation, even if the code has control flow structures like loops and intermediate variables (as long as the control flow is independent of x). You can even do this "automatic" transformation by hand!
Autodiff ≠ what you learned in calculus
Let's try to understand the difference between autodiff and the type of differentiation that you learned in calculus, which is called symbolic differentiation.
I'm going to use an example from Justin Domke's notes,
If we were writing a program (e.g., in Python) to compute f, we'd take advantage of the fact that it has a lot of repeated evaluations for efficiency.
def f(x):
a = exp(x)
b = a**2
c = a + b
d = exp(c)
e = sin(c)
return d + e
Symbolic differentiation would have to use the "flat" version of this function, so no intermediate variable ⇒ slow.
Automatic differentiation lets us differentiate a program with intermediate variables.
-
The rules for transforming the code for a function into code for the gradient are really minimal (fewer things to memorize!). Additionally, the rules are more general than in symbolic case because they handle as a superset of programs.
-
Quite beautifully, the program for the gradient has exactly the same structure as the function, which implies that we get the same runtime (up to some constants factors).
I won't give the details of how to execute the backpropagation transform to the
program. You can get that from
Justin Domke's notes
and many other good
resources. Here's some code
that I wrote that accompanies to the f(x)
example, which has a bunch of
comments describing the manual "automatic" differentiation process on f(x)
.
Autodiff by the method of Lagrange multipliers
Let's view the intermediate variables in our optimization problem as simple equality constraints in an equivalent constrained optimization problem. It turns out that the de facto method for handling constraints, the method Lagrange multipliers, recovers exactly the adjoints (intermediate derivatives) in the backprop algorithm!
Here's our example from earlier written in this constraint form:
The general formulation
The first set of constraint (1,…,d) are a little silly. They are only there to keep our formulation tidy. The variables in the program fall into three categories:
-
input variables (x): x1,…,xd
-
intermediate variables: (z): zi=fi(zα(i)) for 1≤i≤n, where α(i) is a list of indices from {1,…,n−1} and zα(i) is the subvector of variables needed to evaluate fi(⋅). Minor detail: take f1:d to be the identity function.
-
output variable (zn): We assume that our programs has a singled scalar output variable, zn, which represents the quantity we'd like to maximize.
The relation α is a dependency graph among variables. Thus, α(i) is the list of incoming edges to node i and β(j)={i:j∈α(i)} is the set of outgoing edges. For now, we'll assume that the dependency graph given by α is ① acyclic: no zi can transitively depend on itself. ② single-assignment: each zi appears on the left-hand side of exactly one equation. We'll discuss relaxing these assumptions in § Generalizations.
The standard way to solve a constrained optimization is to use the method Lagrange multipliers, which converts a constrained optimization problem into an unconstrained problem with a few more variables λ (one per xi constraint), called Lagrange multipliers.
The Lagrangian
To handle constraints, let's dig up a tool from our calculus class, the method of Lagrange multipliers, which converts a constrained optimization problem into an unconstrained one. The unconstrained version is called "the Lagrangian" of the constrained problem. Here is its form for our task,
Optimizing the Lagrangian amounts to solving the following nonlinear system of equations, which give necessary, but not sufficient, conditions for optimality,
Let's look a little closer at the Lagrangian conditions by breaking up the system of equations into salient parts, corresponding to which variable types are affected.
Intermediate variables (z): Optimizing the multipliers—i.e., setting the gradient of Lagrangian w.r.t. λ to zero—ensures that the constraints on intermediate variables are satisfied.
We can use forward propagation to satisfy these equations, which we may regard as a block-coordinate step in the context of optimizing the L.
Lagrange multipliers (λ, excluding λn): Setting the gradient of the L w.r.t. the intermediate variables equal to zeros tells us what to do with the intermediate multipliers.
Clearly, ∂fi(zα(i))∂zj=0 for j∉α(i), which is why the β(j) notation came in handy. By assumption, the local derivatives, ∂fi(zα(i))∂zj for j∈α(i), are easy to calculate—we don't even need the chain rule to compute them because they are simple function applications without composition. Similar to the equations for z, solving this linear system is another block-coordinate step.
Key observation: The last equation for λj should look very familiar: It is exactly the equation used in backpropagation! It says that we sum λi of nodes that immediately depend on j where we scaled each λi by the derivative of the function that directly relates i and j. You should think of the scaling as a "unit conversion" from derivatives of type i to derivatives of type j.
Output multiplier (λn): Here we follow the same pattern as for intermediate multipliers.
Input multipliers (λ1:d): Our dummy constraints gives us λ1:d, which are conveniently equal to the gradient of the function we're optimizing:
Of course, this interpretation is only precise when ① the constraints are satisfied (z equations) and ② the linear system on multipliers is satisfied (λ equations).
Input variables (x): Unfortunately, the there is no closed-form solution to how to set x. For this we resort to something like gradient ascent. Conveniently, ∇xf(x)=λ1:d, which we can use to optimize x!
Generalizations
We can think of these equations for λ as a simple linear system of equations, which we are solving by back-substitution when we use the backpropagation method. The reason why back-substitution is sufficient for the linear system (i.e., we don't need a full linear system solver) is that the dependency graph induced by the α relation is acyclic. If we had needed a full linear system solver, the solution would take O(n3) time instead of linear time, seriously blowing-up our nice runtime!
This connection to linear systems is interesting: It tells us that we can compute global gradients in cyclic graphs. All we'd need is to run a linear system solver to stitch together local gradients! That is exactly what the implicit function theorem says!
Cyclic constraints add some expressive powerful to our "constraint language," and it's interesting that we can still efficiently compute gradients in this setting. An example of what a general type of cyclic constraint looks like is
where g can be any smooth multivariate function of the intermediate variables! Of course, allowing cyclic constraints comes at the cost of a more-difficult analogue of "the forward pass" to satisfy the z equations (if we want to keep it a block-coordinate step). The λ equations are now a linear system that requires a linear solver (e.g., Gaussian elimination).
Example use cases:
-
Bi-level optimization: Solving an optimization problem with another one inside it. For example, gradient-based hyperparameter optimization in machine learning. The implicit function theorem manages to get gradients of hyperparameters without needing to store any of the intermediate states of the optimization algorithm used in the inner optimization! This is a huge memory saver since direct backprop on the inner gradient descent algorithm would require caching all intermediate states. Yikes!
-
Cyclic constraints are useful in many graph algorithms. For example, computing gradients of edge weights in a general finite-state machine or, similarly, computing the value function in a Markov decision process.
Other methods for optimization?
The connection to Lagrangians brings tons of algorithms for constrained optimization into the mix! We can imagine using more general algorithms for optimizing our function and other ways of enforcing the constraints. We see immediately that we could run optimization with adjoints set to values other than those that backprop would set them to (i.e., we can optimize them like we'd do in other algorithms for optimizing general Lagrangians).
Summary
Backprop does not directly fall out of the rules for differentiation that you learned in calculus (e.g., the chain rule).
- This is because it operates on a more general family of functions: programs which have intermediate variables. Supporting intermediate variables is crucial for implementing both functions and their gradients efficiently.
I described how we could use something we did learn from calculus 101, the method of Lagrange multipliers, to support optimization with intermediate variables.
-
It turned out that backprop is a particular instantiation of the method of Lagrange multipliers, involving block-coordinate steps for solving for the intermediates and multipliers.
-
I also described a neat generalization to support cyclic programs and I hinted at ideas for doing optimization a little differently, deviating from the de facto block-coordinate strategy.
Further reading
After working out the connection between backprop and the method of Lagrange multipliers, I discovered following paper, which beat me to it. I don't think my version is too redundant.
Yann LeCun. (1988) A Theoretical Framework from Back-Propagation.
Ben Recht has a great blog post that uses the implicit function theorem to derive the method of Lagrange multipliers. He also touches on the connection to backpropagation.
Ben Recht. (2016) Mechanics of Lagrangians.
Tom Goldstein's group took the Lagrangian view of backprop and used it to design an ADMM approach for optimizing neural nets. The ADMM approach can run massively in parallel and can leverage highly optimized solvers for subproblems. This work nicely demonstrates that understanding automatic differentiation—in the broader sense that I described in this post—facilitates the development of novel optimization algorithms.
Gavin Taylor, Ryan Burmeister, Zheng Xu, Bharat Singh, Ankit Patel, Tom Goldstein. (2018) Training Neural Networks Without Gradients: A Scalable ADMM Approach.
The backpropagation algorithm can be cleanly generalized from values to functionals!
Alexander Grubb and J. Andrew Bagnell. (2010) Boosted Backpropagation Learning for Training Deep Modular Networks.
Code
I have coded up and tested the Lagrangian perspective on automatic differentiation that I presented in this article. The code is available in this gist.
# -*- coding: utf-8 -*- | |
""" | |
Backprop as the method of Lagrange multiplers (and even the implicit function | |
theorem). | |
""" | |
from __future__ import division | |
import numpy as np | |
from arsenal.alphabet import Alphabet | |
from arsenal.math.checkgrad import finite_difference | |
# Implementation choice: I decided to separate the input-copy and intermediate | |
# constraints to avoid annoyances with having two namespaces (x and z). I | |
# suppose writing all constraints as functions of x and z seems more general, | |
# but with input-copy consraints we don't any expressivity we just have handle | |
# them with special cases (easy enough). | |
class Computation: | |
def __init__(self, f, inputs, constraints, df): | |
self.d = len(inputs) | |
self.n = self.d + len(constraints) | |
self.constraints = constraints | |
self.inputs = inputs | |
self.f = f | |
self.df = df | |
def forward(self, x): | |
"Evaluate `f(x)`" | |
return self.f(self.solve_constraints(x)) | |
def solve_constraints(self, x): | |
"Find a feasible solution to the constraints given `x`." | |
z = np.zeros(self.n) | |
z[self.inputs] = x | |
for c in self.constraints: | |
c.solve(z) | |
return z | |
def lagrangian(self, x, z, l): | |
return (self.f(z) | |
+ l[:self.d].dot(x[:self.d] - z[:self.d]) | |
+ l[self.d:].dot(self.constraints.penalties(z))) | |
def dlagrangian(self, x, z, l): | |
"Compute gradients of the Lagrangian wrt each argument." | |
dx = np.zeros_like(x) | |
dx += l[:self.d] | |
dz = self.df(z) + self.dconstraints(z).dot(l) | |
dl = np.zeros_like(l) | |
dl[:self.d] += x[self.inputs] - z[self.inputs] | |
dl[self.d:] += self.constraints.penalties(z) | |
return dx, dz, dl | |
def dconstraints(self, z): | |
"Evaluate constraint matrix for `z`." | |
# Note: The linear system approach build a massive highly structure | |
# sparse matrix that represents the local gradients. This is really | |
# inefficient in most cases because we can aggregate gradients as we | |
# go. This avoids the need to materialize this monster matrix. | |
D = np.zeros((self.n, self.n)) | |
D[self.inputs, self.inputs] = -1 | |
for c in self.constraints: | |
c.grad(z, D[:, c.i]) | |
return D | |
def reverse_mode(self, D, v): | |
"Solve upper triangular linear system, `D x = -v`." | |
l = v.copy() | |
for c in reversed(self.constraints): | |
for a in c.args: | |
l[a] += l[c.i] * D[a, c.i] | |
return l | |
def forward_mode(self, D, v): | |
"Solve upper triangular linear system, `Dᵀ = -v`." | |
l = v.copy() | |
for c in self.constraints: | |
for a in c.args: | |
l[c.i] += l[a] * D[a, c.i] | |
return l | |
class Constraint: | |
def __init__(self, i, f, args, df=None): | |
self.args = args | |
self.i = i | |
self.f = f | |
self.df = df | |
if df is None: | |
# Use finite-difference approximation if user didn't pass in df. | |
self.df = lambda x: finite_difference(f)(x).flatten() | |
def solve(self, z): | |
# Closed form solution to the constraint, could take a gradient step or | |
# solve a block-coordinate subproblem, more generally. | |
z[self.i] = self.f(z[self.args]) | |
def penalty(self, z): | |
return float(self.f(z[self.args]) - z[self.i]) | |
def grad(self, z, dz, adj=1): | |
# Note: adjoint is scalar because constraint is scalar. | |
dz[self.args] += adj * self.df(z[self.args]) | |
dz[self.i] += -adj | |
class Constraints(list): | |
"""Handles string-valued names and certain conventions like inputs need to be | |
the first vars.""" | |
def __init__(self, inputs): | |
self.A = Alphabet() | |
self.inputs = self.A.map(inputs) # need inputs to be the first vars | |
super(Constraints, self).__init__() | |
def add_constraint(self, lhs, f, rhs, df=None): | |
self.append(Constraint(self.A[lhs], f, self.A.map(rhs), df)) | |
def penalties(self, z): | |
return np.array([c.penalty(z) for c in self]) |
# -*- coding: utf-8 -*- | |
""" | |
Backprop as the method of Lagrange multiplers (and even the implicit function | |
theorem). | |
""" | |
from __future__ import division | |
import numpy as np | |
import scipy.linalg | |
from lagrangeprop import Computation, Constraints | |
from arsenal.math.checkgrad import finite_difference, fdcheck | |
from arsenal.math import onehot, compare | |
from arsenal import colors | |
def test_implicit_diff_view(L): | |
""" | |
Test connections between Lagrangian and implicit differentiation | |
If you have the Lagrangian view of backprop, then implicit functions should | |
really pop out! | |
Think of forward propagation as a smooth blackbox function h that maps inputs | |
(x) to intermediates (z). | |
maximize f(z) | |
subjecto h(x) = z | |
Rewriting slightly, let g(x,z) = h(x) - z. | |
maximize f(z) | |
subjecto g(x,z) = 0 | |
With forward propagation we always satisfy the constraints, so g(x,z)=0. Thus, | |
we also have "equilibrium" under little perturbations | |
g(x+Δx, z+Δz) = g(x,z) + Δx ⋅ ∂g/∂x + Δz ⋅ ∂g/∂z = 0. | |
Since g(x,z) = 0, | |
Δx ⋅ ∂g/∂x + Δz ⋅ ∂g/∂z = 0 | |
Solve for Δz/Δx, | |
Δz/Δx = - (∂g/∂z)^-1 ∂g/∂x ← there's your linear system of equations! | |
Combine with the objective ∂f/∂z | |
∂f/∂z Δz/Δx = ∂f/∂x | |
""" | |
print colors.magenta % 'Implicit differentiation:', | |
x = np.random.randn(L.d) | |
# Important! connection only holds when constraints are satisfied! | |
z = L.solve_constraints(x) | |
f_dz_dx = finite_difference(L.solve_constraints)(x) | |
dC_dx = np.zeros((L.n, L.d)) | |
dC_dx[L.inputs,L.inputs] = 1 | |
dC_dz = L.dconstraints(z) | |
dz_dx = -scipy.linalg.solve(dC_dz.T, dC_dx).T | |
assert np.allclose(f_dz_dx, dz_dx) | |
df_dz = L.df(z) | |
f_df_dx = finite_difference(L.forward)(x) | |
assert np.allclose(f_df_dx, dz_dx.dot(df_dz)) | |
print colors.green % 'ok' | |
def test_forward_mode(L): | |
print colors.magenta % 'Forward-mode:', | |
x = np.random.randn(L.d) | |
z = L.solve_constraints(x) | |
D = L.dconstraints(z) | |
# Compare methods to finite-difference approximation to ∇f(x) | |
f_df_dx = finite_difference(L.forward)(x) | |
# In forward mode, λ is interpreted as a vector of "tangents" pertaining to | |
# a single input, instead of "adjoints" of the single output. Tangents are | |
# equal to rows(cols?) of the Jacobian of the constraints. | |
f_dz_dx = finite_difference(L.solve_constraints)(x) | |
for i in range(L.d): # loop over each input | |
l = L.forward_mode(D, onehot(i, L.n)) | |
assert np.allclose(f_dz_dx[i], l) | |
# df/dz * dz/dx[i] => df/dx[i] | |
gi = L.df(z).dot(l) | |
assert np.allclose(f_df_dx[i], gi) | |
print colors.green % 'ok' | |
def test_dlagrangian(L): | |
print colors.magenta % 'Finite-difference Lagrangian:', | |
x = np.random.randn(L.d) | |
z = np.random.uniform(-1,1,size=L.n) | |
l = np.random.uniform(-1,1,size=L.n) | |
dx, dz, dl = L.dlagrangian(x, z, l) | |
assert fdcheck(lambda: L.lagrangian(x, z, l), z, dz, quiet=1).mean_relative_error < 0.01 | |
assert fdcheck(lambda: L.lagrangian(x, z, l), x, dx, quiet=1).mean_relative_error < 0.01 | |
assert fdcheck(lambda: L.lagrangian(x, z, l), l, dl, quiet=1).mean_relative_error < 0.01 | |
print colors.green % 'ok' | |
def test_reverse_mode(L): | |
print colors.magenta % 'Reverse-mode:', | |
x = np.random.randn(L.d) | |
# Compare methods to finite-difference approximation to ∇f(x) | |
f_df_dx = finite_difference(L.forward)(x) | |
# run forward to cache all the relavant stuff. | |
z = L.solve_constraints(x) | |
l = L.reverse_mode(L.dconstraints(z), L.df(z)) | |
assert np.allclose(f_df_dx, l[:L.d]) | |
print colors.green % 'ok' | |
def test_linear_system(L): | |
print colors.magenta % 'Linear solve:', | |
x = np.random.randn(L.d) | |
f_df_dx = finite_difference(L.forward)(x) | |
z = L.solve_constraints(x) | |
D = L.dconstraints(z) | |
l = L.reverse_mode(D, L.df(z)) | |
# Run linear system solver -- Note that `linalg.solve` is generally worse at | |
# solving the equations than `linalg.solve_triangular` (or equivalently | |
# reverse mode). This is because the solver doesn't realize that the system | |
# is upper triangular so it uses unstable operations like division and | |
# subtraction. | |
sol = scipy.linalg.solve(D, -L.df(z)) | |
assert np.allclose(l, sol) | |
assert np.allclose(f_df_dx, sol[:L.d]) | |
# test aupper triangular solver | |
sol = scipy.linalg.solve_triangular(D, -L.df(z)) | |
assert np.allclose(f_df_dx, sol[:L.d]) | |
assert np.allclose(l, sol) | |
print colors.green % 'ok' | |
def test_blockcoordinate(L): | |
print colors.magenta % 'Block-coordinate updates for z and λ:', | |
x = np.random.randn(L.d) | |
z = L.solve_constraints(x) | |
l = L.reverse_mode(L.dconstraints(z), L.df(z)) | |
dx, dz, dl = L.dlagrangian(x, z, l) | |
assert np.allclose(dx, l[:L.d]) | |
assert np.abs(dz).max() <= 1e-5 | |
assert np.allclose(dl, 0) | |
print colors.green % 'ok' | |
def main(): | |
C = Constraints(['x','y']) | |
C.add_constraint('a', np.exp, ['x'], df=np.exp) | |
C.add_constraint('b', lambda x: x**2, ['a'], df=lambda x: 2*x) | |
C.add_constraint('c', np.sum, ['a','b','y'], df=np.ones_like) | |
# C.add_constraint('c', np.product, ['a','b','y']) | |
# C.add_constraint('d', np.exp, ['c'], df=np.exp) | |
C.add_constraint('d', np.tanh, ['c']) | |
C.add_constraint('e', np.sin, ['c'], df=np.cos) | |
C.add_constraint('f', np.sum, ['d','e'], df=np.ones_like) | |
n = len(C.inputs) + len(C) | |
_r = np.random.randn(n) # random linear function of intermediate nodes | |
f = _r.dot | |
df = lambda z: _r.copy() | |
L = Computation(f, C.inputs, C, df = df) | |
test_dlagrangian(L) | |
test_reverse_mode(L) | |
test_forward_mode(L) | |
test_linear_system(L) | |
test_blockcoordinate(L) | |
test_implicit_diff_view(L) | |
if __name__ == '__main__': | |
main() |