Improving Transformer Optimization Through Better Initialization - - PowerPoint PPT Presentation

improving transformer optimization through better
SMART_READER_LITE
LIVE PREVIEW

Improving Transformer Optimization Through Better Initialization - - PowerPoint PPT Presentation

Improving Transformer Optimization Through Better Initialization Xiao Shi Huang*, Felipe Perez*, Jimmy Ba, Maksims Volkovs 1 Transformer in Detail Removing Warmup: T-Fixup Agenda Experimental Results Summary 2


slide-1
SLIDE 1

Improving Transformer Optimization Through Better Initialization

Xiao Shi Huang*, Felipe Perez*, Jimmy Ba, Maksims Volkovs

1

slide-2
SLIDE 2

Agenda

  • Transformer in Detail
  • Removing Warmup: T-Fixup
  • Experimental Results
  • Summary

2

slide-3
SLIDE 3

3

Transformer

  • Encoder-Decoder architecture
  • Residual backbone
  • Multi-Headed Attention in ResBlock
  • LayerNorm after every residual block
slide-4
SLIDE 4

4

  • Adam optimizer
  • Inverse square root learning rate decay
  • Learning rate warmup
  • Training
slide-5
SLIDE 5

5

Necessity of Warmup

  • Gradient histogram
slide-6
SLIDE 6

6

Necessity of Warmup

Error signal decreases with a large input

  • LayerNorm in Backpropagation[2]
  • x: input to Layer Normalization
  • d: dimension of x
slide-7
SLIDE 7

7

Necessity of Warmup

  • LayerNorm in Backpropagation[2]
slide-8
SLIDE 8

8

Removing Warmup

  • Without LayerNorm:
  • Magnitude on backbone

grows with layer depth

slide-9
SLIDE 9

9

Removing Warmup

  • Without LayerNorm:
  • Magnitude on backbone

grows with layer depth

  • With LayerNorm:
  • Reset to unit magnitude
slide-10
SLIDE 10

Removing Warmup

10

  • Without LayerNorm:
  • Magnitude on backbone

grows with layer depth

  • With LayerNorm:
  • Reset to unit magnitude
  • Parameter-Controller Growth
slide-11
SLIDE 11

Removing Warmup

11

Goal: Control the total change on the output

  • f the transformer after a gradient update.

Control output change in residual blocks:

  • Feedforward blocks as in Fixup
  • Theorem: For Attention blocks, this is

controlled when:

slide-12
SLIDE 12

Removing Warmup

  • T-Fixup Initialization
  • Xavier Initialization for all

projection matrices

  • Gaussian initialization for

embedding layers

  • Scale embedding layers and

decoder parameters by (9N)-1/4

  • Scale encoder parameters by

0.67N-1/4

12

slide-13
SLIDE 13

Experimental Results

13

slide-14
SLIDE 14

14

T-Fixup on Standard Transformer

  • T-Fixup achieves consistently higher performance with less structure
slide-15
SLIDE 15

15

T-Fixup on Standard Transformer: gradients

  • Gradient and Adam Update Magnitudes
  • Vanilla Transformer Without Warmup
  • vanishing gradient
  • T-Fixup Without Warmup
  • stable error signal throughout

training

slide-16
SLIDE 16

16

T-Fixup on Deeper Transformer

  • T-Fixup outperforms all competitive models with equal or less layers
slide-17
SLIDE 17

17

T-Fixup on Ultra-Deep Transformer

  • IWSLT’14 De-En dataset, 64(embed)-128(MLP hidden)-2(head) Transformer
slide-18
SLIDE 18

18

T-Fixup on Large Batch Training

  • WMT’17 En-De Dataset, WMTbase Transformer
slide-19
SLIDE 19

Summary

19

slide-20
SLIDE 20

20

Summary

  • Requirement for learning rate warmup: Adam + LayerNorm
  • T-Fixup Initialization
  • Superior performance on NMT
  • Ultra-Deep Transformer
  • Future Work
slide-21
SLIDE 21

21

Acknowledgement

slide-22
SLIDE 22

22

Questions?

Thank you!

Contact: Xiao Shi (Gary) Huang gary@layer6.ai

slide-23
SLIDE 23

23

References

[1]: Liu, L. etc. On the variance of the adaptive learning rate and beyond. In ICLR, 2020 [2]: Xiong, R. etc. On layer normalization in the transformer architecture. In ICML, 2020 [3]: Zhang, H. etc. Fixup initialization: residual learning without normalization, In ICLR, 2019 [4]: Wang. Q. etc. Learning deep transformer models for machine translation. In ACL, 2019 [5]: Zhang, B. etc. Improving deep transformer with depth-scaled initialization and merged

  • attention. In EMNLP

, 2019 [6]: Xu. H. etc. Why deep transformers are difficult to converge? From computation order to Lipschitz restricted parameter initialization. In Arxiv