Unraveling the mysteries of stochastic gradient descent on deep - - PowerPoint PPT Presentation

unraveling the mysteries of stochastic gradient descent
SMART_READER_LITE
LIVE PREVIEW

Unraveling the mysteries of stochastic gradient descent on deep - - PowerPoint PPT Presentation

Unraveling the mysteries of stochastic gradient descent on deep neural networks Pratik Chaudhari UCLA VISION LAB 1 The question measures disagreement of predictions with ground truth x = argmin f ( x ) Cat Dog ... x weights aka


slide-1
SLIDE 1

Unraveling the mysteries of stochastic gradient descent

  • n deep neural networks

Pratik Chaudhari

UCLA VISION LAB

1

slide-2
SLIDE 2

2

The question

x ∗ = argmin

x

f (x)

measures disagreement of predictions with ground truth weights aka parameters

Many, many variants: AdaGrad, rmsprop, Adam, SAG, SVRG, Catalyst, APPA, Natasha, Katyusha…

xk +1 = xk − η +fb(xk )

Stochastic gradient descent

Why is SGD so special?

Cat Dog ...

slide-3
SLIDE 3

−5

10 20 30 40

Eigenvalues

10 103 105

Frequency

−0.5 −0.4 −0.3 −0.2 −0.1 0.0

Eigenvalues

102 103 104

Frequency

Short negative tail

3

Empirical evidence: wide “minima”

slide-4
SLIDE 4
  • Energy landscape of a binary perceptron

4

  • Wide minima are a large deviations phenomenon

Few wide minima, but generalize better [Baldassi et al., '15] Many sharp minima

A bit of statistical physics

slide-5
SLIDE 5

x ∗ = argmin

x

f (x) = argmax

x

e−f (x)

  • Local Entropy [Chaudhari et al., ICLR '17]

5

≈ argmin

x

− log ⇣ Gγ ∗ e−f (x)⌘

Gaussian kernel

  • f varianceγ

Tilting the Gibbs measure

slide-6
SLIDE 6

Wide-ResNet: CIFAR-10 All-CNN: CIFAR-10 (25% data)

  • State-of-the-art performance [Chaudhari et al., SysML '18]

6

Parle: parallelization of SGD

slide-7
SLIDE 7

7

The question

Why is SGD so special?

slide-8
SLIDE 8

8

A continuous-time view of SGD

  • Diffusion matrix: variance of mini-batch gradients
  • Temperature: ratio of learning rate and step-size

β −1 = η 2b

var ⇣ +fb(x) ⌘ = D(x) b = 1 b * , 1 N

N

X

k =1

+fk (x) +fk (x)> +f (x) +f (x)>+

slide-9
SLIDE 9

9

A continuous-time view of SGD

  • Continuous-time limit of discrete-time updates

will assume x ∈ ⌦ ⊂ “d

  • Fokker-Planck (FP) equation gives the distribution on the


weight space induced by SGD

where x(t) ∼ ρ(t)

ρt = div ⇣ +f ρ |{z}

drift

+ β −1div ⇣ D ρ ⌘ | {z }

diffusion

⌘ dx = −+f (x) dt |{z}

+ q 2β −1D(x) dW (t)

slide-10
SLIDE 10

10

Wasserstein gradient flow

1 2 Z

+ρ(x)2 dx

  • Heat equation performs steepest descent on the


Dirichlet energy

ρt = div ⇣ I +ρ ⌘

−H (ρ) = Z

log ρ dρ

ρτ

k +1 ∈ argmin ρ

     −H (ρ) + ◊2

2(ρ, ρτ k )

2τ     

converges to trajectories

  • f the heat equation
  • It is also the steepest descent in the Wasserstein metric for

ρss

heat = argmin

ρ

−H (ρ)

  • Negative entropy is a Lyapunov functional for Brownian motion
slide-11
SLIDE 11

11

Wasserstein gradient flow: with drift

  • If , the Fokker-Planck equation

ρt = div ⇣ +f ρ + β −1I +ρ ⌘

  • FP is the steepest descent on JKO in the Wasserstein metric

D = I

has the Jordan-Kinderleher-Otto (JKO) functional [Jordan et al., '97] as the Lyapunov functional.

ρss(x) = argmin

ρ

≈x∼ρ f f (x) g | {z }

energetic term

− β −1 H (ρ) | {z }

entropic term

slide-12
SLIDE 12

12

What happens for non-isotropic noise?

ρt = div ⇣ +f ρ |{z}

drift

+ β −1div ⇣ D ρ ⌘ | {z }

diffusion

  • FP monotonically minimizes the free energy

ρss(x) = argmin

ρ

≈x∼ρ f (x) g − β −1H (ρ)

F (ρ) = β −1KL (ρ `` ρss)

  • Rewrite as

compare with |x - x*| for deterministic optimization.

slide-13
SLIDE 13

13

SGD performs variational inference

Theorem [Chaudhari & Soatto, ICLR '18]

The functional is minimized monotonically by trajectories of the Fokker-Planck equation with as the steady-state distribution. Moreover,

ρss ρt = div ⇣ +f ρ + β −1div (D ρ) ⌘ F (ρ) = β −1KL (ρ `` ρss)

up to a constant.

Φ = −β −1 log ρss

slide-14
SLIDE 14

14

Some implications

  • Learning rate should scale linearly with batch-size

β −1 = η 2b should not be small

  • Sampling with replacement regularizes better than without

β −1

w/o replacement = η

2b 1 − b N !

also generalizes better.

slide-15
SLIDE 15

15

Information Bottleneck Principle

  • Minimize mutual information of the representation with the training data


[Tishby '99, Achille & Soatto '17]

  • Minimizing these functionals is hard, SGD does it naturally

IBβ(θ) = ≈x∼ρθ f f (x) g − β −1 KL ⇣ ρθ `` prior ⌘

slide-16
SLIDE 16

16

Potential Phi vs. original loss f

  • The solution of the variational problem is
  • The two losses are equal if and only if noise is isotropic

D(x) = I ⇔ Φ(x) = f (x)

  • Key point

Most likely locations of SGD are not the critical points of the original loss

ρss(x) , 1 Z 0

β

eβ f (x)

ρss(x) = 1 Zβ e−β Φ(x)

slide-17
SLIDE 17

17

Deep networks have highly non-isotropic noise

CIFAR-10

λ(D) = 0.27 ± 0.84 rank(D) = 0.34%

CIFAR-100

λ(D) = 0.98 ± 2.16 rank(D) = 0.47%

  • Evaluate neural architectures using the diffusion matrix
slide-18
SLIDE 18

18

How different are cats and dogs, really?

slide-19
SLIDE 19

is such that

Theorem [Chaudhari & Soatto, ICLR '18]

The most likely trajectories of SGD are where the "leftover" vector field

˙ x = j (x), div j (x) = 0.

j (x) = −+f (x) + D(x) +(x) − β −1divD(x)

19

SGD converges to limit cycles

slide-20
SLIDE 20

20

Trajectories of SGD

FFT of xi

k +1 − xi k

  • Run SGD for epochs

105

slide-21
SLIDE 21

21

An example

j (x) = 0

+(x) = 0 very large j (x) saddle-point j (x) is small force-field

slide-22
SLIDE 22

22

Most likely locations are not the critical points of the original loss

Theorem [Chaudhari & Soatto, ICLR '18]

The Ito SDE with the same steady-state if is equivalent to an A-type SDE

dx = − ⇣ D + Q ⌘ + dt + q 2β −1D dW (t)

ρss ∝ e−βΦ(x)

+f = ⇣ D + Q ⌘ + − β −1 div ⇣ D + Q ⌘ . dx = −+f dt + q 2β −1D dW (t)

slide-23
SLIDE 23

23

Knots in our understanding

ARCHITECTURE OPTIMIZATION GENERALIZATION

slide-24
SLIDE 24

24

Punchline

Is SGD special?

slide-25
SLIDE 25

Thank you, questions?

25

www.pratikac.info

Stochastic gradient descent performs variational inference, converges to limit cycles for deep networks, Pratik Chaudhari and Stefano Soatto.

arXiv:1710.11029, ICLR '18