Direction Matters: On the Implicit Regularization Effect of - - PowerPoint PPT Presentation

direction matters on the implicit regularization effect
SMART_READER_LITE
LIVE PREVIEW

Direction Matters: On the Implicit Regularization Effect of - - PowerPoint PPT Presentation

Direction Matters: On the Implicit Regularization Effect of Stochastic Gradient Descent with Moderate Learning Rate Jingfeng Wu , Difan Zou, Vladimir Braverman, Quanquan Gu Johns Hopkins University & UCLA November 2020 Overview


slide-1
SLIDE 1

Direction Matters: On the Implicit Regularization Effect of Stochastic Gradient Descent with Moderate Learning Rate

Jingfeng Wu, Difan Zou, Vladimir Braverman, Quanquan Gu Johns Hopkins University & UCLA

November 2020

slide-2
SLIDE 2

Overview

  • Background
  • SGD vs. GD: Different Convergence Directions
  • Small Learning Rate
  • Moderate Learning Rate
  • Direction Matters: SGD + Moderate LR is Good
  • Proof Sketches
slide-3
SLIDE 3

Implicit Regularization: SGD >> GD

CIFAR-10, ResNet-18, w/o weight decay, w/o data augmentation

SGD GD

Wu, Jingfeng, et al. "On the Noisy Gradient Descent that Generalizes as SGD." ICML 2020.

<latexit sha1_base64="NEVHOYBfO6Mb1VrYie/t9hYiKGM=">ACMXicbVBNT9tAEF0HaCGlJaXHXlaJKgWhRjaiao9RufRIJRIixVE0Xo+TVdYf2h03sqz8C078jf4BrvAPcqt6KV/go3DgQAj7erpvZl5oxdkShpy3aVT29refV6d6/+Zv/tu4PG+8O+SXMtsCdSlepBAaVTLBHkhQOMo0QBwovg9nZSr/8hdrINLmgIsNRDJNERlIAWrcOJ1zX2FEoHU653P+mZd+tbXUGC64jwT2n2gILVZqPGvPj/i40XI7blX8OfAeQKvb9I+vlt3ifNz454epyGNMSCgwZui5GY1K0CSFwkXdzw1mIGYwaGFCcRoRmV1x4J/skzIo1TblxCv2McTJcTGFHFgO2OgqXmqrciXtGFO0bdRKZMsJ0zE2ijKFaeUr6LiodQoSBUWgNDS3srFDQIsoFuMRFZVK3wXhPY3gO+icd70vH/WkT+s7Wtcs+siZrM49ZV32g52zHhPsmt2wW3bn/HaWzh/n7q15jzMfGAb5fy/BzjzrO8=</latexit>w w ⌘ r `k(w) <latexit sha1_base64="EarS2g+hOyhHCn1+yqIviVW4JU=">ACMnicbVA9bxNBEN1LCARDiIGSZoUVyS5s3UWyQhkBRQqKIOIPyWdZc3tz9sp7H9qdw7JO/iX8Avr8gbSJ+AEgGoRERZue9TkFtjPSrp7em5k3ekGmpCHX/e7s7D7Ye/ho/3HlydODZ4fV5y+6Js21wI5IVar7ARhUMsEOSVLYzRCHCjsBdN3S73GbWRaXJB8wyHMYwTGUkBZKlRtT3jvsKIQOt0xme8yQu/3FpoDBfcRwL7jzWE/MPI/yRA8fqsMarW3JZbFt8G3h2onTZuvzXf/x6Pqr+8cNU5DEmJBQYM/DcjIYFaJC4aLi5wYzEFMY48DCBGI0w6I8ZMGPLBPyKNX2JcRL9v+JAmJj5nFgO2OgidnUluR92iCn6M2wkEmWEyZiZRTlilPKl1nxUGoUpOYWgNDS3srFBDQIsomucTz0qRig/E2Y9gG3eOW1265H21Cb9mq9tkr9prVmcdO2Ck7Y+eswT7wq7YNbtxLp0fzi/n96p1x7mbecnWyvn7DzPzrjI=</latexit>w w η r LS(w) <latexit sha1_base64="KXCQlbNA4tNvFg4knZExHyj1hdI=">ACL3icbVDLSgMxFM34rPVdekmWIS6KTOC6KZQdKELFxXtAzp1yKSZNjTJDElGKeP8h9/gzh9wq38gbsSV4F+YabvQ6oHA4dxzOTfHjxhV2rbfrJnZufmFxdxSfnldW29sLHZUGEsManjkIWy5SNFGBWkrqlmpBVJgrjPSNMfnGTz5g2RiobiSg8j0uGoJ2hAMdJG8gr757iRGDpds9WIFuIBFOnDQRKXRVzL2EVpz0WkCXMObRzJX3CkW7bI8A/xJnQorVU/jgene9mlf4dLshjkRGjOkVNuxI91JkNQUM5Lm3ViRCOEB6pG2oQJxojrJ6G8p3DVKFwahNE9oOFJ/biSIKzXkvnFypPtqepaJ/83asQ6OgkVUayJwOgIGZQhzArCnapJFizoSEIS2puhbiPTD3a1PkrhQ9HIVkxznQNf0ljv+wclO0L09AxGCMHtsEOKAEHIqOAM1UAcY3IMn8AxerEfr1Xq3PsbWGWuyswV+wfr6Bk6UqsE=</latexit>

LS(w) = 1 n

n

X

i=1

`i(w)

Loss

slide-4
SLIDE 4

Two More Figures about SGD (Less Relevant)

2500 5000 7500 10000 12500 15000 17500

LterDtLon

10 20 30 40 50 60 70

test DccurDcy (%)

GD (66.96) GLD const (66.66) GLD GynDmLc (69.25) GLD GLDg (67.96) 6GD (75.21)

Wilson, Ashia C., et al. "The marginal value of adaptive gradient methods in machine learning." NIPS 2017. Zhu, Zhanxing, et al. "The Anisotropic Noise in Stochastic Gradient Descent: Its Behavior of Escaping from Sharp Minima and Regularization Effects." ICML 2019.

slide-5
SLIDE 5

SGD vs. GD: Learning Rate Matters!

Small LR Moderate LR GD L L SGD L J Q1: Small LR, SGD ≈ GD? Q2: Moderate LR, SGD >> GD? Q3: GD is bad anyhow?

slide-6
SLIDE 6

In Theory, SGD ≈ GD or SGD ≠ GD ??

Theory disagrees with practice L

  • “Easy” to prove SGD ≈ GD by concentration

<= e.g., small LR

  • “Hard” to prove an inverse result

<= no concentration!

<latexit sha1_base64="EarS2g+hOyhHCn1+yqIviVW4JU=">ACMnicbVA9bxNBEN1LCARDiIGSZoUVyS5s3UWyQhkBRQqKIOIPyWdZc3tz9sp7H9qdw7JO/iX8Avr8gbSJ+AEgGoRERZue9TkFtjPSrp7em5k3ekGmpCHX/e7s7D7Ye/ho/3HlydODZ4fV5y+6Js21wI5IVar7ARhUMsEOSVLYzRCHCjsBdN3S73GbWRaXJB8wyHMYwTGUkBZKlRtT3jvsKIQOt0xme8yQu/3FpoDBfcRwL7jzWE/MPI/yRA8fqsMarW3JZbFt8G3h2onTZuvzXf/x6Pqr+8cNU5DEmJBQYM/DcjIYFaJC4aLi5wYzEFMY48DCBGI0w6I8ZMGPLBPyKNX2JcRL9v+JAmJj5nFgO2OgidnUluR92iCn6M2wkEmWEyZiZRTlilPKl1nxUGoUpOYWgNDS3srFBDQIsomucTz0qRig/E2Y9gG3eOW1265H21Cb9mq9tkr9prVmcdO2Ck7Y+eswT7wq7YNbtxLp0fzi/n96p1x7mbecnWyvn7DzPzrjI=</latexit>w w η r LS(w)

GD SGD

<latexit sha1_base64="H8W0j+nYlbjAbB3JxTavAukChE=">ACeHicbVFNb9NAEF2bj5bwUReOXEZUiEYokV0JwbECDhw4FEHaSnFkjdfjdJX1h3bHRJGVH8SP4Qcgceml/4ALJzZODrTNSLt6eu/NzuhtWmtlOQx/ef6du/fu7+w+6D189PjJXrD/9NRWjZE0kpWuzHmKlrQqacSKNZ3XhrBINZ2lsw8r/ew7Gauq8hsvapoUOC1VriSyo5KgnUOsKWc0prDHAbQxt2raFsCTExuntqMIPSfxVobDeR9ew1ZfalDOiNstHYMNF5PWyWzFLJPgIByGXcFtEG3AwXH/z8/Bx98/TpLgKs4q2RUstRo7TgKa560aFhJTcte3Fiq3QY4pbGDJRZkJ235hJeOiaDvDLulAwd+39Hi4W1iyJ1zgL5wt7UVuQ2bdxw/m7SqrJumEq5HpQ3GriCVeKQKUOS9cIBlEa5XUFeoEuK3b9cm1IsuiE9F0x0M4b4PRoGL0Zhl9cQu/FunbFc/FCHIpIvBXH4pM4ESMhxaW34wXevfXB/+V319bfW/T80xcK/oH9F/wao=</latexit>w w ⌘ r LS(w) + ⌘ (r LS(w) r `k(w))

Unbiased noise (scales with 𝜃)

slide-7
SLIDE 7

Small Learning Rate: SGD ≈ GD

<latexit sha1_base64="EarS2g+hOyhHCn1+yqIviVW4JU=">ACMnicbVA9bxNBEN1LCARDiIGSZoUVyS5s3UWyQhkBRQqKIOIPyWdZc3tz9sp7H9qdw7JO/iX8Avr8gbSJ+AEgGoRERZue9TkFtjPSrp7em5k3ekGmpCHX/e7s7D7Ye/ho/3HlydODZ4fV5y+6Js21wI5IVar7ARhUMsEOSVLYzRCHCjsBdN3S73GbWRaXJB8wyHMYwTGUkBZKlRtT3jvsKIQOt0xme8yQu/3FpoDBfcRwL7jzWE/MPI/yRA8fqsMarW3JZbFt8G3h2onTZuvzXf/x6Pqr+8cNU5DEmJBQYM/DcjIYFaJC4aLi5wYzEFMY48DCBGI0w6I8ZMGPLBPyKNX2JcRL9v+JAmJj5nFgO2OgidnUluR92iCn6M2wkEmWEyZiZRTlilPKl1nxUGoUpOYWgNDS3srFBDQIsomucTz0qRig/E2Y9gG3eOW1265H21Cb9mq9tkr9prVmcdO2Ck7Y+eswT7wq7YNbtxLp0fzi/n96p1x7mbecnWyvn7DzPzrjI=</latexit>w w η r LS(w) <latexit sha1_base64="pkyzSQktJPiPlv6NuoczmDRe2JA=">ACL3icbVC9SgNBGNzN8a/qKXNhyLEwnAXEG2EoBYWFhFNDORC2Nvbi0v2ftjdM4Qj7+ErWPoCtpYiyBiJfgA9m4uKUziwMIwMx+zjBNxJpVpvhlT0zOzc/OZhezi0vLKam5tvSrDWBaISEPRc3BknIW0IpitNaJCj2HU6vnfZJ37+pUKyMLhS3Yg2fNwKmMcIVlpq5oq2yzowBHsgd0S2IXzpn1JMId8ZxcSsNOKRFC3B2lU9Zq5bNgpoBJYg3Jdmn352Xv9P2+3Mx92W5IYp8GinAsZd0yI9VIsFCMcNrL2rGkESZt3KJ1TQPsU9lI0uIe7GjFBS8U+gUKUvXvRYJ9Kbu+o5M+Vjdy3OuL/3n1WHmHjYQFUaxoQAZFXsxBhdAfClwmKFG8qwkmgum/ArnBAhOl5xp8btpSVYPY43PMEmqxYK1XzAv9ELHaIAM2kRbKI8sdIBK6AyVUQURdIce0RN6Nh6MV+PD+BxEp4zhzQYagfH9C87Dq+Q=</latexit>dw = r LS(w)dt <latexit sha1_base64="gE32TAgCiehRkp6UP49CDdg+CSs=">ACcnicbVFNb9NAEF2bj7Yp0BRu5TIQIbWCRnalqlyQqsKBA4eikrZSHKzxep2sumu7u2OqyPIP4sfwA1B75hfAnY3TA0kZaW3b97MW71NSiUtBcFPz793/8HDldW1zvqjx082uptPT21RGS4GvFCFOU/QCiVzMSBJSpyXRqBOlDhLt7P+mfhLGyL/QtBQjeNcZpIjOSruVlEqM7iCd7AL0dhgCp/i6ISjgu2rHaghai1qI9IGWik18BrqBdpeGqojQdi4y4kca3SzX6MJqmxJOVtwFMTd3tBP2gL7oLwFvQOd/782P1w8/047v6K0oJXWuTEFVo7DIOSRjUaklyJphNVpTIL3Ashg7mqIUd1a1A68ck0JWGHdygpb9d6JGbe1UJ06pkSZ2uTcj/9cbVpS9HdUyLysSOZ8bZUCKmCWNaTSCE5q6gByI91bgU/QICf3IwsuetqadFw4XIMd8HpXj/c7wefXUJHbF6r7Dl7ybZyA7YIfvIjtmAcXbteV7HW/d+1v+C783l/re7cwztlD+m79aJMCi</latexit>

dw = r LS(w)dt + pηΣ(w)

1 2 dBt

GD

Gradient Flow (GF)

SGD

Stochastic Modified Equation (SME)

Higher order term 𝜃 = 𝑒𝑢 → 0 𝜃 = 𝑒𝑢 → 0

<latexit sha1_base64="H8W0j+nYlbjAbB3JxTavAukChE=">ACeHicbVFNb9NAEF2bj5bwUReOXEZUiEYokV0JwbECDhw4FEHaSnFkjdfjdJX1h3bHRJGVH8SP4Qcgceml/4ALJzZODrTNSLt6eu/NzuhtWmtlOQx/ef6du/fu7+w+6D189PjJXrD/9NRWjZE0kpWuzHmKlrQqacSKNZ3XhrBINZ2lsw8r/ew7Gauq8hsvapoUOC1VriSyo5KgnUOsKWc0prDHAbQxt2raFsCTExuntqMIPSfxVobDeR9ew1ZfalDOiNstHYMNF5PWyWzFLJPgIByGXcFtEG3AwXH/z8/Bx98/TpLgKs4q2RUstRo7TgKa560aFhJTcte3Fiq3QY4pbGDJRZkJ235hJeOiaDvDLulAwd+39Hi4W1iyJ1zgL5wt7UVuQ2bdxw/m7SqrJumEq5HpQ3GriCVeKQKUOS9cIBlEa5XUFeoEuK3b9cm1IsuiE9F0x0M4b4PRoGL0Zhl9cQu/FunbFc/FCHIpIvBXH4pM4ESMhxaW34wXevfXB/+V319bfW/T80xcK/oH9F/wao=</latexit>w w ⌘ r LS(w) + ⌘ (r LS(w) r `k(w)) <latexit sha1_base64="2YkqCO3ZcVtUvasN3QP5Gwy0bRM=">AC43icbVLihNBFK1uX2N8RWenm8JhQBCd0DGjTDMbFwIjmIyA6meUF190ymHj1V1Uo+wvcyWz9MH/Anf9gdSc+knih4HDuvecbndeCW5dknyP4mvXb9y8tXO7d+fuvfsP+g8fja2uDYMR0Kbs5xaEFzByHEn4KwyQGUu4DS/OG7px/BWK7VB7eoIJO0VHzGXWBmvZ/kBxKrjwLGrbpkTE1eEKgslxoleFXmFheSno+xIRsdKcpfo4JK7SzLfjNetKl8gaKhoCjDW5ltm/wpZLTHJevVrY7idOx82OMQCVfyJ+J6Xc0eN0Z9WGm3Kbn19214a5zuvpn295JB0hXeBukK7B3uP/65e5X6k/YwhWa1BOWYoNZO0qRymafGcSYgpKgtVJRd0BImASoqwWa+s2/wfmAKPNMmPOVwx/674am0diHzMCmpm9vNXkv+rzep3exl5rmqageKLY1mtcBO4/bj4oIbYE4sAqDM8JAVszk1lLnwC6y5yEVn0guHSTfPsA3Gw0H6YpC8Cxc6QsvaQU/QU/QMpegAHaLX6ASNEIveRCby0ecY4i/x1/hqORpHq51dtFbxt1/se1a</latexit>(

Var[✏] = 2 Var[✏1 + · · · + ✏η] = ⌘2 ∼ O

  • ⌘2

⇒ = O (√⌘)

slide-8
SLIDE 8

Effects of Non-Small Learning Rate

SGD GD

<latexit sha1_base64="NEVHOYBfO6Mb1VrYie/t9hYiKGM=">ACMXicbVBNT9tAEF0HaCGlJaXHXlaJKgWhRjaiao9RufRIJRIixVE0Xo+TVdYf2h03sqz8C078jf4BrvAPcqt6KV/go3DgQAj7erpvZl5oxdkShpy3aVT29refV6d6/+Zv/tu4PG+8O+SXMtsCdSlepBAaVTLBHkhQOMo0QBwovg9nZSr/8hdrINLmgIsNRDJNERlIAWrcOJ1zX2FEoHU653P+mZd+tbXUGC64jwT2n2gILVZqPGvPj/i40XI7blX8OfAeQKvb9I+vlt3ifNz454epyGNMSCgwZui5GY1K0CSFwkXdzw1mIGYwaGFCcRoRmV1x4J/skzIo1TblxCv2McTJcTGFHFgO2OgqXmqrciXtGFO0bdRKZMsJ0zE2ijKFaeUr6LiodQoSBUWgNDS3srFDQIsoFuMRFZVK3wXhPY3gO+icd70vH/WkT+s7Wtcs+siZrM49ZV32g52zHhPsmt2wW3bn/HaWzh/n7q15jzMfGAb5fy/BzjzrO8=</latexit>w w ⌘ r `k(w) <latexit sha1_base64="EarS2g+hOyhHCn1+yqIviVW4JU=">ACMnicbVA9bxNBEN1LCARDiIGSZoUVyS5s3UWyQhkBRQqKIOIPyWdZc3tz9sp7H9qdw7JO/iX8Avr8gbSJ+AEgGoRERZue9TkFtjPSrp7em5k3ekGmpCHX/e7s7D7Ye/ho/3HlydODZ4fV5y+6Js21wI5IVar7ARhUMsEOSVLYzRCHCjsBdN3S73GbWRaXJB8wyHMYwTGUkBZKlRtT3jvsKIQOt0xme8yQu/3FpoDBfcRwL7jzWE/MPI/yRA8fqsMarW3JZbFt8G3h2onTZuvzXf/x6Pqr+8cNU5DEmJBQYM/DcjIYFaJC4aLi5wYzEFMY48DCBGI0w6I8ZMGPLBPyKNX2JcRL9v+JAmJj5nFgO2OgidnUluR92iCn6M2wkEmWEyZiZRTlilPKl1nxUGoUpOYWgNDS3srFBDQIsomucTz0qRig/E2Y9gG3eOW1265H21Cb9mq9tkr9prVmcdO2Ck7Y+eswT7wq7YNbtxLp0fzi/n96p1x7mbecnWyvn7DzPzrjI=</latexit>w w η r LS(w)

𝜃 ≥ 2 𝐼 ! 𝜃 < 2 𝐼 ! 𝑀(𝑥) = 0.5 𝑥"𝐼𝑥 𝑥#$% = 𝐽 − 𝜃𝐼 ⋅ 𝑥#

<latexit sha1_base64="KXCQlbNA4tNvFg4knZExHyj1hdI=">ACL3icbVDLSgMxFM34rPVdekmWIS6KTOC6KZQdKELFxXtAzp1yKSZNjTJDElGKeP8h9/gzh9wq38gbsSV4F+YabvQ6oHA4dxzOTfHjxhV2rbfrJnZufmFxdxSfnldW29sLHZUGEsManjkIWy5SNFGBWkrqlmpBVJgrjPSNMfnGTz5g2RiobiSg8j0uGoJ2hAMdJG8gr757iRGDpds9WIFuIBFOnDQRKXRVzL2EVpz0WkCXMObRzJX3CkW7bI8A/xJnQorVU/jgene9mlf4dLshjkRGjOkVNuxI91JkNQUM5Lm3ViRCOEB6pG2oQJxojrJ6G8p3DVKFwahNE9oOFJ/biSIKzXkvnFypPtqepaJ/83asQ6OgkVUayJwOgIGZQhzArCnapJFizoSEIS2puhbiPTD3a1PkrhQ9HIVkxznQNf0ljv+wclO0L09AxGCMHtsEOKAEHIqOAM1UAcY3IMn8AxerEfr1Xq3PsbWGWuyswV+wfr6Bk6UqsE=</latexit>

LS(w) = 1 n

n

X

i=1

`i(w)

Always smooth! One of them can be non-smooth SGD + moderate and annealing LR

  • Phase 1: moderate LR

=> fits smooth losses

  • Phase 2: small LR

=> fits non-smooth losses

slide-9
SLIDE 9

A 2-D Example

ℓ! 𝑥 = 0.5 𝑥"𝐼!𝑥, 𝐼! = 𝑒𝑗𝑏𝑕(2𝜆, 0) ℓ# 𝑥 = 0.5 𝑥"𝐼#𝑥, 𝐼# = 𝑒𝑗𝑏𝑕(0, 2) 𝑀 𝑥 = 0.5 𝑥"𝐼𝑥, 𝐼 = 𝑒𝑗𝑏𝑕(𝜆, 1) 𝜆 > 2 𝜃$ = 5 6 1.1 𝜆 , 𝑢 = 1, … , 𝑈

!

6 0.1 𝜆 , 𝑢 = 𝑈

! + 1, … , 𝑈#

𝜃$ = 6 0.1 𝜆 , 𝑢 = 1, … , 𝑈 Small LR Moderate LR

Same limits Different convergence directions

slide-10
SLIDE 10

A High Dimensional Linear Regression

Setups

  • Test data 𝑦 = 𝜂 ⋅ 𝜊 ∈ ℝ%, where A𝜂 ∈ 0, 1

𝜊 ∼ 𝒱(𝑇%&!)

  • ℓ 𝑦; 𝑥 = 𝑥 − 𝑥∗ "𝑦𝑦" 𝑥 − 𝑥∗
  • Training data X = 𝑦!, … , 𝑦( , 𝑗. 𝑗. 𝑒., 𝑒 ≫ 𝑜

WOLG

  • Let 𝜇) =

𝑦) #

# ∈ (0,1]

  • Assume 𝜇! ≥ 𝜇# ≥ ⋯ ≥ 𝜇(
  • Let 𝑄 be the projection onto the column

space of 𝑌

  • 𝑄* = 𝐽 − 𝑄

Theorem 0:

  • There are multiple minima for 𝑀𝒯 𝑥 = !

( 𝑥 − 𝑥∗ "𝑌𝑌"(𝑥 − 𝑥∗)

  • The iterates of gradient methods belong to a hypothesis class ℋ𝒯 = {𝑥: 𝑄*𝑥 = 𝑄*𝑥,}
  • If gradient methods find a global minima, then it is the one closest to initialization

Remark: this is also known as “minimal-norm solution” since the initialization is usually zero

slide-11
SLIDE 11

A High Dimensional Linear Regression

Theorem 1 (informal): Consider SGD with moderate LR, 𝜃& = 4𝜃 ∈ 6 1 𝜇% + 𝑝 1 , 6 1 𝜇! − 𝑝(1) , 𝑢 = 1, … , 𝑈

%

𝑝 1 , 𝑢 = 𝑈

% + 1, … , 𝑈!

then 𝑄 𝑥 − 𝑥∗ ‖ ‖ 𝑄(𝑥 − 𝑥∗) ! → 𝑤% ± 𝑝(1) Theorem 2 (informal): Consider GD with moderate or small LR, 𝜃& ∈ 0, 6 𝑜 2𝜇! − 𝑝(1) , 𝑢 = 1, … , 𝑈! then 𝑄 𝑥 − 𝑥∗ ‖ ‖ 𝑄(𝑥 − 𝑥∗) ! → 𝑤( ± 𝑝(1) Rayleigh quotient: 𝑆 𝑌𝑌", 𝑣 =

)!**!) )!)

Remark: 𝑤% (𝑤() is the largest (smallest) eigen vector of 𝑌𝑌"

slide-12
SLIDE 12

Convergence Direction Matters

<latexit sha1_base64="hiXG8ZgIPTsUgwRjaHrDgcqnUOk=">ADnicjVLNbtNAEF6bvxL+UjhyWbVCbQRENhJqL0gVINEDhyJIUykO1nozTld71q7Y9LI8hNw4QWQOMEbcENceQVegKfgwNqpqjpFiJGs/TzfP7Gs5PkUlgMgp+ef+HipctXVq52rl2/cfNWd/X2vtWF4TDgWmpzkDALUigYoEAJB7kBliUShsnRs5ofvgNjhVZvcJ7DOGNTJVLBGbpUvOp1X8Zl9JwzW3O3kZMTnWPqSRUGlczip6hu3RJzQq1ARMYhiH8l/CDXdGu450Na/rmtaXNnpV/QIS2an2QUQjhGMswaLImu4oGKONk95v+f6nxWkztPUbdNaYN1Ysz40+XnKLu+tBP2iCngfhCVjfGb4IP73/vbYXd39FE82LDBRyawdhUGO45IZFxC1YkKCznjR2wKIwcVy8COy+bqKnrPZSY01cY9CmTPasoWbtPEtcpWvz0C5zdfJv3KjAdHtcCpUXCIovjNJCUtS03gM6EQY4yrkDjBvheqX8kLkBo9uWlks2b0w6bjDh8hjOg/1H/fBxP3jlJvSULGKF3CVrZJOEZIvskF2yRwaEezPvo/fZ+J/8L/63/zvi1LfO9HcIa3wf/wBctX9IA=</latexit>

LD(walg)−inf

w LD(w) = LD(walg) −

inf

w02HS LD(w0)

| {z }

∆(walg), estimation error

+ inf

w02HS LD(w0) − inf w LD(w)

| {z }

approximation error

intrinsic error, not improvable determined by the algorithms and hyperparameters

  • 𝛽-level set: 𝒳
  • = {𝑥 ∈ ℋ𝒯: 𝑀𝒯 𝑥 = 𝛽}
  • Optimal estimation error within a level set: Δ-

∗ = min .∈𝒳! Δ(𝑥)

Theorem 3:

  • For SGD with moderate LR, Δ 𝑥12% < 1 + 𝑝 1

⋅ Δ-

  • For GD with moderate or small LR, Δ 𝑥2% >

3" 3# − 𝑝 1

⋅ Δ-

Remark: 𝛿% (𝛿() is the largest (smallest) eigenvalue of 𝑌𝑌"

slide-13
SLIDE 13

Proof Sketch of Theorem 3

Test Loss Training Loss Within a level set of the training loss, larger eigenvalue direction ⇒ smaller test loss

slide-14
SLIDE 14

Proof Sketch of Theorem 2

  • 𝑀𝒯 𝑥 = !

( 𝑥"𝑌𝑌"𝑥

  • GD: 𝑥$4! = 𝐽 − #5

( 𝑌𝑌" 𝑥$

⇒ 𝑥$ = 𝐽 − #5

( 𝑌𝑌" $

𝑥, ⇒ 𝑣$ = 𝐽 − #5

( Γ $

𝑣,

  • Moderate/small LR: 0 ≤ 1 − #53"

( $

≪ ⋯ ≪ 1 − #53#

( $

≪ 1

  • Hence GD converges slower along the small eigenvalue directions
slide-15
SLIDE 15

Proof Sketch of Theorem 1

  • ℓ) 𝑥 = 𝑥"𝑦)𝑦)

"𝑥

  • SGD in one epoch: 𝑥6,84! = 𝐽 − 2𝜃𝑦9 8 𝑦9 8

"

𝑥6,8 ⇒ 𝑥64! = ∏8:!

(

𝐽 − 2𝜃𝑦9 8 𝑦9 8

"

𝑥6

  • Bounding the spectrum of ∏8:!

(

𝐽 − 2𝜃𝑦9 8 𝑦9 8

"

as projected onto 𝑌&! = (𝑦#, … , 𝑦() and its complement 6

; ;$"

  • Bounding the updates of 𝑥6 as projected onto 𝑌&! and 6

; ;$"

  • Reverse engineering Phase 2 and Phase 1…

Remark: with moderate LR, concentration

  • f matrix products turns out to be vacuous

Henriksen, Amelia, and Rachel Ward. "Concentration inequalities for random matrix products." Linear Algebra and its Applications 594 (2020): 81-94.

slide-16
SLIDE 16

Take Home

  • SGD + moderate LR: converges along large eigenvalue directions
  • GD or SGD + small LR: converge along small eigenvalue directions
  • The former directional bias benefits generalization
  • The analysis is “anti-concentration”

Get the paper ->