Gradient Estimation for Implicit Models with Steins Method Yingzhen - - PowerPoint PPT Presentation

gradient estimation for implicit models with stein s
SMART_READER_LITE
LIVE PREVIEW

Gradient Estimation for Implicit Models with Steins Method Yingzhen - - PowerPoint PPT Presentation

Gradient Estimation for Implicit Models with Steins Method Yingzhen Li Microsoft Research Cambridge Joint work with Rich Turner, Wenbo Gong, and Jos e Miguel Hern andez-Lobato . A little about my research... scalability VI +


slide-1
SLIDE 1

Gradient Estimation for Implicit Models with Stein’s Method

Yingzhen Li Microsoft Research Cambridge

Joint work with Rich Turner, Wenbo Gong, and Jos´ e Miguel Hern´ andez-Lobato .

slide-2
SLIDE 2

A little about my research...

Bayesian Deep Learning

scalability accuracy

current methods VI + Gaussian MCMC VI + implicit dist.

1

slide-3
SLIDE 3

Examples for implicit (generative) models

Implicit distributions: + easy to sample: z ∼ q(z|x) ⇔ ǫ ∼ π(ǫ), z = f (ǫ, x) + super flexible

2

slide-4
SLIDE 4

Examples for implicit (generative) models

  • Bayesian inference goal: compute Ep(z|x) [F(z)]
  • Approximate inference: find q(z|x) in some family Q such that q(z|x) ≈ p(z|x)
  • At inference time: Monte Carlo integration:

Ep(z|x) [F(z)] ≈ 1 K

K

  • k=1

F(zk), zk ∼ q(z|x)

3

slide-5
SLIDE 5

Examples for implicit (generative) models

  • Bayesian inference goal: compute Ep(z|x) [F(z)]
  • Approximate inference: find q(z|x) in some family Q such that q(z|x) ≈ p(z|x)
  • At inference time: Monte Carlo integration:

Ep(z|x) [F(z)] ≈ 1 K

K

  • k=1

F(zk), zk ∼ q(z|x)

Tractability requirement: fast sampling from q

(no need for point-wise density evaluation)

3

slide-6
SLIDE 6

Examples for implicit (generative) models

  • Fig. source: Mescheder et al. (2017)

Implicit distributions: + easy to sample: z ∼ q(z|x) ⇔ ǫ ∼ π(ǫ), z = f (ǫ, x) + super flexible + better approximate posterior

4

slide-7
SLIDE 7

Examples for implicit (generative) models

  • Fig. source: Mescheder et al. (2017)

Implicit distributions: + easy to sample: z ∼ q(z|x) ⇔ ǫ ∼ π(ǫ), z = f (ǫ, x) + super flexible + better approximate posterior − hard to evaluate density, need some tricks for training

4

slide-8
SLIDE 8

Loss approximation vs gradient approximation

To train the implicit generative model pφ(x): E.g. the generative adversarial network (GAN) method (Goodfellow et al. 2014): min

φ JS[pD||pφ] = min θ max D EpD[log D(x)] + Epφ[log(1 − D(x))] true loss

  • approx. loss
  • approx. loss minima

true minimum 5

slide-9
SLIDE 9

Loss approximation vs gradient approximation

Often we use gradient-based optimisation methods to train machine learning models. ... which only require evaluating the gradient, rather than the loss function!

true loss

  • approx. loss
  • approx. loss minima

true minimum true gradient

  • approx. gradient

true minimum 5

slide-10
SLIDE 10

Gradient approximation for VI

Variational inference with q distribution parameterised by φ: φ∗ = arg min

φ

KL[qφ(z|x)||p(z|x)] = arg max

φ

LVI(qφ) LVI(qφ) = log p(x) − KL[qφ(z|x)||p(z|x)] = log p(x) − Eq

  • log qφ(z|x)

p(z|x)

  • = Eq
  • log p(x, z)

qφ(z|x)

  • = Eq [log p(x, z)] + H[qφ(z|x)]

LVI(qφ) is also called the variational lower-bound

6

slide-11
SLIDE 11

Gradient approximation for VI

Variational lower-bound: assume z ∼ qφ ⇔ ǫ ∼ π(ǫ), z = fφ(ǫ, x) LVI(qφ) = Eq [log p(x, z)] + H[qφ(z|x)] = Eπ [log p(x, fφ(ǫ, x))] + H[qφ(z|x)] // reparam. trick

6

slide-12
SLIDE 12

Gradient approximation for VI

Variational lower-bound: assume z ∼ qφ ⇔ ǫ ∼ π(ǫ), z = fφ(ǫ, x) LVI(qφ) = Eq [log p(x, z)] + H[qφ(z|x)] = Eπ [log p(x, fφ(ǫ, x))] + H[qφ(z|x)] // reparam. trick If you use gradient descent for optimisation, then you only need gradients!

6

slide-13
SLIDE 13

Gradient approximation for VI

Variational lower-bound: assume z ∼ qφ ⇔ ǫ ∼ π(ǫ), z = fφ(ǫ, x) LVI(qφ) = Eq [log p(x, z)] + H[qφ(z|x)] = Eπ [log p(x, fφ(ǫ, x))] + H[qφ(z|x)] // reparam. trick If you use gradient descent for optimisation, then you only need gradients! The gradient of the variational lower-bound: ∇φLVI(qφ) = Eπ

  • ∇f log p(x, fφ(ǫ, x))T∇φfφ(ǫ, x)
  • + ∇φH[qφ(z|x)]

6

slide-14
SLIDE 14

Gradient approximation for VI

Variational lower-bound: assume z ∼ qφ ⇔ ǫ ∼ π(ǫ), z = fφ(ǫ, x) LVI(qφ) = Eq [log p(x, z)] + H[qφ(z|x)] = Eπ [log p(x, fφ(ǫ, x))] + H[qφ(z|x)] // reparam. trick If you use gradient descent for optimisation, then you only need gradients! The gradient of the variational lower-bound: ∇φLVI(qφ) = Eπ

  • ∇f log p(x, fφ(ǫ, x))T∇φfφ(ǫ, x)
  • + ∇φH[qφ(z|x)]

The gradient of the entropy term: ∇φH[qφ(z|x)] = −Eπ

  • ∇f log q(fφ(ǫ, x)|x)T∇φfφ(ǫ, x)]
  • − ✭✭✭✭✭✭✭✭

✭ ❤❤❤❤❤❤❤❤ ❤ Eq [∇φ log qφ(z|x)] this term is 0 It remains to approximate ∇z log q(z|x)!

6

slide-15
SLIDE 15

Stein gradient estimator

Goal: approximate ∇x log q(x) for a given distribution q(x)

7

slide-16
SLIDE 16

Stein gradient estimator

Goal: approximate ∇x log q(x) for a given distribution q(x) Stein’s identity: Define h(x): a (column vector) test function satisfying the boundary condition lim

x→∞ q(x)h(x) = 0.

Then we can derive Stein’s identity using integration by parts: Eq[h(x)∇x log q(x)T + ∇xh(x)] = 0

7

slide-17
SLIDE 17

Stein gradient estimator

Goal: approximate ∇x log q(x) for a given distribution q(x) Stein’s identity: Define h(x): a (column vector) test function satisfying the boundary condition lim

x→∞ q(x)h(x) = 0.

Then we can derive Stein’s identity using integration by parts: Eq[h(x)∇x log q(x)T + ∇xh(x)] = 0

Invert Stein’s identity to obtain ∇x log q(x)!

7

slide-18
SLIDE 18

Stein gradient estimator (kernel based)

Goal: approximate ∇x log q(x) for a given distribution q(x) Main idea: invert Stein’s identity: Eq[h(x)∇x log q(x)T + ∇xh(x)] = 0

  • 1. Monte Carlo (MC) approximation to Stein’s identity:

1 K

K

  • k=1

−h(xk)∇xk log q(xk)T + err = 1 K

K

  • k=1

∇xkh(xk), xk ∼ q(xk),

8

slide-19
SLIDE 19

Stein gradient estimator (kernel based)

Goal: approximate ∇x log q(x) for a given distribution q(x) Main idea: invert Stein’s identity: Eq[h(x)∇x log q(x)T + ∇xh(x)] = 0

  • 1. Monte Carlo (MC) approximation to Stein’s identity:

1 K

K

  • k=1

−h(xk)∇xk log q(xk)T + err = 1 K

K

  • k=1

∇xkh(xk), xk ∼ q(xk),

  • 2. Rewrite the MC equations in matrix forms: denoting

H =

  • h(x1), · · · , h(xK)
  • ,

∇xh = 1 K

K

  • k=1

∇xkh(xk), G :=

  • ∇x1 log q(x1), · · · , ∇xK log q(xK)

T , Then − 1

K HG + err = ∇xh. 8

slide-20
SLIDE 20

Stein gradient estimator (kernel based)

Goal: approximate ∇x log q(x) for a given distribution q(x) Main idea: invert Stein’s identity: Eq[h(x)∇x log q(x)T + ∇xh(x)] = 0 Matrix form (MC): − 1

K HG + err = ∇xh.

  • 3. Now solve a ridge regression problem:

ˆ GStein

V

:= arg min

ˆ G∈RK×d

||∇xh + 1 K Hˆ G||2

F + η

K 2 ||ˆ G||2

F, 8

slide-21
SLIDE 21

Stein gradient estimator (kernel based)

Goal: approximate ∇x log q(x) for a given distribution q(x) Main idea: invert Stein’s identity: Eq[h(x)∇x log q(x)T + ∇xh(x)] = 0 Matrix form (MC): − 1

K HG + err = ∇xh.

  • 3. Now solve a ridge regression problem:

ˆ GStein

V

:= arg min

ˆ G∈RK×d

||∇xh + 1 K Hˆ G||2

F + η

K 2 ||ˆ G||2

F,

Analytic solution: ˆ GStein

V

= −(K + ηI)−1∇, K, with K := HTH, Kij = K(xi, xj) := h(xi)Th(xj), ∇, K := KHT∇xh, ∇, Kij = K

k=1 ∇xk

j K(xi, xk).

8

slide-22
SLIDE 22

Stein gradient estimator (kernel based)

Kernelized Stein Discrepancy: S2(q, ˆ q) = Ex,x′∼q ˆ g(x)TKxx′ ˆ g(x′) + ˆ g(x)T∇x′Kxx′ + ∇xKT

xx′ ˆ

g(x′) + Tr(∇x,x′Kxx′)

  • ,

g(x) = ∇x log q(x), ˆ g(x) = ∇x log ˆ q(x), Kxx′ = K(x, x′). One can show that the V-statistic of KSD is S2

V (q, ˆ

q) = 1 K 2 Tr(ˆ GTKˆ G + 2ˆ GT∇, K) + C This means ˆ GStein

V

= arg min

ˆ G∈RK×d

S2

V (q, ˆ

q) + η K 2 ||ˆ G||2

F 9

slide-23
SLIDE 23

Comparisons to existing approaches

parametric non-parametric direct indirect

Stein (our approach)

KDE plug-in denoising auto-encoder score matching

NN-based density estimator improved sample efficiency special case kernelise

KDE plug-in estimator: Singh (1977) Score matching estimator: Hyv¨ arinen (2005), Sasaki et al. (2014), Strathmann et al. (2015) Denoising auto-encoder: Vincent et al. (2008), Alain and Bengio (2014)

10

slide-24
SLIDE 24

Comparisons to existing approaches

Compare to denoising auto-encoder (DAE):

  • DAE: for x ∼ q(x), denoise

ˆ x = x + σǫ to x (by min. ℓ2 loss, ǫ ∼ N(0, I))

  • When σ → 0,

DAE∗(ˆ x) ≈ x + σ2∇x log q(x) − unstable estimate: depends on σ + functional gradient in RKHS: ||∇DAE loss||2

H ∝ KSD

Vincent et al. (2008), Alain and Bengio (2014) with Wenbo Gong and Jos´ e Miguel Hern´ andez-Lobato

11

slide-25
SLIDE 25

Example: entropy regularised GANs

  • Addressing mode collapse: train your generator using entropy regularisation:

min Lgen(pgen) − H[pgen]

  • Lgen(pgen) is the generator loss of your favourite GAN method
  • Again the gradient of H[pgen] is approximated by the gradient estimators

12

slide-26
SLIDE 26

Example: entropy regularised GANs

Significant improvement on sample diversity with the Stein approach

BEGAN: Berthelot et al. (2017)

12

slide-27
SLIDE 27

Meta learning for posterior samplers

Many existing posterior samplers in the literature...

  • Which sampler should I use?
  • How do I tune the hyper-parameters?

Learn a sampler from data!

  • Want a general solution for similar tasks
  • Train on low-dim, generalise to high-dim

Salimans et al. (2015), Song et al. (2017), Levy et al. (2018) 13

slide-28
SLIDE 28

Learning to learn

Meta-learning for SG-MCMC

  • Define a sampler with parameters φ:

zt+1 = zt − ηfφ(zt, H(·), ǫ), ǫ ∼ N(0, I)

  • Run it on some training distributions π(z) ∝ exp[−H(z)],

provide learning signals to train φ

  • Once learned, apply this sampler to test distributions

Andrychowicz et al. (2016), Li and Malik (2017), Wichrowska et al. (2017), Li and Turner (2018) 14

slide-29
SLIDE 29

The complete framework: Ma et al. NIPS 2015

  • Itˆ
  • diffusion

dz = µ(z)dt +

  • 2D

D D(z)dW (t) (1)

  • To make sure π(z) ∝ exp[−H(z)] is a stationary distribution:

µ(z) = −[D D D(z) + Q Q Q(z)]∇zH(z) + Γ Γ Γ(z), Γ Γ Γ(z)i =

d

  • j=1

∂ ∂zj [Dij(z) + Qij(z)] (2)

  • D

D D(z): diffusion matrix, PSD

  • Q

Q Q(z): curl matrix, skew-symmetric

  • Γ

Γ Γ(z): correction vector Ma et al. (2015) completeness result: under some mild conditions “Any Itˆ

  • diffusion that has the unique stationary π(z) is governed by (1)+(2)”

15

slide-30
SLIDE 30

The complete framework: Ma et al. NIPS 2015

Langevin Ma et. al. any SDE flexibility "Any better solutions?" "I know how to pick the best one!" "Is it a valid sampler?"

  • Searching the best sampler within the complete framework:
  • Guaranteed to be correct
  • Retains the most flexibility
  • Only needs to learn how to parameterise D

D D(z) and Q Q Q(z) matrices!

15

slide-31
SLIDE 31

Our recipe: dynamics design

  • Goal: train an SG-MCMC sampler to sample from p(θ|D) ∝ exp[−U(θ)]
  • We augment the state space with momentum variable p:

z = (θ, p), π(z) ∝ exp[−H(z)], H(z) = U(θ) + 1 2pTp

  • Recall the complete recipe

dz = −[D D D(z) + Q Q Q(z)]∇zH(z)dt + Γ Γ Γ(z)dt + √ 2dW (t)

16

slide-32
SLIDE 32

Our recipe: dynamics design

  • Our recipe:

Q Q Q(z z z) =

  • −Q

Q Qf (z z z) Q Q Qf (z z z)

  • ,

D D D(z z z) = D D Df (z z z)

  • ,

Γ Γ Γ(z z z) = Γ Γ Γθ(z) Γ Γ Γp(z)

  • Q

Q Qf (z z z) = diag[f f f φQ(z z z)], D D Df (z z z) = diag[αf f f φQ(z z z) ⊙ f f f φQ(z z z) + f f f φD(z z z) + c], α, c > 0

  • Resulting update rules (rearrange terms & discretise & stochastic gradient):

θ θ θt+1 = θ θ θt +

momentum SGD

  • ηQ

Q Qf (z z zt)p p pt +

correction

ηΓ Γ Γθ

θ θ(z

z zt) p p pt+1 = p p pt − ηD D Df (z z zt)p p pt

  • friction

−ηQ Q Qf (z z zt)∇θ

θ θt ˜

U(θ θ θt) + ηΓ Γ Γp

p p(z

z zt) +

  • Σ(zt)ǫ

ǫ ǫ, ǫ ∼ N(0, I) Σ(zt) = 2ηD D Df (z z zt) − η2Q Q Qf (z z zt)B B B(θt)Q Q Qf (z z zt), B B B(θt) = V[∇θ

θ θt ˜

U(θ θ θt)]

16

slide-33
SLIDE 33

Our recipe: dynamics design

Designing fφQ(z) (responsible for the drift): the ith element is defined as fφQ,i(z) = β + fφQ( ˜ U(θ), pi)

  • We want fφQ(z) to depend on the energy landscape:
  • Fast traversal through low-density regions
  • Better exploration in high-density regions
  • But we don’t want Γ

Γ Γθ(z) to be too expensive! (using ∇θU(θ) as input here leads to an extra term ∇, ∇θU(θ) in Γ Γ Γθ(z))

16

slide-34
SLIDE 34

Our recipe: dynamics design

Designing fφD(z) (responsible for friction): the ith element is defined as fφD,i(z) = fφD( ˜ U(θ), pi, ∂θi ˜ U(θ))

  • Γ

Γ Γp(z) only requires computing ∇pD D Df (z)

  • ...so we can use the gradient information ∇θU(θ)
  • prevent overshoot by “comparing” p and ∇θU(θ)

16

slide-35
SLIDE 35

Our recipe: loss function design

Use KL divergence KL[q(θ)||p(θ|D)] to define loss. Define q(θ) implicitly: run parallel chains for several steps, then

  • Cross-chain loss: at time t, collect samples across chains
  • In-chain loss: for each chain, collect samples by thinning

Gradient of KL approximated by Stein gradient estimator!

17

slide-36
SLIDE 36

A toy example

  • trained on factorised Gaussians, tested on correlated Gaussians
  • manually injected Gaussian noise to the gradients

(and assume we don’t know noise variance B B B(θ))

18

slide-37
SLIDE 37

Bayesian NN on MNIST

Goal: sample from BNN posterior Training: meta sampler trained to sample from the posterior of a BNN (1-hidden layer, 20 hidden units, ReLU) Three generalisation tests:

  • to bigger network architecture: 2-hidden layer MLP (40 units, ReLU)
  • to different activation function: 1-hidden layer MLP (20 units, Sigmoid)
  • to different dataset: train on MNIST 0-4, test on MNIST 5-9

Also consider long-time horizon generalisation

19

slide-38
SLIDE 38

Bayesian NN on MNIST: speed improvements

0.05 0.10 0.15 0.20 0.25 0.30 Error

Network Generalization

Adam SGD-M SGHMC NNSGHMC SGLD

0.05 0.10 0.15 0.20 0.25 0.30

Sigmoid Generalization

100 200 300 400 500 Epoch 2000 3000 4000 5000 6000 7000 8000 9000 10000

  • Neg. LL

100 200 300 400 500 Epoch 2000 3000 4000 5000 6000 7000 8000 9000 10000

iter iter

20

slide-39
SLIDE 39

Bayesian NN on MNIST: long-time generalisation

21

slide-40
SLIDE 40

Bayesian NN on MNIST: understanding the learned sampler

Q Q Qf (z z z) = diag[f f f φQ(z z z)], D D Df (z z z) = diag[αf f f φQ(z z z) ⊙ f f f φQ(z z z) + f f f φD(z z z) + c]

  • fφQ (left): nearly linear wrt. energy (fast traversal, better exploration)
  • fφD (middle): decrease friction around high energy regions
  • fφD (right): increase friction when gradient & momentum “disagree” (prevent overshoot)

22

slide-41
SLIDE 41

Stein gradient estimator: summary

  • We derived a non-parametric gradient estimator
  • We used the gradient estimator for

entropy-regularised GANs

  • We explored meta-learning for approximate inference

scalability accuracy

current methods VI + Gaussian MCMC VI + implicit dist. 23

slide-42
SLIDE 42

Related work by colleagues

Doucet et al. (2013). Derivative-Free Estimation of the Score Vector and Observed Information Matrix with Application to State-Space Models. ArXiv 1304.5768 Shi et al. (2018). A Spectral Approach to Gradient Estimation for Implicit Distributions. ICML 2018 Song et al. (2019). Sliced Score Matching: A Scalable Approach to Density and Score

  • Estimation. UAI 2019

Andrew Duncan’s talk. Minimum Stein discrepancy estimators. This workshop Spectral Estimators for Gradient Fields of Log-Densities. This workshop

24

slide-43
SLIDE 43

Thank you!

  • Y. Li and R.E. Turner. Gradient Estimators for Implicit Models. ICLR 2018
  • W. Gong∗, Y. Li∗ and J.M. Hern´

andez-Lobato. Meta learning for stochastic gradient MCMC. ICLR 2019.

25