Variational Autoencoders Recap: Story so far A classification MLP - - PowerPoint PPT Presentation

variational autoencoders recap story so far
SMART_READER_LITE
LIVE PREVIEW

Variational Autoencoders Recap: Story so far A classification MLP - - PowerPoint PPT Presentation

Variational Autoencoders Recap: Story so far A classification MLP actually comprises two components A feature extraction network that converts the inputs into linearly separable features Or nearly linearly separable features


slide-1
SLIDE 1

Variational Autoencoders

slide-2
SLIDE 2

Recap: Story so far

  • A classification MLP actually comprises two components
  • A “feature extraction network” that converts the inputs into linearly

separable features

  • Or nearly linearly separable features
  • A final linear classifier that operates on the linearly separable features
  • Neural networks can be used to perform linear or non-linear PCA
  • “Autoencoders”
  • Can also be used to compose constructive dictionaries for data
  • Which, in turn can be used to model data distributions
slide-3
SLIDE 3

𝑧1 𝑧2

Recap: The penultimate layer

  • The network up to the output layer may be viewed as a transformation that

transforms data from non-linear classes to linearly separable features

  • We can now attach any linear classifier above it for perfect classification
  • Need not be a perceptron
  • In fact, slapping on an SVM on top of the features may be more generalizable!

x1 x2 y2 y1

slide-4
SLIDE 4

Recap: The behavior of the layers

slide-5
SLIDE 5

Recap: Auto-encoders and PCA

5

𝐲 ො 𝐲 𝒙 𝒙𝑼

Training: Learning 𝑋 by minimizing L2 divergence ො x = 𝑥𝑈𝑥x 𝑒𝑗𝑤 ො x, x = x − ො x 2 = x − w𝑈𝑥x 2 ෡ 𝑋 = argmin

𝑋

𝐹 x − w𝑈𝑥x 2 ෡ 𝑋 = argmin

𝑋

𝐹 𝑒𝑗𝑤 ො x, x

slide-6
SLIDE 6

Recap: Auto-encoders and PCA

  • The autoencoder finds the direction of maximum energy
  • Variance if the input is a zero-mean RV
  • All input vectors are mapped onto a point on the principal

axis

6

𝐲 ො 𝐲 𝒙 𝒙𝑼

slide-7
SLIDE 7

Recap: Auto-encoders and PCA

  • Varying the hidden layer value only generates data along

the learned manifold

  • May be poorly learned
  • Any input will result in an output along the learned manifold
slide-8
SLIDE 8

DECODER

Recap: Learning a data-manifold

  • The decoder represents a source-specific generative

dictionary

  • Exciting it will produce typical data from the source!

8

Sax dictionary

slide-9
SLIDE 9

Overview

  • Just as autoencoders can be viewed as performing a non-linear PCA,

variational autoencoders can be viewed as performing a non-linear Factor Analysis (FA)

  • Variational autoencoders (VAEs) get their name from variational

inference, a technique that can be used for parameter estimation

  • We will introduce Factor Analysis, variational inference and

expectation maximization, and finally VAEs

slide-10
SLIDE 10

Why Generative Models? Training data

  • Unsupervised/Semi-supervised learning: More training data available
  • E.g. all of the videos on YouTube
slide-11
SLIDE 11

Why generative models? Many right answers

  • Caption -> Image

A man in an orange jacket with sunglasses and a hat skis down a hill

  • Outline -> Image

https://openreview.net/pdf?id=Hyvw0L9el https://arxiv.org/abs/1611.07004

slide-12
SLIDE 12

Why generative models? Intrinsic to task

Example: Super resolution https://arxiv.org/abs/1609.04802

slide-13
SLIDE 13

Why generative models? Insight

https://bmcbioinformatics.biomedcentral.c

  • m/articles/10.1186/1471-2105-12-327
  • What kind of structure can we find in complex
  • bservations (MEG recording of brain activity

above, gene-expression network to the left)?

  • Is there a low dimensional manifold underlying

these complex observations?

  • What can we learn about the brain, cellular

function, etc. if we know more about these manifolds?

slide-14
SLIDE 14

Factor Analysis

  • Generative model: Assumes that data are generated from real valued

latent variables

Bishop – Pattern Recognition and Machine Learning

slide-15
SLIDE 15

Factor Analysis model

Factor analysis assumes a generative model

  • where the 𝑗𝑢ℎ observation, 𝒚𝒋 ∈ ℝ𝐸 is conditioned on
  • a vector of real valued latent variables 𝒜𝒋 ∈ ℝ𝑀.

Here we assume the prior distribution is Gaussian: 𝑞 𝒜𝒋 = 𝒪(𝒜𝒋|𝝂𝟏, 𝚻𝟏) We also will use a Gaussian for the data likelihood: 𝑞 𝒚𝒋 𝒜𝒋, 𝑿, 𝝂, 𝛀 = 𝒪(𝑿𝒜𝒋 + 𝝂, 𝛀) Where 𝑿 ∈ ℝ𝐸×𝑀, 𝛀 ∈ ℝ𝐸×𝐸, 𝛀 is diagonal

slide-16
SLIDE 16

Marginal distribution of observed 𝒚𝒋

𝑞 𝒚𝒋 𝑿, 𝝂, 𝛀 = න 𝒪(𝑿𝒜𝒋 + 𝝂, 𝛀) 𝒪 𝒜𝒋 𝝂𝟏, 𝚻𝟏 𝐞𝒜𝒋 = 𝒪 𝒚𝒋 𝑿𝝂𝟏 + 𝝂, 𝛀 + 𝑿 𝚻𝟏𝑿𝑈 Note that we can rewrite this as: 𝑞 𝒚𝒋 ෢ 𝑿, ෝ 𝝂, 𝛀 = 𝒪 𝒚𝒋 ෝ 𝝂, 𝛀 + ෢ 𝑿෢ 𝑿𝑈 Where ෝ 𝝂 = 𝑿𝝂𝟏 + 𝝂 and ෢ 𝑿 = 𝑿𝚻𝟏

−1

2.

Thus without loss of generality (since 𝝂𝟏, 𝚻𝟏 are absorbed into learnable parameters) we let: 𝑞 𝒜𝒋 = 𝒪 𝒜𝒋 𝟏, 𝑱 And find: 𝑞 𝒚𝒋 𝑿, 𝝂, 𝛀 = 𝒪 𝒚𝒋 𝝂, 𝛀 + 𝑿𝑿𝑈

slide-17
SLIDE 17

Marginal distribution interpretation

  • We can see from 𝑞 𝒚𝒋 𝑿, 𝝂, 𝛀 = 𝒪 𝒚𝒋 𝝂, 𝛀 + 𝑿𝑿𝑈 that the

covariance matrix of the data distribution is broken into 2 terms

  • A diagonal part 𝛀: variance not shared between variables
  • A low rank matrix 𝑿𝑿𝑈: shared variance due to latent factors
slide-18
SLIDE 18

Special Case: Probabilistic PCA (PPCA)

  • Probabilistic PCA is a special case of Factor Analysis
  • We further restrict 𝛀 = 𝜏2𝑱 (assume isotropic independent variance)
  • Possible to show that when the data are centered (𝝂 = 0), the limiting

case where 𝜏 → 0 gives back the same solution for 𝑿 as PCA

  • Factor analysis is a generalization of PCA that models non-shared

variance (can think of this as noise in some situations, or individual variation in others)

slide-19
SLIDE 19

Inference in FA

  • To find the parameters of the FA model, we use the Expectation

Maximization (EM) algorithm

  • EM is very similar to variational inference
  • We’ll derive EM by first finding a lower bound on the log-likelihood

we want to maximize, and then maximizing this lower bound

slide-20
SLIDE 20

Evidence Lower Bound decomposition

  • For any distributions 𝑟 𝑨 , 𝑞(𝑨) we have:

KL 𝑟 𝑨 || 𝑞 𝑨 ≜ න 𝑟 𝑨 log 𝑟(𝑨) 𝑞(𝑨) 𝐞𝑨

  • Consider the KL divergence of an arbitrary weighting distribution

𝑟 𝑨 from a conditional distribution 𝑞 𝑨|𝑦, 𝜄 : KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄 ≜ න 𝑟 𝑨 log 𝑟(𝑨) 𝑞(𝑨|𝑦, 𝜄) 𝐞𝑨 = න 𝑟 𝑨 [log 𝑟 𝑨 − log 𝑞(𝑨|𝑦, 𝜄)] 𝐞𝑨

slide-21
SLIDE 21

Applying Bayes

log 𝑞 𝑨 𝑦, 𝜄 = log 𝑞 𝑦 𝑨, 𝜄 𝑞(𝑨|𝜄) 𝑞(𝑦|𝜄) = log 𝑞 𝑦 𝑨, 𝜄 + log 𝑞 𝑨 𝜄 − log 𝑞 𝑦 𝜄 Then: KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄 = න 𝑟 𝑨 [log 𝑟 𝑨 − log 𝑞(𝑨|𝑦, 𝜄)] 𝐞𝑨 = න 𝑟 𝑨 log 𝑟 𝑨 − log 𝑞 𝑦 𝑨, 𝜄 − log 𝑞 𝑨 𝜄 + log 𝑞 𝑦 𝜄 𝐞𝑨

slide-22
SLIDE 22

Rewriting the divergence

  • Since the last term does not depend on z, and we know ׬ 𝑟 𝑨 d𝑨 = 1, we can pull it out of the

integration: න 𝑟 𝑨 log 𝑟 𝑨 − log 𝑞 𝑦 𝑨, 𝜄 − log 𝑞 𝑨 𝜄 + log 𝑞 𝑦 𝜄 𝐞𝑨 = න 𝑟 𝑨 log 𝑟 𝑨 − log 𝑞 𝑦 𝑨, 𝜄 − log 𝑞 𝑨 𝜄 𝐞𝑨 + log 𝑞 𝑦 𝜄 = න 𝑟 𝑨 log 𝑟(𝑨) 𝑞 𝑦 𝑨, 𝜄 𝑞(𝑨, 𝜄) 𝐞𝑨 + log 𝑞 𝑦 𝜄 = න 𝑟 𝑨 log 𝑟(𝑨) 𝑞(𝑦, 𝑨 |𝜄) 𝐞𝑨 + log 𝑞 𝑦 𝜄 Then we have: KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄 = KL 𝑟 𝑨 || 𝑞 𝑦, 𝑨 |𝜄 + log 𝑞 𝑦 𝜄

slide-23
SLIDE 23

Evidence Lower Bound

  • From basic probability we have:

KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄 = KL 𝑟 𝑨 || 𝑞 𝑦, 𝑨 |𝜄 + log 𝑞 𝑦 𝜄

  • We can rearrange the terms to get the following decomposition:

log 𝑞 𝑦 𝜄 = KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄 − KL 𝑟 𝑨 || 𝑞 𝑦, 𝑨 |𝜄

  • We define the evidence lower bound (ELBO) as:

ℒ 𝑟, 𝜄 ≜ −KL 𝑟 𝑨 || 𝑞 𝑦, 𝑨 |𝜄 Then: log 𝑞 𝑦 𝜄 = KL 𝑟 𝑨 ||𝑞 𝑨|𝑦, 𝜄 + ℒ 𝑟, 𝜄

slide-24
SLIDE 24

Why the name evidence lower bound?

  • Rearranging the decomposition

log 𝑞 𝑦 𝜄 = KL 𝑟 𝑨 ||𝑞 𝑨|𝑦, 𝜄 + ℒ 𝑟, 𝜄

  • we have

ℒ 𝑟, 𝜄 = log 𝑞 𝑦 𝜄 − KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄

  • Since KL 𝑟 𝑨 ||𝑞 𝑨|𝑦, 𝜄

≥ 0, ℒ 𝑟, 𝜄 is a lower bound on the log- likelihood we want to maximize

  • 𝑞 𝑦 𝜄 is sometimes called the evidence
  • When is this bound tight? When 𝑟 𝑨 = 𝑞 𝑨|𝑦, 𝜄
  • The ELBO is also sometimes called the variational bound
slide-25
SLIDE 25

Visualizing ELBO decomposition

  • Note: all we have done so far is decompose the log

probability of the data, we still have exact equality

  • This holds for any distribution 𝑟

Bishop – Pattern Recognition and Machine Learning

slide-26
SLIDE 26

Expectation Maximization

  • Expectation Maximization alternately optimizes the ELBO, ℒ 𝑟, 𝜄 ,

with respect to 𝑟 (the E step) and 𝜄 (the M step)

  • Initialize 𝜄(0)
  • At each iteration 𝑢 = 1, …
  • E step: Hold 𝜄(𝑢−1) fixed, find 𝑟(𝑢) which maximizes ℒ 𝑟, 𝜄(𝑢−1)
  • M step: Hold 𝑟(𝑢) fixed, find 𝜄(𝑢) which maximizes ℒ 𝑟(𝑢), 𝜄
slide-27
SLIDE 27

The E step

  • Suppose we are at iteration 𝑢 of our algorithm. How do we maximize

ℒ 𝑟, 𝜄(𝑢−1) with respect to 𝑟? We know that:

argmax𝑟 ℒ 𝑟, 𝜄(𝑢−1) = argmax𝑟 log 𝑞 𝑦|𝜄 𝑢−1 − KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄(𝑢−1)

Bishop – Pattern Recognition and Machine Learning

slide-28
SLIDE 28

The E step

  • Suppose we are at iteration 𝑢 of our algorithm. How do we maximize

ℒ 𝑟, 𝜄(𝑢−1) with respect to 𝑟? We know that:

argmax𝑟 ℒ 𝑟, 𝜄(𝑢−1) = argmax𝑟 log 𝑞 𝑦|𝜄 𝑢−1 − KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄(𝑢−1)

  • The first term does not involve 𝑟, and we know the KL

divergence must be non-negative

  • The best we can do is to make the KL divergence 0
  • Thus the solution is to set 𝒓 𝒖 𝒜 ← 𝒒 𝒜 𝒚, 𝜾 𝒖−𝟐

Bishop – Pattern Recognition and Machine Learning

slide-29
SLIDE 29

The E step

  • Suppose we are at iteration 𝑢 of our algorithm. How do we maximize

ℒ 𝑟, 𝜄(𝑢−1) with respect to 𝑟? 𝒓 𝒖 𝒜 ← 𝒒 𝒜 𝒚, 𝜾 𝒖−𝟐

Bishop – Pattern Recognition and Machine Learning

slide-30
SLIDE 30

The M step

  • Fixing 𝑟 𝑢 𝑨 we now solve:

argmax𝜄 ℒ 𝑟(𝑢), 𝜄 = argmax𝜄 −KL 𝑟(𝑢) 𝑨 || 𝑞 𝑦, 𝑨|𝜄 = argmax𝜄 − න 𝑟(𝑢) 𝑨 log 𝑟(𝑢) 𝑨 𝑞 𝑦, 𝑨|𝜄 𝐞𝑨 = argmax𝜄 න 𝑟(𝑢) 𝑨 log 𝑞 𝑦, 𝑨 𝜄 − log 𝑟(𝑢) 𝑨 𝐞𝑨 = argmax𝜄 න 𝑟(𝑢) 𝑨 log 𝑞 𝑦, 𝑨 𝜄 − 𝑟(𝑢) 𝑨 log 𝑟(𝑢) 𝑨 𝐞𝑨 = argmax𝜄 න 𝑟(𝑢) 𝑨 log 𝑞 𝑦, 𝑨 𝜄 𝐞𝑨 = argmax𝜄 𝔽𝑟 𝑢 (𝑨) log 𝑞 𝑦, 𝑨 𝜄

Constant w.r.t. 𝜄

slide-31
SLIDE 31

The M step

  • After applying the E step, we increase the likelihood of the data by finding better

parameters according to: 𝜄(𝑢) ← 𝐛𝐬𝐡𝐧𝐛𝐲𝜾 𝔽𝒓 𝒖 (𝒜) 𝐦𝐩𝐡 𝒒 𝒚, 𝒜 𝜾

Bishop – Pattern Recognition and Machine Learning

slide-32
SLIDE 32

EM algorithm

  • Initialize 𝜄(0)
  • At each iteration 𝑢 = 1, …
  • E step: Update 𝑟 𝑢 𝑨 ← 𝑞 𝑨 𝑦, 𝜄 𝑢−1
  • M step: Update 𝜄(𝑢) ← argmax𝜄 𝔽𝑟 𝑢 (𝑨) log 𝑞 𝑦, 𝑨 𝜄
slide-33
SLIDE 33

Why does EM work?

  • EM does coordinate ascent on the ELBO, ℒ 𝑟, 𝜄
  • Each iteration increases the log-likelihood until 𝑟 𝑢 converges (i.e. we

reach a local maximum)!

  • Simple to prove

Notice after the E step: ℒ 𝑟 𝑢 , 𝜄(𝑢−1) = log 𝑞(𝑦|𝜄(𝑢−1)) − KL 𝑞 𝑨|𝑦, 𝜄 𝑢−1 || 𝑞 𝑨|𝑦, 𝜄 𝑢−1 = log 𝑞(𝑦|𝜄(𝑢−1)) The ELBO is tight! By definition of argmax in the M step: ℒ 𝑟 𝑢 , 𝜄(𝑢) ≥ ℒ 𝑟 𝑢 , 𝜄(𝑢−1) By simple substitution: ℒ 𝑟 𝑢 , 𝜄(𝑢) ≥ log 𝑞 𝑦 𝜄 𝑢−1 Rewriting the left hand side: log 𝑞(𝑦|𝜄(𝑢)) − KL 𝑞 𝑨|𝑦, 𝜄 𝑢−1 || 𝑞 𝑨|𝑦, 𝜄 𝑢 ≥ log 𝑞 𝑦 𝜄 𝑢−1 Noting that KL is non-negative: 𝐦𝐩𝐡 𝒒 𝒚 𝜾 𝒖 ≥ 𝐦𝐩𝐡 𝒒 𝒚 𝜾 𝒖−𝟐

slide-34
SLIDE 34

Why does EM work?

  • This proof is saying the same thing we saw in pictures. Make the KL 0,

then improve our parameter estimates to get a better likelihood

Bishop – Pattern Recognition and Machine Learning

slide-35
SLIDE 35

A different perspective

  • Consider the log-likelihood of a marginal distribution of the data 𝑦 in a generic

latent variable model with latent variable 𝑨 parameterized by 𝜄: ℓ 𝜄 ≜ ෍

𝑗=1 𝑂

log 𝑞 𝑦𝑗 𝜄 = ෍

𝑗=1 𝑂

log න 𝑞 𝑦𝑗, 𝑨𝑗 𝜄 𝐞𝑨𝑗

  • Estimating 𝜄 is difficult because we have a log outside of the integral, so it does

not act directly on the probability distribution (frequently in the exponential family)

  • If we observed 𝑨𝑗, then our log-likelihood would be:

ℓ𝑑 𝜄 ≜ ෍

𝑗=1 𝑂

log 𝑞(𝑦𝑗, 𝑨𝑗|𝜄) This is called the complete log-likelihood

slide-36
SLIDE 36

Expected Complete Log-Likelihood

  • We can take the expectation of this likelihood over a distribution of the

latent variable 𝑟 𝑨 : 𝔽𝑟 𝑨 ℓ𝑑 𝜄 = ෍

𝑗=1 𝑂

න 𝑟 𝑨𝑗 log 𝑞 𝑦𝑗, 𝑨𝑗 𝜄 d𝑨𝑗

  • This looks similar to marginalizing, but now the log is inside the integral, so

it’s easier to deal with

  • We can treat the latent variables as observed and solve this more easily

than directly solving the log-likelihood

  • Finding the 𝑟 that maximizes this is the E step of EM
  • Finding the 𝜄 that maximizes this is the M step of EM
slide-37
SLIDE 37

Back to Factor Analysis

  • For simplicity, assume data is centered. We want:

argmax𝑿,𝛀 log 𝑞 𝒀 𝑿, 𝛀 = argmax𝑿,𝛀 ෍

𝑗=1 𝑂

log 𝑞 𝒚𝒋 𝑿, 𝛀 = argmax𝑿,𝛀 ෍

𝑗=1 𝑂

log 𝒪 𝒚𝒋 𝟏, 𝛀 + 𝑿𝑿𝑈

  • No closed form solution in general (PPCA can be solved in closed

form)

  • 𝛀, 𝑿 get coupled together in the derivative and we can’t solve for

them analytically

slide-38
SLIDE 38

EM for Factor Analysis

argmax𝑿,𝛀 𝔽𝑟 𝑢 (𝒜) log 𝑞 𝒀, 𝒂 𝑿, 𝛀 = argmax𝑿,𝛀 ෍

𝑗=1 𝑂

𝔽𝑟 𝑢 (𝒜𝒋) log 𝑞 𝒚𝒋 𝒜𝒋, 𝑿, 𝛀 + 𝔽𝑟 𝑢 (𝒜𝒋) log 𝑞(𝒜𝒋) = argmax𝑿,𝛀 ෍

𝑗=1 𝑂

𝔽𝑟 𝑢 (𝒜𝒋) log 𝑞 𝒚𝒋 𝒜𝒋, 𝑿, 𝛀 = argmax𝑿,𝛀 ෍

𝑗=1 𝑂

𝔽𝑟 𝑢 (𝒜𝒋) log 𝒪(𝑿𝒜𝒋, 𝛀) = argmax𝑿,𝛀 const − 𝑂 2 log det(𝛀) − ෍

𝑗=1 𝑂

𝔽𝑟 𝑢 (𝒜𝒋) 1 2 𝒚𝒋 − 𝑿𝒜𝒋 𝑈𝛀−1 𝒚𝒋 − 𝑿𝒜𝒋 = argmax𝑿,𝛀 − 𝑂 2 log det(𝛀) − ෍

𝑗=1 𝑂

1 2 𝒚𝑗

𝑈𝛀−1𝒚𝑗 − 𝒚𝒋 𝑈𝛀−1𝑿𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝑗 + 1

2 tr 𝑿𝑈𝛀−1𝑿𝔽𝑟 𝑢

𝒜𝒋 𝒜𝒋𝒜𝒋 𝑈

  • We only need these 2 sufficient statistics to enable the M step.
  • In practice, sufficient statistics are often what we compute in the E step
slide-39
SLIDE 39

Factor Analysis E step

𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝒋 = 𝑯𝑿(𝒖−𝟐)𝑈𝛀(𝑢−1)−1𝒚𝑗 𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝒋𝒜𝒋

𝑈 = 𝑯 + 𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝒋 𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝒋 𝑈

Where 𝑯 = 𝑱 + 𝑿 𝑢−1 𝑈𝛀 𝑢−1 −1𝑿 𝑢−1

−1

This is derived via the Bayes rule for Gaussians

slide-40
SLIDE 40

Factor Analysis M step

𝑿(𝑢) ← ෍

𝑗=1 𝑂

𝒚𝑗 𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝒋 𝑈 ෍

𝑗=1 𝑂

𝔽𝑟 𝑢

𝒜𝒋 𝒜𝒋𝒜𝒋 𝑈 −1

𝛀(𝑢) ← diag 1 𝑂 ෍

𝑗=1 𝑂

𝒚𝒋𝒚𝒋

𝑈 − 𝑿(𝑢) 1

𝑂 ෍

𝑗=1 𝑂

𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝒋 𝒚𝑗

𝑈

slide-41
SLIDE 41

From EM to Variational Inference

  • In EM we alternately maximize the ELBO with respect to 𝜄 and

probability distribution (functional) 𝑟

  • In variational inference, we drop the distinction between hidden

variables and parameters of a distribution

  • I.e. we replace 𝑞(𝑦, 𝑨|𝜄) with 𝑞(𝑦, 𝑨). Effectively this puts a

probability distribution on the parameters 𝜾, then absorbs them into 𝑨

  • Fully Bayesian treatment instead of a point estimate for the

parameters

slide-42
SLIDE 42

Variational Inference

  • Now the ELBO is just a function of our weighting distribution ℒ(𝑟)
  • We assume a form for 𝑟 that we can optimize
  • For example mean field theory assumes 𝑟 factorizes:

𝑟 𝑎 = ෑ

𝑗=1 𝑁

𝑟𝑗(𝑎𝑗)

  • Then we optimize ℒ(𝑟) with respect to one of the terms while

holding the others constant, and repeat for all terms

  • By assuming a form for 𝑟 we approximate a (typically) intractable true

posterior

slide-43
SLIDE 43

Mean Field update derivation

ℒ 𝑟 = න 𝑟 𝑎 log 𝑞(𝑌, 𝑎) 𝑟(𝑎) 𝑒𝑎 = න 𝑟 𝑎 log 𝑞(𝑌, 𝑎) − 𝑟 𝑎 log 𝑟(𝑎) 𝑒𝑎 = න ෑ

𝑗

𝑟𝑗(𝑎𝑗) log 𝑞(𝑌, 𝑎) − ෍

𝑙

log 𝑟𝑙(𝑎𝑙) 𝑒𝑎 = න 𝑟𝑘(𝑎

𝑘) න ෑ 𝑗≠𝑘

𝑟𝑗(𝑎𝑗) log 𝑞(𝑌, 𝑎) − ෍

𝑙

log 𝑟𝑙(𝑎𝑙) 𝑒𝑎𝑗 𝑒𝑎

𝑘

= න 𝑟𝑘(𝑎

𝑘) න log 𝑞(𝑌, 𝑎) ෑ 𝑗≠𝑘

𝑟𝑗 𝑎𝑗 𝑒𝑎𝑗 − න ෑ

𝑗≠𝑘

𝑙

𝑟𝑗(𝑎𝑗) log 𝑟𝑙(𝑎𝑙) 𝑒𝑎𝑗 𝑒𝑎

𝑘

= න 𝑟𝑘(𝑎

𝑘) න log 𝑞(𝑌, 𝑎) ෑ 𝑗≠𝑘

𝑟𝑗 𝑎𝑗 𝑒𝑎𝑗 − log 𝑟𝑘(𝑎

𝑘) න ෑ 𝑗≠𝑘

𝑟𝑗(𝑎𝑗) 𝑒𝑎𝑗 𝑒𝑎

𝑘 + const

= න 𝑟𝑘(𝑎

𝑘) න log 𝑞(𝑌, 𝑎) ෑ 𝑗≠𝑘

𝑟𝑗 𝑎𝑗 𝑒𝑎𝑗 𝑒𝑎

𝑘 − න 𝑟𝑘 𝑎 𝑘 log 𝑟𝑘 𝑎 𝑘 𝑒𝑎 𝑘 + const

= න 𝑟𝑘 𝑎

𝑘 𝔽𝑗≠𝑘[log 𝑞(𝑌, 𝑎)] 𝑒𝑎 𝑘 − න 𝑟𝑘(𝑎 𝑘) log 𝑟𝑘 𝑎 𝑘 𝑒𝑎 𝑘 + const

slide-44
SLIDE 44

Mean Field update

𝑟𝑘 𝑎

𝑘 (𝑢)

← argmax𝑟𝑘(𝑎𝑘) න 𝑟𝑘 𝑎

𝑘 𝔽𝑗≠𝑘[log 𝑞(𝑌, 𝑎)] 𝑒𝑎 𝑘

− න 𝑟𝑘(𝑎

𝑘) log 𝑟𝑘 𝑎 𝑘 𝑒𝑎 𝑘

  • The point of this is not the update equations themselves, but the

general idea:

  • freeze some of the variables, compute expectations over those
  • update the rest using these expectations
slide-45
SLIDE 45

Why does Variational Inference work?

  • The argument is similar to the argument for EM
  • When expectations are computed using the current values for the

variables not being updated, we implicitly set the KL divergence between the weighting distributions and the posterior distributions to

  • The update then pushes up the data likelihood

Bishop – Pattern Recognition and Machine Learning

slide-46
SLIDE 46

Variational Autoencoder

  • Kingma & Welling: Auto-Encoding Variational Bayes proposes

maximizing the ELBO with a trick to make it differentiable

  • Discusses both the variational autoencoder model using parametric

distributions and fully Bayesian variational inference, but we will only discuss the variational autoencoder

slide-47
SLIDE 47

Problem Setup

  • Assume a generative model with a

latent variable distributed according to some distribution 𝑞(𝑨𝑗)

  • The observed variable is distributed

according to a conditional distribution 𝑞(𝑦𝑗|𝑨𝑗, 𝜄)

  • Note the similarity to the Factor

Analysis (FA) setup so far

𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)

slide-48
SLIDE 48

Problem Setup

  • We also create a weighting

distribution 𝑟(𝑨𝑗|𝑦𝑗, 𝜚)

  • This will play the same role as 𝑟(𝑨𝑗) in

the EM algorithm, as we will see.

  • Note that when we discussed EM, this

weighting distribution could be arbitrary: we choose to condition on 𝑦𝑗 here. This is a choice.

  • Why does this make sense?

𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)

slide-49
SLIDE 49

Using a conditional weighting distribution

  • There are many values of the latent variables that don’t matter in

practice – by conditioning on the observed variables, we emphasize the latent variable values we actually care about: the ones most likely given the observations

  • We would like to be able to encode our data into the latent variable
  • space. This conditional weighting distribution enables that encoding
slide-50
SLIDE 50

Problem setup

  • Implement 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) as a neural

network, this can also be seen as a probabilistic decoder

  • Implement 𝑟(𝑨𝑗|𝑦𝑗, 𝜚) as a neural

network, we also can see this as a probabilistic encoder

  • Sample 𝑨𝑗 from 𝑟(𝑨𝑗|𝑦𝑗, 𝜚) in the

middle

𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)

slide-51
SLIDE 51

Unpacking the encoder

  • We choose a family of distributions for our conditional distribution 𝑟. For example

Gaussian with diagonal covariance: 𝑟 𝑨𝑗 𝑦𝑗, 𝜚 = 𝒪 𝑨𝑗 𝜈 = 𝑣 𝑦𝑗, 𝑋

1 , Σ = diag(𝑡 𝑦𝑗, 𝑋 2 ) 𝑟(𝑨𝑗|𝑦𝑗, 𝜚)

𝒚𝒋 𝝂 = 𝒗 𝒚𝒋, 𝑿𝟐 𝚻 = 𝐞𝐣𝐛𝐡(𝒕 𝒚𝒋, 𝑿𝟑 )

slide-52
SLIDE 52

Unpacking the encoder

  • We create neural networks to predict the parameters of 𝑟 from our data
  • In this case, the outputs of our networks are 𝜈 and Σ

𝑟(𝑨𝑗|𝑦𝑗, 𝜚)

𝒚𝒋 𝝂 = 𝒗 𝒚𝒋, 𝑿𝟐 𝚻 = 𝐞𝐣𝐛𝐡(𝒕 𝒚𝒋, 𝑿𝟑 )

slide-53
SLIDE 53

Unpacking the encoder

  • We refer to the parameters of our networks, 𝑿𝟐 and 𝑿𝟑 collectively as 𝜚
  • Together, networks 𝒗 and 𝒕 parameterize a distribution, 𝑟(𝑨𝑗|𝑦𝑗, 𝜚), of the latent

variable 𝒜𝒋 that depends in a complicated, non-linear way on 𝒚𝒋

𝑟(𝑨𝑗|𝑦𝑗, 𝜚)

𝒚𝒋 𝝂 = 𝒗 𝒚𝒋, 𝑿𝟐 𝚻 = 𝐞𝐣𝐛𝐡(𝒕 𝒚𝒋, 𝑿𝟑 )

slide-54
SLIDE 54

Unpacking the decoder

  • The decoder follows the same logic, just swapping 𝒚𝒋 and 𝒜𝒋
  • We refer to the parameters of our networks, 𝑿𝟒 and 𝑿𝟓 collectively as 𝜄
  • Together, networks 𝒗𝒆 and 𝒕𝒆 parameterize a distribution, 𝑞(𝑦𝑗|𝑨𝑗, 𝜄), of the

latent variable 𝒚𝒋 that depends in a complicated, non-linear way on 𝒜𝒋

𝝂 = 𝒗𝒆 𝒜𝒋, 𝑿𝟒 𝚻 = 𝐞𝐣𝐛𝐡(𝒕𝒆 𝒜𝒋, 𝑿𝟓 )

𝑞(𝑦𝑗|𝑨𝑗, 𝜄)

𝒜𝒋~𝒓(𝒜𝒋|𝒚𝒋, 𝝔)

slide-55
SLIDE 55

Understanding the setup

  • Note that 𝑞 and 𝑟 do not have to use

the same distribution family, this was just an example

  • This basically looks like an

autoencoder, but the outputs of both the encoder and decoder are parameters of the distributions of the latent and observed variables respectively

  • We also have a sampling step in the

middle

𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)

slide-56
SLIDE 56

Using EM for training

  • Initialize 𝜄(0)
  • At each iteration 𝑢 = 1, … , 𝑈
  • E step: Hold 𝜄(𝑢−1) fixed, find 𝑟(𝑢) which maximizes ℒ 𝑟, 𝜄(𝑢−1)
  • M step: Hold 𝑟(𝑢) fixed, find 𝜄(𝑢) which maximizes ℒ 𝑟(𝑢), 𝜄
  • We will use a modified EM to train the model, but we will transform it

so we can use standard back propagation!

slide-57
SLIDE 57

Using EM for training

  • Initialize 𝜄(0)
  • At each iteration 𝑢 = 1, … , 𝑈
  • E step: Hold 𝜄(𝑢−1) fixed, find 𝜚(𝑢) which maximizes ℒ 𝜚, 𝜄 𝑢−1 , 𝑦
  • M step: Hold 𝜚(𝑢) fixed, find 𝜄(𝑢) which maximizes ℒ 𝜚(𝑢), 𝜄, 𝑦
  • First we modify the notation to account for our choice of using a

parametric, conditional distribution 𝑟

slide-58
SLIDE 58

Using EM for training

  • Initialize 𝜄(0)
  • At each iteration 𝑢 = 1, … , 𝑈
  • E step: Hold 𝜄(𝑢−1) fixed, find

𝜖ℒ 𝜖𝜚 to increase ℒ 𝜚, 𝜄 𝑢−1 , 𝑦

  • M step: Hold 𝜚(𝑢) fixed, find

𝜖ℒ 𝜖𝜄 to increase ℒ 𝜚(𝑢), 𝜄, 𝑦

  • Instead of fully maximizing at each iteration, we just take a step in the

direction that increases ℒ

slide-59
SLIDE 59

Computing the loss

  • We need to compute the gradient for each mini-batch with 𝐶 data samples using the ELBO/variational

bound ℒ 𝜚, 𝜄, 𝑦𝑗 as the loss ෍

𝑗=1 𝐶

ℒ 𝜚, 𝜄, 𝑦𝑗 = ෍

𝑗=1 𝐶

−KL 𝑟 𝑨𝑗|𝑦𝑗, 𝜚 || 𝑞 𝑦𝑗, 𝑨𝑗|𝜄 = ෍

𝑗=1 𝐶

−𝔽𝑟 𝑨𝑗 𝑦𝑗, 𝜚 log 𝑟 𝑨𝑗 𝑦𝑗, 𝜚 𝑞 𝑦𝑗, 𝑨𝑗|𝜄

  • Notice that this involves an intractable integral over all values of 𝑨
  • We can use Monte Carlo sampling to approximate the expectation using 𝑀 samples from 𝑟(𝑨𝑗|𝑦𝑗, 𝜚):

𝔽𝑟(𝑨𝑗|𝑦𝑗,𝜚) 𝑔 𝑨𝑗 ≃ 1 𝑀 ෍

𝑘=1 𝑀

𝑔(𝑨𝑗,𝑘) ℒ 𝜚, 𝜄, 𝑦𝑗 ≃ ሚ ℒ𝐵 𝜚, 𝜄, 𝑦𝑗 = 1 𝑀 ෍

𝑘=1 𝑀

log 𝑞 𝑦𝑗, 𝑨𝑗,𝑘|𝜄 − log 𝑟(𝑨𝑗,𝑘|𝑦𝑗, 𝜚)

slide-60
SLIDE 60

A lower variance estimator of the loss

  • We can rewrite

ℒ 𝜚, 𝜄, 𝑦 = −KL 𝑟 𝑨 𝑦, 𝜚 || 𝑞 𝑦, 𝑨|𝜄 = − න 𝑟 𝑨 𝑦, 𝜚 log 𝑟 𝑨 𝑦, 𝜚 𝑞 𝑦|𝑨, 𝜄 𝑞(𝑨) 𝐞𝑨 = − න 𝑟 𝑨 𝑦, 𝜚 log 𝑟 𝑨 𝑦, 𝜚 𝑞(𝑨) − log 𝑞 𝑦|𝑨, 𝜄 𝐞𝑨 = = −KL 𝑟 𝑨 𝑦, 𝜚 || 𝑞 𝑨 + 𝔽𝑟 𝑨 𝑦, 𝜚 log 𝑞 𝑦|𝑨, 𝜄

  • The first term can be computed analytically for some families of distributions (e.g.

Gaussian); only the second term must be estimated ℒ 𝜚, 𝜄, 𝑦𝑗 ≃ ሚ ℒ𝐶 𝜚, 𝜄, 𝑦𝑗 = −KL 𝑟 𝑨𝑗|𝑦𝑗, 𝜚 || 𝑞 𝑨𝑗 + 1 𝑀 ෍

𝑘=1 𝑀

log 𝑞 𝑦𝑗|𝑨𝑗,𝑘, 𝜄

slide-61
SLIDE 61

Full EM training procedure (not really used)

  • For 𝑢 = 1: 𝑐: 𝑈
  • Estimate 𝜖ℒ

𝜖𝜚 (How do we do this? We’ll get to it shortly)

  • Update 𝜚
  • Estimate 𝜖ℒ

𝜖𝜄 :

  • Initialize Δ𝜄 = 0
  • For 𝑗 = 𝑢: 𝑢 + 𝑐 − 1
  • Compute the outputs of the encoder (parameters of 𝑟) for 𝑦𝑗
  • For ℓ = 1, … 𝑀
  • Sample 𝑨𝑗 ~ 𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
  • Δ𝜄𝑗,ℓ ← Run forward/backward pass on the decoder

(standard back propagation) using either ሚ ℒ𝐵 or ሚ ℒ𝐶 as the loss

  • Δ𝜄 ← Δ𝜄 + Δ𝜄𝑗,ℓ
  • Update 𝜄

𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)

slide-62
SLIDE 62

Full EM training procedure (not really used)

  • For 𝑢 = 1: 𝑐: 𝑈
  • Estimate 𝜖ℒ

𝜖𝜚 (How do we do this? We’ll get to it shortly)

  • Update 𝜚
  • Estimate 𝜖ℒ

𝜖𝜄 :

  • Initialize Δ𝜄 = 0
  • For 𝑗 = 𝑢: 𝑢 + 𝑐 − 1
  • Compute the outputs of the encoder (parameters of 𝑟) for 𝑦𝑗
  • Sample 𝑨𝑗 ~ 𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
  • Δ𝜄𝑗 ← Run forward/backward pass on the decoder (standard

back propagation) using either ሚ ℒ𝐵 or ሚ ℒ𝐶 as the loss

  • Δ𝜄 ← Δ𝜄 + Δ𝜄𝑗
  • Update 𝜄

𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚) First simplification: Let 𝑀 = 1. We just want a stochastic estimate of the

  • gradient. With a large enough 𝐶,

we get enough samples from 𝑟(𝑨𝑗|𝑦𝑗, 𝜚)

slide-63
SLIDE 63

The E step

  • We can use standard back

propagation to estimate

𝜖ℒ 𝜖𝜄

  • How do we estimate

𝜖ℒ 𝜖𝜚?

  • The sampling step blocks the gradient

flow

  • Computing the derivatives through 𝑟

via the chain rule gives a very high variance estimate of the gradient

𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)

?

slide-64
SLIDE 64

Reparameterization

  • Instead of drawing 𝑨𝑗 ~ 𝑟(𝑨𝑗|𝑦𝑗, 𝜚),

let 𝑨𝑗 = g(𝜗𝑗, 𝑦𝑗, 𝜚), and draw 𝜗𝑗 ~ 𝑞(𝜗)

  • 𝑨𝑗 is still a random variable but depends on 𝜚 deterministically
  • Replace 𝔽𝑟(𝑨𝑗|𝑦𝑗,𝜚) 𝑔 𝑨𝑗

with 𝔽𝑞(𝜗)[𝑔 g 𝜗𝑗, 𝑦𝑗, 𝜚 ]

  • Example – univariate normal:

𝑏 ~ 𝒪 𝜈, 𝜏2 is equivalent to 𝑏 = g 𝜗 , 𝜗 ~𝒪 0, 1 , g 𝑐 ≜ 𝜈 + 𝜏𝑐

slide-65
SLIDE 65

Reparameterization

𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)

?

𝑕(𝜗𝑗, 𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗 = 𝑕(𝜗𝑗, 𝑦𝑗, 𝜚) 𝜗𝑗 ~ 𝑞(𝜗)

slide-66
SLIDE 66

Full EM training procedure (not really used)

  • For 𝑢 = 1: 𝑐: 𝑈
  • E Step
  • Estimate

𝜖ℒ 𝜖𝜚 using standard back

propagation with either ሚ ℒ𝐵 or ሚ ℒ𝐶 as the loss

  • Update 𝜚
  • M Step
  • Estimate

𝜖ℒ 𝜖𝜄 using standard back

propagation with either ሚ ℒ𝐵 or ሚ ℒ𝐶 as the loss

  • Update 𝜄

𝑕(𝜗𝑗, 𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗 = 𝑕(𝜗𝑗, 𝑦𝑗, 𝜚) 𝜗𝑗 ~𝑞(𝜗)

slide-67
SLIDE 67

Full training procedure

  • For 𝑢 = 1: 𝑐: 𝑈
  • Estimate

𝜖ℒ 𝜖𝜚 , 𝜖ℒ 𝜖𝜄 with either ሚ

ℒ𝐵 or ሚ ℒ𝐶 as the loss

  • Update 𝜚, 𝜄
  • Final simplification: update all of the

parameters at the same time instead of using separate E, M steps

  • This is standard back propagation. Just use

− ሚ ℒ𝐵 or − ሚ ℒ𝐶 as the loss, and run your favorite SGD variant

𝑕(𝜗𝑗, 𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗 = 𝑕(𝜗𝑗, 𝑦𝑗, 𝜚) 𝜗𝑗 ~𝑞(𝜗)

slide-68
SLIDE 68

Running the model on new data

  • To get a MAP estimate of the latent variables, just use the mean
  • utput by the encoder (for a Gaussian distribution)
  • No need to take a sample
  • Give the mean to the decoder
  • At test time, this is used just as an auto-encoder
  • You can optionally take multiple samples of the latent variables to

estimate the uncertainty

slide-69
SLIDE 69

Relationship to Factor Analysis

  • VAE performs probabilistic, non-linear

dimensionality reduction

  • It uses a generative model with a latent

variable distributed according to some prior distribution 𝑞(𝑨𝑗)

  • The observed variable is distributed

according to a conditional distribution 𝑞(𝑦𝑗|𝑨𝑗, 𝜄)

  • Training is approximately running

expectation maximization to maximize the data likelihood

  • This can be seen as a non-linear version
  • f Factor Analysis

𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)

slide-70
SLIDE 70

Regularization by a prior

  • Looking at the form of ℒ we used to justify ሚ

ℒ𝐶 gives us additional insight ℒ 𝜚, 𝜄, 𝑦 = −KL 𝑟 𝑨 𝑦, 𝜚 || 𝑞 𝑨 + 𝔽𝑟 𝑨 𝑦, 𝜚 log 𝑞 𝑦|𝑨, 𝜄

  • We are making the latent distribution as close as possible to a prior
  • n 𝑨
  • While maximizing the conditional likelihood of the data under our

model

  • In other words this is an approximation to Maximum Likelihood

Estimation regularized by a prior on the latent space

slide-71
SLIDE 71

Practical advantages of a VAE vs. an AE

  • The prior on the latent space:
  • Allows you to inject domain knowledge
  • Can make the latent space more interpretable
  • The VAE also makes it possible to estimate the variance/uncertainty in

the predictions

slide-72
SLIDE 72

Interpreting the latent space

https://arxiv.org/pdf/1610.00291.pdf

slide-73
SLIDE 73

Requirements of the VAE

  • Note that the VAE requires 2 tractable distributions to be used:
  • The prior distribution 𝑞(𝑨) must be easy to sample from
  • The conditional likelihood 𝑞 𝑦|𝑨, 𝜄 must be computable
  • In practice this means that the 2 distributions of interest are often

simple, for example uniform, Gaussian, or even isotropic Gaussian

slide-74
SLIDE 74

The blurry image problem

https://blog.openai.com/generative-models/

  • The samples from the VAE

look blurry

  • Three plausible

explanations for this

  • Maximizing the

likelihood

  • Restrictions on the

family of distributions

  • The lower bound

approximation

slide-75
SLIDE 75

The maximum likelihood explanation

https://arxiv.org/pdf/1701.00160.pdf

  • Recent evidence

suggests that this is not actually the problem

  • GANs can be trained

with maximum likelihood and still generate sharp examples

slide-76
SLIDE 76

Investigations of blurriness

  • Recent investigations suggest that both the simple probability

distributions and the variational approximation lead to blurry images

  • Kingma & colleages: Improving Variational Inference with Inverse

Autoregressive Flow

  • Zhao & colleagues: Towards a Deeper Understanding of Variational

Autoencoding Models

  • Nowozin & colleagues: f-gan: Training generative neural samplers

using variational divergence minimization