Advanced inference in probabilistic programs Brooks Paige - - PowerPoint PPT Presentation

advanced inference in probabilistic programs
SMART_READER_LITE
LIVE PREVIEW

Advanced inference in probabilistic programs Brooks Paige - - PowerPoint PPT Presentation

Advanced inference in probabilistic programs Brooks Paige Inference thus far Likelihood weighting / importance sampling MCMC (single-dimension, coded by hand) Lightweight Metropolis-Hastings (update one random choice at a time,


slide-1
SLIDE 1

Advanced inference in probabilistic programs

Brooks Paige

slide-2
SLIDE 2
  • Likelihood weighting / importance sampling
  • MCMC (single-dimension, coded by hand)
  • “Lightweight” Metropolis-Hastings (update one

random choice at a time, by re-running the remainder of the program)

Inference thus far

slide-3
SLIDE 3

How can we make inference more computationally efficient?

  • Sequential Monte Carlo uses importance sampling as

a building block for an inference algorithm that can succeed in models with higher-dimensional latent spaces

  • Algorithms which extend SMC: Particle MCMC, and

asynchronous SMC

  • What sort of proposal distributions should we be

simulating from in these methods? Can we learn importance sampling proposals automatically?

Inference: this talk

slide-4
SLIDE 4
  • How do you implement an inference algorithm in

Anglican? (JW will show you this afternoon)

  • Two important special forms are the interface

between model code and inference code:


  • Q: what kinds of inference algorithms can we

develop and implement using this interface?

Inference in Anglican

(observe ¡...) ¡ (sample ¡...) ¡

(doquery ¡:algorithm ¡model ¡[args] ¡options) ¡

slide-5
SLIDE 5
  • If we can write our

programs in such a way that we see early, incremental evidence then we can use more efficient inference algorithms.

  • Intuition: sample

statements which come after observe statements can be informed by the data

Incremental evidence

(defquery ¡monolithic-­‑observe ¡[] ¡ ¡ ¡... ¡;; ¡many ¡sample ¡statments ¡ ¡ ¡ ¡ ¡(sample ¡...) ¡ ¡ ¡ ¡ ¡(sample ¡...) ¡ ¡ ¡ ¡ ¡(sample ¡...) ¡ ¡ ¡... ¡;; ¡single ¡observe ¡/ ¡ ¡ ¡ ¡ ¡ ¡ ¡ ¡;; ¡conditioning ¡statement ¡ ¡ ¡ ¡ ¡ ¡ ¡ ¡;; ¡at ¡the ¡end ¡ ¡ ¡ ¡ ¡(observe ¡...)) ¡

  • (defquery ¡incremental-­‑observe ¡[] ¡

¡ ¡(loop ¡... ¡ ¡ ¡ ¡ ¡;; ¡interleaved ¡sample ¡and ¡ ¡ ¡ ¡ ¡ ¡;; ¡observe ¡statements ¡ ¡ ¡ ¡ ¡(sample ¡...) ¡ ¡ ¡ ¡ ¡ ¡(observe ¡...) ¡ ¡ ¡ ¡ ¡(recur ¡...))) ¡ ¡ ¡ ¡ ¡ ¡

slide-6
SLIDE 6

Hidden Markov model

x0 x1 x2 x3

· · ·

y1 y2 y3

slide-7
SLIDE 7

Hidden Markov model

x0 x1 x2 x3

· · ·

y1 y2 y3

Place a massive observe statement at the end

slide-8
SLIDE 8

Hidden Markov model

x0 x1 x2 x3

· · ·

y1 y2 y3

Place a massive observe statement at the end

slide-9
SLIDE 9

Hidden Markov model

x0 x1 x2 x3

· · ·

y1 y2 y3

Place a massive observe statement at the end

slide-10
SLIDE 10

Hidden Markov model

x0 x1 x2 x3

· · ·

y1 y2 y3

Place a massive observe statement at the end

slide-11
SLIDE 11

Hidden Markov model

x0 x1 x2 x3

· · ·

y1 y2 y3

Place a massive observe statement at the end

No “feedback” until all random variables have been sampled

slide-12
SLIDE 12

Hidden Markov model

x0 x1 x2 x3

· · ·

y1 y2 y3

Place observe statements as early as possibly

slide-13
SLIDE 13

Hidden Markov model

x0 x1 x2 x3

· · ·

y1 y2 y3

Place observe statements as early as possibly

slide-14
SLIDE 14

Hidden Markov model

x0 x1 x2 x3

· · ·

y1 y2 y3

Place observe statements as early as possibly

Does y1 have high probability given x0 and x1?

slide-15
SLIDE 15

Hidden Markov model

x0 x1 x2 x3

· · ·

y1 y2 y3

Place observe statements as early as possibly

Does y2 have high probability given x0, x1, and x2?

slide-16
SLIDE 16

Hidden Markov model

x0 x1 x2 x3

· · ·

y1 y2 y3

Place observe statements as early as possibly

slide-17
SLIDE 17

Hidden Markov model

x0 x1 x2 x3

· · ·

y1 y2 y3

Place observe statements as early as possibly

Incremental evidence == computational efficiency?

slide-18
SLIDE 18
  • Many models and settings are naturally written

incrementally!

  • Canonical example: time series models (observe at

discrete timesteps)

  • Planning problems (observe at discrete timesteps)
  • Models which factor into global and “local” (per-

datapoint) observes, such as mixture models and many multilevel Bayesian models

  • Models such as image synthesis, where the entire

“canvas” is always visible and can be evaluated according to a fitness function at any time

Incremental evidence

slide-19
SLIDE 19
  • Running example:

inference in state- space models

  • Observed data

and latent state

  • Inference goals:

estimate latent state; predict future data; estimate marginal likelihood

State-space models

“time” (n) “space” (x)

yn xn

p(x0:N, y0:N) =

N

Y

n=0

g(yn|x0:n)f(xn|x0:n−1)

slide-20
SLIDE 20
  • Basic idea: approximate the posterior

distribution using a weighted set of K particles
 


  • Sequential Monte Carlo

n = 1

K total particles

p(x0:n|y0:n) ≈

K

X

k=1

w1:K

n

δx(k)

0:n(x0:n)

x(k)

0:n

slide-21
SLIDE 21

Sequential Monte Carlo

n = 1

K total particles

  • Each particle is assigned an

(unnormalized) weight based on its likelihood
 


  • W k

n

wk

n ∝ W k n

p(x0:n|y0:n) ≈

K

X

k=1

w1:K

n

δx(k)

0:n(x0:n)

slide-22
SLIDE 22

Sequential Monte Carlo

n = 1

K total particles

  • Each particle is assigned an

(unnormalized) weight based on its likelihood
 


  • W k

n

wk

n ∝ W k n

p(x0:n|y0:n) ≈

K

X

k=1

w1:K

n

δx(k)

0:n(x0:n)

slide-23
SLIDE 23

n = 2

Sequential Monte Carlo

n = 1

K total particles

  • Particles are

resampled according to their weights, then simulated forward


  • Each particle has zero
  • r more children

  • Number of children


is proportional to the weight M k

n

W k

n

slide-24
SLIDE 24

Sequential Monte Carlo

n = 1 n = 2

K total particles

  • Particles with low

weight are discarded, and particles with high weight are replicated


  • Better-than-average

particles are replicated more often


  • E[M n

k |W 1:K n

] = W k

n

W n

slide-25
SLIDE 25

Sequential Monte Carlo

n = 1 n = 2

K total particles

Iteratively, 


  • simulate

  • weight

  • resample
slide-26
SLIDE 26

Sequential Monte Carlo

n = 1 n = 2

K total particles

Iteratively, 


  • simulate

  • weight

  • resample
slide-27
SLIDE 27

Sequential Monte Carlo

n = 1 n = 2 n = 3

K total particles

slide-28
SLIDE 28

Sequential Monte Carlo

SMC in action: slowed down for clarity

slide-29
SLIDE 29

Probabilistic programs as state spaces?

slide-30
SLIDE 30
  • Sequence of N observe’s
  • Sequence of M sample’s
  • Sequence of M sampled values
  • Conditioned on these sampled values the entire computation

is deterministic

Trace

e encounter N s {(gi, φi, yi)}N

i=1

to the sample . This yields seq d {(fj, θj)}M

j=1

sampled values ments, wi e) {xj}M

j=1.

  • wn norm
slide-31
SLIDE 31
  • Defined as (up to a normalization constant)
  • Hides true dependency structure

Trace Probability

γ(x) , p(x, y) =

N

Y

i=1

gi(yi|φi)

M

Y

j=1

fj(xj|θj). γ(x) = p(x, y) =

N

Y

i=1

˜ gi(xni) ✓ yi

  • ˜

φi(xni) ◆ M Y

j=1

˜ fj(xj−1) ✓ xj

  • ˜

θj(xj−1) ◆

y1 y2

{

{

etc

x4 x6 x1 x3 x2 x4 x5 x6

alue xj = x1 × · · · × xj denote sampled values (with

slide-32
SLIDE 32
  • Run K independent copies of program simulating from

the prior

  • Accumulate unnormalized weights (likelihoods)
  • Use in approximate (Monte Carlo) integration

ˆ Eπ[R(x)] =

K

X

k=1

W kR(xk)

Likelihood Weighting

q(xk) =

Mk

Y

j=1

fj(xk

j|θk j )

w(xk) = γ(xk) q(xk) =

Nk

Y

i=1

gk

i (yk i |φk i )

W k = w(xk) PK

`=1 w(x`)

slide-33
SLIDE 33

Probabilistic programs as state spaces

  • Notation
  • Incrementalized joint
  • Incrementalized target

γn(˜ x1:n) =

N

Y

n=1

g(yn|˜ x1:n)p(˜ xn|˜ x1:n−1),

ed incremental targets

πn(˜ x1:n) = 1 Zn γn(˜ x1:n)

y1 y2

{

{

etc

x1 x2 x3 x4 x5 x6 ˜ x1 ˜ x2 subspace of x which is with ˜

x1:n = ˜ x1 × · · · × ˜ xn such

  • disjoint. While there are alw
slide-34
SLIDE 34

Particle Markov chain Monte Carlo

slide-35
SLIDE 35

Particle Markov Chain Monte Carlo

  • Iterable SMC
  • PIMH : “particle

independent Metropolis- Hastings”

  • PGIBBS : “iterated

conditional SMC”

  • PGAS : “particle Gibbs

ancestral sampling"

n n n

n n n

n n n

Sweep

slide-36
SLIDE 36

ˆ EP IMH[R(x)] = 1 S

S

X

s=1 K

X

k=1

W s,kR(xs,k)

PIMH Math

  • Each sweep of SMC can

compute

  • PIMH is MH that accepts entire

new particle sets w.p.

  • And all particles can be used

αs

PIMH = min

1, ˆ Z? ˆ Zs−1 !

ˆ Z =

N

Y

n=1

ˆ Zn =

N

Y

n=1

1 K

K

X

k=1

w(˜ xk

1:n)

n n n

n n n

n n n

Sweep ˆ Z1 ˆ Z2 ˆ Z∗

slide-37
SLIDE 37

Asynchronous anytime sequential Monte Carlo

slide-38
SLIDE 38

Parallelization in SMC

  • Forward simulation trivially parallelizes
  • this is the sort of parallelization achieved through

(e.g.) parfor in MATLAB, or pmap in functional programming languages

  • The resampling step (normalizing weights, sampling

child counts) is a global synchronous operation

  • cannot resample until all particles finish simulation
slide-39
SLIDE 39

Particle Cascade

  • Replace resampling step with branching step
  • Launch particles asynchronously
  • As each particle arrives at an observation, choose

a number of offspring based only on the particles which have arrived so far

  • … don’t need to wait for all particles to arrive
  • … only need to track average weights at each 

  • bservation, which we compute online
slide-40
SLIDE 40

Particle Cascade

n = 1

  • Start by simulating particles,
  • ne at a time, from
  • Weight by likelihood

f(xn|x1:n−1)

g(yn|x1:n)

slide-41
SLIDE 41

Particle Cascade

n = 1

  • Start by simulating particles,
  • ne at a time, from
  • Weight by likelihood

f(xn|x1:n−1)

g(yn|x1:n)

slide-42
SLIDE 42

Particle Cascade

n = 1 n = 2

  • Keep track of the running

average weight at each n, based only on first k particles to arrive

  • Choose number of
  • ffspring immediately,

no need to wait for other particles

  • W

k n

E[M n

k |W 1:k n ] = W k n

W

k n

slide-43
SLIDE 43

Particle Cascade

n = 1 n = 2

  • Launch new particles

while other particles continue moving forward through the system

  • Total size of particle

system may vary over course of execution

slide-44
SLIDE 44

Particle Cascade

n = 1 n = 2

  • Particles do not have

identical weight after resampling

  • The “outgoing” weight

is set to the current running average W

k n

slide-45
SLIDE 45

Particle Cascade

n = 1 n = 2

Asynchronously 


  • simulate

  • weight

  • branch
slide-46
SLIDE 46

Asynchronously 


  • simulate

  • weight

  • branch

Particle Cascade

n = 1 n = 2

slide-47
SLIDE 47

Particle Cascade

n = 1 n = 2 n = 3

slide-48
SLIDE 48

Particle Cascade

slide-49
SLIDE 49

Particle Cascade

slide-50
SLIDE 50

Particle Cascade

slide-51
SLIDE 51

Particle Cascade

slide-52
SLIDE 52

Particle Cascade

slide-53
SLIDE 53

Particle Cascade

slide-54
SLIDE 54

Particle Cascade

slide-55
SLIDE 55

Particle Cascade

slide-56
SLIDE 56

Scalability: Particle Count

  • Comparison across particle-based inference

approaches: raw speed of drawing samples

  • Each particle runs as a separate CPU process
slide-57
SLIDE 57

Scalability: Multiple Cores

  • More cores == faster inference
  • Scales to multiple cores more efficiently than other

particle-based methods

slide-58
SLIDE 58

Particle cascade summary

  • Particle cascade is an asynchronous anytime drop-in

replacement for SMC, with the added benefits of

  • … an anytime property similar to MCMC methods;

keep running inference indefinitely, stop when satisfied with the current estimate

  • … no barrier synchronizations, yielding increased

particle throughput and parallel scalability as compared to traditional SMC

slide-59
SLIDE 59

Inference networks for sequential Monte Carlo

slide-60
SLIDE 60

Executive Summary

We want to make model-based Bayesian inference efficient.

  • In general: what artifacts can we learn offline to

compile away the runtime costs of inference?


  • Outside of specific (probably wrong) models, inference

is fundamentally not a feed-forward computation!


  • Sequential Monte Carlo for graphical models:

approximate optimal importance sampling proposals

slide-61
SLIDE 61

Importance sampling and SMC approximate the posterior as weighted samples:

  • Performance depends on quality of proposal !

ˆ p(x|y) =

K

X

k=1

Wkδxk(x) Wk = w(xk) PK

j=1 w(xj)

 q(x|λ)

Inference in Graphical Models

Goal: posterior inference in generative models with latent variables x and observed variables y:

p(x, y) ,

N

Y

i=1

p (xi|pa(xi))

M

Y

j=1

p (yj|pa(yj))

  • n π(x) ≡ p(x|y).

w(x) = p(x, y) q(x|λ)

slide-62
SLIDE 62

Learning an importance sampling proposal for a single dataset

approximating family q(x|λ)

Target density ,

6= π(x) = p(x|y)

argmin

λ

DKL(π||qλ) =

Single dataset :

x|y)

fit λ to learn an importance sampling proposal

Inference Networks for Graphical Models

A probabilistic model generates data An inverse model generates latents Can we learn how to sample from the inverse model?

tn zn w0 w1 w2

N

tn zn w0 w1 w2

N

tn zn w0 w1 w2 ϕw

N

slide-63
SLIDE 63

at λ = ϕ(η, y), argmin

η

Ep(y) ⇥ DKL(π||qϕ(η,y)) ⇤

Averaging over 
 all possible datasets:

learn a mapping from arbitrary datasets to λ

Idea: amortize inference by learning a map from data to target

approximating family q(x|λ)

Target density ,

6= π(x) = p(x|y)

Inference Networks for Graphical Models

A probabilistic model generates data An inverse model generates latents Can we learn how to sample from the inverse model?

tn zn w0 w1 w2

N

tn zn w0 w1 w2

N

tn zn w0 w1 w2 ϕw

N

slide-64
SLIDE 64

Learn to invert the generative model, before seeing data

Compiling away runtime costs of inference

at λ = ϕ(η, y), argmin

η

Ep(y) ⇥ DKL(π||qϕ(η,y)) ⇤

Averaging over 
 all possible datasets:

expectation over any data we might observe

slide-65
SLIDE 65

J (η) = Z DKL(π||qλ)p(y)dy = Z p(y) Z p(x|y) log  p(x|y) q(x|ϕ(η, y))

  • dxdy

= Ep(x,y) [ log q(x|ϕ(η, y))] + const.

at λ = ϕ(η, y), argmin

η

Ep(y) ⇥ DKL(π||qϕ(η,y)) ⇤

Averaging over 
 all possible datasets: New objective function, upper-level parameters:

Learn to invert the generative model, before seeing data

Compiling away runtime costs of inference

expectation over (tractable) joint distribution

slide-66
SLIDE 66

J (η) = Z DKL(π||qλ)p(y)dy = Z p(y) Z p(x|y) log  p(x|y) q(x|ϕ(η, y))

  • dxdy

= Ep(x,y) [ log q(x|ϕ(η, y))] + const.

at λ = ϕ(η, y), argmin

η

Ep(y) ⇥ DKL(π||qϕ(η,y)) ⇤

Averaging over 
 all possible datasets: New objective function, upper-level parameters: t rηJ (η) = Ep(x,y) [rη log q(x|ϕ(η, y))] Tractable gradient! 
 Can train entirely offline:

Learn to invert the generative model, before seeing data

Compiling away runtime costs of inference

approximate with samples from the joint distribution

slide-67
SLIDE 67

t rηJ (η) = Ep(x,y) [rη log q(x|ϕ(η, y))] Expected KL divergence: J (η) =

in Ep(y) ⇥ DKL(π||qϕ(η,y)) ⇤

Gradient:

approximate with samples from model choose a known parametric family… … and any differentiable function

Choice of approximating family

slide-68
SLIDE 68

Choice of approximating family

t rηJ (η) = Ep(x,y) [rη log q(x|ϕ(η, y))] Expected KL divergence: J (η) =

in Ep(y) ⇥ DKL(π||qϕ(η,y)) ⇤

Gradient:

Univariate x: mixture density network

  • Neural network outputs parameters of a parametric model for the

next dimension, conditioned on previous dimensions

  • e.g. mixture of Gaussians, categorical, …

  • MADE: efficient weight sharing for multivariate densities

Multivariate x: autoregressive neural density estimator

slide-69
SLIDE 69

Non-conjugate polynomial regression

Samples from prior

slide-70
SLIDE 70

Non-conjugate polynomial regression

Samples from prior Metropolis-Hastings

slide-71
SLIDE 71

Non-conjugate polynomial regression

Samples from proposal Metropolis-Hastings

slide-72
SLIDE 72

Non-conjugate polynomial regression

After importance weighting Metropolis-Hastings

slide-73
SLIDE 73

Non-conjugate polynomial regression

slide-74
SLIDE 74

Bigger models: 
 Exploiting structure

slide-75
SLIDE 75

Factorization of inverse models

(1) There is an algorithm [Stuhlmüller et al., 2013] which takes a model and constructs an inverse model, in which the observed nodes come first.
 
 (2) Property: this inverse model does not introduce any additional conditional independencies. 
 
 That is, if two random variables are independent given a third in the inverse model, this was also true in the original generative model.

slide-76
SLIDE 76

tn zn w0 w1 w2

N

tn zn w0 w1 w2

N

Generative model Inverse model

tn zn w0 w1 w2 ϕw

N

Single multivariate proposal

Factorization of inverse models

slide-77
SLIDE 77

Inverting a multilevel model

yn θn tn α β

N

y1 θ1 t1 yN θN tN α β

… …

Generative model Inverse model

slide-78
SLIDE 78

Inverting a multilevel model

yn θn tn α β

N

y1 θ1 t1 yN θN tN α β

… …

Partial model can be evaluated before simulating all random variables

Generative model Inverse model

p(y1|θ1, t1) q(θ1|y1, t1)

slide-79
SLIDE 79

Inverting a multilevel model

yn θn tn α β

N

y1 θ1 t1 yN θN tN α β

… …

Local random variables can be evaluated independently, and share learned factors

Generative model Inverse model

slide-80
SLIDE 80

yn θn tn α β

N

Inverting a multilevel model

yn θn tn α β

N

Generative model Inverse model

slide-81
SLIDE 81

Inverting a multilevel model

yn θn tn α β

N

reusable approximation lower-dimensional approximation

Generative model Inverse model

yn ϕθn θn tn ϕαβ α β

N

slide-82
SLIDE 82

Heirarchical Poisson model

Orders of magnitude fewer samples required

slide-83
SLIDE 83

In sequential models

Factorial HMM (partial figure) Inverting the factorial HMM

For models which are actually sequential, then this learns approximations to the optimal filtering proposal

slide-84
SLIDE 84

In sequential models

Factorial HMM (partial figure) Inverting the factorial HMM

For models which are actually sequential, then this learns approximations to the optimal filtering proposal

reusable approximation

slide-85
SLIDE 85

Additive Factorial HMM

Example: energy usage 
 disaggregation.
 
 Combinatorial space: 
 2^20 or about 100k possible
 states at each timestep
 
 Many diverse plausible 
 interpretations


slide-86
SLIDE 86

Discussion

We’d like to be able to completely automate this process!

  • Ideally: here’s a model, in some model specification

language (BUGS, STAN, Anglican, …). Can we compile the model to an approximate inverse model?


  • Open problems: topological sort is not unique! What

makes a “good” inverse model? (1) structures the neural network so that training is easier (fewer overall parameters) (2) structures the sequence of target densities for SMC such that inference is easier