Soft Attention Models in Deep Networks Praveen Krishnan CVIT, IIIT - - PowerPoint PPT Presentation

soft attention models in deep networks
SMART_READER_LITE
LIVE PREVIEW

Soft Attention Models in Deep Networks Praveen Krishnan CVIT, IIIT - - PowerPoint PPT Presentation

Soft Attention Models in Deep Networks Praveen Krishnan CVIT, IIIT Hyderabad June 21, 2017 Everyone knows what attention is. It is the taking possession of the mind, in clear and vivid form, of one out of what seem several simultaneously


slide-1
SLIDE 1

Soft Attention Models in Deep Networks

Praveen Krishnan

CVIT, IIIT Hyderabad

June 21, 2017

Everyone knows what attention is. It is the taking possession of the mind, in clear and vivid form, of one out of what seem several simultaneously possible objects or trains of thought. Focalization, concentration of consciousness are of its essence....James’1980

1

slide-2
SLIDE 2

Outline

Motivation Primer on Prediction using RNNs Handwriting prediction [Read] Handwriting synthesis [Write] Deep Recurrent Attentive Writer [DRAW]

2

slide-3
SLIDE 3

Motivation

Few questions to begin

◮ How do we perceive an image and start interpreting? ◮ Given the knowledge of any languages, how do we manually

translate between its sentences?

◮ . . .

3

slide-4
SLIDE 4

Motivation

Few questions to begin

◮ How do we perceive an image and start interpreting? ◮ Given the knowledge of any languages, how do we manually

translate between its sentences?

◮ . . .

Figure 1: Left: Source Wikipedia, Right: Bahdanau et al. ICLR’15

3

slide-5
SLIDE 5

Motivation

Why Attention?

◮ You don’t see every pixel! ◮ You remove the clutter and process the salient parts. ◮ You process one step at a time and aggregate the information

in your memory.

4

slide-6
SLIDE 6

Attention Mechanism

Definition

In cognitive neuroscience, it is viewed as a neural system for the selection of information similar in many ways to the visual, auditory, or motor systems [Posner 1994].

Visual Attention Components [Tsotsos, et al. 1995]

◮ Selection of a region of interest in the visual field. ◮ Selection of feature dimensions and values of interest ◮ Control of information flow through visual cortex. ◮ Shifting from one selected region to the next in time

5

slide-7
SLIDE 7

Attention Mechanism

Attention in Neural Networks

An architecture-level feature in neural networks to allow attending different parts of image sequentially and aggregate information

  • ver time.

Types of attention

◮ Hard: Picking discrete location to attend. However the model

is non-differentiable.

◮ Soft: Spread out the attention weights over the entire image.

In this talk, we limit ourselves to soft attention models which are differentiable and uses standard back-propagation. Before we dig deeper, lets brush up RNN.

6

slide-8
SLIDE 8

Recurrent Neural Network

Figure 2: An unrolled recurrent neural network1

RNN network

◮ Neural network with loops. ◮ Captures temporal information. ◮ Issue of long-term dependency is solved by gated units

(LSTMs, GRUs,. . .).

◮ Wide range of applications in image captioning, speech

processing, language modelling, . . .

1colah’s blog , Understanding LSTM Networks. 7

slide-9
SLIDE 9

LSTM

Long Short-term Memory Cell

ft =σ(Wxf xt + Whf ht−1 + Wcf ct−1 + bf ) it =σ(Wxixt + Whiht−1 + Wcict−1 + bi) ct =ftct−1 + it tanh(Wxcxt + Whcht−1 + bc)

  • t =σ(Wxoxt + Whoht−1 + Wcoct + bo)

ht =ot tanh(ct)

◮ Uses memory cells to store information. ◮ The above version uses peephole connections.

Let us now see how we can use them for prediction.

8

slide-10
SLIDE 10

LSTMs as Prediction Network

Given an input sample xt, predict the next sample xt+1.

Prediction Problem

Learn a distribution P(xt+1|yt) where x = (x1, . . . , xT) is the input sequence, given to N hidden layers (hn = hn

1, . . . , hn T) to predict an

  • utput sequence y = (y1, . . . , yT)

9

slide-11
SLIDE 11

LSTMs as Prediction Network

Choice of predictive distribution (Density Modeling)

The probability given by the network to the input sequence x is:- P(x) =

T

  • t=1

P(xt+1|yt) and the sequence loss is:- L(x) = −

T

  • t=1

log P(xt+1|yt) Training is done through back-propagation through time. For e.g: In text prediction, one can parameterize the output distribution using a softmax function.

10

slide-12
SLIDE 12

Handwriting Prediction

Problem

Given online handwriting data (recorded pen tip x1, x2 locations) at time step t, predict the location of pen at t + 1 time step along with the end of stroke variable.

Figure 3: Left: Samples of online handwritten data of multiple authors. Right: Demo2

2Carter et al., Experiments in Handwriting with a Neural Network, Distill,

2016.

11

slide-13
SLIDE 13

Handwriting Prediction

Mixture Density Outputs [Graves arxiv’13]

A mixture of bivariate Gaussians is used to predict x1, x2, while a Bernoulli distribution is used for x3. xt ∈R × R × {0, 1} yt =(et, {πj

t, µj t, σj t, ρj t}M j=1) ◮ et ∈ (0, 1) is the end of stroke probability, ◮ πj ∈ (0, 1) is the mixture weights, ◮ µj ∈ R2 the means vector, ◮ σj > 0 the standard deviation, and ◮ ρj ∈ (−1, 1) are the correlations

Note that x1, x2 are now the offsets from the previous location and the above parameters are normalized from network outputs.

12

slide-14
SLIDE 14

Handwriting Prediction

Mixture Density Outputs [Graves arxiv’13]

The probability of next input is given as:- P(xt+1|yt) =

M

  • j=1

πj

tN(xt+1|µj t, σj t, ρj t)

  • et

if (xt+1)3 = 1 1 − et

  • therwise

As shown earlier the sequence loss is given as:- L(x) =

T

  • t=1

− log M

  • j=1

πj

tN(xt+1|µj t, σj t, ρj t)

  • log et

if (xt+1)3 = 1 log(1 − et)

  • therwise

13

slide-15
SLIDE 15

Handwriting Prediction

Visualization

Figure 4: Heat map showing the mixture density outputs for handwriting prediction.

14

slide-16
SLIDE 16

Handwriting Prediction

Demo

Available at :- Link

Figure 5: Carter et al., Experiments in Handwriting with a Neural Network, Distill, 2016.

15

slide-17
SLIDE 17

Handwriting synthesis [Graves arxiv’13]

HW Synthesis

Generation of handwriting conditioned on an input text.

Key Question

Ques: How to resolve the alignment problem between two sequences of varying length?

16

slide-18
SLIDE 18

Handwriting synthesis [Graves arxiv’13]

HW Synthesis

Generation of handwriting conditioned on an input text.

Key Question

Ques: How to resolve the alignment problem between two sequences of varying length? Sol: Add “attention” as a soft window which is convolved with the input text and given as input to prediction network. Learning to decide which character to write next.

16

slide-19
SLIDE 19

Handwriting synthesis

17

slide-20
SLIDE 20

Handwriting synthesis

The soft window wt into c at timestep t is defined as:- φ(t, u) =

K

  • k=1

αk

t exp(−βk t (κk t − u)2)

wt =

U

  • u=1

φ(t, u)cu φ(t, u) acts as window weight for cu (one-hot en- coding) at time t. The soft attention is modeled by a mixture of K Gaussians, where κt → location, βt → width and αt → weight of the Gaussian.

17

slide-21
SLIDE 21

Handwriting synthesis

Window Parameters

(ˆ αt, ˆ βt, ˆ κt) =Wh1ph1

t + bp

αt = exp(ˆ αt) βt = exp(ˆ βt) κt =κt−1 + exp(ˆ κt)

Figure 6: Alignment between the text sequence and handwriting.

18

slide-22
SLIDE 22

Handwriting synthesis

Qualitative Results Questions

◮ How is stochasticity induced in the generation of different

samples?

19

slide-23
SLIDE 23

Handwriting synthesis

Qualitative Results Questions

◮ How is stochasticity induced in the generation of different

samples?

◮ How to decide the network has finished writing text?

19

slide-24
SLIDE 24

Handwriting synthesis

Qualitative Results Questions

◮ How is stochasticity induced in the generation of different

samples?

◮ How to decide the network has finished writing text? ◮ How to control the quality of writing?

19

slide-25
SLIDE 25

Handwriting synthesis

Qualitative Results Questions

◮ How is stochasticity induced in the generation of different

samples?

◮ How to decide the network has finished writing text? ◮ How to control the quality of writing? ◮ How to generate handwriting in a particular style?

19

slide-26
SLIDE 26

Handwriting synthesis

Biased Sampling vs. Primed Sampling

20

slide-27
SLIDE 27

Deep Recurrent Attentive Writer (DRAW)

DRAW

Combines spatial attention mechanism with a sequential variational auto-encoding framework for iterative construction of complex images.

Figure 7: MNIST digits drawn using recurrent attention model.

21

slide-28
SLIDE 28

Deep Recurrent Attentive Writer (DRAW)

Major contribution

◮ Progressive Refinement (Temporal): Suppose C is the canvas

  • n which the image is drawn. The joint distribution of P(C)

can be split into multiple latent variables C1, C2, . . . , CT−1, given the observed variable P(CT). P(C) = P(CT|CT−1)P(CT−1|CT−2) . . . P(C1|C0)P(0) (1)

◮ Spatial Attention (Spatial): Drawing a part of canvas at a

time which simplifies the drawing process by defining “where to look” and “where to write”.

Figure 8: Recurrence relation

22

slide-29
SLIDE 29

Deep Recurrent Attentive Writer (DRAW)

Figure 9: Left: Traditional VAE network, Right: DRAW network

◮ Encoder and decoder are recurrent networks. ◮ Encoder oversees the previous output of decoder to tailor its

current output while decoder output is successively added to the output distribution.

◮ An dynamic attention mechanism for “where to read” and

“where to write”.

23

slide-30
SLIDE 30

Deep Recurrent Attentive Writer (DRAW)

Training

ˆ xt =x − σ(ct−1) rt =read(xt, ˆ xt, hdec

t−1)

henc

t

=RNNenc(henc

t−1, [rt, hdec t−1])

zt ∼Q(Zt|henc

t

) hdec

t

=RNNdec(hdec

t−1, zt)

ct =ct−1 + write(hdec

t

)

Here σ(x) =

1 1+exp−x is the logistic sigmoid function and the latent

distribution is taken as diagonal Gaussian N(Zt|µt, σt) where:- µt =W (henc

t

) σt = expW (henc

t

)

24

slide-31
SLIDE 31

Deep Recurrent Attentive Writer (DRAW)

Loss Function

The target generative model is given as D(X|cT) where cT is the final canvas matrix. The reconstruction loss Lx is given as:- Lx = − log D(x|cT) and the latent loss Lz is defined for a sequence of latent distributions:- Lz =

T

  • t=1

KL(Q(Zt|henc

t

) P(Zt)) If P(Zt) is assumed to be N(0, 1) then there is closed form solution given as:- Lz = 1 2 T

  • t=1

µ2

t + σ2 t − log σ2 t

  • − T

2

25

slide-32
SLIDE 32

Deep Recurrent Attentive Writer (DRAW)

Generation and Testing

˜ zt ∼P(Zt) ˜ hdec

t

=RNNdec(˜ hdec

t−1, ˜

zt) ˜ ct =˜ ct−1 + write(˜ hdec

t

) ˜ x ∼D(X|˜ cT)

26

slide-33
SLIDE 33

Deep Recurrent Attentive Writer (DRAW)

Selective Attention Model

◮ Similar to differential attention mechanisms shown in

handwriting synthesis (Graves arxiv’13), Neural Turing Machine (Graves et al. arxiv’14) and Neural Machine Translation (Bahadanau et al. ICLR’15).

◮ 2D form of attention using an array of 2D Gaussian filters.

27

slide-34
SLIDE 34

Deep Recurrent Attentive Writer (DRAW)

Selective Attention Model

28

slide-35
SLIDE 35

Deep Recurrent Attentive Writer (DRAW)

Defining Attention Parameters

Assume the image of size A × B and we place N × N grid of Gaussian filters positioned on image with grid center (gX, gY ) with stride δ which controls the zoom of the patch. The mean location

  • f filter is given as:-

µi

X =gX + (i − N/2 − 0.5)δ

µj

Y =gY + (j − N/2 − 0.5)δ

In addition, we have σ2 as variance of filters and scalar intensity γ which multiples with each filter response. The read attention parameters are learned by a linear transformation:- (˜ gX, ˜ gY , log σ2, log ˜ δ, log γ) = W (hdec) The scaling of the parameters are chosen to ensure initial patch randomly covers the entire image.

29

slide-36
SLIDE 36

Deep Recurrent Attentive Writer (DRAW)

Defining Filterbank Matrices

FX[i, a] = 1 ZX exp−

(a−µi X )2 2σ2

, FY [j, b] = 1 ZY exp−

(b−µj Y )2 2σ2

Here ZX and ZY are normalization constants to ensure

  • a FX[i, a] = 1 and

a FY [j, b] = 1

Read

read(x, ˆ x, hdec

t−1) = γ[FY xF T X , FY ˆ

xF T

X ]

Write

write(hdec

t

) = 1

γ [ ˆ

F T

Y wt ˆ

FX]

30

slide-37
SLIDE 37

Deep Recurrent Attentive Writer (DRAW)

Some Qualitative Results

Figure 10: Cluttered MNIST classification Figure 11: SVHN digit generation

31

slide-38
SLIDE 38

References

Bob Fisher, CVOnline, http://homepages.inf.ed.ac.uk/ rbf/CVonline/LOCAL_COPIES/SUN1/attn.htm Eric Jang, Understanding and Implementing Deepmind’s DRAW Model, http://blog.evjang.com/2016/06/ understanding-and-implementing.html

32

slide-39
SLIDE 39

Thanks for your attention :)

33