State Space Expectation Propagation Efficient Inference Schemes for - - PowerPoint PPT Presentation

state space expectation propagation
SMART_READER_LITE
LIVE PREVIEW

State Space Expectation Propagation Efficient Inference Schemes for - - PowerPoint PPT Presentation

IJ Aalto University State Space Expectation Propagation Efficient Inference Schemes for Temporal Gaussian Processes William Wilkinson , Paul Chang , Michael Riis Andersen , Arno Solin Aalto University , Technical University of


slide-1
SLIDE 1

State Space Expectation Propagation

Efficient Inference Schemes for Temporal Gaussian Processes William Wilkinson∗, Paul Chang∗, Michael Riis Andersen†, Arno Solin∗

Aalto University∗, Technical University of Denmark†

ICML 2020

IJ

Aalto University

slide-2
SLIDE 2

Motivation

  • We’re interested in long temporal and spatio-temporal data with

interesting non-conjugate GP models (e.g. classification, log-Gaussian Cox processes).

  • Idea: We should treat the temporal dimension in a fundamentally

different manner to other dimensions.

State Space Expectation Propagation Wilkinson et. al. 1/10

slide-3
SLIDE 3

Approximate Inference in Temporal GPs

There exists a dual kernel / SDE form for most popular Gaussian process (GP) models f(t) ∼ GP

  • 0, Kθ(t, t′)
  • ,

fk = Aθ,kfk−1 + qk, qk ∼ N(0, Qk) yk ∼ p(yk | f(tk)) yk = h(fk, σk), σk ∼ N(0, Σk)

State Space Expectation Propagation Wilkinson et. al. 2/10

slide-4
SLIDE 4

Approximate Inference in Temporal GPs

There exists a dual kernel / SDE form for most popular Gaussian process (GP) models f(t) ∼ GP

  • 0, Kθ(t, t′)
  • ,

fk = Aθ,kfk−1 + qk, qk ∼ N(0, Qk) yk ∼ p(yk | f(tk)) yk = h(fk, σk), σk ∼ N(0, Σk) inference in O(n) via Kalman filtering and smoothing

State Space Expectation Propagation Wilkinson et. al. 2/10

slide-5
SLIDE 5

Approximate Inference

50 100 150 200 250 300 −2 2 time - t f(t)

Kalman filter update step: p(fk|y1:k) ∝ N(mpredict

k

, Ppredict

k

) p(yk | f(tk))

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-6
SLIDE 6

Approximate Inference

50 100 150 200 250 300 −2 2 time - t f(t)

Kalman filter update step: p(fk|y1:k) ∝ N(mpredict

k

, Ppredict

k

) p(yk | f(tk)) ≈ N(mpredict

k

, Ppredict

k

) N(msite

k , Psite k )

  • “site”

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-7
SLIDE 7

Approximate Inference

50 100 150 200 250 300 −2 2 time - t f(t)

Kalman filter update step: p(fk|y1:k) ∝ N(mpredict

k

, Ppredict

k

) p(yk | f(tk)) ≈ N(mpredict

k

, Ppredict

k

) N(msite

k , Psite k )

  • Approx. Inference:

select parameters ←

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-8
SLIDE 8

Approximate Inference

Kalman filter update step: p(fk|y1:k) ∝ N(mpredict

k

, Ppredict

k

) p(yk | f(tk)) ≈ N(mpredict

k

, Ppredict

k

) N(msite

k , Psite k )

  • Approx. Inference:

select parameters ←

50 100 150 200 250 300 −2 2 time - t f(t)

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-9
SLIDE 9

Approximate Inference

50 100 150 200 250 300 −2 2 time - t f(t)

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-10
SLIDE 10

Approximate Inference

50 100 150 200 250 300 −2 2 time - t f(t)

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-11
SLIDE 11

Approximate Inference

50 100 150 200 250 300 −2 2 time - t f(t)

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-12
SLIDE 12

Approximate Inference

50 100 150 200 250 300 −2 2 time - t f(t)

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-13
SLIDE 13

Approximate Inference

50 100 150 200 250 300 −2 2 time - t f(t)

Smoothing:

  • update posterior with future observations,

p(fk | y1:N) = N(mpost.

k

, Ppost.

k

)

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-14
SLIDE 14

Approximate Inference

Smoothing:

  • update posterior with future observations,

p(fk | y1:N) = N(mpost.

k

, Ppost.

k

)

50 100 150 200 250 300 −2 2 time - t f(t)

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-15
SLIDE 15

Approximate Inference

Smoothing:

  • update posterior with future observations,

p(fk | y1:N) = N(mpost.

k

, Ppost.

k

)

50 100 150 200 250 300 −2 2 time - t f(t)

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-16
SLIDE 16

Approximate Inference

Smoothing:

  • update posterior with future observations,

p(fk | y1:N) = N(mpost.

k

, Ppost.

k

)

50 100 150 200 250 300 −2 2 time - t f(t)

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-17
SLIDE 17

Approximate Inference

Smoothing:

  • update posterior with future observations,

p(fk | y1:N) = N(mpost.

k

, Ppost.

k

)

50 100 150 200 250 300 −2 2 time - t f(t)

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-18
SLIDE 18

Approximate Inference

Smoothing:

  • update posterior with future observations,

p(fk | y1:N) = N(mpost.

k

, Ppost.

k

)

50 100 150 200 250 300 −2 2 time - t f(t)

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-19
SLIDE 19

Approximate Inference

Smoothing:

  • update posterior with future observations,

p(fk | y1:N) = N(mpost.

k

, Ppost.

k

)

50 100 150 200 250 300 −2 2 time - t f(t)

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-20
SLIDE 20

Approximate Inference

Smoothing:

  • update posterior with future observations,

p(fk | y1:N) = N(mpost.

k

, Ppost.

k

)

50 100 150 200 250 300 −2 2 time - t f(t)

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-21
SLIDE 21

Approximate Inference

Smoothing:

  • update posterior with future observations,

p(fk | y1:N) = N(mpost.

k

, Ppost.

k

)

50 100 150 200 250 300 −2 2 time - t f(t)

Our Contribution: Given marginal posterior N(mpost.

k

, Ppost.

k

), we show how approximate inference amounts to a simple site parameter update rule during smoothing.

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-22
SLIDE 22

Approximate Inference

Smoothing:

  • update posterior with future observations,

p(fk | y1:N) = N(mpost.

k

, Ppost.

k

) Our Contribution: Given marginal posterior N(mpost.

k

, Ppost.

k

), we show how approximate inference amounts to a simple site parameter update rule during smoothing. This encompasses:

  • Power Expectation Propagation
  • Variational Inference (with natural gradients)
  • Extended Kalman Smoothing
  • Unscented / Gauss-Hermite Kalman Smoothing
  • Posterior Linearisation

State Space Expectation Propagation Wilkinson et. al. 3/10

slide-23
SLIDE 23

Parameter Update Rules

Power Expectation Propagation:

qcavity(fk) = qpost.(fk)/qα

site(fk)

Lk = log Eqcavity

  • pα(yk | fk)
  • Psite

k

= −α

  • Pcavity

k

+

  • ∇2Lk

−1 msite

k

= mcavity

k

  • ∇2Lk

−1 ∇Lk for ∇Lk = dLk

dmk

State Space Expectation Propagation Wilkinson et. al. 4/10

slide-24
SLIDE 24

Parameter Update Rules

Power Expectation Propagation:

qcavity(fk) = qpost.(fk)/qα

site(fk)

Lk = log Eqcavity

  • pα(yk | fk)
  • Psite

k

= −α

  • Pcavity

k

+

  • ∇2Lk

−1 msite

k

= mcavity

k

  • ∇2Lk

−1 ∇Lk

Variational Inference:

Lk = Eqpost.

  • log p(yk | fk)
  • Psite

k

= −

  • ∇2Lk

−1 msite

k

= mpost.

k

  • ∇2Lk

−1 ∇Lk for ∇Lk = dLk

dmk

State Space Expectation Propagation Wilkinson et. al. 4/10

slide-25
SLIDE 25

Parameter Update Rules

Power Expectation Propagation:

qcavity(fk) = qpost.(fk)/qα

site(fk)

Lk = log Eqcavity

  • pα(yk | fk)
  • Psite

k

= −α

  • Pcavity

k

+

  • ∇2Lk

−1 msite

k

= mcavity

k

  • ∇2Lk

−1 ∇Lk

Variational Inference:

Lk = Eqpost.

  • log p(yk | fk)
  • Psite

k

= −

  • ∇2Lk

−1 msite

k

= mpost.

k

  • ∇2Lk

−1 ∇Lk

Extended Kalman Smoother:

vk = yk − h(mpost.

k

, 0) Sk = H⊤

f Ppost. k

Hf + HσΣkH⊤

σ

Psite

k

=

  • H⊤

f

  • HσΣkH⊤

σ

−1 Hf −1 msite

k

= mpost.

k

+ (Psite

k

+ Ppost.

k

)H⊤

f S−1 k

vk for ∇Lk = dLk

dmk

for Hf = dh

df and Hσ = dh dσ , σk ∼ N(0, Σk)

State Space Expectation Propagation Wilkinson et. al. 4/10

slide-26
SLIDE 26

A Unifying Perspective

  • For sequential data, the EKF / UKF / GHKF are equivalent to

single-sweep EP where the moment matching is solved via linearisation.

State Space Expectation Propagation Wilkinson et. al. 5/10

slide-27
SLIDE 27

A Unifying Perspective

  • For sequential data, the EKF / UKF / GHKF are equivalent to

single-sweep EP where the moment matching is solved via linearisation.

  • The iterated Kalman smoothers (EKS / UKS / GHKS) can also be

recovered under certain parameter choices. But note that they

  • ptimise a different objective to EP (see paper for details).

State Space Expectation Propagation Wilkinson et. al. 5/10

slide-28
SLIDE 28

A Unifying Perspective

  • For sequential data, the EKF / UKF / GHKF are equivalent to

single-sweep EP where the moment matching is solved via linearisation.

  • The iterated Kalman smoothers (EKS / UKS / GHKS) can also be

recovered under certain parameter choices. But note that they

  • ptimise a different objective to EP (see paper for details).
  • We show how natural gradient VI updates are surprisingly similar

to the EP updates (when using a similar parametrisation).

State Space Expectation Propagation Wilkinson et. al. 5/10

slide-29
SLIDE 29

New Algorithms

  • We propose to mix the beneficial properties of EP with the

efficiency of classical smoothers.

State Space Expectation Propagation Wilkinson et. al. 6/10

slide-30
SLIDE 30

New Algorithms

  • We propose to mix the beneficial properties of EP with the

efficiency of classical smoothers.

  • For example, using linearisation to speed up the updates, whilst

also introducing the EP cavity and fractional updates.

State Space Expectation Propagation Wilkinson et. al. 6/10

slide-31
SLIDE 31

New Algorithms

  • We propose to mix the beneficial properties of EP with the

efficiency of classical smoothers.

  • For example, using linearisation to speed up the updates, whilst

also introducing the EP cavity and fractional updates.

  • We call this Extended Kalman Expectation Propagation (EK-EP).

State Space Expectation Propagation Wilkinson et. al. 6/10

slide-32
SLIDE 32

New Algorithms

  • We propose to mix the beneficial properties of EP with the

efficiency of classical smoothers.

  • For example, using linearisation to speed up the updates, whilst

also introducing the EP cavity and fractional updates.

  • We call this Extended Kalman Expectation Propagation (EK-EP).
  • It has clear computational benefits when the parameter updates

are high-dimensional, e.g., in spatio-temporal problems.

State Space Expectation Propagation Wilkinson et. al. 6/10

slide-33
SLIDE 33

Spatio-Temporal Classification

  • We show that our smoothing methods can be applied to tasks

with more than one input dimension.

State Space Expectation Propagation Wilkinson et. al. 7/10

slide-34
SLIDE 34

Spatio-Temporal Classification

  • We show that our smoothing methods can be applied to tasks

with more than one input dimension.

State Space Expectation Propagation Wilkinson et. al. 7/10

slide-35
SLIDE 35

Spatio-Temporal Classification

  • We show that our smoothing methods can be applied to tasks

with more than one input dimension. We treat the first dimension (x-axis) as time, and run iterated spatio-temporal smoothing (this demo uses EP).

slide-36
SLIDE 36

Fast Learning Using JAX

  • Temporal GP methods have been limited by a lack of appropriate

software for hyperparameter learning: Automatic differentiation + massive for loops https://github.com/AaltoML/kalman-jax

State Space Expectation Propagation Wilkinson et. al. 8/10

slide-37
SLIDE 37

Fast Learning Using JAX

  • Temporal GP methods have been limited by a lack of appropriate

software for hyperparameter learning: Automatic differentiation + massive for loops

  • We provide a temporal GP framework in JAX with all inference

methods implemented. https://github.com/AaltoML/kalman-jax

State Space Expectation Propagation Wilkinson et. al. 8/10

slide-38
SLIDE 38

Fast Learning Using JAX

  • Temporal GP methods have been limited by a lack of appropriate

software for hyperparameter learning: Automatic differentiation + massive for loops

  • We provide a temporal GP framework in JAX with all inference

methods implemented. i) avoid loop “unrolling” to reduce compilation overheads https://github.com/AaltoML/kalman-jax

State Space Expectation Propagation Wilkinson et. al. 8/10

slide-39
SLIDE 39

Fast Learning Using JAX

  • Temporal GP methods have been limited by a lack of appropriate

software for hyperparameter learning: Automatic differentiation + massive for loops

  • We provide a temporal GP framework in JAX with all inference

methods implemented. i) avoid loop “unrolling” to reduce compilation overheads ii) JIT compilation to avoid graph retracing https://github.com/AaltoML/kalman-jax

State Space Expectation Propagation Wilkinson et. al. 8/10

slide-40
SLIDE 40

Fast Learning Using JAX

  • Temporal GP methods have been limited by a lack of appropriate

software for hyperparameter learning: Automatic differentiation + massive for loops

  • We provide a temporal GP framework in JAX with all inference

methods implemented. i) avoid loop “unrolling” to reduce compilation overheads ii) JIT compilation to avoid graph retracing iii) Exploits accelerated linear algebra (XLA) ops https://github.com/AaltoML/kalman-jax

State Space Expectation Propagation Wilkinson et. al. 8/10

slide-41
SLIDE 41

Results

We run extensive analysis on synthetic and real world data:

  • Heteroscedastic Noise
  • 1D & 2D Log Gaussian Cox

Process

  • 1D & 2D Classification
  • Audio Amplitude

Demodulation

State Space Expectation Propagation Wilkinson et. al. 9/10

slide-42
SLIDE 42

Results

  • No consistently best inference method or EP power value:
  • EK-EP the only practical method when updates are high dimensional

(rainforest)

  • EP or VI needed when likelihood is highly nonlinear

State Space Expectation Propagation Wilkinson et. al. 10/10

slide-43
SLIDE 43

Results

  • No consistently best inference method or EP power value:
  • EK-EP the only practical method when updates are high dimensional

(rainforest)

  • EP or VI needed when likelihood is highly nonlinear
  • We compare against non-sequential baselines (SVGP and EP).
  • Sequential learning methods match the performance of batch methods,

whilst scaling to larger data. State Space Expectation Propagation Wilkinson et. al. 10/10

slide-44
SLIDE 44

Results

  • No consistently best inference method or EP power value:
  • EK-EP the only practical method when updates are high dimensional

(rainforest)

  • EP or VI needed when likelihood is highly nonlinear
  • We compare against non-sequential baselines (SVGP and EP).
  • Sequential learning methods match the performance of batch methods,

whilst scaling to larger data.

  • See the paper for full results table.

State Space Expectation Propagation Wilkinson et. al. 10/10

slide-45
SLIDE 45

Thanks for Listening

Take home messages:

  • Any approximate inference method can be framed as a simple

parameter update rule during Kalman smoothing.

  • Sequential methods match the performance of batch methods,

and can be extended to multiple dimensions.

  • We provide fast JAX code for all methods.

Contact: william.wilkinson@aalto.fi JAX code: https://github.com/AaltoML/kalman-jax

State Space Expectation Propagation Wilkinson et. al. 11/10