OSL 2015 The Wasserstein Barycenter Problem Marco Cuturi - - PowerPoint PPT Presentation

osl 2015 the wasserstein barycenter problem
SMART_READER_LITE
LIVE PREVIEW

OSL 2015 The Wasserstein Barycenter Problem Marco Cuturi - - PowerPoint PPT Presentation

OSL 2015 The Wasserstein Barycenter Problem Marco Cuturi mcuturi@i.kyoto-u.ac.jp Joint work with G. Peyr e, G. Carlier, J.D. Benamou, L. Nenna, A. Gramfort, J. Solomon, ... 13.1.15 1 Motivation 1.2 1 0.8 0.6 0.4 0.2 0 0.2


slide-1
SLIDE 1

OSL 2015 The Wasserstein Barycenter Problem

Marco Cuturi mcuturi@i.kyoto-u.ac.jp Joint work with G. Peyr´ e, G. Carlier, J.D. Benamou, L. Nenna,

  • A. Gramfort, J. Solomon, ...

13.1.15 1

slide-2
SLIDE 2

Motivation

−0.2 0.2 0.4 0.6 0.8 1 1.2 −0.2 0.2 0.4 0.6 0.8 1 1.2

4 points in R2 x1, x2, x3, x4

13.1.15 2

slide-3
SLIDE 3

Mean

−0.2 0.2 0.4 0.6 0.8 1 1.2 −0.2 0.2 0.4 0.6 0.8 1 1.2

Their mean is (x1 + x2 + x3 + x4) /4.

13.1.15 3

slide-4
SLIDE 4

Computing Means

Consider for each point the function · − xi2

2

13.1.15 4

slide-5
SLIDE 5

Computing Means

The mean is the argmin 1

4

4

i=1· − xi2 2.

13.1.15 5

slide-6
SLIDE 6

Means in Metric Spaces

0.2 0.4 0.6 0.8 1 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1 0.2 0.4 0.6 0.8 1 1.2 1.4 1.6 1.8 2

Means can be defined using any distance/divergence/discrepancy.

13.1.15 6

slide-7
SLIDE 7

Means in Metric Spaces

0.2 0.4 0.6 0.8 1 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1 0.2 0.4 0.6 0.8 1 1.2 1.4 1.6 1.8 2

Using e.g. geodesic distances. Here ∆(•, •) = 0.994

13.1.15 7

slide-8
SLIDE 8

Means in Metric Spaces

Consider the distance functions ∆(·, xi), i = 1, 2, 3, 4.

13.1.15 8

slide-9
SLIDE 9

Means in Metric Spaces

= argmin 1

N

N

i=1 ∆(·, xi).

13.1.15 9

slide-10
SLIDE 10

From points

−0.2 0.2 0.4 0.6 0.8 1 1.2 −0.2 0.2 0.4 0.6 0.8 1 1.2

13.1.15 10

slide-11
SLIDE 11

to Probability Measures

−0.2 0.2 0.4 0.6 0.8 1 1.2 −0.2 0.2 0.4 0.6 0.8 1 1.2

Assume that each datum is now an empirical measure. What could be the mean of these 4 measures?

13.1.15 11

slide-12
SLIDE 12
  • 1. Naive Averaging
−0.2 0.2 0.4 0.6 0.8 1 1.2 −0.2 0.2 0.4 0.6 0.8 1 1.2

= naive mean of all observations.

Mean of 4 measures = a point?

13.1.15 12

slide-13
SLIDE 13

Averaging Probability Measures

−0.2 0.2 0.4 0.6 0.8 1 1.2 −0.2 0.2 0.4 0.6 0.8 1 1.2 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1

Same measures, in a 3D perspective.

13.1.15 13

slide-14
SLIDE 14
  • 2. Naive Averaging
−0.2 0.2 0.4 0.6 0.8 1 1.2 −0.2 0.2 0.4 0.6 0.8 1 1.2 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1

Euclidean mean of measures is their sum / N.

Here, ∆(µ, ν) =

  • R2[dµ − dν]2.

13.1.15 14

slide-15
SLIDE 15

Focus on uncertainty

−0.2 0.2 0.4 0.6 0.8 1 1.2 −0.2 0.2 0.4 0.6 0.8 1 1.2 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1

...but geometric knowledge ignored.

13.1.15 15

slide-16
SLIDE 16

Focus on geometry

−0.2 0.2 0.4 0.6 0.8 1 1.2 −0.2 0.2 0.4 0.6 0.8 1 1.2 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1

...but uncertainty is lost.

13.1.15 16

slide-17
SLIDE 17

Problem of interest

Given a discrepancy function ∆ between probabilities, compute their mean: argmin

i ∆(·, νi)

  • The idea is useful, sometimes tractable & appears in
  • Bregman clustering for histograms [Banerjee’05]..
  • Topic modeling [Blei & al.’03]..
  • Clustering problems (k-means).
  • Our goal in this talk: study the case ∆ = Wasserstein

13.1.15 17

slide-18
SLIDE 18

Wasserstein Distances

13.1.15 18

slide-19
SLIDE 19

Comparing Two Measures

Ω µ ν

Two measures µ, ν ∈ P(Ω).

13.1.15 19

slide-20
SLIDE 20

The Optimal Transport Approach

(Ω, D) µ ν x y D(x, y)

Optimal Transport distances rely on 2 key concepts:

  • A metric D : Ω × Ω → R+ ;
  • Π(µ, ν): joint probabilities with marginals µ, ν.

13.1.15 20

slide-21
SLIDE 21

Joint Probabilities of (µ, ν)

−2 −1 1 2 3 4 5 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1

µ ν

Consider µ, ν two measures on the real line.

13.1.15 21

slide-22
SLIDE 22

Joint Probabilities of (µ, ν)

−2 −1 1 2 3 4 5−2 −1 1 2 3 4 5 0.5 µ(x) ν(y) x y P 0.2 0.4 0.6 P (x, y)

Π(µ, ν) = probability measures on Ω2 with marginals µ and ν.

13.1.15 22

slide-23
SLIDE 23

Joint Probabilities of (µ, ν)

−2 −1 1 2 3 4 5−2 −1 1 2 3 4 5 0.5 µ(x) ν(y) x y P 0.2 0.4 0.6 P (x, y)

Π(µ, ν) = probability measures on Ω2 with marginals µ and ν.

13.1.15 23

slide-24
SLIDE 24

Optimal Transport Distance

(Ω, D) µ ν x y D(x, y)

p-Wasserstein (or OT) distance, assuming p ≥ 1, is: Wp(µ, ν) =

  • inf

P ∈Π(µ,ν) EP[D(X, Y )p]

1/p .

13.1.15 24

slide-25
SLIDE 25

(Historical Parenthesis)

Monge-Kantorovich, Kantorovich-Rubinstein, Wasserstein, Earth Mover’s Distance, Mallows

  • Monge 1781 M´

emoire sur la th´ eorie des d´ eblais et des remblais

  • Optimization & Operations Research
  • Kantorovich’42, Dantzig’47, Ford Fulkerson’55,etc.
  • Probability & Statistical Physics
  • Rachev’92, Talagrand’96, Villani’09
  • Computer Vision: Rubner et al’98

13.1.15 25

slide-26
SLIDE 26

OT Distance for Empirical Measures

ν = m

j=1 bjδyj

µ = n

i=1 aiδxi

(Ω, D)

Wp(µ, ν) =

  • inf

P∈Π(µ,ν) EP[D(X, Y )p]

1/p . Algorithmically?

13.1.15 26

slide-27
SLIDE 27

OT Distance for Empirical Measures

ν = m

j=1 1 mδyj

µ = n

i=1 1 nδxi

(Ω, D)

Suppose n = m and all weights are uniform

13.1.15 27

slide-28
SLIDE 28

OT Distance for Empirical Measures

ν = m

j=1 1 mδyj

µ = n

i=1 1 nδxi

(Ω, D)

Then W p

p = optimal matching cost

(solved for instance with Hungarian algorithm)

  • min

σ∈Sn

1 n

n

  • i=1

D(xi, yσi)p 1/p

13.1.15 28

slide-29
SLIDE 29

OT Distance for Empirical Measures

ν = m

j=1 1 mδyj

µ = n

i=1 1 nδxi

(Ω, D)

As soon as n = m or weights are non uniform,

  • ptimal matching does not make sense.

13.1.15 29

slide-30
SLIDE 30

Computing the OT Distance

ν = m

j=1 bjδyj

µ = n

i=1 aiδxi

(Ω, D)

W p

p (µ, ν) can be cast as a linear program in Rn×m:

  • 1. MXY

def

=[D(xi, yj)p]ij ∈ Rn×m (metric information)

  • 2. Transportation Polytope (joint probabilities)

U(a, b) = {P ∈ Rn×m

+

| P1m = a, P T1n = b}

13.1.15 30

slide-31
SLIDE 31

Computing p-Wasserstein Distances

W p

p (µ, ν) = primal(a, b, MXY ) def

= min

T∈U(a,b)T, MXY

MXY W p

p (µ, ν) = T ⋆, MXY

= min

T∈U(a,b)T, MXY

U(a, b) T ⋆

13.1.15 31

slide-32
SLIDE 32

[Kantorovich’42] Duality

  • This primal problem has an equivalent, dual LP:

W p

p (µ, ν) =

                 primal(a, b, MXY )

def

= min

T∈U(a,b)T, MXY

  • r

dual(a, b, MXY )

def

= max

(α,β)∈CMXY

αTa + βTb , where CM = {(α, β) ∈ Rn+m | αi + βj ≤ Mij}.

13.1.15 32

slide-33
SLIDE 33

[Kantorovich’42] Duality

  • This primal problem has an equivalent, dual LP:

W p

p (µ, ν) =

                 primal(a, b, MXY )

def

= min

T∈U(a,b)T, MXY

  • r

dual(a, b, MXY )

def

= max

(α,β)∈CMXY

αTa + βTb , where CM = {(α, β) ∈ Rn+m | αi + βj ≤ Mij}.

Both problems require O(n3 log(n)) operations.

Typically solved using the network simplex.

13.1.15 33

slide-34
SLIDE 34

Wasserstein Barycenter Problem (WBP)

  • [Agueh’11] introduced the WBP:

argmin

µ∈P(Ω)

C(µ)

def

=

N

  • i=1

W p

p (µ, νi),

  • Can be solved with a multi-marginal OT problem.
  • Intractable: LP of

i card(supp(νi)) variables.

13.1.15 34

slide-35
SLIDE 35

Differentiability w.r.t. X or a

ν = m

j=1 bjδyj

µ = n

i=1 aiδxi

(Ω, D)

To solve it numerically, we must understand how fν(a, X)

def

= W p

p (µ, ν) varies when a & X varies.

13.1.15 35

slide-36
SLIDE 36

Differentiability w.r.t. X or a

ν = m

j=1 bjδyj

µ = n

i=1 a′ iδxi

  • 1. Infinitesimal Variation in Weights

fν(a′, X)?, if a′ ≈ a

13.1.15 36

slide-37
SLIDE 37

Differentiability w.r.t. X or a

ν = m

j=1 bjδyj

µ = n

i=1 aiδx′

i

  • 2. Infinitesimal Variation in Locations

fν(a, X′)?, if X′ ≈ X

13.1.15 37

slide-38
SLIDE 38

Using the dual, ∂|a

fν(a, X) = max

(α,β)∈CMXY

αTa + βTb

−5 −4 −3 −2 −1 1 2 3 4 5 −6 −4 −2 2 4 6 8 10

13.1.15 38

slide-39
SLIDE 39

Using the dual, ∂|a

fν(a, X) = max

(α,β)∈CMXY

αTa + βTb

−5 −4 −3 −2 −1 1 2 3 4 5 −6 −4 −2 2 4 6 8 10

a → fν(a, X) is a convex non-smooth map. The dual optimum α⋆ is a subgradient fν(a, X).

13.1.15 39

slide-40
SLIDE 40

Using the primal ∂|X

fν(a, X) = min

T∈U(a,b)T, MXY

  • More involved computations. Tractable when

D =Euclidean, p = 2.

  • Convex quadratic + piecewise linear concave of X
  • ∂fν|X = Y T ⋆T diag(a−1): optimal transport T ⋆T

yields a subgradient.

13.1.15 40

slide-41
SLIDE 41

To sum up: (1) the WBP is challenging

C(a, X)

def

= 1 N

N

  • i=1

W p

p (µ, νi) = 1

N

N

  • i=1

fνi(a, X)

  • a → C(a, X) is convex, non-smooth, computing
  • ne subgradient requires solving N OT problems!
  • X → C(a, X) is not convex, non-smooth

13.1.15 41

slide-42
SLIDE 42

(2) the WBP is unstable

  • Assume X = Y1 = · · · = YN (fixed grid).

C(a) = 1 N

N

  • i=1

primal(a, bi, M) = 1 N

N

  • i=1

min

Ti∈U(a,bi)Ti, M

  • In that case, the WBP can be solved as a large LP:

min

T1,··· ,TN,a N

  • i=1

Ti, M s.t. T T

i 1d = bi, ∀i ≤ N,

T11d = · · · = TN1d = a.

13.1.15 42

slide-43
SLIDE 43

Averaging Two Gaussians

−4 −2 2 4 0.01 0.02 0.03 0.04 0.05 0.06 0.07 N (2, 1) N (−2, 1/4) W = N (0, 5/8)

13.1.15 43

slide-44
SLIDE 44

Discretized

−4 −2 2 4 0.01 0.02 0.03 0.04 0.05 0.06 0.07 q1 q2 pW

pW is the discrete equivalent of the true barycenter.

13.1.15 44

slide-45
SLIDE 45

Exact Solution

−4 −2 2 4 0.01 0.02 0.03 0.04 0.05 0.06 0.07 q1 q2 pW p⋆

p⋆ is the solution to that LP

13.1.15 45

slide-46
SLIDE 46

Does not get much better with large n...

−4 −3 −2 −1 1 2 3 4 5 0.01 0.02 0.03 n = 200 q1 q2 pW p⋆ −4 −3 −2 −1 1 2 3 4 5 0.005 0.01 0.015 0.02 n = 300 q1 q2 pW p⋆ −4 −3 −2 −1 1 2 3 4 5 2 4 6 8 10 12 14x 10

−3

n = 500 q1 q2 pW p⋆

13.1.15 46

slide-47
SLIDE 47

Entropic Smoothing of OT

13.1.15 47

slide-48
SLIDE 48

Smoothing solves (almost) everything

Original OT primal: primal(a, b, MXY ) = min

T ∈U(a,b)T , MXY

Original OT Kantorovich dual: dual(a, b, MXY ) = max

(α,β),αi+βj≤Mij

αTa + βTb

13.1.15 48

slide-49
SLIDE 49

Smoothing solves (almost) everything

Entropy-smoothed (γ > 0) primal problem: primalγ(a, b, MXY ) = min

T ∈U(a,b)T , MXY − γH(T )

Smoothed dual problem: dualγ(a, b, MXY ) = max

(α,β) αTa+βTb−γ

  • i≤n,j≤m

e−(Mij−αi−βj)/γ

13.1.15 49

slide-50
SLIDE 50

Smoothing solves (almost) everything

Entropy-smoothed (γ > 0) primal problem: primalγ(a, b, MXY ) = min

T ∈U(a,b)KL(T e−MXY /γ)

Smoothed dual problem: dualγ(a, b, MXY ) = max

(α,β) αTa+βTb−γ

  • i≤n,j≤m

e−(Mij−αi−βj)/γ

13.1.15 50

slide-51
SLIDE 51

Why is entropy a good regularizer for OT?

The penalized problem Tγ = argmin

T∈U(r,c)

P, M − γH(T ) implies that Tγ has the form: (first order cond.) ∃u ∈ R+

n, v ∈ R+ m | Tγ = diag(u)e−M/γ diag(v).

Gravity Model in Transportation[Wilson’69] Schr¨

  • dinger Problem[’32]

13.1.15 51

slide-52
SLIDE 52

Sinkhorn - Matrix Scaling

Theorem 1 (Sinkhorn’62). For any n × m matrix A with positive entries, any r and c in the simplex, ∃!u ∈ R+

n, v ∈ R+ m such that

     u1 . . . u2 . . . . . . . . . ... . . . . . . un         A         v1 . . . v2 . . . . . . . . . ... . . . . . . vm      ∈ U(r, c) u, v can be computed in O(nm) time using the Sinkhorn fixed-point iteration.

13.1.15 52

slide-53
SLIDE 53

Sinkhorn Algorithm

  • 1. Set K = exp(−M/γ) (note: if M is Euclidean metric, this is a Gaussian convolution...)
  • 2. Seed initial random values for u.
  • 3. Loop until convergence

(a) Set v ← (K′u−1)./c (b) Set u ← (Kv−1)./r

  • T ⋆

γ = diag(u⋆)K diag(v⋆), α⋆ γ = log(u⋆)/γ.

  • Wγ(µ, ν) = Tγ, M = u⋆T (K. ∗ M) v⋆

13.1.15 53

slide-54
SLIDE 54

Benefits of Smoothing [C.’13]

  • These OT problems are strongly convex vs. LPs.

Unicity of solutions, differentiable.

  • Considerably more efficient in practice [Nesterov’05].
  • Primal/dual smoothed optima α⋆

γ, T ⋆ γ can be solved

  • In O(n2) with Sinkhorn’s (IPFP) algorithm,
  • in parallel on GPGPUs for any metric on finite Ω,
  • millions of time faster than simplex,
  • can deal with large dimensions (≈ 50.000 so far).

13.1.15 54

slide-55
SLIDE 55

Our Solution (using regularization)

−4 −2 2 4 0.01 0.02 0.03 0.04 0.05 0.06 0.07 q1 q2 pW p⋆

γ

13.1.15 55

slide-56
SLIDE 56
  • 1. Smoothed primal [C.Doucet’14]

C(a) = 1 N

N

  • i=1

primalγ(a, bi, M)

  • (Projected) gradient descent:
  • Solve N smoothed (dual) OT problems α⋆

i,γ

  • Update a using gradient 1

N

  • i α⋆

i,γ

  • each step requires computing α⋆

i,γ.

13.1.15 56

slide-57
SLIDE 57
  • 2. Dual appproach [C. Peyr´

e’14]

  • The Fenchel-Legendre conjugate of

fb(a) = primalγ(a, b, M), namely f ∗

b(g) = maxp∈Σng, p − fb(p).

has a closed form f ∗

b(g) = γ

  • H(b) + b, log e−M/γeg/γ
  • 13.1.15

57

slide-58
SLIDE 58
  • 2. Dual appproach [C. Peyr´

e’14]

  • The original problem in splitted form:

min

a1,··· ,aN∈Σn N

  • i=1

fbi(ai) subj. to a1 = · · · = aN

  • can be replaced with an easier problem:

min

g1,··· ,gN∈Rn N

  • i=1

f ∗

bi(gi) subj. to N

  • i=1

gi = 0.

gradient/Hessian explicit, equality constraint → truncated Newton. at convergence, all ∇f∗

bi(gi) are equal to solution a⋆.

13.1.15 58

slide-59
SLIDE 59
  • 3. Generalized KL Projections

[NBCCP’14]

  • Idea: generalize KL projection for two marginals

argmin

T∈U(a,bi)

KL(T|e−M/γ)

  • to alternated KL projections for N + common

(unknown) one. argmin

T T

i 1=bi, Ti1=Ti+11

  • i

KL(Ti|e−M/γ)

  • 2 lines of matlab code.

13.1.15 59

slide-60
SLIDE 60

Applications

13.1.15 60

slide-61
SLIDE 61

Averaging 30 Measures

30 measures on R2.

13.1.15 61

slide-62
SLIDE 62

Euclidean Mean

13.1.15 62

slide-63
SLIDE 63

Symmetric KL Mean

13.1.15 63

slide-64
SLIDE 64

2-Wasserstein

13.1.15 64

slide-65
SLIDE 65

Averaging Brain Activations real

MEG ERF data, N=16. Left, medial view. border of (V1)

13.1.15 65

slide-66
SLIDE 66

1-Wasserstein

13.1.15 66

slide-67
SLIDE 67

Averaging Brain Activations real

Right, ventral view

13.1.15 67

slide-68
SLIDE 68

Averaging Brain Activations real

Centered on the Fusiform gyrus

13.1.15 68

slide-69
SLIDE 69

Averaging Text Histograms

  • Using GLOVE embeddings for words, 2-Wasserstein.

13.1.15 69

slide-70
SLIDE 70

Graphics

13.1.15 70

slide-71
SLIDE 71

Graphics

13.1.15 71

slide-72
SLIDE 72

Graphics

13.1.15 72

slide-73
SLIDE 73

Graphics

13.1.15 73

slide-74
SLIDE 74

Graphics

13.1.15 74

slide-75
SLIDE 75

Graphics

13.1.15 75

slide-76
SLIDE 76

End

13.1.15 76