HiGrad: Statistical Inference for Stochastic Approximation and - - PowerPoint PPT Presentation

higrad statistical inference for stochastic approximation
SMART_READER_LITE
LIVE PREVIEW

HiGrad: Statistical Inference for Stochastic Approximation and - - PowerPoint PPT Presentation

HiGrad: Statistical Inference for Stochastic Approximation and Online Learning Weijie Su University of Pennsylvania Collaborator Yuancheng Zhu (UPenn) 2 / 59 Learning by optimization Sample Z 1 , . . . , Z N , and f ( , z ) is cost


slide-1
SLIDE 1

HiGrad: Statistical Inference for Stochastic Approximation and Online Learning

Weijie Su

University of Pennsylvania

slide-2
SLIDE 2

Collaborator

  • Yuancheng Zhu (UPenn)

2 / 59

slide-3
SLIDE 3

Learning by optimization

Sample Z1, . . . , ZN, and f(θ, z) is cost function Learning model by minimizing argmin

θ

1 N

N

  • n=1

f(θ, Zn)

3 / 59

slide-4
SLIDE 4

Learning by optimization

Sample Z1, . . . , ZN, and f(θ, z) is cost function Learning model by minimizing argmin

θ

1 N

N

  • n=1

f(θ, Zn)

  • Maximum likelihood estimation (MLE). More generally, M-estimation
  • Often no closed-form solution
  • Need optimization

3 / 59

slide-5
SLIDE 5

Gradient descent

◮ Start at some θ0 ◮ Iterate

θj = θj−1 − γj N

n=1 ∇f(θj−1, Zn)

N , where γj are step sizes Dates back to Newton, Gauss, and Cauchy

4 / 59

slide-6
SLIDE 6

Difficulty with gradient descent

Modern machine learning Gradient descent often not feasible due to

5 / 59

slide-7
SLIDE 7

Difficulty with gradient descent

Modern machine learning

  • Data arrives in a stream

Gradient descent often not feasible due to

  • Essentially an offline algorithm

5 / 59

slide-8
SLIDE 8

Difficulty with gradient descent

Modern machine learning

  • Data arrives in a stream
  • Number of data points N is exceedingly large

Gradient descent often not feasible due to

  • Essentially an offline algorithm
  • Evaluating full gradient is computationally expensive

5 / 59

slide-9
SLIDE 9

Stochastic gradient descent (SGD)

Aka incremental gradient descent

◮ Start at some θ0 ◮ Iterate

θj = θj−1 − γj∇f(θj−1, Zj)

6 / 59

slide-10
SLIDE 10

Stochastic gradient descent (SGD)

Aka incremental gradient descent

◮ Start at some θ0 ◮ Iterate

θj = θj−1 − γj∇f(θj−1, Zj)

SGD resolved these challenges

  • Online in nature

6 / 59

slide-11
SLIDE 11

Stochastic gradient descent (SGD)

Aka incremental gradient descent

◮ Start at some θ0 ◮ Iterate

θj = θj−1 − γj∇f(θj−1, Zj)

SGD resolved these challenges

  • Online in nature
  • One pass over data

6 / 59

slide-12
SLIDE 12

Stochastic gradient descent (SGD)

Aka incremental gradient descent

◮ Start at some θ0 ◮ Iterate

θj = θj−1 − γj∇f(θj−1, Zj)

SGD resolved these challenges

  • Online in nature
  • One pass over data
  • Optimal properties (Nemirovski & Yudin, 1983; Bertsekas, 1999; Agarwal et

al, 2012; Rakhlin et al, 2012; Hardt et al, 2015)

6 / 59

slide-13
SLIDE 13

SGD in one line

7 / 59

slide-14
SLIDE 14

SGD vs GD

8 / 59

SGD GD

slide-15
SLIDE 15

SGD: past and now

Statistics

  • Robbins & Monro (1951); Kiefer & Wolfowitz (1952); Robbins & Siegmund

(1971); Ruppert (1988); Polyak & Juditsky (1992) Machine learning and optimization

  • Nesterov & Vial (2008); Nemirovski et al (2009); Bottou (2010); Bach and

Moulines (2011); Duchi et al (2011); Diederik & Ba (2014) Applications

  • Deep learning, recommender systems, MCMC, Kalman filter, phase

retrieval, networks, and many

9 / 59

slide-16
SLIDE 16

Using SGD for prediction

Averaged SGD

An estimator of θ∗ := argmin Ef(θ, Z) is given by averaging θ = 1 N

N

  • j=1

θj Recall that θj = θj−1 − γj∇f(θj−1, Zj) for j = 1, . . . , N.

10 / 59

slide-17
SLIDE 17

Using SGD for prediction

Averaged SGD

An estimator of θ∗ := argmin Ef(θ, Z) is given by averaging θ = 1 N

N

  • j=1

θj Recall that θj = θj−1 − γj∇f(θj−1, Zj) for j = 1, . . . , N. Given a new instance z = (x, y) with y unknown

Interested in µx(θ)

  • Linear regression: µx(θ) = x′ θ
  • Logistic regression: µx(θ) =

ex′ θ 1+ex′ θ

  • Generalized linear models: µx(θ) = Eθ(Y |X = x)

10 / 59

slide-18
SLIDE 18

How much can we trust SGD predictions?

We would observe a different µx(θ) if

  • Re-sample Z′

1, . . . , Z′ N

  • Sample with replacement N times from a finite population z1, . . . , zm

11 / 59

slide-19
SLIDE 19

How much can we trust SGD predictions?

We would observe a different µx(θ) if

  • Re-sample Z′

1, . . . , Z′ N

  • Sample with replacement N times from a finite population z1, . . . , zm

Decision-making requires uncertainty quantification

  • Should I invest in Bitcoin?
  • How early to leave to catch a flight?

11 / 59

slide-20
SLIDE 20

A real data example

Adult dataset on UCI repository1

  • 123 features
  • Y = 1 if an individual’s annual income exceeds $50,000
  • 32,561 instances

Randomly pick 1,000 as a test set. Run SGD 500 times independently, each with 20 epochs and step sizes γj = 0.5j−0.55. Construct empirical confidence intervals with α = 10%

1https://archive.ics.uci.edu/ml/datasets/Adult

12 / 59

slide-21
SLIDE 21

High variability of SGD predictions

0.01% 0.1% 1% 10% 100% 0% 25% 50% 75% 100%

Predicted probability Confidence interval length

13 / 59

slide-22
SLIDE 22

What is desired

Can we construct a confidence interval for µ∗

x := µx(θ∗)? 14 / 59

slide-23
SLIDE 23

What is desired

Can we construct a confidence interval for µ∗

x := µx(θ∗)?

Remarks

  • Bootstrap is computationally infeasible
  • Most existing works concern bounding generalization errors or minimizing

regrets (Shalev-Shwartz et al, 2011; Rakhlin et al, 2012)

  • Chen et al (2016) proposed a batch-mean estimator of SGD covariance, and

Fang et al (2017) proposed a perturbation-based resampling procedure

14 / 59

slide-24
SLIDE 24

This talk: HiGrad

A new method: Hierarchical Incremental GRAdient Descent

15 / 59

slide-25
SLIDE 25

This talk: HiGrad

A new method: Hierarchical Incremental GRAdient Descent Properties of HiGrad

◮ Online in nature with same computational cost as vanilla SGD 15 / 59

slide-26
SLIDE 26

This talk: HiGrad

A new method: Hierarchical Incremental GRAdient Descent Properties of HiGrad

◮ Online in nature with same computational cost as vanilla SGD ◮ A confidence interval for µ∗ x in addition to an estimator 15 / 59

slide-27
SLIDE 27

This talk: HiGrad

A new method: Hierarchical Incremental GRAdient Descent Properties of HiGrad

◮ Online in nature with same computational cost as vanilla SGD ◮ A confidence interval for µ∗ x in addition to an estimator ◮ Estimator (almost) as accurate as vanilla SGD 15 / 59

slide-28
SLIDE 28

Preview of HiGrad

16 / 59 θ

θ

2

θ

1

slide-29
SLIDE 29

Preview of HiGrad

  • θ1 = 1

3θ ∅ + 2 3θ 1,

θ2 = 1

3θ ∅ + 2 3θ 2 16 / 59 θ

θ

2

θ

1

slide-30
SLIDE 30

Preview of HiGrad

  • θ1 = 1

3θ ∅ + 2 3θ 1,

θ2 = 1

3θ ∅ + 2 3θ 2 16 / 59 θ

θ

2

θ

1

slide-31
SLIDE 31

Preview of HiGrad

  • θ1 = 1

3θ ∅ + 2 3θ 1,

θ2 = 1

3θ ∅ + 2 3θ 2

  • µ1

x := µx(θ1) = 0.15,

µ2

x := µx(θ2) = 0.11 16 / 59 θ

θ

2

θ

1

slide-32
SLIDE 32

Preview of HiGrad

  • θ1 = 1

3θ ∅ + 2 3θ 1,

θ2 = 1

3θ ∅ + 2 3θ 2

  • µ1

x := µx(θ1) = 0.15,

µ2

x := µx(θ2) = 0.11

  • HiGrad estimator is µx = µ1

x+µ2 x

2

= 0.13

16 / 59 θ

θ

2

θ

1

slide-33
SLIDE 33

Preview of HiGrad

  • θ1 = 1

3θ ∅ + 2 3θ 1,

θ2 = 1

3θ ∅ + 2 3θ 2

  • µ1

x := µx(θ1) = 0.15,

µ2

x := µx(θ2) = 0.11

  • HiGrad estimator is µx = µ1

x+µ2 x

2

= 0.13

  • The 90% HiGrad confidence interval for µ∗

x is

  • µx − t1,0.95

√ 0.375|µ1

x − µ2 x|, µx + t1,0.95

√ 0.375|µ1

x − µ2 x|

  • = [−0.025, 0.285]

16 / 59 θ

θ

2

θ

1

slide-34
SLIDE 34

Outline

  • 1. Deriving HiGrad
  • 2. Constructing Confidence Intervals
  • 3. Configuring HiGrad
  • 4. Empirical Performance

17 / 59

slide-35
SLIDE 35

Problem statement

Minimizing convex f θ∗ = argmin

θ

f(θ) ≡ Ef(θ, Z) Observe i.i.d. Z1, . . . , ZN and can evaluate unbiased noisy gradient g(θ; Z) E g(θ, Z) = ∇f(θ) for all θ

To be fulfilled

◮ Online in nature with same computational cost as vanilla SGD ◮ A confidence interval for µ∗ x in addition to an estimator ◮ Estimator (almost) as accurate as vanilla SGD 18 / 59

slide-36
SLIDE 36

The idea of contrasting and sharing

  • Need more than one value µx to quantify variability: contrasting

19 / 59

slide-37
SLIDE 37

The idea of contrasting and sharing

  • Need more than one value µx to quantify variability: contrasting
  • Need to share gradient information to elongate threads: sharing

19 / 59

slide-38
SLIDE 38

The HiGrad tree

  • K + 1 levels
  • each k-level segment is of length nk and is split into Bk+1 segments

n0 + B1n1 + B1B2n2 + B1B2B3n3 + · · · + B1B2 · · · BKnK = N

20 / 59

slide-39
SLIDE 39

The HiGrad tree

  • K + 1 levels
  • each k-level segment is of length nk and is split into Bk+1 segments

n0 + B1n1 + B1B2n2 + B1B2B3n3 + · · · + B1B2 · · · BKnK = N An example of HiGrad tree: B1 = 2, B2 = 3, K = 2

20 / 59

slide-40
SLIDE 40

The HiGrad tree

  • K + 1 levels
  • each k-level segment is of length nk and is split into Bk+1 segments

n0 + B1n1 + B1B2n2 + B1B2B3n3 + · · · + B1B2 · · · BKnK = N An example of HiGrad tree: B1 = 2, B2 = 3, K = 2

20 / 59

slide-41
SLIDE 41

The HiGrad tree

  • K + 1 levels
  • each k-level segment is of length nk and is split into Bk+1 segments

n0 + B1n1 + B1B2n2 + B1B2B3n3 + · · · + B1B2 · · · BKnK = N An example of HiGrad tree: B1 = 2, B2 = 3, K = 2

20 / 59

slide-42
SLIDE 42

Iterate along HiGrad tree

Recall: noisy gradient g(θ, Z) unbiased for ∇f(θ); partition {Zs} of {Z1, . . . , ZN}; and Lk := n0 + · · · + nk

21 / 59

slide-43
SLIDE 43

Iterate along HiGrad tree

Recall: noisy gradient g(θ, Z) unbiased for ∇f(θ); partition {Zs} of {Z1, . . . , ZN}; and Lk := n0 + · · · + nk

◮ Iterate along level 0 segment: θj = θj−1 − γj∇f(θj−1, Zj) for j = 1, . . . , n0,

starting from some θ0

21 / 59

slide-44
SLIDE 44

Iterate along HiGrad tree

Recall: noisy gradient g(θ, Z) unbiased for ∇f(θ); partition {Zs} of {Z1, . . . , ZN}; and Lk := n0 + · · · + nk

◮ Iterate along level 0 segment: θj = θj−1 − γj∇f(θj−1, Zj) for j = 1, . . . , n0,

starting from some θ0

◮ Iterate along each level 1 segment s = (b1) for 1 ≤ b1 ≤ B1

θs

j = θs j−1 − γj+L0g(θs j−1, Zs j )

for j = 1, . . . , n1, starting from θn0

21 / 59

slide-45
SLIDE 45

Iterate along HiGrad tree

Recall: noisy gradient g(θ, Z) unbiased for ∇f(θ); partition {Zs} of {Z1, . . . , ZN}; and Lk := n0 + · · · + nk

◮ Iterate along level 0 segment: θj = θj−1 − γj∇f(θj−1, Zj) for j = 1, . . . , n0,

starting from some θ0

◮ Iterate along each level 1 segment s = (b1) for 1 ≤ b1 ≤ B1

θs

j = θs j−1 − γj+L0g(θs j−1, Zs j )

for j = 1, . . . , n1, starting from θn0

◮ Generally, for the segment s = (b1 · · · bk), iterate

θs

j = θs j−1 − γj+Lk−1 g(θs j−1, Zs j )

for j = 1, . . . , nk, starting from θ(b1···bk−1)

nk−1 21 / 59

slide-46
SLIDE 46

A second look at the HiGrad tree

An example of HiGrad tree: B1 = 2, B2 = 3, K = 2

22 / 59

slide-47
SLIDE 47

A second look at the HiGrad tree

An example of HiGrad tree: B1 = 2, B2 = 3, K = 2

Fulfilled

  • Online in nature with same computational cost as vanilla SGD

22 / 59

slide-48
SLIDE 48

A second look at the HiGrad tree

An example of HiGrad tree: B1 = 2, B2 = 3, K = 2

Fulfilled

  • Online in nature with same computational cost as vanilla SGD

Bonus

Easier to parallelize than vanilla SGD!

22 / 59

slide-49
SLIDE 49

The HiGrad algorithm in action

Require: g(·, ·), Z1, . . . , ZN, (n0, n1, . . . , nK), (B1, . . . , BK), (γ1, . . . , γNK), θ0 θ

s = 0 for all segments s

function NodeTreeSGD(θ, s) θs

0 = θ

k = #s for j = 1 to nk do θs

j ← θs j−1 − γj+Lk−1 g(θs j−1, Zs j )

θ

s ← θ s + θs j /nk

end for if k < K then for bk+1 = 1 to Bk+1 do s+ ← (s, bk+1) execute NodeTreeSGD

  • θs

nk, s+

end for end if end function execute NodeTreeSGD(θ0, ∅)

  • utput: θ

s for all segments s 23 / 59

slide-50
SLIDE 50

Outline

  • 1. Deriving HiGrad
  • 2. Constructing Confidence Intervals
  • 3. Configuring HiGrad
  • 4. Empirical Performance

24 / 59

slide-51
SLIDE 51

Estimate µ∗

x through each thread Average over each segment s = (b1, . . . , bk) θ

s = 1

nk

nk

  • j=1

θs

j

Given weights w0, w1, . . . , wK that sum up to 1, weighted average along thread t = (b1, . . . , bK) is θt =

K

  • k=0

wkθ

(b1,...,bk) 25 / 59

slide-52
SLIDE 52

Estimate µ∗

x through each thread Average over each segment s = (b1, . . . , bk) θ

s = 1

nk

nk

  • j=1

θs

j

Given weights w0, w1, . . . , wK that sum up to 1, weighted average along thread t = (b1, . . . , bK) is θt =

K

  • k=0

wkθ

(b1,...,bk)

Estimator yielded by thread t

µt

x := µx(θt) 25 / 59

slide-53
SLIDE 53

How to construct a confidence interval based on T := B1B2 · · · BK many such µt

x estimates?

25 / 59

slide-54
SLIDE 54

Assume normality

Denote by µx the T-dimensional vector consisting of all µt

x

Normality of µx (to be proved soon)

√ N(µx − µ∗

x1) converges weakly to normal distribution N(0, Σ) as N → ∞ 26 / 59

slide-55
SLIDE 55

Convert to simple linear regression

From µx

a

∼ N(µ∗

x1, Σ/N) we get

Σ− 1

2 µx ≈ (Σ− 1 2 1)µ∗

x + ˜

z, ˜ z ∼ N(0, I/N)

27 / 59

slide-56
SLIDE 56

Convert to simple linear regression

From µx

a

∼ N(µ∗

x1, Σ/N) we get

Σ− 1

2 µx ≈ (Σ− 1 2 1)µ∗

x + ˜

z, ˜ z ∼ N(0, I/N) Simple linear regression! Least-squares estimator of µ∗

x given as

(1′Σ− 1

2 Σ− 1 2 1)−11′Σ− 1 2 Σ− 1 2 µx

= (1′Σ−11)−11′Σ−1µx = 1 T

  • t∈T

µt

x ≡ µx

HiGrad estimator

Just the sample mean µx

27 / 59

slide-57
SLIDE 57

A t-based confidence interval

A pivot for µ∗

x

µx − µ∗

x

SEx

a

∼ tT −1, where the standard error is given as SEx =

  • (µ′

x − µx1′)Σ−1(µx − µx1)

T − 1 · √ 1′Σ1 T

28 / 59

slide-58
SLIDE 58

A t-based confidence interval

A pivot for µ∗

x

µx − µ∗

x

SEx

a

∼ tT −1, where the standard error is given as SEx =

  • (µ′

x − µx1′)Σ−1(µx − µx1)

T − 1 · √ 1′Σ1 T

HiGrad confidence interval of coverage 1 − α

  • µx − tT −1,1− α

2 SEx,

µx + tT −1,1− α

2 SEx

  • 28 / 59
slide-59
SLIDE 59

Do we know the covariance Σ?

28 / 59

slide-60
SLIDE 60

An extension of Ruppert–Polyak normality

Given a thread t = (b1, . . . , bK), denote by segments sk = (b1, b2, . . . , bk)

Fact (informal)

√n0(θ

s0 − θ∗), √n1(θ s1 − θ∗), . . . , √nK(θ sK − θ∗) converge to i.i.d. centered

normal distributions

29 / 59

slide-61
SLIDE 61

An extension of Ruppert–Polyak normality

Given a thread t = (b1, . . . , bK), denote by segments sk = (b1, b2, . . . , bk)

Fact (informal)

√n0(θ

s0 − θ∗), √n1(θ s1 − θ∗), . . . , √nK(θ sK − θ∗) converge to i.i.d. centered

normal distributions

  • Hessian H = ∇2f(θ∗) and V = E [g(θ∗, Z)g(θ∗, Z)′]. Ruppert (1988), Polyak

(1990), and Polyak and Juditsky (1992) prove √ N(θN − θ∗) ⇒ N(0, H−1V H−1)

29 / 59

slide-62
SLIDE 62

An extension of Ruppert–Polyak normality

Given a thread t = (b1, . . . , bK), denote by segments sk = (b1, b2, . . . , bk)

Fact (informal)

√n0(θ

s0 − θ∗), √n1(θ s1 − θ∗), . . . , √nK(θ sK − θ∗) converge to i.i.d. centered

normal distributions

  • Hessian H = ∇2f(θ∗) and V = E [g(θ∗, Z)g(θ∗, Z)′]. Ruppert (1988), Polyak

(1990), and Polyak and Juditsky (1992) prove √ N(θN − θ∗) ⇒ N(0, H−1V H−1)

  • Difficult to estimate sandwich covariance H−1V H−1 (Chen et al, 2016)

29 / 59

slide-63
SLIDE 63

An extension of Ruppert–Polyak normality

Given a thread t = (b1, . . . , bK), denote by segments sk = (b1, b2, . . . , bk)

Fact (informal)

√n0(θ

s0 − θ∗), √n1(θ s1 − θ∗), . . . , √nK(θ sK − θ∗) converge to i.i.d. centered

normal distributions

  • Hessian H = ∇2f(θ∗) and V = E [g(θ∗, Z)g(θ∗, Z)′]. Ruppert (1988), Polyak

(1990), and Polyak and Juditsky (1992) prove √ N(θN − θ∗) ⇒ N(0, H−1V H−1)

  • Difficult to estimate sandwich covariance H−1V H−1 (Chen et al, 2016)
  • To know covariance of {µx(θt)}, really need to know H−1V H−1?

29 / 59

slide-64
SLIDE 64

Covariance determined by number of shared segments

Consider µx(θ) = T(x)′ θ and observe

  • √n0(µx(θ

s0) − µ∗ x), √n1(µx(θ s1) − µ∗ x), . . . , √nK(µx(θ sK) − µ∗ x) converge

to i.i.d. centered univariate normal distributions

  • µt

x − µ∗ x = µx(θt) − µ∗ x = K

  • k=0

wk

  • µx(θ

sk) − µ∗ x

  • 30 / 59
slide-65
SLIDE 65

Covariance determined by number of shared segments

Consider µx(θ) = T(x)′ θ and observe

  • √n0(µx(θ

s0) − µ∗ x), √n1(µx(θ s1) − µ∗ x), . . . , √nK(µx(θ sK) − µ∗ x) converge

to i.i.d. centered univariate normal distributions

  • µt

x − µ∗ x = µx(θt) − µ∗ x = K

  • k=0

wk

  • µx(θ

sk) − µ∗ x

  • Fact (informal)

For any two threads t and t′ that agree at the first k segments and differ henceforth, we have Cov

  • µt

x, µt′ x

  • = (1 + o(1))σ2

k

  • i=0

w2

i

ni

30 / 59

slide-66
SLIDE 66

Specify Σ up to a multiplicative factor

If µx(θ) = T(x)′ θ, then for any two threads t and t′ that agree only at the first k segments, Σt,t′ = (1 + o(1))C

k

  • i=0

ω2

i N

ni

31 / 59

slide-67
SLIDE 67

Specify Σ up to a multiplicative factor

If µx(θ) = T(x)′ θ, then for any two threads t and t′ that agree only at the first k segments, Σt,t′ = (1 + o(1))C

k

  • i=0

ω2

i N

ni

  • Do we need to know C as well?

31 / 59

slide-68
SLIDE 68

Specify Σ up to a multiplicative factor

If µx(θ) = T(x)′ θ, then for any two threads t and t′ that agree only at the first k segments, Σt,t′ = (1 + o(1))C

k

  • i=0

ω2

i N

ni

  • Do we need to know C as well?
  • No! Standard error of µx invariant under multiplying Σ by a scalar

SEx =

  • (µ′

x − µx1′)Σ−1(µx − µx1)

T − 1 · √ 1′Σ1 T

31 / 59

slide-69
SLIDE 69

Some remarks

  • In generalized linear models, µx often takes the form µx(θ) = η−1(T(x)′θ)

for an increasing η. Construct confidence interval for η(µx) and then invert

  • For general nonlinear but smooth µx(θ) , use delta method
  • Need less than Ruppert–Polyak: remains to hold if

√ N(θN − θ∗) converges to some centered normal distribution

32 / 59

slide-70
SLIDE 70

Formal statement of theoretical results

32 / 59

slide-71
SLIDE 71

Assumptions

1

Local strong convexity. f(θ) ≡ Ef(θ, Z) convex, differentiable, with Lipschitz gradients. Hessian ∇2f(θ) locally Lipschitz and positive-definite at θ∗

2

Noise regularity. V (θ) = E [g(θ, Z)g(θ, Z)′] Lipschitz and does not grow too

  • fast. Noisy gradient g(θ, Z) has 2 + o(1) moment locally at θ∗

33 / 59

slide-72
SLIDE 72

Examples satisfying assumptions

  • Linear regression: f(θ, z) = 1

2(y − x⊤θ)2.

  • Logistic regression: f(θ, z) = −yx⊤θ + log
  • 1 + ex⊤θ

.

  • Penalized regression: Add a ridge penalty λθ2.
  • Huber regression: f(θ, z) = ρλ(y − x⊤θ), where ρλ(a) = a2/2 for |a| ≤ λ

and ρλ(a) = λ|a| − λ2/2 otherwise.

Sufficient conditions

X in generic position, and EX4+o(1) < ∞ and E|Y |2+o(1)X2+o(1) < ∞

34 / 59

slide-73
SLIDE 73

Main theoretical results

Theorem (S. and Zhu)

Assume K and B1, . . . , BK are fixed, nk ∝ N as N → ∞, and µx has a nonzero derivative at θ∗. Taking γj ≍ j−α for α ∈ (0.5, 1) gives µx − µ∗

x

SEx = ⇒ tT −1

35 / 59

slide-74
SLIDE 74

Main theoretical results

Theorem (S. and Zhu)

Assume K and B1, . . . , BK are fixed, nk ∝ N as N → ∞, and µx has a nonzero derivative at θ∗. Taking γj ≍ j−α for α ∈ (0.5, 1) gives µx − µ∗

x

SEx = ⇒ tT −1

Confidence intervals

lim

N→∞ P

  • µ∗

x ∈

  • µx − tT −1,1− α

2 SEx,

µx + tT −1,1− α

2 SEx

  • = 1 − α

35 / 59

slide-75
SLIDE 75

Main theoretical results

Theorem (S. and Zhu)

Assume K and B1, . . . , BK are fixed, nk ∝ N as N → ∞, and µx has a nonzero derivative at θ∗. Taking γj ≍ j−α for α ∈ (0.5, 1) gives µx − µ∗

x

SEx = ⇒ tT −1

Confidence intervals

lim

N→∞ P

  • µ∗

x ∈

  • µx − tT −1,1− α

2 SEx,

µx + tT −1,1− α

2 SEx

  • = 1 − α

Fulfilled

  • Online in nature with same computational cost as vanilla SGD
  • A confidence interval for µ∗

x in addition to an estimator 35 / 59

slide-76
SLIDE 76

How accurate is the HiGrad estimator?

35 / 59

slide-77
SLIDE 77

Optimal variance with optimal weights

By Cauchy–Schwarz N Var(µx) = (1 + o(1))σ2 K

  • k=0

nk

k

  • i=1

Bi K

  • k=0

w2

k

nk k

i=1 Bi

  • ≥ (1 + o(1))σ2

K

  • k=0
  • w2

k

2 = (1 + o(1))σ2, with equality if w∗

k = nk

k

i=1 Bi

N

36 / 59

slide-78
SLIDE 78

Optimal variance with optimal weights

By Cauchy–Schwarz N Var(µx) = (1 + o(1))σ2 K

  • k=0

nk

k

  • i=1

Bi K

  • k=0

w2

k

nk k

i=1 Bi

  • ≥ (1 + o(1))σ2

K

  • k=0
  • w2

k

2 = (1 + o(1))σ2, with equality if w∗

k = nk

k

i=1 Bi

N

  • Segments at an early level weighted less

36 / 59

slide-79
SLIDE 79

Optimal variance with optimal weights

By Cauchy–Schwarz N Var(µx) = (1 + o(1))σ2 K

  • k=0

nk

k

  • i=1

Bi K

  • k=0

w2

k

nk k

i=1 Bi

  • ≥ (1 + o(1))σ2

K

  • k=0
  • w2

k

2 = (1 + o(1))σ2, with equality if w∗

k = nk

k

i=1 Bi

N

  • Segments at an early level weighted less
  • HiGrad estimator has the same asymptotic variance as vanilla SGD

36 / 59

slide-80
SLIDE 80

Optimal variance with optimal weights

By Cauchy–Schwarz N Var(µx) = (1 + o(1))σ2 K

  • k=0

nk

k

  • i=1

Bi K

  • k=0

w2

k

nk k

i=1 Bi

  • ≥ (1 + o(1))σ2

K

  • k=0
  • w2

k

2 = (1 + o(1))σ2, with equality if w∗

k = nk

k

i=1 Bi

N

  • Segments at an early level weighted less
  • HiGrad estimator has the same asymptotic variance as vanilla SGD
  • Achieves Cramér–Rao lower bound when model specified

36 / 59

slide-81
SLIDE 81

Prediction intervals for vanilla SGD

Theorem (S. and Zhu)

Run vanilla SGD on a fresh dataset of the same size, producing µSGD

x

. Then, with optimal weights, lim

N→∞ P

  • µSGD

x

  • µx −

√ 2tT −1,1− α

2 SEx,

µx + √ 2tT −1,1− α

2 SEx

  • = 1 − α.
  • µSGD

x

can be replaced by the HiGrad estimator with the same structure

  • Interpretable even under model misspecification

37 / 59

slide-82
SLIDE 82

HiGrad enjoys three appreciable properties

Under certain assumptions, for example, f being locally strongly convex

Fulfilled

  • Online in nature with same computational cost as vanilla SGD
  • A confidence interval for µ∗

x in addition to an estimator

  • Estimator (almost) as accurate as vanilla SGD

38 / 59

slide-83
SLIDE 83

Outline

  • 1. Deriving HiGrad
  • 2. Constructing Confidence Intervals
  • 3. Configuring HiGrad
  • 4. Empirical Performance

39 / 59

slide-84
SLIDE 84

Which one?

40 / 59

slide-85
SLIDE 85

Length of confidence intervals

Denote by LCI = 2tT −1,1− α

2 SEx the length of HiGrad confidence interval

Proposition (S. and Zhu)

√ NELCI → 2σ √ 2tT −1,1− α

2 Γ

T

2

T − 1 Γ T −1

2

  • 41 / 59
slide-86
SLIDE 86

Length of confidence intervals

Denote by LCI = 2tT −1,1− α

2 SEx the length of HiGrad confidence interval

Proposition (S. and Zhu)

√ NELCI → 2σ √ 2tT −1,1− α

2 Γ

T

2

T − 1 Γ T −1

2

  • The function tT −1,1− α

2 Γ

T

2

T − 1 Γ T −1

2

is decreasing in T ≥ 2

41 / 59

slide-87
SLIDE 87

Length of confidence intervals

Denote by LCI = 2tT −1,1− α

2 SEx the length of HiGrad confidence interval

Proposition (S. and Zhu)

√ NELCI → 2σ √ 2tT −1,1− α

2 Γ

T

2

T − 1 Γ T −1

2

  • The function tT −1,1− α

2 Γ

T

2

T − 1 Γ T −1

2

is decreasing in T ≥ 2

  • The more threads, the shorter the HiGrad confidence interval on average

41 / 59

slide-88
SLIDE 88

Length of confidence intervals

Denote by LCI = 2tT −1,1− α

2 SEx the length of HiGrad confidence interval

Proposition (S. and Zhu)

√ NELCI → 2σ √ 2tT −1,1− α

2 Γ

T

2

T − 1 Γ T −1

2

  • The function tT −1,1− α

2 Γ

T

2

T − 1 Γ T −1

2

is decreasing in T ≥ 2

  • The more threads, the shorter the HiGrad confidence interval on average
  • More contrasting leads to shorter confidence interval

41 / 59

slide-89
SLIDE 89

Really want to set T = 1000?

42 / 59

slide-90
SLIDE 90

T = 4 is sufficient

  • 3

6 9 2 4 6 8 10 T Length

Plot of tT −1,0.975Γ (T/2) √ T − 1 Γ (T/2 − 0.5)

  • Too many threads result in inaccurate normality (unless N is huge)
  • Large T leads to much contrasting and little sharing

43 / 59

slide-91
SLIDE 91

How to choose (n0, . . . , nK)?

n0 + B1n1 + B1B2n2 + B1B2B3n3 + · · · + B1B2 · · · BKnK = N

Length of each thread

LK := n0 + n1 + · · · + nK

44 / 59

slide-92
SLIDE 92

How to choose (n0, . . . , nK)?

n0 + B1n1 + B1B2n2 + B1B2B3n3 + · · · + B1B2 · · · BKnK = N

Length of each thread

LK := n0 + n1 + · · · + nK

  • Sharing: want a larger LK by setting n0 > n1 > · · · > nK

44 / 59

slide-93
SLIDE 93

How to choose (n0, . . . , nK)?

n0 + B1n1 + B1B2n2 + B1B2B3n3 + · · · + B1B2 · · · BKnK = N

Length of each thread

LK := n0 + n1 + · · · + nK

  • Sharing: want a larger LK by setting n0 > n1 > · · · > nK
  • Contrasting: want n0 < n1 < · · · < nK

44 / 59

slide-94
SLIDE 94

Outline

  • 1. Deriving HiGrad
  • 2. Constructing Confidence Intervals
  • 3. Configuring HiGrad
  • 4. Empirical Performance

45 / 59

slide-95
SLIDE 95

General simulation setup

X generated as i.i.d. N(0, 1) and Z = (X, Y ) ∈ Rd × R. Set N = 106 and use γj = 0.5j−0.55

  • Linear regression Y ∼ N(µX(θ∗), 1), where µx(θ) = x′θ
  • Logistic regression Y ∼ Bernoulli(µX(θ∗)), where

µx(θ) = ex′θ 1 + ex′θ Criteria

  • Accuracy: θ − θ∗2, where θ averaged over T threads
  • Coverage probability and length of confidence interval

46 / 59

slide-96
SLIDE 96

Accuracy

Dimension d = 50. MSE θ − θ∗2 normalized by that of vanilla SGD

  • null case where θ1 = · · · = θ50 = 0
  • dense case where θ1 = · · · = θ50 =

1 √ 50

  • sparse case where θ1 = · · · = θ5 =

1 √ 5, θ6 = · · · = θ50 = 0 47 / 59

slide-97
SLIDE 97

Accuracy

: : : :

1e+04 5e+04 2e+05 5e+05 1.00 1.10 1.20 1.30 Total number of steps Normalized risk

Linear regression, null

1e+04 5e+04 2e+05 5e+05 1.00 1.10 1.20 1.30 Total number of steps Normalized risk

Linear regression, sparse

1e+04 5e+04 2e+05 5e+05 1.00 1.10 1.20 1.30 Total number of steps Normalized risk

Linear regression, dense

1e+04 5e+04 2e+05 5e+05 1.00 1.10 1.20 1.30 Total number of steps Normalized risk

Logistic regression, null

1e+04 5e+04 2e+05 5e+05 1.00 1.10 1.20 1.30 Total number of steps Normalized risk

Logistic regression, sparse

1e+04 5e+04 2e+05 5e+05 1.00 1.10 1.20 1.30 Total number of steps Normalized risk

Logistic regression, dense

48 / 59

slide-98
SLIDE 98

Coverage and CI length

HiGrad configurations

  • K = 1, then n1 = n0 = r = 1;
  • K = 2, then n1/n0 = n2/n1 = r ∈ {0.75, 1, 1.25, 1.5}

Set θ∗

i = (i − 1)/d for i = 1, . . . , d and α = 5%. Use measure

1 20

20

  • i=1

1(µxi(θ∗) ∈ CIxi)

49 / 59

slide-99
SLIDE 99

Linear regression: d = 20

0.9348 0.9245 0.9185 0.925 0.9378 0.935 0.9318 0.924 0.9448 0.9452 0.9472 0.9425 0.8488 0.887 0.9185 0.938 0.956 1, 4, 1 1, 8, 1 1, 12, 1 1, 16, 1 1, 20, 1 2, 2, 1 2, 2, 1.25 2, 2, 1.5 2, 2, 2 3, 2, 1 3, 2, 1.25 3, 2, 1.5 3, 2, 2 2, 3, 1 2, 3, 1.25 2, 3, 1.5 2, 3, 2 0.0621 0.0618 0.0606 0.0605 0.0633 0.062 0.0614 0.061 0.0815 0.0828 0.0811 0.0801 0.0637 0.0637 0.0653 0.0683 0.0851

50 / 59

slide-100
SLIDE 100

Linear regression: d = 100

0.9115 0.897 0.8992 0.894 0.917 0.9148 0.9065 0.9 0.9302 0.9358 0.9338 0.9312 0.9125 0.92 0.9308 0.9478 0.9472 1, 4, 1 1, 8, 1 1, 12, 1 1, 16, 1 1, 20, 1 2, 2, 1 2, 2, 1.25 2, 2, 1.5 2, 2, 2 3, 2, 1 3, 2, 1.25 3, 2, 1.5 3, 2, 2 2, 3, 1 2, 3, 1.25 2, 3, 1.5 2, 3, 2 0.15 0.1491 0.1466 0.1457 0.1489 0.1453 0.1428 0.1412 0.1972 0.1946 0.1927 0.1917 0.2649 0.2495 0.2312 0.2197 0.2403

51 / 59

slide-101
SLIDE 101

A real data example: setup

From the 1994 census data based on UCI repository. Y indicates if an individual’s annual income exceeds $50,000

  • 123 features
  • 32,561 instances
  • Randomly pick 1,000 as a test set

Use N = 106, α = 10%, and γj = 0.5j−0.55. Run HiGrad for L = 500 times. Use measure coveragei = 1 L(L − 1)

L

  • ℓ1
  • ℓ2=ℓ1

1 (ˆ piℓ1 ∈ PIiℓ2)

52 / 59

slide-102
SLIDE 102

A real data example: histogram

100 200 300 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Coverage probability Count 53 / 59

slide-103
SLIDE 103

Comparisons of HiGrad configurations

Configurations Accuracy Coverage CI length

54 / 59

slide-104
SLIDE 104

Default HiGrad parameters

HiGrad R package default values

K = 2, B1 = 2, B2 = 2, n0 = n1 = n2 = N 7

55 / 59

slide-105
SLIDE 105

Concluding Remarks

55 / 59

slide-106
SLIDE 106

Straightforward extensions

  • Flexible tree structures

HiGrad tree can be asymmetric

  • N unknown

Grow the tree assuming a lower bound on N

  • Burn-in

Get a better initial point

  • A criterion for stopping

Need to incorporate selective inference

  • Mini-batch sizes

Evaluate (less) noisy gradient g(θ, Z1:m) = 1 m

m

  • i=1

g(θ, Zi)

56 / 59

slide-107
SLIDE 107

Future extensions

Improving statistical properties

◮ Finite-sample guarantee

Better coverage probability

57 / 59

slide-108
SLIDE 108

Future extensions

Improving statistical properties

◮ Finite-sample guarantee

Better coverage probability

◮ Extend Ruppert-Polyak to high dimensions

Number of unknown variables growing

57 / 59

slide-109
SLIDE 109

Future extensions

Improving statistical properties

◮ Finite-sample guarantee

Better coverage probability

◮ Extend Ruppert-Polyak to high dimensions

Number of unknown variables growing

A new template for online learning

◮ Adaptive step sizes and pre-conditioned SGD

AdaGrad (Duchi et al, 2011) and Adam (Diederik & Ba, 2014)

57 / 59

slide-110
SLIDE 110

Future extensions

Improving statistical properties

◮ Finite-sample guarantee

Better coverage probability

◮ Extend Ruppert-Polyak to high dimensions

Number of unknown variables growing

A new template for online learning

◮ Adaptive step sizes and pre-conditioned SGD

AdaGrad (Duchi et al, 2011) and Adam (Diederik & Ba, 2014)

◮ General convex optimization and non-convex problems

SVM, regularized GLM, and deep learning

57 / 59

slide-111
SLIDE 111

Take-home messages

Idea

Contrasting and sharing through hierarchical splitting

58 / 59

slide-112
SLIDE 112

Take-home messages

Idea

Contrasting and sharing through hierarchical splitting

Properties (local strong convexity)

◮ Online in nature with same computational cost as vanilla SGD ◮ A confidence interval for µ∗ x in addition to an estimator ◮ Estimator (almost) as accurate as vanilla SGD 58 / 59

slide-113
SLIDE 113

Take-home messages

Idea

Contrasting and sharing through hierarchical splitting

Properties (local strong convexity)

◮ Online in nature with same computational cost as vanilla SGD ◮ A confidence interval for µ∗ x in addition to an estimator ◮ Estimator (almost) as accurate as vanilla SGD

Bonus

Easier to parallelize than vanilla SGD!

58 / 59

slide-114
SLIDE 114

Thanks!

  • Reference. Statistical Inference for Stochastic Approximation and Online

Learning via Hierarchical Incremental Gradient Descent, Weijie Su and Yuancheng Zhu, coming soon

  • Software. R package HiGrad, coming soon

59 / 59