Machine Learning Optimal Transport Sayas Numerics Seminar Lars - - PowerPoint PPT Presentation

machine learning optimal transport
SMART_READER_LITE
LIVE PREVIEW

Machine Learning Optimal Transport Sayas Numerics Seminar Lars - - PowerPoint PPT Presentation

Ruthotto ML meet OT @ Oct 2020 Machine Learning Optimal Transport Sayas Numerics Seminar Lars Ruthotto Departments of Mathematics and Computer Science Emory University lruthotto@emory.edu @lruthotto Title ML OT Lag NN Exp OT


slide-1
SLIDE 1

Ruthotto ML meet OT @ Oct 2020

Machine Learning ↔ Optimal Transport

Sayas Numerics Seminar Lars Ruthotto

Departments of Mathematics and Computer Science Emory University lruthotto@emory.edu @lruthotto

Title ML → OT Lag NN Exp OT→CNF Σ 1

slide-2
SLIDE 2

Ruthotto ML meet OT @ Oct 2020

Agenda: Machine Learning meets Optimal Transport

◮ ML → OT: New Tricks from Learning

◮ based on relaxed dynamical optimal transport ◮ combine macroscopic / microscopic / HJB equations ◮ neural networks for value function ◮ combine analytic gradients and automatic differentiation ◮ generalization to mean field games and control problems

◮ OT → ML: Learning from Old Tricks

◮ variational inference via continuous normalizing flows ◮ applications: density estimation, generative modeling ◮ OT uniqueness and regularity of dynamics ◮ HJB, solid numerics, and efficient implementation ◮ orders of magnitude speedup training and inference

LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.

Title ML → OT Lag NN Exp OT→CNF Σ 2

slide-3
SLIDE 3

Ruthotto ML meet OT @ Oct 2020

Collaborators and Funding

Emory Funding: ◮ DMS 1751636 ◮ BSF 2018209 ◮ FA9550-20-1-0372 Special Thanks: ◮ Organizers and staff of IPAM Long Program MLP 2019. ◮ Osher’s funding AFOSR MURI and ONR

Stan Osher

Onken Wu Fung Li Nurbekyan

Title ML → OT Lag NN Exp OT→CNF Σ 3

slide-4
SLIDE 4

Ruthotto ML meet OT @ Oct 2020

initial density, ρ0 target density, ρ1 density evolution

Dynamic Optimal Transport (Benamou and Brenier, ’00)

Given the initial density, ρ0, and the target density, ρ1, find the velocity v that renders the push-forward of ρ0 equal to ρ1 and minimizes the transport costs, i.e., minimizev,ρ 1

1 2v(x, t)2ρ(x, t)dxdt subject to ∂tρ + ∇ · (ρv) = 0, ρ(·, 0) = ρ0(·), ρ(·, 1) = ρ1(·)

Title ML → OT Lag NN Exp OT→CNF Σ 4

slide-5
SLIDE 5

Ruthotto ML meet OT @ Oct 2020

initial density, ρ0 target density, ρ1 density evolution ρ(·, 1) push-fwd of ρ0

Dynamic Optimal Transport (Benamou and Brenier, ’00)

Given the initial density, ρ0, and the target density, ρ1, find the velocity v that renders the push-forward of ρ0 equal to ρ1 and minimizes the transport costs, i.e., minimizev,ρ 1

1 2v(x, t)2ρ(x, t)dxdt subject to ∂tρ + ∇ · (ρv) = 0, ρ(·, 0) = ρ0(·), ρ(·, 1) = ρ1(·)

Title ML → OT Lag NN Exp OT→CNF Σ 4

slide-6
SLIDE 6

Ruthotto ML meet OT @ Oct 2020

initial density, ρ0 target density, ρ1 density evolution

Relaxed Dynamical Optimal Transport

Given the initial density, ρ0, and the target density, ρ1, find the velocity v that minimizes the discrepancy between the push-forward of ρ0 and ρ1 and the transport costs, i.e., minimizev,ρJMFG(ρ, v)

def

= 1

1 2v(x, t)2ρ(x, t)dxdt + G(ρ(·, 1), ρ1) subject to ∂tρ + ∇ · (ρv) = 0, ρ(·, 0) = ρ0(·) (CE) Examples for terminal cost G: L2, Kullback Leibler divergence,. . . Side note: relaxed OT problem is a potential mean field game (MFG)

Title ML → OT Lag NN Exp OT→CNF Σ 5

slide-7
SLIDE 7

Ruthotto ML meet OT @ Oct 2020

initial density, ρ0 target density, ρ1 density evolution ρ(·, 1) push-fwd of ρ0

Relaxed Dynamical Optimal Transport

Given the initial density, ρ0, and the target density, ρ1, find the velocity v that minimizes the discrepancy between the push-forward of ρ0 and ρ1 and the transport costs, i.e., minimizev,ρJMFG(ρ, v)

def

= 1

1 2v(x, t)2ρ(x, t)dxdt + G(ρ(·, 1), ρ1) subject to ∂tρ + ∇ · (ρv) = 0, ρ(·, 0) = ρ0(·) (CE) Examples for terminal cost G: L2, Kullback Leibler divergence,. . . Side note: relaxed OT problem is a potential mean field game (MFG)

Title ML → OT Lag NN Exp OT→CNF Σ 5

slide-8
SLIDE 8

Ruthotto ML meet OT @ Oct 2020

Relaxed Dynamic Optimal Transport: A Microscopic View

A single agent with initial position x ∈ Ω aims at choosing v that minimizes Jx,0(v) = 1 1 2v(s)2ds + G (z(1), ρ(z(1), 1)) , where their position changes according to ∂tz(s) = v (s) , 0 ≤ s ≤ 1, z(0) = x. ◮ G(x, ρ) = δG(ρ,ρ1)

δρ

(x) (variational derivative of G) ◮ agent interacts with the population through ρ and G ◮ z(·) is characteristic curve of (CE) starting at x

Title ML → OT Lag NN Exp OT→CNF Σ 6

slide-9
SLIDE 9

Ruthotto ML meet OT @ Oct 2020

Relaxed Dynamic Optimal Transport: A Microscopic View

A single agent with initial position x ∈ Ω aims at choosing v that minimizes Jx,0(v) = 1 1 2v(s)2ds + G (z(1), ρ(z(1), 1)) , where their position changes according to ∂tz(s) = v (s) , 0 ≤ s ≤ 1, z(0) = x. ◮ G(x, ρ) = δG(ρ,ρ1)

δρ

(x) (variational derivative of G) ◮ agent interacts with the population through ρ and G ◮ z(·) is characteristic curve of (CE) starting at x Useful to define the value of an agent’s state (x, t) as Φ(x, t) = inf

v Jx,t(v)

Title ML → OT Lag NN Exp OT→CNF Σ 6

slide-10
SLIDE 10

Ruthotto ML meet OT @ Oct 2020

Hamilton-Jacobi-Bellman (HJB) Equation

initial density, ρ0 value function density evolution target density, ρ1

Lasry & Lions ’06: First-order optimality conditions of relaxed OT are −∂tΦ(x, t) + 1 2∇Φ(x, t)2 = 0, Φ(x, 1) = G(x, ρ(x, 1)) (HJB) and optimal strategy is v(x, t) = −∇Φ(x, t), which gives ∂tρ(x, t) − ∇ · (ρ(x, t)∇Φ(x, t)) = 0, ρ(x, 0) = ρ0(x) (CE) challenges: forward-backward structure and high-dimensionality of PDE system

Title ML → OT Lag NN Exp OT→CNF Σ 7

slide-11
SLIDE 11

Ruthotto ML meet OT @ Oct 2020

Machine Learning for High-Dimensional OT: Overview

Three options for solving the problem

  • 1. minimize JMFG w.r.t. (ρ, v), or (ρ, −∇Φ) (variational problem)
  • 2. minimize Jx,t w.r.t. v or −∇Φ for some points x (microscopic view)
  • 3. compute value function by solving (HJB) and (CE) (high-dimensional PDEs)

Title ML → OT Lag NN Exp OT→CNF Σ 8

slide-12
SLIDE 12

Ruthotto ML meet OT @ Oct 2020

Machine Learning for High-Dimensional OT: Overview

Three options for solving the problem

  • 1. minimize JMFG w.r.t. (ρ, v), or (ρ, −∇Φ) (variational problem)
  • 2. minimize Jx,t w.r.t. v or −∇Φ for some points x (microscopic view)
  • 3. compute value function by solving (HJB) and (CE) (high-dimensional PDEs)

Idea: Combine advantages of the above to tackle curse of dimensionality

Title ML → OT Lag NN Exp OT→CNF Σ 8

slide-13
SLIDE 13

Ruthotto ML meet OT @ Oct 2020

Machine Learning for High-Dimensional OT: Overview

Three options for solving the problem

  • 1. minimize JMFG w.r.t. (ρ, v), or (ρ, −∇Φ) (variational problem)
  • 2. minimize Jx,t w.r.t. v or −∇Φ for some points x (microscopic view)
  • 3. compute value function by solving (HJB) and (CE) (high-dimensional PDEs)

Idea: Combine advantages of the above to tackle curse of dimensionality ◮ formulate as variational problem. minimize JMFG(ρ, −∇Φ) ◮ eliminate (CE) with Lagrangian PDE solver mesh-free, parallel ◮ parameterize Φ with NN universal approximator, mesh-free, cheap(?) ◮ penalize violations of (HJB) regularity, global convergence(?)

Title ML → OT Lag NN Exp OT→CNF Σ 8

slide-14
SLIDE 14

Ruthotto ML meet OT @ Oct 2020

Lagrangian Method for Continuity Equation

Assume Φ given. Then, the solution to ∂tρ(x, t) − ∇ · (ρ(x, t)∇Φ(x, t)) = 0, ρ(x, 0) = ρ0(x) satisfies ρ(z(x, t), t) det ∇z(x, t) = ρ0(x) along the characteristic curve ∂tz(x, t) = −∇Φ(z(x, t)), z(x, 0) = x.

Title ML → OT Lag NN Exp OT→CNF Σ 9

slide-15
SLIDE 15

Ruthotto ML meet OT @ Oct 2020

Lagrangian Method for Continuity Equation

Assume Φ given. Then, the solution to ∂tρ(x, t) − ∇ · (ρ(x, t)∇Φ(x, t)) = 0, ρ(x, 0) = ρ0(x) satisfies ρ(z(x, t), t) det ∇z(x, t) = ρ0(x) along the characteristic curve ∂tz(x, t) = −∇Φ(z(x, t)), z(x, 0) = x. instead of computing det ∇z(x, t) (cost O(d3) flops) use l(x, t)

def

= log det(∇z(x, t)) = 1 ∆Φ(z(x, t), t)dt Hint: Compute z and l in one ODE solve (parallelize over x1, x2, . . .).

Title ML → OT Lag NN Exp OT→CNF Σ 9

slide-16
SLIDE 16

Ruthotto ML meet OT @ Oct 2020

Lagrangian Method for Optimal Transport

minimizeΦ Eρ0

  • cL(x, 1) + G(z(x, 1)) + α1cH(x, 1) + α2Φ(z(x, 1), 1) − G(z(x, 1))
  • subject to

∂t     z(x, t) l(x, t) cL(x, t) cH(x, t)     =     −∇Φ(z(x, t), t) −∆Φ(z(x, t), t)

1 2∇Φ(z(x, t), t)2

  • ∂tΦ(z(x, t), t) + 1

2∇Φ(z(x, t), t)2

   , t ∈ (0, 1] z(x, 0) = x, l(x, 0) = cL(x, 0) = cH(x, 0) = 0

Title ML → OT Lag NN Exp OT→CNF Σ 10

slide-17
SLIDE 17

Ruthotto ML meet OT @ Oct 2020

Lagrangian Method for Optimal Transport

minimizeΦ Eρ0

  • cL(x, 1) + G(z(x, 1)) + α1cH(x, 1) + α2Φ(z(x, 1), 1) − G(z(x, 1))
  • subject to

∂t     z(x, t) l(x, t) cL(x, t) cH(x, t)     =     −∇Φ(z(x, t), t) −∆Φ(z(x, t), t)

1 2∇Φ(z(x, t), t)2

  • ∂tΦ(z(x, t), t) + 1

2∇Φ(z(x, t), t)2

   , t ∈ (0, 1] z(x, 0) = x, l(x, 0) = cL(x, 0) = cH(x, 0) = 0 ◮ z and l = log det needed to solve continuity eq. (CE) ◮ cL and cH accumulate cost along characteristic ◮ α1, α2: penalty parameters for HJB violation ◮ discretize dynamics with nt steps of Runge-Kutta-4 ◮ discretize E with Monte Carlo ◮ can use SA (SGD, ADAM,. . . ) or SAA (BFGS, Newton,. . . ) methods ◮ no grid needed and computation can be parallelized over x Next, parameterize Φ with NN. Needed: ∇Φ and ∆Φ

Title ML → OT Lag NN Exp OT→CNF Σ 10

slide-18
SLIDE 18

Ruthotto ML meet OT @ Oct 2020

Deep Learning Revolution (?)

◮ deep learning: use neural networks (from ≈ 1950’s) with many hidden layers ◮ able to ”learn” complicated patterns from data ◮ applications: classification, face recognition, segmentation, driverless cars, . . . ◮ recent success fueled by: massive data sets, computing power ◮ A few recent references:

◮ Data Scientist: Sexiest Job of the 21st Century, Harvard Business Rev ’17 ◮ A radical new neural network design could overcome big challenges in AI, MIT Tech Review ’18

Title ML → OT Lag NN Exp OT→CNF Σ 11

slide-19
SLIDE 19

Ruthotto ML meet OT @ Oct 2020

Deep Learning Revolution (?)

         Yj+1 = σ(KjYj + bj) Yj+1 = Yj + σ(KjYj + bj) Yj+1 = Yj + σ (Kj,2σ(Kj,1Yj + bj,1) + bj,2) . . .

(Notation: Yj : features, Kj, bj: weights, σ : activation)

◮ deep learning: use neural networks (from ≈ 1950’s) with many hidden layers ◮ able to ”learn” complicated patterns from data ◮ applications: classification, face recognition, segmentation, driverless cars, . . . ◮ recent success fueled by: massive data sets, computing power ◮ A few recent references:

◮ Data Scientist: Sexiest Job of the 21st Century, Harvard Business Rev ’17 ◮ A radical new neural network design could overcome big challenges in AI, MIT Tech Review ’18

Title ML → OT Lag NN Exp OT→CNF Σ 11

slide-20
SLIDE 20

Ruthotto ML meet OT @ Oct 2020

Neural Network Model for Value Function

Let s = (x, t) ∈ Rd+1 and use (NN + quadratic) model for value function Φ(s, θ) = w⊤N(s, θN) + 1 2s⊤As + c⊤s + b, θ = (w, θN, vec(A), c, b) N(s, θN) is an M-layer ResNet with weights θN = (vec(K0), . . . , vec(KM), b0, . . . , bM).

Title ML → OT Lag NN Exp OT→CNF Σ 12

slide-21
SLIDE 21

Ruthotto ML meet OT @ Oct 2020

Neural Network Model for Value Function

Let s = (x, t) ∈ Rd+1 and use (NN + quadratic) model for value function Φ(s, θ) = w⊤N(s, θN) + 1 2s⊤As + c⊤s + b, θ = (w, θN, vec(A), c, b) N(s, θN) is an M-layer ResNet with weights θN = (vec(K0), . . . , vec(KM), b0, . . . , bM). forward propagation: u−1 = s u0 = σ(K0u−1 + b0) u1 = u0 + hσ(K1u0 + b1) . . . . . . uM = uM−1 + hσ(KMuM−1 + bM), Output: w⊤uM = w⊤N(s, θN)

Title ML → OT Lag NN Exp OT→CNF Σ 12

slide-22
SLIDE 22

Ruthotto ML meet OT @ Oct 2020

Neural Network Model for Value Function

Let s = (x, t) ∈ Rd+1 and use (NN + quadratic) model for value function Φ(s, θ) = w⊤N(s, θN) + 1 2s⊤As + c⊤s + b, θ = (w, θN, vec(A), c, b) N(s, θN) is an M-layer ResNet with weights θN = (vec(K0), . . . , vec(KM), b0, . . . , bM). forward propagation: u−1 = s u0 = σ(K0u−1 + b0) u1 = u0 + hσ(K1u0 + b1) . . . . . . uM = uM−1 + hσ(KMuM−1 + bM), Output: w⊤uM = w⊤N(s, θN) backward propagation: zM+1 = w zM = zM+1 + hK⊤

Mdiag(σ′(KMuM−1 + bM))zM+1,

. . . . . . z1 = z2 + hK⊤

1 diag(σ′(K1u0 + b1))z2,

z0 = K⊤

0 diag(σ′(K0s + b0))z1,

Output: z0 = ∇s(w⊤N(s, θN)) Next: Compute ∆Φ(s, θ) = tr

  • E⊤(∇2

s(N(s, θN)w) + A)E

  • ,

Title ML → OT Lag NN Exp OT→CNF Σ 12

slide-23
SLIDE 23

Ruthotto ML meet OT @ Oct 2020

Computing the Laplacian of Value Function

∆Φ(s, θ) = tr

  • E⊤(∇2

s(N(s, θN)w) + A)E

  • for

E = eye(d+1,d)

Title ML → OT Lag NN Exp OT→CNF Σ 13

slide-24
SLIDE 24

Ruthotto ML meet OT @ Oct 2020

Computing the Laplacian of Value Function

∆Φ(s, θ) = tr

  • E⊤(∇2

s(N(s, θN)w) + A)E

  • for

E = eye(d+1,d) Second term trivial. Focus on NN part and use forward mode for first layer t0 = tr

  • E⊤∇s(K⊤

0 diag(σ′′(K0s + b0))z1)E

  • = (σ′′(K0s + b0) ⊙ z1)⊤((K0E) ⊙ (K0E))1,

( ⊙ Hadamard product, 1 =ones(d,1))

Title ML → OT Lag NN Exp OT→CNF Σ 13

slide-25
SLIDE 25

Ruthotto ML meet OT @ Oct 2020

Computing the Laplacian of Value Function

∆Φ(s, θ) = tr

  • E⊤(∇2

s(N(s, θN)w) + A)E

  • for

E = eye(d+1,d) Second term trivial. Focus on NN part and use forward mode for first layer t0 = tr

  • E⊤∇s(K⊤

0 diag(σ′′(K0s + b0))z1)E

  • = (σ′′(K0s + b0) ⊙ z1)⊤((K0E) ⊙ (K0E))1,

( ⊙ Hadamard product, 1 =ones(d,1))

Get ∆(N(s, θN)w) = t0 + h M

i=1 ti where for i ≥ 1

ti = tr

  • J⊤

i−1∇s(K⊤ i diag(σ′′(Kiui−1(s) + bi))zi+1)Ji−1

  • = (σ′′(Kiui−1 + bi) ⊙ zi+1)⊤((KiJi−1) ⊙ (KiJi−1))1.

Here, Ji−1 = ∇su⊤

i−1 ∈ Rm×d is a Jacobian matrix (update during forward pass)

Title ML → OT Lag NN Exp OT→CNF Σ 13

slide-26
SLIDE 26

Ruthotto ML meet OT @ Oct 2020

Computing the Laplacian of Value Function

∆Φ(s, θ) = tr

  • E⊤(∇2

s(N(s, θN)w) + A)E

  • for

E = eye(d+1,d) Second term trivial. Focus on NN part and use forward mode for first layer t0 = tr

  • E⊤∇s(K⊤

0 diag(σ′′(K0s + b0))z1)E

  • = (σ′′(K0s + b0) ⊙ z1)⊤((K0E) ⊙ (K0E))1,

( ⊙ Hadamard product, 1 =ones(d,1))

Get ∆(N(s, θN)w) = t0 + h M

i=1 ti where for i ≥ 1

ti = tr

  • J⊤

i−1∇s(K⊤ i diag(σ′′(Kiui−1(s) + bi))zi+1)Ji−1

  • = (σ′′(Kiui−1 + bi) ⊙ zi+1)⊤((KiJi−1) ⊙ (KiJi−1))1.

Here, Ji−1 = ∇su⊤

i−1 ∈ Rm×d is a Jacobian matrix (update during forward pass)

  • verall cost when K0 ∈ Rm×(d+1) is O(m2 · d) FLOPS

Title ML → OT Lag NN Exp OT→CNF Σ 13

slide-27
SLIDE 27

Ruthotto ML meet OT @ Oct 2020

Experiment 1: Benefit of HJB Penalty

500 101 102 iteration

with CHJB, nt = 2 no CHJB, nt = 2 no CHJB, nt = 8

ρ0, initial density ρ1, target density JMFG, mean field obj pull back with CHJB, nt = 2 push forward characteristics no CHJB, nt = 2 no CHJB, nt = 8

HJB penalty improves accuracy and(!) lowers computational costs

Title ML → OT Lag NN Exp OT→CNF Σ 14

slide-28
SLIDE 28

Ruthotto ML meet OT @ Oct 2020

Experiment 3: Comparison with Eulerian Solver

Eulerian scheme: ◮ dynamical OT formulation ◮ conservative finite volume ◮ leads to convex optimization ◮ solved to high accuracy with Newton’s method

E Haber, R Horesh A Multilevel Method for the Solution of Time Dependent Optimal Transport, NM-TMA 8(1), 2015.

Title ML → OT Lag NN Exp OT→CNF Σ 15

slide-29
SLIDE 29

Ruthotto ML meet OT @ Oct 2020

Experiment 3: Comparison with Eulerian Solver

Eulerian scheme: ◮ dynamical OT formulation ◮ conservative finite volume ◮ leads to convex optimization ◮ solved to high accuracy with Newton’s method Comparison:

# parameters JMFG Eulerian, fine 3,080,448 1.066e+01 (100.00%) Eulerian, coarse 376,960 1.082e+01 (101.47%) MFGnet (nt = 2) 637 1.072e+01 (100.59%) MFGnet (nt = 8) 637 1.063e+01 (99.69%) E Haber, R Horesh A Multilevel Method for the Solution of Time Dependent Optimal Transport, NM-TMA 8(1), 2015.

ρ0, initial density ρ1, target density pull back Lagrangian, ML push forward characteristics Eulerian, finite volume Title ML → OT Lag NN Exp OT→CNF Σ 15

slide-30
SLIDE 30

Ruthotto ML meet OT @ Oct 2020

Experiment 3: Comparison of Value Functions

ρ0, initial density ρ1, target density initial time, t = 0 final time, t = 1 ΦLag(·, t) ,Lagrangian ML ΦEul(·, t) ,Eulerian FV error, |ΦLag(·, t) − ΦEul(·, t)|

Take away: Eulerian (≈ 3M parameters) and Lagrangian-ML (637 parameters) give comparable accuracy.

Title ML → OT Lag NN Exp OT→CNF Σ 16

slide-31
SLIDE 31

Ruthotto ML meet OT @ Oct 2020

Experiment 3: Comparison of Value Functions

ρ0, initial density ρ1, target density initial time, t = 0 final time, t = 1 ΦLag(·, t) ,Lagrangian ML ΦEul(·, t) ,Eulerian FV error, |ΦLag(·, t) − ΦEul(·, t)|

Take away: Eulerian (≈ 3M parameters) and Lagrangian-ML (637 parameters) give comparable accuracy.

Title ML → OT Lag NN Exp OT→CNF Σ 16

slide-32
SLIDE 32

Ruthotto ML meet OT @ Oct 2020

Extension: Mean Field Games / Mean Field Control

Model large populations of rational agents playing non-cooperative differential game.

Title ML → OT Lag NN Exp OT→CNF Σ 17

slide-33
SLIDE 33

Ruthotto ML meet OT @ Oct 2020

Extension: Mean Field Games / Mean Field Control

Model large populations of rational agents playing non-cooperative differential game. minimizev,ρ JMFG(v, ρ)

def

= 1

  • Rd L (x, v(x, t)) ρ(x, t)dxdt +

1 F(ρ(·, t))dt + G(ρ(·, 1)) subject to ∂tρ(x, t) + ∇ · (ρ(x, t)v(x, t)) = 0, ρ(x, 0) = ρ0(x), Use running costs F to model, e.g., ◮ congestion FE(ρ) =

  • Rd ρ(x) log(ρ(x))dx

◮ spatio-temporal preference FP(ρ) =

  • Rd Q(x)ρ(x, t)dx

time − → ← − space − →

Title ML → OT Lag NN Exp OT→CNF Σ 17

slide-34
SLIDE 34

Ruthotto ML meet OT @ Oct 2020

More To Watch

Levon Nurbekyan @ IPAM Opening Workshop Computational methods for mean-field games https://bit.ly/3cELBmW Samy Wu Fung @ Emory Scientific Computing Seminar A GAN-based Approach for High-Dimensional Stochastic Mean Field Games https://bit.ly/2TcqvVp

Title ML → OT Lag NN Exp OT→CNF Σ 18

slide-35
SLIDE 35

Ruthotto ML meet OT @ Oct 2020

Continuous Normalizing Flows (CNF)

Likelihood Maximization

Given samples x1, x2, . . . , xN ∈ Rd, find a velocity v that maximizes the likelihood of the samples w.r.t. the push-forward of the standard normal distribution ρ1, i.e., maximizev,z 1 N

N

  • k=1

ρ1(z(xk, 1)) · det ∇(z(xk, 1)) subject to ∂tz(xk, t) = v(z(xk, t), t), with z(xk, 0) = xk for all k.

W Grathwohl et al. FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative

  • Models. arXiv, 2018.

z(x1, 0) . . . z(xN, 0) ρ1

  • D. Onken S. Wu Fung
  • X. Li

Title ML → OT Lag NN Exp OT→CNF Σ 19

slide-36
SLIDE 36

Ruthotto ML meet OT @ Oct 2020

Continuous Normalizing Flows (CNF)

Likelihood Maximization

Given samples x1, x2, . . . , xN ∈ Rd, find a velocity v that maximizes the likelihood of the samples w.r.t. the push-forward of the standard normal distribution ρ1, i.e., minimizev,z GCNF(v, z) := 1 N

N

  • k=1

1 2z(xk, 1)2 − l(xk, 1)

  • subject to

∂t z(xk, s) l(xk, s)

  • =
  • v(z(xk, s), s)

trace(∇v(z(xk, s), s))

  • with z(xk, 0) = xk and l(xk, 0) = 0 for all k.

Recall: l(xk, 1) = log det(∇z(xk, 1))

W Grathwohl et al. FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative

  • Models. arXiv, 2018.

z(x1, 0) . . . z(xN, 0) ρ1

  • D. Onken S. Wu Fung
  • X. Li

Title ML → OT Lag NN Exp OT→CNF Σ 19

slide-37
SLIDE 37

Ruthotto ML meet OT @ Oct 2020

Continuous Normalizing Flows (CNF)

Likelihood Maximization

Given samples x1, x2, . . . , xN ∈ Rd, find a velocity v that maximizes the likelihood of the samples w.r.t. the push-forward of the standard normal distribution ρ1, i.e., minimizev,z GCNF(v, z) := 1 N

N

  • k=1

1 2z(xk, 1)2 − l(xk, 1)

  • subject to

∂t z(xk, s) l(xk, s)

  • =
  • v(z(xk, s), s)

trace(∇v(z(xk, s), s))

  • with z(xk, 0) = xk and l(xk, 0) = 0 for all k.

Recall: l(xk, 1) = log det(∇z(xk, 1))

W Grathwohl et al. FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative

  • Models. arXiv, 2018.

z(x1, 0) . . . z(xN, 0) ρ1 z(xN, 1)

  • D. Onken S. Wu Fung
  • X. Li

Title ML → OT Lag NN Exp OT→CNF Σ 19

slide-38
SLIDE 38

Ruthotto ML meet OT @ Oct 2020

Continuous Normalizing Flows (CNF)

Likelihood Maximization

Given samples x1, x2, . . . , xN ∈ Rd, find a velocity v that maximizes the likelihood of the samples w.r.t. the push-forward of the standard normal distribution ρ1, i.e., minimizev,z GCNF(v, z) := 1 N

N

  • k=1

1 2z(xk, 1)2 − l(xk, 1)

  • subject to

∂t z(xk, s) l(xk, s)

  • =
  • v(z(xk, s), s)

trace(∇v(z(xk, s), s))

  • with z(xk, 0) = xk and l(xk, 0) = 0 for all k.

Recall: l(xk, 1) = log det(∇z(xk, 1))

W Grathwohl et al. FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative

  • Models. arXiv, 2018.

z(x1, 0) . . . z(xN, 0) ρ1 z(xN, 1) ˆ ρ0

  • D. Onken S. Wu Fung
  • X. Li

Title ML → OT Lag NN Exp OT→CNF Σ 19

slide-39
SLIDE 39

Ruthotto ML meet OT @ Oct 2020

OT-Regularized Continuous Normalizing Flow

OT-Flow: Regularized Continuous Normalizing Flow

Given samples x1, x2, . . . , xN ∈ Rd, find the value function Φ such that the flow given by v = −∇Φ maximizes the likelihood

  • f the samples w.r.t. the standard normal distribution ρ1, i.e.,

minv,z 1 N

N

  • k=1

1 2z(xk, 1)2 − l(xk, 1) + β1cL(xk, 1) + β2cH(xk, 1)

  • subj. to

∂tz(xk, t) = v(z(xk, t), t), z(xk, 0) = xk ∀k

z(x1, 0) ρ1

Title ML → OT Lag NN Exp OT→CNF Σ 20

slide-40
SLIDE 40

Ruthotto ML meet OT @ Oct 2020

OT-Regularized Continuous Normalizing Flow

OT-Flow: Regularized Continuous Normalizing Flow

Given samples x1, x2, . . . , xN ∈ Rd, find the value function Φ such that the flow given by v = −∇Φ maximizes the likelihood

  • f the samples w.r.t. the standard normal distribution ρ1, i.e.,

minv,z 1 N

N

  • k=1

1 2z(xk, 1)2 − l(xk, 1) + β1cL(xk, 1) + β2cH(xk, 1)

  • subj. to

∂tz(xk, t) = v(z(xk, t), t), z(xk, 0) = xk ∀k

z(x1, 0) ρ1 z(x1, 1)

Title ML → OT Lag NN Exp OT→CNF Σ 20

slide-41
SLIDE 41

Ruthotto ML meet OT @ Oct 2020

OT-Regularized Continuous Normalizing Flow

OT-Flow: Regularized Continuous Normalizing Flow

Given samples x1, x2, . . . , xN ∈ Rd, find the value function Φ such that the flow given by v = −∇Φ maximizes the likelihood

  • f the samples w.r.t. the standard normal distribution ρ1, i.e.,

minv,z 1 N

N

  • k=1

1 2z(xk, 1)2 − l(xk, 1) + β1cL(xk, 1) + β2cH(xk, 1)

  • subj. to

∂tz(xk, t) = v(z(xk, t), t), z(xk, 0) = xk ∀k ◮ provides uniqueness ◮ more efficient time integration

z(x1, 0) ρ1 z(x1, 1) ˆ ρ0

L Yang, GE Karniadakis Potential Flow Generator with L2 OT Regularity for Generative Models. arXiv:1908.11462v1, 2018. L Zhang, Weinan E, L Wang Monge-Amp` ere Flow for Generative Modeling, arXiv:1809.10188v1, 2018. C Finlay, JH Jacobsen, L Nurbekyan, AM Oberman How to train your neural ODE, arXiv:2002.02798, 2020.

Title ML → OT Lag NN Exp OT→CNF Σ 20

slide-42
SLIDE 42

Ruthotto ML meet OT @ Oct 2020

Trace Computation: Runtime and Accuracy

◮ Exact computation with automatic differentiation (AD) trace(∇v(x)) =

d

  • i=1

e⊤

i (∇v(x)⊤ei)

exact O(m · d2) FLOPS ◮ trace estimator with AD trace(∇v(x)) = Ew

  • w⊤(∇v(x)⊤w)
  • ≈ 1

S

S

  • k=1

(wk)⊤(∇v(x)⊤wk) inexact O(m · S · d) FLOPS

Title ML → OT Lag NN Exp OT→CNF Σ 21

slide-43
SLIDE 43

Ruthotto ML meet OT @ Oct 2020

Trace Computation: Runtime and Accuracy

◮ Exact computation with automatic differentiation (AD) trace(∇v(x)) =

d

  • i=1

e⊤

i (∇v(x)⊤ei)

exact O(m · d2) FLOPS ◮ trace estimator with AD trace(∇v(x)) = Ew

  • w⊤(∇v(x)⊤w)
  • ≈ 1

S

S

  • k=1

(wk)⊤(∇v(x)⊤wk) inexact O(m · S · d) FLOPS

Title ML → OT Lag NN Exp OT→CNF Σ 21

slide-44
SLIDE 44

Ruthotto ML meet OT @ Oct 2020

Trace Computation: Runtime and Accuracy

◮ Exact computation with automatic differentiation (AD) trace(∇v(x)) =

d

  • i=1

e⊤

i (∇v(x)⊤ei)

exact O(m · d2) FLOPS ◮ trace estimator with AD trace(∇v(x)) = Ew

  • w⊤(∇v(x)⊤w)
  • ≈ 1

S

S

  • k=1

(wk)⊤(∇v(x)⊤wk) inexact O(m · S · d) FLOPS OT-Flow: exact trace computation (highly parallel) using O(m2 · d) FLOPS.

Title ML → OT Lag NN Exp OT→CNF Σ 21

slide-45
SLIDE 45

Ruthotto ML meet OT @ Oct 2020

OT-Flow: Two-Dimensional Examples

moons circles pinwheel checkerboard samples density estimate

Title ML → OT Lag NN Exp OT→CNF Σ 22

slide-46
SLIDE 46

Ruthotto ML meet OT @ Oct 2020

OT-Flow vs. FFJORD, RNODE: UCI Datasets

104 105 106 107 10−5 10−4 10−3 10−2 network parameters max mean discrepancy 100 101 102 10−1 100 101 102 103 104 105 training time [hours] testing time [sec] ◮ OT-Flow yields competitive accuracy w.r.t. MMD ◮ FFJORD, RNODE: between 2× and 22× more weights ◮ OT-Flow considerably faster in training and testing.

Title ML → OT Lag NN Exp OT→CNF Σ 23

slide-47
SLIDE 47

Ruthotto ML meet OT @ Oct 2020

OT-Flow Example: Generative Modeling MNIST

◮ let y1, y2, . . . ∈ R768 MNIST images ◮ train encoder E : R784 → R128 and decoder D : R128 → R784 s.t. D(E(y)) ≈ y ◮ latent space representation of data xj = E(yj) for all j. ◮ train OT-Flow f that maps {xj}j to ρ1 ∼ N(0, I128) ◮ interpolate between two images y1, y2 in latent space and get new image y(λ) = D(f −1(λf(E(y1))+(1−λ)f(E(y2))))

red boxed values are original; others are interpolated in rho_1 space

y1 y2 y4 y3

Title ML → OT Lag NN Exp OT→CNF Σ 24

slide-48
SLIDE 48

Ruthotto ML meet OT @ Oct 2020

OT-Flow - Fast Continuous Normalizing Flows in PyTorch

https://github.com/EmoryMLIP/OT-Flow Julia implementation for more general MFGs: https://github.com/EmoryMLIP/MFGnet.jl

Title ML → OT Lag NN Exp OT→CNF Σ 25

slide-49
SLIDE 49

Ruthotto ML meet OT @ Oct 2020

Σ: Machine Learning meets Optimal Transport

Machine Learning → Optimal Transport

◮ ML attractive for high-dimensional PDEs, control, . . . ◮ MFGnet: mesh-free solver for variational problem and combine. . .

◮ microscopic: Lagrangian method for continuity and HJB eqs. ◮ macroscopic: variational problem, new penalties for HJB eq.

◮ details matter: models, numerics, architecture, training, . . . ◮ surprise: ML solution competitive to convex programming

LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.

Title ML → OT Lag NN Exp OT→CNF Σ 26

slide-50
SLIDE 50

Ruthotto ML meet OT @ Oct 2020

Σ: Machine Learning meets Optimal Transport

Machine Learning → Optimal Transport

◮ ML attractive for high-dimensional PDEs, control, . . . ◮ MFGnet: mesh-free solver for variational problem and combine. . .

◮ microscopic: Lagrangian method for continuity and HJB eqs. ◮ macroscopic: variational problem, new penalties for HJB eq.

◮ details matter: models, numerics, architecture, training, . . . ◮ surprise: ML solution competitive to convex programming

Optimal Transport → Continuous Normalizing Flows

◮ OT regularization: well-posed simplifies time integration ◮ discretize-then-optimize + HJB penalty → very few time steps ◮ don’t take chances: use exact trace computation ◮ OT-Flow speeds up training and testing by ≈ 10x

LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.

Title ML → OT Lag NN Exp OT→CNF Σ 26

slide-51
SLIDE 51

Ruthotto ML meet OT @ Oct 2020

Σ: Machine Learning meets Optimal Transport

Machine Learning → Optimal Transport

◮ ML attractive for high-dimensional PDEs, control, . . . ◮ MFGnet: mesh-free solver for variational problem and combine. . .

◮ microscopic: Lagrangian method for continuity and HJB eqs. ◮ macroscopic: variational problem, new penalties for HJB eq.

◮ details matter: models, numerics, architecture, training, . . . ◮ surprise: ML solution competitive to convex programming

Optimal Transport → Continuous Normalizing Flows

◮ OT regularization: well-posed simplifies time integration ◮ discretize-then-optimize + HJB penalty → very few time steps ◮ don’t take chances: use exact trace computation ◮ OT-Flow speeds up training and testing by ≈ 10x

LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.

Title ML → OT Lag NN Exp OT→CNF Σ 26

slide-52
SLIDE 52

Ruthotto ML meet OT @ Oct 2020

Σ: Machine Learning meets Optimal Transport

Machine Learning → Optimal Transport

◮ ML attractive for high-dimensional PDEs, control, . . . ◮ MFGnet: mesh-free solver for variational problem and combine. . .

◮ microscopic: Lagrangian method for continuity and HJB eqs. ◮ macroscopic: variational problem, new penalties for HJB eq.

◮ details matter: models, numerics, architecture, training, . . . ◮ surprise: ML solution competitive to convex programming

Optimal Transport → Continuous Normalizing Flows

◮ OT regularization: well-posed simplifies time integration ◮ discretize-then-optimize + HJB penalty → very few time steps ◮ don’t take chances: use exact trace computation ◮ OT-Flow speeds up training and testing by ≈ 10x

LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.

Title ML → OT Lag NN Exp OT→CNF Σ 26

slide-53
SLIDE 53

Ruthotto ML meet OT @ Oct 2020

Σ: Machine Learning meets Optimal Transport

Machine Learning → Optimal Transport

◮ ML attractive for high-dimensional PDEs, control, . . . ◮ MFGnet: mesh-free solver for variational problem and combine. . .

◮ microscopic: Lagrangian method for continuity and HJB eqs. ◮ macroscopic: variational problem, new penalties for HJB eq.

◮ details matter: models, numerics, architecture, training, . . . ◮ surprise: ML solution competitive to convex programming

Optimal Transport → Continuous Normalizing Flows

◮ OT regularization: well-posed simplifies time integration ◮ discretize-then-optimize + HJB penalty → very few time steps ◮ don’t take chances: use exact trace computation ◮ OT-Flow speeds up training and testing by ≈ 10x

LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.

Title ML → OT Lag NN Exp OT→CNF Σ 26

slide-54
SLIDE 54

Ruthotto ML meet OT @ Oct 2020

Σ: Machine Learning meets Optimal Transport

Machine Learning → Optimal Transport

◮ ML attractive for high-dimensional PDEs, control, . . . ◮ MFGnet: mesh-free solver for variational problem and combine. . .

◮ microscopic: Lagrangian method for continuity and HJB eqs. ◮ macroscopic: variational problem, new penalties for HJB eq.

◮ details matter: models, numerics, architecture, training, . . . ◮ surprise: ML solution competitive to convex programming

Optimal Transport → Continuous Normalizing Flows

◮ OT regularization: well-posed simplifies time integration ◮ discretize-then-optimize + HJB penalty → very few time steps ◮ don’t take chances: use exact trace computation ◮ OT-Flow speeds up training and testing by ≈ 10x

LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.

Title ML → OT Lag NN Exp OT→CNF Σ 26