Genera&ve Models and Model Cri&cism via Op&mized Maximum - - PowerPoint PPT Presentation

genera ve models and model cri cism via op mized maximum
SMART_READER_LITE
LIVE PREVIEW

Genera&ve Models and Model Cri&cism via Op&mized Maximum - - PowerPoint PPT Presentation

Genera&ve Models and Model Cri&cism via Op&mized Maximum Mean Discrepancy Dougal J. Sutherland Gatsby unit, UCL Two-sample tes&ng Observe two different datasets: vs Y Q X P Ques&on we want to answer: is


slide-1
SLIDE 1

Genera&ve Models and Model Cri&cism via Op&mized Maximum Mean Discrepancy

Dougal J. Sutherland Gatsby unit, UCL

slide-2
SLIDE 2

Two-sample tes&ng

Observe two different datasets:

2

vs

X ∼ P Y ∼ Q

Ques&on we want to answer: is ? P = Q

slide-3
SLIDE 3

Two-sample tes&ng

3

Applica&ons:

  • Do cigareMe smokers and non-smokers have different

distribu&ons of cancers?

  • Do these neurons behave differently when the subject

is looking at background image A instead of B?

  • Do these columns from different databases mean the

same thing?

  • Did my genera&ve model actually learn the distribu&on

I wanted it to?

slide-4
SLIDE 4

Standard approaches

  • (Unpaired) t-test, Wilcoxon rank-sum test, etc

– Only test differences in loca&on (mean)

  • Kolmogorov-Smirnov test

– Tests for all differences – Nonparametric – Hard to extend to > 1d

  • Want a test that looks for all possible differences,

without parametric assump&ons, in mul&ple dimensions

4

slide-5
SLIDE 5

Defining a two-sample test

  • 1. Choose a distance between distribu&ons

– Ideally, if and only if

  • 2. Es&mate the distribu&on distance from data:
  • 3. Choose a rejec&on threshold; say when

5

ρ(P, Q) ρ(P, Q) = 0 P = Q ˆ ρ(X, Y ) Pr

X,Y ∼P (ˆ

ρ(X, Y ) > cα) < α P 6= Q ˆ ρ > cα Reminder:

  • The level of a test is probability of false rejec&on
  • The power of a test is probability of true rejec&on
slide-6
SLIDE 6

A Kernel Distance on Distribu&ons

Quick reminder about kernels:

  • Our data lives in a space
  • The kernel is a similarity func&on
  • Corresponds to a reproducing kernel Hilbert space

(RKHS) , with feature map , by

6

X k : X × X → R H ϕ : X → H hϕ(x), ϕ(y)iH = k(x, y) We’ll use a base kernel on to make a distance between distribu&ons over . X X k(x, y) = exp ✓ 1 2σ2 kx yk2 ◆ e.g.

slide-7
SLIDE 7

Mean embeddings of distribu&ons

The mean embedding of a distribu&on in an RKHS:

7

µP = Ex∼P[ϕ(x)]

Remember

hϕ(x), ϕ(y)i = k(x, y), so we can think of as . ϕ(x) k(x, ·)

−3 −2 −1 1 2 3

x

0.0 0.2 0.4 0.6 0.8 1.0

[k(X, x)]

−2 −1 1 2

x

0.0 0.2 0.4 0.6 0.8 1.0

Pr(X ) x)

slide-8
SLIDE 8

Mean embeddings of distribu&ons

The mean embedding of a distribu&on in an RKHS:

8

µP = Ex∼P[ϕ(x)]

Remember

hϕ(x), ϕ(y)i = k(x, y), so we can think of as . ϕ(x) k(x, ·)

−3 −2 −1 1 2 3

x

0.0 0.2 0.4 0.6 0.8 1.0

[k(X, x)]

−2 −1 1 2

x

0.0 0.2 0.4 0.6 0.8 1.0

Pr(X ) x)

slide-9
SLIDE 9

Mean embeddings of distribu&ons

The mean embedding of a distribu&on in an RKHS:

9

µP = Ex∼P[ϕ(x)]

Remember

hϕ(x), ϕ(y)i = k(x, y), so we can think of as . ϕ(x) k(x, ·)

−3 −2 −1 1 2 3

x

0.0 0.2 0.4 0.6 0.8 1.0

[k(X, x)]

−2 −1 1 2

x

0.0 0.2 0.4 0.6 0.8 1.0

Pr(X ) x)

slide-10
SLIDE 10

Mean embeddings of distribu&ons

The mean embedding of a distribu&on in an RKHS:

10

µP = Ex∼P[ϕ(x)]

Remember

hϕ(x), ϕ(y)i = k(x, y), so we can think of as . ϕ(x) k(x, ·)

−3 −2 −1 1 2 3

x

0.0 0.2 0.4 0.6 0.8 1.0

[k(X, x)]

−2 −1 1 2

x

0.0 0.2 0.4 0.6 0.8 1.0

Pr(X ) x)

slide-11
SLIDE 11

Mean embeddings of distribu&ons

The mean embedding of a distribu&on in an RKHS:

11

µP = Ex∼P[ϕ(x)]

−3 −2 −1 1 2 3

x

0.0 0.2 0.4 0.6 0.8 1.0

[k(X, x)]

−2 −1 1 2

x

0.0 0.2 0.4 0.6 0.8 1.0

Pr(X ) x)

Remember

hϕ(x), ϕ(y)i = k(x, y), so we can think of as . ϕ(x) k(x, ·)

slide-12
SLIDE 12

=hµP, µPi+hµQ, µQi2hµP, µQi = hEX∼P [ϕ(X)], EY ∼Q[ϕ(Y )]i

Maximum Mean Discrepancy (MMD)

The MMD is the distance between mean embeddings:

12

mmd2(P, Q) = kµP µQk2

H

= EX∼P

Y ∼Q [hϕ(X), ϕ(Y )i]

= EX∼P

Y ∼Q [k(X, Y )]

hµP, µQi µP = EX∼P[ϕ(X)]

  • 2

H

mmd(P, Q) = sup

f∈H

EX∼P f(X) − EY ∼Q f(Y )

slide-13
SLIDE 13

=hµP, µPi+hµQ, µQi2hµP, µQi

MMD es&mator

13

mmd2(P, Q) = kµP µQk2

H

= EX∼P

Y ∼Q [k(X, Y )]

hµP, µQi hˆ µP, ˆ µQi = 1 mn X

ij

k(Xi, Yj)

slide-14
SLIDE 14

MMD es&mator

14

mmd2(P, Q) = kµP µQk2

H

= EX∼P

Y ∼Q [k(X, Y )]

hµP, µQi hˆ µP, ˆ µQi = 1 mn X

ij

k(Xi, Yj) =hµP, µPi+hµQ, µQi2hµP, µQi

slide-15
SLIDE 15

MMD es&mator

15

mmd2(P, Q) = kµP µQk2

H

= EX∼P

Y ∼Q [k(X, Y )]

hµP, µQi hˆ µP, ˆ µQi = 1 mn X

ij

k(Xi, Yj) =hµP, µPi+hµQ, µQi2hµP, µQi

slide-16
SLIDE 16

MMD es&mator

16

mmd2(P, Q) = kµP µQk2

H

= EX∼P

Y ∼Q [k(X, Y )]

hµP, µQi hˆ µP, ˆ µQi = 1 mn X

ij

k(Xi, Yj) =hµP, µPi+hµQ, µQi2hµP, µQi

slide-17
SLIDE 17

MMD es&mator

17

mmd2(P, Q) = kµP µQk2

H

= EX∼P

Y ∼Q [k(X, Y )]

hµP, µQi hˆ µP, ˆ µQi = 1 mn X

ij

k(Xi, Yj) =hµP, µPi+hµQ, µQi2hµP, µQi

slide-18
SLIDE 18

MMD es&mator

18

mmd2(P, Q) = kµP µQk2

H

= EX∼P

Y ∼Q [k(X, Y )]

hµP, µQi hˆ µP, ˆ µQi = 1 mn X

ij

k(Xi, Yj) =hµP, µPi+hµQ, µQi2hµP, µQi

slide-19
SLIDE 19

MMD es&mator

19

mmd2(P, Q) = kµP µQk2

H

= EX∼P

Y ∼Q [k(X, Y )]

hµP, µQi hˆ µP, ˆ µQi = 1 mn X

ij

k(Xi, Yj) =hµP, µPi+hµQ, µQi2hµP, µQi

slide-20
SLIDE 20

MMD es&mator

20

mmd2(P, Q) = kµP µQk2

H

= EX∼P

Y ∼Q [k(X, Y )]

hµP, µQi hˆ µP, ˆ µQi = 1 mn X

ij

k(Xi, Yj) =hµP, µPi+hµQ, µQi2hµP, µQi

slide-21
SLIDE 21

MMD es&mator

21

mmd2(P, Q) = kµP µQk2

H

= EX∼P

Y ∼Q [k(X, Y )]

hµP, µQi hˆ µP, ˆ µQi = 1 mn X

ij

k(Xi, Yj) =hµP, µPi+hµQ, µQi2hµP, µQi

slide-22
SLIDE 22

m[ mmd

2(X(1), Y (1))

m[ mmd

2(X(2), Y (2))

Permuta&on tes&ng

22

When , MMD asympto&cs depend on , so it’s hard to find a threshold that way.

P = Q

: (1-𝛽)th quan&le of

m[ mmd

2(X(i), Y (i))

ˆ cα

Permuta&on test: split randomly to es&mate MMD when .

P = Q X Y

P . . . Pr

X,Y ∼P (ˆ

ρ(X, Y ) > cα) ≤ α S&ll need rejec&on threshold:

slide-23
SLIDE 23

5 10 15 20 25 Number of threads 2 4 6 8 10 12 14 16 18 Time (s)

Our perm. MKL spectr.

Compu&ng permuta&on tests

23

K =           k(X1, X1) . . . K(X1, Xm) K(X1, Y1) . . . K(X1, Ym) . . . ... . . . . . . ... . . . k(Xm, X1) . . . K(Xm, Xm) K(Xm, Y1) . . . K(Xm, Ym) k(Y1, X1) . . . K(Y1, Xm) K(Y1, Y1) . . . K(Y1, Ym) . . . ... . . . . . . ... . . . k(Ym, X1) . . . K(Ym, Xm) K(Ym, Y1) . . . K(Ym, Ym)          

Each element of K is added or subtracted to a term of each permuta&on es&mate. So, do it all in one pass. Original Matlab code: 381s BeMer Python code: 182s

slide-24
SLIDE 24

Example two-sample test

24

X ∼ P = N(0, 1) Y ∼ Q = Laplace ⇣ 0,

1 √ 2

−3 −2 −1 1 2 3 0.0 0.1 0.2 0.3 0.4 0.5 0.6 −3 −2 −1 1 2 3 0.0 0.1 0.2 0.3 0.4 0.5 0.6

  • 1. Choose a kernel k
  • 2. Es&mate MMD for true division and many permuta&ons
  • 3. Reject if m [

mmd

2 k(X, Y ) > cα

slide-25
SLIDE 25

The kernel maMers!

Witness func7on f helps compare samples:

25

mmd(P, Q) = EX∼P f(X) − EY ∼Q f(Y )

−4 −3 −2 −1 1 2 3 4 σ = 0. 75; p = 0. 0;

f(x) = µP(x) − µQ(x) = EX∼P k(x, X) − EY ∼Q k(x, Y )

slide-26
SLIDE 26

The kernel maMers!

Witness func7on f helps compare samples:

26

mmd(P, Q) = EX∼P f(X) − EY ∼Q f(Y )

−4 −3 −2 −1 1 2 3 4 σ = 0. 75; p = 0. 0; σ = 2; p = 0. ;3

f(x) = µP(x) − µQ(x) = EX∼P k(x, X) − EY ∼Q k(x, Y )

slide-27
SLIDE 27

The kernel maMers!

Witness func7on f helps compare samples:

27

mmd(P, Q) = EX∼P f(X) − EY ∼Q f(Y )

−4 −3 −2 −1 1 2 3 4 σ = 0. 75; p = 0. 0; σ = 1; p = 0. ;3 σ = 0. 1; p = 0. 16

f(x) = µP(x) − µQ(x) = EX∼P k(x, X) − EY ∼Q k(x, Y )

slide-28
SLIDE 28

Choosing a kernel

So we need a way to pick a kernel to do the test.

28

X Y

Choose a kernel k Chosen k in MMD test

slide-29
SLIDE 29

Y X

Choosing a kernel

So we need a way to pick a kernel to do the test.

29

Split data: Choose a kernel k Chosen k in MMD test How to pick k? But we want the (asympto&cally) most powerful test. Typically: maximize MMD.

slide-30
SLIDE 30

Asympto&c power of MMD

30

When , the MMD es&mator is asympto&cally normal: and we can analyze the power:

\ mmd2 − mmd2 √Vm

D

→ N(0, 1) Vm = VarX∼P m

Y ∼Qm

h \ mmd2(X, Y ) i P 6= Q PrH1 ⇣ m[ mmd

2 > ˆ

cα ⌘ = PrH1 [ mmd

2 − mmd2

√Vm > ˆ cα m√Vm − mmd2 √Vm ! → Φ ✓mmd2 √Vm − cα m√Vm ◆

slide-31
SLIDE 31

MMD t-sta&s&c

31

So we can maximize the power by maximizing But Vm is O(1/m), so the first term dominates for large m, and we should be able to get away with maximizing

τU = mmd2 √Vm − cα m√Vm ˆ τU = \ mmd2 q b Vm − ˆ cα m q b Vm tU = mmd2 √Vm ˆ tU = \ mmd2 q b Vm PrH1 ⇣ m[ mmd

2 > ˆ

cα ⌘ → Φ ✓mmd2 √Vm − cα m√Vm ◆

slide-32
SLIDE 32

t-sta&s&c es&mator

32

\ mmd2 := 1 m

2

  • X

i6=j

k(Xi, Xj) + k(Yi, Yj) − k(Xi, Yj) − k(Xj, Yi) ˆ τU = \ mmd2 q b Vm − ˆ cα m q b Vm

b Vm := 2 m2(m 1)2 ⇣ 2k ˜ KXXek2 k ˜ KXXk2

F + 2k ˜

KY Y ek2 k ˜ KY Y k2

F

  • 4m 6

m3(m 1)3 ⇣ eT ˜ KXXe ⌘2 + ⇣ eT ˜ KY Y e ⌘2 + 4(m 2) m3(m 1)2

  • kKXY ek2 + kKT

XY ek2

  • 4(m 3)

m3(m 1)2 kKXY k2

F

8m 12 m5(m 1)

  • eTKXY e

2 + 8 m3(m 1) ✓ 1 m ⇣ eT ˜ KXXe + eT ˜ KY Y e ⌘ eTKXY e

  • eT ˜

KXXKXY e eT ˜ KY Y KT

XY e

is from a permuta&on test, so it’s the average of a bunch of MMD es&mates. ˆ cα

ˆ tU = \ mmd2 q b Vm

slide-33
SLIDE 33

t-sta&s&c es&mator

33

Can even get gradients of and (with some more effort) , to maximize it. (automa&c differen&a&on is your friend) ˆ tU ˆ τU

slide-34
SLIDE 34

Kernel choice on Blobs

34

Blobs dataset: vs Mixture of

N ✓ µij,  1

ε−1 ε+1 ε−1 ε+1

1 ◆ N ✓ µij, 1 1 ◆

Mixture of When 𝜁=1, P = Q; this picture has 𝜁=6. Only consider choosing the bandwidth of a Gaussian kernel.

slide-35
SLIDE 35

Kernel choice on Blobs

35

m = 500

slide-36
SLIDE 36

. . . y1 y2 ym

Deep kernels

Map through layers of a deep network:

36

x1 x2 xm . . . f(x1) f(x2) f(xm) . . .

convolu&onal layers

. . .

convolu&onal layers

ˆ tU or ˆ τU f(ym) f(y2) f(y1)

slide-37
SLIDE 37

. . . y1 y2 ym

ARD kernel

Simple scaling func&on:

37

x1 x2 xm . . . f(x1) f(x2) f(xm) . . .

convolu&onal layers

. . .

convolu&onal layers

ˆ tU or ˆ τU f(ym) f(y2) f(y1) f       x1 . . . xd       =    w1x1 . . . wdxd   

slide-38
SLIDE 38

Genera&ve model cri&cism

38

model samples

MNIST digits

vs

slide-39
SLIDE 39

Genera&ve model cri&cism

39

model samples

MNIST digits

vs

ARD weights

100 &mes: 2k samples, 1k permuta&ons:

  • ARD with : 98 &mes p = .000

2 &mes p = .001

  • Just bandwidth with :

43 &mes p > .01, max = .135

  • Median bandwidth:

58 &mes p > .01 3 &mes p = 1.000

ˆ tU ˆ tU

slide-40
SLIDE 40

Genera&ve model cri&cism

40

slide-41
SLIDE 41

GANs

Genera&ve Adversarial Networks

41

z ∼ Unif

  • [0, 1]100

G(z) = x ∼ PG Generator: x ∼ PG x0 ∼ Pdata D(x) = Pr (x came from G) Discriminator: Classifier, trained on samples Trained to trick the classifier

2JS(PG, Pdata) − 2 log 2

With op&mal discriminator, minimizes

JS(p, q) = 1 2KL ✓ p

  • p + q

2 ◆ + 1 2KL ✓ q

  • p + q

2 ◆

slide-42
SLIDE 42

GAN example on MNIST

Epoch 1

42

PG Pdata

slide-43
SLIDE 43

GAN example on MNIST

Epoch 2

43

PG Pdata

slide-44
SLIDE 44

GAN example on MNIST

Epoch 3

44

PG Pdata

slide-45
SLIDE 45

GAN example on MNIST

Epoch 4

45

PG Pdata

slide-46
SLIDE 46

GAN example on MNIST

Epoch 5

46

PG Pdata

slide-47
SLIDE 47

GAN example on MNIST

Epoch 6

47

PG Pdata

slide-48
SLIDE 48

GAN example on MNIST

Epoch 100

48

PG Pdata

slide-49
SLIDE 49

GAN example on MNIST

Epoch 200

49

PG Pdata

slide-50
SLIDE 50

GAN example on MNIST

Epoch 500

50

PG Pdata

slide-51
SLIDE 51

GAN example on MNIST

Epoch 900

51

PG Pdata

slide-52
SLIDE 52

Problems with GANs

52

slide-53
SLIDE 53

Problems with GANs

Some possible reasons:

  • 1. For a fixed discriminator, the op&mal generator

puts a point mass at one point the discriminator thinks is good.

  • 2. For a fixed generator, with probability 1 there is

a perfect discriminator with flat gradients (Arjovsky and BoMou 2017).

53

slide-54
SLIDE 54

Genera&ve Moment-Matching Networks

We can solve problem 1 by caring about sets at a &me… Replace the discriminator with an MMD two-sample test!

Li, Swersky, & Zemel, UAI 2015; Dziugaite, Roy, & Ghahramani, UAI 2015

54

Single Gaussian kernel, minimize MMD

(a) generative MMD

Mix of Gaussian kernels, minimize MMD

(b) generative MMD var ridge

Mix of Gaussian kernels, minimize t sta&s&c

slide-55
SLIDE 55

Genera&ve Moment-Matching Networks

This also solves problem 2, for a fixed kernel:

55

PG Pdata

  • 3
  • 2
  • 1

1 2 3 0.2 0.4 0.6 0.8 1.0

JS(PG, Pdata)

  • 4
  • 2

2 4 0.5 1.0 1.5

mmd2(PG, Pdata) But once we try op&mizing the kernel, it breaks again.

slide-56
SLIDE 56

56