Learning Accurate Low-bit Deep Neural Networks with Stochastic - - PowerPoint PPT Presentation

โ–ถ
learning accurate low bit deep neural networks with
SMART_READER_LITE
LIVE PREVIEW

Learning Accurate Low-bit Deep Neural Networks with Stochastic - - PowerPoint PPT Presentation

Learning Accurate Low-bit Deep Neural Networks with Stochastic Quantization Yinpeng Dong 1 , Renkun Ni 2 , Jianguo Li 3 , Yurong Chen 3 , Jun Zhu 1 , Hang Su 1 1 Department of CST, Tsinghua University 2 University of Virginia 3 Intel Labs China


slide-1
SLIDE 1

Learning Accurate Low-bit Deep Neural Networks with Stochastic Quantization

Yinpeng Dong1, Renkun Ni2, Jianguo Li3, Yurong Chen3, Jun Zhu1, Hang Su1

1Department of CST, Tsinghua University 2University of Virginia 3Intel Labs China

slide-2
SLIDE 2

2

Deep Learning is Everywhere

Self-Driving Alpha Go Dota Machine Translation

slide-3
SLIDE 3

3

Limitations

n More data + deeper models ร  more FLOPs + lager

memory

n Computation Intensive n Memory Intensive n Hard to deploy on mobile devices

slide-4
SLIDE 4

4

Low-bit DNNs for Efficient Inference

n High Redundancy in DNNs; n Quantize full-precision(32-bits) weights to binary(1 bit)

  • r ternary(2 bits) weights;

n Replace multiplication(convolution) by addition and

subtraction;

slide-5
SLIDE 5

5

Typical Low-bit DNNs

n BinaryConnect:

๐ถ" = $+1 with probability ๐‘ž = ๐œ(๐‘‹

")

โˆ’1 with probability 1 โˆ’ ๐‘ž

n BWN: minimize ๐‘‹ โˆ’ ๐›ฝ๐ถ

๐ถ" = ๐‘ก๐‘—๐‘•๐‘œ ๐‘‹

" ,

๐›ฝ = โˆ‘ ๐‘‹

" @ "AB

๐‘’

n TWN: minimize ๐‘‹ โˆ’ ๐›ฝ๐‘ˆ

๐‘ˆ" = E +1 if ๐‘‹

" > โˆ†

0 if ๐‘‹

" < โˆ†

โˆ’1 if ๐‘‹

" < โˆ’โˆ†

, ๐›ฝ = โˆ‘ ๐‘‹

"

  • "โˆˆMโˆ†

๐ฝโˆ† ๐ฝโˆ† = ๐‘— ๐‘‹

" > โˆ† ,

โˆ†= 0.7 ๐‘’ Q ๐‘‹

" @ "AB

slide-6
SLIDE 6

6

Training & Inference of Low-bit DNN

n Let ๐‘‹ be the full-precision weights, ๐‘… be the low-bit

weights (๐ถ, ๐‘ˆ, ฮฑ๐ถ, ฮฑ๐‘ˆ).

n Forward propagation: quantize ๐‘‹ to ๐‘… and perform

convolution or multiplication

n Backward propagation: use ๐‘… to calculate gradients n Parameter update: ๐‘‹TUB = ๐‘‹T โˆ’ ๐œƒT WX

WYZ

n Inference: only need to keep low-bit weights ๐‘…

slide-7
SLIDE 7

7

Motivations

n Quantize all weights simultaneously; n Quantization error ๐‘‹ โˆ’ ๐‘… may be large for some

elements/filters;

n Induce inappropriate gradient directions. n Quantize a portion of weights n Stochastic selection n Could be applied to any low-bit settings

slide-8
SLIDE 8

8

Roulette Selection Algorithm

1.3

  • 1.1

0.75 0.85 0.95 1.4

  • 1.2
  • 0.9

1.05

  • 1.0
  • 0.9

0.8

  • 0.8

0.9 1.0

  • 1.0

0.2 0.05 0.2 0.1

Selection Point

C1 C2 C3 C4

1-st selection: v=0.58 C2 selected Rotation

Selection Point

2-nd selection: v=0.37 C3 selected Rotation

1.3

  • 1.1

0.75 0.85 1 1

  • 1.2
  • 1

1

  • 1
  • 1

0.8

  • 1

1 1.0

  • 1.0

Weight Matrix Quantization Error Stochastic Partition with r = 50% Hybrid Weight Matrix

๐‘“" = ๐‘‹

" โˆ’ ๐‘…" B

๐‘‹

" B

Quantization Error: Quantization Probability: Larger quantization error means smaller quantization probability, e.g. ๐‘ž" โˆ B

]^

Quantization Ratio r: Gradually increase to 100%

slide-9
SLIDE 9

9

Training & Inference

n Hybrid weight matrix ๐‘…

_ ๐‘… _" = $๐‘…" if channel i being selected ๐‘‹

" else

n Parameter update

๐‘‹TUB = ๐‘‹T โˆ’ ๐œƒT ๐œ–๐‘€ ๐œ–๐‘… _T

n Inference: all weights are quantized; use ๐‘… to perform

inference

slide-10
SLIDE 10

10

Ablation Studies

n Selection Granularity:

ยจ Filter-level > Element-level

n Selection/partition algorithms

ยจ Stochastic (roulette) > deterministic (sorting) ~ fixed

(selection only at first iteration)

n Quantization functions

ยจ Linear > Sigmoid > Constant ~ Softmax

n ๐‘ž" = exp

(๐‘”

") โˆ‘ exp

(๐‘”

")

  • โ„

, where ๐‘” = B

]

n Quantization Ratio Update Scheme

ยจ Exponential > Fine-tune > Uniformly

n 50% ร  75% ร  87.5% ร  100%

slide-11
SLIDE 11

11

Results -- CIFAR

Bits CIFAR-10 CIFAR-100 VGG-9 ResNet-56 VGG-9 ResNet-56 FWN 32 9.00 6.69 30.68 29.49 BWN 1 10.67 16.42 37.68 35.01 SQ-BWN 1 9.40 7.15 35.25 31.56 TWN 2 9.87 7.64 34.80 32.09 SQ-TWN 2 8.37 6.20 34.24 28.90 error (%) of VGG-9 and ResNet-56 trained with 5 different methods on the CIFAR-10 and

20 40 60 80 64 128 192 256 Loss Iter.(k) FWN BWN SQ-BWN 0.2 0.4 0.6 0.8 1 1.2 1.4 1.6 1.8 2 64 128 192 256 Loss Iter.(k) FWN TWN SQ-TWN

slide-12
SLIDE 12

12

Results -- ImageNet

Bits AlexNet-BN ResNet-18 top-1 top-5 top-1 top-5 FWN 32 44.18 20.83 34.80 13.60 BWN 1 51.22 27.18 45.20 21.08 SQ-BWN 1 48.78 24.86 41.64 18.35 TWN 2 47.54 23.81 39.83 17.02 SQ-TWN 2 44.70 21.40 36.18 14.26 (%) of AlexNet-BN and ResNet-18 trained with 5 different methods on

slide-13
SLIDE 13

13

Conclusions

n We propose a stochastic quantization algorithm for

Low-bit DNN training

n Our algorithm can be flexibly applied to all low-bit

settings;

n Our algorithm help to consistently improve the

performance;

n We release our codes to public for future development

ยจ https://github.com/dongyp13/Stochastic-Quantization

slide-14
SLIDE 14

Q & A