SLIDE 1 Reparameterization Gradient for Non-differentiable Models
Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST
Published at NeurIPS 2018
SLIDE 2 Reparameterization Gradient for Non-differentiable Models
Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST
Published at NeurIPS 2018
SLIDE 3 Reparameterization Gradient for Non-differentiable Models
Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST
Published at NeurIPS 2018
SLIDE 4 Reparameterization Gradient for Non-differentiable Models
Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST
Published at NeurIPS 2018
SLIDE 5
Backgrounds
SLIDE 6 Posterior inference
- Latent variable z n.
- Observed variable x m.
- Joint density p(x,z).
- Want to infer posterior p(z|x0) given a
particular value x0 of x.
SLIDE 7 Variational inference
- 1. Fix a family of variational distr. {qθ(z)}θ.
- 2. Find qθ(z) that approximates p(z|x0) well.
- Typically, by solving
argmaxθ(ELBOθ) where ELBOθ = qθ(z)[ log( p(x0,z)/qθ(z) ) ].
SLIDE 8 Variational inference
- 1. Fix a family of variational distr. {qθ(z)}θ.
- 2. Find qθ(z) that approximates p(z|x0) well.
- Typically, by solving
argmaxθ(ELBOθ) where ELBOθ = qθ(z)[ log( p(x0,z)/qθ(z) ) ].
differentiable & easy-to-sample
SLIDE 9 Variational inference
- 1. Fix a family of variational distr. {qθ(z)}θ.
- 2. Find qθ(z) that approximates p(z|x0) well.
- Typically, by solving
argmaxθ(ELBOθ) where ELBOθ = qθ(z)[ log( p(x0,z)/qθ(z) ) ].
Typically, by solving argmaxθ(ELBOθ) where ELBOθ = qθ(z)[ log( p(x0,z)/qθ(z) ) ].
differentiable & easy-to-sample
SLIDE 10 Variational inference
- 1. Fix a family of variational distr. {qθ(z)}θ.
- 2. Find qθ(z) that approximates p(z|x0) well.
- Typically, by solving
argmaxθ(ELBOθ) where ELBOθ = qθ(z)[ log( p(x0,z)/qθ(z) ) ].
Typically, by solving argmaxθ(ELBOθ) where ELBOθ = qθ(z)[ log( p(x0,z)/qθ(z) ) ].
.. z .. z ..
differentiable & easy-to-sample
SLIDE 11 Gradient ascent
θn+1 = θn + η × θELBOθ=θn
- Difficult to compute θELBOθ.
- Use an estimated gradient instead.
SLIDE 12 Gradient ascent
θn+1 = θn + η × θELBOθ=θn
- Difficult to compute θELBOθ.
- Use an estimated gradient instead.
SLIDE 13 Gradient ascent
θn+1 = θn + η × θELBOθ=θn
- Difficult to compute θELBOθ.
- Use an estimated gradient instead.
SLIDE 14 Reparameterization estimator
- Works if p(x0,z) is differentiable wrt. z.
- Need distr. q(ε) & smooth function fθ(ε) s.t.
fθ(ε) for ε ~ q(ε) has the distr. qθ(z).
- Derived from the equation:
θELBOθ = q(ε)[ θ(.. fθ(ε) .. fθ(ε) ..) ]
SLIDE 15 Reparameterization estimator
- Works if p(x0,z) is differentiable wrt. z.
- Need distr. q(ε) & smooth function fθ(ε) s.t.
fθ(ε) for ε ~ q(ε) has the distr. qθ(z).
- Derived from the equation:
θELBOθ = q(ε)[ θ(.. fθ(ε) .. fθ(ε) ..) ]
SLIDE 16 Reparameterization estimator
- Works if p(x0,z) is differentiable wrt. z.
- Need distr. q(ε) & smooth function fθ(ε) s.t.
fθ(ε) for ε ~ q(ε) has the distr. qθ(z).
- Derived from the equation:
θELBOθ = q(ε)[ θ(.. fθ(ε) .. fθ(ε) ..) ] θELBOθ = θqθ(z)[.. z .. z ..] = θq(ε)[.. fθ(ε) .. fθ(ε) ..] = q(ε)[θ(.. fθ(ε) .. fθ(ε) ..)]
θ θ θ qθ(z) θ q(ε) θ θ q(ε) θ θ θ
SLIDE 17 Reparameterization estimator
- Works if p(x0,z) is differentiable wrt. z.
- Need distr. q(ε) & smooth function fθ(ε) s.t.
fθ(ε) for ε ~ q(ε) has the distr. qθ(z).
- Derived from the equation:
θELBOθ = q(ε)[ θ(.. fθ(ε) .. fθ(ε) ..) ] θELBOθ = θqθ(z)[.. z .. z ..] = θq(ε)[.. fθ(ε) .. fθ(ε) ..] = q(ε)[θ(.. fθ(ε) .. fθ(ε) ..)]
θ θ θ qθ(z) θ q(ε) θ θ q(ε) θ θ θ
SLIDE 18 Reparameterization estimator
θELBOθ = θqθ(z)[.. z .. z ..] = θq(ε)[.. fθ(ε) .. fθ(ε) ..] = q(ε)[θ(.. fθ(ε) .. fθ(ε) ..)]
- Works if p(x0,z) is differentiable wrt. z.
- Need distr. q(ε) & smooth function fθ(ε) s.t.
fθ(ε) for ε ~ q(ε) has the distr. qθ(z).
- Derived from the equation:
θELBOθ = q(ε)[ θ(.. fθ(ε) .. fθ(ε) ..) ]
SLIDE 19 Reparameterization estimator
- Works if p(x0,z) is differentiable wrt. z.
- Need distr. q(ε) & smooth function fθ(ε) s.t.
fθ(ε) for ε ~ q(ε) has the distr. qθ(z).
- Derived from the equation:
θELBOθ = q(ε)[ θ(.. fθ(ε) .. fθ(ε) ..) ] θELBOθ = θqθ(z)[.. z .. z ..] = θq(ε)[.. fθ(ε) .. fθ(ε) ..] = q(ε)[θ(.. fθ(ε) .. fθ(ε) ..)]
θ θ θ qθ(z) θ q(ε) θ θ q(ε) θ θ θ
SLIDE 20
Non-differentiable models from probabilistic programming
SLIDE 21 (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
SLIDE 22 (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
SLIDE 23 (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
SLIDE 24 (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
SLIDE 25 (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
SLIDE 26 r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
SLIDE 27 (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
SLIDE 28 (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
SLIDE 29 r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z) z p(z,x=0)
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
SLIDE 30 (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) (let [ε (sample (normal 0 1)) z (+ ε θ)] z)
r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
≈
SLIDE 31 (let [ε (sample (normal 0 1)) z (+ ε θ)] z)
r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
≈
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
SLIDE 32 (let [ε (sample (normal 0 1)) z (+ ε θ)] z)
r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
≈
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
SLIDE 33 (let [ε (sample (normal 0 1)) z (+ ε θ)] z)
r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
≈
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
SLIDE 34 r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
≈
q(ε) = (ε|0,1) z = ε+θ
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) (let [ε (sample (normal 0 1)) z (+ ε θ)] z)
SLIDE 35 r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
≈
θn+1 ← θn + η × θ ELBOθ=θn
q(ε) = (ε|0,1) z = ε+θ
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) (let [ε (sample (normal 0 1)) z (+ ε θ)] z)
How to find a good θ?
SLIDE 36 r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
≈
How to find a good θ? By gradient ascent on ELBOθ. θn+1 ← θn + η × θ ELBOθ=θn
q(ε) = (ε|0,1) z = ε+θ
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) (let [ε (sample (normal 0 1)) z (+ ε θ)] z)
SLIDE 37 r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
≈
θn+1 ← θn + η × θ ELBOθ=θn
q(ε) = (ε|0,1) z = ε+θ
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) (let [ε (sample (normal 0 1)) z (+ ε θ)] z)
How to find a good θ? By gradient ascent on ELBOθ.
SLIDE 38 r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
θn+1 ← θn + η × θ ELBOθ=θn
q(ε) = (ε|0,1) z = ε+θ
θELBOθ = θq(ε)[[ε>-θ]log(r1(ε+θ)) + [ε≤-θ]log(r2(ε+θ))] = q(ε)[[ε>-θ]θlog(r1(ε+θ)) + [ε≤-θ]θlog(r2(ε+θ))] = q(ε)[θ]
θ θ θ q(ε) 1 2 q(ε) θ 1 θ 2 q(ε)
How to find a good θ? By gradient ascent on ELBOθ.
SLIDE 39 r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
θn+1 ← θn + η × θ ELBOθ=θn
q(ε) = (ε|0,1) z = ε+θ
θELBOθ = θq(ε)[[ε>-θ]log(r1(ε+θ)) + [ε≤-θ]log(r2(ε+θ))] = q(ε)[[ε>-θ]θlog(r1(ε+θ)) + [ε≤-θ]θlog(r2(ε+θ))] = q(ε)[θ]
θ θ θ q(ε) 1 2 q(ε) θ 1 θ 2 q(ε)
How to find a good θ? By gradient ascent on ELBOθ.
SLIDE 40 r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
θn+1 ← θn + η × θ ELBOθ=θn
q(ε) = (ε|0,1) z = ε+θ
θELBOθ = θq(ε)[[ε>-θ]log(r1(ε+θ)) + [ε≤-θ]log(r2(ε+θ))] = q(ε)[[ε>-θ]θlog(r1(ε+θ)) + [ε≤-θ]θlog(r2(ε+θ))] = q(ε)[θ]
θ θ θ q(ε) 1 2 q(ε) θ 1 θ 2 q(ε)
How to find a good θ? By gradient ascent on ELBOθ.
SLIDE 41 θELBOθ = θq(ε)[[ε>-θ]log(r1(ε+θ)) + [ε≤-θ]log(r2(ε+θ))] = q(ε)[[ε>-θ]θlog(r1(ε+θ)) + [ε≤-θ]θlog(r2(ε+θ))] = q(ε)[θ]
θ θ θ q(ε) 1 2 q(ε) θ 1 θ 2 q(ε)
r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
θn+1 ← θn + η × θ ELBOθ=θn
q(ε) = (ε|0,1) z = ε+θ
How to find a good θ? By gradient ascent on ELBOθ.
SLIDE 42 r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
θn+1 ← θn + η × θ ELBOθ=θn θELBOθ = θq(ε)[[ε>-θ]log(r1(ε+θ)) + [ε≤-θ]log(r2(ε+θ))] = q(ε)[[ε>-θ]θlog(r1(ε+θ)) + [ε≤-θ]θlog(r2(ε+θ))] = q(ε)[θ]
θ θ θ q(ε) 1 2 q(ε) θ 1 θ 2 q(ε)
q(ε) = (ε|0,1) z = ε+θ
How to find a good θ? By gradient ascent on ELBOθ.
SLIDE 43 r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
θn+1 ← θn + η × θ ELBOθ=θn
q(ε) = (ε|0,1) z = ε+θ
θELBOθ = θq(ε)[[ε>-θ]log(r1(ε+θ)) + [ε≤-θ]log(r2(ε+θ))] = q(ε)[[ε>-θ]θlog(r1(ε+θ)) + [ε≤-θ]θlog(r2(ε+θ))] = q(ε)[ -θ-ε ] = -θ
θ θ θ q(ε) 1 2 q(ε) θ 1 θ 2 q(ε)
How to find a good θ? By gradient ascent on ELBOθ.
SLIDE 44 r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
θn+1 ← θn + η × θ ELBOθ=θn
q(ε) = (ε|0,1) z = ε+θ q(ε) = (ε|0,1) z = ε+θ
θELBOθ = θq(ε)[[ε>-θ]log(r1(ε+θ)) + [ε≤-θ]log(r2(ε+θ))] = q(ε)[[ε>-θ]θlog(r1(ε+θ)) + [ε≤-θ]θlog(r2(ε+θ))] = q(ε)[ -θ-ε ] = -θ
θ θ θ q(ε) 1 2 q(ε) θ 1 θ 2 q(ε)
How to find a good θ? By gradient ascent on ELBOθ.
SLIDE 45 How to find a good θ? By gradient ascent on ELBOθ.
r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
θn+1 ← θn + η × θ ELBOθ=θn
q(ε) = (ε|0,1) z = ε+θ q(ε) = (ε|0,1) z = ε+θ
θELBOθ = θq(ε)[[ε>-θ]log(r1(ε+θ)) + [ε≤-θ]log(r2(ε+θ))] = q(ε)[[ε>-θ]θlog(r1(ε+θ)) + [ε≤-θ]θlog(r2(ε+θ))] = q(ε)[ -θ-ε ] = -θ
θ θ θ q(ε) 1 2 q(ε) θ 1 θ 2 q(ε)
z
SLIDE 46 θELBOθ = θq(ε)[[ε>-θ]log(r1(ε+θ)) + [ε≤-θ]log(r2(ε+θ))] = q(ε)[[ε>-θ]θlog(r1(ε+θ)) + [ε≤-θ]θlog(r2(ε+θ))] = q(ε)[ -θ-ε ] = -θ
θ θ θ q(ε) 1 2 q(ε) θ 1 θ 2 q(ε)
r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
θn+1 ← θn + η × θ ELBOθ=θn
q(ε) = (ε|0,1) z = ε+θ
How to find a good θ? By gradient ascent on ELBOθ.
SLIDE 47 r1(z) = (z|0,1)(x=0|3,1) r2(z) = (z|0,1)(x=0|-2,1) p(z,x=0) = [z>0]r1(z) + [z≤0]r2(z)
θn+1 ← θn + η × θ ELBOθ=θn
q(ε) = (ε|0,1) z = ε+θ
θELBOθ = θq(ε)[[ε>-θ]log(r1(ε+θ)) + [ε≤-θ]log(r2(ε+θ))] = q(ε)[[ε>-θ]θlog(r1(ε+θ)) + [ε≤-θ]θlog(r2(ε+θ))] + Correction Term How to find a good θ? By gradient ascent on ELBOθ.
SLIDE 48 Why doesn’t it work?
- Careful when exchanging gradient and
integration.
- May fail unexpectedly.
- May hold unexpectedly, but with correction.
SLIDE 49 Why doesn’t it work?
- Careful when exchanging gradient and
integration.
- May fail unexpectedly.
- May hold unexpectedly, but with correction.
SLIDE 50 Why doesn’t it work?
- Careful when exchanging gradient and
integration.
- May fail unexpectedly.
- May hold unexpectedly, but with correction.
+ CorrectionTerm
SLIDE 51
Our results formally
SLIDE 52
Non-differentiable models
SLIDE 53
Non-differentiable models
SLIDE 54 Non-differentiable models
is differentiable.
SLIDE 55 Non-differentiable models
is differentiable.
has Lebesgue measure zero.
SLIDE 56
Wishful thinking
SLIDE 57
Wishful thinking
SLIDE 58
Wishful thinking
SLIDE 59
Wishful thinking
SLIDE 60
Wishful thinking
SLIDE 61
Wishful thinking
SLIDE 62
Wishful thinking
SLIDE 63
Wishful thinking
SLIDE 74 Correction
- surface integral over
- Accounts for the impact of moving the boundaries.
Can be estimated by manifold sampling.
SLIDE 75 Correction
- surface integral over
- Accounts for the impact of moving the boundaries.
Can be estimated by manifold sampling.
SLIDE 76 Two ingredients
- Differentiation under moving domain:
SLIDE 77 Two ingredients
- Differentiation under moving domain:
- Divergence theorem:
SLIDE 78
Two ingredients
SLIDE 79 Surface integral over
Correction term
SLIDE 80 Surface integral over
Requires manifold sampling Hard to estimate in general cases
SLIDE 81 Surface integral over
- Easy to estimate if
- is a hyperplane.
Correction term
SLIDE 82 Surface integral over
- Easy to estimate if
- is a hyperplane.
- Assume the branch condition of each if-
statement is linear in .
Correction term
SLIDE 83 Subsampling
- surface integral over
- For computational efficiency,
we subsample surface integrals.
SLIDE 84 Subsampling
- surface integral over
- For computational efficiency,
we subsample surface integrals.
SLIDE 85
Experiments
SLIDE 86 Implementation
- Implemented a black-box variational
inference engine for a simple probabilistic programming language
- Supports sample, observe, if, ...
- Written in Python, using autograd package.
SLIDE 87 Benchmarks
textmsg
- Models #’s of per-day SNS msg’s, where SNS-
usage pattern changes on some day.
- Non-differentiable part: the day of change in
SNS-usage pattern.
- Given #’s of per-day SNS msg’s over 2 months,
infer the day when the pattern changes.
SLIDE 88 Benchmarks
temperature
- Models random dynamics of a controller that
tries to keep room temp. stable.
- Non-differentiable part: on/off of air conditioner,
- n which evolution of room temp. depends.
- Given noisy observations of temp. at each step,
infer on/off status of the controller at each step.
SLIDE 89 ELBO
{dotted, dashed, solid} lines: {N = 1, N = 8, N = 16}
SLIDE 90 ELBO
{dotted, dashed, solid} lines: {N = 1, N = 8, N = 16}
SLIDE 91
Computation time
SLIDE 92 High-level message
- Careful when exchanging gradient and integration.
SLIDE 93 High-level message
- Careful when exchanging gradient and integration.
- May fail unexpectedly.
SLIDE 94 High-level message
- Careful when exchanging gradient and integration.
- May fail unexpectedly.
- May hold unexpectedly, but with correction.
SLIDE 95
Any questions?