On The Complexity of Training a Neural Network Santosh Vempala - - PowerPoint PPT Presentation

on the complexity of training a neural network
SMART_READER_LITE
LIVE PREVIEW

On The Complexity of Training a Neural Network Santosh Vempala - - PowerPoint PPT Presentation

On The Complexity of Training a Neural Network Santosh Vempala Algorithms and Randomness Center Georgia Tech The Complexity of Learning Neural Networks Deep learning s successes are incredible Do you want to beat 9 dan Goplayers? The


slide-1
SLIDE 1

On The Complexity of Training a Neural Network

Santosh Vempala

Algorithms and Randomness Center Georgia Tech

The Complexity of Learning Neural Networks

slide-2
SLIDE 2

Deep learning’s successes are incredible

Do you want to beat 9 dan Goplayers?

The Complexity of Learning Neural Networks

slide-3
SLIDE 3

Deep learning’s successes are incredible

Do you want to beat 9 dan Goplayers? classify images more accurately than humans?

The Complexity of Learning Neural Networks

slide-4
SLIDE 4

Deep learning’s successes are incredible

Do you want to beat 9 dan Goplayers? classify images more accurately than humans? recognize speech? recommend movies? drive an autonomous vehicle?

The Complexity of Learning Neural Networks

slide-5
SLIDE 5

Deep learning conquers the world

Do you want to beat 9 dan Goplayers? classify images more accurately than humans? recognize speech? recommend movies? drive an autonomous vehicle? publish a paper in NIPS? Try deeplearning!

The Complexity of Learning Neural Networks

slide-6
SLIDE 6

The Learning Problem

Problem Given labeled samples (𝑦,𝑔 𝑦 ) where x ∼D and f : Rn → R, find 𝑕: 𝑆𝑜 → 𝑆 s.t. 𝐹𝐸 𝑔 𝑦 − 𝑕 𝑦

2 ≤ 𝜁

The Complexity of Learning Neural Networks

slide-7
SLIDE 7

Deep Learning for a theoretician

Want to approximate concept f . How to choose model g? A simple “neural network” (NN) x ∼ D(Rn) y1 = σ(W1 · x + b1) y2 = σ(W2 · x + b2) W3 · y + b3 = g(x) input layer hidden layer

  • utput

layer

The Complexity of Learning Neural Networks

slide-8
SLIDE 8

Deep Learning for a theoretician

Want to approximate concept f . How to choose model g? A simple “neural network” (NN) x ∼ D(Rn) y1 = σ(W1 · x + b1) y2 = σ(W2 · x + b2) W3 · y + b3 = g(x)

e.g., sigmoid: 𝜏 𝑦 =

1 1+𝑓−𝑦

activation function weights

The Complexity of Learning Neural Networks

slide-9
SLIDE 9

Deep Learning for a theoretician

How to “train” network, i.e., choose W , b? x ∼ D(Rn) y1 = σ(W1 · x + b1) y2 = σ(W2 · x + b2) W3 · y + b3 = g(x) Gradient descent: estimate gradient from samples, update weights, repeat (x1,f (x1)) (x2,f (x2)) (x2,f (x2)) labeled data Ex ∇W (f −g)2 W ←W −∇W

The Complexity of Learning Neural Networks

slide-10
SLIDE 10

Guarantees for deep learning?

Goal Provable guarantees for NN training algorithms

The Complexity of Learning Neural Networks

slide-11
SLIDE 11

Guarantees for deep learning?

Goal Provable guarantees for NN training algorithms when data generated by a NN

The Complexity of Learning Neural Networks

slide-12
SLIDE 12

Guarantees for deep learning?

Goal Provable guarantees for NN training algorithms when data generated by a NN Theorem (Cybenko1989) Continuous functions can be approximated by 1-hidden-layer NNs with sigmoids

The Complexity of Learning Neural Networks

slide-13
SLIDE 13

Guarantees for deep learning?

Goal Provable guarantees for NN training algorithms when data generated by a small one-hidden-layer NN Theorem (Cybenko1989) Continuous functions can be approximated by 1-hidden-layer NNs with sigmoids

The Complexity of Learning Neural Networks

slide-14
SLIDE 14

Provable guarantees for deep learning?

What could the form of such guarantees be? Under what conditions (on input distribution, function) does Stochastic Gradient Descent work? Does it help if the data is generated by a NN? (Is the “realizable” case easier?)

slide-15
SLIDE 15

Guarantees for deep learning?

Lower bounds for realizable case: NP-Hard to train neural network with 3 threshold neurons (A. Blum–Rivest 1993) Complexity/crypto assumptions ⇒ cannot efficiently learn small depth networks (even improperly) (Klivans–Sherstov, 2006), (Daniely–Linial–Shalev-Schwartz, 2014) Even with nice input distributions, some deep learning algorithms don’t work (Shamir, 2016)

slide-16
SLIDE 16

Outline

  • 1. Lower bounds for learning neural networks

“You can’t efficiently learn functions computed by small, single-hidden layer neural networks, even over nice input distributions.”

  • 2. Polynomial-time analysis of gradient descent

“Gradient descent can efficiently train single-hidden layer neuralnetworks of unbiased activation units.”

  • 2. Open questions

The Complexity of Learning Neural Networks

slide-17
SLIDE 17

A nice family of neural networks

N (0,1) N (0,1) N (0,1) N (0,1) N (0,1) ෨ 𝑃(𝑜) sigmoid units n-dim input Linear

  • utput

The Construction

slide-18
SLIDE 18

Smooth activation functions

Sigmoid function of sharpness s is

esx 1+esx

s = 1 s = 4 s = 12

A nice family of neural networks The Construction

slide-19
SLIDE 19

A little more generality

logconcave logconcave logconcave logconcave logconcave ReLU, PReLU, softplus, sigmoid ... n-dim input Linear

  • utput

A nice family of neural networks The Construction

slide-20
SLIDE 20

Use deep learning!

Choose your favorite network architecture, activation units, loss function, gradient descent variant, regularization scheme. . .

slide-21
SLIDE 21

A computational lower bound

Choose your favorite network architecture, activation units, loss function, gradient descent variant, regularization scheme, etc. “Theorem” If using only “black box” functions of input (e.g., gradients via Tensorflow, Hessians via Autograd) need 2Ω 𝑜 𝑡2 function evaluations of accuracy at most 1 𝑡 𝑜.

slide-22
SLIDE 22

A little more context

O˜(n) sigmoid units n-dim input Linear

  • utput

N (0,1) N (0,1) N (0,1) N (0,1) N (0,1) Janzamin–Sedghi–Anandkumar (2015): tensor decomposition algorithmwith additional assumptions; sample size = poly(n, condition number of weights) Shamir (2016) gives exponential lower bounds against “vanilla” SGD with mean-squared loss, ReLU units, Gaussian input (nonrealizable, but similar construction) More recent improvements on upper bounds (coming up)

slide-23
SLIDE 23

A little more generality

ReLU, PReLU, softplus, sigmoid, ... n-dim input Linear

  • utput

logconcave logconcave logconcave logconcave logconcave

Lower bound applies to algorithms of following form: 1. Estimate v = E(x,y)∼D(h(W ,x, y)) where W: current weights, (x,y): labeled example from input dist. D H: arbitrary [0, 1]-valued function

  • 2. Use v to update W .
slide-24
SLIDE 24

The hard family of functions

σ(1/s + x) σ(1/s − x) φ(x) = σ(1/s + x) + σ(1/s − x) − 1 Fσ(x) = φ(x) + φ(x − 2/s) + φ(x + 2/s) + · · ·

Fσ : R → R affine combination of σ-units, almost periodic on [− ෨ 𝑃(𝑜), ෨ 𝑃(𝑜)], period = 1/s

slide-25
SLIDE 25

The hard family of functions

r(x + 1/(2s)) −r(x − 1/(2s)) φ0(x) = r(x + 1/(2s)) − r(x − 1/(2s)) φ(x) = φ0(1/(2s) + x) + φ0(1/(2s) − x) − 1 Fr(x) = φ(x) + φ(x − 2/s) + φ(x + 2/s) + · · ·

slide-26
SLIDE 26

The hard family of functions

S logconcave f S

Fσ : R → R affine combination of σ-units, almost periodic on [− ෨ 𝑃(𝑜), ෨ 𝑃(𝑜)], period = 1/s

weights ±1 logconcave S F σ

∀ S ⊆ {1,...,n} with |S| = n/2 𝑔

𝑇 𝑦 = 𝐺 𝜏

𝑗∈𝑇

𝑦𝑗

slide-27
SLIDE 27

Throw some “deep learning” at it!

Choose your favorite: network architecture, activation units, loss function, gradient descent variant, regularization scheme. ..

The Complexity of Learning Neural Networks

slide-28
SLIDE 28

Theory vs. practice, revisited

5 15 20 10 s * sqrt(n) 0.000 0.005 0.010 0.035 0.030 0.025 0.020 0.015 0.040 train error

sigmoid 50 100 200 500 1000

The Complexity of Learning Neural Networks

slide-29
SLIDE 29

Theory vs. practice, revisited

slide-30
SLIDE 30

Throw some “deep learning” at it!

Choose your favorite: network architecture, activation units, loss function, gradient descent variant, regularization scheme. .. Were yousuccessful?

The Complexity of Learning Neural Networks

slide-31
SLIDE 31

Try “deep learning”!

Choose your favorite: network architecture, activation units, loss function, gradient descent variant, regularization scheme. .. Were you successful? “Theorem” No!

The Complexity of Learning Neural Networks

slide-32
SLIDE 32

Statistical Query Algorithms

Recall gradient descent training algorithm: (x1,f (x1)) (x2,f (x2)) (x2,f (x2)) labeled data estimate gradient Ex∇W (f −g)2 update weights W ←W - ∇W

The Complexity of Learning Neural Networks

slide-33
SLIDE 33

Statistical Query Algorithms

Recall gradient descent training algorithm: query gradient Ex∇W (f −g)2 update weights W ← W + ∇W Need gradient estimate, not necessarily labeled examples

The Complexity of Learning Neural Networks

slide-34
SLIDE 34

Statistical Query Algorithms

Oracle Algorithm

Statistical query (SQ) algorithms introduced by Kearns in 1993. No direct access to samples. Queries expectations of functions on labeled example distribution

Query (h, τ ): h : Rn × R → [0, 1], τ > 0 Response v:

|Eh − v| < τ E.g. for gradient descent, query h = ∇W (f − g)2

The Complexity of Learning Neural Networks

slide-35
SLIDE 35

Statistical Query Algorithms

SQ algorithms extremelygeneral. Almost all “robust” machine learning guarantees can be acheived with SQ algorithms.

The Complexity of Learning Neural Networks

slide-36
SLIDE 36

Statistical Query Algorithms

Statistical algorithms: no direct interaction with labeled examples Definition [Feldman-Grigorescu-Reyzin-V.-Xiao13] Statistical query algorithm interacts with input distribution (X, D)

  • nly by evaluating expectations of bounded functions. More precisely,

for any query function 𝑔:𝑌 → [0,1], with 𝑞 = 𝐹(𝑔 𝑌 ) and any integer 𝑢 > 0, the oracle VSTAT(t) returns 𝑤 s.t. 𝑤 − 𝑞 ≤ max 1 𝑢 , 𝑞(1 − 𝑞) 𝑢 (error of t Bernoulli coin tosses, each of bias p).

1 2 3

slide-37
SLIDE 37

Statistical Query Algorithms

Framework introduced by (Kearns, 1993) with STAT oracle. VSTAT introduced by (FGRVX13, 2013) gives better bounds. Known algorithms for supervised/unsupervised learning can be viewed as SQ algorithms. Only known exception: Parity learning via Gaussian elimination. Nearly tight SQ lower bounds for: Learning (noisy) parity (Kearns, 1993), (Blum–Furst–Jackson– Kearns–Mansour–Rudich, 1994) Detecting bipartite planted cliques (Feldman–Grigorescu– Reyzin–Vempala–Xiao, 2012) Planted SAT/CSP (Feldman–Perkins–Vempala, 2015) Robustly learning a Gaussian (Diakonikolas-Kane-Stewart, 2017)

slide-38
SLIDE 38

Lower bound

Theorem [Song-V.-Wilmes-Xie 2017]

There exist functions 𝑔:ℜ𝑜 → ℜ realizable as NNs with one hidden layer of s-sharp sigmoids, such that for any logconcaveproduct distribution on inputs, any statistical algorithm that learns f needs

2Ω 𝑜 𝑡2

queries to VSTAT(𝑡2𝑜)

slide-39
SLIDE 39

Statistical dimension

Informal definition Function family C has dim. d with avg. corr. γ if ∀C'⊆Cwith 𝐷′ ≤ |𝐷|/𝑒, a random pair 𝑔,𝑕 ∈ 𝐷′ has expected correlation 𝐷𝑝𝑠𝑠𝑓𝑚𝑏𝑢𝑗𝑝𝑜(𝑔,𝑕 ) < 𝛿 Statistical dimension d ⇒ need Ω(d ) queries Correlation γ corresponds to query tolerance

The Complexity of Learning Neural Networks

slide-40
SLIDE 40

Statistical dimension

Hypothesis space “Good” query h 1

The Complexity of Learning Neural Networks

slide-41
SLIDE 41

Statistical dimension

“Good” query h 1 Any query when SD large Hypothesis space 1

The Complexity of Learning Neural Networks

slide-42
SLIDE 42

Statistical dimension

Statistical dimension bounds ⇒ query complexity bounds Here we extendconcept and proof to regression problems

The Complexity of Learning Neural Networks

slide-43
SLIDE 43

The lower bound

Not enough to make generic assumptions about network size and input distribution. What are candidate sufficient conditions?

The Complexity of Learning Neural Networks

slide-44
SLIDE 44

Upper bounds

Can learn neural nets in polynomial time: non-GD algorithms: single-layer, with strong non-degeneracy assumptions and access to density function (Janzamin–Sedghi– Anandkumar 2015), two layers, with weaker non-degeneracy assumptions (Goel–Klivans 2017, Goel-Klivans-Meka 2018). using GD: with careful regularization (Ge–Lee–Ma 2017), for non-overlapping convolutional nets (Zhong–Song–Dhillon 2017) and (Du–Lee–Tian-P´ozcos–Singh 2017),for ResNet-style architectures (Li–Yuan 2017), for learning polynomials using 𝑓𝑨 activation units (Andoni-Panigrahy- Valiant-Zhang 2014). Vanilla GD with weak assumptions? A kernel framework (Daniely 2017)

The Complexity of Learning Neural Networks

slide-45
SLIDE 45

What is “non-degenerate” enough?

𝑔 𝑦 = ෍

𝑣

𝑏 𝑣 𝜏 ( 𝑣 ⋅ 𝑦 ) Where |𝑏| 1 ≤ 𝐷 x~ uniform over Sn−1

  • r Gaussian

(or spherically symmetric)

x ∈S n − 1 uniform unbiased σ units Linear

  • utput

Assumptions guarantee f is approximately low-degree

The Complexity of Learning Neural Networks

slide-46
SLIDE 46

Learning low-degree NNs via gradient descent

Notation: f (≤k) = f projected to degree ≤ k polynomial over the sphere Problem Fix k ∈N, let OPT = E(f − f (≤k))2 Goal: find g : Sn−1 → R with E(g − f )2 ≤ OPT + ε

The Complexity of Learning Neural Networks

slide-47
SLIDE 47

Learning low-degree NNs via gradient descent

Notation: f (≤k) = f projected to degree ≤ k spherical harmonics Problem Fix k ∈N, let OPT = E(f − f (≤k))2 Goal: find g : Sn−1 → R with E(g − f )2 ≤ OPT + ε Theorem (V.-Wilmes2018) For m = nO(k)/poly (ε) units, if g is a randomly initialized 1-hidden-layer NN with m gates and a linear output layer, then nO(k) log(1/ε) rounds of GD gives E(f − g )2 < OPT + ε whp.

The Complexity of Learning Neural Networks

slide-48
SLIDE 48

The Funk transform

Gradient for model g(x) = a·σ(u ·x): Ex∇a(f −g)2 = Ex((f −g)σ(u ·x)) Definition (Funk transform) Tσh(u) = Ex∈Sn−1(h(x)σ(u · x))

The Complexity of Learning Neural Networks

𝜏 𝑣 ⋅ 𝑦 = 1𝑣⋅𝑦≥0: 𝑈

𝜏 is the hemispherical transform

𝜏 𝑣 ⋅ 𝑦 = 1𝑣⋅𝑦=0: 𝑈

𝜏 is the Radon transform

With ℎ = 𝑔 − 𝑕, 𝑈

𝜏 𝑔 − 𝑕 = Ex∇a(f −g)2

slide-49
SLIDE 49

The Funk transform

degree = 1 degree = 2 degree = 3 Original function Harmonic spectrum Funk transform Harmonic spectrum

The Complexity of Learning Neural Networks

slide-50
SLIDE 50

Funk-HeckeTheorem

Theorem (Funk–Hecke Formula) If P is a degree-k spherical harmonic on Sn−1 then 𝑈𝜏𝑄(𝑣) = 𝛽𝑙 𝜏 𝑄(𝑣) for some constant 𝛽𝑙(𝜏) = 𝑜−𝜄(𝑙). Classic result of spherical harmonics, over a century old.

The Complexity of Learning Neural Networks

slide-51
SLIDE 51

Double Funk

Lemma 𝑕𝑗+1 − 𝑕𝑗 ≈ Τ𝜏Τ𝜏(𝑔 − 𝑕) Corollary fromFunk–Hecke If f − gi ≈ (f − gi )(≤k),then ||𝑕𝑗+1 − 𝑕𝑗||2

2 ≈ 𝛽𝑜,𝑙 𝜏 2||𝑔 − 𝑕𝑗||2 2

𝑏𝑣 ← 𝑏𝑣 + Τ𝜏(𝑔 − 𝑕) 𝑕𝑗+1 = 𝑕𝑗 + 1 𝑛 ෍

𝑣

Τ𝜏(𝑔 − 𝑕)(𝑣)𝜏 𝑣 ⋅ 𝑦

slide-52
SLIDE 52

Updating only the top-level weights suffices

Pick 𝑣 at random and fix. Update: 𝑏𝑣 ← 𝑏𝑣 + Τ𝜏(𝑔 − 𝑕)

Reminiscent of “random kitchen sinks” [Rahimi-Recht 2009] There the guarantee is w.r.t. to functions of similar support as the sampling density of 𝑣 Here it is w.r.t. the best low-degree approximation. And it is nearly-optimal for every k for all SQ algorithms! [VW18] [Proof idea: pick random harmonics; they are highly uncorrelated even over small intervals of the range]

slide-53
SLIDE 53

Research Directions

1. Is gradient descent faster for the unbiased realizable by a small NN setting? 2. Does having many layers help computational efficiency? (we know it can help representation size) 3. Why is it that training works even with lesser data than number of parameters? 4. What (distribution, function)-pairs can NN’s learn provably efficiently? 5. What real-world structure makes learning easier? E.g.,, do generic weights in a NN make the corresponding loss function “close” to convex? 6. How does the brain learn? (Is it actually robust?)

The Complexity of Learning Neural Networks

slide-54
SLIDE 54

Thank you!

The Complexity of Learning Neural Networks

slide-55
SLIDE 55

Learning low-degree NNs via gradient descent

Corollary Data labeled by unbiased 1-hidden layer sigmoid-NN with hidden-layer weights of 2-norm ≤ a and output-layer weight of 1- norm ≤ b, can be learned to accuracy ε by training NN with m = nO(b log(ab)/ε)) units via gradientdescent.

The Complexity of Learning Neural Networks