Learning To Stop While Learning To Predict Xinshi Chen 1 , Hanjun - - PowerPoint PPT Presentation

β–Ά
learning to stop while learning to predict
SMART_READER_LITE
LIVE PREVIEW

Learning To Stop While Learning To Predict Xinshi Chen 1 , Hanjun - - PowerPoint PPT Presentation

Learning To Stop While Learning To Predict Xinshi Chen 1 , Hanjun Dai 2 , Yu Li 3 , Xin Gao 3 , Le Song 1,4 1 Georgia Tech, 2 Google Brain, 3 KAUST, 4 Ant Financial ICML 2020 5-minute Core Message Dynamic Depth stop at different depths for


slide-1
SLIDE 1

Learning To Stop While Learning To Predict

Xinshi Chen1, Hanjun Dai2, Yu Li3, Xin Gao3, Le Song1,4

1Georgia Tech, 2Google Brain, 3KAUST, 4Ant Financial

ICML 2020

slide-2
SLIDE 2

Dynamic Depth

π’šπŸ π’šπŸ‘

stop? stop, output π’šπŸ“

π’šπŸ’

no

stop?

no

stop?

no

π’šπŸ“

stop?

yes

π’šπŸ π’šπŸ‘

stop? stop, output π’šπŸ‘

no

stop?

yes stop at different depths for different input samples.

stopped depth=4 stopped depth=2

5-minute Core Message

slide-3
SLIDE 3

Motivation

  • 1. Task-imbalanced Meta Learning

πœ„' πœ„()*+, πœ„()*+-

βˆ‡/β„’, βˆ‡/β„’-

Task 1: fewer samples Task 2: more samples Need different numbers of gradient steps for adaptation

5-minute Core Message

slide-4
SLIDE 4

Motivation

  • 2. Data-driven Algorithm Design

π’šπ’–

(output)

π’š π’šπ’–

not satisfied

stop criteria

hand-designed update step Traditional algorithms have certain stop criteria to determine the number of iterations for each problem. E.g.,

  • iterate until convergence
  • early stopping to avoid over-fitting

Deep learning based algorithms usually have a fixed number of iterations in the architecture.

5-minute Core Message

slide-5
SLIDE 5

Motivation

  • 3. Others

Image Denoising

  • Images with different noise levels may need different number of denoising steps.

Image Recognition

  • β€˜early exits’ is proposed to improve the computation efficiency and avoid β€˜over-thinking’.

[Teerapittayanon et al., 2016; Zamir et al., 2017; Huang et al., 2018, Kaya et al. (2019)]

noisy less noisy

5-minute Core Message

slide-6
SLIDE 6

Predictive Model with Stopping Policy

Predictive model π“–πœΎ

  • Transforms the input π’š to generate a path of states π’š,, … , π’š6

Stopping Policy 𝝆𝝔

  • Sequentially observes the states π’š( and determines the probability of stop at layer 𝑒

π’šπŸ π’šπŸ‘

𝝆𝝔

stop, output π’šπŸ“

π’šπŸ’ π’šπŸ“ π’š

𝜾𝟐 πœΎπŸ‘ πœΎπŸ’ πœΎπŸ“

𝝆𝝔 𝝆𝝔 𝝆𝝔

1

5-minute Core Message

predictive model stop policy

Variational stop time distribution 𝒓𝝔

  • Stop time distribution induced by stopping policy 𝝆𝝔

𝒓𝝔 𝑒 = 𝜌=(𝑦() ∏BC,

(D,(1 βˆ’ 𝜌=(𝑦B)

variational stop time distribution

slide-7
SLIDE 7

How to learn the optimal (π“–πœΎ, 𝝆𝝔) efficiently?

  • Design a joint training objective:

β„’(π“–πœΎ, 𝒓𝝔)

  • Introduce an oracle stop time distribution:

π’“βˆ—|π“–πœΎ: = argminπ’“βˆˆQRST β„’(π“–πœΎ, 𝒓)

  • Then we decompose the learning procedure into two stages:

(ii) The imitation learning stage (i) The oracle model learning stage

5-minute Core Message

π‘Ÿβˆ—|β„±πœΎ β„±πœΎ

β„’(π“–πœΎ, π’“βˆ—|β„±πœΎ)

  • ptimal β„±πœΎβˆ—

KL divergence

  • ptimal π’“π”βˆ—

π‘Ÿβˆ—|β„±πœΎβˆ— 𝒓𝝔

  • racle
  • racle
slide-8
SLIDE 8

Advantages of our training procedure

5-minute Core Message

ΓΌ Principled

  • Two components are optimized towards a joint objective.

ΓΌ Tuning-free

  • Weights of different layers in the loss are given by the oracle distribution automatically.
  • For different input samples, the weights on the layers can be different.

ΓΌ Efficient

  • Instead of updating πœ„ and 𝜚 alternatively, 𝜾 is optimized in 1st stage, and then 𝜚 is optimized in 2nd stage.

ΓΌ Generic

  • can be applied to a diverse range of applications.

ΓΌ Better understanding

  • A variational Bayes perspective, for better understanding the proposed model and joint training.
  • A reinforcement learning perspective, for better understanding the learning of the stop policy.
slide-9
SLIDE 9

Experiments

5-minute Core Message

l Learning to optimize: sparse recovery l Task-imbalanced meta learning: few-shot learning l Image denoising l Some observations on image recognition tasks.

slide-10
SLIDE 10

Problem Formulation - Models

Predictive model π“–πœΎ

  • π’š( = 𝑔

/Y(π’š(D,), for 𝑒 = 1,2, … , π‘ˆ

Stopping Policy 𝝆𝝔

  • 𝜌( = 𝜌= π’š, π’š( , for 𝑒 = 1,2, … , π‘ˆ

Variational stop time distribution 𝒓𝝔 (induced by 𝝆𝝔)

  • π‘Ÿ= 𝑒 = 𝜌( ∏BC,

(D,(1 βˆ’ 𝜌B) for 𝑒 < π‘ˆ

Pr[not stopped before t]

  • Help design the training objective and the algorithm.
slide-11
SLIDE 11

Problem Formulation – Optimization Objective

β„’ β„±/, π‘Ÿ=; 𝑦, 𝑧 = 𝔽(∼abπ‘š 𝑧, 𝑦(; πœ„ βˆ’ 𝛾𝐼 π‘Ÿ=

entropy loss in expectation over 𝑒

  • Variational Bayes Perspective

min

/,= β„’ β„±/, π‘Ÿ=; 𝑦, 𝑧

equivalent max

/,= 𝒦hDijk β„±/, π‘Ÿ=; 𝑦, 𝑧

(i.e., 𝛾-VAE, ELBO)

slide-12
SLIDE 12

Training Algorithm – Stage I

Oracle stop time distribution:

π‘Ÿ/

βˆ— β‹… 𝑧, 𝑦 ≔ argmax π’“βˆˆQRST 𝒦hDijk β„±/, 𝒓; 𝑦, 𝑧

= π‘ž/ 𝑧 𝑒, 𝑦 ,/h βˆ‘(C,

6

π‘ž/ 𝑧 𝑒, 𝑦 ,/h

Interpretation:

  • It is the optimal stop time distribution given a predictive model β„±/
  • When 𝛾 = 1, the oracle is the true posterior, π‘Ÿ/

βˆ— 𝑒 𝑧, 𝑦 = π‘ž/ 𝑒 𝑧, 𝑦

  • This posterior is computationally tractable, but it requires the

knowledge of the true label 𝑧. Stage I. Oracle model learning max

/

1 |𝒠| r

(s,t)βˆˆπ’ 

𝒦hDijk π“–πœΎ, π’“πœΎ

βˆ— ; 𝑦, 𝑧 = max /

1 |𝒠| r

(s,t)βˆˆπ’ 

r

(C, 6

π’“πœΎ

βˆ— 𝑒 𝑧, 𝑦 log π’’πœΎ(𝒛|𝒖, π’š)

likelihood of the

  • utput at 𝑒-th layer
slide-13
SLIDE 13

Training Algorithm – Stage II

Stage II. Imitation With Sequential Policy Recall: Variational stop time distribution 𝒓𝝔 𝑒|𝑦 induced by the sequential policy 𝝆𝝔 Hope: 𝒓𝝔 𝑒|𝑦 can mimic the oracle distribution π’“πœΎβˆ—

βˆ— (𝑒|𝑧, 𝑦), by optimizing the forward KL divergence:

KL(π‘Ÿ/βˆ—

βˆ— | π‘Ÿ= = βˆ’ r (C, 6

π‘Ÿ/βˆ—

βˆ—

𝑒 𝑧, 𝑦 log π‘Ÿ= 𝑒 𝑦 βˆ’ 𝐼(π‘Ÿ/βˆ—

βˆ— )

Note: If we use reverse KL divergence, then it is equivalent to solving maximum-entropy RL. forward KL divergence

slide-14
SLIDE 14

Experiment I - Learning To Optimize: Sparse Recovery

  • Task: Recover π‘¦βˆ— from its noisy measurements 𝑐 = π΅π‘¦βˆ— + πœ—
  • Traditional Approach:

– LASSO formulation min

  • Β½||𝑐 βˆ’ 𝐡𝑦||-
  • + 𝜍||𝑦||,

– Solved by iterative algorithms such as ISTA

  • Learning-based Algorithm:

– Learned ISTA (LISTA) is a deep architecture designed based on ISTA update steps

  • Ablation study: Whether LISTA with adaptive depth (LISTA-stop)

is better than LISTA.

slide-15
SLIDE 15

Experiment II – Task-imbalanced Meta Learning

  • Task: Task-imbalanced few-shot learning. Each task contains k-shots for each class where k can vary.
  • Our variant, MAML-stop:

– Built on top of MAML, but MAML-stop learns how many adaptation gradient descent steps are needed for each task.

Task-imbalanced setting: Vanilla setting:

slide-16
SLIDE 16

Experiment III – Image Denoising

  • Our variant, DnCNN-stop:

– Built on top of one of the most popular models, DnCNN, for the denoising task. *Noise-level 65, 75 are not observed during training.