optimizing federated learning on non iid data with
play

Optimizing Federated Learning on Non-IID Data with Reinforcement - PowerPoint PPT Presentation

INFOCOM20 Optimizing Federated Learning on Non-IID Data with Reinforcement Learning Hao Wang *, Zakhary Kaplan*, Di Niu^, Baochun Li* *University of Toronto, ^University of Alberta < < > > Alexa Siri 2 Machine


  1. INFOCOM’20 Optimizing Federated Learning on Non-IID Data with Reinforcement Learning Hao Wang *, Zakhary Kaplan*, Di Niu^, Baochun Li* *University of Toronto, ^University of Alberta

  2. … < < …> …> Alexa Siri 2

  3. Machine Learning

  4. Federated Learning

  5. Federated Averaging Algorithm (FedAvg) … 5

  6. Random selection Local model … Local data 6

  7. Random selection Local model … Local data 6

  8. Thank you for the feedback Local model … Local data 7

  9. ML algorithms assume the training data is i ndependent and identically distributed (IID) 8

  10. Federated Learning reuses the existing ML algorithms but on non-IID data 9

  11. … … < > … < > 10

  12. … … < > 10

  13. Non-IID data introduces bias into the training and leads to a slow convergence and training failures 11

  14. MNIST http://yann.lecun.com/exdb/mnist/ 12

  15. FedAvg-IID FedAvg-non-IID 100 97 Accuracy (%) 95 93 91 1 10 19 28 37 46 55 64 73 82 91 100109 118 127 136145154 Communication Round (#) 13

  16. Build IID training data? No, w e don’t have any access to the data on your phone. 14

  17. Shared Data α × Shared α × Shared Data Data α × Shared Data α × Shared α × Shared α × Shared Data Data Data Private Private Private Data Data Data Figure 6: Illustration of the data- Zhao, Yue, et al. "Federated Learning with Non-IID Data." arXiv preprint arXiv:1806.00582 (2018). 15

  18. Optimizing Federated Learning on Non-IID Data with Reinforcement Learning [INFOCOM’20] 16

  19. Build IID training data? No Peeking into the data distribution on each device without violating data privacy Probing the bias of non-IID data 17

  20. … Carefully select devices to balance the bias introduced by non-IID data … < > 18

  21. Probing the data distribution

  22. 100 devices, each has 600 samples Non-IID data 80% data has the same label, e.g, “6” Initial model A two-layer CNN model with 431,080 parameters Local model 20

  23. We apply Principle Component Analysis (PCA) to reduce dimensionality 431,080-dimension model weight 2-dimension space 21

  24. − 0.05 0 0.05 0.4 0 0.3 0.2 − 0.05 C1 0.1 − 0.10 0 − 0.1 − 0.2 − 0.2 0 0.2 0.4 0.6 C0 22

  25. … An implicit connection between model weights and data distribution … 23

  26. Probing the data distribution Selecting devices for federated learning

  27. < > < > 25

  28. − 0.05 0 0.05 0.4 0 0.3 0.2 − 0.05 C1 0.1 − 0.10 0 − 0.1 − 0.2 − 0.2 0 0.2 0.4 0.6 C0 26

  29. K-Center Clustering … 27

  30. Random Selection from Groups … 28

  31. FedAvg-IID FedAvg-non-IID K-Center-non-IID 100 97 Accuracy (%) 95 93 91 1 31 61 91 121 151 Communication Round (#) 29

  32. Probing the data distribution Selecting devices for federated learning How to select devices to speed up training ?

  33. It is difficult to select the appropriate subset of devices - Model weights —> device selection choice - A dynamic and undeterministic problem Reinforcement Learning (RL) 31

  34. Reward Action … Environment Agent FL server State (…,state, action, reward, state’, action’, …,end) Episode 32

  35. (…,state, action, reward , state’, action’, …,end) (…,state, action, reward , state’, action’, …,end) (…,state, action, reward , state’, action’, …,end) (…,state, action, reward , state’, action’, …,end) Learn to maximize sum(reward) (…,state, action, reward , state’, action’, …,end) (…,state, action, reward , state’, action’, …,end) … (…,state, action, reward , state’, action’, …,end) (…,state, action, reward , state’, action’, …,end) 34

  36. States Global weights Local model weights … < > 100-dimension vector 35

  37. Actions Select K devices from a pool of N devices — a huge action space Selecting 10 devices from a pool of 100 devices leads to 1.7310309e+13 possible actions 36

  38. Modify the RL training algorithm

  39. Selecting the Top K Devices Only one device is selected during the RL training Now the action space is {1, 2, …, N} , instead of selecting K devices from N devices 38

  40. Evaluating Each Device Scores 0.3 0.5 Select the top K 0.1 … … … … 0.2 39

  41. Rewards Ξ Positive constant r t = Ξ ( ω t −Ω ) − 1 ω t Training Accuracy Ω 0 ⩽ ω t ⩽ Ω ⩽ 1 Target accuracy Communication t r t ∈ ( − 1,0] round # ! Accuracy increase: r t ω t ⬆ —> ⬆ " More communication rounds: ⬆ —> sum( ) ⬇ t r t 40

  42. Training the DRL Agent Look for a function that points out the actions leading to the maximum cumulative return under a particular state T T γ t − 1 ( Ξ ( ω t −Ω ) − 1) ∑ ∑ γ t − 1 r t = R = Max t =1 t =1 discount factor γ ∈ (0,1) 41

  43. r t Reward Agent DDQN Environment Features softmax a t … … … … … Action FL server s t − 1 State 42

  44. 0 Cumulative Discounted Reward -28 -55 -83 Training the DRL agent -110 1 11 21 31 41 51 61 71 81 91 101 111 121 131 141 151 161 171 Episode 43

  45. Check-in Selection Update … Probing Update weight DRL agent DRL agent 44

  46. Evaluating Our Solution Benchmark: MNIST, FashionMNIST, CIFAR-10 Non-IID level: 1, half-and-half, 80%, 50% Half-and-half 80% 45

  47. FedAvg K-Center Favor 2200 Communication Rounds 1650 Non-IID level 1100 1 550 0 MNIST FashionMNIST CIFAR-10 46

  48. FedAvg K-Center Favor 1600 Communication Rounds 1200 Non-IID level 800 half & half 400 0 MNIST FashionMNIST CIFAR-10 47

  49. FedAvg K-Center Favor 240 Communication Rounds 180 Non-IID level 120 80% 60 0 MNIST FashionMNIST CIFAR-10 48

  50. FedAvg K-Center Favor 70 Communication Rounds 53 Non-IID level 35 50% 18 0 MNIST FashionMNIST CIFAR-10 49

  51. w init 1.5 Local weights Global weights w 1 1.0 C2 0.5 w 2 FedAvg w 3 w 4 w 5 0 − 0.5 1.0 1.5 2.0 2.5 3.0 C1 w init 1.5 Local weights Global weights w 1 1.0 C2 Favor 0.5 w 2 w 3 w 4 0 − 0.5 1.0 1.5 2.0 2.5 3.0 C1 50

  52. Indirect data distribution probing DRL-based device selection Communication rounds can be reduced by up to • 49% on the MNIST • 23% on FashionMNIST • 42% on CIFAR-10 51

Download Presentation
Download Policy: The content available on the website is offered to you 'AS IS' for your personal information and use only. It cannot be commercialized, licensed, or distributed on other websites without prior consent from the author. To download a presentation, simply click this link. If you encounter any difficulties during the download process, it's possible that the publisher has removed the file from their server.

Recommend


More recommend