Towards Verified Stochastic Variational Inference for Probabilistic Programs
Wonyeol Lee1 Hangyeol Yu1 Xavier Rival2 Hongseok Yang1
1KAIST, South Korea 2INRIA/ENS/CNRS, France
POPL 2020
Towards Verified Stochastic Variational Inference for Probabilistic - - PowerPoint PPT Presentation
Towards Verified Stochastic Variational Inference for Probabilistic Programs Wonyeol Lee 1 Hangyeol Yu 1 Xavier Rival 2 Hongseok Yang 1 1 KAIST, South Korea 2 INRIA/ENS/CNRS, France POPL 2020 Probabilistic c Programming Example 1: def p():
Wonyeol Lee1 Hangyeol Yu1 Xavier Rival2 Hongseok Yang1
1KAIST, South Korea 2INRIA/ENS/CNRS, France
POPL 2020
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.) 2
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.) 3
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.) 4
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.) 5
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.)
density !
prior " ! posterior " ! # 6
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.)
density !
prior " ! posterior # $ % 7
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.) def qθ(): # guide_eg1 θ = pyro.param(“θ”, 0.) z = pyro.sample(“z”, Normal(θ, 1.))
density !
prior " ! posterior " ! # 8
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.) def qθ(): # guide_eg1 θ = pyro.param(“θ”, 0.) z = pyro.sample(“z”, Normal(θ, 1.))
density !
prior " ! posterior " ! # 9
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.) def qθ(): # guide_eg1 θ = pyro.param(“θ”, 0.) z = pyro.sample(“z”, Normal(θ, 1.))
density !
prior " ! posterior " ! # 10
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.) def qθ(): # guide_eg1 θ = pyro.param(“θ”, 0.) z = pyro.sample(“z”, Normal(θ, 1.))
guide !"($) with optimal &
density $
prior ' $ posterior ' $ ( 11
argmin' KL *' + || - + . ≜ 012 3 log
12 3 6 3|7
.
12
KL# = KL %# & || ( & ) . argmin# KL# ≜ 123 4 log
23 4 7 4|8
.
13
KL# = KL %# & || ( & ) .
*+,- ← *+ − 0.01× 4#KL# |#5#6. argmin# KL# ≜ >?@ A log
?@ A D A|E
.
14
KL# = KL %# & || ( & ) .
*+,- ← *+ − 0.01× 4#KL# |#5#6. argmin# KL# ≜ >?@ A log
?@ A D A|E
.
15
KL# = KL %# & || ( & ) .
*+,- ← *+ − 0.01× 4#KL# |#5#6. argmin# KL# ≜ >?@ A log
?@ A D A|E
.
16
KL# = KL %# & || ( & ) .
*+,- ← *+ − 0.01× 4#KL# |#5#6. argmin# KL# ≜ >?@ A log
?@ A D A|E
17
KL# = KL %# & || ( & ) .
*+,- ← *+ − 0.01× 4#KL# |#5#6. argmin# KL# ≜ >?@ A log
?@ A D A|E
Issue 1: Undefined KL#
18
KL# = KL %# & || ( & ) .
*+,- ← *+ − 0.01× 4#KL# |#5#6. argmin# KL# ≜ >?@ A log
?@ A D A|E
Issue 2: Undefined 4#KL# Issue 1: Undefined KL#
19
KL# = KL %# & || ( & ) .
*+,- ← *+ − 0.01× 4#KL# |#5#6. argmin# KL# ≜ >?@ A log
?@ A D A|E
Issue 3: Wrong estimate Issue 2: Undefined 4#KL# Issue 1: Undefined KL#
20
KL# = %&' ( log
&' ( , (|.
= ∫ 01 2# 1 log
&' ( , (|.
KL# could be undefined for two reasons.
21
KL# = %&' ( log
&' ( , (|.
= ∫ 01 2# 1 log
&' ( , (|.
KL# could be undefined for two reasons. (a) Undefined integrand: 2# 1 ≠ 0 and 5 1 6 = 0 for some 1.
22
KL# = %&' ( log
&' ( , (|.
= ∫ 01 2# 1 log
&' ( , (|.
KL# could be undefined for two reasons. (a) Undefined integrand: 2# 1 ≠ 0 and 5 1 6 = 0 for some 1. (b) Undefined integral: ∫ 01 ⋯ is not integrable.
23
def p(): # model_eg2 ... sigma = pyro.sample(“sigma”, Uniform(0., 10.)) ... pyro.sample(“obs”, Normal(..., sigma), obs=...) def qθ(): # guide_eg2 ... sigma = pyro.sample(“sigma”, Normal(θ, 0.05)) 24
def p(): # model_eg2 ... sigma = pyro.sample(“sigma”, Uniform(0., 10.)) ... pyro.sample(“obs”, Normal(..., sigma), obs=...) def qθ(): # guide_eg2 ... sigma = pyro.sample(“sigma”, Normal(θ, 0.05)) 25
def p(): # model_eg2 ... sigma = pyro.sample(“sigma”, Uniform(0., 10.)) ... pyro.sample(“obs”, Normal(..., sigma), obs=...) def qθ(): # guide_eg2 ... sigma = pyro.sample(“sigma”, Normal(θ, 0.05)) 26
def p(): # model_eg2 ... sigma = pyro.sample(“sigma”, Uniform(0., 10.)) ... pyro.sample(“obs”, Normal(..., sigma), obs=...) def qθ(): # guide_eg2 ... sigma = pyro.sample(“sigma”, Normal(θ, 0.05))
$# % ≠ 0 and ( % ) = 0 for % < 0.
27
def p(): # model_eg2 ... sigma = pyro.sample(“sigma”, Uniform(0., 10.)) ... pyro.sample(“obs”, Normal(..., sigma), obs=...) def qθ(): # guide_eg2 ... sigma = pyro.sample(“sigma”, Normal(θ, 0.05)) 28
def p(): # model_eg2’ ... sigma = pyro.sample(“sigma”, Uniform(0., 10.)) ... pyro.sample(“obs”, Normal(..., sigma), obs=...) def qθ(): # guide_eg2 ... sigma = pyro.sample(“sigma”, Normal(θ, 0.05)) Normal(5., 5.) abs(sigma) 29
def p(): # model_eg2’ ... sigma = pyro.sample(“sigma”, Uniform(0., 10.)) ... pyro.sample(“obs”, Normal(..., sigma), obs=...) def qθ(): # guide_eg2 ... sigma = pyro.sample(“sigma”, Normal(θ, 0.05)) Normal(5., 5.) abs(sigma) 30
def p(): # model_eg2’ ... sigma = pyro.sample(“sigma”, Uniform(0., 10.)) ... pyro.sample(“obs”, Normal(..., sigma), obs=...) def qθ(): # guide_eg2 ... sigma = pyro.sample(“sigma”, Normal(θ, 0.05)) Normal(5., 5.) abs(sigma)
$
%& &
1 ( ×* (; ,, 0.05 1( = ∞
31
KL" could be non-differentiable w.r.t. %.
32
KL" could be non-differentiable w.r.t. %.
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.) def qθ(): # guide_eg1’ θ = pyro.param(“θ”, 0.) z = pyro.sample(“z”, Normal(θ, 1.)) Uniform(θ-1.,θ+1.) 33
KL" could be non-differentiable w.r.t. %.
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.) def qθ(): # guide_eg1’ θ = pyro.param(“θ”, 0.) z = pyro.sample(“z”, Normal(θ, 1.)) Uniform(θ-1.,θ+1.)
KL" % not differentiable
34
!"KL" = !" log )" *+ × log
1 /0,3
where *+ is sampled from )".
!"KL" = 4/0[!"KL"] if some requirements are met.
35
!"KL" = !" log )" *+ × log
1 /0,3
where *+ is sampled from )".
!"KL" = 4/0[!"KL"] if some requirements are met.
36
!"KL" = !" log )" *+ × log
1 /0,3
where *+ is sampled from )".
!"KL" = 4/0[!"KL"] if some requirements are met. !"KL" could be an incorrect estimator if the requirements are not met.
37
!"KL" = !"∫ '" ( log
,- . / .|1
2( = ∫ !" '" ( log
,- . / .|1
2( ... = 3.4 !" log '" (5 × log
,- .4 / .4,1
= 3.4 !"KL" .
38
!"KL" = !"∫ '" ( log
,- . / .|1
2( = ∫ !" '" ( log
,- . / .|1
2( ... = 3.4 !" log '" (5 × log
,- .4 / .4,1
= 3.4 !"KL" . This might fail.
39
!"KL" = !"∫ '" ( log
,- . / .|1
2( = ∫ !" '" ( log
,- . / .|1
2( ... = 3.4 !" log '" (5 × log
,- .4 / .4,1
= 3.4 !"KL" . This might fail. !" 8
9(")
⋯ 2( ≠ 8
9(")
!" ⋯ 2(
40
!"KL" = !"∫ '" ( log
,- . / .|1
2( = ∫ !" '" ( log
,- . / .|1
2( ... = 3.4 !" log '" (5 × log
,- .4 / .4,1
= 3.4 !"KL" . This might fail. !" 8
9(")
⋯ 2( ≠ 8
9(")
!" ⋯ 2(
41
(a) $# % ≠ 0 and ( % ) = 0 for some %. (b) ∫ ⋯ -% is not integrable.
42
(a) $# % ≠ 0 and ( % ) = 0 for some %. (b) ∫ ⋯ -% is not integrable.
43
(a) $# % ≠ 0 and ( % ) = 0 for some %. (b) ∫ ⋯ -% is not integrable.
E.g., / %; _, _ > 0 for all % ∈ ℝ.
44
(a) $# % ≠ 0 and ( % ) = 0 for some %. (b) ∫ ⋯ -% is not integrable.
E.g., / %; _, _ > 0 for all % ∈ ℝ. familiar type of static analysis problem
45
(a) $# % ≠ 0 and ( % ) = 0 for some %. (b) ∫ ⋯ -% is not integrable.
46
(a) $# % ≠ 0 and ( % ) = 0 for some %. (b) ∫ ⋯ -% is not integrable.
familiar type of static analysis problems
47
(a) $# % ≠ 0 and ( % ) = 0 for some %. (b) ∫ ⋯ -% is not integrable.
Sufficient conditions about
familiar type of static analysis problems
48
(a) $# % ≠ 0 and ( % ) = 0 for some %. (b) ∫ ⋯ -% is not integrable.
Sufficient conditions about
familiar type of static analysis problems
49
(a) $# % ≠ 0 and ( % ) = 0 for some %. (b) ∫ ⋯ -% is not integrable.
Sufficient conditions about
familiar type of static analysis problems
Density semantics is used for formalization.
50
51
!" # ≠ 0 and & # ' = 0 for some #.
52
!" # ≠ 0 and & # ' = 0 for some #.
(e.g., tensor broadcasting). github.com/wonyeol/static-analysis-for-support-match
53
54
55
Example 2 ! uses Uniform. "# uses Normal.
56
Example 2 ! uses Uniform. "# uses Normal. ! uses Dirichlet. "# uses Delta.
57
! uses Uniform. "# uses Normal. ! uses Dirichlet. "# uses Delta. Performs different inference algorithm, using variational inference engine. Example 2
58
They could be violated sometimes; need to be checked carefully.
59
They could be violated sometimes; need to be checked carefully.
E.g., how to check these assumptions automatically?
60
They could be violated sometimes; need to be checked carefully.
E.g., how to check these assumptions automatically?
61
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.) def qθ(): # guide_eg1’ θ = pyro.param(“θ”, 0.) z = pyro.sample(“z”, Uniform(θ-1., θ+1.))
KL# $
%#KL# ≠ 0 for all $ %#KL# = 0 for all $, *′
def p(): # model_eg1 z = pyro.sample(“z”, Normal(0., 5.)) if (z > 0): pyro.sample(“x”, Normal( 1., 1.), obs=0.) else: pyro.sample(“x”, Normal(-2., 1.), obs=0.) def qθ(): # guide_eg1’ θ = pyro.param(“θ”, 0.) z = pyro.sample(“z”, Uniform(θ-1., θ+1.))
KL# $
%#KL# ≠ 0 for all $ %#KL# = 0 for all $, *′ %# ,
#-. #/.
⋯ 1* ≠ ,
#-. #/.
%# ⋯ 1*
), * : mean, standard deviation in !(#, %). )′, *′: mean, standard deviation in '((#).
§ ) # ≤ exp(0( # )) for some affine 0. § exp(1( # )) ≤ * # ≤ exp(ℎ( # )) for some affine 1, ℎ. § )3, *3 are continuously differentiable w.r.t. 4.
), * : mean, standard deviation in !(#, %). )′, *′: mean, standard deviation in '((#).
§ ) # ≤ exp(0( # )) for some affine 0. § exp(1( # )) ≤ * # ≤ exp(ℎ( # )) for some affine 1, ℎ. § )3, *3 are continuously differentiable w.r.t. 4.
:
;< <
exp(0 # )×> #; ⋯ 7# < ∞ for any affine 0.
), * : mean, standard deviation in !(#, %). )′, *′: mean, standard deviation in '((#).
§ ) # ≤ exp(0( # )) for some affine 0. § exp(1( # )) ≤ * # ≤ exp(ℎ( # )) for some affine 1, ℎ. § )3, *3 are continuously differentiable w.r.t. 4.
well-behaved function of 4, #