1/14
Stein Point Markov Chain Monte Carlo Wilson Chen Institute of - - PowerPoint PPT Presentation
Stein Point Markov Chain Monte Carlo Wilson Chen Institute of - - PowerPoint PPT Presentation
Stein Point Markov Chain Monte Carlo Wilson Chen Institute of Statistical Mathematics, Japan June 15, 2019 @ ICML Steins Method Workshop, Long Beach 1/14 Collaborators Alessandro Barp Fran cois-Xavier Briol Jackson Gorham Mark
2/14
Collaborators
Alessandro Barp Fran¸ cois-Xavier Briol Jackson Gorham Mark Girolami Lester Mackey Chris Oates
3/14
Empirical Approximation Problem
A major problem in machine learning and modern statistics is to approximate some difficult-to-compute density p defined on some domain X ⊆ Rd where normalisation constant is unknown. I.e., p(x) = ˜ p(x)/Z and Z > 0 is unknown.
3/14
Empirical Approximation Problem
A major problem in machine learning and modern statistics is to approximate some difficult-to-compute density p defined on some domain X ⊆ Rd where normalisation constant is unknown. I.e., p(x) = ˜ p(x)/Z and Z > 0 is unknown. We consider an empirical approximation of p with points {xi}n
i=1:
ˆ pn(x) = 1 n
n
- i=1
δ(x − xi), so that for test function f : X → R:
- X
f(x)p(x)dx ≈ 1 n
n
- i=1
f(xi).
3/14
Empirical Approximation Problem
A major problem in machine learning and modern statistics is to approximate some difficult-to-compute density p defined on some domain X ⊆ Rd where normalisation constant is unknown. I.e., p(x) = ˜ p(x)/Z and Z > 0 is unknown. We consider an empirical approximation of p with points {xi}n
i=1:
ˆ pn(x) = 1 n
n
- i=1
δ(x − xi), so that for test function f : X → R:
- X
f(x)p(x)dx ≈ 1 n
n
- i=1
f(xi). A popular approach is Markov chain Monte Carlo.
4/14
Discrepancy
Idea – construct a measure of discrepancy D(ˆ pn, p) with desirable features:
- Detect (non)convergence. I.e., D(ˆ
pn, p) → 0 only if ˆ pn
∗
− → p.
- Efficiently computable with limited access to p.
4/14
Discrepancy
Idea – construct a measure of discrepancy D(ˆ pn, p) with desirable features:
- Detect (non)convergence. I.e., D(ˆ
pn, p) → 0 only if ˆ pn
∗
− → p.
- Efficiently computable with limited access to p.
Unfortunately not the case for many popular discrepancy measures:
- Kullback-Leibler divergence,
- Wasserstein distance,
- Maximum mean discrepancy (MMD).
5/14
Kernel Embedding and MMD
Kernel embedding of a distribution p µp(·) =
- k(x, ·)p(x)dx
(a function in the RKHS K)
5/14
Kernel Embedding and MMD
Kernel embedding of a distribution p µp(·) =
- k(x, ·)p(x)dx
(a function in the RKHS K) Consider the maximum mean discrepancy (MMD) as an option for D: D(ˆ pn, p) := µˆ
pn − µpK =: Dk,p({xi}n i=1)
5/14
Kernel Embedding and MMD
Kernel embedding of a distribution p µp(·) =
- k(x, ·)p(x)dx
(a function in the RKHS K) Consider the maximum mean discrepancy (MMD) as an option for D: D(ˆ pn, p) := µˆ
pn − µpK =: Dk,p({xi}n i=1)
∴ Dk,p({xi}n
i=1)2 = µˆ pn − µp2 K = µˆ pn − µp, µˆ pn − µp
= µˆ
pn, µˆ pn − 2µˆ pn, µp + µp, µp
We are faced with intractable integrals w.r.t. p!
5/14
Kernel Embedding and MMD
Kernel embedding of a distribution p µp(·) =
- k(x, ·)p(x)dx
(a function in the RKHS K) Consider the maximum mean discrepancy (MMD) as an option for D: D(ˆ pn, p) := µˆ
pn − µpK =: Dk,p({xi}n i=1)
∴ Dk,p({xi}n
i=1)2 = µˆ pn − µp2 K = µˆ pn − µp, µˆ pn − µp
= µˆ
pn, µˆ pn − 2µˆ pn, µp + µp, µp
We are faced with intractable integrals w.r.t. p! For a Stein kernel k0: µp(·) =
- k0(x, ·)p(x)dx = 0.
∴ µˆ
pn − µp2 K0 = µˆ pn2 K0 =: Dk0,p({xi}n i=1)2 =: KSD2!
6/14
Kernel Stein Discrepancy (KSD)
The kernel Stein discrepancy (KSD) is given by Dk0,p({xi}n
i=1) = 1
n
- n
- i=1
n
- j=1
k0(xi, xj),
6/14
Kernel Stein Discrepancy (KSD)
The kernel Stein discrepancy (KSD) is given by Dk0,p({xi}n
i=1) = 1
n
- n
- i=1
n
- j=1
k0(xi, xj), where k0 is the Stein kernel k0(x, x′) := TpT ′
pk(x, x′)
= ∇x · ∇x′k(x, x′) + ∇x log p(x), ∇x′k(x, x′) + ∇x′ log p(x′), ∇xk(x, x′) + ∇x log p(x), ∇x′ log p(x′)k(x, x′), with Tpf = ∇(pf)/p. (Tp is a Stein operator.)
6/14
Kernel Stein Discrepancy (KSD)
The kernel Stein discrepancy (KSD) is given by Dk0,p({xi}n
i=1) = 1
n
- n
- i=1
n
- j=1
k0(xi, xj), where k0 is the Stein kernel k0(x, x′) := TpT ′
pk(x, x′)
= ∇x · ∇x′k(x, x′) + ∇x log p(x), ∇x′k(x, x′) + ∇x′ log p(x′), ∇xk(x, x′) + ∇x log p(x), ∇x′ log p(x′)k(x, x′), with Tpf = ∇(pf)/p. (Tp is a Stein operator.)
- This is computable without the normalisation constant.
- Requires gradient information ∇ log p(xi).
- Detects (non)convergence for an appropriately chosen k (e.g., the IMQ kernel).
7/14
Stein Points (SP)
The main idea of Stein Points is the greedy minimisation of KSD: xj|x1, . . . , xj−1 ← arg min
x∈X
Dk0,p({xi}j−1
i=1 ∪ {x})
= arg min
x∈X
k0(x, x) + 2
j−1
- i=1
k0(x, xi).
7/14
Stein Points (SP)
The main idea of Stein Points is the greedy minimisation of KSD: xj|x1, . . . , xj−1 ← arg min
x∈X
Dk0,p({xi}j−1
i=1 ∪ {x})
= arg min
x∈X
k0(x, x) + 2
j−1
- i=1
k0(x, xi). A global optimisation step is needed for each iteration.
8/14
Stein Point Markov Chain Monte Carlo (SP-MCMC)
We propose to replace the global minimisation at each iteration j of the SP method with a local search based on a p-invariant Markov chain of length mj. The proposed SP-MCMC method proceeds as follows:
- 1. Fix an initial point x1 ∈ X.
- 2. For j = 2, . . . , n:
- a. Select i∗ ∈ {1, . . . , j − 1} according to criterion crit({xi}j−1
i=1).
- b. Generate (yj,i)mj
i=1 from a p-invariant Markov chain with yj,1 = xi∗.
- c. Set xj ← arg minx∈{yj,i}
mj i=1 Dk0,p({xi}j−1
i=1 ∪ {x}).
8/14
Stein Point Markov Chain Monte Carlo (SP-MCMC)
We propose to replace the global minimisation at each iteration j of the SP method with a local search based on a p-invariant Markov chain of length mj. The proposed SP-MCMC method proceeds as follows:
- 1. Fix an initial point x1 ∈ X.
- 2. For j = 2, . . . , n:
- a. Select i∗ ∈ {1, . . . , j − 1} according to criterion crit({xi}j−1
i=1).
- b. Generate (yj,i)mj
i=1 from a p-invariant Markov chain with yj,1 = xi∗.
- c. Set xj ← arg minx∈{yj,i}
mj i=1 Dk0,p({xi}j−1
i=1 ∪ {x}).
For crit, three different approaches are considered:
- LAST selects the point last added: i∗ := j − 1.
- RAND selects i∗ uniformly at random in {1, . . . , j − 1}.
- INFL selects i∗ to be the index of the most influential point in {xi}j−1
i=1.
We call x∗
i the most influential point if removing it from the point set creates the
greatest increase in KSD.
9/14
Gaussian Mixture Model Experiment
MCMC 500 1000 j
- 4
- 2
log KSD 500 1000 j
- 2
2 2 4 6 Jump2 0.5 1 1.5 Density LAST 500 1000
- 4
- 2
500 1000
- 2
2 2 4 6 0.5 1 1.5
SP-MCMC MCMC
RAND 500 1000
- 4
- 2
500 1000
- 2
2 2 4 6 0.5 1 1.5 INFL 500 1000
- 4
- 2
500 1000
- 2
2 2 4 6 0.5 1 1.5
10/14
IGARCH Experiment (d = 2)
2 4 6 8 10 12 log neval
- 11
- 10
- 9
- 8
- 7
- 6
- 5
- 4
log EP
MALA RWM SVGD MED SP SP-MALA LAST SP-MALA INFL SP-RWM LAST SP-RWM INFL
SP-MCMC methods are compared against the original SP (Chen et al., 2018), MED (Roshan Joseph et al., 2015) and SVGD (Liu & Wang, 2016), as well as the Metropolis-adjusted Langevin algorithm (MALA) and random-walk Metropolis (RWM).
11/14
ODE Experiment (d = 4)
2 4 6 8 10 12 log n eval
- 1
1 2 3 4 5 log KSD
MALA RWM SVGD MED SP SP-MALA LAST SP-MALA INFL SP-RWM LAST SP-RWM INFL
SP-MCMC methods are compared against the original SP (Chen et al., 2018), MED (Roshan Joseph et al., 2015) and SVGD (Liu & Wang, 2016), as well as the Metropolis-adjusted Langevin algorithm (MALA) and random-walk Metropolis (RWM).
12/14
ODE Experiment (d = 10)
4 6 8 10 12 log neval 1 2 3 4 5 6 7 8 log KSD
MALA RWM SVGD MED SP SP-MALA LAST SP-MALA INFL SP-RWM LAST SP-RWM INFL
SP-MCMC methods are compared against the original SP (Chen et al., 2018), MED (Roshan Joseph et al., 2015) and SVGD (Liu & Wang, 2016), as well as the Metropolis-adjusted Langevin algorithm (MALA) and random-walk Metropolis (RWM).
13/14
Theoretical Guarantees
The convergence of the proposed SP-MCMC method is established, with an explicit bound provided on the KSD in terms of the V -uniform ergodicity of the Markov transition kernel. Example: SP-MALA Convergence Let (mj)n
j=1 ⊂ N be a fixed sequence and let {xi}n i=1 denote the SP-MALA
- utput, based on Markov chains (Yj,l)mj
l=1, j ∈ N. (Under certain regularity con-
ditions) MALA is V -uniformly ergodic for V (x) = 1 + x2 and ∃C > 0 such that E
- Dk0,p({xi}n
i=1)2
≤ C n
n
- i=1
log(n ∧ mi) n ∧ mi .
14/14
Paper, Code and Poster
- Paper is available at:
https://arxiv.org/pdf/1905.03673.pdf
- Code is available at:
https://github.com/wilson-ye-chen/sp-mcmc
- Check out the poster at Lunch and Poster Session!