INF5820: Language technological applications Gated RNNs (3:2) - - PowerPoint PPT Presentation
INF5820: Language technological applications Gated RNNs (3:2) - - PowerPoint PPT Presentation
INF5820: Language technological applications Gated RNNs (3:2) Taraka Rama University of Oslo 30 October 2018 Agenda Break only for 5 minutes Lecture ends at 3:45 GRU LSTM Connections An analysis of why GRUs address
Agenda
◮ Break only for 5 minutes ◮ Lecture ends at 3:45 ◮ GRU ◮ LSTM ◮ Connections ◮ An analysis of why GRUs address vanishing gradients ◮ Applications of Gated RNNs
2
Some external resources
Many online resources on Gated RNNs:
◮ Fourth part on GRU and LSTM: https://tinyurl.com/z9j4ws9 ◮ Understanding Long-Short Term Memory Network:
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
3
Recap
Equations of basic RNN
◮ ht = f(Wxhxt + Whhht−1 + bh) ◮ ˆ
yt = Whyht + by
◮ pt = softmax( ˆ
yt)
4
Recap
Equations of basic RNN
◮ ht = f(Wxhxt + Whhht−1 + bh) ◮ ˆ
yt = Whyht + by
◮ pt = softmax( ˆ
yt)
◮ xt is the input vector at time t ◮ ht is the hidden state or memory at time t ◮ f(·) is the non-linearity such as tanh or sigmoid function ◮ ˆ
yt is the output transformation that maps ht to the number of output classes at each time step
◮ pt is the output of softmax function that transforms the values in ˆ
yt to probabilities
4
Recap: Vanishing/Exploding Gradient
◮ Fix k = 18 and S = 20 ◮ All W matrices and inputs are of size 1 × 1. (Simplication!!!)
5
Recap: Vanishing/Exploding Gradient
Basic recursion
◮ dhrawt = At + Btdhnextt ◮ dhnextt−1 = Whhdhrawt
5
Recap: Vanishing/Exploding Gradient
Basic recursion
◮ dhrawt = At + Btdhnextt ◮ dhnextt−1 = Whhdhrawt ◮ dhraw18 = A18 + B17dhnext18
5
Recap: Vanishing/Exploding Gradient
Basic recursion
◮ dhrawt = At + Btdhnextt ◮ dhnextt−1 = Whhdhrawt ◮ dhraw18 = A18 + B17dhnext18 ◮ ≈ A18 + B18Whhdhraw19
5
Recap: Vanishing/Exploding Gradient
Basic recursion
◮ dhrawt = At + Btdhnextt ◮ dhnextt−1 = Whhdhrawt ◮ dhraw18 = A18 + B17dhnext18 ◮ ≈ A18 + B18Whhdhraw19 ◮ ≈ A18 + B18Whh(A19 + B19dhnext19)
5
Recap: Vanishing/Exploding Gradient
Basic recursion
◮ dhrawt = At + Btdhnextt ◮ dhnextt−1 = Whhdhrawt ◮ dhraw18 = A18 + B17dhnext18 ◮ ≈ A18 + B18Whhdhraw19 ◮ ≈ A18 + B18Whh(A19 + B19dhnext19) ◮ ≈ A18 + B18Whh(A19 + B19Whhdhraw20)
5
Recap: Vanishing/Exploding Gradient
Basic recursion
◮ dhrawt = At + Btdhnextt ◮ dhnextt−1 = Whhdhrawt ◮ dhraw18 = A18 + B17dhnext18 ◮ ≈ A18 + B18Whhdhraw19 ◮ ≈ A18 + B18Whh(A19 + B19dhnext19) ◮ ≈ A18 + B18Whh(A19 + B19Whhdhraw20) ◮ The right most term: dhraw20 is always left multiplied by Whh
5
Recap: Vanishing/Exploding Gradient
Basic recursion
◮ dhrawt = At + Btdhnextt ◮ dhnextt−1 = Whhdhrawt ◮ dhraw18 = A18 + B17dhnext18 ◮ ≈ A18 + B18Whhdhraw19 ◮ ≈ A18 + B18Whh(A19 + B19dhnext19) ◮ ≈ A18 + B18Whh(A19 + B19Whhdhraw20) ◮ The right most term: dhraw20 is always left multiplied by Whh ◮ When k = 1 then dWhh is computed using (Whh)19
5
Problem with RNNs
◮ Vanishing/Exploding gradients ◮ Not sure if
◮ There is no dependency between different time steps ◮ Wrong initilization of parameters
◮ Many solutions (Pascanau 2012). ◮ Today: We will analyze why Gates address Vanishing gradient problem. ◮ Main solution is Gated RNNs
6
Gated RNNs are better than RNNs at MT
Sutskever et al. (2014) Sequence to Sequence Learning with Neural Networks (4903 citations according to Google Scholar as of today)
7
LSTM vs. GRU vs. CNN?
Yin et al. 2017: Comparative Study of CNN and RNN for Natural Language Processing https://arxiv.org/pdf/1702.01923.pdf
8
Why the name gated?
◮ Analogous to gates in circuits
9
Why the name gated?
◮ Analogous to gates in circuits ◮ Allow parts of the input vector or hidden state to pass over to the next
state
9
Why the name gated?
◮ Analogous to gates in circuits ◮ Allow parts of the input vector or hidden state to pass over to the next
state
◮ x = g × y + (1 − g) × z where g ∈ {0, 1}
9
Why the name gated?
◮ Analogous to gates in circuits ◮ Allow parts of the input vector or hidden state to pass over to the next
state
◮ x = g × y + (1 − g) × z where g ∈ {0, 1} ◮ In vector notation: xi = gi × yi + (1 − gi) × zi where gi ∈ {0, 1}
9
Why the name gated?
◮ Analogous to gates in circuits ◮ Allow parts of the input vector or hidden state to pass over to the next
state
◮ x = g × y + (1 − g) × z where g ∈ {0, 1} ◮ In vector notation: xi = gi × yi + (1 − gi) × zi where gi ∈ {0, 1} ◮ A simpler way to write: Hadamard Product
9
Hadamard product
◮ A cool name for elementwise matrix multiplication ◮ Can be represented as (in Machine Learning or Physics) or ◦ (usually
in NLP)
◮ C = A B where Ci,j = Ai,j × Bi,j
◮ Matrix dimensions do not change ◮ Can be performed only between identical dimension matrices 10
0/1 problem
◮ 0/1 gate variable is not differentiable
11
0/1 problem
◮ 0/1 gate variable is not differentiable ◮ Non-differentiable variable blocks back propagation
11
0/1 problem
◮ 0/1 gate variable is not differentiable ◮ Non-differentiable variable blocks back propagation ◮ No back propagation implies no gradient descent
11
0/1 problem
◮ 0/1 gate variable is not differentiable ◮ Non-differentiable variable blocks back propagation ◮ No back propagation implies no gradient descent ◮ What is the solution?
11
0/1 problem
◮ 0/1 gate variable is not differentiable ◮ Non-differentiable variable blocks back propagation ◮ No back propagation implies no gradient descent ◮ What is the solution? ◮ sigmoid or tanh function
11
Gated Recurrent Unit (GRU)
◮ GRU has two gates when compared to RNN
12
Gated Recurrent Unit (GRU)
◮ GRU has two gates when compared to RNN ◮ Update gate z and Reset gate r
12
Gated Recurrent Unit (GRU)
◮ GRU has two gates when compared to RNN ◮ Update gate z and Reset gate r ◮ Both gates are computed using different weights just like RNN at each
time step
12
Gated Recurrent Unit (GRU)
◮ GRU has two gates when compared to RNN ◮ Update gate z and Reset gate r ◮ Both gates are computed using different weights just like RNN at each
time step
◮ Update gate: zt = σ(W (z)xt + U(z)ht−1)
12
Gated Recurrent Unit (GRU)
◮ GRU has two gates when compared to RNN ◮ Update gate z and Reset gate r ◮ Both gates are computed using different weights just like RNN at each
time step
◮ Update gate: zt = σ(W (z)xt + U(z)ht−1) ◮ Reset gate: rt = σ(W (r)xt + U(r)ht−1)
12
GRU computations
◮ Update gate: zt = σ(W (z)xt + U(z)ht−1)
13
GRU computations
◮ Update gate: zt = σ(W (z)xt + U(z)ht−1) ◮ Reset gate: rt = σ(W (r)xt + U(r)ht−1)
13
GRU computations
◮ Update gate: zt = σ(W (z)xt + U(z)ht−1) ◮ Reset gate: rt = σ(W (r)xt + U(r)ht−1) ◮ Internal memory content: ˜
ht = tanh(Wxt + rt ◦ Uht−1)
13
GRU computations
◮ Update gate: zt = σ(W (z)xt + U(z)ht−1) ◮ Reset gate: rt = σ(W (r)xt + U(r)ht−1) ◮ Internal memory content: ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ Final memory: ht = zt ◦ ht−1 + (1 − zt) ◦ ˜
ht
13
GRU (Visualization)
(http://web.stanford.edu/class/cs224n/lectures/lecture9.pdf)
14
Is it complex?
◮ Think of all weights, variables and input as 1 × 1 matrix.
15
Is it complex?
◮ Think of all weights, variables and input as 1 × 1 matrix. ◮ Easy to analyze
15
Is it complex?
◮ Think of all weights, variables and input as 1 × 1 matrix. ◮ Easy to analyze ◮ Boils down to single variable mathematics.
15
Reset Gate and GRU
◮ Update gate: zt = σ(W (z)xt + U(z)ht−1) ◮ Reset gate: rt = σ(W (r)xt + U(r)ht−1) ◮ Internal memory content: ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ Final memory: ht = zt ◦ ht−1 + (1 − zt) ◦ ˜
ht If rt = 0
◮ Ignore previous hidden memory ht−1. Equivalently, forget the past.
16
Reset Gate and GRU
◮ Update gate: zt = σ(W (z)xt + U(z)ht−1) ◮ Reset gate: rt = σ(W (r)xt + U(r)ht−1) ◮ Internal memory content: ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ Final memory: ht = zt ◦ ht−1 + (1 − zt) ◦ ˜
ht If rt = 0
◮ Ignore previous hidden memory ht−1. Equivalently, forget the past. ◮ Current hidden memory ht is partly dependent on current input xt
16
Reset Gate and GRU
◮ Update gate: zt = σ(W (z)xt + U(z)ht−1) ◮ Reset gate: rt = σ(W (r)xt + U(r)ht−1) ◮ Internal memory content: ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ Final memory: ht = zt ◦ ht−1 + (1 − zt) ◦ ˜
ht If rt = 1 Use everything in the previous hidden memory ht−1. Looks like RNN is back.
16
Update Gate and GRU
◮ Update gate: zt = σ(W (z)xt + U(z)ht−1) ◮ Reset gate: rt = σ(W (r)xt + U(r)ht−1) ◮ Internal memory content: ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ Final memory: ht = zt ◦ ht−1 + (1 − zt) ◦ ˜
ht If zt = 0
◮ Ignore previous hidden memory ht−1
17
Update Gate and GRU
◮ Update gate: zt = σ(W (z)xt + U(z)ht−1) ◮ Reset gate: rt = σ(W (r)xt + U(r)ht−1) ◮ Internal memory content: ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ Final memory: ht = zt ◦ ht−1 + (1 − zt) ◦ ˜
ht If zt = 0
◮ Ignore previous hidden memory ht−1 ◮ Current hidden memory ht is only dependent on the new memory state
˜ ht
17
Update Gate and GRU
◮ Update gate: zt = σ(W (z)xt + U(z)ht−1) ◮ Reset gate: rt = σ(W (r)xt + U(r)ht−1) ◮ Internal memory content: ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ Final memory: ht = zt ◦ ht−1 + (1 − zt) ◦ ˜
ht If zt = 1
◮ Copy the previous hidden memory ht−1 ◮ Current input is ignored completely
17
What is it for NLP?
◮ Short term dependencies means older history should be ignored (rt → 0)
18
What is it for NLP?
◮ Short term dependencies means older history should be ignored (rt → 0) ◮ Long term dependencies means older history should be retained as
much as possible (zt → 1)
18
Is RNN a special case of GRU?
Can you think when RNN is a special case of GRU?
19
Is RNN a special case of GRU?
Can you think when RNN is a special case of GRU?
◮ When rt = 1 and zt = 0:
19
Is RNN a special case of GRU?
Can you think when RNN is a special case of GRU?
◮ When rt = 1 and zt = 0:
◮ ˜
ht = tanh(Wxt + rt ◦ Uht−1)
19
Is RNN a special case of GRU?
Can you think when RNN is a special case of GRU?
◮ When rt = 1 and zt = 0:
◮ ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ ht = 0 ∗ ht−1 + 1 ∗ ˜
ht
19
Is RNN a special case of GRU?
Can you think when RNN is a special case of GRU?
◮ When rt = 1 and zt = 0:
◮ ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ ht = 0 ∗ ht−1 + 1 ∗ ˜
ht
◮ ht = ˜
ht
19
Is RNN a special case of GRU?
Can you think when RNN is a special case of GRU?
◮ When rt = 1 and zt = 0:
◮ ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ ht = 0 ∗ ht−1 + 1 ∗ ˜
ht
◮ ht = ˜
ht
◮ W is Wxh and U is Whh 19
Is RNN a special case of GRU?
Can you think when RNN is a special case of GRU?
◮ When rt = 1 and zt = 0:
◮ ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ ht = 0 ∗ ht−1 + 1 ∗ ˜
ht
◮ ht = ˜
ht
◮ W is Wxh and U is Whh
◮ Back to Vanilla RNN... Is it a problem?
19
Is RNN a special case of GRU?
Can you think when RNN is a special case of GRU?
◮ When rt = 1 and zt = 0:
◮ ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ ht = 0 ∗ ht−1 + 1 ∗ ˜
ht
◮ ht = ˜
ht
◮ W is Wxh and U is Whh
◮ Back to Vanilla RNN... Is it a problem? ◮ rt = 1 and zt = 0 should happen at all the time steps. Very low chance
- f happening!!
19
Is RNN a special case of GRU?
Can you think when RNN is a special case of GRU?
◮ When rt = 1 and zt = 0:
◮ ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ ht = 0 ∗ ht−1 + 1 ∗ ˜
ht
◮ ht = ˜
ht
◮ W is Wxh and U is Whh
◮ Back to Vanilla RNN... Is it a problem? ◮ rt = 1 and zt = 0 should happen at all the time steps. Very low chance
- f happening!!
◮ So we are fine...
19
Does GRU fixes vanishing gradient?
Two ways to analyze this:
◮ What does the hidden state look like in GRU after few time steps?
20
Does GRU fixes vanishing gradient?
Two ways to analyze this:
◮ What does the hidden state look like in GRU after few time steps? ◮ How does gradient calculation look like?
20
Does GRU fixes vanishing gradient?
Two ways to analyze this:
◮ What does the hidden state look like in GRU after few time steps? ◮ How does gradient calculation look like?
◮ Requires multivariate chain rule!!! 20
GRU: Hidden state calculation
We will expand ht from t = 20 to t = 19. Again, all our matrices are 1 × 1 to simplify.
◮ h20 = z20h19 + (1 − z20)˜
h20
21
GRU: Hidden state calculation
We will expand ht from t = 20 to t = 19. Again, all our matrices are 1 × 1 to simplify.
◮ h20 = z20h19 + (1 − z20)˜
h20
◮ h19 = z19h18 + (1 − z19)˜
h19
21
GRU: Hidden state calculation
We will expand ht from t = 20 to t = 19. Again, all our matrices are 1 × 1 to simplify.
◮ h20 = z20h19 + (1 − z20)˜
h20
◮ h19 = z19h18 + (1 − z19)˜
h19
◮ Therefore, h20 = z20z19h18 + z20(1 − z19)˜
h19 + z20(1 − z20)˜ h20
21
GRU: Hidden state calculation
We will expand ht from t = 20 to t = 19. Again, all our matrices are 1 × 1 to simplify.
◮ h20 = z20h19 + (1 − z20)˜
h20
◮ h19 = z19h18 + (1 − z19)˜
h19
◮ Therefore, h20 = z20z19h18 + z20(1 − z19)˜
h19 + z20(1 − z20)˜ h20
◮ If z19 → 0 then only part of h18 is relevant through ˜
h19
21
GRU: Hidden state calculation
We will expand ht from t = 20 to t = 19. Again, all our matrices are 1 × 1 to simplify.
◮ h20 = z20h19 + (1 − z20)˜
h20
◮ h19 = z19h18 + (1 − z19)˜
h19
◮ Therefore, h20 = z20z19h18 + z20(1 − z19)˜
h19 + z20(1 − z20)˜ h20
◮ If z19 → 0 then only part of h18 is relevant through ˜
h19
◮ If z19 → 1 then ˜
h19 is ignored completely. It is as if the input x19 is completely ignored.
21
GRU: Hidden state calculation
We will expand ht from t = 20 to t = 19. Again, all our matrices are 1 × 1 to simplify.
◮ h20 = z20h19 + (1 − z20)˜
h20
◮ h19 = z19h18 + (1 − z19)˜
h19
◮ Therefore, h20 = z20z19h18 + z20(1 − z19)˜
h19 + z20(1 − z20)˜ h20
◮ If z19 → 0 then only part of h18 is relevant through ˜
h19
◮ If z19 → 1 then ˜
h19 is ignored completely. It is as if the input x19 is completely ignored.
◮ There is a jump between timesteps 18 and 20.
21
GRU: Hidden state calculation
We will expand ht from t = 20 to t = 19. Again, all our matrices are 1 × 1 to simplify.
◮ h20 = z20h19 + (1 − z20)˜
h20
◮ h19 = z19h18 + (1 − z19)˜
h19
◮ Therefore, h20 = z20z19h18 + z20(1 − z19)˜
h19 + z20(1 − z20)˜ h20
◮ If z19 → 0 then only part of h18 is relevant through ˜
h19
◮ If z19 → 1 then ˜
h19 is ignored completely. It is as if the input x19 is completely ignored.
◮ There is a jump between timesteps 18 and 20. ◮ It is a shortcut across time.
21
Single variable chain rule
Question If z is dependent on y and y is dependent on x what is dz
dx?
22
Single variable chain rule
Question If z is dependent on y and y is dependent on x what is dz
dx?
Answer
dz dx = dz dy dy dx
22
Single variable chain rule
Question If z is dependent on y and y is dependent on x what is dz
dx?
Answer
dz dx = dz dy dy dx
Example
◮ z = exp(y) ◮ y = x2 ◮ dz dy = exp(y) ◮ dy dx = 2x ◮ dz dx = exp(y) × 2x
22
Multivariate chain rule (Very Important!)
23
Multivariate chain rule example
◮ z = x + y
24
Multivariate chain rule example
◮ z = x + y ◮ x = t2, y = exp(t)
24
Multivariate chain rule example
◮ z = x + y ◮ x = t2, y = exp(t) ◮ dz dt = ∂z ∂y dy dt + ∂z ∂x dx dt
24
Multivariate chain rule example
◮ z = x + y ◮ x = t2, y = exp(t) ◮ dz dt = ∂z ∂y dy dt + ∂z ∂x dx dt ◮ ∂z ∂y = 1
24
Multivariate chain rule example
◮ z = x + y ◮ x = t2, y = exp(t) ◮ dz dt = ∂z ∂y dy dt + ∂z ∂x dx dt ◮ ∂z ∂y = 1 ◮ ∂z ∂x = 1
24
Multivariate chain rule example
◮ z = x + y ◮ x = t2, y = exp(t) ◮ dz dt = ∂z ∂y dy dt + ∂z ∂x dx dt ◮ ∂z ∂y = 1 ◮ ∂z ∂x = 1 ◮ dy dt = exp(t)
24
Multivariate chain rule example
◮ z = x + y ◮ x = t2, y = exp(t) ◮ dz dt = ∂z ∂y dy dt + ∂z ∂x dx dt ◮ ∂z ∂y = 1 ◮ ∂z ∂x = 1 ◮ dy dt = exp(t) ◮ dx dt = 2t
24
Multivariate chain rule example
◮ z = x + y ◮ x = t2, y = exp(t) ◮ dz dt = ∂z ∂y dy dt + ∂z ∂x dx dt ◮ ∂z ∂y = 1 ◮ ∂z ∂x = 1 ◮ dy dt = exp(t) ◮ dx dt = 2t ◮ Finally, dz dt = exp(t) + 2t
24
Multivariate chain rule in summary
What is the derivative of h3 with respect to Whh? ( ∂h3
∂Whh )
25
Multivariate chain rule in summary
What is the derivative of h3 with respect to Whh? ( ∂h3
∂Whh ) ◮ Enumerate all paths from Whh to h3 and sum the derivatives
25
Multivariate chain rule in summary
What is the derivative of h3 with respect to Whh? ( ∂h3
∂Whh ) ◮ Enumerate all paths from Whh to h3 and sum the derivatives ◮ A arrow between two nodes is derivative between the two nodes.
25
Multivariate chain rule in summary
What is the derivative of h3 with respect to Whh? ( ∂h3
∂Whh ) ◮ Enumerate all paths from Whh to h3 and sum the derivatives ◮ A arrow between two nodes is derivative between the two nodes.
- 1. Whh → h3 directly
25
Multivariate chain rule in summary
What is the derivative of h3 with respect to Whh? ( ∂h3
∂Whh ) ◮ Enumerate all paths from Whh to h3 and sum the derivatives ◮ A arrow between two nodes is derivative between the two nodes.
- 1. Whh → h3 directly
- 2. Whh → h2 → h3
25
Multivariate chain rule in summary
What is the derivative of h3 with respect to Whh? ( ∂h3
∂Whh ) ◮ Enumerate all paths from Whh to h3 and sum the derivatives ◮ A arrow between two nodes is derivative between the two nodes.
- 1. Whh → h3 directly
- 2. Whh → h2 → h3
- 3. Whh → h1 → h2 → h3
25
Multivariate chain rule in summary
What is the derivative of h3 with respect to Whh? ( ∂h3
∂Whh ) ◮ Enumerate all paths from Whh to h3 and sum the derivatives ◮ A arrow between two nodes is derivative between the two nodes.
- 1. Whh → h3 directly
- 2. Whh → h2 → h3
- 3. Whh → h1 → h2 → h3
◮ h2 → h3 means ∂h3 ∂h2
25
What does it mean?
What is the derivative of h3 with respect to Whh?
26
What does it mean?
What is the derivative of h3 with respect to Whh?
◮ ∂h3 ∂Whh is sum of the following
26
What does it mean?
What is the derivative of h3 with respect to Whh?
◮ ∂h3 ∂Whh is sum of the following
1.
∂h3 ∂Whh
26
What does it mean?
What is the derivative of h3 with respect to Whh?
◮ ∂h3 ∂Whh is sum of the following
1.
∂h3 ∂Whh
2.
∂h3 ∂h2 ∂h2 ∂Whh
26
What does it mean?
What is the derivative of h3 with respect to Whh?
◮ ∂h3 ∂Whh is sum of the following
1.
∂h3 ∂Whh
2.
∂h3 ∂h2 ∂h2 ∂Whh
3.
∂h3 ∂h2 ∂h2 ∂h1 ∂h1 ∂Whh
26
RNN expansion
Equations of basic RNN
◮ ht = tanh(Wxhxt + Whhht−1 + bh) ◮ ˆ
yt = Whyht + by
◮ pt = softmax( ˆ
yt)
◮ ∂L ∂Whh = T t=1 ∂Lt ∂Whh (Sum the loss over each time step t)
27
RNN expansion
Equations of basic RNN
◮ ht = tanh(Wxhxt + Whhht−1 + bh) ◮ ˆ
yt = Whyht + by
◮ pt = softmax( ˆ
yt)
◮ ∂L ∂Whh = T t=1 ∂Lt ∂Whh (Sum the loss over each time step t) ◮ ∂Lt ∂Whh = ∂Lt ∂ˆ yt ∂ˆ yt ∂ht
t
k=1 ∂ht ∂hk ∂hk ∂Whh (Each ht is computed using Whh.
Apply multivariate chain rule.)
27
RNN expansion
Equations of basic RNN
◮ ht = tanh(Wxhxt + Whhht−1 + bh) ◮ ˆ
yt = Whyht + by
◮ pt = softmax( ˆ
yt)
◮ ∂L ∂Whh = T t=1 ∂Lt ∂Whh (Sum the loss over each time step t) ◮ ∂Lt ∂Whh = ∂Lt ∂ˆ yt ∂ˆ yt ∂ht
t
k=1 ∂ht ∂hk ∂hk ∂Whh (Each ht is computed using Whh.
Apply multivariate chain rule.)
◮ ∂ht ∂hk = t j=k+1 ∂hj ∂hj−1 (Each hj is immediately dependent on hj−1.
Apply single variable chain rule.)
27
RNN expansion
Equations of basic RNN
◮ ht = tanh(Wxhxt + Whhht−1 + bh) ◮ ˆ
yt = Whyht + by
◮ pt = softmax( ˆ
yt)
◮ ∂L ∂Whh = T t=1 ∂Lt ∂Whh (Sum the loss over each time step t) ◮ ∂Lt ∂Whh = ∂Lt ∂ˆ yt ∂ˆ yt ∂ht
t
k=1 ∂ht ∂hk ∂hk ∂Whh (Each ht is computed using Whh.
Apply multivariate chain rule.)
◮ ∂ht ∂hk = t j=k+1 ∂hj ∂hj−1 (Each hj is immediately dependent on hj−1.
Apply single variable chain rule.)
◮ ∂hj ∂hj−1 = W(1 − (hj)2) (Derivative of tanh(x) function is
1 − (tanh(x))2)
27
RNN expansion
Equations of basic RNN
◮ ht = tanh(Wxhxt + Whhht−1 + bh) ◮ ˆ
yt = Whyht + by
◮ pt = softmax( ˆ
yt)
◮ ∂L ∂Whh = T t=1 ∂Lt ∂Whh (Sum the loss over each time step t) ◮ ∂Lt ∂Whh = ∂Lt ∂ˆ yt ∂ˆ yt ∂ht
t
k=1 ∂ht ∂hk ∂hk ∂Whh (Each ht is computed using Whh.
Apply multivariate chain rule.)
◮ ∂ht ∂hk = t j=k+1 ∂hj ∂hj−1 (Each hj is immediately dependent on hj−1.
Apply single variable chain rule.)
◮ ∂hj ∂hj−1 = W(1 − (hj)2) (Derivative of tanh(x) function is
1 − (tanh(x))2)
◮ Repeated multiplication of Whh when t >> k causes vanishing gradient
- r exploding gradient
27
GRU expansion
How does
∂hj ∂hj−1 look in the case of GRU? ◮ Is not so simple as in the case of RNN. Why?
28
GRU expansion
How does
∂hj ∂hj−1 look in the case of GRU? ◮ Is not so simple as in the case of RNN. Why? ◮ GRU has two gates that play a role in the computation of hidden state
ht
28
GRU expansion
How does
∂hj ∂hj−1 look in the case of GRU? ◮ Is not so simple as in the case of RNN. Why? ◮ GRU has two gates that play a role in the computation of hidden state
ht
◮ How does it look?
28
GRU expansion
How does
∂hj ∂hj−1 look in the case of GRU? ◮ Is not so simple as in the case of RNN. Why? ◮ GRU has two gates that play a role in the computation of hidden state
ht
◮ How does it look? ◮ [z((1 − z)W z(1 − h) + 1)] + [r(1 − z)(1 − h2)U(1 + ht−1(1 − r)W r)]
28
Analysis of the partial derivative
[z((1 − z)W z(1 − h) + 1)] + [r(1 − z)(1 − h2)U(1 + ht−1(1 − r)W r)]
◮ When z → 0: First term vanishes. Only dependent on reset gate.
29
Analysis of the partial derivative
[z((1 − z)W z(1 − h) + 1)] + [r(1 − z)(1 − h2)U(1 + ht−1(1 − r)W r)]
◮ When z → 0: First term vanishes. Only dependent on reset gate. ◮ When z → 1: 1. Repeated multiplication by 1 does not effect the
- gradient. (Good for long-term dependencies)
29
Analysis of the partial derivative
[z((1 − z)W z(1 − h) + 1)] + [r(1 − z)(1 − h2)U(1 + ht−1(1 − r)W r)]
◮ When z → 0: First term vanishes. Only dependent on reset gate. ◮ When z → 1: 1. Repeated multiplication by 1 does not effect the
- gradient. (Good for long-term dependencies)
◮ When r → 0: Second term vanishes. (Good for short term
dependencies)
29
Analysis of the partial derivative
[z((1 − z)W z(1 − h) + 1)] + [r(1 − z)(1 − h2)U(1 + ht−1(1 − r)W r)]
◮ When z → 0: First term vanishes. Only dependent on reset gate. ◮ When z → 1: 1. Repeated multiplication by 1 does not effect the
- gradient. (Good for long-term dependencies)
◮ When r → 0: Second term vanishes. (Good for short term
dependencies)
◮ No guarantee that the gradient is always going to be multiplied by a
weight matrix.
29
Conclusion
When r = 1 and z = 0, we get back the RNN derivative leading to gradient problems. What is the chance for that to happen?
30
Conclusion
When r = 1 and z = 0, we get back the RNN derivative leading to gradient problems. What is the chance for that to happen? No need to worry. It has to happen at each timestep. Very low chance.
30
Conclusion
When r = 1 and z = 0, we get back the RNN derivative leading to gradient problems. What is the chance for that to happen? No need to worry. It has to happen at each timestep. Very low chance. Clearly GRU addresses vanishing gradient problem
30
Long-Short Term Memory Network
◮ Is more complex than GRU. Has one extra gate and one extra internal
state.
◮ Input gate: it = σ(W (i)xt + U(i)ht−1) ◮ Forget gate: ft = σ(W (f)xt + U(f)ht−1) ◮ Output gate: ot = σ(W (o)xt + U(o)ht−1) ◮ Candidate gate: gt = σ(W (g)xt + U(g)ht−1) ◮ Internal memory content: ct = ft ◦ ct−1 + it ◦ gt ◮ Final hidden state: ht = ot ◦ tanh(ct)
31
Visualizations
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
32
Visualizations
http://deeplearning.net/tutorial/lstm.html
32
Visualizations
Chung et al. 2014: https://arxiv.org/abs/1412.3555
32
Visualizations
Chung et al. 2014: https://arxiv.org/abs/1412.3555
32
Visualizations
Visualization by Tim Rocktäschel
32
Comparison between LSTM and GRU
LSTM
◮ it = σ(W (i)xt + U(i)ht−1) ◮ ft = σ(W (f)xt + U(f)ht−1) ◮ ot = σ(W (o)xt + U(o)ht−1) ◮ gt = σ(W (g)xt + U(g)ht−1) ◮ ct = ft ◦ ct−1 + it ◦ gt ◮ ht = ot ◦ tanh(ct)
GRU
◮ zt = σ(W (z)xt + U(z)ht−1) ◮ rt = σ(W (r)xt + U(r)ht−1) ◮ ˜
ht = tanh(Wxt + rt ◦ Uht−1)
◮ ht = zt ◦ ht−1 + (1 − zt) ◦ ˜
ht Differences
◮ Forget gate is like Update gate ◮ Input gate is like Reset gate. ◮ Input gate does not act
directly on the previous hidden state.
◮ Lesser number of parameters.
33
Is RNN is a special case of LSTM?
◮ Not very straightforward.
34
Is RNN is a special case of LSTM?
◮ Not very straightforward. ◮ Input gate is all 1
34
Is RNN is a special case of LSTM?
◮ Not very straightforward. ◮ Input gate is all 1 ◮ Forget gate is all 0
34
Is RNN is a special case of LSTM?
◮ Not very straightforward. ◮ Input gate is all 1 ◮ Forget gate is all 0 ◮ Output gate is all 1
34
Is RNN is a special case of LSTM?
◮ Not very straightforward. ◮ Input gate is all 1 ◮ Forget gate is all 0 ◮ Output gate is all 1 ◮ Remove the final tanh covering the internal memory
34