Generative Models and Optimal Transport Marco Cuturi Joint work / - - PowerPoint PPT Presentation

generative models and optimal transport
SMART_READER_LITE
LIVE PREVIEW

Generative Models and Optimal Transport Marco Cuturi Joint work / - - PowerPoint PPT Presentation

Generative Models and Optimal Transport Marco Cuturi Joint work / work in progress with G. Peyr, A. Genevay (ENS) , F. Bach (INRIA), G. Montavon, K-R Mller (TU Berlin) Statistics 0.1 : Density Fitting We collect data N data = 1 X x


slide-1
SLIDE 1

Generative Models and Optimal Transport

Marco Cuturi

Joint work / work in progress with

  • G. Peyré, A. Genevay (ENS), F. Bach (INRIA),
  • G. Montavon, K-R Müller (TU Berlin)
slide-2
SLIDE 2

2

Statistics 0.1 : Density Fitting

νdata

We collect data

νdata = 1 N

N

X

i=1

δxi

slide-3
SLIDE 3

2

Statistics 0.1 : Density Fitting

νdata

We fit a parametric family of densities

{pθ, θ ∈ Θ}

We collect data

νdata = 1 N

N

X

i=1

δxi pθ0 e.g. θ = (m, Σ); pθ = N(m, Σ)

slide-4
SLIDE 4

Density Fitting

νdata pθ1

slide-5
SLIDE 5

Density Fitting

νdata pθ2

slide-6
SLIDE 6

Density Fitting

νdata pθdone!

We stop when there is a good fit.

slide-7
SLIDE 7

Maximum Likelihood Estimation

νdata pθdone! max

θ∈Θ

1 N

N

X

i=1

log pθ(xi)

slide-8
SLIDE 8

Maximum Likelihood Estimation

νdata pθdone! max

θ∈Θ

1 N

N

X

i=1

log pθ(xi)

log 0 = −∞

pθ(xi) must be > 0

slide-9
SLIDE 9

νdata pθdone!

Equivalent to a KL projection in the space of probability measures

{pθ, θ ∈ Θ} νdata pθdone! pθ1 pθ2 min

θ∈Θ KL(νdatakpθ)

KL

Maximum Likelihood Estimation

slide-10
SLIDE 10

νdata pθdone!

Equivalent to a KL projection in the space of probability measures

{pθ, θ ∈ Θ} νdata pθdone! pθ0 pθ1 pθ2 min

θ∈Θ KL(νdatakpθ)

KL

Maximum Likelihood Estimation

slide-11
SLIDE 11

8

In higher dimensional spaces…

νdata min

θ∈Θ KL(νdatakpθ)

slide-12
SLIDE 12

8

In higher dimensional spaces…

νdata min

θ∈Θ KL(νdatakpθ)

slide-13
SLIDE 13

9

Data space has dimension

100 × 100 × 256 × 256 × 256 ≈ 167 × 109 νdata

In higher dimensional spaces…

slide-14
SLIDE 14

10

Generative Models

νdata

slide-15
SLIDE 15

10

Generative Models

µ

latent space data space

νdata

slide-16
SLIDE 16

10

Generative Models

µ

latent space data space

νdata fθ : latent space → data space

slide-17
SLIDE 17

10

Generative Models

µ

latent space data space

νdata fθ : latent space → data space z

z =        .32 .8 .34 . . . .01       

slide-18
SLIDE 18

10

Generative Models

µ

latent space data space

νdata fθ : latent space → data space z

z =        .32 .8 .34 . . . .01       

fθ(z) fθ

slide-19
SLIDE 19

10

Generative Models

µ

latent space data space

νdata fθ : latent space → data space z

z =        .32 .8 .34 . . . .01       

fθ(z) fθ

slide-20
SLIDE 20

10

Generative Models

µ

latent space data space

νdata fθ : latent space → data space fθ]µ

slide-21
SLIDE 21

10

Generative Models

µ

latent space data space

νdata fθ : latent space → data space fθ]µ

Push-forward: ∀B ⊂ Ω, f ]µ(B) := µ(f −1(B))

slide-22
SLIDE 22

11

Generative Models

µ

latent space

νdata fθ : latent space → data space

data space

fθ]µ Goal: find θ such that fθ]µ fits νdata

slide-23
SLIDE 23

11

Generative Models

µ

latent space

νdata fθ : latent space → data space

data space

fθ]µ Goal: find θ such that fθ]µ fits νdata

slide-24
SLIDE 24

12

Generative Models

µ

latent space

νdata fθ : latent space → data space

data space

Difference between fitting a push forward measure fθ]µ vs. a density pθ? fθ]µ

slide-25
SLIDE 25

13

Generative Models

µ

latent space

νdata fθ : latent space → data space

max

θ∈Θ

1 N

N

X

i=1

log pθ(xi)

MLE data space

min

θ∈Θ KL(νdatakpθ)

fθ]µ

=

slide-26
SLIDE 26

14

Generative Models

µ

latent space

νdata fθ : latent space → data space

max

✓∈Θ

1 N

N

X

i=1

log f✓]µ(xi)

MLE data space

min

✓∈Θ KL(νdatakf✓]µ)

fθ]µ

slide-27
SLIDE 27

14

Generative Models

µ

latent space

νdata fθ : latent space → data space

max

✓∈Θ

1 N

N

X

i=1

log f✓]µ(xi)

MLE MLE data space

min

✓∈Θ KL(νdatakf✓]µ)

fθ]µ

slide-28
SLIDE 28

15

Generative Models

µ

latent space

νdata fθ : latent space → data space

data space

Need a more flexible discrepancy function to compare νdata and f✓]µ fθ]µ

slide-29
SLIDE 29

16

Workarounds?

µ

latent space

νdata

data space

  • Formulation as adversarial problem [GPM…’14]
  • Use a richer metric for probability measures,

able to handle measures with non-overlapping supports:

min

✓∈Θ

max

classifiers g Accuracyg ((f✓]µ, +1), (νdata, −1))

min

θ∈Θ ∆(νdata, pθ),

not min

θ∈Θ KL(νdatakpθ)

slide-30
SLIDE 30

Minimum Estimation

17

l1

slide-31
SLIDE 31

Minimum Kantorovich Estimation

18

  • Use optimal transport theory, namely Wasserstein

distances to define discrepancy .

  • Optimal transport? fertile field in mathematics.

Monge Kantorovich Dantzig Brenier McCann Villani Otto Koopmans

Nobel ’75 Fields ’10

min

θ∈Θ W(νdata, fθ]µ)

slide-32
SLIDE 32

What is Optimal Transport?

A geometric toolbox to 
 compare probability measures 
 supported on a metric space.

19

Empirical Measures, i.e. data

µ

ν

h1

Color Histograms

h2

Bags

  • f features

d

pθ pθ0

Statistical Models Brain Activation Maps

slide-33
SLIDE 33

h2

Bags

  • f features

d

Brain Activation Maps

What is Optimal Transport?

A geometric toolbox to 
 compare probability measures
 supported on a metric space.

20

pθ pθ0

Statistical Models

µ

ν

Color Histograms Empirical Measures, i.e. data

slide-34
SLIDE 34

21

pθ0 P(Ω)

Optimal Transport Geometry

A geometric toolbox to 
 compare probability measures
 supported on a metric space.

slide-35
SLIDE 35

21

pθ Wasserstein Distance

W(pθ, pθ0) pθ0 P(Ω)

Optimal Transport Geometry

A geometric toolbox to 
 compare probability measures
 supported on a metric space.

slide-36
SLIDE 36

21

pθ [McCann’95] Interpolant

pθ0 P(Ω)

Optimal Transport Geometry

A geometric toolbox to 
 compare probability measures
 supported on a metric space.

slide-37
SLIDE 37

22

pθ0

pθ00 P(Ω) A geometric toolbox to 
 compare probability measures
 supported on a metric space.

Optimal Transport Geometry

slide-38
SLIDE 38

22

pθ0

pθ00 Wasserstein Barycenter [Agueh’11] P(Ω) A geometric toolbox to 
 compare probability measures
 supported on a metric space.

Optimal Transport Geometry

slide-39
SLIDE 39

23

[SDPC..’15]

A geometric toolbox to 
 compare probability measures
 supported on a metric space.

Optimal Transport Geometry

slide-40
SLIDE 40

23

[SDPC..’15]

A geometric toolbox to 
 compare probability measures
 supported on a metric space.

Optimal Transport Geometry

slide-41
SLIDE 41

24

[SDPC..’15]

A geometric toolbox to 
 compare probability measures
 supported on a metric space.

Optimal Transport Geometry

slide-42
SLIDE 42

25

Origins: Monge’s Problem

slide-43
SLIDE 43

26

Origins: Monge’s Problem

slide-44
SLIDE 44

26

Origins: Monge’s Problem

slide-45
SLIDE 45

26

Origins: Monge’s Problem

slide-46
SLIDE 46

26

Origins: Monge’s Problem

slide-47
SLIDE 47

26

Origins: Monge’s Problem

slide-48
SLIDE 48

26

Origins: Monge’s Problem

slide-49
SLIDE 49

26

Origins: Monge’s Problem

x

slide-50
SLIDE 50

26

Origins: Monge’s Problem

x

slide-51
SLIDE 51

26

Origins: Monge’s Problem

x y = T(x)

slide-52
SLIDE 52

26

Origins: Monge’s Problem

x y = T(x) D(x, T(x))

slide-53
SLIDE 53

27

Ω a probability space, c : Ω × Ω → R. µ, ν two probability measures in P(Ω). x T (x)

[Monge’81] problem: find a map T : Ω → Ω

inf

T ]µ=ν

Z

c(x, T (x))µ(dx)

Origins: Monge’s Problem

slide-54
SLIDE 54

27

Ω a probability space, c : Ω × Ω → R. µ, ν two probability measures in P(Ω). x T (x) If Ω = Rd, c = k · · k2, µ, ν a.c., then T = ru, u convex.

[Brenier’87] [Monge’81] problem: find a map T : Ω → Ω

Origins: Monge’s Problem

slide-55
SLIDE 55

28

[Monge’81] problem: find a map T : Ω → Ω

x T (x) Ω a probability space, c : Ω × Ω → R. µ, ν two probability measures in P(Ω).

inf

T ]µ=ν

Z

c(x, T (x))µ(dx)

Monge’s Problem

slide-56
SLIDE 56

28

[Monge’81] problem: find a map T : Ω → Ω

δx Ω a probability space, c : Ω × Ω → R. µ, ν two probability measures in P(Ω).

inf

T ]µ=ν

Z

c(x, T (x))µ(dx)

Monge’s Problem

slide-57
SLIDE 57

[Kantorovich’42] Relaxation

29

Π(µ, ν)

def

= {P ∈ P(Ω × Ω)| ∀A, B ⊂ Ω, P (A × Ω) = µ(A), P (Ω × B) = ν(B)}

  • Instead of maps , consider

probabilistic maps, i.e. couplings :

T : Ω → Ω P ∈ P(Ω × Ω)

slide-58
SLIDE 58

30

Π(µ, ν)

def

= {P ∈ P(Ω × Ω)| ∀A, B ⊂ Ω, P (A × Ω) = µ(A), P (Ω × B) = ν(B)}

{ } { } { } {

−1 1 2 3 4−1 1 2 3 4 0.2 0.4 0.6 µ(x) ν(y) x y P 0.1 0.2 0.3 P (x, y)

[Kantorovich’42] Relaxation

slide-59
SLIDE 59

30

Π(µ, ν)

def

= {P ∈ P(Ω × Ω)| ∀A, B ⊂ Ω, P (A × Ω) = µ(A), P (Ω × B) = ν(B)}

{ } { } { } {

−1 1 2 3 4−1 1 2 3 4 0.2 0.4 0.6 µ(x) ν(y) x y P 0.1 0.2 0.3 P (x, y) −1 1 2 3 4−1 1 2 3 4 0.2 0.4 0.6 µ(x) ν(y) x y P 5 · 10 0.1 0.15 P (x, y) 0.1 0.2 0.3

[Kantorovich’42] Relaxation

slide-60
SLIDE 60

Wasserstein Distances

31

  • Def. For p ≥ 1, the p-Wasserstein distance

between µ, ν in P(Ω), defined by a metric D on Ω, W p

p (µ, ν) def

= inf

P ∈Π(µ,ν)

ZZ D(x, y)pP (dx, dy).

PRIMAL

slide-61
SLIDE 61

Wasserstein Distances

31

  • Def. For p ≥ 1, the p-Wasserstein distance

between µ, ν in P(Ω), defined by a metric D on Ω, W p

p (µ, ν) def

= inf

P ∈Π(µ,ν)

ZZ D(x, y)pP (dx, dy).

PRIMAL

slide-62
SLIDE 62

Wasserstein Distances

31

  • Def. For p ≥ 1, the p-Wasserstein distance

between µ, ν in P(Ω), defined by a metric D on Ω, W p

p (µ, ν) def

= inf

P ∈Π(µ,ν)

ZZ D(x, y)pP (dx, dy).

PRIMAL

W p

p (µ, ν) =

sup

ϕ∈L1(µ),ψ∈L1(ν) ϕ(x)+ψ(y)≤Dp(x,y)

Z ϕdµ + Z ψdν.

DUAL

slide-63
SLIDE 63

W is versatile

32

Discrete - Continuous Continuous - Continuous Discrete - Discrete

slide-64
SLIDE 64

W is versatile

32

Discrete - Continuous Continuous - Continuous Discrete - Discrete

Stochastic Optimization

  • Network flow solvers
  • Entropic regularization

[GCPB’16]

low dim.

[M’11][KMB’16] [L’15]

slide-65
SLIDE 65

33

Minimum Kantorovich Estimators

min

θ∈Θ W(νdata, fθ]µ)

  • [Bassetti’06] 1st reference discussing this approach.
  • [MMC’16] use regularization in a finite setting.
  • [ACB’17] (WGAN) [BJGR’17] (Wasserstein ABC).
  • Hot topics: approximate & differentiate W efficiently.
  • Today: ideas from our recent preprint [GPC’17]
slide-66
SLIDE 66

Wasserstein between 2 Diracs

34

δy δx (Ω, D) W p

p (δx, δy) = D(x, y)

slide-67
SLIDE 67

Wasserstein on Uniform Measures

35

µ =

n

X

i=1

1 nδxi ν =

n

X

j=1

1 nδyj (Ω, D)

slide-68
SLIDE 68

Wasserstein on Uniform Measures

35

µ =

n

X

i=1

1 nδxi ν =

n

X

j=1

1 nδyj (Ω, D) C(σ) = 1 n

n

X

i=1

D(xi, yσi)p

slide-69
SLIDE 69

Optimal Assignment ⊂ Wasserstein

36

µ =

n

X

i=1

1 nδxi W p

p (µ, ν) = min σ∈Sn C(σ)

ν =

n

X

j=1

1 nδyj (Ω, D)

slide-70
SLIDE 70

37

(Ω, D)

OT on Two Empirical Measures

µ =

n

X

i=1

aiδxi ν =

m

X

j=1

bjδyj

slide-71
SLIDE 71

37

(Ω, D)

OT on Two Empirical Measures

µ =

n

X

i=1

aiδxi ν =

m

X

j=1

bjδyj

slide-72
SLIDE 72

Wasserstein on Empirical Measures

38

U(a, b)

def

= {P ∈ Rn×m

+

|P 1m = a,P T 1n = b} MXY

def

= [D(xi, yj)p]ij Consider µ =

n

X

i=1

aiδxi and ν =

m

X

j=1

bjδyj.

     

b1 ... bm a1

· · · · · · · · ·

. . .

· · · P 1m = a · · ·

an

· · · · · · · · ·            

y1 ... ym x1

· · ·

. . .

· D(xi, yj)p ·

xn

· · ·      

slide-73
SLIDE 73

Wasserstein on Empirical Measures

38

U(a, b)

def

= {P ∈ Rn×m

+

|P 1m = a,P T 1n = b} MXY

def

= [D(xi, yj)p]ij Consider µ =

n

X

i=1

aiδxi and ν =

m

X

j=1

bjδyj.

     

b1 ... bm a1

. . . . . . . . .

. . .

. . . P T 1n = b . . .

an

. . . . . . . . .      

     

y1 ... ym x1

· · ·

. . .

· D(xi, yj)p ·

xn

· · ·      

slide-74
SLIDE 74

Wasserstein on Empirical Measures

38

U(a, b)

def

= {P ∈ Rn×m

+

|P 1m = a,P T 1n = b} MXY

def

= [D(xi, yj)p]ij

  • Def. Optimal Transport Problem

W p

p (µ, ν) =

min

P ∈U(a,b)hP , MXY i

Consider µ =

n

X

i=1

aiδxi and ν =

m

X

j=1

bjδyj.

slide-75
SLIDE 75

Discrete OT Problem

39

MXY U(a, b)

slide-76
SLIDE 76

Discrete OT Problem

40

MXY U(a, b) P ?

slide-77
SLIDE 77

Discrete OT Problem

40

  • Def. Dual OT problem

W p

p (µ, ν) =

max

α∈Rn,β∈Rm αi+βj≤D(xi,yj)p

αT a + βT b MXY U(a, b) P ?

slide-78
SLIDE 78

Discrete OT Problem

40

MXY U(a, b) P ? O(n3 log(n))

network flow solver used in practice.

Note: flow/PDE formulations [Beckman’61]/[Benamou’98] can be used for p=1/p=2 for a sparse-graph metric/Euclidean metric.

slide-79
SLIDE 79

Discrete OT Problem

41

MXY U(a, b) P ? O(n3 log(n))

network flow solver used in practice.

slide-80
SLIDE 80

Discrete OT Problem

41

MXY U(a, b) P ? O(n3 log(n))

network flow solver used in practice.

P ?

Solution unstable and not always unique.

slide-81
SLIDE 81

Discrete OT Problem

41

MXY U(a, b) O(n3 log(n))

network flow solver used in practice.

P ?

Solution unstable and not always unique.

{P ?}

slide-82
SLIDE 82

Discrete OT Problem

42

MXY U(a, b) O(n3 log(n))

network flow solver used in practice.

{P ?} P ?

Solution unstable and not always unique.

slide-83
SLIDE 83

Discrete OT Problem

42

MXY U(a, b) O(n3 log(n))

network flow solver used in practice.

P ? P ?

Solution unstable and not always unique.

slide-84
SLIDE 84

Discrete OT Problem

42

MXY U(a, b) O(n3 log(n))

network flow solver used in practice.

P ? P ?

Solution unstable and not always unique.

W p

p (µ, ν) not differentiable.

slide-85
SLIDE 85

Discrete OT Problem

43

MXY U(a, b) P ?

slide-86
SLIDE 86

Discrete OT Problem

43

MXY U(a, b) P ? O(n3 log(n))

network flow solver used in practice.

slide-87
SLIDE 87

Discrete OT Problem

43

MXY U(a, b) P ? O(n3 log(n))

network flow solver used in practice.

slide-88
SLIDE 88

Solution: Modify OT Problem

44

MXY U(a, b) P ?

Wishlist: faster & scalable, more stable, differentiable

slide-89
SLIDE 89

Entropic Regularization [Wilson’62]

45

Note: Unique optimal solution because of strong concavity of Entropy

E(P)

def

= −

nm

X

i,j=1

Pij(log Pij)

  • Def. Regularized Wasserstein, γ ≥ 0

Wγ(µ, ν)

def

= min

P ∈U(a,b)hP , MXY i γE(P )

slide-90
SLIDE 90

Entropic Regularization [Wilson’62]

45

γ

µ ν Pγ

Note: Unique optimal solution because of strong concavity of Entropy

  • Def. Regularized Wasserstein, γ ≥ 0

Wγ(µ, ν)

def

= min

P ∈U(a,b)hP , MXY i γE(P )

slide-91
SLIDE 91

Fast & Scalable Algorithm

46

  • Prop. If Pγ

def

= argmin

P ∈U(a,b)

hP , MXY iγE(P ) then 9!u 2 Rn

+, v 2 Rm +, such that

Pγ = diag(u)Kdiag(v), K

def

= e−MXY /γ

slide-92
SLIDE 92

Fast & Scalable Algorithm

46

  • Prop. If Pγ

def

= argmin

P ∈U(a,b)

hP , MXY iγE(P ) then 9!u 2 Rn

+, v 2 Rm +, such that

Pγ = diag(u)Kdiag(v), K

def

= e−MXY /γ

L(P, α, β) = X

ij

PijMij + γPij log Pij + αT (P1 − a) + βT (P T 1 − b) ∂L/∂Pij = Mij + γ(log Pij + 1) + αi + βj (∂L/∂Pij = 0) ⇒Pij = e

αi γ + 1 2 e − Mij γ

e

βj γ + 1 2 = ui Kijvj

slide-93
SLIDE 93

Fast & Scalable Algorithm

46

  • [Sinkhorn’64] fixed-point iterations for
  • complexity, GPGPU parallel [C’13] .
  • if and separable.
  • Prop. If Pγ

def

= argmin

P ∈U(a,b)

hP , MXY iγE(P ) then 9!u 2 Rn

+, v 2 Rm +, such that

Pγ = diag(u)Kdiag(v), K

def

= e−MXY /γ (u, v) O(nm) Dp

[S..C..’15]

Ω = {1, . . . , n}d O(nd+1) u ← a/Kv, v ← b/KT u

slide-94
SLIDE 94

Very Fast EMD Approx. Solver

47

  • Note. is a random graph with shortest path metric, histograms

sampled uniformly on simplex, Sinkhorn tolerance 10-2.

(Ω, D)

64 128 256 512 1024 2048 4096 10

−6

10

−4

10

−2

10 10

2

10

4

Histogram Dimension

  • Avg. Execution Time per Distance (in s.)

FastEMD Rubner’s emd CPU γ=0.02 CPU γ=0.1 GPU γ=0.02 GPU γ=0.1

slide-95
SLIDE 95

48

(Ω, D) µ =

n

X

i=1

aiδxi ν =

m

X

j=1

bjδyj

Regularization ⤑ Differentiability

Wγ((a, X), (b, Y )) = min

P ∈U(a,b)hP , MXY iγE(P )

slide-96
SLIDE 96

48

(Ω, D) µ =

n

X

i=1

aiδxi ν =

m

X

j=1

bjδyj

Regularization ⤑ Differentiability

Wγ((a + ∆a, X), (b, Y )) = Wγ((a, X), (b, Y ))+??

slide-97
SLIDE 97

48

(Ω, D) µ =

n

X

i=1

aiδxi ν =

m

X

j=1

bjδyj

Regularization ⤑ Differentiability

a ← a + ∆a

Wγ((a + ∆a, X), (b, Y )) = Wγ((a, X), (b, Y ))+??

slide-98
SLIDE 98

49

(Ω, D) µ =

n

X

i=1

aiδxi ν =

m

X

j=1

bjδyj

Wγ((a, X + ∆X), (b, Y )) = Wγ((a, X), (b, Y ))+??

Regularization ⤑ Differentiability

slide-99
SLIDE 99

49

(Ω, D) µ =

n

X

i=1

aiδxi ν =

m

X

j=1

bjδyj X ← X + ∆X

Wγ((a, X + ∆X), (b, Y )) = Wγ((a, X), (b, Y ))+??

Regularization ⤑ Differentiability

slide-100
SLIDE 100
  • 1. Differentiability of Regularized OT

50

  • Def. Dual regularized OT Problem

Wγ(µ, ν) = max

α,β αT a + βT b − 1

γ (eα/γ)T K Keβ/γ

[CD’14]

  • Prop. W(µ, ν) is
  • 1. convex w.r.t. a,

raW = α? = γ log(u).

  • 2. decreased, when p = 2, Ω = Rd, using

X Y P T

D(a−1).

slide-101
SLIDE 101

51

[CP’16]

  • Prop. Writing Hν : a 7! Wγ(µ, ν),
  • 1. Hν has simple Legendre transform:

H∗

ν : g 2 Rn 7! γ

⇣ E(b) + bT log(Keg/γ) ⌘

  • 2. If A 2 Rn×d, f convex on Rd,

min

a∈ΣnHν(a)+f(Aa)=max g∈RdH∗ ν(

ATg)f ∗( g)

  • 2. Duality for Discrete Reg. OT’s
slide-102
SLIDE 102

W p

p (µ, ν) = sup ϕ,ψ

Z ϕdµ + Z ψdν − ιC(ϕ, ψ)

  • 3. Stochastic Formulation

52

C = {(ϕ, ψ)|ϕ ⊕ ψ ≤ Dp} ιγ

C(ϕ, ψ) = γ

RR e(ϕ⊕ψ−Dp)/γdµdν γ > 0 Wγ(µ, ν) = sup

ϕ,ψ

Z ϕdµ + Z ψdν − ιγ

C(ϕ, ψ)

regularizing dual constraints

REGULARIZED DUAL DUAL

[GCPB’16

slide-103
SLIDE 103

W p

p (µ, ν) = sup ϕ,ψ

Z ϕdµ + Z ψdν − ιC(ϕ, ψ)

  • 3. Stochastic Formulation

52

C = {(ϕ, ψ)|ϕ ⊕ ψ ≤ Dp} ιγ

C(ϕ, ψ) = γ

RR e(ϕ⊕ψ−Dp)/γdµdν γ > 0 Wγ(µ, ν) = sup

ϕ,ψ

Z ϕdµ + Z ψdν − ιγ

C(ϕ, ψ)

regularizing dual constraints

REGULARIZED DUAL DUAL

[GCPB’16

slide-104
SLIDE 104

Smoothed D transforms

53

γ > 0

W p

p (µ, ν) = sup ϕ

Z ϕdµ + Z ϕDdν.

Wγ(µ, ν) = sup

ϕ

Z ϕdµ + Z ϕD,γdν. ϕD,γ = −γ log Z e

ϕ(x)−D(x,·)p γ

dµ(x)

REGULARIZED SEMI-DUAL SEMI-DUAL

slide-105
SLIDE 105

54

Wγ(µ, ν) = sup

ϕ

Z ϕdµ + Z ϕD,γdν. ϕD,γ = −γ log Z e

ϕ(x)−D(x,·)p γ

dµ(x)

REGULARIZED SEMI-DUAL REGULARIZED SEMI-DUAL

Regularized Semidual Wasserstein

substituting

sup

ϕ

Z

y

Z

x

ϕ(x)dµ(x) − γ log Z

x

e

ϕ(x)−D(x,y)p γ

dµ(x)

  • dν(y).
slide-106
SLIDE 106

55

REGULARIZED SEMI-DUAL

Stochastic Regularized Semidual

sup

ϕ

Z

y

Z

x

ϕ(x)dµ(x) − γ log Z

x

e

ϕ(x)−D(x,y)p γ

dµ(x)

  • dν(y).
slide-107
SLIDE 107

55

REGULARIZED SEMI-DUAL

Stochastic Regularized Semidual

What if µ is a discrete measure?

µ = Pn

i=1 aiδxi

ϕ ∈ L1(µ) is now just a vector α ∈ Rn!

sup

ϕ

Z

y

Z

x

ϕ(x)dµ(x) − γ log Z

x

e

ϕ(x)−D(x,y)p γ

dµ(x)

  • dν(y).
slide-108
SLIDE 108

55

REGULARIZED SEMI-DUAL

= sup

α∈Rn Eν[f(α, y)]

STOCHASTIC REGULARIZED SEMI-DUAL

sup

α∈Rn

Z

y

" n X

i=1

αiai − γ log

n

X

i=1

e

αi−D(xi,y)p γ

ai # dν(y)

Stochastic Regularized Semidual

What if µ is a discrete measure?

µ = Pn

i=1 aiδxi

ϕ ∈ L1(µ) is now just a vector α ∈ Rn!

sup

ϕ

Z

y

Z

x

ϕ(x)dµ(x) − γ log Z

x

e

ϕ(x)−D(x,y)p γ

dµ(x)

  • dν(y).
slide-109
SLIDE 109

56

  • 4. Sinkhorn Divergence
  • Prop. Wγ(µ, µ) > 0
  • Def. Normalized Sinkhorn Divergence

¯ Wγ(µ, ν)

def

= Wγ(µ, ν)−1 2 (Wγ(µ, µ) + Wγ(ν, ν))

  • Def. For γ > 0, let Wγ(µ, ν)

def

= hPγ, MXY i

  • Prop. If p = 1, ¯

Wγ(µ, ν) →

γ→∞ ED(µ, ν)

slide-110
SLIDE 110

57

Algorithmic Formulation

Prop.

∂WL ∂X , ∂WL ∂a

can be computed recur- sively, in O(L) kernel K×vector products.

  • Def. For L 1, define

WL(µ, ν)

def

= hPL, MXY i, where PL

def

= diag(uL)Kdiag(vL), v0 = 1m; l 0, ul

def

= a/ Kvl, vl+1

def

= b/ KT ul.

slide-111
SLIDE 111

58

✓∂v0 ∂a ◆T = 0m×n, ✓∂ul ∂a ◆T x = x Kvl

  • ✓∂vl

∂a ◆T KT x a ( Kvl)2 , ✓∂vl+1 ∂a ◆T y = ✓∂ul ∂a ◆T K y b ( KT ul)2 . Example: Differentiability w.r.t. a

Algorithmic Formulation of Reg. OT

slide-112
SLIDE 112

59

Example: Differentiability w.r.t. a

N = K MXY raWL(µ, ν) = ✓∂uL ∂a ◆T NvL + ✓∂vL ∂a ◆T N T uL

Algorithmic Formulation of Reg. OT

slide-113
SLIDE 113

Wasserstein Barycenters

60

Wasserstein Barycenter [Agueh’11]

min

µ∈P(Ω) N

X

i=1

λiW p

p (µ, νi)

ν1 ν2 ν3 P(Ω)

slide-114
SLIDE 114

Multimarginal Formulation

  • Exact solution (W2) using MM-OT. [Agueh’11]

−1 −0.5 0.5 1 1.5 2 2.5 3 −1.5 −1 −0.5 0.5 1

61

slide-115
SLIDE 115

Multimarginal Formulation

  • Exact solution (W2) using MM-OT. [Agueh’11]

−1 −0.5 0.5 1 1.5 2 2.5 3 −1.5 −1 −0.5 0.5 1

If | supp νi| = ni, LP of size (Q

i ni, P i ni)

−1 −0.5 0.5 1 1.5 2 2.5 3 −1.5 −1 −0.5 0.5 1

61

slide-116
SLIDE 116
  • When is a finite set, metric M, another LP.

Finite Case, LP Formulation

62

Ω min

µ

X

i

λiW p

p (µ, νi)

slide-117
SLIDE 117
  • When is a finite set, metric M, another LP.

Finite Case, LP Formulation

62

Ω min

P1,··· ,PN ,a N

X

i=1

λihPi, M i s.t. Pi

T 1n = bi, 8i  N,

P11n = · · · = PN1d = a.

If |Ω| = n, LP of size (Nn2, (2N − 1)n); unstable

slide-118
SLIDE 118

Primal Descent on Regularized W

63

[CD’14]

min

µ∈Q⊂P(Ω) N

X

i=1

λiWγ(µ, νi)

Fast Computation of Wasserstein Barycenters International Conference on Machine Learning 2014

slide-119
SLIDE 119

Primal Descent on Regularized W

63

[CD’14]

min

µ∈Q⊂P(Ω) N

X

i=1

λiWγ(µ, νi)

Fast Computation of Wasserstein Barycenters International Conference on Machine Learning 2014

slide-120
SLIDE 120

Primal Descent on Regularized W

63

[CD’14]

min

µ∈Q⊂P(Ω) N

X

i=1

λiWγ(µ, νi)

Fast Computation of Wasserstein Barycenters International Conference on Machine Learning 2014

slide-121
SLIDE 121

Primal Descent on Algorithmic W

64

min

µ∈Q⊂P(Ω) N

X

i=1

λiWL(µ, νi)

slide-122
SLIDE 122

Primal Descent on Algorithmic W

64

min

µ∈Q⊂P(Ω) N

X

i=1

λiWL(µ, νi)

slide-123
SLIDE 123

Primal Descent on Algorithmic W

64

min

µ∈Q⊂P(Ω) N

X

i=1

λiWL(µ, νi)

not a convex problem

slide-124
SLIDE 124

65

  • consider Barycenter operator:
  • address now Wasserstein inverse problems:

b(λ)

def

= argmin

a N

X

i=1

λiWγ(a, bi) Given a, find argmin

λ∈ΣN

E(λ)

def

= Loss(a, b(λ))

Inverse Wasserstein Problems

slide-125
SLIDE 125

66

The Wasserstein Simplex

slide-126
SLIDE 126

Barycenters = Fixed Points

67

  • Prop. [BCCNP’15] Consider B ∈ ΣN

d

and let U0 = 1d×N, and then for l ≥ 0: bl def = exp

  • log
  • KT Ul
  • λ
  • ;

8 < : Vl+1

def

=

bl1T

N

KT Ul ,

Ul+1

def

=

B KVl+1 .

slide-127
SLIDE 127

68

Using Truncated Barycenters

argmin

λ∈ΣN

E(L)(λ)

def

= Loss(a, b(L)(λ)) argmin

λ∈ΣN

E(λ)

def

= Loss(a, b(λ))

  • instead of using the exact barycenter
  • use instead the L-iterate barycenter
  • Differente using the chain rule.

rE(L)(λ) = [∂b(L)]T (g), g

def

= rLoss(a, ·)|b(L)(λ).

slide-128
SLIDE 128

69

Gradient / Barycenter Computation

slide-129
SLIDE 129

70

Application: Volume Reconstruction

Wasserstein Barycentric Coordinates: Histogram Regression using Optimal Transport, SIGGRAPH’16

[BPC’16]

slide-130
SLIDE 130

71

Application: Color Grading

slide-131
SLIDE 131

72

Application: Color Grading

slide-132
SLIDE 132

73

Application: Color Grading

slide-133
SLIDE 133

74

Application: Color Grading

Wasserstein Barycentric Coordinates: Histogram Regression using Optimal Transport, SIGGRAPH’16

[BPC’16]

slide-134
SLIDE 134

75

Application: Brain Mapping

Original Euclidean Wasserstein projection projection

slide-135
SLIDE 135

75

Application: Brain Mapping

Original Euclidean Wasserstein projection projection

slide-136
SLIDE 136

76

At Last: Application to Generative Models

[GPC’17]

C

K

` ← ` + 1

Sinkhorn Generative model

` = 1, . . . , L − 1

. . .

θ1 θ2

(c(xi, yj))i,j

. . .

Input data

(z1, . . . , zm)

(x1, . . . , xm) (y1, . . . , yn)

1m

ˆ EL(θ)

1/· ×mK>

×nK

1/·

b`

a`+1

b`+1

. . . . . .

h(C K)bL, aLi

e−C/ε

Approximate W loss by the transport cost ¯ WL after L Sinkhorn iterations.

slide-137
SLIDE 137

77

Example: MNIST, Learning fθ

slide-138
SLIDE 138

78

Example: Generation of Images

MMD-GAN gamma = 1000 gamma=10

  • CIFAR-10 images
  • In these examples the cost function is also learned

adversarially, as a NN mapping onto feature vectors.

slide-139
SLIDE 139

79

Concluding Remarks

  • Regularized OT is much faster than OT.
  • Regularized OT can interpolate between W and the

MMD / Energy distance metrics.

  • The solution of regularized OT is “auto-differentiable”.
  • Many open problems remain!

NIPS’17 WORKSHOP NIPS’17 TUTORIAL