SLIDE 1
Class Probabilities and the Log-sum-exp Trick
Oren Freifeld Computer Science, Ben-Gurion University May 14, 2017
Oren Freifeld (BGU CS) May 14, 2017 1 / 10
SLIDE 2 Disclaimer
Both the problem and the solution described in these slides are widely
- known. I don’t remember where I saw the solution the first time, and
couldn’t find out who should be credited with discovering it. Some of my derivations below are based on Ryan Adams’ post at https://hips.seas.harvard.edu/blog/2013/01/09/ computing-log-sum-exp/
Oren Freifeld (BGU CS) May 14, 2017 2 / 10
SLIDE 3
Numerical Issues with Computing Class Probabilities
We often need to compute, for data point x, expressions such as: p(z = k|θ, x) ∝ wk exp(lk) and wk exp(lk) K
k′=1 wk′ exp(lk′)
where lk ∈ R ∀k ∈ {1, . . . , K}
Here, lk does not necessarily stand for log-likelihood; rather, it stands for the nominal value of the exponent of the k−th term of interest.
Oren Freifeld (BGU CS) May 14, 2017 3 / 10
SLIDE 4 Example In EM for GMM, the E step involves ri,k = πkN(xi; µk, Σk) K
k′=1 πk′N(xi; µk′, Σk′)
= πk(2π)−n/2|Σk|−1/2 exp
2(xi − µk)T Σ−1 k (xi − µk)
k′=1 πk′(2π)−n/2|Σk′|−1/2 exp
2(xi − µk′)T Σ−1 k′ (xi − µk′)
wk
lk
2(xi − µk)T Σ−1 k (xi − µk)
K
k′=1 πk′|Σk′|−1/2 exp
2(xi − µk′)T Σ−1 k′ (xi − µk′)
Here, the π in the 2π term (which cancels out anyway) is the number π, while πk is the weight of the k-th component; confusing, but it is a fairly standard notation, especially in Bayesian statistics.
Oren Freifeld (BGU CS) May 14, 2017 4 / 10
SLIDE 5
Numerical Issues with Computing Class Probabilities
If lk < 0 and |lk| is too large, we might have situations where (on a computer) exp(lk) = 0 for all k. Thus, K
k′=1 wk′ exp(lk′) will be zero.
Similarly, if lk > 0 (can happen, for example, for some non-Gaussian conditional class probabilities), might get +∞ (and/or overflow) if lk is too large. These issues appear in many clustering problems, including in (either Bayesian or non-Bayesian) mixture models.
Oren Freifeld (BGU CS) May 14, 2017 5 / 10
SLIDE 6 The Log-sum-exp Trick
Fact ∀a ∈ R and ∀{lk}K
k=1 ⊂ R :
log
K
exp(lk) = a + log
K
exp(lk − a)
Oren Freifeld (BGU CS) May 14, 2017 6 / 10
SLIDE 7 The Log-sum-exp Trick
Proof. log
K
exp(lk) = log
K
exp(lk − a + a) = log K
exp(lk − a) exp(a)
K
exp(lk − a)
K
exp(lk − a)
K
exp(lk − a)
Oren Freifeld (BGU CS) May 14, 2017 7 / 10
SLIDE 8
The Log-sum-exp Trick
Fact ∀a ∈ R and ∀{lk}K
k=1 ⊂ R :
exp(lk − a) K
k′=1 exp(lk′ − a)
= exp(lk) K
k′=1 exp(lk′)
Oren Freifeld (BGU CS) May 14, 2017 8 / 10
SLIDE 9 Proof. (1) log K
k=1 exp(lk) = a + log
K
k=1 exp(lk − a) (by the previous fact)
(2) exp(lk) = exp (log exp(lk))
(1) with K=1
= exp (a + log exp(lk − a)) (3)
K
exp(lk) = exp
K
exp(lk)
= exp
K
exp(lk − a)
exp(lk − a) K
k′=1 exp(lk′ − a)
= exp (log exp(lk − a)) exp
k=1 exp(lk − a)
exp(a) exp (log exp(lk − a)) exp
k=1 exp(lk − a)
exp (a + log exp(lk − a)) exp
k=1 exp(lk − a)
= exp(lk) K
k′=1 exp(lk′)
Oren Freifeld (BGU CS) May 14, 2017 9 / 10
SLIDE 10
The Log-sum-exp Trick
exp(lk − a) K
k′=1 exp(lk′ − a)
= exp(lk) K
k′=1 exp(lk′)
Choose a = maxk lk and compute the LHS, not the problematic RHS. This will prevent +∞, and even if some values vanish, we will have at least one survivor (emaxk lk−a = e0 = 1 > 0) so the denominator will be strictly positive (and finite). More generally, instead of computing wk exp(lk) K
k′=1 wk′ exp(lk′)
use exp(lk + log wk − a) K
k′=1 exp(lk′ + log wk − a)
where a = maxk lk + log wk and where we also used the fact that wk exp(lk) = exp(log wk) exp(lk) = exp(lk + log wk).
Oren Freifeld (BGU CS) May 14, 2017 10 / 10