Hao Wang*, Zakhary Kaplan*, Di Niu^, Baochun Li* *University of Toronto, ^University of Alberta
INFOCOM’20
Optimizing Federated Learning on Non-IID Data with Reinforcement - - PowerPoint PPT Presentation
INFOCOM20 Optimizing Federated Learning on Non-IID Data with Reinforcement Learning Hao Wang *, Zakhary Kaplan*, Di Niu^, Baochun Li* *University of Toronto, ^University of Alberta < < > > Alexa Siri 2 Machine
Hao Wang*, Zakhary Kaplan*, Di Niu^, Baochun Li* *University of Toronto, ^University of Alberta
INFOCOM’20
Siri Alexa
2
5
Federated Averaging Algorithm (FedAvg)
6
Local data Local model Random selection
6
Local data Local model Random selection
7
Thank you for the feedback Local data Local model
8
9
10
10
11
http://yann.lecun.com/exdb/mnist/
12
Accuracy (%) 91 93 95 97 100 Communication Round (#) 1 10 19 28 37 46 55 64 73 82 91 100109 118 127 136145154
FedAvg-IID FedAvg-non-IID
13
14
Shared Data
α× Shared
Data
α× Shared
Data
α× Shared
Data Private Data
α× Shared
Data Private Data
α× Shared
Data Private Data
α× Shared
Data
Zhao, Yue, et al. "Federated Learning with Non-IID Data." arXiv preprint arXiv:1806.00582 (2018).
15
16
17
18
20
Initial model Local model A two-layer CNN model with 431,080 parameters 100 devices, each has 600 samples Non-IID data 80% data has the same label, e.g, “6”
21
431,080-dimension model weight 2-dimension space
C1 −0.2 −0.1 0.1 0.2 0.3 0.4 C0 −0.2 0.2 0.4 0.6
−0.10 −0.05 −0.05 0.05
22
23
25
C1 −0.2 −0.1 0.1 0.2 0.3 0.4 C0 −0.2 0.2 0.4 0.6
−0.10 −0.05 −0.05 0.05
26
27
28
Accuracy (%) 91 93 95 97 100 Communication Round (#) 1 31 61 91 121 151
FedAvg-IID FedAvg-non-IID K-Center-non-IID
29
It is difficult to select the appropriate subset of devices
31
… FL server
32
Episode (…,state, action, reward, state’, action’, …,end)
34
… (…,state, action, reward, state’, action’, …,end) (…,state, action, reward, state’, action’, …,end) (…,state, action, reward, state’, action’, …,end) (…,state, action, reward, state’, action’, …,end) (…,state, action, reward, state’, action’, …,end) (…,state, action, reward, state’, action’, …,end) (…,state, action, reward, state’, action’, …,end) (…,state, action, reward, state’, action’, …,end)
Global weights Local model weights
35
100-dimension vector
Selecting 10 devices from a pool of 100 devices leads to 1.7310309e+13 possible actions
36
Only one device is selected during the RL training Now the action space is {1, 2, …, N}, instead of selecting K devices from N devices
38
Scores 0.3 0.5 0.1 … 0.2
39
… … Select the top K
Positive constant Training Accuracy Target accuracy Communication round #
40
⬆ —> ⬆
41
T
t=1
T
t=1
γ ∈ (0,1)
discount factor
Look for a function that points out the actions leading to the maximum cumulative return under a particular state
Action
42
Environment
Agent Features softmax … … … …
State
… FL server
Reward
DDQN
Cumulative Discounted Reward
Episode 1 11 21 31 41 51 61 71 81 91 101 111 121 131 141 151 161 171
43
DRL agent DRL agent
Selection Check-in Probing Update Update weight
44
Benchmark: MNIST, FashionMNIST, CIFAR-10 Non-IID level: 1, half-and-half, 80%, 50% 80% Half-and-half
45
Communication Rounds 550 1100 1650 2200 MNIST FashionMNIST CIFAR-10
FedAvg K-Center Favor
Non-IID level 1
46
Communication Rounds 400 800 1200 1600 MNIST FashionMNIST CIFAR-10
FedAvg K-Center Favor
Non-IID level half & half
47
Communication Rounds 60 120 180 240 MNIST FashionMNIST CIFAR-10
FedAvg K-Center Favor
Non-IID level 80%
48
Communication Rounds 18 35 53 70 MNIST FashionMNIST CIFAR-10
FedAvg K-Center Favor
Non-IID level 50%
49
winit w1 w2 w3 w4 w5
Local weights Global weights C2 −0.5 0.5 1.0 1.5 C1 1.0 1.5 2.0 2.5 3.0
winit w1 w2 w3 w4
Local weights Global weights C2 −0.5 0.5 1.0 1.5 C1 1.0 1.5 2.0 2.5 3.0
50
Indirect data distribution probing DRL-based device selection Communication rounds can be reduced by up to
51