A Quantitative Analysis of the Effect of Batch Normalization on - - PowerPoint PPT Presentation

a quantitative analysis of the effect of batch
SMART_READER_LITE
LIVE PREVIEW

A Quantitative Analysis of the Effect of Batch Normalization on - - PowerPoint PPT Presentation

A Quantitative Analysis of the Effect of Batch Normalization on Gradient Descent Yongqiang Cai 1 , Qianxiao Li 1,2 , Zuowei Shen 1 9-15 June 2019 (ICML), Long Beach, CA, USA 1 Department of Mathematics, National University of Singapore, Singapore


slide-1
SLIDE 1

A Quantitative Analysis of the Effect of Batch Normalization on Gradient Descent

Yongqiang Cai1, Qianxiao Li1,2, Zuowei Shen1 9-15 June 2019 (ICML), Long Beach, CA, USA

1Department of Mathematics, National University of Singapore, Singapore 2Institute of High Performance Computing, A*STAR, Singapore

slide-2
SLIDE 2

Batch Normalization

A vanilla fully-connected layer z = σ(Wu + b). With batch normalization (Ioffe & Szegedy 2015): z = σ(γN(Wu) + β), N(ξ) :=

ξ−E[ξ]

Var[ξ].

Batch normalization works well in practice, e.g. allows stable training with large learning rates, works well in high dimensions or ill-conditioned problems Related work on BN [Ma & Klabjan (2017); Kohler et al. (2018); Arora et al. (2019)]

1

slide-3
SLIDE 3

Batch Normalization

A vanilla fully-connected layer z = σ(Wu + b). With batch normalization (Ioffe & Szegedy 2015): z = σ(γN(Wu) + β), N(ξ) :=

ξ−E[ξ]

Var[ξ].

Batch normalization works well in practice, e.g. allows stable training with large learning rates, works well in high dimensions or ill-conditioned problems Related work on BN [Ma & Klabjan (2017); Kohler et al. (2018); Arora et al. (2019)] Question: Can we quantify the precise effect of BN on gradient descent (GD)?

1

slide-4
SLIDE 4

Batch Normalization on Ordinary Least Squares

Linear regression model: Input: x ∈ Rd Label: y ∈ R Model: y = xT w∗ + noise

2

slide-5
SLIDE 5

Batch Normalization on Ordinary Least Squares

Linear regression model: Input: x ∈ Rd Label: y ∈ R Model: y = xT w∗ + noise OLS regression without BN Optimization problem: minw J0(w) := Ex,y[ 1

2(y − xT w)2]

Gradient descent dynamics: wk+1 = wk − ε∇wJ0(wk) = wk + ε(g − Hwk), where H := E[xxT ], g := E[xy], c := E[y2].

contraction ratio

2

slide-6
SLIDE 6

Batch Normalization on Ordinary Least Squares

Linear regression model: Input: x ∈ Rd Label: y ∈ R Model: y = xT w∗ + noise OLS regression with BN Optimization problem: mina,w J(a, w) = Ex,y 1

2

  • y − a N(xT w)

2 Gradient descent dynamics:      ak+1 = ak − εa∇aJ(ak, wk) = ak + εa

  • wT

k g

wT

k Hwk − ak

  • ,

wk+1 = wk − ε∇wJ(ak, wk) = wk +

εak

wT

k Hwk

  • g −

wT

k g

wT

k Hwk Hwk

  • .

How does this compare with the GD case? wk+1 = wk − ε∇wJ0(wk) = wk + ε(g − Hwk) Properties of interest: convergence, robustness

2

slide-7
SLIDE 7

Summary of Theoretical Results

Property Gradient Descent Gradient Descent with BN Convergence

  • nly for small ε

arbitrary ε provided εa ≤ 1 Convergence Rate linear linear (can be faster) Robustness to Learning Rates small range of ε wide range of ε Robustness to Dimensions no effect the higher the better

(a) Loss of GD and BNGD(d = 100) (b) Effect of dimension on BNGD

3

slide-8
SLIDE 8

Summary of Theoretical Results

Property Gradient Descent Gradient Descent with BN Convergence

  • nly for small ε

arbitrary ε provided εa ≤ 1 Convergence Rate linear linear (can be faster) Robustness to Learning Rates small range of ε wide range of ε Robustness to Dimensions no effect the higher the better

  • Those properties are also observed in neural network experiments.

(a) Loss of GD and BNGD(d = 100) (b) Effect of dimension on BNGD (c) Accuracy of BNGD on MNIST

3

slide-9
SLIDE 9

Poster: Pacific Ballroom #54

3