Composing graphical models with neural networks for structured - - PowerPoint PPT Presentation

composing graphical models with neural networks for
SMART_READER_LITE
LIVE PREVIEW

Composing graphical models with neural networks for structured - - PowerPoint PPT Presentation

Composing graphical models with neural networks for structured representations and fast inference Matt Johnson, David Duvenaud, Alex Wiltschko, Bob Datta, Ryan Adams 6 0 6 0 6 0 5 0 5 0 5 0 4 0 4 0 4 0 m m m m 3 0 m 3 0


slide-1
SLIDE 1

Composing graphical models with neural networks for structured representations and fast inference

Matt Johnson, David Duvenaud, Alex Wiltschko, Bob Datta, Ryan Adams

slide-2
SLIDE 2
slide-3
SLIDE 3

pause rear

m m 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 2 30 40 m m 1 2 3 4 m m 5 6 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 20 30 40 50 60 70 10 20 3 40 m m 1 2 3 4 m m 5 6 10 20 30 40 50 60 70 mm 10 20 30 40 50 60 70 10 20 3 40 m m 1 2 3 4 m m 5 6 10 20 30 40 50 60 70 mm

dart

[1] Lee and Glass. A Nonparametric Bayesian Approach to Acoustic Model Discovery. ACL 2012. [2] Lee. Discovering Linguistic Structures in Speech: Models and Applications. MIT Ph.D. Thesis 2014. [1,2]

/b/ /ax/ /n/ /ae/ /n/ /ax/

10 20 30 40 50 60 70 10 20 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 10 20 30 40 50 60 70 10 20 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm mm 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 20 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150
slide-4
SLIDE 4
slide-5
SLIDE 5

Alexander Wiltschko, Matthew Johnson, et al., Neuron 2015.

slide-6
SLIDE 6

image manifold

slide-7
SLIDE 7

image manifold depth video

slide-8
SLIDE 8

image manifold depth video

slide-9
SLIDE 9

image manifold depth video

slide-10
SLIDE 10

image manifold depth video

slide-11
SLIDE 11

image manifold depth video

slide-12
SLIDE 12

image manifold depth video

slide-13
SLIDE 13

rear dart

manifold coordinates image manifold depth video

slide-14
SLIDE 14

[1] Srivastava, Mansimov, Salakhutdinov. Unsupervised learning of video representations using LSTMs. ICML 2015. [2] Ranzato, MarcAurelio, et al. Video (language) modeling: a baseline for generative models of natural videos. Preprint 2015. [3] Sutskever, Hinton, and Taylor. The Recurrent Temporal Restricted Boltzmann Machine. NIPS 2008.

Recurrent neural networks?

Figure 1. LSTM unit

v1 v2 v3 v3 v2 v3 v2 ˆ v3 ˆ v2 ˆ v1 Learned Representation W1 W1 copy W2 W2

Figure 2. LSTM Autoencoder Model

[1,2,3]

Probabilistic graphical models? [4,5,6]

[4] Fox, Sudderth, Jordan, Willsky. Bayesian nonparametric inference of switching dynamic linear models. IEEE TSP 2011. [5] Johnson and Willsky. Bayesian nonparametric hidden semi-Markov models. JMLR 2013. [6] Murphy. Machine learning: a probabilistic perspective. MIT Press 2012.

slide-15
SLIDE 15
slide-16
SLIDE 16
slide-17
SLIDE 17
slide-18
SLIDE 18
slide-19
SLIDE 19
slide-20
SLIDE 20
slide-21
SLIDE 21

unsupervised learning supervised learning

slide-22
SLIDE 22

Probabilistic graphical models + structured representations + priors and uncertainty + data and computational efficiency – rigid assumptions may not fit – feature engineering – top-down inference Deep learning – neural net “goo” – difficult parameterization – can require lots of data + flexible + feature learning + recognition networks

slide-23
SLIDE 23
slide-24
SLIDE 24

Modeling idea: graphical models on latent variables, neural network models for observations

10 20 30 40 50 60 70 10 2 30 40 m m 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 10 20 30 40 50 60 70 10 2 30 40 m m 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm mm 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 2 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150

Application: learn syllable representation of behavior from video

Inference: recognition networks output conjugate potentials, then apply fast graphical model inference

slide-25
SLIDE 25

Modeling idea: graphical models on latent variables, neural network models for observations

slide-26
SLIDE 26 10 20 30 40 50 60 70 10 20 30 4 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 10 20 30 40 50 60 70 10 20 30 4 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm mm 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 20 30 4 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150

A(1) A(3) A(2) B(1) B(2) B(3) π =   π(1) π(2) π(3)   zt+1 ∼ π(zt) z1 z2 z3 z4 z5 z6 z7 xt+1 = A(zt)xt + B(zt)ut ut

iid

∼ N(0, I)

slide-27
SLIDE 27 10 20 30 40 50 60 70 10 20 30 4 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 10 20 30 40 50 60 70 10 20 30 4 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm mm 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 20 30 4 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150

π =   π(1) π(2) π(3)   A(1) A(3) A(2) B(1) B(2) B(3) z1 z2 z3 z4 z5 z6 z7 x1 x2 x3 x4 x5 x6 x7

slide-28
SLIDE 28 10 20 30 40 50 60 70 10 20 30 4 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 10 20 30 40 50 60 70 10 20 30 4 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm mm 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 20 30 4 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150

z1 z2 z3 z4 z5 z6 z7 x1 x2 x3 x4 x5 x6 x7

θ

slide-29
SLIDE 29 10 20 30 40 50 60 70 10 20 30 4 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 10 20 30 40 50 60 70 10 20 30 4 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm mm 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 20 30 4 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150

z1 z2 z3 z4 z5 z6 z7 x1 x2 x3 x4 x5 x6 x7 y1 y2 y3 y4 y5 y6 y7

θ

slide-30
SLIDE 30

yt | xt, γ ∼ N(µ(xt; γ), Σ(xt; γ))

diag(Σ(xt; γ)) xt µ(xt; γ)

z1 z2 z3 z4 z5 z6 z7 x1 x2 x3 x4 x5 x6 x7 y1 y2 y3 y4 y5 y6 y7

θ γ

slide-31
SLIDE 31

xn yn zn θ γ yn θ γ xn

yn γ xn θ

z1 z2 z3 z4 x1 x2 x3 x4 y1 y2 y3 y4 θ γ

p(θ) conjugate prior on global variables p(x | θ) exponential family on local variables p(γ) any prior on observation parameters p(y | x, γ) neural network observation model

slide-32
SLIDE 32

[1] Palmer, Wipf, Kreutz-Delgado, and Rao. Variational EM algorithms for non-Gaussian latent variable models. NIPS 2005. [2] Ghahramani and Beal. Propagation algorithms for variational Bayesian learning. NIPS 2001. [3] Beal. Variational algorithms for approximate Bayesian inference, Ch. 3. U of London Ph.D. Thesis 2003. [4] Ghahramani and Hinton. Variational learning for switching state-space models. Neural Computation 2000. [5] Jordan and Jacobs. Hierarchical Mixtures of Experts and the EM algorithm. Neural Computation 1994. [6] Bengio and Frasconi. An Input Output HMM Architecture. NIPS 1995. [7] Ghahramani and Jordan. Factorial Hidden Markov Models. Machine Learning 1997. [8] Bach and Jordan. A probabilistic interpretation of Canonical Correlation Analysis. Tech. Report 2005. [9] Archambeau and Bach. Sparse probabilistic projections. NIPS 2008. [10] Hoffman, Bach, Blei. Online learning for Latent Dirichlet Allocation. NIPS 2010. [1] [2] [3] [4] Gaussian mixture model Linear dynamical system Hidden Markov model Switching LDS [8,9] [10] Canonical correlations analysis admixture / LDA / NMF [6] [2] [5] Mixture of Experts Driven LDS IO-HMM Factorial HMM [7]

slide-33
SLIDE 33

yn γ xn θ

Inference?

slide-34
SLIDE 34

θ x1 x2 x3 x4 y1 y2 y3 y4

q(θ)q(x) ≈ p(θ, x | y) L[ q(θ)q(x) ] , Eq(θ)q(x) h log p(θ,x,y)

q(θ)q(x)

i q(θ) ↔ ηθ q(x) ↔ ηx

θ x1 x2 x3 x4

p(x | θ) is linear dynamical system p(y | x, θ) is linear-Gaussian p(θ) is conjugate prior

slide-35
SLIDE 35

θ x1 x2 x3 x4 y1 y2 y3 y4

q(θ)q(x) ≈ p(θ, x | y) L(ηθ, ηx) , Eq(θ)q(x) h log p(θ,x,y)

q(θ)q(x)

i

θ x1 x2 x3 x4

Proposition (natural gradient SVI of Hoffman et al. 2013) e rLSVI(ηθ) = η0

θ + Eq∗(x)(txy(x, y), 1) ηθ

η∗

x(ηθ) , arg max ηx

L(ηθ, ηx) LSVI(ηθ) , L(ηθ, η∗

x(ηθ))

p(x | θ) is linear dynamical system p(y | x, θ) is linear-Gaussian p(θ) is conjugate prior

slide-36
SLIDE 36

θ x1 x2 x3 x4 y1 y2 y3 y4

N

q(θ)q(x) ≈ p(θ, x | y) L(ηθ, ηx) , Eq(θ)q(x) h log p(θ,x,y)

q(θ)q(x)

i

θ x1 x2 x3 x4

N Proposition (natural gradient SVI of Hoffman et al. 2013) e rLSVI(ηθ) = η0

θ + N

X

n=1

Eq∗(xn)(txy(xn, yn), 1) ηθ

η∗

x(ηθ) , arg max ηx

L(ηθ, ηx) LSVI(ηθ) , L(ηθ, η∗

x(ηθ))

p(x | θ) is linear dynamical system p(y | x, θ) is linear-Gaussian p(θ) is conjugate prior

slide-37
SLIDE 37

Step 1: compute evidence potentials

[1] Johnson and Willsky. Stochastic variational inference for Bayesian time series models. ICML 2014. [2] Foti, Xu, Laird, and Fox. Stochastic variational inference for hidden Markov models. NIPS 2014.

slide-38
SLIDE 38

Step 1: compute evidence potentials

[1] Johnson and Willsky. Stochastic variational inference for Bayesian time series models. ICML 2014. [2] Foti, Xu, Laird, and Fox. Stochastic variational inference for hidden Markov models. NIPS 2014.

slide-39
SLIDE 39

Step 1: compute evidence potentials Step 3: compute natural gradient Step 2: run fast message passing

[1] Johnson and Willsky. Stochastic variational inference for Bayesian time series models. ICML 2014. [2] Foti, Xu, Laird, and Fox. Stochastic variational inference for hidden Markov models. NIPS 2014.

slide-40
SLIDE 40

+ optimal local factor – expensive for general obs. + exploits conj. graph structure + natural gradients – suboptimal local factor + fast for general obs. – does all local inference – no natural gradients ± optimal given conj. evidence + fast for general obs. + exploits conj. graph structure + natural gradients on

φ Natural gradient SVI Variational autoencoders Structured VAEs p q ηθ p q p q q∗(x) , N(x | µ(y; φ), Σ(y; φ)) q∗(x) , arg max

q(x)

L[ q(θ)q(x) ] q∗(x) , ?

[1] Kingma and Welling. Auto-encoding variational Bayes. ICLR 2014. [2] Rezende, Mohamed, and Wierstra. Stochastic backpropagation and approximate inference in deep generative models. ICML 2014 [1,2]

slide-41
SLIDE 41

Inference: recognition networks output conjugate potentials, then apply fast graphical model inference

slide-42
SLIDE 42

q(θ) ↔ ηθ q(γ) ↔ ηγ q(x) ↔ ηx L[ q(θ)q(γ)q(x) ] , Eq(θ)q(γ)q(x) h log p(θ,γ,x)p(y | x,γ)

q(θ)q(γ)q(x)

i

yn θ γ xn θ γ xn yn

slide-43
SLIDE 43

L(ηθ, ηγ, ηx) , Eq(θ)q(γ)q(x) h log p(θ,γ,x)p(y | x,γ)

q(θ)q(γ)q(x)

i η∗

x(ηθ, φ) , arg max ηx

b L(ηθ, ηx, φ) LSVAE(ηθ, ηγ, φ) , L(ηθ, ηγ, η∗

x(ηθ, φ))

where ψ(x; y, φ) is a conjugate potential for p(x | θ)

Eq(γ) log p(yt | xt, γ)

xt

yn θ γ xn θ γ xn yn b L(ηθ, ηx, φ) , Eq(θ)q(γ)q(x) h log p(θ,γ,x) exp{ψ(x;y,φ)}

q(θ)q(γ)q(x)

i

ψ(xt; yt, φ)

slide-44
SLIDE 44

Proposition (log evidence lower bound)

LSVAE(ηθ, ηγ, φ) max

ηx L(ηθ, ηγ, ηx)

log p(y) ηθ, ηγ

Fact (conjugate graphical models are easy) The local variational parameter η∗

x(ηθ, φ) is easy to compute.

if ∃ φ ∈ Rm with ψ(x; y, φ) = Eq(γ) log p(y | x, γ) max

ηx L(ηθ, ηγ, ηx)

log p(y) ηθ, ηγ max

φ

LSVAE(ηθ, ηγ, φ)

Proposition (easy natural gradient) +(rηxL(ηθ, ηγ, η∗

x(ηθ, φ)), 0)

Proposition (reparameterization trick) Estimate rηγ,φLSVAE(ηθ, ηγ, φ) with samples ˆ γ ⇠ q(γ) and ˆ x ⇠ q∗(x | φ) via LSVAE(ηθ, ηγ, φ) ⇡ log p(y | ˆ x, ˆ γ) KL(q(θ)q(γ)q∗(x | φ) k p(θ, γ, x)) e rηθLSVAE(ηθ, ηγ, φ)=(η0

θ +Eq∗(x | φ)(tx(x), 1)ηθ)

slide-45
SLIDE 45

Step 1: apply recognition network

slide-46
SLIDE 46

Step 1: apply recognition network

slide-47
SLIDE 47

Step 1: apply recognition network Step 4: compute natural gradient Step 2: run fast PGM algorithms Step 3: sample, compute flat grads

slide-48
SLIDE 48
slide-49
SLIDE 49
slide-50
SLIDE 50
slide-51
SLIDE 51

data space latent space

slide-52
SLIDE 52

data frame index predictions latent states

slide-53
SLIDE 53

natural gradient flat gradient

slide-54
SLIDE 54 10 20 30 40 50 60 70 10 2 30 40 m m 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 10 20 30 40 50 60 70 10 2 30 40 m m 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm mm 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 2 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150

Application: learn syllable representation of behavior from video

slide-55
SLIDE 55
slide-56
SLIDE 56
slide-57
SLIDE 57
slide-58
SLIDE 58
slide-59
SLIDE 59

start rear

slide-60
SLIDE 60

fall from rear

slide-61
SLIDE 61

grooming

slide-62
SLIDE 62

Discovery of Heterozygous Phenotypes in Ror1b Mice

Alexander Wiltschko, Matthew Johnson, et al., Neuron 2015.

slide-63
SLIDE 63

… and high and low doses of each drug

from Alex Wiltschko preprint

slide-64
SLIDE 64

Modeling idea: graphical models on latent variables, neural network models for observations

10 20 30 40 50 60 70 10 2 30 40 m m 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 10 20 30 40 50 60 70 10 2 30 40 m m 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm mm 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 2 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150

Application: learn syllable representation of behavior from video

Inference: recognition networks output conjugate potentials, then apply fast graphical model inference

slide-65
SLIDE 65

Limitations and future work

  • How expressive is latent linear structure?
  • word embeddings [1], analogical reasoning in image models
  • SVAE can use nonlinear latent structure

[1] Hashimoto, Alvarez-Melis, and Jaakkola, Word, graph and manifold embedding from Markov processes, Preprint 2015. [2] Grosse et al., Exploiting compositionality to explore a large space of model structures, UAI 2012. [3] Duvenaud et al., Structure discovery in nonparametric regression through compositional kernel search, ICML 2013.

  • model-based reinforcement learning
  • automatic structure search [2,3]
  • semi-supervised applications

future work

  • PGMs get complicated
  • SVAE keeps complexity modular

complexity capacity

slide-66
SLIDE 66

github.com/hips/autograd

slide-67
SLIDE 67

Thanks!