Composing graphical models with neural networks for structured representations and fast inference
Matt Johnson, David Duvenaud, Alex Wiltschko, Bob Datta, Ryan Adams
Composing graphical models with neural networks for structured - - PowerPoint PPT Presentation
Composing graphical models with neural networks for structured representations and fast inference Matt Johnson, David Duvenaud, Alex Wiltschko, Bob Datta, Ryan Adams 6 0 6 0 6 0 5 0 5 0 5 0 4 0 4 0 4 0 m m m m 3 0 m 3 0
Composing graphical models with neural networks for structured representations and fast inference
Matt Johnson, David Duvenaud, Alex Wiltschko, Bob Datta, Ryan Adams
pause rear
m m 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 2 30 40 m m 1 2 3 4 m m 5 6 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 20 30 40 50 60 70 10 20 3 40 m m 1 2 3 4 m m 5 6 10 20 30 40 50 60 70 mm 10 20 30 40 50 60 70 10 20 3 40 m m 1 2 3 4 m m 5 6 10 20 30 40 50 60 70 mmdart
[1] Lee and Glass. A Nonparametric Bayesian Approach to Acoustic Model Discovery. ACL 2012. [2] Lee. Discovering Linguistic Structures in Speech: Models and Applications. MIT Ph.D. Thesis 2014. [1,2]
/b/ /ax/ /n/ /ae/ /n/ /ax/
10 20 30 40 50 60 70 10 20 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 10 20 30 40 50 60 70 10 20 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm mm 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 20 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150Alexander Wiltschko, Matthew Johnson, et al., Neuron 2015.
image manifold
image manifold depth video
image manifold depth video
image manifold depth video
image manifold depth video
image manifold depth video
image manifold depth video
rear dart
manifold coordinates image manifold depth video
[1] Srivastava, Mansimov, Salakhutdinov. Unsupervised learning of video representations using LSTMs. ICML 2015. [2] Ranzato, MarcAurelio, et al. Video (language) modeling: a baseline for generative models of natural videos. Preprint 2015. [3] Sutskever, Hinton, and Taylor. The Recurrent Temporal Restricted Boltzmann Machine. NIPS 2008.
Recurrent neural networks?
Figure 1. LSTM unit
v1 v2 v3 v3 v2 v3 v2 ˆ v3 ˆ v2 ˆ v1 Learned Representation W1 W1 copy W2 W2Figure 2. LSTM Autoencoder Model
[1,2,3]
Probabilistic graphical models? [4,5,6]
[4] Fox, Sudderth, Jordan, Willsky. Bayesian nonparametric inference of switching dynamic linear models. IEEE TSP 2011. [5] Johnson and Willsky. Bayesian nonparametric hidden semi-Markov models. JMLR 2013. [6] Murphy. Machine learning: a probabilistic perspective. MIT Press 2012.
unsupervised learning supervised learning
Probabilistic graphical models + structured representations + priors and uncertainty + data and computational efficiency – rigid assumptions may not fit – feature engineering – top-down inference Deep learning – neural net “goo” – difficult parameterization – can require lots of data + flexible + feature learning + recognition networks
Modeling idea: graphical models on latent variables, neural network models for observations
10 20 30 40 50 60 70 10 2 30 40 m m 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 10 20 30 40 50 60 70 10 2 30 40 m m 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm mm 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 2 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150Application: learn syllable representation of behavior from video
Inference: recognition networks output conjugate potentials, then apply fast graphical model inference
Modeling idea: graphical models on latent variables, neural network models for observations
A(1) A(3) A(2) B(1) B(2) B(3) π = π(1) π(2) π(3) zt+1 ∼ π(zt) z1 z2 z3 z4 z5 z6 z7 xt+1 = A(zt)xt + B(zt)ut ut
iid
∼ N(0, I)
π = π(1) π(2) π(3) A(1) A(3) A(2) B(1) B(2) B(3) z1 z2 z3 z4 z5 z6 z7 x1 x2 x3 x4 x5 x6 x7
z1 z2 z3 z4 z5 z6 z7 x1 x2 x3 x4 x5 x6 x7
θ
z1 z2 z3 z4 z5 z6 z7 x1 x2 x3 x4 x5 x6 x7 y1 y2 y3 y4 y5 y6 y7
θ
yt | xt, γ ∼ N(µ(xt; γ), Σ(xt; γ))
diag(Σ(xt; γ)) xt µ(xt; γ)
z1 z2 z3 z4 z5 z6 z7 x1 x2 x3 x4 x5 x6 x7 y1 y2 y3 y4 y5 y6 y7
θ γ
xn yn zn θ γ yn θ γ xn
yn γ xn θ
z1 z2 z3 z4 x1 x2 x3 x4 y1 y2 y3 y4 θ γ
p(θ) conjugate prior on global variables p(x | θ) exponential family on local variables p(γ) any prior on observation parameters p(y | x, γ) neural network observation model
[1] Palmer, Wipf, Kreutz-Delgado, and Rao. Variational EM algorithms for non-Gaussian latent variable models. NIPS 2005. [2] Ghahramani and Beal. Propagation algorithms for variational Bayesian learning. NIPS 2001. [3] Beal. Variational algorithms for approximate Bayesian inference, Ch. 3. U of London Ph.D. Thesis 2003. [4] Ghahramani and Hinton. Variational learning for switching state-space models. Neural Computation 2000. [5] Jordan and Jacobs. Hierarchical Mixtures of Experts and the EM algorithm. Neural Computation 1994. [6] Bengio and Frasconi. An Input Output HMM Architecture. NIPS 1995. [7] Ghahramani and Jordan. Factorial Hidden Markov Models. Machine Learning 1997. [8] Bach and Jordan. A probabilistic interpretation of Canonical Correlation Analysis. Tech. Report 2005. [9] Archambeau and Bach. Sparse probabilistic projections. NIPS 2008. [10] Hoffman, Bach, Blei. Online learning for Latent Dirichlet Allocation. NIPS 2010. [1] [2] [3] [4] Gaussian mixture model Linear dynamical system Hidden Markov model Switching LDS [8,9] [10] Canonical correlations analysis admixture / LDA / NMF [6] [2] [5] Mixture of Experts Driven LDS IO-HMM Factorial HMM [7]
yn γ xn θ
Inference?
θ x1 x2 x3 x4 y1 y2 y3 y4
q(θ)q(x) ≈ p(θ, x | y) L[ q(θ)q(x) ] , Eq(θ)q(x) h log p(θ,x,y)
q(θ)q(x)
i q(θ) ↔ ηθ q(x) ↔ ηx
θ x1 x2 x3 x4
p(x | θ) is linear dynamical system p(y | x, θ) is linear-Gaussian p(θ) is conjugate prior
θ x1 x2 x3 x4 y1 y2 y3 y4
q(θ)q(x) ≈ p(θ, x | y) L(ηθ, ηx) , Eq(θ)q(x) h log p(θ,x,y)
q(θ)q(x)
i
θ x1 x2 x3 x4
Proposition (natural gradient SVI of Hoffman et al. 2013) e rLSVI(ηθ) = η0
θ + Eq∗(x)(txy(x, y), 1) ηθ
η∗
x(ηθ) , arg max ηx
L(ηθ, ηx) LSVI(ηθ) , L(ηθ, η∗
x(ηθ))
p(x | θ) is linear dynamical system p(y | x, θ) is linear-Gaussian p(θ) is conjugate prior
θ x1 x2 x3 x4 y1 y2 y3 y4
N
q(θ)q(x) ≈ p(θ, x | y) L(ηθ, ηx) , Eq(θ)q(x) h log p(θ,x,y)
q(θ)q(x)
i
θ x1 x2 x3 x4
N Proposition (natural gradient SVI of Hoffman et al. 2013) e rLSVI(ηθ) = η0
θ + N
X
n=1
Eq∗(xn)(txy(xn, yn), 1) ηθ
η∗
x(ηθ) , arg max ηx
L(ηθ, ηx) LSVI(ηθ) , L(ηθ, η∗
x(ηθ))
p(x | θ) is linear dynamical system p(y | x, θ) is linear-Gaussian p(θ) is conjugate prior
Step 1: compute evidence potentials
[1] Johnson and Willsky. Stochastic variational inference for Bayesian time series models. ICML 2014. [2] Foti, Xu, Laird, and Fox. Stochastic variational inference for hidden Markov models. NIPS 2014.
Step 1: compute evidence potentials
[1] Johnson and Willsky. Stochastic variational inference for Bayesian time series models. ICML 2014. [2] Foti, Xu, Laird, and Fox. Stochastic variational inference for hidden Markov models. NIPS 2014.
Step 1: compute evidence potentials Step 3: compute natural gradient Step 2: run fast message passing
[1] Johnson and Willsky. Stochastic variational inference for Bayesian time series models. ICML 2014. [2] Foti, Xu, Laird, and Fox. Stochastic variational inference for hidden Markov models. NIPS 2014.
+ optimal local factor – expensive for general obs. + exploits conj. graph structure + natural gradients – suboptimal local factor + fast for general obs. – does all local inference – no natural gradients ± optimal given conj. evidence + fast for general obs. + exploits conj. graph structure + natural gradients on
φ Natural gradient SVI Variational autoencoders Structured VAEs p q ηθ p q p q q∗(x) , N(x | µ(y; φ), Σ(y; φ)) q∗(x) , arg max
q(x)
L[ q(θ)q(x) ] q∗(x) , ?
[1] Kingma and Welling. Auto-encoding variational Bayes. ICLR 2014. [2] Rezende, Mohamed, and Wierstra. Stochastic backpropagation and approximate inference in deep generative models. ICML 2014 [1,2]
Inference: recognition networks output conjugate potentials, then apply fast graphical model inference
q(θ) ↔ ηθ q(γ) ↔ ηγ q(x) ↔ ηx L[ q(θ)q(γ)q(x) ] , Eq(θ)q(γ)q(x) h log p(θ,γ,x)p(y | x,γ)
q(θ)q(γ)q(x)
i
yn θ γ xn θ γ xn yn
L(ηθ, ηγ, ηx) , Eq(θ)q(γ)q(x) h log p(θ,γ,x)p(y | x,γ)
q(θ)q(γ)q(x)
i η∗
x(ηθ, φ) , arg max ηx
b L(ηθ, ηx, φ) LSVAE(ηθ, ηγ, φ) , L(ηθ, ηγ, η∗
x(ηθ, φ))
where ψ(x; y, φ) is a conjugate potential for p(x | θ)
Eq(γ) log p(yt | xt, γ)
xt
yn θ γ xn θ γ xn yn b L(ηθ, ηx, φ) , Eq(θ)q(γ)q(x) h log p(θ,γ,x) exp{ψ(x;y,φ)}
q(θ)q(γ)q(x)
i
ψ(xt; yt, φ)
Proposition (log evidence lower bound)
LSVAE(ηθ, ηγ, φ) max
ηx L(ηθ, ηγ, ηx)
log p(y) ηθ, ηγ
Fact (conjugate graphical models are easy) The local variational parameter η∗
x(ηθ, φ) is easy to compute.
if ∃ φ ∈ Rm with ψ(x; y, φ) = Eq(γ) log p(y | x, γ) max
ηx L(ηθ, ηγ, ηx)
log p(y) ηθ, ηγ max
φ
LSVAE(ηθ, ηγ, φ)
Proposition (easy natural gradient) +(rηxL(ηθ, ηγ, η∗
x(ηθ, φ)), 0)
Proposition (reparameterization trick) Estimate rηγ,φLSVAE(ηθ, ηγ, φ) with samples ˆ γ ⇠ q(γ) and ˆ x ⇠ q∗(x | φ) via LSVAE(ηθ, ηγ, φ) ⇡ log p(y | ˆ x, ˆ γ) KL(q(θ)q(γ)q∗(x | φ) k p(θ, γ, x)) e rηθLSVAE(ηθ, ηγ, φ)=(η0
θ +Eq∗(x | φ)(tx(x), 1)ηθ)
Step 1: apply recognition network
Step 1: apply recognition network
Step 1: apply recognition network Step 4: compute natural gradient Step 2: run fast PGM algorithms Step 3: sample, compute flat grads
data space latent space
data frame index predictions latent states
natural gradient flat gradient
Application: learn syllable representation of behavior from video
start rear
fall from rear
grooming
Discovery of Heterozygous Phenotypes in Ror1b Mice
Alexander Wiltschko, Matthew Johnson, et al., Neuron 2015.
… and high and low doses of each drug
from Alex Wiltschko preprint
Modeling idea: graphical models on latent variables, neural network models for observations
10 20 30 40 50 60 70 10 2 30 40 m m 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 10 20 30 40 50 60 70 10 2 30 40 m m 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm mm 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 2 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150Application: learn syllable representation of behavior from video
Inference: recognition networks output conjugate potentials, then apply fast graphical model inference
[1] Hashimoto, Alvarez-Melis, and Jaakkola, Word, graph and manifold embedding from Markov processes, Preprint 2015. [2] Grosse et al., Exploiting compositionality to explore a large space of model structures, UAI 2012. [3] Duvenaud et al., Structure discovery in nonparametric regression through compositional kernel search, ICML 2013.
future work
complexity capacity
github.com/hips/autograd