Efficient Meta Learning via Minibatch Proximal Update Pan Zhou - - PowerPoint PPT Presentation

efficient meta learning via minibatch proximal update
SMART_READER_LITE
LIVE PREVIEW

Efficient Meta Learning via Minibatch Proximal Update Pan Zhou - - PowerPoint PPT Presentation

Efficient Meta Learning via Minibatch Proximal Update Pan Zhou Joint work with Xiao-Tong Yuan, Huan Xu, Shuicheng Yan, Jiashi Feng National University of Singapore pzhou@u.nus.edu Dec 11, 2019 1 Meta Learning via Minibatch Proximal Update


slide-1
SLIDE 1

Efficient Meta Learning via Minibatch Proximal Update

Pan Zhou

Joint work with Xiao-Tong Yuan, Huan Xu, Shuicheng Yan, Jiashi Feng National University of Singapore pzhou@u.nus.edu Dec 11, 2019

1

slide-2
SLIDE 2

Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx)

2

Meta-MinibatchProx learns a good prior model initialization

from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning

slide-3
SLIDE 3

Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx)

  • Training model: given a task distribution

, we minimize a bi-level meta learning model where each task has training samples is empirical loss with predictor and loss .

3

Meta-MinibatchProx learns a good prior model initialization

from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning

slide-4
SLIDE 4

Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx)

Meta-MinibatchProx learns a good prior model initialization

from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning

4

update task-specific solution

  • Training model: given a task distribution

, we minimize a bi-level meta learning model where each task has training samples is empirical loss with predictor and loss .

slide-5
SLIDE 5

Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx)

5

Meta-MinibatchProx learns a good prior model initialization

from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning

  • Training model: given a task distribution

, we minimize a bi-level meta learning model

update the prior model

where each task has training samples is empirical loss with predictor and loss .

slide-6
SLIDE 6

Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx)

small average distance to optimum models of all tasks in expectation

6

Meta-MinibatchProx learns a good prior model initialization

from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning

  • Training model: given a task distribution

, we minimize a bi-level meta learning model where each task has training samples is empirical loss with predictor and loss .

slide-7
SLIDE 7

Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx)

7

Meta-MinibatchProx learns a good prior model initialization

from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning

  • Test model: given a randomly sampled task consisting of K samples

where denotes the learnt prior initialization.

slide-8
SLIDE 8

Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx)

  • Benefit: a few data is sufficient for adaptation

the learnt prior initialization is close to optimum when training and test tasks are sampled from the same distribution.

small distance in expectation

8

Meta-MinibatchProx learns a good prior model initialization

from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning

  • Test model: given a randomly sampled task consisting of K samples

where denotes the learnt prior initialization.

slide-9
SLIDE 9

Optimization Algorithm

We use SGD based algorithm to solve bi-level training model :

9

slide-10
SLIDE 10

Optimization Algorithm

We use SGD based algorithm to solve bi-level training model :

10

  • Step1. select a mini-batch of task
  • f size

.

slide-11
SLIDE 11

Optimization Algorithm

We use SGD based algorithm to solve bi-level training model :

11

  • Step1. select a mini-batch of task
  • f size

.

  • Step2. for , compute an approximate minimizer:
slide-12
SLIDE 12

Optimization Algorithm

We use SGD based algorithm to solve bi-level training model :

  • Step3. update the prior initialization model:

12

  • Step1. select a mini-batch of task
  • f size

.

  • Step2. for , compute an approximate minimizer:
slide-13
SLIDE 13

Optimization Algorithm

We use SGD based algorithm to solve bi-level training model :

  • Step3. update the prior initialization model:

Theorem 1 (convergence guarantees, informal).

(1) Convex setting, i.e. convex . We prove (2) Nonconvex setting, i.e. smooth . We prove

13

  • Step2. for , compute an approximate minimizer:
  • Step1. select a mini-batch of task
  • f size

.

slide-14
SLIDE 14

Generalization Performance Guarantee

14

  • In practice, we has only K samples and adapt the learnt prior model to the new task:
  • Ideally, for a given task , one should train the model on the population risk
  • Since , why is good for generalization in few-shot learning problem?
slide-15
SLIDE 15

Generalization Performance Guarantee

  • Since , why is good for generalization in few-shot learning problem?

Theorem 2 (generalization performance guarantee, informal).

Suppose each loss is convex and is smooth. Let . Then we have Remark: strong generalization performance, as our training model guarantees the learnt prior is close to the optimum model .

15

  • In practice, we has only K samples and adapt the learnt prior model to the new task:
  • Ideally, for a given task , one should train the model on the population risk
slide-16
SLIDE 16

Experimental results

47 52 57 62 67 72 1-shot 5-way 5-shot 5-way 1-shot 5-way 5-shot 5-way MAML FOMAML Reptile Ours

0.8% 1.15% 3.31% 1.44%

15 25 35 45 55 1-shot 20-way 5-shot 20-way 1-shot 10-way 5-shot 10-way MAML FOMAML Reptile Ours

2.41% 1.18% 1.12% 5.15%

Few-shot regression : smaller mean square error (MSE) between prediction and ground truth Few-shot classification: higher classification accuracy

miniImageNet tieredImageNet miniImageNet tieredImageNet

16

slide-17
SLIDE 17

POSTER # 26 05:00 -- 07:00 PM @ East Exhibition Hall B + C Thanks!

17