Training Binary Neural Networks Using the Bayesian Learning Rule
1
Thirty-seventh International Conference on Machine Learning (ICML 2020)
Xiangming Meng (RIKEN AIP) Presenter
Mohammad Emtiyaz Khan (RIKEN AIP) Roman Bachmann (EPFL)
Training Binary Neural Networks Using the Bayesian Learning Rule - - PowerPoint PPT Presentation
Thirty-seventh International Conference on Machine Learning (ICML 2020) Training Binary Neural Networks Using the Bayesian Learning Rule Xiangming Meng Roman Bachmann Mohammad Emtiyaz Khan (EPFL) (RIKEN AIP) (RIKEN AIP) Presenter 1 Binary
1
Thirty-seventh International Conference on Machine Learning (ICML 2020)
Xiangming Meng (RIKEN AIP) Presenter
Mohammad Emtiyaz Khan (RIKEN AIP) Roman Bachmann (EPFL)
estimator (STE)”!
2
1. Courbariaux et al., Training deep neural networks with binary weights during propagations. NeurIPS 2015. 2. Courbariaux et al., . Binarized neural networks.… arXiv:1602.02830, 2016. 3. Yin, P. et al., Understanding straight-through estimator in training activation quantized neural nets. arXiv, 2019.
(natural-gradient variational inference), we can justify such previous approaches
distribution (a continuous optimization problem)
which can be used for continual learning [3]
3
1. Khan, M. E. and Rue, H. Learning-algorithms from bayesian principles. ArXiv. 2019. 2. Khan, M. E. and Lin, W. Conjugate-computation variational inference. AISTATS, 2017 3. Kirkpatrick, J. et al. Overcoming catastrophic forgetting in neural networks. PNAS, 114(13):3521–3526, 2017.
4
Output Input Loss Neural Network Binary weights
through estimator (STE)” [1]
5
Output Input Loss Neural Network Binary weights
“latent” weights
through estimator (STE)” [1]
weights are not weights but “Inertia”
6
Output Input Loss Neural Network Binary weights
“latent” weights
Binary Optimizer (Bop)
through estimator (STE)” [1]
weights are not weights but “Inertia”
7
Output Input Loss Neural Network Binary weights
“latent” weights
Binary Optimizer (Bop)
continuous optimization problem)
min
q(w)
Posterior approximation
KL Divergence Prior Distribution
8
Loss
continuous optimization problem)
min
q(w)
Posterior approximation
KL Divergence Prior Distribution
9
q (w) =
D
∏
i=1
p
1 + wi 2
i
(1 − pi)
1 − wi 2
Natural parameters: λi := 1
2 log pi 1 − pi
Probability of wi = + 1
q (w) =
D
∏
i=1
exp [λiϕ (wi) − A (λi)]
wi ∈ {−1, + 1}
q (w)
Loss
10
Natural parameter
Expectation parameter
Natural parameter
Learning rate
11
Natural parameter
Expectation parameter
Natural parameter
Learning rate How to compute?
12
Natural parameter
Expectation parameter
Natural parameter
Learning rate
natural gradient by using the mini-batch gradient
Minibatch Gradient, easy to compute! Scale vector How to compute?
13
Note that in BayesBiNN corresponds to
wr λ
14
τ → 0
Note that in BayesBiNN corresponds to
wr λ
15
τ → 0
Note that in BayesBiNN corresponds to
wr λ
16
̂ pk ← 1
C ∑C c=1 p (y = k|x, w(c)), C = 10
~
w(c) q(w)
classification boundaries
Classification on two moons dataset
17
Overcoming catastrophic forgetting
18
using the intrinsic KL divergence as regularization
Common Method: Regularizing weights
Overcoming catastrophic forgetting
19
min
qt(w) 𝔽qt(w) [ ∑ i∈Dt
ℓ(yt
i, fw(xt i))] + 𝔼KL (qt (w)||p (w))
Prior Distribution (uniform)
using the intrinsic KL divergence as regularization
Common Method: Regularizing weights
1. Kirkpatrick, J. et al. Overcoming catastrophic forgetting in neural networks. PANS, 114(13):3521–3526, 2017.
Independent Learning
Overcoming catastrophic forgetting
20
min
qt(w) 𝔽qt(w) [ ∑ i∈Dt
ℓ(yt
i, fw(xt i))] + 𝔼KL (qt (w)||qt−1 (w))
posterior approximation after task t − 1
using the intrinsic KL divergence as regularization
Common Method: Regularizing weights
1. Kirkpatrick, J. et al. Overcoming catastrophic forgetting in neural networks. PANS, 114(13):3521–3526, 2017.
Continual Learning
Overcoming catastrophic forgetting
21
min
qt(w) 𝔽qt(w) [ ∑ i∈Dt
ℓ(yt
i, fw(xt i))] + 𝔼KL (qt (w)||qt−1 (w))
posterior approximation after task t − 1
λt ← (1 − α)λt − α [s ⊙ g−λt−1]
Learned natural parameter after task t − 1
using the intrinsic KL divergence as regularization
Common Method: Regularizing weights
1. Kirkpatrick, J. et al. Overcoming catastrophic forgetting in neural networks. PANS, 114(13):3521–3526, 2017.
Continual Learning
22 Training on task 2 Training on task 3
Test Accuracy
Our method
Catastrophic forgetting of task 1
Training on Task 1
Note: For other tasks, refer to paper
Permuted MNIST
23 Training on task 2 Training on task 3
Test Accuracy
Our method
Catastrophic forgetting of task 1
Training on Task 1
Note: For other tasks, refer to paper
the distribution over binary weights become more and more deterministic
Permuted MNIST
github.com/team-approx-bayes/BayesBiNN
24
Bayesian Learning Rule, which can justify some previous approaches
which can be used for continual learning
25