reparameterization gradient
play

Reparameterization Gradient for Non-differentiable Models Wonyeol - PowerPoint PPT Presentation

Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018 Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang


  1. Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018

  2. Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018

  3. Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018

  4. Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018

  5. Backgrounds

  6. Posterior inference • Latent variable z � � n . • Observed variable x � � m . • Joint density p(x,z). • Want to infer posterior p(z|x 0 ) given a particular value x 0 of x.

  7. Variational inference 1. Fix a family of variational distr. {q θ (z)} θ . 2. Find q θ (z) that approximates p(z|x 0 ) well. • Typically, by solving argmax θ (ELBO θ ) where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ].

  8. Variational inference differentiable & easy-to-sample 1. Fix a family of variational distr. {q θ (z)} θ . 2. Find q θ (z) that approximates p(z|x 0 ) well. • Typically, by solving argmax θ (ELBO θ ) where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ].

  9. Variational inference differentiable & easy-to-sample 1. Fix a family of variational distr. {q θ (z)} θ . 2. Find q θ (z) that approximates p(z|x 0 ) well. • Typically, by solving argmax θ (ELBO θ ) Typically, by solving where ELBO θ = � qθ (z) [ log( p(x 0 ,z)/q θ (z) ) ]. argmax θ (ELBO θ ) where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ].

  10. Variational inference differentiable & easy-to-sample 1. Fix a family of variational distr. {q θ (z)} θ . 2. Find q θ (z) that approximates p(z|x 0 ) well. • Typically, by solving argmax θ (ELBO θ ) Typically, by solving where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ]. argmax θ (ELBO θ ) where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ]. .. z .. z ..

  11. Gradient ascent θ n+1 = θ n + η × � θ ELBO θ=θn • Difficult to compute � θ ELBO θ . • Use an estimated gradient instead.

  12. Gradient ascent θ n+1 = θ n + η × � θ ELBO θ=θn • Difficult to compute � θ ELBO θ . • Use an estimated gradient instead.

  13. Gradient ascent θ n+1 = θ n + η × � θ ELBO θ=θn • Difficult to compute � θ ELBO θ . • Use an estimated gradient instead.

  14. Reparameterization estimator • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]

  15. Reparameterization estimator • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]

  16. � θ ELBO θ = � θ � qθ(z) [.. z .. z ..] θ θ θ qθ(z) Reparameterization estimator = � θ � q(ε) [.. f θ (ε) .. f θ (ε) ..] θ q(ε) θ θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..)] q(ε) θ θ θ • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]

  17. � θ ELBO θ = � θ � qθ(z) [.. z .. z ..] θ θ θ qθ(z) Reparameterization estimator = � θ � q(ε) [.. f θ (ε) .. f θ (ε) ..] θ q(ε) θ θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..)] q(ε) θ θ θ • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]

  18. � θ ELBO θ = � θ � qθ(z) [.. z .. z ..] Reparameterization estimator = � θ � q(ε) [.. f θ (ε) .. f θ (ε) ..] = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..)] • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]

  19. � θ ELBO θ = � θ � qθ(z) [.. z .. z ..] θ θ θ qθ(z) Reparameterization estimator = � θ � q(ε) [.. f θ (ε) .. f θ (ε) ..] θ q(ε) θ θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..)] q(ε) θ θ θ • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]

  20. Non-differentiable models from probabilistic programming

  21. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)

  22. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)

  23. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)

  24. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)

  25. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)

  26. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  27. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  28. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  29. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) p(z,x=0) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z) z

  30. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  31. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  32. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  33. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  34. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) q(ε) = � (ε|0,1) r 1 (z) = � (z|0,1) � (x=0|3,1) z = ε+θ r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  35. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) q(ε) = � (ε|0,1) r 1 (z) = � (z|0,1) � (x=0|3,1) z = ε+θ r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z) How to find a good θ? θ n+1 ← θ n + η × � θ ELBO θ=θn

  36. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) q(ε) = � (ε|0,1) r 1 (z) = � (z|0,1) � (x=0|3,1) z = ε+θ r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z) How to find a good θ? By gradient ascent on ELBO θ . θ n+1 ← θ n + η × � θ ELBO θ=θn

  37. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) q(ε) = � (ε|0,1) r 1 (z) = � (z|0,1) � (x=0|3,1) z = ε+θ r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z) How to find a good θ? By gradient ascent on ELBO θ . θ n+1 ← θ n + η × � θ ELBO θ=θn

Download Presentation
Download Policy: The content available on the website is offered to you 'AS IS' for your personal information and use only. It cannot be commercialized, licensed, or distributed on other websites without prior consent from the author. To download a presentation, simply click this link. If you encounter any difficulties during the download process, it's possible that the publisher has removed the file from their server.

Recommend


More recommend