Reliable Variational Learning for Hierarchical Dirichlet Processes
Erik Sudderth
Brown University Computer Science Collaborators:
- Michael Hughes & Dae Il Kim, Brown University
- Prem Gopalan & David Blei, Princeton University
Reliable Variational Learning for Hierarchical Dirichlet Processes - - PowerPoint PPT Presentation
Reliable Variational Learning for Hierarchical Dirichlet Processes Erik Sudderth Brown University Computer Science Collaborators: Michael Hughes & Dae Il Kim, Brown University Prem Gopalan & David Blei, Princeton University Learning
structure: topics, behaviors, objects, communities…
assumptions, not heuristic algorithm initializations
good predictive power, not full posterior uncertainty
Hierarchical Dirichlet Process (Teh et al., JASA 2006)
There are reasons to believe that the genetics of an organism are likely to shift due to the extreme changes in
politicians must pass environmental legislation that can protect our future species from becoming extinct…
Genetics, Climate Change, Politics, …
v1, v2, v3 . . .
0.2 0.3 0.5
GOAL: Partition data into an a priori unknown number of discrete clusters.
k=1 πk = QK k=1(1 − vk)
1 Stick-Breaking
(Sethuraman 1994)
`=1 (1 − v`)
0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1 0.5 1 1.5 2 2.5 3 3.5 4 4.5 5α = 1
α = 20
v1, v2, v3 . . .
0.2 0.3 0.5
GOAL: Partition data into an a priori unknown number of discrete clusters.
0 ¯
k t(xn) − a(φk)
family likelihoods with conjugate priors
v1, v2, v3 . . .
0.2 0.3 0.5
GOAL: Partition data into an a priori unknown number of discrete clusters.
Hyperparameters
k t(xn) − a(φk)
v1, v2, v3 . . .
0.2 0.3 0.5
Can we sample from the posterior distribution over data clusterings?
Iteratively resample cluster assignment for one observation, fixing all others.
Log-probability Number of clusters
= log X
z
ZZ q(z, v, φ)p(x, z, v, φ | α, λ0) q(z, v, φ) dvdφ
What is the marginal likelihood of our observed data?
z
variational distribution q(z, v, φ)
Jensen’s Inequality Expected log-likelihood (negative of “average energy”) Variational entropy
tractable algorithms via assumed independence:
q(z, v, φ) = q(z)q(v, φ) = "
N
Y
n=1
q(zn) # · " ∞ Y
k=1
q(vk)q(φk) #
Beta Distribution Exponential Family from Conjugate Prior
q(zn = k) = rnk
Categorical distribution with unbounded support, and infinitely many potential clusters!
1 2 3 4 5 6 7 8 9 101112131415 0.05 0.1 0.15 0.2
α = 4, K = 10
Blei & Jordan, 2006; Ishwaran & James, 2001
q(v, φ) = " K Y
k=1
q(φk) # · " K−1 Y
k=1
q(vk) # , vK =
K−1
Y
k=1
(1 − vk).
q(zn) = Cat(zn | rn1, rn2, . . . , rnK)
1 2 3 4 5 6 7 8 9 101112131415 0.05 0.1 0.15 0.2
Bryant & Sudderth, 2012; Teh, Kurihara, & Welling, 2008
q(v, φ) =
∞
Y
k=1
q(vk)q(φk)
For any k>K, optimal variational distributions equal prior & need not be explicitly represented
A Bayesian nonparametric analog of Expectation-Maximization (EM)
for some K>0
N
n=1
k=1
`=1 Eq[log(1 − v`)]
ψ(αk1) − ψ(αk1 + αk0) ψ(αk0) − ψ(αk1 + αk0)
N 0
k= PN n=1 rnk
k ← PN n=1 rnkt(xn)
k
Expected counts and sufficient statistics are only non-zero for first K clusters
k
`=k+1 N 0 ` = α + PK `=k+1 N 0 `
Match Expected Sufficient Statistics For cluster k = 1, 2, … K: For data item n = 1, 2, … N, and K candidate clusters:
k ← PN n=1 rnkt(xn)
k
k
>k
K
k=1
log-normalizers for cluster shape and beta stick-breaking priors
n=1
k=1 rnk log rnk = − PN n=1
k=1 rnk log rnk
K
k=1
log-normalizers for cluster shape and beta stick-breaking priors
n=1
k=1 rnk log rnk = − PN n=1
k=1 rnk log rnk
+ Likelihood bound monotonically increasing, guaranteed convergence to posterior mode + Unlike classical EM for MAP estimation, allows Bayesian comparison of hypotheses with varying complexity K, crucial for BNP models
Hoffman, Blei, Paisley, & Wang, JMLR 2013 Data
. . . . . .
k ← λ0 + N |Bb|sb k
k + (1−ρt)λk
For cluster k = 1, 2, … K:
Apply similar updates to stick weights.
Update: For each batch b
sb
k ← P n∈Bb rnkt(xn)
batch stats give noisy estimate of (natural) gradient
10 10
1
10
2
10
3
10
4
0.2 0.4
a b c
Learning Rate
t ρ2 t < ∞
κ ∈ (.5, 1]
Robbins-Monro convergence condition:
t ρt → ∞
+ Per-iteration cost is low + Initial iterations often very effective
convergence guarantee is weak
Hughes & Sudderth, NIPS 2013; Neal & Hinton 1999 Data
. . . . . .
+ Per-iteration cost is low + Initial iterations often very effective + Insensitive to chosen B, no learning rate + Foundation for inferring number of clusters K
batches (NOT number of observations)
For cluster k = 1, 2, … K:
Apply similar updates to stick weights.
Update: For each batch b
batch stats allow exact estimation from partial E-steps
k ← s0 k − sb k
k ← s0 k + sb k
sb
k ← P n∈Bb rnkt(xn)
k
Batch Summaries Global Summary
1 s1 2 · · · s1 K
1 s2 2 · · · s2 K
1
2
K
1 s0 2 · · · s0 K
k = s1 k + s2 k + . . . sB k
H0
k = H1 k + H2 k + . . . HB k
Entropy for L(q)
Hb
k = − P n∈Bb rnk log rnk
Subsample data explained by 1
1) Create new components 3) Merge to remove redundancy 1 2
Before
1 2 3 4 5 6 7
Add fresh components to expand original model Learn fresh DP-GMM on subsample via VB
After
Memoized summary
1 2 3 4 5 6 7
expected count of each component
current position batches not-yet updated
any new components
2) Adopt in one pass thru data
Batch 1 Batch b Batch b+1 Batch B
… …
N b
k, sb k
km←s0 ka + s0 kb
Requires memoized entropy sums for candidate pairs of clusters; more efficient alternatives under development.
km←N 0 ka + N 0 kb,
L(q) = H[r] +
K
X
k=1
⇥ ¯ a(s0
k + λ0) − ¯
a(λ0) + log B(1 + N 0
k, α + N 0 >k) − log B(1, α)
⇤
0.13 0.13 0.12 0.12 0.13 0.13 0.13 0.12
worst MO-BM run worst MO run best SO run worst Batch run
0.00 0.00 0.13 0.12 0.13 0.25 0.25 0.13
Not found Not found
0.25 0.13 0.12 0.12 0.13 0.13 0.13 0.00
Not found
0.00 0.00 0.25 0.13 0.13 0.25 0.25 0.00
Not found Not found Not found
3 6 9 12 15 18 21 24 27 30 0.99 1 1.01 1.02 1.03 1.04
log evidence x106 SOa K=25 SOb K=25 SOc K=25
Batch, memoized, & memoized birth-merge Stochastic variational: Rate a, Rate b, Rate c Greedy: Merge based on single batches
3 6 9 12 15 18 21 24 27 30 0.99 1 1.01 1.02 1.03 1.04
log evidence x106 Full K=25 MO K=25 GreedyMerge MO−BM K=1
L(q)
finds true cluster every time
rate and initialization
SOa SOb SOc Full MO MO−BM Kuri −3.1 −3.05 −3 −2.95 −2.9 −2.85 log evidence x106 20 batches 100 batches
Likelihood bound, K-means++ initialization
SOa SOb SOc Full MO MO−BM Kuri −4.5 −4 −3.5 −3 log evidence x106 20 batches 100 batches
Likelihood Bound, random initialization
Learning rate schedules
Batch, memoized, & memoized birth-merge Stochastic variational: Rate a, Rate b, Rate c Kurihara: Accelerated variational, NIPS 2006
40 50 60 70 80 90 100 110 0.7 0.72 0.74 0.76 0.78 0.8 0.82 Effective num. components K Alignment accuracy
Many-to-one alignment
SOa SOb SOc Full MO MO−BM Kuri −3.1 −3.05 −3 −2.95 −2.9 −2.85 log evidence x106 20 batches 100 batches
Likelihood bound, K-means++ initialization
SOa SOb SOc Full MO MO−BM Kuri −4.5 −4 −3.5 −3 log evidence x106 20 batches 100 batches
Likelihood Bound, random initialization
Learning rate schedules
accuracy while using fewer clusters
Gibbs: Log-probability Gibbs: Number of clusters
Memoized birth-merge: Log-likelihood bound Memoized birth-merge: Number of clusters Gap: Tiny clusters
5 10 15 20 25 30 35 40 45 50 −1.62 −1.61 −1.6 −1.59 −1.58 −1.57 −1.56 −1.55
log evidence x107 SOa K=100 SOb K=100 Full K=100 MO K=100 MO−BM K=1
25 batches
Likelihood bound
10 20 30 40 50 60 70 80 90 100 50 100 150 200 250 300
MO−BM K=1 MO K=100 SOa K=100 10 20 30 40 50 60 70 80 90 100 4.25 4.3 4.35 4.4 4.45
log evidence x108 MO−BM K=1 MO K=100 SOa K=100
Likelihood bound Number of clusters
There are reasons to believe that the genetics of an organism are likely to shift due to the extreme changes in our
must pass environmental legislation that can protect our future species from becoming extinct…
Documents are represented as mixtures
0.5 Document 1
Genetics Climate Politics
“Politics” Topic “Climate Change” Topic “Genetics” Topic
Topics are categorical distributions on a (typically large) discrete vocabulary:
Generalization of Latent Dirichlet Allocation (LDA, Blei 2003) by Teh et al. JMLR 2006. Dependent Dirichlet process (DDP , MacEachern 1999) with group-specific weights.
`=1 (1 − u`)
`=1 (1 − vd`)
Generalized Dirichlet, Connor & Mosimann 1969
for some K>0
Eq[log πdk(vd)] = Eq[log(vdk)] + Pk−1
`=1 Eq[log(1 − vd`)]
k)
dependence on topic frequencies, requires additional bound and numerical optimization
NIPS Conference (D=1392) Huffington Post (D=3271)
Training Log-likelihood Bound Training Log-likelihood Bound Test Log-likelihood (MCMC Estimator) Test Log-likelihood (MCMC Estimator)
GOAL: Unsupervised community discovery from observed relationships.
.8 .6 .2 .1 .9 .3 .1 .2 .8
Edge Creation Parameter Matrix (K=3)
unobserved link Receiver Source
Parametric mixed membership stochastic blockmodel, Airoldi et al. JMLR 2008
.8 .1 .5 .2 .7 .7 .5 .3 .9 .2 .7 .8 .1 .1 .9 .7
ij
12
21
Edge is not present
Edge is present
Edge Creation Parameter Matrix
ij)
.8 .1 .5 .2 .7 .7 .5 .3 .9 .2 .7 .8 .1 .1 .9 .7
ij
12
21
Edge is not present
Edge is present
Edge Creation Parameter Matrix
ij)
storage & computation for distribution on K2 community pairs
Mini-Batch #1 Mini-Batch #2 Mini-Batch #3
Θk < (log K)/N
i=1 Eq[πik]
2 4 6 8 10 12 14 x 10
7
5 10 15 20 25 30 Perplexity Hep N=11204 Number of Observed Edges Perplexity aMMSB−K250 aMMSB−K300 aHDPR−Naive−K500 aHDPR−K500 aHDPR−Pruning 0.5 1 1.5 2 2.5 x 10
8
10 20 30 40 50 60 70 Perplexity Condensed Matter N=21363 Number of Observed Edges Perplexity aMMSB−K400 aMMSB−K450 aHDPR−Naive−K500 aHDPR−K500 aHDPR−Pruning 0.6 0.65 0.7 0.75 0.8 0.85 0.9 0.95 1 AUC Hep N=11204 AUC Quantiles aMMSB K250 aMMSB K300 aHDPR Naive−K500 aHDPR K500 aHDPR Pruning 0.6 0.65 0.7 0.75 0.8 0.85 0.9 0.95 1 AUC Condensed Matter N=21363 AUC Quantiles aMMSB K400 aMMSB K450 aHDPR Naive−K500 aHDPR K500 aHDPR Pruning
AUC: Area Under ROC (prediction of held-out links) Perplexity: Normalized negative log-probability
High Energy Physics (HEP) N=11,204
Stochastic inference & model variants:
Condensed Matter Physics N=21,363
Top 200 degree nodes Full network has N=18,831
Colors correspond to highest community memberships Community Distance: Top 200 degree nodes Full network has N=18,831
2
k |πik − πjk|