SLIDE 1 On The Complexity of Training a Neural Network
Santosh Vempala
Algorithms and Randomness Center Georgia Tech
The Complexity of Learning Neural Networks
SLIDE 2 Deep learning’s successes are incredible
Do you want to beat 9 dan Goplayers?
The Complexity of Learning Neural Networks
SLIDE 3 Deep learning’s successes are incredible
Do you want to beat 9 dan Goplayers? classify images more accurately than humans?
The Complexity of Learning Neural Networks
SLIDE 4 Deep learning’s successes are incredible
Do you want to beat 9 dan Goplayers? classify images more accurately than humans? recognize speech? recommend movies? drive an autonomous vehicle?
The Complexity of Learning Neural Networks
SLIDE 5 Deep learning conquers the world
Do you want to beat 9 dan Goplayers? classify images more accurately than humans? recognize speech? recommend movies? drive an autonomous vehicle? publish a paper in NIPS? Try deeplearning!
The Complexity of Learning Neural Networks
SLIDE 6 The Learning Problem
Problem Given labeled samples (𝑦,𝑔 𝑦 ) where x ∼D and f : Rn → R, find : 𝑆𝑜 → 𝑆 s.t. 𝐹𝐸 𝑔 𝑦 − 𝑦
2 ≤ 𝜁
The Complexity of Learning Neural Networks
SLIDE 7 Deep Learning for a theoretician
Want to approximate concept f . How to choose model g? A simple “neural network” (NN) x ∼ D(Rn) y1 = σ(W1 · x + b1) y2 = σ(W2 · x + b2) W3 · y + b3 = g(x) input layer hidden layer
layer
The Complexity of Learning Neural Networks
SLIDE 8 Deep Learning for a theoretician
Want to approximate concept f . How to choose model g? A simple “neural network” (NN) x ∼ D(Rn) y1 = σ(W1 · x + b1) y2 = σ(W2 · x + b2) W3 · y + b3 = g(x)
e.g., sigmoid: 𝜏 𝑦 =
1 1+𝑓−𝑦
activation function weights
The Complexity of Learning Neural Networks
SLIDE 9 Deep Learning for a theoretician
How to “train” network, i.e., choose W , b? x ∼ D(Rn) y1 = σ(W1 · x + b1) y2 = σ(W2 · x + b2) W3 · y + b3 = g(x) Gradient descent: estimate gradient from samples, update weights, repeat (x1,f (x1)) (x2,f (x2)) (x2,f (x2)) labeled data Ex ∇W (f −g)2 W ←W −∇W
The Complexity of Learning Neural Networks
SLIDE 10 Guarantees for deep learning?
Goal Provable guarantees for NN training algorithms
The Complexity of Learning Neural Networks
SLIDE 11 Guarantees for deep learning?
Goal Provable guarantees for NN training algorithms when data generated by a NN
The Complexity of Learning Neural Networks
SLIDE 12 Guarantees for deep learning?
Goal Provable guarantees for NN training algorithms when data generated by a NN Theorem (Cybenko1989) Continuous functions can be approximated by 1-hidden-layer NNs with sigmoids
The Complexity of Learning Neural Networks
SLIDE 13 Guarantees for deep learning?
Goal Provable guarantees for NN training algorithms when data generated by a small one-hidden-layer NN Theorem (Cybenko1989) Continuous functions can be approximated by 1-hidden-layer NNs with sigmoids
The Complexity of Learning Neural Networks
SLIDE 14
Provable guarantees for deep learning?
What could the form of such guarantees be? Under what conditions (on input distribution, function) does Stochastic Gradient Descent work? Does it help if the data is generated by a NN? (Is the “realizable” case easier?)
SLIDE 15
Guarantees for deep learning?
Lower bounds for realizable case: NP-Hard to train neural network with 3 threshold neurons (A. Blum–Rivest 1993) Complexity/crypto assumptions ⇒ cannot efficiently learn small depth networks (even improperly) (Klivans–Sherstov, 2006), (Daniely–Linial–Shalev-Schwartz, 2014) Even with nice input distributions, some deep learning algorithms don’t work (Shamir, 2016)
SLIDE 16 Outline
- 1. Lower bounds for learning neural networks
“You can’t efficiently learn functions computed by small, single-hidden layer neural networks, even over nice input distributions.”
- 2. Polynomial-time analysis of gradient descent
“Gradient descent can efficiently train single-hidden layer neuralnetworks of unbiased activation units.”
The Complexity of Learning Neural Networks
SLIDE 17 A nice family of neural networks
N (0,1) N (0,1) N (0,1) N (0,1) N (0,1) ෨ 𝑃(𝑜) sigmoid units n-dim input Linear
The Construction
SLIDE 18
Smooth activation functions
Sigmoid function of sharpness s is
esx 1+esx
s = 1 s = 4 s = 12
A nice family of neural networks The Construction
SLIDE 19 A little more generality
logconcave logconcave logconcave logconcave logconcave ReLU, PReLU, softplus, sigmoid ... n-dim input Linear
A nice family of neural networks The Construction
SLIDE 20
Use deep learning!
Choose your favorite network architecture, activation units, loss function, gradient descent variant, regularization scheme. . .
SLIDE 21
A computational lower bound
Choose your favorite network architecture, activation units, loss function, gradient descent variant, regularization scheme, etc. “Theorem” If using only “black box” functions of input (e.g., gradients via Tensorflow, Hessians via Autograd) need 2Ω 𝑜 𝑡2 function evaluations of accuracy at most 1 𝑡 𝑜.
SLIDE 22 A little more context
O˜(n) sigmoid units n-dim input Linear
N (0,1) N (0,1) N (0,1) N (0,1) N (0,1) Janzamin–Sedghi–Anandkumar (2015): tensor decomposition algorithmwith additional assumptions; sample size = poly(n, condition number of weights) Shamir (2016) gives exponential lower bounds against “vanilla” SGD with mean-squared loss, ReLU units, Gaussian input (nonrealizable, but similar construction) More recent improvements on upper bounds (coming up)
SLIDE 23 A little more generality
ReLU, PReLU, softplus, sigmoid, ... n-dim input Linear
logconcave logconcave logconcave logconcave logconcave
Lower bound applies to algorithms of following form: 1. Estimate v = E(x,y)∼D(h(W ,x, y)) where W: current weights, (x,y): labeled example from input dist. D H: arbitrary [0, 1]-valued function
SLIDE 24
The hard family of functions
σ(1/s + x) σ(1/s − x) φ(x) = σ(1/s + x) + σ(1/s − x) − 1 Fσ(x) = φ(x) + φ(x − 2/s) + φ(x + 2/s) + · · ·
Fσ : R → R affine combination of σ-units, almost periodic on [− ෨ 𝑃(𝑜), ෨ 𝑃(𝑜)], period = 1/s
SLIDE 25 The hard family of functions
r(x + 1/(2s)) −r(x − 1/(2s)) φ0(x) = r(x + 1/(2s)) − r(x − 1/(2s)) φ(x) = φ0(1/(2s) + x) + φ0(1/(2s) − x) − 1 Fr(x) = φ(x) + φ(x − 2/s) + φ(x + 2/s) + · · ·
SLIDE 26 The hard family of functions
S logconcave f S
Fσ : R → R affine combination of σ-units, almost periodic on [− ෨ 𝑃(𝑜), ෨ 𝑃(𝑜)], period = 1/s
weights ±1 logconcave S F σ
∀ S ⊆ {1,...,n} with |S| = n/2 𝑔
𝑇 𝑦 = 𝐺 𝜏
𝑗∈𝑇
𝑦𝑗
SLIDE 27 Throw some “deep learning” at it!
Choose your favorite: network architecture, activation units, loss function, gradient descent variant, regularization scheme. ..
The Complexity of Learning Neural Networks
SLIDE 28 Theory vs. practice, revisited
5 15 20 10 s * sqrt(n) 0.000 0.005 0.010 0.035 0.030 0.025 0.020 0.015 0.040 train error
sigmoid 50 100 200 500 1000
The Complexity of Learning Neural Networks
SLIDE 29
Theory vs. practice, revisited
SLIDE 30 Throw some “deep learning” at it!
Choose your favorite: network architecture, activation units, loss function, gradient descent variant, regularization scheme. .. Were yousuccessful?
The Complexity of Learning Neural Networks
SLIDE 31 Try “deep learning”!
Choose your favorite: network architecture, activation units, loss function, gradient descent variant, regularization scheme. .. Were you successful? “Theorem” No!
The Complexity of Learning Neural Networks
SLIDE 32 Statistical Query Algorithms
Recall gradient descent training algorithm: (x1,f (x1)) (x2,f (x2)) (x2,f (x2)) labeled data estimate gradient Ex∇W (f −g)2 update weights W ←W - ∇W
The Complexity of Learning Neural Networks
SLIDE 33 Statistical Query Algorithms
Recall gradient descent training algorithm: query gradient Ex∇W (f −g)2 update weights W ← W + ∇W Need gradient estimate, not necessarily labeled examples
The Complexity of Learning Neural Networks
SLIDE 34 Statistical Query Algorithms
Oracle Algorithm
Statistical query (SQ) algorithms introduced by Kearns in 1993. No direct access to samples. Queries expectations of functions on labeled example distribution
Query (h, τ ): h : Rn × R → [0, 1], τ > 0 Response v:
|Eh − v| < τ E.g. for gradient descent, query h = ∇W (f − g)2
The Complexity of Learning Neural Networks
SLIDE 35 Statistical Query Algorithms
SQ algorithms extremelygeneral. Almost all “robust” machine learning guarantees can be acheived with SQ algorithms.
The Complexity of Learning Neural Networks
SLIDE 36 Statistical Query Algorithms
Statistical algorithms: no direct interaction with labeled examples Definition [Feldman-Grigorescu-Reyzin-V.-Xiao13] Statistical query algorithm interacts with input distribution (X, D)
- nly by evaluating expectations of bounded functions. More precisely,
for any query function 𝑔:𝑌 → [0,1], with 𝑞 = 𝐹(𝑔 𝑌 ) and any integer 𝑢 > 0, the oracle VSTAT(t) returns 𝑤 s.t. 𝑤 − 𝑞 ≤ max 1 𝑢 , 𝑞(1 − 𝑞) 𝑢 (error of t Bernoulli coin tosses, each of bias p).
1 2 3
SLIDE 37
Statistical Query Algorithms
Framework introduced by (Kearns, 1993) with STAT oracle. VSTAT introduced by (FGRVX13, 2013) gives better bounds. Known algorithms for supervised/unsupervised learning can be viewed as SQ algorithms. Only known exception: Parity learning via Gaussian elimination. Nearly tight SQ lower bounds for: Learning (noisy) parity (Kearns, 1993), (Blum–Furst–Jackson– Kearns–Mansour–Rudich, 1994) Detecting bipartite planted cliques (Feldman–Grigorescu– Reyzin–Vempala–Xiao, 2012) Planted SAT/CSP (Feldman–Perkins–Vempala, 2015) Robustly learning a Gaussian (Diakonikolas-Kane-Stewart, 2017)
SLIDE 38
Lower bound
Theorem [Song-V.-Wilmes-Xie 2017]
There exist functions 𝑔:ℜ𝑜 → ℜ realizable as NNs with one hidden layer of s-sharp sigmoids, such that for any logconcaveproduct distribution on inputs, any statistical algorithm that learns f needs
2Ω 𝑜 𝑡2
queries to VSTAT(𝑡2𝑜)
SLIDE 39 Statistical dimension
Informal definition Function family C has dim. d with avg. corr. γ if ∀C'⊆Cwith 𝐷′ ≤ |𝐷|/𝑒, a random pair 𝑔, ∈ 𝐷′ has expected correlation 𝐷𝑝𝑠𝑠𝑓𝑚𝑏𝑢𝑗𝑝𝑜(𝑔, ) < 𝛿 Statistical dimension d ⇒ need Ω(d ) queries Correlation γ corresponds to query tolerance
The Complexity of Learning Neural Networks
SLIDE 40 Statistical dimension
Hypothesis space “Good” query h 1
The Complexity of Learning Neural Networks
SLIDE 41 Statistical dimension
“Good” query h 1 Any query when SD large Hypothesis space 1
The Complexity of Learning Neural Networks
SLIDE 42 Statistical dimension
Statistical dimension bounds ⇒ query complexity bounds Here we extendconcept and proof to regression problems
The Complexity of Learning Neural Networks
SLIDE 43 The lower bound
Not enough to make generic assumptions about network size and input distribution. What are candidate sufficient conditions?
The Complexity of Learning Neural Networks
SLIDE 44 Upper bounds
Can learn neural nets in polynomial time: non-GD algorithms: single-layer, with strong non-degeneracy assumptions and access to density function (Janzamin–Sedghi– Anandkumar 2015), two layers, with weaker non-degeneracy assumptions (Goel–Klivans 2017, Goel-Klivans-Meka 2018). using GD: with careful regularization (Ge–Lee–Ma 2017), for non-overlapping convolutional nets (Zhong–Song–Dhillon 2017) and (Du–Lee–Tian-P´ozcos–Singh 2017),for ResNet-style architectures (Li–Yuan 2017), for learning polynomials using 𝑓𝑨 activation units (Andoni-Panigrahy- Valiant-Zhang 2014). Vanilla GD with weak assumptions? A kernel framework (Daniely 2017)
The Complexity of Learning Neural Networks
SLIDE 45 What is “non-degenerate” enough?
𝑔 𝑦 =
𝑣
𝑏 𝑣 𝜏 ( 𝑣 ⋅ 𝑦 ) Where |𝑏| 1 ≤ 𝐷 x~ uniform over Sn−1
(or spherically symmetric)
x ∈S n − 1 uniform unbiased σ units Linear
Assumptions guarantee f is approximately low-degree
The Complexity of Learning Neural Networks
SLIDE 46 Learning low-degree NNs via gradient descent
Notation: f (≤k) = f projected to degree ≤ k polynomial over the sphere Problem Fix k ∈N, let OPT = E(f − f (≤k))2 Goal: find g : Sn−1 → R with E(g − f )2 ≤ OPT + ε
The Complexity of Learning Neural Networks
SLIDE 47 Learning low-degree NNs via gradient descent
Notation: f (≤k) = f projected to degree ≤ k spherical harmonics Problem Fix k ∈N, let OPT = E(f − f (≤k))2 Goal: find g : Sn−1 → R with E(g − f )2 ≤ OPT + ε Theorem (V.-Wilmes2018) For m = nO(k)/poly (ε) units, if g is a randomly initialized 1-hidden-layer NN with m gates and a linear output layer, then nO(k) log(1/ε) rounds of GD gives E(f − g )2 < OPT + ε whp.
The Complexity of Learning Neural Networks
SLIDE 48 The Funk transform
Gradient for model g(x) = a·σ(u ·x): Ex∇a(f −g)2 = Ex((f −g)σ(u ·x)) Definition (Funk transform) Tσh(u) = Ex∈Sn−1(h(x)σ(u · x))
The Complexity of Learning Neural Networks
𝜏 𝑣 ⋅ 𝑦 = 1𝑣⋅𝑦≥0: 𝑈
𝜏 is the hemispherical transform
𝜏 𝑣 ⋅ 𝑦 = 1𝑣⋅𝑦=0: 𝑈
𝜏 is the Radon transform
With ℎ = 𝑔 − , 𝑈
𝜏 𝑔 − = Ex∇a(f −g)2
SLIDE 49 The Funk transform
degree = 1 degree = 2 degree = 3 Original function Harmonic spectrum Funk transform Harmonic spectrum
The Complexity of Learning Neural Networks
SLIDE 50 Funk-HeckeTheorem
Theorem (Funk–Hecke Formula) If P is a degree-k spherical harmonic on Sn−1 then 𝑈𝜏𝑄(𝑣) = 𝛽𝑙 𝜏 𝑄(𝑣) for some constant 𝛽𝑙(𝜏) = 𝑜−𝜄(𝑙). Classic result of spherical harmonics, over a century old.
The Complexity of Learning Neural Networks
SLIDE 51
Double Funk
Lemma 𝑗+1 − 𝑗 ≈ Τ𝜏Τ𝜏(𝑔 − ) Corollary fromFunk–Hecke If f − gi ≈ (f − gi )(≤k),then ||𝑗+1 − 𝑗||2
2 ≈ 𝛽𝑜,𝑙 𝜏 2||𝑔 − 𝑗||2 2
𝑏𝑣 ← 𝑏𝑣 + Τ𝜏(𝑔 − ) 𝑗+1 = 𝑗 + 1 𝑛
𝑣
Τ𝜏(𝑔 − )(𝑣)𝜏 𝑣 ⋅ 𝑦
SLIDE 52
Updating only the top-level weights suffices
Pick 𝑣 at random and fix. Update: 𝑏𝑣 ← 𝑏𝑣 + Τ𝜏(𝑔 − )
Reminiscent of “random kitchen sinks” [Rahimi-Recht 2009] There the guarantee is w.r.t. to functions of similar support as the sampling density of 𝑣 Here it is w.r.t. the best low-degree approximation. And it is nearly-optimal for every k for all SQ algorithms! [VW18] [Proof idea: pick random harmonics; they are highly uncorrelated even over small intervals of the range]
SLIDE 53 Research Directions
1. Is gradient descent faster for the unbiased realizable by a small NN setting? 2. Does having many layers help computational efficiency? (we know it can help representation size) 3. Why is it that training works even with lesser data than number of parameters? 4. What (distribution, function)-pairs can NN’s learn provably efficiently? 5. What real-world structure makes learning easier? E.g.,, do generic weights in a NN make the corresponding loss function “close” to convex? 6. How does the brain learn? (Is it actually robust?)
The Complexity of Learning Neural Networks
SLIDE 54 Thank you!
The Complexity of Learning Neural Networks
SLIDE 55 Learning low-degree NNs via gradient descent
Corollary Data labeled by unbiased 1-hidden layer sigmoid-NN with hidden-layer weights of 2-norm ≤ a and output-layer weight of 1- norm ≤ b, can be learned to accuracy ε by training NN with m = nO(b log(ab)/ε)) units via gradientdescent.
The Complexity of Learning Neural Networks