Backpropagation with Callbacks: Foundations for Efficient and Expressive Differentiable Programming
2018
Training of deep learning models depends on gradient descent and end-to-end
differentiation. Under the slogan of differentiable programming, there is an
increasing demand for efficient automatic gradient computation for emerging
network architectures that incorporate dynamic control flow, especially in NLP.
In this paper we propose an implementation of backpropagation using functions
with callbacks, where the forward pass is executed as a sequence of function
calls, and the backward pass as a corresponding sequence of function returns.
A key realization is that this technique of chaining callbacks is well known in the
programming languages community as continuation-passing style (CPS). Any
program can be converted to this form using standard techniques, and hence,
any program can be mechanically converted to compute gradients.
Our approach achieves the same flexibility as other reverse-mode automatic
differentiation (AD) techniques, but it can be implemented without any auxiliary
data structures besides the function call stack, and it can easily be combined
with graph construction and native code generation techniques through forms of
multi-stage programming, leading to a highly efficient implementation that
combines the performance benefits of define-then-run software frameworks such
as TensorFlow with the expressiveness of define-by-run frameworks such as PyTorch.
- Correction
- Source
- Cite
- Save
- Machine Reading By IdeaReader
0
References
0
Citations
NaN
KQI