Variational Autoencoders Recap: Story so far A classification MLP - - PowerPoint PPT Presentation
Variational Autoencoders Recap: Story so far A classification MLP - - PowerPoint PPT Presentation
Variational Autoencoders Recap: Story so far A classification MLP actually comprises two components A feature extraction network that converts the inputs into linearly separable features Or nearly linearly separable features
Recap: Story so far
- A classification MLP actually comprises two components
- A “feature extraction network” that converts the inputs into linearly
separable features
- Or nearly linearly separable features
- A final linear classifier that operates on the linearly separable features
- Neural networks can be used to perform linear or non-linear PCA
- “Autoencoders”
- Can also be used to compose constructive dictionaries for data
- Which, in turn can be used to model data distributions
𝑧1 𝑧2
Recap: The penultimate layer
- The network up to the output layer may be viewed as a transformation that
transforms data from non-linear classes to linearly separable features
- We can now attach any linear classifier above it for perfect classification
- Need not be a perceptron
- In fact, slapping on an SVM on top of the features may be more generalizable!
x1 x2 y2 y1
Recap: The behavior of the layers
Recap: Auto-encoders and PCA
5
𝐲 ො 𝐲 𝒙 𝒙𝑼
Training: Learning 𝑋 by minimizing L2 divergence ො x = 𝑥𝑈𝑥x 𝑒𝑗𝑤 ො x, x = x − ො x 2 = x − w𝑈𝑥x 2 𝑋 = argmin
𝑋
𝐹 x − w𝑈𝑥x 2 𝑋 = argmin
𝑋
𝐹 𝑒𝑗𝑤 ො x, x
Recap: Auto-encoders and PCA
- The autoencoder finds the direction of maximum energy
- Variance if the input is a zero-mean RV
- All input vectors are mapped onto a point on the principal
axis
6
𝐲 ො 𝐲 𝒙 𝒙𝑼
Recap: Auto-encoders and PCA
- Varying the hidden layer value only generates data along
the learned manifold
- May be poorly learned
- Any input will result in an output along the learned manifold
DECODER
Recap: Learning a data-manifold
- The decoder represents a source-specific generative
dictionary
- Exciting it will produce typical data from the source!
8
Sax dictionary
Overview
- Just as autoencoders can be viewed as performing a non-linear PCA,
variational autoencoders can be viewed as performing a non-linear Factor Analysis (FA)
- Variational autoencoders (VAEs) get their name from variational
inference, a technique that can be used for parameter estimation
- We will introduce Factor Analysis, variational inference and
expectation maximization, and finally VAEs
Why Generative Models? Training data
- Unsupervised/Semi-supervised learning: More training data available
- E.g. all of the videos on YouTube
Why generative models? Many right answers
- Caption -> Image
A man in an orange jacket with sunglasses and a hat skis down a hill
- Outline -> Image
https://openreview.net/pdf?id=Hyvw0L9el https://arxiv.org/abs/1611.07004
Why generative models? Intrinsic to task
Example: Super resolution https://arxiv.org/abs/1609.04802
Why generative models? Insight
https://bmcbioinformatics.biomedcentral.c
- m/articles/10.1186/1471-2105-12-327
- What kind of structure can we find in complex
- bservations (MEG recording of brain activity
above, gene-expression network to the left)?
- Is there a low dimensional manifold underlying
these complex observations?
- What can we learn about the brain, cellular
function, etc. if we know more about these manifolds?
Factor Analysis
- Generative model: Assumes that data are generated from real valued
latent variables
Bishop – Pattern Recognition and Machine Learning
Factor Analysis model
Factor analysis assumes a generative model
- where the 𝑗𝑢ℎ observation, 𝒚𝒋 ∈ ℝ𝐸 is conditioned on
- a vector of real valued latent variables 𝒜𝒋 ∈ ℝ𝑀.
Here we assume the prior distribution is Gaussian: 𝑞 𝒜𝒋 = 𝒪(𝒜𝒋|𝝂𝟏, 𝚻𝟏) We also will use a Gaussian for the data likelihood: 𝑞 𝒚𝒋 𝒜𝒋, 𝑿, 𝝂, 𝛀 = 𝒪(𝑿𝒜𝒋 + 𝝂, 𝛀) Where 𝑿 ∈ ℝ𝐸×𝑀, 𝛀 ∈ ℝ𝐸×𝐸, 𝛀 is diagonal
Marginal distribution of observed 𝒚𝒋
𝑞 𝒚𝒋 𝑿, 𝝂, 𝛀 = න 𝒪(𝑿𝒜𝒋 + 𝝂, 𝛀) 𝒪 𝒜𝒋 𝝂𝟏, 𝚻𝟏 𝐞𝒜𝒋 = 𝒪 𝒚𝒋 𝑿𝝂𝟏 + 𝝂, 𝛀 + 𝑿 𝚻𝟏𝑿𝑈 Note that we can rewrite this as: 𝑞 𝒚𝒋 𝑿, ෝ 𝝂, 𝛀 = 𝒪 𝒚𝒋 ෝ 𝝂, 𝛀 + 𝑿 𝑿𝑈 Where ෝ 𝝂 = 𝑿𝝂𝟏 + 𝝂 and 𝑿 = 𝑿𝚻𝟏
−1
2.
Thus without loss of generality (since 𝝂𝟏, 𝚻𝟏 are absorbed into learnable parameters) we let: 𝑞 𝒜𝒋 = 𝒪 𝒜𝒋 𝟏, 𝑱 And find: 𝑞 𝒚𝒋 𝑿, 𝝂, 𝛀 = 𝒪 𝒚𝒋 𝝂, 𝛀 + 𝑿𝑿𝑈
Marginal distribution interpretation
- We can see from 𝑞 𝒚𝒋 𝑿, 𝝂, 𝛀 = 𝒪 𝒚𝒋 𝝂, 𝛀 + 𝑿𝑿𝑈 that the
covariance matrix of the data distribution is broken into 2 terms
- A diagonal part 𝛀: variance not shared between variables
- A low rank matrix 𝑿𝑿𝑈: shared variance due to latent factors
Special Case: Probabilistic PCA (PPCA)
- Probabilistic PCA is a special case of Factor Analysis
- We further restrict 𝛀 = 𝜏2𝑱 (assume isotropic independent variance)
- Possible to show that when the data are centered (𝝂 = 0), the limiting
case where 𝜏 → 0 gives back the same solution for 𝑿 as PCA
- Factor analysis is a generalization of PCA that models non-shared
variance (can think of this as noise in some situations, or individual variation in others)
Inference in FA
- To find the parameters of the FA model, we use the Expectation
Maximization (EM) algorithm
- EM is very similar to variational inference
- We’ll derive EM by first finding a lower bound on the log-likelihood
we want to maximize, and then maximizing this lower bound
Evidence Lower Bound decomposition
- For any distributions 𝑟 𝑨 , 𝑞(𝑨) we have:
KL 𝑟 𝑨 || 𝑞 𝑨 ≜ න 𝑟 𝑨 log 𝑟(𝑨) 𝑞(𝑨) 𝐞𝑨
- Consider the KL divergence of an arbitrary weighting distribution
𝑟 𝑨 from a conditional distribution 𝑞 𝑨|𝑦, 𝜄 : KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄 ≜ න 𝑟 𝑨 log 𝑟(𝑨) 𝑞(𝑨|𝑦, 𝜄) 𝐞𝑨 = න 𝑟 𝑨 [log 𝑟 𝑨 − log 𝑞(𝑨|𝑦, 𝜄)] 𝐞𝑨
Applying Bayes
log 𝑞 𝑨 𝑦, 𝜄 = log 𝑞 𝑦 𝑨, 𝜄 𝑞(𝑨|𝜄) 𝑞(𝑦|𝜄) = log 𝑞 𝑦 𝑨, 𝜄 + log 𝑞 𝑨 𝜄 − log 𝑞 𝑦 𝜄 Then: KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄 = න 𝑟 𝑨 [log 𝑟 𝑨 − log 𝑞(𝑨|𝑦, 𝜄)] 𝐞𝑨 = න 𝑟 𝑨 log 𝑟 𝑨 − log 𝑞 𝑦 𝑨, 𝜄 − log 𝑞 𝑨 𝜄 + log 𝑞 𝑦 𝜄 𝐞𝑨
Rewriting the divergence
- Since the last term does not depend on z, and we know 𝑟 𝑨 d𝑨 = 1, we can pull it out of the
integration: න 𝑟 𝑨 log 𝑟 𝑨 − log 𝑞 𝑦 𝑨, 𝜄 − log 𝑞 𝑨 𝜄 + log 𝑞 𝑦 𝜄 𝐞𝑨 = න 𝑟 𝑨 log 𝑟 𝑨 − log 𝑞 𝑦 𝑨, 𝜄 − log 𝑞 𝑨 𝜄 𝐞𝑨 + log 𝑞 𝑦 𝜄 = න 𝑟 𝑨 log 𝑟(𝑨) 𝑞 𝑦 𝑨, 𝜄 𝑞(𝑨, 𝜄) 𝐞𝑨 + log 𝑞 𝑦 𝜄 = න 𝑟 𝑨 log 𝑟(𝑨) 𝑞(𝑦, 𝑨 |𝜄) 𝐞𝑨 + log 𝑞 𝑦 𝜄 Then we have: KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄 = KL 𝑟 𝑨 || 𝑞 𝑦, 𝑨 |𝜄 + log 𝑞 𝑦 𝜄
Evidence Lower Bound
- From basic probability we have:
KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄 = KL 𝑟 𝑨 || 𝑞 𝑦, 𝑨 |𝜄 + log 𝑞 𝑦 𝜄
- We can rearrange the terms to get the following decomposition:
log 𝑞 𝑦 𝜄 = KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄 − KL 𝑟 𝑨 || 𝑞 𝑦, 𝑨 |𝜄
- We define the evidence lower bound (ELBO) as:
ℒ 𝑟, 𝜄 ≜ −KL 𝑟 𝑨 || 𝑞 𝑦, 𝑨 |𝜄 Then: log 𝑞 𝑦 𝜄 = KL 𝑟 𝑨 ||𝑞 𝑨|𝑦, 𝜄 + ℒ 𝑟, 𝜄
Why the name evidence lower bound?
- Rearranging the decomposition
log 𝑞 𝑦 𝜄 = KL 𝑟 𝑨 ||𝑞 𝑨|𝑦, 𝜄 + ℒ 𝑟, 𝜄
- we have
ℒ 𝑟, 𝜄 = log 𝑞 𝑦 𝜄 − KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄
- Since KL 𝑟 𝑨 ||𝑞 𝑨|𝑦, 𝜄
≥ 0, ℒ 𝑟, 𝜄 is a lower bound on the log- likelihood we want to maximize
- 𝑞 𝑦 𝜄 is sometimes called the evidence
- When is this bound tight? When 𝑟 𝑨 = 𝑞 𝑨|𝑦, 𝜄
- The ELBO is also sometimes called the variational bound
Visualizing ELBO decomposition
- Note: all we have done so far is decompose the log
probability of the data, we still have exact equality
- This holds for any distribution 𝑟
Bishop – Pattern Recognition and Machine Learning
Expectation Maximization
- Expectation Maximization alternately optimizes the ELBO, ℒ 𝑟, 𝜄 ,
with respect to 𝑟 (the E step) and 𝜄 (the M step)
- Initialize 𝜄(0)
- At each iteration 𝑢 = 1, …
- E step: Hold 𝜄(𝑢−1) fixed, find 𝑟(𝑢) which maximizes ℒ 𝑟, 𝜄(𝑢−1)
- M step: Hold 𝑟(𝑢) fixed, find 𝜄(𝑢) which maximizes ℒ 𝑟(𝑢), 𝜄
The E step
- Suppose we are at iteration 𝑢 of our algorithm. How do we maximize
ℒ 𝑟, 𝜄(𝑢−1) with respect to 𝑟? We know that:
argmax𝑟 ℒ 𝑟, 𝜄(𝑢−1) = argmax𝑟 log 𝑞 𝑦|𝜄 𝑢−1 − KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄(𝑢−1)
Bishop – Pattern Recognition and Machine Learning
The E step
- Suppose we are at iteration 𝑢 of our algorithm. How do we maximize
ℒ 𝑟, 𝜄(𝑢−1) with respect to 𝑟? We know that:
argmax𝑟 ℒ 𝑟, 𝜄(𝑢−1) = argmax𝑟 log 𝑞 𝑦|𝜄 𝑢−1 − KL 𝑟 𝑨 || 𝑞 𝑨|𝑦, 𝜄(𝑢−1)
- The first term does not involve 𝑟, and we know the KL
divergence must be non-negative
- The best we can do is to make the KL divergence 0
- Thus the solution is to set 𝒓 𝒖 𝒜 ← 𝒒 𝒜 𝒚, 𝜾 𝒖−𝟐
Bishop – Pattern Recognition and Machine Learning
The E step
- Suppose we are at iteration 𝑢 of our algorithm. How do we maximize
ℒ 𝑟, 𝜄(𝑢−1) with respect to 𝑟? 𝒓 𝒖 𝒜 ← 𝒒 𝒜 𝒚, 𝜾 𝒖−𝟐
Bishop – Pattern Recognition and Machine Learning
The M step
- Fixing 𝑟 𝑢 𝑨 we now solve:
argmax𝜄 ℒ 𝑟(𝑢), 𝜄 = argmax𝜄 −KL 𝑟(𝑢) 𝑨 || 𝑞 𝑦, 𝑨|𝜄 = argmax𝜄 − න 𝑟(𝑢) 𝑨 log 𝑟(𝑢) 𝑨 𝑞 𝑦, 𝑨|𝜄 𝐞𝑨 = argmax𝜄 න 𝑟(𝑢) 𝑨 log 𝑞 𝑦, 𝑨 𝜄 − log 𝑟(𝑢) 𝑨 𝐞𝑨 = argmax𝜄 න 𝑟(𝑢) 𝑨 log 𝑞 𝑦, 𝑨 𝜄 − 𝑟(𝑢) 𝑨 log 𝑟(𝑢) 𝑨 𝐞𝑨 = argmax𝜄 න 𝑟(𝑢) 𝑨 log 𝑞 𝑦, 𝑨 𝜄 𝐞𝑨 = argmax𝜄 𝔽𝑟 𝑢 (𝑨) log 𝑞 𝑦, 𝑨 𝜄
Constant w.r.t. 𝜄
The M step
- After applying the E step, we increase the likelihood of the data by finding better
parameters according to: 𝜄(𝑢) ← 𝐛𝐬𝐡𝐧𝐛𝐲𝜾 𝔽𝒓 𝒖 (𝒜) 𝐦𝐩𝐡 𝒒 𝒚, 𝒜 𝜾
Bishop – Pattern Recognition and Machine Learning
EM algorithm
- Initialize 𝜄(0)
- At each iteration 𝑢 = 1, …
- E step: Update 𝑟 𝑢 𝑨 ← 𝑞 𝑨 𝑦, 𝜄 𝑢−1
- M step: Update 𝜄(𝑢) ← argmax𝜄 𝔽𝑟 𝑢 (𝑨) log 𝑞 𝑦, 𝑨 𝜄
Why does EM work?
- EM does coordinate ascent on the ELBO, ℒ 𝑟, 𝜄
- Each iteration increases the log-likelihood until 𝑟 𝑢 converges (i.e. we
reach a local maximum)!
- Simple to prove
Notice after the E step: ℒ 𝑟 𝑢 , 𝜄(𝑢−1) = log 𝑞(𝑦|𝜄(𝑢−1)) − KL 𝑞 𝑨|𝑦, 𝜄 𝑢−1 || 𝑞 𝑨|𝑦, 𝜄 𝑢−1 = log 𝑞(𝑦|𝜄(𝑢−1)) The ELBO is tight! By definition of argmax in the M step: ℒ 𝑟 𝑢 , 𝜄(𝑢) ≥ ℒ 𝑟 𝑢 , 𝜄(𝑢−1) By simple substitution: ℒ 𝑟 𝑢 , 𝜄(𝑢) ≥ log 𝑞 𝑦 𝜄 𝑢−1 Rewriting the left hand side: log 𝑞(𝑦|𝜄(𝑢)) − KL 𝑞 𝑨|𝑦, 𝜄 𝑢−1 || 𝑞 𝑨|𝑦, 𝜄 𝑢 ≥ log 𝑞 𝑦 𝜄 𝑢−1 Noting that KL is non-negative: 𝐦𝐩𝐡 𝒒 𝒚 𝜾 𝒖 ≥ 𝐦𝐩𝐡 𝒒 𝒚 𝜾 𝒖−𝟐
Why does EM work?
- This proof is saying the same thing we saw in pictures. Make the KL 0,
then improve our parameter estimates to get a better likelihood
Bishop – Pattern Recognition and Machine Learning
A different perspective
- Consider the log-likelihood of a marginal distribution of the data 𝑦 in a generic
latent variable model with latent variable 𝑨 parameterized by 𝜄: ℓ 𝜄 ≜
𝑗=1 𝑂
log 𝑞 𝑦𝑗 𝜄 =
𝑗=1 𝑂
log න 𝑞 𝑦𝑗, 𝑨𝑗 𝜄 𝐞𝑨𝑗
- Estimating 𝜄 is difficult because we have a log outside of the integral, so it does
not act directly on the probability distribution (frequently in the exponential family)
- If we observed 𝑨𝑗, then our log-likelihood would be:
ℓ𝑑 𝜄 ≜
𝑗=1 𝑂
log 𝑞(𝑦𝑗, 𝑨𝑗|𝜄) This is called the complete log-likelihood
Expected Complete Log-Likelihood
- We can take the expectation of this likelihood over a distribution of the
latent variable 𝑟 𝑨 : 𝔽𝑟 𝑨 ℓ𝑑 𝜄 =
𝑗=1 𝑂
න 𝑟 𝑨𝑗 log 𝑞 𝑦𝑗, 𝑨𝑗 𝜄 d𝑨𝑗
- This looks similar to marginalizing, but now the log is inside the integral, so
it’s easier to deal with
- We can treat the latent variables as observed and solve this more easily
than directly solving the log-likelihood
- Finding the 𝑟 that maximizes this is the E step of EM
- Finding the 𝜄 that maximizes this is the M step of EM
Back to Factor Analysis
- For simplicity, assume data is centered. We want:
argmax𝑿,𝛀 log 𝑞 𝒀 𝑿, 𝛀 = argmax𝑿,𝛀
𝑗=1 𝑂
log 𝑞 𝒚𝒋 𝑿, 𝛀 = argmax𝑿,𝛀
𝑗=1 𝑂
log 𝒪 𝒚𝒋 𝟏, 𝛀 + 𝑿𝑿𝑈
- No closed form solution in general (PPCA can be solved in closed
form)
- 𝛀, 𝑿 get coupled together in the derivative and we can’t solve for
them analytically
EM for Factor Analysis
argmax𝑿,𝛀 𝔽𝑟 𝑢 (𝒜) log 𝑞 𝒀, 𝒂 𝑿, 𝛀 = argmax𝑿,𝛀
𝑗=1 𝑂
𝔽𝑟 𝑢 (𝒜𝒋) log 𝑞 𝒚𝒋 𝒜𝒋, 𝑿, 𝛀 + 𝔽𝑟 𝑢 (𝒜𝒋) log 𝑞(𝒜𝒋) = argmax𝑿,𝛀
𝑗=1 𝑂
𝔽𝑟 𝑢 (𝒜𝒋) log 𝑞 𝒚𝒋 𝒜𝒋, 𝑿, 𝛀 = argmax𝑿,𝛀
𝑗=1 𝑂
𝔽𝑟 𝑢 (𝒜𝒋) log 𝒪(𝑿𝒜𝒋, 𝛀) = argmax𝑿,𝛀 const − 𝑂 2 log det(𝛀) −
𝑗=1 𝑂
𝔽𝑟 𝑢 (𝒜𝒋) 1 2 𝒚𝒋 − 𝑿𝒜𝒋 𝑈𝛀−1 𝒚𝒋 − 𝑿𝒜𝒋 = argmax𝑿,𝛀 − 𝑂 2 log det(𝛀) −
𝑗=1 𝑂
1 2 𝒚𝑗
𝑈𝛀−1𝒚𝑗 − 𝒚𝒋 𝑈𝛀−1𝑿𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝑗 + 1
2 tr 𝑿𝑈𝛀−1𝑿𝔽𝑟 𝑢
𝒜𝒋 𝒜𝒋𝒜𝒋 𝑈
- We only need these 2 sufficient statistics to enable the M step.
- In practice, sufficient statistics are often what we compute in the E step
Factor Analysis E step
𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝒋 = 𝑯𝑿(𝒖−𝟐)𝑈𝛀(𝑢−1)−1𝒚𝑗 𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝒋𝒜𝒋
𝑈 = 𝑯 + 𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝒋 𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝒋 𝑈
Where 𝑯 = 𝑱 + 𝑿 𝑢−1 𝑈𝛀 𝑢−1 −1𝑿 𝑢−1
−1
This is derived via the Bayes rule for Gaussians
Factor Analysis M step
𝑿(𝑢) ←
𝑗=1 𝑂
𝒚𝑗 𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝒋 𝑈
𝑗=1 𝑂
𝔽𝑟 𝑢
𝒜𝒋 𝒜𝒋𝒜𝒋 𝑈 −1
𝛀(𝑢) ← diag 1 𝑂
𝑗=1 𝑂
𝒚𝒋𝒚𝒋
𝑈 − 𝑿(𝑢) 1
𝑂
𝑗=1 𝑂
𝔽𝑟 𝑢 (𝒜𝒋) 𝒜𝒋 𝒚𝑗
𝑈
From EM to Variational Inference
- In EM we alternately maximize the ELBO with respect to 𝜄 and
probability distribution (functional) 𝑟
- In variational inference, we drop the distinction between hidden
variables and parameters of a distribution
- I.e. we replace 𝑞(𝑦, 𝑨|𝜄) with 𝑞(𝑦, 𝑨). Effectively this puts a
probability distribution on the parameters 𝜾, then absorbs them into 𝑨
- Fully Bayesian treatment instead of a point estimate for the
parameters
Variational Inference
- Now the ELBO is just a function of our weighting distribution ℒ(𝑟)
- We assume a form for 𝑟 that we can optimize
- For example mean field theory assumes 𝑟 factorizes:
𝑟 𝑎 = ෑ
𝑗=1 𝑁
𝑟𝑗(𝑎𝑗)
- Then we optimize ℒ(𝑟) with respect to one of the terms while
holding the others constant, and repeat for all terms
- By assuming a form for 𝑟 we approximate a (typically) intractable true
posterior
Mean Field update derivation
ℒ 𝑟 = න 𝑟 𝑎 log 𝑞(𝑌, 𝑎) 𝑟(𝑎) 𝑒𝑎 = න 𝑟 𝑎 log 𝑞(𝑌, 𝑎) − 𝑟 𝑎 log 𝑟(𝑎) 𝑒𝑎 = න ෑ
𝑗
𝑟𝑗(𝑎𝑗) log 𝑞(𝑌, 𝑎) −
𝑙
log 𝑟𝑙(𝑎𝑙) 𝑒𝑎 = න 𝑟𝑘(𝑎
𝑘) න ෑ 𝑗≠𝑘
𝑟𝑗(𝑎𝑗) log 𝑞(𝑌, 𝑎) −
𝑙
log 𝑟𝑙(𝑎𝑙) 𝑒𝑎𝑗 𝑒𝑎
𝑘
= න 𝑟𝑘(𝑎
𝑘) න log 𝑞(𝑌, 𝑎) ෑ 𝑗≠𝑘
𝑟𝑗 𝑎𝑗 𝑒𝑎𝑗 − න ෑ
𝑗≠𝑘
𝑙
𝑟𝑗(𝑎𝑗) log 𝑟𝑙(𝑎𝑙) 𝑒𝑎𝑗 𝑒𝑎
𝑘
= න 𝑟𝑘(𝑎
𝑘) න log 𝑞(𝑌, 𝑎) ෑ 𝑗≠𝑘
𝑟𝑗 𝑎𝑗 𝑒𝑎𝑗 − log 𝑟𝑘(𝑎
𝑘) න ෑ 𝑗≠𝑘
𝑟𝑗(𝑎𝑗) 𝑒𝑎𝑗 𝑒𝑎
𝑘 + const
= න 𝑟𝑘(𝑎
𝑘) න log 𝑞(𝑌, 𝑎) ෑ 𝑗≠𝑘
𝑟𝑗 𝑎𝑗 𝑒𝑎𝑗 𝑒𝑎
𝑘 − න 𝑟𝑘 𝑎 𝑘 log 𝑟𝑘 𝑎 𝑘 𝑒𝑎 𝑘 + const
= න 𝑟𝑘 𝑎
𝑘 𝔽𝑗≠𝑘[log 𝑞(𝑌, 𝑎)] 𝑒𝑎 𝑘 − න 𝑟𝑘(𝑎 𝑘) log 𝑟𝑘 𝑎 𝑘 𝑒𝑎 𝑘 + const
Mean Field update
𝑟𝑘 𝑎
𝑘 (𝑢)
← argmax𝑟𝑘(𝑎𝑘) න 𝑟𝑘 𝑎
𝑘 𝔽𝑗≠𝑘[log 𝑞(𝑌, 𝑎)] 𝑒𝑎 𝑘
− න 𝑟𝑘(𝑎
𝑘) log 𝑟𝑘 𝑎 𝑘 𝑒𝑎 𝑘
- The point of this is not the update equations themselves, but the
general idea:
- freeze some of the variables, compute expectations over those
- update the rest using these expectations
Why does Variational Inference work?
- The argument is similar to the argument for EM
- When expectations are computed using the current values for the
variables not being updated, we implicitly set the KL divergence between the weighting distributions and the posterior distributions to
- The update then pushes up the data likelihood
Bishop – Pattern Recognition and Machine Learning
Variational Autoencoder
- Kingma & Welling: Auto-Encoding Variational Bayes proposes
maximizing the ELBO with a trick to make it differentiable
- Discusses both the variational autoencoder model using parametric
distributions and fully Bayesian variational inference, but we will only discuss the variational autoencoder
Problem Setup
- Assume a generative model with a
latent variable distributed according to some distribution 𝑞(𝑨𝑗)
- The observed variable is distributed
according to a conditional distribution 𝑞(𝑦𝑗|𝑨𝑗, 𝜄)
- Note the similarity to the Factor
Analysis (FA) setup so far
𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
Problem Setup
- We also create a weighting
distribution 𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
- This will play the same role as 𝑟(𝑨𝑗) in
the EM algorithm, as we will see.
- Note that when we discussed EM, this
weighting distribution could be arbitrary: we choose to condition on 𝑦𝑗 here. This is a choice.
- Why does this make sense?
𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
Using a conditional weighting distribution
- There are many values of the latent variables that don’t matter in
practice – by conditioning on the observed variables, we emphasize the latent variable values we actually care about: the ones most likely given the observations
- We would like to be able to encode our data into the latent variable
- space. This conditional weighting distribution enables that encoding
Problem setup
- Implement 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) as a neural
network, this can also be seen as a probabilistic decoder
- Implement 𝑟(𝑨𝑗|𝑦𝑗, 𝜚) as a neural
network, we also can see this as a probabilistic encoder
- Sample 𝑨𝑗 from 𝑟(𝑨𝑗|𝑦𝑗, 𝜚) in the
middle
𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
Unpacking the encoder
- We choose a family of distributions for our conditional distribution 𝑟. For example
Gaussian with diagonal covariance: 𝑟 𝑨𝑗 𝑦𝑗, 𝜚 = 𝒪 𝑨𝑗 𝜈 = 𝑣 𝑦𝑗, 𝑋
1 , Σ = diag(𝑡 𝑦𝑗, 𝑋 2 ) 𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
𝒚𝒋 𝝂 = 𝒗 𝒚𝒋, 𝑿𝟐 𝚻 = 𝐞𝐣𝐛𝐡(𝒕 𝒚𝒋, 𝑿𝟑 )
Unpacking the encoder
- We create neural networks to predict the parameters of 𝑟 from our data
- In this case, the outputs of our networks are 𝜈 and Σ
𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
𝒚𝒋 𝝂 = 𝒗 𝒚𝒋, 𝑿𝟐 𝚻 = 𝐞𝐣𝐛𝐡(𝒕 𝒚𝒋, 𝑿𝟑 )
Unpacking the encoder
- We refer to the parameters of our networks, 𝑿𝟐 and 𝑿𝟑 collectively as 𝜚
- Together, networks 𝒗 and 𝒕 parameterize a distribution, 𝑟(𝑨𝑗|𝑦𝑗, 𝜚), of the latent
variable 𝒜𝒋 that depends in a complicated, non-linear way on 𝒚𝒋
𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
𝒚𝒋 𝝂 = 𝒗 𝒚𝒋, 𝑿𝟐 𝚻 = 𝐞𝐣𝐛𝐡(𝒕 𝒚𝒋, 𝑿𝟑 )
Unpacking the decoder
- The decoder follows the same logic, just swapping 𝒚𝒋 and 𝒜𝒋
- We refer to the parameters of our networks, 𝑿𝟒 and 𝑿𝟓 collectively as 𝜄
- Together, networks 𝒗𝒆 and 𝒕𝒆 parameterize a distribution, 𝑞(𝑦𝑗|𝑨𝑗, 𝜄), of the
latent variable 𝒚𝒋 that depends in a complicated, non-linear way on 𝒜𝒋
𝝂 = 𝒗𝒆 𝒜𝒋, 𝑿𝟒 𝚻 = 𝐞𝐣𝐛𝐡(𝒕𝒆 𝒜𝒋, 𝑿𝟓 )
𝑞(𝑦𝑗|𝑨𝑗, 𝜄)
𝒜𝒋~𝒓(𝒜𝒋|𝒚𝒋, 𝝔)
Understanding the setup
- Note that 𝑞 and 𝑟 do not have to use
the same distribution family, this was just an example
- This basically looks like an
autoencoder, but the outputs of both the encoder and decoder are parameters of the distributions of the latent and observed variables respectively
- We also have a sampling step in the
middle
𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
Using EM for training
- Initialize 𝜄(0)
- At each iteration 𝑢 = 1, … , 𝑈
- E step: Hold 𝜄(𝑢−1) fixed, find 𝑟(𝑢) which maximizes ℒ 𝑟, 𝜄(𝑢−1)
- M step: Hold 𝑟(𝑢) fixed, find 𝜄(𝑢) which maximizes ℒ 𝑟(𝑢), 𝜄
- We will use a modified EM to train the model, but we will transform it
so we can use standard back propagation!
Using EM for training
- Initialize 𝜄(0)
- At each iteration 𝑢 = 1, … , 𝑈
- E step: Hold 𝜄(𝑢−1) fixed, find 𝜚(𝑢) which maximizes ℒ 𝜚, 𝜄 𝑢−1 , 𝑦
- M step: Hold 𝜚(𝑢) fixed, find 𝜄(𝑢) which maximizes ℒ 𝜚(𝑢), 𝜄, 𝑦
- First we modify the notation to account for our choice of using a
parametric, conditional distribution 𝑟
Using EM for training
- Initialize 𝜄(0)
- At each iteration 𝑢 = 1, … , 𝑈
- E step: Hold 𝜄(𝑢−1) fixed, find
𝜖ℒ 𝜖𝜚 to increase ℒ 𝜚, 𝜄 𝑢−1 , 𝑦
- M step: Hold 𝜚(𝑢) fixed, find
𝜖ℒ 𝜖𝜄 to increase ℒ 𝜚(𝑢), 𝜄, 𝑦
- Instead of fully maximizing at each iteration, we just take a step in the
direction that increases ℒ
Computing the loss
- We need to compute the gradient for each mini-batch with 𝐶 data samples using the ELBO/variational
bound ℒ 𝜚, 𝜄, 𝑦𝑗 as the loss
𝑗=1 𝐶
ℒ 𝜚, 𝜄, 𝑦𝑗 =
𝑗=1 𝐶
−KL 𝑟 𝑨𝑗|𝑦𝑗, 𝜚 || 𝑞 𝑦𝑗, 𝑨𝑗|𝜄 =
𝑗=1 𝐶
−𝔽𝑟 𝑨𝑗 𝑦𝑗, 𝜚 log 𝑟 𝑨𝑗 𝑦𝑗, 𝜚 𝑞 𝑦𝑗, 𝑨𝑗|𝜄
- Notice that this involves an intractable integral over all values of 𝑨
- We can use Monte Carlo sampling to approximate the expectation using 𝑀 samples from 𝑟(𝑨𝑗|𝑦𝑗, 𝜚):
𝔽𝑟(𝑨𝑗|𝑦𝑗,𝜚) 𝑔 𝑨𝑗 ≃ 1 𝑀
𝑘=1 𝑀
𝑔(𝑨𝑗,𝑘) ℒ 𝜚, 𝜄, 𝑦𝑗 ≃ ሚ ℒ𝐵 𝜚, 𝜄, 𝑦𝑗 = 1 𝑀
𝑘=1 𝑀
log 𝑞 𝑦𝑗, 𝑨𝑗,𝑘|𝜄 − log 𝑟(𝑨𝑗,𝑘|𝑦𝑗, 𝜚)
A lower variance estimator of the loss
- We can rewrite
ℒ 𝜚, 𝜄, 𝑦 = −KL 𝑟 𝑨 𝑦, 𝜚 || 𝑞 𝑦, 𝑨|𝜄 = − න 𝑟 𝑨 𝑦, 𝜚 log 𝑟 𝑨 𝑦, 𝜚 𝑞 𝑦|𝑨, 𝜄 𝑞(𝑨) 𝐞𝑨 = − න 𝑟 𝑨 𝑦, 𝜚 log 𝑟 𝑨 𝑦, 𝜚 𝑞(𝑨) − log 𝑞 𝑦|𝑨, 𝜄 𝐞𝑨 = = −KL 𝑟 𝑨 𝑦, 𝜚 || 𝑞 𝑨 + 𝔽𝑟 𝑨 𝑦, 𝜚 log 𝑞 𝑦|𝑨, 𝜄
- The first term can be computed analytically for some families of distributions (e.g.
Gaussian); only the second term must be estimated ℒ 𝜚, 𝜄, 𝑦𝑗 ≃ ሚ ℒ𝐶 𝜚, 𝜄, 𝑦𝑗 = −KL 𝑟 𝑨𝑗|𝑦𝑗, 𝜚 || 𝑞 𝑨𝑗 + 1 𝑀
𝑘=1 𝑀
log 𝑞 𝑦𝑗|𝑨𝑗,𝑘, 𝜄
Full EM training procedure (not really used)
- For 𝑢 = 1: 𝑐: 𝑈
- Estimate 𝜖ℒ
𝜖𝜚 (How do we do this? We’ll get to it shortly)
- Update 𝜚
- Estimate 𝜖ℒ
𝜖𝜄 :
- Initialize Δ𝜄 = 0
- For 𝑗 = 𝑢: 𝑢 + 𝑐 − 1
- Compute the outputs of the encoder (parameters of 𝑟) for 𝑦𝑗
- For ℓ = 1, … 𝑀
- Sample 𝑨𝑗 ~ 𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
- Δ𝜄𝑗,ℓ ← Run forward/backward pass on the decoder
(standard back propagation) using either ሚ ℒ𝐵 or ሚ ℒ𝐶 as the loss
- Δ𝜄 ← Δ𝜄 + Δ𝜄𝑗,ℓ
- Update 𝜄
𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
Full EM training procedure (not really used)
- For 𝑢 = 1: 𝑐: 𝑈
- Estimate 𝜖ℒ
𝜖𝜚 (How do we do this? We’ll get to it shortly)
- Update 𝜚
- Estimate 𝜖ℒ
𝜖𝜄 :
- Initialize Δ𝜄 = 0
- For 𝑗 = 𝑢: 𝑢 + 𝑐 − 1
- Compute the outputs of the encoder (parameters of 𝑟) for 𝑦𝑗
- Sample 𝑨𝑗 ~ 𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
- Δ𝜄𝑗 ← Run forward/backward pass on the decoder (standard
back propagation) using either ሚ ℒ𝐵 or ሚ ℒ𝐶 as the loss
- Δ𝜄 ← Δ𝜄 + Δ𝜄𝑗
- Update 𝜄
𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚) First simplification: Let 𝑀 = 1. We just want a stochastic estimate of the
- gradient. With a large enough 𝐶,
we get enough samples from 𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
The E step
- We can use standard back
propagation to estimate
𝜖ℒ 𝜖𝜄
- How do we estimate
𝜖ℒ 𝜖𝜚?
- The sampling step blocks the gradient
flow
- Computing the derivatives through 𝑟
via the chain rule gives a very high variance estimate of the gradient
𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
?
Reparameterization
- Instead of drawing 𝑨𝑗 ~ 𝑟(𝑨𝑗|𝑦𝑗, 𝜚),
let 𝑨𝑗 = g(𝜗𝑗, 𝑦𝑗, 𝜚), and draw 𝜗𝑗 ~ 𝑞(𝜗)
- 𝑨𝑗 is still a random variable but depends on 𝜚 deterministically
- Replace 𝔽𝑟(𝑨𝑗|𝑦𝑗,𝜚) 𝑔 𝑨𝑗
with 𝔽𝑞(𝜗)[𝑔 g 𝜗𝑗, 𝑦𝑗, 𝜚 ]
- Example – univariate normal:
𝑏 ~ 𝒪 𝜈, 𝜏2 is equivalent to 𝑏 = g 𝜗 , 𝜗 ~𝒪 0, 1 , g 𝑐 ≜ 𝜈 + 𝜏𝑐
Reparameterization
𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
?
(𝜗𝑗, 𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗 = (𝜗𝑗, 𝑦𝑗, 𝜚) 𝜗𝑗 ~ 𝑞(𝜗)
Full EM training procedure (not really used)
- For 𝑢 = 1: 𝑐: 𝑈
- E Step
- Estimate
𝜖ℒ 𝜖𝜚 using standard back
propagation with either ሚ ℒ𝐵 or ሚ ℒ𝐶 as the loss
- Update 𝜚
- M Step
- Estimate
𝜖ℒ 𝜖𝜄 using standard back
propagation with either ሚ ℒ𝐵 or ሚ ℒ𝐶 as the loss
- Update 𝜄
(𝜗𝑗, 𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗 = (𝜗𝑗, 𝑦𝑗, 𝜚) 𝜗𝑗 ~𝑞(𝜗)
Full training procedure
- For 𝑢 = 1: 𝑐: 𝑈
- Estimate
𝜖ℒ 𝜖𝜚 , 𝜖ℒ 𝜖𝜄 with either ሚ
ℒ𝐵 or ሚ ℒ𝐶 as the loss
- Update 𝜚, 𝜄
- Final simplification: update all of the
parameters at the same time instead of using separate E, M steps
- This is standard back propagation. Just use
− ሚ ℒ𝐵 or − ሚ ℒ𝐶 as the loss, and run your favorite SGD variant
(𝜗𝑗, 𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗 = (𝜗𝑗, 𝑦𝑗, 𝜚) 𝜗𝑗 ~𝑞(𝜗)
Running the model on new data
- To get a MAP estimate of the latent variables, just use the mean
- utput by the encoder (for a Gaussian distribution)
- No need to take a sample
- Give the mean to the decoder
- At test time, this is used just as an auto-encoder
- You can optionally take multiple samples of the latent variables to
estimate the uncertainty
Relationship to Factor Analysis
- VAE performs probabilistic, non-linear
dimensionality reduction
- It uses a generative model with a latent
variable distributed according to some prior distribution 𝑞(𝑨𝑗)
- The observed variable is distributed
according to a conditional distribution 𝑞(𝑦𝑗|𝑨𝑗, 𝜄)
- Training is approximately running
expectation maximization to maximize the data likelihood
- This can be seen as a non-linear version
- f Factor Analysis
𝑟(𝑨𝑗|𝑦𝑗, 𝜚) 𝑞(𝑦𝑗|𝑨𝑗, 𝜄) 𝑨𝑗~𝑟(𝑨𝑗|𝑦𝑗, 𝜚)
Regularization by a prior
- Looking at the form of ℒ we used to justify ሚ
ℒ𝐶 gives us additional insight ℒ 𝜚, 𝜄, 𝑦 = −KL 𝑟 𝑨 𝑦, 𝜚 || 𝑞 𝑨 + 𝔽𝑟 𝑨 𝑦, 𝜚 log 𝑞 𝑦|𝑨, 𝜄
- We are making the latent distribution as close as possible to a prior
- n 𝑨
- While maximizing the conditional likelihood of the data under our
model
- In other words this is an approximation to Maximum Likelihood
Estimation regularized by a prior on the latent space
Practical advantages of a VAE vs. an AE
- The prior on the latent space:
- Allows you to inject domain knowledge
- Can make the latent space more interpretable
- The VAE also makes it possible to estimate the variance/uncertainty in
the predictions
Interpreting the latent space
https://arxiv.org/pdf/1610.00291.pdf
Requirements of the VAE
- Note that the VAE requires 2 tractable distributions to be used:
- The prior distribution 𝑞(𝑨) must be easy to sample from
- The conditional likelihood 𝑞 𝑦|𝑨, 𝜄 must be computable
- In practice this means that the 2 distributions of interest are often
simple, for example uniform, Gaussian, or even isotropic Gaussian
The blurry image problem
https://blog.openai.com/generative-models/
- The samples from the VAE
look blurry
- Three plausible
explanations for this
- Maximizing the
likelihood
- Restrictions on the
family of distributions
- The lower bound
approximation
The maximum likelihood explanation
https://arxiv.org/pdf/1701.00160.pdf
- Recent evidence
suggests that this is not actually the problem
- GANs can be trained
with maximum likelihood and still generate sharp examples
Investigations of blurriness
- Recent investigations suggest that both the simple probability
distributions and the variational approximation lead to blurry images
- Kingma & colleages: Improving Variational Inference with Inverse
Autoregressive Flow
- Zhao & colleagues: Towards a Deeper Understanding of Variational
Autoencoding Models
- Nowozin & colleagues: f-gan: Training generative neural samplers
using variational divergence minimization