Decentralized Machine Learning ICML 2020 Kevin Hsieh , Amar - - PowerPoint PPT Presentation
Decentralized Machine Learning ICML 2020 Kevin Hsieh , Amar - - PowerPoint PPT Presentation
The Non-IID Data Quagmire of Decentralized Machine Learning ICML 2020 Kevin Hsieh , Amar Phanishayee, Onur Mutlu, Phillip Gibbons ML Training with Decentralized Data Geo-Distributed Learning Federated Learning Data Sovereignty and Privacy 2
ML Training with Decentralized Data
Federated Learning Geo-Distributed Learning
2
Data Sovereignty and Privacy
Major Challenges in Decentralized ML
Federated Learning Geo-Distributed Learning
3
Challenge 1: Communication Bottlenecks Solutions: Federated Averaging, Gaia, Deep Gradient Compression
Major Challenges in Decentralized ML
Federated Learning Geo-Distributed Learning
4
Solutions: Understudied! Is it a real problem? Challenge 2: Data are often highly skewed (non-iid data)
Our Work in a Nutshell
5
Real-World Dataset Experimental Study Proposed Solution
Geographical mammal images from Flickr 736K pictures in 42 mammal classes Highly skewed labels among geographic regions
6
Real-World Dataset
Skewed data labels are a fundamental and pervasive problem The problem is even worse for DNNs with batch normalization The degree of skew determines the difficulty of the problem
7
Experimental Study
Replace batch normalization with group normalization SkewScout: communication-efficient decentralized learning over arbitrarily skewed data
8
Proposed Solution
Real-World Dataset
9
Flickr-Mammal Dataset
42 mammal classes from Open Images and ImageNet 40,000 images per class Clean images with PNAS [Liu et al.,’18] Reverse geocoding to country, subcontinent, and continent
736K Pictures with Labels and Geographic Information
https://doi.org/10.5281/zenodo.3676081
Top-3 Mammals in Each Continent
11
Each top-3 mammal takes 44-92% share of global images
Label Distribution Across Continents
12
0% 10% 20% 30% 40% 50% 60% 70% 80% 90% 100% alpaca antelope armadillo brown bear bull camel cat cattle cheetah deer dolphin elephant fox goat hamster harbor seal hedgehog hippopotamus jaguar kangaroo koala leopard lion lynx monkey mule
- tter
panda pig polar bear porcupine rabbit red panda sea lion sheep skunk squirrel teddy bear tiger whale zebra
Africa Americas Asia Europe Oceania
Vast majority of mammals are dominated by 2-3 continents The labels are even more skewed among subcontinents
Experimental Study
13
Scope of Experimental Study
ML Application Decentralized Learning Algorithms
× ×
Skewness of Data Label Partitions
- Image Classification
(with various DNNs and datasets)
- Face recognition
Gaia [NSDI’17] FederatedAveraging [AISTATS’17] DeepGradientCompression [ICLR’18] 2-5 Partitions -- more partitions are worse
Results: GoogLeNet over CIFAR-10
15
- 12% -15%
- 69%
0% 20% 40% 60% 80% Shuffled Data Skewed Data Top-1 Validation Accuarcy
BSP (Bulk Synchronous Parallel) Gaia (20X faster than BSP) FederatedAveraging (20X faster than BSP) DeepGradientCompression (30X faster than BSP)
All decentralized learning algorithms lose significant accuracy Tight synchronization (BSP) is accurate but too slow
Similar Results across the Board
0% 45% 90% Shuffled Data Skewed Data Shuffled Data Skewed Data Shuffled Data Skewed Data AlexNet LeNet ResNet20 Top-1 Validation Accuracy BSP Gaia FederatedAveraging DeepGradientCompression
0% 40% 80% Shuffled Data Skewed Data Shuffled Data Skewed Data GoogLeNet ResNet10 Top-1 Validation Accuracy
Image Classification (CIFAR-10) Image Classification (ImageNet)
60% 80% 100% Shuffled Data Skewed Data BSP Gaia FedAvg
Image Classification (Mammal-Flickr)
0% 50% 100% Shuffled Data Skewed Data BSP Gaia FedAvg
Face Recognition (CASIA and test with LFW)
Skewed data is a pervasive and fundamental problem Even BSP loses accuracy for DNNs with Batch Normalization layers
Degree of Skew is a Key Factor
17
- 1.3%
- 0.5%
- 1.1%
- 3.0%
- 1.5%
- 2.6%
- 4.8%
- 3.5%
- 6.5%
- 5.3%
- 5.1%
- 8.5%
60% 65% 70% 75% 80% BSP Gaia Federated Averaging Deep Gradient Compression Top-1 Validation Accuracy 20% Skewed Data 40% Skewed Data 60% Skewed Data 80% Skewed Data
CIFAR-10 with GN-LeNet
Degree of skew can determine the difficulty of the problem
Batch Normalization ― Problem and Solution
18
Background: Batch Normalization
W BN Prev Layer Next Layer
[Ioffe & Szegedy, 2015] Standard normal distribution (μ = 0, σ = 1) in each minibatch at training time
Batch normalization enables larger learning rates and avoid sharp local minimum (generalize better)
Normalize with estimated global μ and σ at test time
Batch Normalization with Skewed Data
20
0% 35% 70% 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 Minibatch Mean Divergence Channel Shuffled Data Skewed Data Minibatch Mean Divergence: ||Mean1 – Mean2|| / AVG(Mean1, Mean2) CIFAR-10 with BN-LeNet (2 Partitions)
Minibatch μ and σ vary significantly among partitions Global μ and σ do not work for all partitions
Solution: Use Group Normalization [Wu and He, ECCV’18]
21
N
Batch Normalization
C H, W
Group Normalization
N C H, W Designed for small minibatches We apply as a solution for skewed data
- 12%
- 26%
- 29%
- 70%
0%
- 15%
- 10%
- 9%
0% 20% 40% 60% 80%
BSP Gaia Federated Averaging Deep Gradient Compression BSP Gaia Federated Averaging Deep Gradient Compression BatchNorm GroupNorm
Validation Accuracy Shuffled Data Skewed Data
Results with Group Normalization
22
GroupNorm recovers the accuracy loss for BSP and reduces accuracy losses for decentralized algorithms
Sk SkewScout wScout: Decentralized learning
- ver arbitrarily skewed data
23
Overview of Sk SkewScout wScout
- Recall that degree of data skew determines difficulty
- Sk
SkewScout wScout: : Adapts communication to the skew-induced accuracy loss
Model Travelling Accuracy Loss Estimation Communication Control
Minimize commutation when accuracy loss is acceptable Work with different decentralized learning algorithms
Evaluation of Sk SkewScout wScout
25
34.1 19.9 9.6 51.8 24.9 10.6 10 20 30 40 50 60 20% Skewed 60% Skewed 100% Skewed Communication Saving
- ver BSP (times)
SkewScout Oracle
CIFAR-10 with AlexNet
All data points achieves the same validation accuracy
29.6 19.1 9.9 42.1 23.6 11.0 10 20 30 40 50 20% Skewed 60% Skewed 100% Skewed SkewScout Oracle
CIFAR-10 with GoogLeNet
Significant saving over BSP Only within 1.5X more than Oracle
Key Takeaways
26
- Flickr-Mammal dataset: Highly skewed
label distribution in the real world
- Skewed data is a pervasive problem
- Batch normalization is particularly problematic
- SkewScout: adapts decentralized learning over
arbitrarily skewed data
- Group normalization is a good alternative to