Advanced inference in probabilistic programs
Brooks Paige
Advanced inference in probabilistic programs Brooks Paige - - PowerPoint PPT Presentation
Advanced inference in probabilistic programs Brooks Paige Inference thus far Likelihood weighting / importance sampling MCMC (single-dimension, coded by hand) Lightweight Metropolis-Hastings (update one random choice at a time,
Brooks Paige
random choice at a time, by re-running the remainder of the program)
How can we make inference more computationally efficient?
a building block for an inference algorithm that can succeed in models with higher-dimensional latent spaces
asynchronous SMC
simulating from in these methods? Can we learn importance sampling proposals automatically?
Anglican? (JW will show you this afternoon)
between model code and inference code:
develop and implement using this interface?
(observe ¡...) ¡ (sample ¡...) ¡
(doquery ¡:algorithm ¡model ¡[args] ¡options) ¡
programs in such a way that we see early, incremental evidence then we can use more efficient inference algorithms.
statements which come after observe statements can be informed by the data
(defquery ¡monolithic-‑observe ¡[] ¡ ¡ ¡... ¡;; ¡many ¡sample ¡statments ¡ ¡ ¡ ¡ ¡(sample ¡...) ¡ ¡ ¡ ¡ ¡(sample ¡...) ¡ ¡ ¡ ¡ ¡(sample ¡...) ¡ ¡ ¡... ¡;; ¡single ¡observe ¡/ ¡ ¡ ¡ ¡ ¡ ¡ ¡ ¡;; ¡conditioning ¡statement ¡ ¡ ¡ ¡ ¡ ¡ ¡ ¡;; ¡at ¡the ¡end ¡ ¡ ¡ ¡ ¡(observe ¡...)) ¡
¡ ¡(loop ¡... ¡ ¡ ¡ ¡ ¡;; ¡interleaved ¡sample ¡and ¡ ¡ ¡ ¡ ¡ ¡;; ¡observe ¡statements ¡ ¡ ¡ ¡ ¡(sample ¡...) ¡ ¡ ¡ ¡ ¡ ¡(observe ¡...) ¡ ¡ ¡ ¡ ¡(recur ¡...))) ¡ ¡ ¡ ¡ ¡ ¡
x0 x1 x2 x3
· · ·
y1 y2 y3
x0 x1 x2 x3
· · ·
y1 y2 y3
Place a massive observe statement at the end
x0 x1 x2 x3
· · ·
y1 y2 y3
Place a massive observe statement at the end
x0 x1 x2 x3
· · ·
y1 y2 y3
Place a massive observe statement at the end
x0 x1 x2 x3
· · ·
y1 y2 y3
Place a massive observe statement at the end
x0 x1 x2 x3
· · ·
y1 y2 y3
Place a massive observe statement at the end
No “feedback” until all random variables have been sampled
x0 x1 x2 x3
· · ·
y1 y2 y3
Place observe statements as early as possibly
x0 x1 x2 x3
· · ·
y1 y2 y3
Place observe statements as early as possibly
x0 x1 x2 x3
· · ·
y1 y2 y3
Place observe statements as early as possibly
Does y1 have high probability given x0 and x1?
x0 x1 x2 x3
· · ·
y1 y2 y3
Place observe statements as early as possibly
Does y2 have high probability given x0, x1, and x2?
x0 x1 x2 x3
· · ·
y1 y2 y3
Place observe statements as early as possibly
x0 x1 x2 x3
· · ·
y1 y2 y3
Place observe statements as early as possibly
Incremental evidence == computational efficiency?
incrementally!
discrete timesteps)
datapoint) observes, such as mixture models and many multilevel Bayesian models
“canvas” is always visible and can be evaluated according to a fitness function at any time
inference in state- space models
and latent state
estimate latent state; predict future data; estimate marginal likelihood
“time” (n) “space” (x)
yn xn
p(x0:N, y0:N) =
N
Y
n=0
g(yn|x0:n)f(xn|x0:n−1)
distribution using a weighted set of K particles
n = 1
K total particles
p(x0:n|y0:n) ≈
K
X
k=1
w1:K
n
δx(k)
0:n(x0:n)
x(k)
0:n
n = 1
K total particles
(unnormalized) weight based on its likelihood
n
wk
n ∝ W k n
p(x0:n|y0:n) ≈
K
X
k=1
w1:K
n
δx(k)
0:n(x0:n)
n = 1
K total particles
(unnormalized) weight based on its likelihood
n
wk
n ∝ W k n
p(x0:n|y0:n) ≈
K
X
k=1
w1:K
n
δx(k)
0:n(x0:n)
n = 2
n = 1
K total particles
resampled according to their weights, then simulated forward
is proportional to the weight M k
n
W k
n
n = 1 n = 2
K total particles
weight are discarded, and particles with high weight are replicated
particles are replicated more often
k |W 1:K n
] = W k
n
W n
n = 1 n = 2
K total particles
Iteratively,
n = 1 n = 2
K total particles
Iteratively,
n = 1 n = 2 n = 3
K total particles
SMC in action: slowed down for clarity
is deterministic
e encounter N s {(gi, φi, yi)}N
i=1
to the sample . This yields seq d {(fj, θj)}M
j=1
sampled values ments, wi e) {xj}M
j=1.
γ(x) , p(x, y) =
N
Y
i=1
gi(yi|φi)
M
Y
j=1
fj(xj|θj). γ(x) = p(x, y) =
N
Y
i=1
˜ gi(xni) ✓ yi
φi(xni) ◆ M Y
j=1
˜ fj(xj−1) ✓ xj
θj(xj−1) ◆
y1 y2
etc
x4 x6 x1 x3 x2 x4 x5 x6
alue xj = x1 × · · · × xj denote sampled values (with
the prior
ˆ Eπ[R(x)] =
K
X
k=1
W kR(xk)
q(xk) =
Mk
Y
j=1
fj(xk
j|θk j )
w(xk) = γ(xk) q(xk) =
Nk
Y
i=1
gk
i (yk i |φk i )
W k = w(xk) PK
`=1 w(x`)
γn(˜ x1:n) =
N
Y
n=1
g(yn|˜ x1:n)p(˜ xn|˜ x1:n−1),
ed incremental targets
πn(˜ x1:n) = 1 Zn γn(˜ x1:n)
y1 y2
etc
x1 x2 x3 x4 x5 x6 ˜ x1 ˜ x2 subspace of x which is with ˜
x1:n = ˜ x1 × · · · × ˜ xn such
independent Metropolis- Hastings”
conditional SMC”
ancestral sampling"
n n n
n n n
n n n
Sweep
ˆ EP IMH[R(x)] = 1 S
S
X
s=1 K
X
k=1
W s,kR(xs,k)
compute
new particle sets w.p.
αs
PIMH = min
1, ˆ Z? ˆ Zs−1 !
ˆ Z =
N
Y
n=1
ˆ Zn =
N
Y
n=1
1 K
K
X
k=1
w(˜ xk
1:n)
n n n
n n n
n n n
Sweep ˆ Z1 ˆ Z2 ˆ Z∗
(e.g.) parfor in MATLAB, or pmap in functional programming languages
child counts) is a global synchronous operation
a number of offspring based only on the particles which have arrived so far
n = 1
f(xn|x1:n−1)
g(yn|x1:n)
n = 1
f(xn|x1:n−1)
g(yn|x1:n)
n = 1 n = 2
average weight at each n, based only on first k particles to arrive
no need to wait for other particles
k n
E[M n
k |W 1:k n ] = W k n
W
k n
n = 1 n = 2
while other particles continue moving forward through the system
system may vary over course of execution
n = 1 n = 2
identical weight after resampling
is set to the current running average W
k n
n = 1 n = 2
Asynchronously
Asynchronously
n = 1 n = 2
n = 1 n = 2 n = 3
approaches: raw speed of drawing samples
particle-based methods
replacement for SMC, with the added benefits of
keep running inference indefinitely, stop when satisfied with the current estimate
particle throughput and parallel scalability as compared to traditional SMC
compile away the runtime costs of inference?
is fundamentally not a feed-forward computation!
approximate optimal importance sampling proposals
Importance sampling and SMC approximate the posterior as weighted samples:
ˆ p(x|y) =
K
X
k=1
Wkδxk(x) Wk = w(xk) PK
j=1 w(xj)
q(x|λ)
Goal: posterior inference in generative models with latent variables x and observed variables y:
p(x, y) ,
N
Y
i=1
p (xi|pa(xi))
M
Y
j=1
p (yj|pa(yj))
w(x) = p(x, y) q(x|λ)
Learning an importance sampling proposal for a single dataset
approximating family q(x|λ)
Target density ,
6= π(x) = p(x|y)
argmin
λ
DKL(π||qλ) =
Single dataset :
x|y)
fit λ to learn an importance sampling proposal
A probabilistic model generates data An inverse model generates latents Can we learn how to sample from the inverse model?
tn zn w0 w1 w2
N
tn zn w0 w1 w2
N
tn zn w0 w1 w2 ϕw
N
at λ = ϕ(η, y), argmin
η
Ep(y) ⇥ DKL(π||qϕ(η,y)) ⇤
Averaging over all possible datasets:
learn a mapping from arbitrary datasets to λ
Idea: amortize inference by learning a map from data to target
approximating family q(x|λ)
Target density ,
6= π(x) = p(x|y)
A probabilistic model generates data An inverse model generates latents Can we learn how to sample from the inverse model?
tn zn w0 w1 w2
N
tn zn w0 w1 w2
N
tn zn w0 w1 w2 ϕw
N
Learn to invert the generative model, before seeing data
at λ = ϕ(η, y), argmin
η
Ep(y) ⇥ DKL(π||qϕ(η,y)) ⇤
Averaging over all possible datasets:
expectation over any data we might observe
J (η) = Z DKL(π||qλ)p(y)dy = Z p(y) Z p(x|y) log p(x|y) q(x|ϕ(η, y))
= Ep(x,y) [ log q(x|ϕ(η, y))] + const.
at λ = ϕ(η, y), argmin
η
Ep(y) ⇥ DKL(π||qϕ(η,y)) ⇤
Averaging over all possible datasets: New objective function, upper-level parameters:
Learn to invert the generative model, before seeing data
expectation over (tractable) joint distribution
J (η) = Z DKL(π||qλ)p(y)dy = Z p(y) Z p(x|y) log p(x|y) q(x|ϕ(η, y))
= Ep(x,y) [ log q(x|ϕ(η, y))] + const.
at λ = ϕ(η, y), argmin
η
Ep(y) ⇥ DKL(π||qϕ(η,y)) ⇤
Averaging over all possible datasets: New objective function, upper-level parameters: t rηJ (η) = Ep(x,y) [rη log q(x|ϕ(η, y))] Tractable gradient! Can train entirely offline:
Learn to invert the generative model, before seeing data
approximate with samples from the joint distribution
t rηJ (η) = Ep(x,y) [rη log q(x|ϕ(η, y))] Expected KL divergence: J (η) =
in Ep(y) ⇥ DKL(π||qϕ(η,y)) ⇤
Gradient:
approximate with samples from model choose a known parametric family… … and any differentiable function
t rηJ (η) = Ep(x,y) [rη log q(x|ϕ(η, y))] Expected KL divergence: J (η) =
in Ep(y) ⇥ DKL(π||qϕ(η,y)) ⇤
Gradient:
Univariate x: mixture density network
next dimension, conditioned on previous dimensions
Multivariate x: autoregressive neural density estimator
Samples from prior
Samples from prior Metropolis-Hastings
Samples from proposal Metropolis-Hastings
After importance weighting Metropolis-Hastings
(1) There is an algorithm [Stuhlmüller et al., 2013] which takes a model and constructs an inverse model, in which the observed nodes come first. (2) Property: this inverse model does not introduce any additional conditional independencies. That is, if two random variables are independent given a third in the inverse model, this was also true in the original generative model.
tn zn w0 w1 w2
N
tn zn w0 w1 w2
N
Generative model Inverse model
tn zn w0 w1 w2 ϕw
N
Single multivariate proposal
yn θn tn α β
N
y1 θ1 t1 yN θN tN α β
… …
Generative model Inverse model
yn θn tn α β
N
y1 θ1 t1 yN θN tN α β
… …
Partial model can be evaluated before simulating all random variables
Generative model Inverse model
p(y1|θ1, t1) q(θ1|y1, t1)
yn θn tn α β
N
y1 θ1 t1 yN θN tN α β
… …
Local random variables can be evaluated independently, and share learned factors
Generative model Inverse model
yn θn tn α β
N
yn θn tn α β
N
Generative model Inverse model
yn θn tn α β
N
reusable approximation lower-dimensional approximation
Generative model Inverse model
yn ϕθn θn tn ϕαβ α β
N
Orders of magnitude fewer samples required
Factorial HMM (partial figure) Inverting the factorial HMM
For models which are actually sequential, then this learns approximations to the optimal filtering proposal
Factorial HMM (partial figure) Inverting the factorial HMM
For models which are actually sequential, then this learns approximations to the optimal filtering proposal
reusable approximation
Example: energy usage disaggregation. Combinatorial space: 2^20 or about 100k possible states at each timestep Many diverse plausible interpretations
We’d like to be able to completely automate this process!
language (BUGS, STAN, Anglican, …). Can we compile the model to an approximate inverse model?
makes a “good” inverse model? (1) structures the neural network so that training is easier (fewer overall parameters) (2) structures the sequence of target densities for SMC such that inference is easier