Differential Programming
Gabriel Peyré
É C O L E N O R M A L E S U P É R I E U R E
www.numerical-tours.com
Differential Programming Gabriel Peyr www.numerical-tours.com C O - - PowerPoint PPT Presentation
Differential Programming Gabriel Peyr www.numerical-tours.com C O L E N O R M A L E S U P R I E U R E https://mathematical-coffees.github.io Organized by : Mrouane Debbah & Gabriel Peyr Optimization Deep Learning Optimal
É C O L E N O R M A L E S U P É R I E U R E
www.numerical-tours.com
Optimal Transport Optimization
Yves Achdou, Paris 6 Daniel Bennequin, Paris 7 Marco Cuturi, ENSAE Jalal Fadili, ENSICaen Alexandre Gramfort, INRIA Olivier Grisel (INRIA) Olivier Guéant, Paris 1 Iordanis Kerenidis, CNRS and Paris 7 Guillaume Lecué, CNRS and ENSAE Frédéric Magniez, CNRS and Paris 7 Edouard Oyallon, CentraleSupelec Gabriel Peyré, CNRS and ENS Joris Van den Bossche (INRIA)
https://mathematical-coffees.github.io
Organized by: Mérouane Debbah & Gabriel Peyré
Deep Learning Artificial intelligence Compressed Sensing Quantum computing Mean field games Topos
Loss Input Output Model Parameter
min
θ
E(θ)
def.
= L(f(x, θ), y)
Loss Input Output Model Parameter
min
θ
E(θ)
def.
= L(f(x, θ), y)
Deep-learning:
x
θ1 θ2 θ3 θ4
y
class probabilities
f(·, θ)
Loss Input Output Model Parameter
min
θ
E(θ)
def.
= L(f(x, θ), y)
Deep-learning:
x
θ1 θ2 θ3 θ4
y
class probabilities
f(·, θ) Super-resolution: θ unknown image
y observation
degradation
f(x, ·)
Loss Input Output Model Parameter
min
θ
E(θ)
def.
= L(f(x, θ), y)
Medical imaging registration:
x y
diffeomorphism
f(·, θ) Deep-learning:
x
θ1 θ2 θ3 θ4
y
class probabilities
f(·, θ) Super-resolution: θ unknown image
y observation
degradation
f(x, ·)
min
θ
E(θ)
def.
= L(f(x, θ), y)
Small τ` Large τ` Optimal τ` = τ ?
`
θ`+1 = θ` τ`rE(θ`) Gradient descent:
min
θ
E(θ)
def.
= L(f(x, θ), y)
Small τ` Large τ` Optimal τ` = τ ?
`
θ`+1 = θ` τ`rE(θ`) Gradient descent:
Many generalization: (quasi)-Newton Nesterov / heavy-ball Stochastic / incremental methods Proximal splitting (non-smooth E)
. . .
Setup: E : Rn → R computable in K operations. Hypothesis: elementary operations (a × b, log(a), √a . . . ) and their derivatives cost O(1).
Setup: E : Rn → R computable in K operations. Hypothesis: elementary operations (a × b, log(a), √a . . . ) and their derivatives cost O(1). Question: What is the complexity of computing rE : Rn ! Rn?
Setup: E : Rn → R computable in K operations. Hypothesis: elementary operations (a × b, log(a), √a . . . ) and their derivatives cost O(1). Question: What is the complexity of computing rE : Rn ! Rn? Finite differences: K(n + 1) operations, intractable for large n. rE(θ) ⇡ 1 ε(E(θ + εδ1) E(θ), . . . E(θ + εδn) E(θ))
Setup: E : Rn → R computable in K operations. Hypothesis: elementary operations (a × b, log(a), √a . . . ) and their derivatives cost O(1). Question: What is the complexity of computing rE : Rn ! Rn? [Seppo Linnainmaa, 1970] Theorem: there is an algorithm to compute rE in O(K) operations. Finite differences: K(n + 1) operations, intractable for large n. rE(θ) ⇡ 1 ε(E(θ + εδ1) E(θ), . . . E(θ + εδn) E(θ))
Setup: E : Rn → R computable in K operations. Hypothesis: elementary operations (a × b, log(a), √a . . . ) and their derivatives cost O(1).
Seppo Linnainmaa
This algorithm is reverse mode automatic differentiation Question: What is the complexity of computing rE : Rn ! Rn? [Seppo Linnainmaa, 1970] Theorem: there is an algorithm to compute rE in O(K) operations. Finite differences: K(n + 1) operations, intractable for large n. rE(θ) ⇡ 1 ε(E(θ + εδ1) E(θ), . . . E(θ + εδn) E(θ))
x0 xR
xR+1
x =
x1 . . . g0 g1 gR x2 ∈ R xr+1 = gr(xr) gr : Rnr → Rnr+1 ∂gr(xr) ∈ Rnr+1×nr rgR(xr) = [∂gr(xr)]> 2 Rnr+1⇥1
x0 xR
xR+1
x =
x1 . . . g0 g1 gR x2 ∈ R xr+1 = gr(xr) gr : Rnr → Rnr+1 ∂gr(xr) ∈ Rnr+1×nr rgR(xr) = [∂gr(xr)]> 2 Rnr+1⇥1 A0 A1 AR−1 AR
×
∂g(x) = ∂gR(xR) × ∂gR−1(xR−1) × . . . × ∂g1(x1) × ∂g0(x0)
× × ×
. . . 1
n0 n1 n2 nR−1 nR
Chain rule:
x0 xR
xR+1
x =
x1 . . . g0 g1 gR x2 ∈ R xr+1 = gr(xr) gr : Rnr → Rnr+1 ∂gr(xr) ∈ Rnr+1×nr rgR(xr) = [∂gr(xr)]> 2 Rnr+1⇥1 A0 A1 AR−1 AR
×
∂g(x) = ∂gR(xR) × ∂gR−1(xR−1) × . . . × ∂g1(x1) × ∂g0(x0)
× × ×
. . . 1
n0 n1 n2 nR−1 nR
Chain rule: ∂g(x) = ((. . . ((A0 × A1) × A2) . . . × AR−2) × AR−1) × AR n0n1n2 n1n2n3 nR−2nR−1nR nR−1nR Complexity: (if nr = 1 for r = 0, . . . , R − 1) (R − 1)n3 + n2 Forward O(n3)
x0 xR
xR+1
x =
x1 . . . g0 g1 gR x2 ∈ R xr+1 = gr(xr) gr : Rnr → Rnr+1 ∂gr(xr) ∈ Rnr+1×nr rgR(xr) = [∂gr(xr)]> 2 Rnr+1⇥1 A0 A1 AR−1 AR
×
∂g(x) = ∂gR(xR) × ∂gR−1(xR−1) × . . . × ∂g1(x1) × ∂g0(x0)
× × ×
. . . 1
n0 n1 n2 nR−1 nR
Chain rule: ∂g(x) = A0 × (A1 × (A2 × . . . × (AR−2 × (AR−1 × AR)) . . .)) nR−1nR nR−2nR−1 n1n2 n0n1 Complexity: Rn2 Backward O(n2) ∂g(x) = ((. . . ((A0 × A1) × A2) . . . × AR−2) × AR−1) × AR n0n1n2 n1n2n3 nR−2nR−1nR nR−1nR Complexity: (if nr = 1 for r = 0, . . . , R − 1) (R − 1)n3 + n2 Forward O(n3)
x0 xR E
xR+1
y
θR
x =
x1 . . .
θR−1
L θ1 θ0 g0
g1 gR
xr+1 = gr(xr, θr) E(x) = L(xR+1, y)
x2
x0 xR E
xR+1
y
θR
x =
x1 . . .
θR−1
L θ1 θ0 g0
g1 gR
xr+1 = gr(xr, θr) E(x) = L(xR+1, y) Example: deep neural network (here fully connected)
x
θ1 θ2 θ3 θ4 xr+1 = ρ(Arxr + br) xr ∈ Rdr Ar ∈ Rdr+1×dr br ∈ Rdr+1 θr = (Ar, br) ρ(u)
u x2
x0 xR E
xR+1
y
θR
x =
x1 . . .
θR−1
L θ1 θ0 g0
g1 gR
xr+1 = gr(xr, θr) E(x) = L(xR+1, y) Example: deep neural network (here fully connected)
x
θ1 θ2 θ3 θ4 xr+1 = ρ(Arxr + br) xr ∈ Rdr Ar ∈ Rdr+1×dr br ∈ Rdr+1 θr = (Ar, br) ρ(u)
u
Logistic loss: L(xR+1, y)
def.
= log X
i
exp(xR+1,i) − xR+1,iyi rxR+1L(xR+1, y) = exR+1 P
i exR+1,i y
(classification)
x2
x0 xR E
xR+1
y
θR
x =
x1 . . .
θR−1
L θ1 θ0 g0
g1 gR
xr+1 = gr(xr, θr) E(x) = L(xR+1, y)
x2
x0 xR E
xR+1
y
θR
x =
x1 . . .
θR−1
L θ1 θ0 g0
g1 gR
xr+1 = gr(xr, θr) E(x) = L(xR+1, y) Proposition: rθrE = [∂θrgR(xr, θr)]>(rxr+1E) rxrE = [∂xrgR(xr, θr)]>(rxr+1E) ∀r = R, . . . , 0,
x2
x0 xR E
xR+1
y
θR
x =
x1 . . .
θR−1
L θ1 θ0 g0
g1 gR
xr+1 = gr(xr, θr) E(x) = L(xR+1, y) xr+1 = ρ(Arxr + br) Example: deep neural network ∀r = R, . . . , 0, rxrE = A>
r Mr
rArE = Mrx>
r
rbrE = Mr1 Mr
def.
= ρ0(Arxr + br) rxr+1E Proposition: rθrE = [∂θrgR(xr, θr)]>(rxr+1E) rxrE = [∂xrgR(xr, θr)]>(rxr+1E) ∀r = R, . . . , 0,
x2
x0 xR E
xR+1
y x =
x1 . . . L g0 g1 gR xr+1 = gr(xr, θ) Shared parameters: x2 θ
x0 xR E
xR+1
y x =
x1 . . . L g0 g1 gR xr+1 = gr(xr, θ) Shared parameters: x2 θ = at θ
g
xt bt θ
g g g
. . . a1
xt−1
Recurrent networks for natural language processing: xT x1 a0 aT bT b1 b0 x2
x0 xR E
xR+1
y x =
x1 . . . L g0 g1 gR xr+1 = gr(xr, θ) Shared parameters: x2 θ for complicated computational architectures, Take home message:
you do not want to do the computation/implementation by hand. =
at θ g xt bt θ g g g
. . . a1
xt−1
Recurrent networks for natural language processing: xT x1 a0 aT bT b1 b0 x2
θ3 θ1 θ2 θ4 θ5 g3 g4 g5 input
return θR θr = gr(θParents(r)) for r = M + 1, . . . , R function `(✓1, . . . , ✓M) forward computing ` Computer program ⇔ directed acyclic graph ⇔ linear ordering of nodes (θr)r
`(✓1, ✓2)
def.
= ✓2eθ1p ✓1 + ✓2eθ1
θ1 θ2 input
θ3
def.
= eθ1 θ4
def.
= θ2θ3 θ5
def.
= θ1 + θ4 θ6
def.
= p θ5
θ7
def.
= θ4θ6 g3
g4 g5 g7 g6 `
`(✓1, ✓2)
def.
= ✓2eθ1p ✓1 + ✓2eθ1
θ1 θ2 input
θ3
def.
= eθ1 θ4
def.
= θ2θ3 θ5
def.
= θ1 + θ4 θ6
def.
= p θ5
θ7
def.
= θ4θ6 g3
g4 g5 g7 g6 Chain rules: θj = gj(θi)i6j θk = gk(θ`)`6k θi θN gj gk
. . . . . . θ1 ∂θj ∂θ1 = X
i∈Parent(j)
∂θj ∂θi ∂θi ∂θ1 ∂igj(θ) “Classical” evaluation: forward. Complexity ∼ #inputs.
`
`(✓1, ✓2)
def.
= ✓2eθ1p ✓1 + ✓2eθ1
θ1 θ2 input
θ3
def.
= eθ1 θ4
def.
= θ2θ3 θ5
def.
= θ1 + θ4 θ6
def.
= p θ5
θ7
def.
= θ4θ6 g3
g4 g5 g7 g6 Chain rules: θj = gj(θi)i6j θk = gk(θ`)`6k θi θN gj gk
. . . . . . θ1
∂θN ∂θj = X
k∈Child(j)
∂θN ∂θk ∂θk ∂θj ∂jgk(θ) rk`(✓) rj`(✓) Complexity ∼ #outputs (1 for grad). Backward evaluation. ∂θj ∂θ1 = X
i∈Parent(j)
∂θj ∂θi ∂θi ∂θ1 ∂igj(θ) “Classical” evaluation: forward. Complexity ∼ #inputs.
`
return θR θr = gr(θParents(r)) for r = M + 1, . . . , R function `(✓1, . . . , ✓M) forward for r = R − 1, . . . , 1 rR` = 1 rr` = X
s∈Child(r)
@rgs(✓) rs` backward return (r1`, . . . , rM`) function r`(✓1, . . . , ✓M) computing ` computing r` `(✓1, ✓2)
def.
= ✓2eθ1p ✓1 + ✓2eθ1
θ1 θ2 input
θ3
def.
= eθ1 θ4
def.
= θ2θ3 θ5
def.
= θ1 + θ4 θ6
def.
= p θ5
θ7
def.
= θ4θ6 g3
g4 g5 g7 g6 `