Taking Advantage of Low Precision to Accelerate Training and Inference Using PyTorch
Presented by: Myle Ott and Sergey Edunov Facebook AI Research (FAIR)
Talk ID: S9832
Taking Advantage of Low Precision to Accelerate Training and - - PowerPoint PPT Presentation
Taking Advantage of Low Precision to Accelerate Training and Inference Using PyTorch Presented by: Myle Ott and Sergey Edunov Facebook AI Research (FAIR) Talk ID: S9832 Overview Mixed precision training in PyTorch: 3-4x speedups in
Talk ID: S9832
2
3
Slide credit: Nvidia
4
Slide credit: Nvidia
5
6
7
* Some operations should still happen in FP32:
8
9
FP16 Weights FP16 Loss FP16 Gradients
Forward Pass B a c k p r
FP16 Weights FP16 Loss FP32 Master Gradients FP16 Gradients
Forward Pass B a c k p r
C
y
10
11
FP16 Weights FP16 Loss FP32 Master Gradients FP16 Gradients FP32 Master Weights
Forward Pass B a c k p r
C
y Apply
12
FP16 Weights FP16 Loss FP32 Master Gradients FP16 Gradients FP32 Master Weights
Forward Pass B a c k p r
C
y Apply Copy
13
FP16 Weights FP16 Loss FP32 Master Gradients FP16 Gradients FP32 Master Weights
Forward Pass B a c k p r
C
y Apply Copy
This adds overhead! It’s only worth it because of the Tensor Cores. Don’t use mixed precision without Tensor Cores!
14
15
Gradients
Underflow can not be detected But if we scale loss up
If we scale the loss up by K, by the chain rule of derivatives, gradients will be K times bigger
16
Inf
If overflow detected Scale the loss down
Gradients
17
FP16 Weights FP16 Loss Scaled FP16 Gradients
Forward Pass B a c k p r
Loss Scaling
Scaled FP16 Loss
18
FP16 Weights FP16 Loss Scaled FP16 Gradients
Forward Pass B a c k p r
Loss Scaling
Scaled FP16 Loss
If gradients overflow (inf), throw away the batch
19
FP16 Weights FP16 Loss Scaled FP16 Gradients
Forward Pass B a c k p r
C
y Loss Scaling
Scaled FP16 Loss Scaled FP32 Gradients
If gradients overflow (inf), throw away the batch
20
FP16 Weights FP16 Loss FP32 Gradients Scaled FP16 Gradients
Forward Pass B a c k p r
C
y Loss Scaling
Scaled FP16 Loss Scaled FP32 Gradients
Remove scale If gradients overflow (inf), throw away the batch
21
FP16 Weights FP16 Loss FP32 Gradients Scaled FP16 Gradients FP32 Master Weights
Forward Pass B a c k p r
C
y Apply Copy Loss Scaling
Scaled FP16 Loss Scaled FP32 Gradients
If gradients overflow (inf), throw away the batch Remove scale
22
23
24
25
26
27
FP16 Weights FP32 Loss
Forward Pass
28
FP16 Weights FP32 Loss Scaled FP16 Gradients
Forward Pass B a c k p r
Loss Scaling
Scaled FP32 Loss
29
FP16 Weights FP32 Loss Scaled FP16 Gradients
Forward Pass B a c k p r
Loss Scaling
Scaled FP32 Loss
If gradients overflow (inf), throw away the batch
30
FP16 Weights FP32 Loss Scaled FP16 Gradients
Forward Pass B a c k p r
C
y Loss Scaling
Scaled FP32 Loss Scaled FP32 Gradients
If gradients overflow (inf), throw away the batch
31
FP16 Weights FP32 Loss FP32 Gradients Scaled FP16 Gradients
Forward Pass B a c k p r
C
y Loss Scaling
Scaled FP32 Loss Scaled FP32 Gradients
Remove scale If gradients overflow (inf), throw away the batch
32
FP16 Weights FP32 Loss FP32 Gradients Scaled FP16 Gradients FP32 Master Weights
Forward Pass B a c k p r
C
y Apply Copy Loss Scaling
Scaled FP32 Loss Scaled FP32 Gradients
Remove scale If gradients overflow (inf), throw away the batch
33
FP16 Weights FP32 Loss FP32 Gradients Scaled FP16 Gradients FP32 Master Weights
Forward Pass B a c k p r
C
y Apply Copy Loss Scaling
Scaled FP32 Loss Scaled FP32 Gradients
Remove scale If gradients overflow (inf), throw away the batch
Distributed gradient accumulation / all-reduce
from apex import amp
model, optim = amp.initialize(model, optim, opt_level="O1") (…) with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward()
34
35
x = torch.nn.functional.softmax(x, dtype=torch.float32).type_as(x)
36
37
38
39
Teng Li Ailing Zhang Shubho Sengupta Myle Ott Sergey Edunov David Grangier Michael Auli
41
42
Train Time (Minutes) 400 800 1200 1600 Original +16-bit + cumul +2x lr 16 nodes +overlap 1,429
43
Time in minutes to train "Transformer" translation model on Volta V100 GPUs (WMT En-De)
Train Time (Minutes) 400 800 1200 1600 Original +16-bit + cumul +2x lr 16 nodes +overlap 495 1,429
44
Time in minutes to train "Transformer" translation model on Volta V100 GPUs (WMT En-De) 3x faster (wall time) using the same hardware, model architecture and bsz!
Train Time (Minutes) 400 800 1200 1600 Original +16-bit + cumul +2x lr 16 nodes +overlap 447 495 1,429
45
Gradient Forward/Backward Idle GPU 1 GPU 2 GPU 3 GPU 4 Sync After 1
Sync After 2 Time
GPU 1 GPU 2 GPU 3 GPU 4
Time in minutes to train "Transformer" translation model on Volta V100 GPUs (WMT En-De)
Train Time (Minutes) 400 800 1200 1600 Original +16-bit + cumul +2x lr 16 nodes +overlap 311 447 495 1,429
46
Time in minutes to train "Transformer" translation model on Volta V100 GPUs (WMT En-De)
Train Time (Minutes) 400 800 1200 1600 Original +16-bit + cumul +2x lr 16 nodes +overlap 37 311 447 495 1,429
47
Time in minutes to train "Transformer" translation model on Volta V100 GPUs (WMT En-De)
Train Time (Minutes) 400 800 1200 1600 Original +16-bit + cumul +2x lr 16 nodes +overlap 32 37 311 447 495 1,429
48
Gradient Sync Forward Idle Backward GPU 1 GPU 2 GPU 3 GPU 4
Sync After Backward Overlap Sync with Backward
GPU 1 GPU 2 GPU 3 GPU 4
Time
Time in minutes to train "Transformer" translation model on Volta V100 GPUs (WMT En-De)
49
Myle Ott Sergey Edunov David Grangier Michael Auli
Bilingual German English = German English = Intermediate Model
Back-translation (Bojar & Tamchyna, 2011; Sennrich et al., 2016)
Bilingual German English = German English = Intermediate Model
Back-translation (Bojar & Tamchyna, 2011; Sennrich et al., 2016)
Monolingual Source German German German Bilingual German English = Intermediate Model
Monolingual Generated English English Monolingual Source German German German Bilingual Intermediate Model German English =
Monolingual Generated English English Monolingual Source German German German Bilingual German English = Intermediate Model German English =
Monolingual Generated English English English English Bilingual Monolingual Source German German English = German German English = Final Model
Bilingual Monolingual Source Monolingual Generated German English English German English English = Final Model German English =
57
BLEU (Accuracy) 9 18 27 36
f a i r s e q & s a m p l e d B T D e e p L ( 2 1 7 ) S A t t + R P R ( G
l e , 2 1 8 ) W T r a n s f
m e r ( S a l e s f
c e , 2 1 7 ) T r a n s f
m e r ( G
l e , 2 1 7 ) C
v S 2 S ( 2 1 7 ) G N M T ( R N N , 2 1 6 ) P h r a s e
a s e d ( 2 1 4 )
WMT'14 English-German
High quality, non-benchmark data
Only benchmark bilingual + monolingual data
Model trains in 22.5h
58
Ranked #1 in the human evaluation of the WMT'18 English-German translation task
59
60