Class Probabilities and the Log-sum-exp Trick Oren Freifeld - - PowerPoint PPT Presentation

class probabilities and the log sum exp trick
SMART_READER_LITE
LIVE PREVIEW

Class Probabilities and the Log-sum-exp Trick Oren Freifeld - - PowerPoint PPT Presentation

Class Probabilities and the Log-sum-exp Trick Oren Freifeld Computer Science, Ben-Gurion University May 14, 2017 Oren Freifeld (BGU CS) May 14, 2017 1 / 10 Disclaimer Both the problem and the solution described in these slides are widely


slide-1
SLIDE 1

Class Probabilities and the Log-sum-exp Trick

Oren Freifeld Computer Science, Ben-Gurion University May 14, 2017

Oren Freifeld (BGU CS) May 14, 2017 1 / 10

slide-2
SLIDE 2

Disclaimer

Both the problem and the solution described in these slides are widely

  • known. I don’t remember where I saw the solution the first time, and

couldn’t find out who should be credited with discovering it. Some of my derivations below are based on Ryan Adams’ post at https://hips.seas.harvard.edu/blog/2013/01/09/ computing-log-sum-exp/

Oren Freifeld (BGU CS) May 14, 2017 2 / 10

slide-3
SLIDE 3

Numerical Issues with Computing Class Probabilities

We often need to compute, for data point x, expressions such as: p(z = k|θ, x) ∝ wk exp(lk) and wk exp(lk) K

k′=1 wk′ exp(lk′)

where lk ∈ R ∀k ∈ {1, . . . , K}

Here, lk does not necessarily stand for log-likelihood; rather, it stands for the nominal value of the exponent of the k−th term of interest.

Oren Freifeld (BGU CS) May 14, 2017 3 / 10

slide-4
SLIDE 4

Example In EM for GMM, the E step involves ri,k = πkN(xi; µk, Σk) K

k′=1 πk′N(xi; µk′, Σk′)

= πk(2π)−n/2|Σk|−1/2 exp

  • − 1

2(xi − µk)T Σ−1 k (xi − µk)

  • K

k′=1 πk′(2π)−n/2|Σk′|−1/2 exp

  • − 1

2(xi − µk′)T Σ−1 k′ (xi − µk′)

  • =

wk

  • πk|Σk|−1/2 exp

  

lk

  • − 1

2(xi − µk)T Σ−1 k (xi − µk)

   K

k′=1 πk′|Σk′|−1/2 exp

  • − 1

2(xi − µk′)T Σ−1 k′ (xi − µk′)

  • Remark

Here, the π in the 2π term (which cancels out anyway) is the number π, while πk is the weight of the k-th component; confusing, but it is a fairly standard notation, especially in Bayesian statistics.

Oren Freifeld (BGU CS) May 14, 2017 4 / 10

slide-5
SLIDE 5

Numerical Issues with Computing Class Probabilities

If lk < 0 and |lk| is too large, we might have situations where (on a computer) exp(lk) = 0 for all k. Thus, K

k′=1 wk′ exp(lk′) will be zero.

Similarly, if lk > 0 (can happen, for example, for some non-Gaussian conditional class probabilities), might get +∞ (and/or overflow) if lk is too large. These issues appear in many clustering problems, including in (either Bayesian or non-Bayesian) mixture models.

Oren Freifeld (BGU CS) May 14, 2017 5 / 10

slide-6
SLIDE 6

The Log-sum-exp Trick

Fact ∀a ∈ R and ∀{lk}K

k=1 ⊂ R :

log

K

  • k=1

exp(lk) = a + log

K

  • k=1

exp(lk − a)

Oren Freifeld (BGU CS) May 14, 2017 6 / 10

slide-7
SLIDE 7

The Log-sum-exp Trick

Proof. log

K

  • k=1

exp(lk) = log

K

  • k=1

exp(lk − a + a) = log K

  • k=1

exp(lk − a) exp(a)

  • = log
  • exp(a)

K

  • k=1

exp(lk − a)

  • = log exp(a) + log

K

  • k=1

exp(lk − a)

  • = a + log

K

  • k=1

exp(lk − a)

Oren Freifeld (BGU CS) May 14, 2017 7 / 10

slide-8
SLIDE 8

The Log-sum-exp Trick

Fact ∀a ∈ R and ∀{lk}K

k=1 ⊂ R :

exp(lk − a) K

k′=1 exp(lk′ − a)

= exp(lk) K

k′=1 exp(lk′)

Oren Freifeld (BGU CS) May 14, 2017 8 / 10

slide-9
SLIDE 9

Proof. (1) log K

k=1 exp(lk) = a + log

K

k=1 exp(lk − a) (by the previous fact)

(2) exp(lk) = exp (log exp(lk))

(1) with K=1

= exp (a + log exp(lk − a)) (3)

K

  • k=1

exp(lk) = exp

  • log

K

  • k=1

exp(lk)

  • (1)

= exp

  • a + log

K

  • k=1

exp(lk − a)

  • (4)

exp(lk − a) K

k′=1 exp(lk′ − a)

= exp (log exp(lk − a)) exp

  • log K

k=1 exp(lk − a)

  • = exp(a)

exp(a) exp (log exp(lk − a)) exp

  • log K

k=1 exp(lk − a)

  • =

exp (a + log exp(lk − a)) exp

  • a + log K

k=1 exp(lk − a)

  • (2)&(3)

= exp(lk) K

k′=1 exp(lk′)

Oren Freifeld (BGU CS) May 14, 2017 9 / 10

slide-10
SLIDE 10

The Log-sum-exp Trick

exp(lk − a) K

k′=1 exp(lk′ − a)

= exp(lk) K

k′=1 exp(lk′)

Choose a = maxk lk and compute the LHS, not the problematic RHS. This will prevent +∞, and even if some values vanish, we will have at least one survivor (emaxk lk−a = e0 = 1 > 0) so the denominator will be strictly positive (and finite). More generally, instead of computing wk exp(lk) K

k′=1 wk′ exp(lk′)

use exp(lk + log wk − a) K

k′=1 exp(lk′ + log wk − a)

where a = maxk lk + log wk and where we also used the fact that wk exp(lk) = exp(log wk) exp(lk) = exp(lk + log wk).

Oren Freifeld (BGU CS) May 14, 2017 10 / 10