Disentangling Trainability and Generalization In Deep Neural - - PowerPoint PPT Presentation

disentangling trainability and generalization in deep
SMART_READER_LITE
LIVE PREVIEW

Disentangling Trainability and Generalization In Deep Neural - - PowerPoint PPT Presentation

Disentangling Trainability and Generalization In Deep Neural Networks Lechao Xiao, Je ff rey Pennington and Samuel S. Schoenholz Google Brain Team, Google Research Colab Tutorial Two Fundamental Theoretical Questions in Deep Learning


slide-1
SLIDE 1

Disentangling Trainability and Generalization In Deep Neural Networks

Lechao Xiao, Jeffrey Pennington and Samuel S. Schoenholz Google Brain Team, Google Research

Colab Tutorial

slide-2
SLIDE 2

Two Fundamental Theoretical Questions in Deep Learning

  • Trainability / Optimization 

  • Efficient algorithm to reach global minima

  • Generalization

  • performant on unseen data

  • Dream
  • (model, algorithm): Fast Training + Fantastic Generalization 

  • Solves AGI 

slide-3
SLIDE 3

A trade-off between Trainability and Generalization for very deep and very wide NNs

  • Trained Fast, but NOT generalizable
  • Large Weight Initialization (Chaotic Phase) 

  • Trained Slowly, able to generalize
  • Small Weight Initialization (Ordered Phase)

Deep Neural Networks

slide-4
SLIDE 4

x

<latexit sha1_base64="DuIzDhz/IEruhRbuOKeVOYkbx8E=">AB+XicbVDLSgMxFM3UVx1fVZdugkVwVWYU0Y1YdOyBfuAdiZ9E4bmSGJCOWoV/gVj9A3IhbP8JPENf+iOljodUDFw7n3Mu5nDhTBvP+3RyC4tLyv5VXdtfWNzq7C9U9dxqijUaMxj1QyJBs4k1AwzHJqJAiJCDo1wcDX2G7egNIvljRkmEAjSkyxilBgrVe86haJX8ibAf4k/I8WLd/c8ef5wK53CV7sb01SANJQTrVu+l5gI8owymHktlMNCaED0oOWpZI0E2eXSED6zSxVGs7EiDJ+rPi4wIrYcitJuCmL6e98biv14YirloE50FGZNJakDSaXKUcmxiPK4Bd5kCavjQEkIVs89j2ieKUGPLcm0r/nwHf0n9qOQfl06qXrF8iabIoz20jw6Rj05RGV2jCqohigDdowf06GTOk/PivE5Xc87sZhf9gvP2DWyrl04=</latexit>

f

<latexit sha1_base64="GgESaj68ZiQuUqXL42pXNMYKlgk=">AB+XicbVC7SgNBFL3rM8ZX1FKRwSBYhV1FtAzaWCZgHpCEMDu5mwyZ2V1mZoWwpLSy1Q+wE1ub/Iq1pT/h5Fo4oELh3Pu5VyOHwujet+OkvLK6tr65mN7ObW9s5ubm+/qNEMaywSESq7lONgodYMdwIrMcKqfQF1vz+7divPaDSPArvzSDGlqTdkAecUWOlctDO5d2COwFZJN6M5ItHo/L34/Go1M59NTsRSySGhgmqdcNzY9NKqTKcCRxm4nGmLI+7WLD0pBK1K108uiQnFqlQ4JI2QkNmai/L1IqtR5I325Kanp63huL/3q+L+eiTXDdSnkYJwZDNk0OEkFMRMY1kA5XyIwYWEKZ4vZ5wnpUWZsWVnbijfwSKpnhe8i8Jl2dZzA1Nk4BO4Aw8uIi3EJKsA4Qme4cVJnVfnzXmfri45s5sD+APn4wfkpJem</latexit>

Neural Networks Initialization

slide-5
SLIDE 5

Training Dynamics and NTK

Gradient descent dynamics with Mean Squared Error

  • In the infinite width setting, the NTK is deterministic and remains a constant

through training (NTK Jacot et al., 2018)

  • The above ODE has a closed form solution.

Neural Tangent Kernel (NTK) Function Space

slide-6
SLIDE 6

Training and Learning Dynamics

Training Dynamics: Learning Dynamics: Credit: Roman Novak Agreement between finite- and infinite-width networks

slide-7
SLIDE 7

Metric for Trainability: Condition Number

Training Dynamics: Eigen-decomposition The smallest eigenvector converges at rate Trainability Metric:

8-layers finite width FCN on CIFAR10

  • Blue
  • Orange

σ2

w = 25

σ2

w = 0.5

slide-8
SLIDE 8

Metric for Generalization: Mean Prediction

Mean Prediction Generalization metric: Learning Dynamics: Cannot generalize if becomes completely independent of the inputs.

P(Θ)Ytrain

slide-9
SLIDE 9

Evolution of the Metrics with depth

NTK Condition Number Mean Prediction Neural Networks

Analyzing Induced Dynamical Systems

slide-10
SLIDE 10

Convergence of NTK and Phase Diagram

Convergence of is determined by a bivariate function defined on the

Θ(l) χ1 (σ2

w, σ2 b)-plane

κ* = ∞ κ* = 1

P(Θ*)Ytrain = 0 P(Θ*)Ytrain = Ctest

  • Ordered Phase

:

  • χ1 < 1

Θ(l) → Θ* = C11T κ(l) → ∞ P(Θ(l))Ytrain → Ctest

  • Chaotic Phase

:

  • χ1 > 1

Θ(l) → ∞ κ(l) → 1 P(Θ(l))Ytrain → 0

Θ(l) → Θ* Θ(l) → ∞

slide-11
SLIDE 11

Chaotic Phase χ1 > 1

Easy to Train, but not Generalizable

  • Trainability / Generalization metrics
  • Entries dynamics of NTK
slide-12
SLIDE 12

Chaotic Phase / Memorization

Easy to Train, but Not Generalize

  • 10k/2k Training / Test, CIFAR10 (10 classes)
  • Full Batch + Gradient Descent
  • σ2

w = 25, σ2 b = 0, l = 8

slide-13
SLIDE 13

Ordered Phase χ1 < 1

  • Entries of the NTK

Difficult to Train, Generalizable

  • Trainability / Generalization metrics
slide-14
SLIDE 14

Ordered Phase / Generalization

  • σ2

w = 0.5

  • σ2

w = 25

Difficult to Train, Generalizable

Easy to Train, but Not Generalize

slide-15
SLIDE 15

Summary

  • A tradeoff between trainability and generalization for deep

and wide networks

  • Fast training + memorization (e.g. Chaotic Phase)
  • Slow training + generalizable (e.g. Ordered Phase)
  • More results
  • Pooling, Dropout, Skip Connection, LayerNorm, etc.
  • Conjugate Kernels

Colab Tutorial