Dropout in RNNs Following a VI Interpretation Yarin Gal - - PowerPoint PPT Presentation

dropout in rnns following a vi interpretation
SMART_READER_LITE
LIVE PREVIEW

Dropout in RNNs Following a VI Interpretation Yarin Gal - - PowerPoint PPT Presentation

Dropout in RNNs Following a VI Interpretation Yarin Gal yg279@cam.ac.uk Unless specified otherwise, photos are either original work or taken from Wikimedia, under Creative Commons license Recurrent Neural Networks Recurrent neural networks


slide-1
SLIDE 1

Dropout in RNNs Following a VI Interpretation

Yarin Gal

yg279@cam.ac.uk

Unless specified otherwise, photos are either original work or taken from Wikimedia, under Creative Commons license

slide-2
SLIDE 2

Recurrent Neural Networks

Recurrent neural networks (RNNs) are damn useful.

Figure : RNN structure

Image Source: karpathy.github.io/2015/05/21/rnn-effectiveness

2 of 24

slide-3
SLIDE 3

Recurrent Neural Networks

But these also overfit very quickly...

Figure : Overfitting

This means...

◮ We can’t use large models ◮ We have to use early stopping ◮ We can’t use small data ◮ We have to waste data for validation sets...

3 of 24

slide-4
SLIDE 4

Recurrent Neural Networks

But these also overfit very quickly...

Figure : Overfitting

This means...

◮ We can’t use large models ◮ We have to use early stopping ◮ We can’t use small data ◮ We have to waste data for validation sets...

3 of 24

slide-5
SLIDE 5

Recurrent Neural Networks

But these also overfit very quickly...

Figure : Overfitting

This means...

◮ We can’t use large models ◮ We have to use early stopping ◮ We can’t use small data ◮ We have to waste data for validation sets...

3 of 24

slide-6
SLIDE 6

Recurrent Neural Networks

But these also overfit very quickly...

Figure : Overfitting

This means...

◮ We can’t use large models ◮ We have to use early stopping ◮ We can’t use small data ◮ We have to waste data for validation sets...

3 of 24

slide-7
SLIDE 7

Recurrent Neural Networks

But these also overfit very quickly...

Figure : Overfitting

This means...

◮ We can’t use large models ◮ We have to use early stopping ◮ We can’t use small data ◮ We have to waste data for validation sets...

3 of 24

slide-8
SLIDE 8

Dropout in recurrent neural networks

Let’s use dropout then. But lots of research has claimed that that’s a bad idea:

◮ Pachitariu & Sahani, 2013

◮ noise added in the recurrent connections of an RNN leads to

model instabilities

◮ Bayer et al., 2013

◮ with dropout, the RNNs dynamics change dramatically

◮ Pham et al., 2014

◮ dropout in recurrent layers disrupts the RNNs ability to model

sequences

◮ Zaremba et al., 2014

◮ applying dropout to the non-recurrent connections alone results

in improved performance

◮ Bluche et al., 2015

◮ exploratory analysis of the performance of dropout before,

inside, and after the RNNs

4 of 24

slide-9
SLIDE 9

Dropout in recurrent neural networks

→ has settled on using dropout for inputs and outputs alone: xt ht xt−1 ht−1 xt+1 ht+1

Figure : Naive application of dropout in RNNs (colours = different dropout masks)

5 of 24

slide-10
SLIDE 10

Dropout in recurrent neural networks

Why not use dropout with recurrent layers?

◮ It doesn’t work ◮ Noise drowns the signal ◮ Because it’s not used correctly?

6 of 24

slide-11
SLIDE 11

Dropout in recurrent neural networks

Why not use dropout with recurrent layers?

◮ It doesn’t work ◮ Noise drowns the signal ◮ Because it’s not used correctly?

6 of 24

slide-12
SLIDE 12

Dropout in recurrent neural networks

Why not use dropout with recurrent layers?

◮ It doesn’t work ◮ Noise drowns the signal ◮ Because it’s not used correctly?

6 of 24

slide-13
SLIDE 13

Dropout in recurrent neural networks

Why not use dropout with recurrent layers?

◮ It doesn’t work ◮ Noise drowns the signal ◮ Because it’s not used correctly?

First, some background on Bayesian modelling and VI in Bayesian neural networks.

6 of 24

slide-14
SLIDE 14

Bayesian modelling and inference

◮ Observed inputs X = {xi}N i=1 and outputs Y = {yi}N i=1 ◮ Capture stochastic process believed to have generated outputs ◮ Def. ω model parameters as r.v. ◮ Prior dist. over ω: p(ω) ◮ Likelihood: p(Y|ω, X) ◮ Posterior: p(ω|X, Y) = p(Y|ω,X)p(ω) p(Y|X)

(Bayes’ theorem)

◮ Predictive distribution given new input x∗

p(y∗|x∗, X, Y) =

  • p(y∗|x∗, ω) p(ω|X, Y)
  • posterior

◮ But... p(ω|X, Y) is often intractable

7 of 24

slide-15
SLIDE 15

Bayesian modelling and inference

◮ Observed inputs X = {xi}N i=1 and outputs Y = {yi}N i=1 ◮ Capture stochastic process believed to have generated outputs ◮ Def. ω model parameters as r.v. ◮ Prior dist. over ω: p(ω) ◮ Likelihood: p(Y|ω, X) ◮ Posterior: p(ω|X, Y) = p(Y|ω,X)p(ω) p(Y|X)

(Bayes’ theorem)

◮ Predictive distribution given new input x∗

p(y∗|x∗, X, Y) =

  • p(y∗|x∗, ω) p(ω|X, Y)
  • posterior

◮ But... p(ω|X, Y) is often intractable

7 of 24

slide-16
SLIDE 16

Bayesian modelling and inference

◮ Observed inputs X = {xi}N i=1 and outputs Y = {yi}N i=1 ◮ Capture stochastic process believed to have generated outputs ◮ Def. ω model parameters as r.v. ◮ Prior dist. over ω: p(ω) ◮ Likelihood: p(Y|ω, X) ◮ Posterior: p(ω|X, Y) = p(Y|ω,X)p(ω) p(Y|X)

(Bayes’ theorem)

◮ Predictive distribution given new input x∗

p(y∗|x∗, X, Y) =

  • p(y∗|x∗, ω) p(ω|X, Y)
  • posterior

◮ But... p(ω|X, Y) is often intractable

7 of 24

slide-17
SLIDE 17

Approximate inference

◮ Approximate p(ω|X, Y) with simple dist. qθ(ω) ◮ Minimise divergence from posterior w.r.t. θ

KL(qθ(ω) || p(ω|X, Y))

◮ Identical to minimising

LVI(θ) := −

  • qθ(ω) log

likelihood

  • p(Y|X, ω) dω + KL(qθ(ω)||

prior

p(ω))

◮ We can approximate the predictive distribution

qθ(y∗|x∗) =

  • p(y∗|x∗, ω)qθ(ω)dω.

8 of 24

slide-18
SLIDE 18

Bayesian neural networks

◮ Place prior p(Wi):

Wi ∼ N(0, I) for i ≤ L (and write ω := {Wi}L

i=1). ◮ Output is a r.v. f

  • x, ω
  • = WLσ
  • ...W2σ
  • W1x + b1
  • ...
  • .

◮ Softmax likelihood for class.: p

  • y|x, ω
  • = softmax
  • f
  • x, ω
  • 9 of 24
slide-19
SLIDE 19

Bayesian neural networks

◮ Place prior p(Wi):

Wi ∼ N(0, I) for i ≤ L (and write ω := {Wi}L

i=1). ◮ Output is a r.v. f

  • x, ω
  • = WLσ
  • ...W2σ
  • W1x + b1
  • ...
  • .

◮ Softmax likelihood for class.: p

  • y|x, ω
  • = softmax
  • f
  • x, ω
  • r a Gaussian for regression: p
  • y|x, ω
  • = N
  • y; f
  • x, ω
  • , τ −1I
  • .

◮ But difficult to evaluate posterior

p

  • ω|X, Y
  • .

9 of 24

slide-20
SLIDE 20

Bayesian neural networks

◮ Place prior p(Wi):

Wi ∼ N(0, I) for i ≤ L (and write ω := {Wi}L

i=1). ◮ Output is a r.v. f

  • x, ω
  • = WLσ
  • ...W2σ
  • W1x + b1
  • ...
  • .

◮ Softmax likelihood for class.: p

  • y|x, ω
  • = softmax
  • f
  • x, ω
  • r a Gaussian for regression: p
  • y|x, ω
  • = N
  • y; f
  • x, ω
  • , τ −1I
  • .

◮ But difficult to evaluate posterior

p

  • ω|X, Y
  • .

9 of 24

slide-21
SLIDE 21

Bayesian neural networks

◮ Place prior p(Wi):

Wi ∼ N(0, I) for i ≤ L (and write ω := {Wi}L

i=1). ◮ Output is a r.v. f

  • x, ω
  • = WLσ
  • ...W2σ
  • W1x + b1
  • ...
  • .

◮ Softmax likelihood for class.: p

  • y|x, ω
  • = softmax
  • f
  • x, ω
  • r a Gaussian for regression: p
  • y|x, ω
  • = N
  • y; f
  • x, ω
  • , τ −1I
  • .

◮ But difficult to evaluate posterior

p

  • ω|X, Y
  • .

9 of 24

slide-22
SLIDE 22

Approximate inference in Bayesian NNs

◮ Def qθ

  • ω
  • to approximate posterior p
  • ω|X, Y
  • ◮ KL divergence to minimise:

KL

  • ω
  • || p
  • ω|X, Y

  • ω
  • log p
  • Y|X, ω

+ KL

  • ω
  • || p
  • ω
  • =: L(θ)

◮ Approximate the integral with MC integration

ω ∼ qθ(ω):

  • L(θ) := − log p
  • Y|X,

ω

  • + KL
  • ω
  • || p
  • ω
  • 10 of 24
slide-23
SLIDE 23

Approximate inference in Bayesian NNs

◮ Def qθ

  • ω
  • to approximate posterior p
  • ω|X, Y
  • ◮ KL divergence to minimise:

KL

  • ω
  • || p
  • ω|X, Y

  • ω
  • log p
  • Y|X, ω

+ KL

  • ω
  • || p
  • ω
  • =: L(θ)

◮ Approximate the integral with MC integration

ω ∼ qθ(ω):

  • L(θ) := − log p
  • Y|X,

ω

  • + KL
  • ω
  • || p
  • ω
  • 10 of 24
slide-24
SLIDE 24

Approximate inference in Bayesian NNs

◮ Def qθ

  • ω
  • to approximate posterior p
  • ω|X, Y
  • ◮ KL divergence to minimise:

KL

  • ω
  • || p
  • ω|X, Y

  • ω
  • log p
  • Y|X, ω

+ KL

  • ω
  • || p
  • ω
  • =: L(θ)

◮ Approximate the integral with MC integration

ω ∼ qθ(ω):

  • L(θ) := − log p
  • Y|X,

ω

  • + KL
  • ω
  • || p
  • ω
  • 10 of 24
slide-25
SLIDE 25

Stochastic approx. inf. in Bayesian NNs

◮ Unbiased estimator:

E

ω∼qθ(ω)

L(θ)

  • = L(θ)

◮ Converges to the same optima as L(θ) ◮ For inference, repeat:

◮ Sample

ω ∼ qθ(ω)

◮ And minimise (one step)

  • L(θ) = − log p
  • Y|X,

ω

  • + KL
  • ω
  • || p
  • ω
  • w.r.t. θ.

11 of 24

slide-26
SLIDE 26

Stochastic approx. inf. in Bayesian NNs

◮ Unbiased estimator:

E

ω∼qθ(ω)

L(θ)

  • = L(θ)

◮ Converges to the same optima as L(θ) ◮ For inference, repeat:

◮ Sample

ω ∼ qθ(ω)

◮ And minimise (one step)

  • L(θ) = − log p
  • Y|X,

ω

  • + KL
  • ω
  • || p
  • ω
  • w.r.t. θ.

11 of 24

slide-27
SLIDE 27

Stochastic approx. inf. in Bayesian NNs

◮ Unbiased estimator:

E

ω∼qθ(ω)

L(θ)

  • = L(θ)

◮ Converges to the same optima as L(θ) ◮ For inference, repeat:

◮ Sample

ω ∼ qθ(ω)

◮ And minimise (one step)

  • L(θ) = − log p
  • Y|X,

ω

  • + KL
  • ω
  • || p
  • ω
  • w.r.t. θ.

11 of 24

slide-28
SLIDE 28

Specifying q()

◮ Given variational parameters θ =

  • [mi1, ..., miK]

L

i=1:

qθ(ω) =

  • i

qθ(Wi) qθ(Wi) =

  • k

qmik(wik) qmik(wik) = pN(0, σ2) + (1 − p)N(mik, σ2)

→ k’th column of the i’th layer is a multivariate mixture of Gaussians

◮ With small enough σ2, in practice equivalent to

zi,j ∼ Bernoulli(pi) for i = 1, ..., L, j = 1, ..., Ki−1 Wi = Mi · diag([zi,j]Ki

j=1)

with zi,j Bernoulli r.v.s.

12 of 24

slide-29
SLIDE 29

Deep learning as approx. inference

In summary:

Minimise divergence between qθ(ω) and p(ω|X, Y):

◮ Repeat:

◮ Sample

zi,j ∼ Bernoulli(pi) and set

  • Wi = Mi · diag([

zi,j]Ki

j=1)

  • ω = {

Wi}L

i=1

◮ Minimise (one step)

  • L(θ) = − log p
  • Y|X,

ω

  • + KL
  • ω
  • || p
  • ω
  • w.r.t. θ = {Mi}L

i=1 (set of matrices).

13 of 24

slide-30
SLIDE 30

Deep learning as approx. inference

In summary:

Minimise divergence between qθ(ω) and p(ω|X, Y):

◮ Repeat:

◮ = Randomly set columns of Mi to zero ◮ Minimise (one step)

  • L(θ) = − log p
  • Y|X,

ω

  • + KL
  • ω
  • || p
  • ω
  • w.r.t. θ = {Mi}L

i=1 (set of matrices).

13 of 24

slide-31
SLIDE 31

Deep learning as approx. inference

In summary:

Minimise divergence between qθ(ω) and p(ω|X, Y):

◮ Repeat:

◮ = Randomly set units of the network to zero ◮ Minimise (one step)

  • L(θ) = − log p
  • Y|X,

ω

  • + KL
  • ω
  • || p
  • ω
  • w.r.t. θ = {Mi}L

i=1 (set of matrices).

13 of 24

slide-32
SLIDE 32

Deep learning as approx. inference

Sounds familiar?

  • L(θ) =

= loss

  • − log p
  • Y|X,

ω

  • +

= L2 reg

  • KL
  • ω
  • || p
  • ω
  • Implementing VI with qθ(·) above = implementing dropout in

deep network

14 of 24

slide-33
SLIDE 33

Other stochastic reg. techniques

◮ Multiplicative Gaussian noise (Srivastava et al. 2014) – ◮ Multiply network units by N(1, 1) ◮ Same performance as dropout

  • Multiplicative Gaussian noise as approximate inference1

zi,j ∼ N(1, 1) for i = 1, ..., L, j = 1, ..., Ki−1 Wi = Mi · diag([zi,j]Ki

j=1)

qθ(ω) =

  • qMi(Wi)

Similarly for drop-connect (Wan et al., 2013), etc.

1See Gal and Ghahramani (2015) and Kingma et al. (2015) 15 of 24

slide-34
SLIDE 34

Back to recurrent neural networks

xt ht xt−1 ht−1 xt+1 ht+1

Figure : A Recurrent Neural Network

16 of 24

slide-35
SLIDE 35

Now, in RNNs...

◮ Input sequence of vectors x = {x1, ..., xT} with T time steps ◮ Let ω = {all weight matrices in the model} ◮ Define ht = fω h (xt, ht−1)

◮ single recurrent unit transition. E.g. tanh of affine

transformation: tanh(Wxt + Uht−1 + b)

◮ Set fω y (hT) = fω y (fω h (xT, ...fω h (x1, h0)...))

◮ model output (e.g. affine transformation of last state, or function

  • f all states)

◮ Lastly, define p(y|fω y (hT))

◮ model likelihood. E.g. N(y; fω

y (hT), σ2)

◮ Similarly for LSTM, GRU

17 of 24

slide-36
SLIDE 36

Now, in RNNs...

◮ Looking at the variational lower bound, we have:

  • q(ω) log p(y|fω

y (hT))dω =

  • q(ω) log p
  • y

y

h

  • xT, ...fω

h (x1, h0)...

  • dω,

◮ Using MC integration with

ω ∼ q(ω), LVI ≈ − log p

  • y
  • f

ω y

  • f

ω h (xT, ...f ω h (x1, h0)...)

  • + KL
  • ω
  • || p
  • ω
  • .

18 of 24

slide-37
SLIDE 37

Dropout in RNNs

Objective: − log p

  • y
  • f
  • ω

y

  • f
  • ω

h

(xT, ...f

  • ω

h

(x1, h0)...)

  • + ...
  • ω

∼ q(ω)

◮ In practice, use the same dropout mask at each time step

xt ht xt−1 ht−1 xt+1 ht+1

Figure : Bayesian motivated dropout in RNNs (colours = dropout masks)

19 of 24

slide-38
SLIDE 38

Word embedding dropout

◮ With continuous inputs we apply dropout to the input layer

(place a distr. over weight matrix)

◮ But not for models with discrete inputs... ◮ Word embeddings: input can be seen as either the word

  • embed. itself, or a “one-hot” encoding times an embed. matrix

◮ Optimising embedding matrix can lead to overfitting... ◮ Let’s apply dropout to the one-hot encoded vectors

20 of 24

slide-39
SLIDE 39

Word embedding dropout

◮ In practice, drop words at random throughout the sentence

◮ Randomly set embedding matrix rows to zero – entire word

embeddings

◮ Mask is repeated at each time step → drop the same words

throughout the sequence

◮ i.e. drop word types at random rather than word tokens

◮ For example, “the dog and the cat” might become “— dog and

— cat” or “the — and the cat”, but never “— dog and the cat”.

◮ Can be interpreted as encouraging model to not “depend” on

single words.

21 of 24

slide-40
SLIDE 40

Working dropout in recurrent layers

Some results (much more in paper):

◮ Sentiment analysis (Pang & Lee, 2005)

Figure : LSTM test error

◮ Language model (Penn Treebank)

22 of 24

slide-41
SLIDE 41

Working dropout in recurrent layers

Some results (much more in paper):

◮ Sentiment analysis (Pang & Lee, 2005)

Figure : GRU test error

◮ Language model (Penn Treebank)

22 of 24

slide-42
SLIDE 42

Working dropout in recurrent layers

Some results (much more in paper):

◮ Sentiment analysis (Pang & Lee, 2005) ◮ Language model (Penn Treebank)

22 of 24

slide-43
SLIDE 43

Working dropout in recurrent layers

Some results (much more in paper):

◮ Sentiment analysis (Pang & Lee, 2005) ◮ Language model (Penn Treebank)

Figure : 2 layers LSTM, 200 units

22 of 24

slide-44
SLIDE 44

Many unanswered questions left

◮ Practical deep learning uncertainty?

◮ Capture language ambiguity? Image Source: cs224d.stanford.edu/lectures/CS224d-Lecture8.pdf ◮ Weight uncertainty for model debugging?

◮ Principled extensions of deep learning?

◮ New appr. distributions = new stochastic reg. techniques? ◮ Model compression: Wi ∼ discrete distribution w. continuous

base measure?

23 of 24

slide-45
SLIDE 45

Many unanswered questions left

◮ Practical deep learning uncertainty?

◮ Capture language ambiguity? ◮ Weight uncertainty for model debugging?

◮ Principled extensions of deep learning?

◮ New appr. distributions = new stochastic reg. techniques? ◮ Model compression: Wi ∼ discrete distribution w. continuous

base measure?

23 of 24

slide-46
SLIDE 46

Many unanswered questions left

◮ Practical deep learning uncertainty?

◮ Capture language ambiguity? ◮ Weight uncertainty for model debugging?

◮ Principled extensions of deep learning?

◮ New appr. distributions = new stochastic reg. techniques?

qθ(ω) =?

◮ Model compression: Wi ∼ discrete distribution w. continuous

base measure?

23 of 24

slide-47
SLIDE 47

Many unanswered questions left

◮ Practical deep learning uncertainty?

◮ Capture language ambiguity? ◮ Weight uncertainty for model debugging?

◮ Principled extensions of deep learning?

◮ New appr. distributions = new stochastic reg. techniques? ◮ Model compression: Wi ∼ discrete distribution w. continuous

base measure?

23 of 24

slide-48
SLIDE 48

Many unanswered questions left

◮ Practical deep learning uncertainty?

◮ Capture language ambiguity? ◮ Weight uncertainty for model debugging?

◮ Principled extensions of deep learning?

◮ New appr. distributions = new stochastic reg. techniques? ◮ Model compression: Wi ∼ discrete distribution w. continuous

base measure?

Work in progress!

23 of 24

slide-49
SLIDE 49

New horizons

Most exciting is work to come:

◮ Practical uncertainty in deep learning applications ◮ Principled extensions to deep learning tools ◮ Hybrid deep learning – Bayesian models

and much, much, more.

24 of 24

slide-50
SLIDE 50

New horizons

Most exciting is work to come:

◮ Practical uncertainty in deep learning applications ◮ Principled extensions to deep learning tools ◮ Hybrid deep learning – Bayesian models

and much, much, more. Thank you for listening.

24 of 24