Learning To Stop While Learning To Predict
Xinshi Chen1, Hanjun Dai2, Yu Li3, Xin Gao3, Le Song1,4
1Georgia Tech, 2Google Brain, 3KAUST, 4Ant Financial
ICML 2020
Learning To Stop While Learning To Predict Xinshi Chen 1 , Hanjun - - PowerPoint PPT Presentation
Learning To Stop While Learning To Predict Xinshi Chen 1 , Hanjun Dai 2 , Yu Li 3 , Xin Gao 3 , Le Song 1,4 1 Georgia Tech, 2 Google Brain, 3 KAUST, 4 Ant Financial ICML 2020 5-minute Core Message Dynamic Depth stop at different depths for
Xinshi Chen1, Hanjun Dai2, Yu Li3, Xin Gao3, Le Song1,4
1Georgia Tech, 2Google Brain, 3KAUST, 4Ant Financial
ICML 2020
ππ ππ
stop? stop, output ππ
ππ
no
stop?
no
stop?
no
ππ
stop?
yes
ππ ππ
stop? stop, output ππ
no
stop?
yes stop at different depths for different input samples.
stopped depth=4 stopped depth=2
5-minute Core Message
Task 1: fewer samples Task 2: more samples Need different numbers of gradient steps for adaptation
5-minute Core Message
ππ
(output)
π ππ
not satisfied
stop criteria
hand-designed update step Traditional algorithms have certain stop criteria to determine the number of iterations for each problem. E.g.,
Deep learning based algorithms usually have a fixed number of iterations in the architecture.
5-minute Core Message
Image Denoising
Image Recognition
[Teerapittayanon et al., 2016; Zamir et al., 2017; Huang et al., 2018, Kaya et al. (2019)]
noisy less noisy
5-minute Core Message
Predictive model ππΎ
Stopping Policy ππ
ππ ππ
ππ
stop, output ππ
ππ ππ π
πΎπ πΎπ πΎπ πΎπ
ππ ππ ππ
1
5-minute Core Message
predictive model stop policy
Variational stop time distribution ππ
ππ π’ = π=(π¦() βBC,
(D,(1 β π=(π¦B)
variational stop time distribution
β(ππΎ, ππ)
πβ|ππΎ: = argminπβQRST β(ππΎ, π)
(ii) The imitation learning stage (i) The oracle model learning stage
5-minute Core Message
πβ|β±πΎ β±πΎ
β(ππΎ, πβ|β±πΎ)
KL divergence
πβ|β±πΎβ ππ
5-minute Core Message
ΓΌ Principled
ΓΌ Tuning-free
ΓΌ Efficient
ΓΌ Generic
ΓΌ Better understanding
5-minute Core Message
l Learning to optimize: sparse recovery l Task-imbalanced meta learning: few-shot learning l Image denoising l Some observations on image recognition tasks.
Predictive model ππΎ
/Y(π(D,), for π’ = 1,2, β¦ , π
Stopping Policy ππ
Variational stop time distribution ππ (induced by ππ)
(D,(1 β πB) for π’ < π
Pr[not stopped before t]
β β±/, π=; π¦, π§ = π½(βΌabπ π§, π¦(; π β πΎπΌ π=
entropy loss in expectation over π’
min
/,= β β±/, π=; π¦, π§
equivalent max
/,= π¦hDijk β±/, π=; π¦, π§
(i.e., πΎ-VAE, ELBO)
Oracle stop time distribution:
π/
β β π§, π¦ β argmax πβQRST π¦hDijk β±/, π; π¦, π§
= π/ π§ π’, π¦ ,/h β(C,
6
π/ π§ π’, π¦ ,/h
Interpretation:
β π’ π§, π¦ = π/ π’ π§, π¦
knowledge of the true label π§. Stage I. Oracle model learning max
/
1 |π | r
(s,t)βπ
π¦hDijk ππΎ, ππΎ
β ; π¦, π§ = max /
1 |π | r
(s,t)βπ
r
(C, 6
ππΎ
β π’ π§, π¦ log ππΎ(π|π, π)
likelihood of the
Stage II. Imitation With Sequential Policy Recall: Variational stop time distribution ππ π’|π¦ induced by the sequential policy ππ Hope: ππ π’|π¦ can mimic the oracle distribution ππΎβ
β (π’|π§, π¦), by optimizing the forward KL divergence:
KL(π/β
β | π= = β r (C, 6
π/β
β
π’ π§, π¦ log π= π’ π¦ β πΌ(π/β
β )
Note: If we use reverse KL divergence, then it is equivalent to solving maximum-entropy RL. forward KL divergence
β LASSO formulation min
β Solved by iterative algorithms such as ISTA
β Learned ISTA (LISTA) is a deep architecture designed based on ISTA update steps
is better than LISTA.
β Built on top of MAML, but MAML-stop learns how many adaptation gradient descent steps are needed for each task.
Task-imbalanced setting: Vanilla setting:
β Built on top of one of the most popular models, DnCNN, for the denoising task. *Noise-level 65, 75 are not observed during training.