Communication-efficient Distributed SGD with Sketching Nikita - - PowerPoint PPT Presentation

communication efficient distributed sgd with sketching
SMART_READER_LITE
LIVE PREVIEW

Communication-efficient Distributed SGD with Sketching Nikita - - PowerPoint PPT Presentation

Communication-efficient Distributed SGD with Sketching Nikita Ivkin*, Daniel Rothchild*, Enayat Ullah*, Vladimir Braverman, Ion Stoica, Raman Arora * equal contribution Going distributed: why? Large scale machine learning is moving to the


slide-1
SLIDE 1

Communication-efficient Distributed SGD with Sketching

Nikita Ivkin*, Daniel Rothchild*, Enayat Ullah*, Vladimir Braverman, Ion Stoica, Raman Arora

* equal contribution

slide-2
SLIDE 2

Going distributed: why?

  • Large scale machine learning is moving to the distributed setting due to

growing size of datasets, which does not fit in one GPU, and modern learning paradigms like Federated learning.

  • Master-workers topology. Workers compute gradients, communicate to

master; master aggregates these gradients, updates the model, and communicates back the updated parameters.

  • Problem - Slow communication overwhelms local computations.
  • Resolution(s) - Compress the gradients

○ Intrinsic low dimensional structure ○ Trade-off communication with convergence

  • Example of compression - sparsification, quantization
  • Large scale machine learning is moving to the distributed setting due to

growing size of datasets/models, and modern learning paradigms like Federated learning.

slide-3
SLIDE 3

Going distributed: how?

data model hybrid

most popular

slide-4
SLIDE 4

Going distributed: how?

parameter server sync

batch 1 batch 2 batch m

all-gather hybrid topology

slide-5
SLIDE 5

Going distributed: how?

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

slide-6
SLIDE 6

Going distributed: how?

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
slide-7
SLIDE 7

Going distributed: how?

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes the gradients

g1 g2 gm

slide-8
SLIDE 8

Going distributed: how?

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes the gradients

  • workers send gradients to parameter server

g1 g2 gm

slide-9
SLIDE 9

Going distributed: how?

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes the gradients

  • workers send gradients to parameter server

g1, g2, …, gm

slide-10
SLIDE 10

Going distributed: how?

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes the gradients

  • workers send gradients to parameter server
  • parameter server sums it up and sends it

back to all workers

G = g1+ g2+ … + gm

slide-11
SLIDE 11

Going distributed: how?

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes the gradients

  • workers send gradients to parameter server
  • parameter server sums it up and sends it

back to all workers

G

slide-12
SLIDE 12

Going distributed: how?

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes the gradients

  • workers send gradients to parameter server
  • parameter server sums it up and sends it

back to all workers

G G G

slide-13
SLIDE 13

Going distributed: how?

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes the gradients

  • workers send gradients to parameter server
  • parameter server sums it up and sends it

back to all workers

  • each worker makes a step
slide-14
SLIDE 14

Going distributed: what’s the problem?

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

  • Slow communication overwhelms local

computations: ○ parameter vector for large models can weight up to 0.5 GB ○ synchronize every fraction of a second entire parameter vector every synchronization

slide-15
SLIDE 15

Going distributed: what’s the problem?

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

  • Slow communication overwhelms local

computations: ○ parameter vector for large models can weight up to 0.5 GB ○ synchronize every fraction of a second

  • Mini batch size has limit to its growth

entire parameter vector every synchronization computation resources are wasted

slide-16
SLIDE 16

Going distributed: how others deal with it?

  • Compressing the gradients:

Quantization Sparsification

slide-17
SLIDE 17

[1] Wen, Wei, et al. "Terngrad: Ternary gradients to reduce communication in distributed deep learning." Advances in neural information processing systems. 2017. [2] Bernstein, Jeremy, et al. "signSGD: Compressed optimisation for non-convex problems." arXiv preprint arXiv:1802.04434 (2018). [3] Karimireddy, Sai Praneeth, et al. "Error Feedback Fixes SignSGD and other Gradient Compression Schemes." arXiv preprint arXiv:1901.09847 (2019). APA

  • Quantizing gradients can give a constant factor decrease in communication cost.
  • Simplest quantization to 16-bit, but all the way to 2-bit (TernGrad [1]) and 1-bit

(SignSGD [2]) have been successful.

  • Quantization techniques can in principle be combined with gradient sparsification

Quantization

slide-18
SLIDE 18

Sparsification

[1] Stich, Sebastian U., Jean-aptiste Cordonnier, and Martin Jaggi. "Sparsified sgd with memory." Advances in Neural Information Processing Systems. 2018. [2] Alistarh, Dan, et al. "The convergence of sparsified gradient methods." Advances in Neural Information Processing Systems. 2018. [3] Lin, Yujun, et al. "Deep gradient compression: Reducing the communication bandwidth for distributed training." arXiv preprint arXiv:1712.01887 (2017). APA

  • Existing techniques either communicate Ω(Wd) in the worst case, or are

heuristics; W - number of workers, d - dimension of gradient.

  • [1] showed that SGD (on 1 machine) with top-k gradient updates and error

accumulation has desirable convergence properties.

  • Q. Can we extend the top-k to the distributed setting?

○ MEM-SGD [1] (for 1 machine, extension to distributed setting is sequential) ○ top-k SGD [2] (assumes that global top k is close to sum of local top k) ○ Deep gradient compression [3] (no theoretical guarantees).

  • We resolve the above using sketches!
slide-19
SLIDE 19
slide-20
SLIDE 20

9 4 2 5 2 3

frequencies

  • f balls

Want to find:

slide-21
SLIDE 21

+1 +1 +1 +1 -1 +1 +1 +1 +1 +1 -1 -1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 -1 -1 -1 -1 +1 +1

  • 1

+1

  • 1
  • 1

+/-1 equiprobably, independent

1 2 3 4 3 4 5 6 7 8 7 6 7 8 9 10 11 12 13 14 15 16 15 14 13 12

slide-22
SLIDE 22

+1 +1 +1 +1 -1 +1 +1 +1 +1 +1 -1 -1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 -1 -1 -1 -1 +1 +1

  • 1

+1

  • 1
  • 1

+1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 -1 -1 +1 +1 +1 +1 +1 +1 -1 -1 -1 -1 -1 +/-1 equiprobably, independent 22

slide-23
SLIDE 23

Count Sketch

coordinate updates

+1

7 sign hash bucket hash

slide-24
SLIDE 24

Count Sketch

slide-25
SLIDE 25

Mergebility

slide-26
SLIDE 26

Compression scheme

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

slide-27
SLIDE 27

Compression scheme

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
slide-28
SLIDE 28

Compression scheme

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes and sketch the gradients

g1 g2 gm

slide-29
SLIDE 29

Compression scheme

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes and sketch the gradients

S(g1) S(g2) S(gm)

slide-30
SLIDE 30

Compression scheme

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes and sketch the gradients

  • workers send sketches to parameter server

S(g1) S(g2) S(gm)

slide-31
SLIDE 31

Compression scheme

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes and sketch the gradients

  • workers send sketches to parameter server

S1, S2, …, Sm

slide-32
SLIDE 32

Compression scheme

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes and sketch the gradients

  • workers send sketches to parameter server
  • parameter server merge the sketches,

extract top k and send it back

S = S1+ S2+ … + Sm

slide-33
SLIDE 33

Compression scheme

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes and sketch the gradients

  • workers send sketches to parameter server
  • parameter server merge the sketches,

extract top k and send it back

G’ = topk(S)

slide-34
SLIDE 34

Compression scheme

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

G’ G’ G’

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes and sketch the gradients

  • workers send sketches to parameter server
  • parameter server merge the sketches,

extract top k and send it back

slide-35
SLIDE 35

Compression scheme

parameter server workers

batch 1 worker 1

data

batch 2 worker 2 batch m worker m

Synchronization with the parameter server:

  • mini-batches distributed among workers
  • each worker makes forward-backward pass

and computes and sketch the gradients

  • workers send sketches to parameter server
  • parameter server merge the sketches,

extract top k and send it back

  • each worker makes a step
slide-36
SLIDE 36

Algorithm and theory

❑ Theoretical guarantees

  • Converges at O(1/WT) rate, at par with SGD

for smooth strongly convex functions, where W is the number of workers.

  • Communicates O(k log2 d), size of sketch,

0< k < d, d: dimension of model. ❏ Scalability

  • More workers - Increasing the number of

workers W increases the rate of convergence (suitable for Federated learning)

  • Bigger models - Increasing the model size d

increases the compression ratio d/k log2 d.

slide-37
SLIDE 37

Empirical Results

BLEU scores on the test data achieved for vanilla distributed SGD, top-k SGD, and SKETCHED-SGD with 20x and 40x compression.. Larger BLEU score is better. 90M 70M

slide-38
SLIDE 38

Empirical Results

Comparison between SKETCHED-SGD and local top-k SGD on CIFAR10. The best overall compression that local top-k can achieve for many workers is 2x.

slide-39
SLIDE 39

Data parallelism Model parallelism

Computational overhead

compute hashes update counters

Simple to parallelize the sketching part: 100x acceleration on modern GPU Specifics of distributed SGD application:

  • gradient vector is already on GPU
  • for reasonable d, all hashes can be precomputed
  • ne-liner to parallelize using pytorch framework (20x speed up)
slide-40
SLIDE 40

Thanks a lot!