COCOA Communication-Efficient Coordinate Ascent
Virginia Smith
Martin Jaggi, Martin Takáč, Jonathan Terhorst, Sanjay Krishnan, Thomas Hofmann, & Michael I. Jordan
C O C O A Communication-Efficient Coordinate Ascent Virginia Smith - - PowerPoint PPT Presentation
C O C O A Communication-Efficient Coordinate Ascent Virginia Smith Martin Jaggi, Martin Tak , Jonathan Terhorst, Sanjay Krishnan, Thomas Hofmann, & Michael I. Jordan LARGE-SCALE OPTIMIZATION LARGE-SCALE OPTIMIZATION C O C O A
Martin Jaggi, Martin Takáč, Jonathan Terhorst, Sanjay Krishnan, Thomas Hofmann, & Michael I. Jordan
image/music/video tagging document categorization item recommendation click-through rate prediction sequence tagging protein structure prediction sensor data prediction spam classification fraud detection
DATA & PROBLEM
classification, regression, collaborative filtering, …
DATA & PROBLEM
classification, regression, collaborative filtering, …
MACHINE LEARNING MODEL
logistic regression, lasso, support vector machines, …
DATA & PROBLEM
classification, regression, collaborative filtering, …
MACHINE LEARNING MODEL
logistic regression, lasso, support vector machines, …
OPTIMIZATION ALGORITHM
gradient descent, coordinate descent, Newton’s method, …
w x
= 1 w x
=
w x
= w 2 / ||w||
w x
= 1 w x
=
w x
= w 2 / ||w||
min
w∈Rd
n
n
X
i=1
`hinge(yiwT xi)
w x
= 1 w x
=
w x
= w 2 / ||w||
min
w∈Rd
n
n
X
i=1
`hinge(yiwT xi)
Descent algorithms and line search methods Acceleration, momentum, and conjugate gradients Newton and Quasi-Newton methods Coordinate descent Stochastic and incremental gradient methods SMO SVMlight LIBLINEAR
min
w∈Rd
n
n
X
i=1
`i(wT xi)
min
w∈Rd
n
n
X
i=1
`i(wT xi) support vector machines
min
w∈Rd
n
n
X
i=1
`i(wT xi) support vector machines logistic regression
min
w∈Rd
n
n
X
i=1
`i(wT xi) support vector machines logistic regression lasso regression
min
w∈Rd
n
n
X
i=1
`i(wT xi) support vector machines logistic regression lasso regression ridge regression
min
w∈Rd
n
n
X
i=1
`i(wT xi) support vector machines logistic regression lasso regression ridge regression etc…
min
w∈Rd
n
n
X
i=1
`i(wT xi) support vector machines logistic regression lasso regression ridge regression etc…
image/music/video tagging document categorization item recommendation click-through rate prediction sequence tagging protein structure prediction sensor data prediction spam classification fraud detection
DATA & PROBLEM
classification, regression, collaborative filtering, …
MACHINE LEARNING MODEL
logistic regression, lasso, support vector machines, …
OPTIMIZATION ALGORITHM
gradient descent, coordinate descent, Newton’s method, …
DATA & PROBLEM
classification, regression, collaborative filtering, …
MACHINE LEARNING MODEL
logistic regression, lasso, support vector machines, …
OPTIMIZATION ALGORITHM
gradient descent, coordinate descent, Newton’s method, …
SYSTEMS SETTING
multi-core, cluster, cloud, supercomputer, …
DATA & PROBLEM
classification, regression, collaborative filtering, …
MACHINE LEARNING MODEL
logistic regression, lasso, support vector machines, …
OPTIMIZATION ALGORITHM
gradient descent, coordinate descent, Newton’s method, …
SYSTEMS SETTING
multi-core, cluster, cloud, supercomputer, …
reduce: w = w − α P
k ∆w
reduce: w = w − α P
k ∆w
“always communicate”
reduce: w = w − α P
k ∆w
✔ convergence guarantees “always communicate”
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication “always communicate”
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication “always communicate”
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication
average: w := 1
K
P
k wk
“always communicate”
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication
average: w := 1
K
P
k wk
“always communicate” “never communicate”
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication ✔ low communication
average: w := 1
K
P
k wk
“always communicate” “never communicate”
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication ✔ low communication
average: w := 1
K
P
k wk
“always communicate” “never communicate” ZDWJ, 2012
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication ✔ low communication ✗ convergence not guaranteed
average: w := 1
K
P
k wk
“always communicate” “never communicate” ZDWJ, 2012
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication ✔ low communication ✗ convergence not guaranteed
average: w := 1
K
P
k wk
“always communicate” “never communicate” ZDWJ, 2012
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication ✔ low communication ✗ convergence not guaranteed
average: w := 1
K
P
k wk
“always communicate” “never communicate” ZDWJ, 2012
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication ✔ low communication ✗ convergence not guaranteed
average: w := 1
K
P
k wk
“always communicate” “never communicate” ZDWJ, 2012
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication ✔ low communication ✗ convergence not guaranteed
average: w := 1
K
P
k wk
“always communicate” “never communicate” ZDWJ, 2012
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication ✔ low communication ✗ convergence not guaranteed
average: w := 1
K
P
k wk
“always communicate” “never communicate” ZDWJ, 2012
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication ✔ low communication ✗ convergence not guaranteed
average: w := 1
K
P
k wk
“always communicate” “never communicate” ZDWJ, 2012
reduce: w = w − α P
k ∆w
✔ convergence guarantees ✗ high communication ✔ low communication ✗ convergence not guaranteed
average: w := 1
K
P
k wk
“always communicate” “never communicate” ZDWJ, 2012
reduce: w = w − α P
k ∆w
reduce: w = w − α
|b|
P
i∈b ∆w
reduce: w = w − α
|b|
P
i∈b ∆w
✔ convergence guarantees reduce: w = w − α
|b|
P
i∈b ∆w
✔ convergence guarantees ✔ tunable communication reduce: w = w − α
|b|
P
i∈b ∆w
✔ convergence guarantees ✔ tunable communication a natural middle-ground reduce: w = w − α
|b|
P
i∈b ∆w
✔ convergence guarantees ✔ tunable communication a natural middle-ground reduce: w = w − α
|b|
P
i∈b ∆w
Use Primal-Dual Framework
Immediately apply local updates
Average over K << batch size
Use Primal-Dual Framework
Immediately apply local updates
Average over K << batch size
PRIMAL DUAL
PRIMAL DUAL
min
w∈Rd
" P(w) := 2 ||w||2 + 1 n
n
X
i=1
`i(wT xi) #
PRIMAL DUAL
min
w∈Rd
" P(w) := 2 ||w||2 + 1 n
n
X
i=1
`i(wT xi) # max
α∈Rn
" D(↵) := −||A↵||2 − 1 n
n
X
i=1
`∗
i (−↵i)
# Ai = 1 λnxi
PRIMAL DUAL Stopping criteria given by duality gap
min
w∈Rd
" P(w) := 2 ||w||2 + 1 n
n
X
i=1
`i(wT xi) # max
α∈Rn
" D(↵) := −||A↵||2 − 1 n
n
X
i=1
`∗
i (−↵i)
# Ai = 1 λnxi
PRIMAL DUAL Stopping criteria given by duality gap Good performance in practice
min
w∈Rd
" P(w) := 2 ||w||2 + 1 n
n
X
i=1
`i(wT xi) # max
α∈Rn
" D(↵) := −||A↵||2 − 1 n
n
X
i=1
`∗
i (−↵i)
# Ai = 1 λnxi
PRIMAL DUAL Stopping criteria given by duality gap Good performance in practice Default in software packages e.g. liblinear
min
w∈Rd
" P(w) := 2 ||w||2 + 1 n
n
X
i=1
`i(wT xi) # max
α∈Rn
" D(↵) := −||A↵||2 − 1 n
n
X
i=1
`∗
i (−↵i)
# Ai = 1 λnxi
for i 2 b ∆w ∆w αriP(w) end w w + ∆w
STALE
for i 2 b ∆w ∆w αriP(w) end w w + ∆w
STALE for i 2 b ∆w ∆w αriP(w) w w + ∆w end FRESH
reduce: w = w + 1
K
P
k ∆wk
Algorithm 1: CoCoA Input: T ≥ 1, scaling parameter 1 ≤ βK ≤ K (default: βK := 1). Data: {(xi, yi)}n
i=1 distributed over K machines
Initialize: α(0)
[k] ← 0 for all machines k, and w(0) ← 0
for t = 1, 2, . . . , T for all machines k = 1, 2, . . . , K in parallel (∆α[k], ∆wk) ← LocalDualMethod(α(t−1)
[k]
, w(t−1)) α(t)
[k] ← α(t−1) [k]
+ βK
K ∆α[k]
end reduce w(t) ← w(t−1) + βK
K
PK
k=1 ∆wk
end
Procedure A: LocalDualMethod: Dual algorithm on machine k Input: Local α[k] ∈ Rnk, and w ∈ Rd consistent with other coordinate blocks of α s.t. w = Aα Data: Local {(xi, yi)}nk
i=1
Output: ∆α[k] and ∆w := A[k]∆α[k]
Algorithm 1: CoCoA Input: T ≥ 1, scaling parameter 1 ≤ βK ≤ K (default: βK := 1). Data: {(xi, yi)}n
i=1 distributed over K machines
Initialize: α(0)
[k] ← 0 for all machines k, and w(0) ← 0
for t = 1, 2, . . . , T for all machines k = 1, 2, . . . , K in parallel (∆α[k], ∆wk) ← LocalDualMethod(α(t−1)
[k]
, w(t−1)) α(t)
[k] ← α(t−1) [k]
+ βK
K ∆α[k]
end reduce w(t) ← w(t−1) + βK
K
PK
k=1 ∆wk
end
Procedure A: LocalDualMethod: Dual algorithm on machine k Input: Local α[k] ∈ Rnk, and w ∈ Rd consistent with other coordinate blocks of α s.t. w = Aα Data: Local {(xi, yi)}nk
i=1
Output: ∆α[k] and ∆w := A[k]∆α[k]
Algorithm 1: CoCoA Input: T ≥ 1, scaling parameter 1 ≤ βK ≤ K (default: βK := 1). Data: {(xi, yi)}n
i=1 distributed over K machines
Initialize: α(0)
[k] ← 0 for all machines k, and w(0) ← 0
for t = 1, 2, . . . , T for all machines k = 1, 2, . . . , K in parallel (∆α[k], ∆wk) ← LocalDualMethod(α(t−1)
[k]
, w(t−1)) α(t)
[k] ← α(t−1) [k]
+ βK
K ∆α[k]
end reduce w(t) ← w(t−1) + βK
K
PK
k=1 ∆wk
end
Procedure A: LocalDualMethod: Dual algorithm on machine k Input: Local α[k] ∈ Rnk, and w ∈ Rd consistent with other coordinate blocks of α s.t. w = Aα Data: Local {(xi, yi)}nk
i=1
Output: ∆α[k] and ∆w := A[k]∆α[k]
Algorithm 1: CoCoA Input: T ≥ 1, scaling parameter 1 ≤ βK ≤ K (default: βK := 1). Data: {(xi, yi)}n
i=1 distributed over K machines
Initialize: α(0)
[k] ← 0 for all machines k, and w(0) ← 0
for t = 1, 2, . . . , T for all machines k = 1, 2, . . . , K in parallel (∆α[k], ∆wk) ← LocalDualMethod(α(t−1)
[k]
, w(t−1)) α(t)
[k] ← α(t−1) [k]
+ βK
K ∆α[k]
end reduce w(t) ← w(t−1) + βK
K
PK
k=1 ∆wk
end
Procedure A: LocalDualMethod: Dual algorithm on machine k Input: Local α[k] ∈ Rnk, and w ∈ Rd consistent with other coordinate blocks of α s.t. w = Aα Data: Local {(xi, yi)}nk
i=1
Output: ∆α[k] and ∆w := A[k]∆α[k]
Algorithm 1: CoCoA Input: T ≥ 1, scaling parameter 1 ≤ βK ≤ K (default: βK := 1). Data: {(xi, yi)}n
i=1 distributed over K machines
Initialize: α(0)
[k] ← 0 for all machines k, and w(0) ← 0
for t = 1, 2, . . . , T for all machines k = 1, 2, . . . , K in parallel (∆α[k], ∆wk) ← LocalDualMethod(α(t−1)
[k]
, w(t−1)) α(t)
[k] ← α(t−1) [k]
+ βK
K ∆α[k]
end reduce w(t) ← w(t−1) + βK
K
PK
k=1 ∆wk
end
Procedure A: LocalDualMethod: Dual algorithm on machine k Input: Local α[k] ∈ Rnk, and w ∈ Rd consistent with other coordinate blocks of α s.t. w = Aα Data: Local {(xi, yi)}nk
i=1
Output: ∆α[k] and ∆w := A[k]∆α[k]
Assumptions: are -smooth LocalDualMethod makes improvement per step
e.g. for SDCA
Θ
Θ = ✓ 1 − λnγ 1 + λnγ 1 ˜ n ◆H
`i 1/γ
E[D(α∗) − D(α(T ))] ≤ ✓ 1 − (1 − Θ) 1 K λnγ σ + λnγ ◆T ⇣ D(α∗) − D(α(0)) ⌘
Assumptions: are -smooth LocalDualMethod makes improvement per step
e.g. for SDCA
Θ
Θ = ✓ 1 − λnγ 1 + λnγ 1 ˜ n ◆H
`i 1/γ
E[D(α∗) − D(α(T ))] ≤ ✓ 1 − (1 − Θ) 1 K λnγ σ + λnγ ◆T ⇣ D(α∗) − D(α(0)) ⌘
Assumptions: are -smooth LocalDualMethod makes improvement per step
e.g. for SDCA
Θ
Θ = ✓ 1 − λnγ 1 + λnγ 1 ˜ n ◆H
`i 1/γ applies also to duality gap measure of difficulty of data partition 0 ≤ σ ≤ n/K
E[D(α∗) − D(α(T ))] ≤ ✓ 1 − (1 − Θ) 1 K λnγ σ + λnγ ◆T ⇣ D(α∗) − D(α(0)) ⌘
Assumptions: are -smooth LocalDualMethod makes improvement per step
e.g. for SDCA
Θ
Θ = ✓ 1 − λnγ 1 + λnγ 1 ˜ n ◆H
`i 1/γ applies also to duality gap measure of difficulty of data partition 0 ≤ σ ≤ n/K
E[D(α∗) − D(α(T ))] ≤ ✓ 1 − (1 − Θ) 1 K λnγ σ + λnγ ◆T ⇣ D(α∗) − D(α(0)) ⌘
Assumptions: are -smooth LocalDualMethod makes improvement per step
e.g. for SDCA
Θ
Θ = ✓ 1 − λnγ 1 + λnγ 1 ˜ n ◆H
`i 1/γ applies also to duality gap measure of difficulty of data partition 0 ≤ σ ≤ n/K
E[D(α∗) − D(α(T ))] ≤ ✓ 1 − (1 − Θ) 1 K λnγ σ + λnγ ◆T ⇣ D(α∗) − D(α(0)) ⌘
Assumptions: are -smooth LocalDualMethod makes improvement per step
e.g. for SDCA
Θ
Θ = ✓ 1 − λnγ 1 + λnγ 1 ˜ n ◆H
`i 1/γ applies also to duality gap measure of difficulty of data partition 0 ≤ σ ≤ n/K
Dataset Training (n) Features (d) Sparsity Workers (K) Cov 522,911 54 22.22% 4 Rcv1 677,399 47,236 0.16% 8 Imagenet 32,751 160,000 100% 32
200 400 600 800 10
−6
10
−4
10
−2
10 10
2
Imagenet Time (s) Log Primal Suboptimality
200 400 600 800 10
−6
10
−4
10
−2
10 10
2
Imagenet
COCOA (H=1e3) mini−batch−CD (H=1) local−SGD (H=1e3) mini−batch−SGD (H=10)
200 400 600 800 10
−6
10
−4
10
−2
10 10
2
Imagenet Time (s) Log Primal Suboptimality
200 400 600 800 10
−6
10
−4
10
−2
10 10
2
Imagenet
COCOA (H=1e3) mini−batch−CD (H=1) local−SGD (H=1e3) mini−batch−SGD (H=10)
20 40 60 80 100 10
−6
10
−4
10
−2
10 10
2
Cov Time (s) Log Primal Suboptimality
20 40 60 80 100 10
−6
10
−4
10
−2
10 10
2
Cov
COCOA (H=1e5) minibatch−CD (H=100) local−SGD (H=1e5) batch−SGD (H=1) 100 200 300 400 10
−6
10
−4
10
−2
10 10
2
RCV1 Time (s) Log Primal Suboptimality
100 200 300 400 10
−6
10
−4
10
−2
10 10
2
COCOA (H=1e5) minibatch−CD (H=100) local−SGD (H=1e4) batch−SGD (H=100)
20 40 60 80 100 10
−6
10
−4
10
−2
10 10
2
Time (s) Log Primal Suboptimality 20 40 60 80 100 10
−6
10
−4
10
−2
10 10
2
1e5 1e4 1e3 100 1