T ASK N ORM : Rethinking Batch Normalization for Meta-Learning John - - PowerPoint PPT Presentation

t ask n orm
SMART_READER_LITE
LIVE PREVIEW

T ASK N ORM : Rethinking Batch Normalization for Meta-Learning John - - PowerPoint PPT Presentation

T ASK N ORM : Rethinking Batch Normalization for Meta-Learning John Bronskill Jonathan Gordon James Requeima Sebastian Nowozin Richard E. Turner University of University of University of Microsoft Research University of Cambridge


slide-1
SLIDE 1

TASKNORM:

Rethinking Batch Normalization for Meta-Learning

Department of Engineering

Jonathan Gordon University of Cambridge John Bronskill University of Cambridge Sebastian Nowozin Microsoft Research Richard E. Turner University of Cambridge, Microsoft Research

Paper: *Bronskill, J. *Gordon, J. Requeima, J., Nowozin, S. and Turner, R.E. “TaskNorm: Rethinking Batch Normalization for Meta-Learning.” Proceedings of the 37th International Conference on Machine Learning, PMLR 108 (2020). *Equal contribution. Code: https://github.com/cambridge-mlg/cnaps

James Requeima University of Cambridge, Invenia Labs

slide-2
SLIDE 2

TaskNorm: Batch Normalization for Meta-learning with Images

  • We demonstrate the significant effect of batch normalization (BN) on

meta-learning image classification accuracy and training efficiency.

  • We identify issues with transductive BN schemes used in well known

meta-learning algorithms.

  • We introduce TASKNORM, a normalization algorithm that is tailored for

the meta-learning setting and improves both image classification accuracy and training efficiency.

slide-3
SLIDE 3

Meta-Learning

slide-4
SLIDE 4

Meta-Learning

➢ Early Machine Learning: Learn classifier based on engineered features

slide-5
SLIDE 5

Meta-Learning

➢ Early Machine Learning: Learn classifier based on engineered features ➢ Deep learning: Jointly learn features and classifier

slide-6
SLIDE 6

Meta-Learning

➢ Early Machine Learning: Learn classifier based on engineered features ➢ Deep learning: Jointly learn classifier and model ➢ Meta-Learning: Jointly learn features, classifier, and algorithm[1]

[1] Hospedales, Timothy, et al. "Meta-learning in neural networks: A survey." arXiv preprint arXiv:2004.05439 (2020).

slide-7
SLIDE 7

Meta-Learning

➢ Early Machine Learning: Learn model based on engineered features ➢ Deep learning: Jointly learn features and model ➢ Meta-Learning: Jointly learn features, model, and algorithm[1]

[1] Hospedales, Timothy, et al. "Meta-learning in neural networks: A survey." arXiv preprint arXiv:2004.05439 (2020). [2] Sergey Levine & Chelsea Finn - Meta-Learning: from Few-Shot Learning to Rapid Reinforcement Learning: https://metalearning-cvpr2019.github.io/assets/CVPR_2019_Metalearning_Tutorial_Chelsea_Finn.pdf

Given a task distribution, learn a new task efficiently.[2]

slide-8
SLIDE 8

Meta-Learning

➢ Early Machine Learning: Learn model based on engineered features ➢ Deep learning: Jointly learn features and model ➢ Meta-Learning: Jointly learn features, model, and algorithm[1] ➢ Focus on utilizing meta-learning in the few-shot classification scenario

[1] Hospedales, Timothy, et al. "Meta-learning in neural networks: A survey." arXiv preprint arXiv:2004.05439 (2020). [2] Sergey Levine & Chelsea Finn - Meta-Learning: from Few-Shot Learning to Rapid Reinforcement Learning: https://metalearning-cvpr2019.github.io/assets/CVPR_2019_Metalearning_Tutorial_Chelsea_Finn.pdf

Given a task distribution, learn a new task efficiently.[2]

slide-9
SLIDE 9

Few-Shot Meta-Training / Meta-Testing

slide-10
SLIDE 10

Few-Shot Meta-Training / Meta-Testing

Hugo Larochelle – Generalizing From Few Examples With Meta-Learning: https://www.dropbox.com/s/sm68skkkbxbob0i/metalearning.pdf?dl=0

Context Set (𝐸𝜐) Target Set (𝑈

𝜐)

Task 𝜐

Target Set (𝑈

𝜐)

Context Set (𝐸𝜐)

slide-11
SLIDE 11

𝐸1 𝑈

1

stopwatch meter stopwatch watch clock clock

Few-Shot Meta-Training / Meta-Testing

Hugo Larochelle – Generalizing From Few Examples With Meta-Learning: https://www.dropbox.com/s/sm68skkkbxbob0i/metalearning.pdf?dl=0

Meta-Train

Context Set (𝐸𝜐) Target Set (𝑈

𝜐)

Task 𝜐

Target Set (𝑈

𝜐)

Context Set (𝐸𝜐)

slide-12
SLIDE 12

Few-Shot Meta-Training / Meta-Testing

Hugo Larochelle – Generalizing From Few Examples With Meta-Learning: https://www.dropbox.com/s/sm68skkkbxbob0i/metalearning.pdf?dl=0

Meta-Learner

Context Images & Labels

Meta-Train

Context Set (𝐸𝜐) Target Set (𝑈

𝜐)

Task 𝜐

Target Set (𝑈

𝜐)

Context Set (𝐸𝜐)

𝐸1 𝑈

1

stopwatch meter stopwatch watch clock clock

slide-13
SLIDE 13

Few-Shot Meta-Training / Meta-Testing

Hugo Larochelle – Generalizing From Few Examples With Meta-Learning: https://www.dropbox.com/s/sm68skkkbxbob0i/metalearning.pdf?dl=0

Meta-Learner Learner

Parameters

Context Images & Labels

Meta-Train

Context Set (𝐸𝜐) Target Set (𝑈

𝜐)

Task 𝜐

Target Set (𝑈

𝜐)

Context Set (𝐸𝜐)

𝐸1 𝑈

1

stopwatch meter stopwatch watch clock clock

slide-14
SLIDE 14

Few-Shot Meta-Training / Meta-Testing

Hugo Larochelle – Generalizing From Few Examples With Meta-Learning: https://www.dropbox.com/s/sm68skkkbxbob0i/metalearning.pdf?dl=0

Meta-Learner Learner

Parameters

Context Images & Labels

Predictions

Target Images

Meta-Train

𝐸1 𝑈

1

stopwatch meter stopwatch watch clock clock

Context Set (𝐸𝜐) Target Set (𝑈

𝜐)

Task 𝜐

Target Set (𝑈

𝜐)

Context Set (𝐸𝜐)

slide-15
SLIDE 15

𝐸1 𝑈

1

stopwatch meter stopwatch watch clock clock

Few-Shot Meta-Training / Meta-Testing

Hugo Larochelle – Generalizing From Few Examples With Meta-Learning: https://www.dropbox.com/s/sm68skkkbxbob0i/metalearning.pdf?dl=0

Meta-Learner Learner Loss

Parameters

Context Images & Labels

Predictions

Target Images Target Labels

Meta-Train

Context Set (𝐸𝜐) Target Set (𝑈

𝜐)

Task 𝜐

Target Set (𝑈

𝜐)

Context Set (𝐸𝜐)

slide-16
SLIDE 16

Meta-Training / Meta-Testing

Hugo Larochelle – Generalizing From Few Examples With Meta-Learning: https://www.dropbox.com/s/sm68skkkbxbob0i/metalearning.pdf?dl=0

𝐸2 𝑈2

Aramaic8 Aramaic15 Aramaic19 Aramaic19 Aramaic9 Aramaic9

Meta-Train

Context Set (𝐸𝜐) Target Set (𝑈

𝜐)

Task 𝜐

Target Set (𝑈

𝜐)

Context Set (𝐸𝜐)

𝐸1 𝑈

1

stopwatch meter stopwatch watch clock clock

slide-17
SLIDE 17

𝐸2 𝑈2

Aramaic8 Aramaic15 Aramaic19 Aramaic19 Aramaic9 Aramaic9

Meta-Training / Meta-Testing

Hugo Larochelle – Generalizing From Few Examples With Meta-Learning: https://www.dropbox.com/s/sm68skkkbxbob0i/metalearning.pdf?dl=0

Meta-Learner Learner Loss

Parameters

Context Images & Labels Target Images

Meta-Train

Predictions

Target Labels

Context Set (𝐸𝜐) Target Set (𝑈

𝜐)

Task 𝜐

Target Set (𝑈

𝜐)

Context Set (𝐸𝜐)

𝐸1 𝑈

1

stopwatch meter stopwatch watch clock clock

slide-18
SLIDE 18

Meta-Training / Meta-Testing

Hugo Larochelle – Generalizing From Few Examples With Meta-Learning: https://www.dropbox.com/s/sm68skkkbxbob0i/metalearning.pdf?dl=0

𝐸1 𝑈

1

stopwatch meter stopwatch watch clock clock

Meta-Train

Context Set (𝐸𝜐) Target Set (𝑈

𝜐)

Task 𝜐

Target Set (𝑈

𝜐)

Context Set (𝐸𝜐)

𝐸1 𝑈

1

stopwatch meter stopwatch watch clock clock

𝐸2 𝑈2

Aramaic8 Aramaic15 Aramaic19 Aramaic19 Aramaic9 Aramaic9

slide-19
SLIDE 19

Meta-Training / Meta-Testing

Hugo Larochelle – Generalizing From Few Examples With Meta-Learning: https://www.dropbox.com/s/sm68skkkbxbob0i/metalearning.pdf?dl=0

Meta-Test

𝐸1

𝑈

1 ∗

speed stop

?

no trucks curve

?

Meta-Train

Context Set (𝐸𝜐) Target Set (𝑈

𝜐)

Task 𝜐

Target Set (𝑈

𝜐)

Context Set (𝐸𝜐)

𝐸1 𝑈

1

stopwatch meter stopwatch watch clock clock

𝐸2 𝑈2

Aramaic8 Aramaic15 Aramaic19 Aramaic19 Aramaic9 Aramaic9

slide-20
SLIDE 20

Meta-Training / Meta-Testing

Hugo Larochelle – Generalizing From Few Examples With Meta-Learning: https://www.dropbox.com/s/sm68skkkbxbob0i/metalearning.pdf?dl=0

Meta-Test Meta-Train Meta-Learner Learner

Parameters

Context Images & Labels Target Images

Predictions Context Set (𝐸𝜐) Target Set (𝑈

𝜐)

Task 𝜐

Target Set (𝑈

𝜐)

Context Set (𝐸𝜐)

𝐸1 𝑈

1

stopwatch meter stopwatch watch clock clock

𝐸2 𝑈2

Aramaic8 Aramaic15 Aramaic19 Aramaic19 Aramaic9 Aramaic9

𝐸1

𝑈

1 ∗

speed stop

?

no trucks curve

?

slide-21
SLIDE 21

Batch Normalization

Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by Reducing internal covariate shift." arXiv preprint arXiv:1502.03167 (2015).

slide-22
SLIDE 22

Batch Normalization

➢ Goal: Normalize each training batch so that it has:

  • zero mean
  • unit variance

Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by Reducing internal covariate shift." arXiv preprint arXiv:1502.03167 (2015).

slide-23
SLIDE 23

Batch Normalization

➢ Goal: Normalize each training batch so that it has:

  • zero mean
  • unit variance

➢ Accelerates Neural Network training by:

  • Allowing the use of higher learning rates.
  • Decreasing the sensitivity to network initialization.

Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by Reducing internal covariate shift." arXiv preprint arXiv:1502.03167 (2015).

slide-24
SLIDE 24

“Conventional” Batch Normalization Algorithm

Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by Reducing internal covariate shift." arXiv preprint arXiv:1502.03167 (2015).

Training:

slide-25
SLIDE 25

“Conventional” Batch Normalization Algorithm

Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by Reducing internal covariate shift." arXiv preprint arXiv:1502.03167 (2015).

𝐶 = 𝑦1, 𝑦2, … , 𝑦𝑛 Training: # a mini-batch ⓪

slide-26
SLIDE 26

“Conventional” Batch Normalization Algorithm

Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by Reducing internal covariate shift." arXiv preprint arXiv:1502.03167 (2015).

𝐶 = 𝑦1, 𝑦2, … , 𝑦𝑛 Training: 𝜈𝐶 = 1 𝑛 ෍

𝑗=1 𝑛

𝑦𝑗 # a mini-batch # compute batch mean ① ⓪

slide-27
SLIDE 27

“Conventional” Batch Normalization Algorithm

Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by Reducing internal covariate shift." arXiv preprint arXiv:1502.03167 (2015).

𝐶 = 𝑦1, 𝑦2, … , 𝑦𝑛 Training: 𝜈𝐶 = 1 𝑛 ෍

𝑗=1 𝑛

𝑦𝑗 𝜏𝐶

2 = 1

𝑛 ෍

𝑗=1 𝑛

𝑦𝑗 − 𝜈𝐶 2 # a mini-batch # compute batch mean # compute batch variance ① ⓪ ②

slide-28
SLIDE 28

“Conventional” Batch Normalization Algorithm

Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by Reducing internal covariate shift." arXiv preprint arXiv:1502.03167 (2015).

𝐶 = 𝑦1, 𝑦2, … , 𝑦𝑛 Training: 𝜈𝐶 = 1 𝑛 ෍

𝑗=1 𝑛

𝑦𝑗 𝜏𝐶

2 = 1

𝑛 ෍

𝑗=1 𝑛

𝑦𝑗 − 𝜈𝐶 2 𝑦𝑗

′ = 𝛿 𝑦𝑗 − 𝜈𝐶

𝜏𝐶

2 + 𝜗

+ 𝛾 # a mini-batch # compute batch mean # compute batch variance # normalize # 𝛿, 𝛾 are learned # 𝜗 is a small constant to avoid division by 0 ① ⓪ ② ③

slide-29
SLIDE 29

“Conventional” Batch Normalization Algorithm

Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by Reducing internal covariate shift." arXiv preprint arXiv:1502.03167 (2015).

𝐶 = 𝑦1, 𝑦2, … , 𝑦𝑛 Training: 𝜈𝐶 = 1 𝑛 ෍

𝑗=1 𝑛

𝑦𝑗 𝜏𝐶

2 = 1

𝑛 ෍

𝑗=1 𝑛

𝑦𝑗 − 𝜈𝐶 2 𝑦𝑗

′ = 𝛿 𝑦𝑗 − 𝜈𝐶

𝜏𝐶

2 + 𝜗

+ 𝛾 Accumulate moving averages of 𝜈𝐶, 𝜏𝐶

2 over all batches as 𝜈𝑠, 𝜏𝑠 2

# a mini-batch # compute batch mean # compute batch variance # normalize # 𝛿, 𝛾 are learned # 𝜗 is a small constant to avoid division by 0 ① ⓪ ② ③ ④

slide-30
SLIDE 30

“Conventional” Batch Normalization Algorithm

Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by Reducing internal covariate shift." arXiv preprint arXiv:1502.03167 (2015).

𝐶 = 𝑦1, 𝑦2, … , 𝑦𝑛 Training: 𝜈𝐶 = 1 𝑛 ෍

𝑗=1 𝑛

𝑦𝑗 𝜏𝐶

2 = 1

𝑛 ෍

𝑗=1 𝑛

𝑦𝑗 − 𝜈𝐶 2 𝑦𝑗

′ = 𝛿 𝑦𝑗 − 𝜈𝐶

𝜏𝐶

2 + 𝜗

+ 𝛾 Inference: 𝑦𝑗

′ = 𝛿 𝑦𝑗 − 𝜈𝑠

𝜏𝑠

2 + 𝜗

+ 𝛾 Accumulate moving averages of 𝜈𝐶, 𝜏𝐶

2 over all batches as 𝜈𝑠, 𝜏𝑠 2

# a mini-batch # compute batch mean # compute batch variance # normalize # 𝛿, 𝛾 are learned # 𝜗 is a small constant to avoid division by 0 ① ⓪ ② ③ ④ # use moving averages to normalize

slide-31
SLIDE 31

“Conventional” Batch Normalization Algorithm

Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by Reducing internal covariate shift." arXiv preprint arXiv:1502.03167 (2015).

𝐶 = 𝑦1, 𝑦2, … , 𝑦𝑛 Training: 𝜈𝐶 = 1 𝑛 ෍

𝑗=1 𝑛

𝑦𝑗 𝜏𝐶

2 = 1

𝑛 ෍

𝑗=1 𝑛

𝑦𝑗 − 𝜈𝐶 2 𝑦𝑗

′ = 𝛿 𝑦𝑗 − 𝜈𝐶

𝜏𝐶

2 + 𝜗

+ 𝛾 Inference: 𝑦𝑗

′ = 𝛿 𝑦𝑗 − 𝜈𝑠

𝜏𝑠

2 + 𝜗

+ 𝛾 Accumulate moving averages of 𝜈𝐶, 𝜏𝐶

2 over all batches as 𝜈𝑠, 𝜏𝑠 2

# a mini-batch # compute batch mean # compute batch variance # normalize # 𝛿, 𝛾 are learned # 𝜗 is a small constant to avoid division by 0 ① ⓪ ② ③ ④ # use moving averages to normalize We call the mean and variance of a batch its moments.

slide-32
SLIDE 32

How should batch normalization for meta-learning work?

slide-33
SLIDE 33

How should batch normalization for meta-learning work?

➢ First idea: Use conventional batch normalization (CBN):

slide-34
SLIDE 34

How should batch normalization for meta-learning work?

➢ First idea: Use conventional batch normalization (CBN):

  • Meta-Training: Normalize with computed moments (𝝂𝐶𝑂, 𝝉𝐶𝑂

2 ).

Normalize Context 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

In Context Activations Out Context Activations

Normalize Target 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

In Target Activations Out Target Activations

Meta-Training

slide-35
SLIDE 35

How should batch normalization for meta-learning work?

➢ First idea: Use conventional batch normalization (CBN):

  • Meta-Training: Normalize with computed moments (𝝂𝐶𝑂, 𝝉𝐶𝑂

2 ).

  • Meta-Testing: Normalize with running averages of moments (𝜈𝑠, 𝜏𝑠

2)

that were computed during meta-training.

Normalize Context 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

In Context Activations Out Context Activations

Normalize Target 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

In Target Activations Out Target Activations

Normalize Context 𝝂𝑠, 𝝉𝑠

2

In Context Activations Out Context Activations

Normalize Target 𝝂𝑠, 𝝉𝑠

2

In Target Activations Out Target Activations

Meta-Training Meta-Testing 𝝂𝑠, 𝝉𝑠

2

slide-36
SLIDE 36

Classification Accuracy (%) of Model Agnostic Meta-Learning (MAML)

  • n Omniglot and miniImagNet datasets

Configuration CBN Omniglot 5-way, 1-shot 20.1±0.0 Omniglot 5-way, 5-shot 20.0±0.0 Omniglot 20-way, 1-shot 5.0±0.0 Omniglot 20-way, 5-shot 5.0±0.0 miniImageNet 5-way, 1-shot 20.1±0.0 miniImageNet 5-way, 5-shot 20.2±0.0 These results are terrible. The classification accuracy is no better than chance.

slide-37
SLIDE 37

MAML uses Transductive Batch Normalization (TBN)

slide-38
SLIDE 38

MAML uses Transductive Batch Normalization (TBN)

Normalize Context 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

In Context Activations Out Context Activations

Normalize Target 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

In Target Activations Out Target Activations

Meta-Training and Meta-Testing

  • TBN ignores the running moments (𝝂𝑠, 𝝉𝑠

2).

  • Uses computed moments (𝝂𝐶𝑂, 𝝉𝐶𝑂

2 ) to

normalize during both meta-training and meta-testing.

slide-39
SLIDE 39

MAML uses Transductive Batch Normalization (TBN)

Configuration CBN TBN Omniglot 5-way, 1-shot 20.1±0.0 98.4±0.7 Omniglot 5-way, 5-shot 20.0±0.0 99.2±0.2 Omniglot 20-way, 1-shot 5.0±0.0 90.9±0.5 Omniglot 20-way, 5-shot 5.0±0.0 96.6±0.2 miniImageNet 5-way, 1-shot 20.1±0.0 45.5±1.8 miniImageNet 5-way, 5-shot 20.2±0.0 59.7±0.9

The TBN accuracies are what we would expect for MAML. Normalize Context 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

In Context Activations Out Context Activations

Normalize Target 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

In Target Activations Out Target Activations

Meta-Training and Meta-Testing

  • TBN ignores the running moments (𝝂𝑠, 𝝉𝑠

2).

  • Uses computed moments (𝝂𝐶𝑂, 𝝉𝐶𝑂

2 ) to

normalize during both meta-training and meta-testing.

slide-40
SLIDE 40

Transductive vs Non-Transductive

slide-41
SLIDE 41

Transductive vs Non-Transductive

Non-Transductive 𝑞(𝑧1

∗|𝑦1 ∗, 𝐸1 ∗)

𝐸𝜐

𝑦2

speed stop no trucks curve

𝑦1

𝑧1

𝑧2

At meta-test time, the prediction for a label 𝑧𝑗

∗ for an input 𝑦𝑗 ∗

is conditioned only on 𝒚𝒋

∗ and the context set 𝐸𝜐 ∗.

slide-42
SLIDE 42

Transductive vs Non-Transductive

Non-Transductive Transductive 𝑞(𝑧1

∗|𝑦1 ∗, 𝐸1 ∗)

𝑞(𝑧1

∗|𝑦1 ∗, 𝑦2 ∗, 𝐸1 ∗)

𝐸𝜐

𝑦2

speed stop no trucks curve

𝑦1

𝑧1

𝑧2

𝐸1

𝑦2

speed stop no trucks curve

𝑦1

𝑧1

𝑧2

At meta-test time, the prediction for a label 𝑧𝑗

∗ for an input 𝑦𝑗 ∗

is conditioned only on 𝒚𝒋

∗ and the context set 𝐸𝜐 ∗.

At meta-test time, the prediction for a label 𝑧𝑗

∗ for an input 𝑦𝑗 ∗

is conditioned on all 𝒚∗ in the target set and the context set 𝐸𝜐

∗.

slide-43
SLIDE 43

Transductive Batch Normalization Issues

Note: Under normal circumstances, at meta-test time, we have no control over the makeup of the target set in terms of the relative proportions of the true labels as these are unknown. There are two key issues with TBN:

slide-44
SLIDE 44

Transductive Batch Normalization Issues

Note: Under normal circumstances, at meta-test time, we have no control over the makeup of the target set in terms of the relative proportions of the true labels as these are unknown. There are two key issues with TBN: 1. Transductive learning is sensitive to the distribution of the target set learned during meta-training and will fail if required to make good predictions:

  • One example at a time (e.g. online learning).
  • When the target set contains a class balance different from meta-training.
  • Respecting some privacy constraints.
slide-45
SLIDE 45

Transductive Batch Normalization Issues

Note: Under normal circumstances, at meta-test time, we have no control over the makeup of the target set in terms of the relative proportions of the true labels as these are unknown. There are two key issues with TBN: 1. Transductive learning is sensitive to the distribution of the target set learned during meta-training and will fail if required to make good predictions:

  • One example at a time (e.g. online learning).
  • When the target set contains a class balance different from meta-training.
  • Respecting some privacy constraints.

2. Transductive learners have more information available to them at prediction time, which may lead to unfair comparisons.

slide-46
SLIDE 46

➢ TBN accuracy degrades significantly when predictions are made one example at a time (streaming) or one class at a time (class imbalance).

Configuration CBN TBN TBN

(1 example at a time)

TBN

(1 class at a time)

Omniglot 5-way, 1-shot 20.1±0.0 98.4±0.7 21.6±1.3 21.6±1.3 Omniglot 5-way, 5-shot 20.0±0.0 99.2±0.2 22.0±0.5 23.2±0.5 Omniglot 20-way, 1-shot 5.0±0.0 90.9±0.5 3.7±0.2 3.7±0.2 Omniglot 20-way, 5-shot 5.0±0.0 96.6±0.2 5.5±0.2 14.5±0.3 miniImageNet 5-way, 1-shot 20.1±0.0 45.5±1.8 26.9±1.5 26.9±1.5 miniImageNet 5-way, 5-shot 20.2±0.0 59.7±0.9 30.3±0.7 27.2±0.6

Transductive Batch Normalization Issues (con’t)

slide-47
SLIDE 47

Need to Rethink Normalization for Meta-Learning

  • For MAML, CBN doesn’t work and TBN has potentially unwanted side effects
slide-48
SLIDE 48

Need to Rethink Normalization for Meta-Learning

  • For MAML, CBN doesn’t work and TBN has potentially unwanted side effects.
  • There are other non-transductive learners including Instance Normalization[1]

(IN), Layer Normalization[2] (LN), and Group Normalization[3], but they don’t work well in the few-shot classification setting.

Configuration CBN TBN LN IN Omniglot 5-way, 1-shot 20.1±0.0 98.4±0.7 83.0±1.3 87.4±1.2 Omniglot 5-way, 5-shot 20.0±0.0 99.2±0.2 91.0±0.8 93.9±0.5 Omniglot 20-way, 1-shot 5.0±0.0 90.9±0.5 78.1±0.7 80.4±0.7 Omniglot 20-way, 5-shot 5.0±0.0 96.6±0.2 92.3±0.2 92.9±0.2 miniImageNet 5-way, 1-shot 20.1±0.0 45.5±1.8 41.2±1.6 40.7±1.7 miniImageNet 5-way, 5-shot 20.2±0.0 59.7±0.9 52.8±0.9 54.3±0.9

[1] Ulyanov et al. "Instance normalization: The missing ingredient for fast stylization." arXiv:1607.08022 (2016). [2] Ba et al. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). [3] Wu et al. "Group normalization." Proceedings of the European Conference on Computer Vision (ECCV). 2018.

slide-49
SLIDE 49

Desiderata for Meta-Learning Normalization

slide-50
SLIDE 50

Desiderata for Meta-Learning Normalization

1. Improves speed and stability of training without harming test performance (accuracy or log-likelihood).

slide-51
SLIDE 51

Desiderata for Meta-Learning Normalization

1. Improves speed and stability of training without harming test performance (accuracy or log-likelihood). 2. Works well across a range of context set sizes.

slide-52
SLIDE 52

Desiderata for Meta-Learning Normalization

1. Improves speed and stability of training without harming test performance (accuracy or log-likelihood). 2. Works well across a range of context set sizes. 3. Is non-transductive, thus supporting inference at meta-test time in a variety of circumstances.

slide-53
SLIDE 53

A Few Principles

slide-54
SLIDE 54

A Few Principles

➢ Data is i.i.d. only within a task 𝜐, but not across tasks.

  • Hence, normalization statistics 𝝂, 𝝉 should be local at the task level.
slide-55
SLIDE 55

A Few Principles

➢ Data is i.i.d. only within a task 𝜐, but not across tasks.

  • Hence, normalization statistics 𝝂, 𝝉 should be local at the task level.

➢ To avoid being transductive, the target set 𝑈𝜐 normalization should only have access to:

  • 1. The context set 𝐸𝜐
  • 2. The single example being predicted 𝒚𝑗

𝜐∗

slide-56
SLIDE 56

MetaBN

➢ Simple idea inspired by the previous principles:

  • Use the batch statistics from the context set to normalize both the

context set and the target set.

Context 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

Normalize Context

In Context Activations Out Context Activations

slide-57
SLIDE 57

MetaBN

➢ Simple idea inspired by the previous principles:

  • Use the batch statistics from the context set to normalize both the

context set and the target set.

Context 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

Normalize Context

In Context Activations

Normalize Target

In Target Activations Out Context Activations Out Target Activations

slide-58
SLIDE 58

METABN

➢ MetaBN works well, but:

  • Classification accuracy suffers when the context set is small (poor

estimate of true statistics)

  • Doesn’t leverage information from the target example under test.
slide-59
SLIDE 59

TASKNORM

slide-60
SLIDE 60

TASKNORM

Context 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

Normalize Context

In Context Activations

Normalize Target

In Target Activations Out Context Activations Out Target Activations

slide-61
SLIDE 61

TASKNORM

Context 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

Target 𝝂+, 𝝉+

2

Context 𝝂+, 𝝉+

2

+ +

1 − 𝛽 1 − 𝛽 𝛽 Normalize Context

𝝂𝑈𝑂, 𝝉𝑈𝑂

2

In Context Activations

Normalize Target

In Target Activations

𝝂𝑈𝑂, 𝝉𝑈𝑂

2

Out Context Activations Out Target Activations

slide-62
SLIDE 62

TASKNORM

𝝂𝑈𝑂 = 𝛽𝝂𝐶𝑂 + 1 − 𝛽 𝝂+ 𝝉𝑈𝑂

2

= 𝛽 𝝉𝐶𝑂

2

+ 𝝂𝐶𝑂 − 𝝂𝑈𝑂 2 + (1 − 𝛽)(𝝉+

2 + 𝝂+ − 𝝂𝑈𝑂 2)

Context 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

Target 𝝂+, 𝝉+

2

Context 𝝂+, 𝝉+

2

+ +

1 − 𝛽 1 − 𝛽 𝛽 Normalize Context

𝝂𝑈𝑂, 𝝉𝑈𝑂

2

In Context Activations

Normalize Target

In Target Activations

𝝂𝑈𝑂, 𝝉𝑈𝑂

2

Out Context Activations Out Target Activations

slide-63
SLIDE 63

TASKNORM

𝝂𝑈𝑂 = 𝛽𝝂𝐶𝑂 + 1 − 𝛽 𝝂+ 𝛽 = 𝑇𝐽𝐻𝑁𝑃𝐽𝐸 𝑇𝐷𝐵𝑀𝐹 𝐸𝜐 + 𝑃𝐺𝐺𝑇𝐹𝑈 , 0 ≤ 𝛽 ≤ 1 𝝉𝑈𝑂

2

= 𝛽 𝝉𝐶𝑂

2

+ 𝝂𝐶𝑂 − 𝝂𝑈𝑂 2 + (1 − 𝛽)(𝝉+

2 + 𝝂+ − 𝝂𝑈𝑂 2)

𝑇𝐷𝐵𝑀𝐹, 𝑃𝐺𝐺𝑇𝐹𝑈 are learned during training

Context 𝝂𝐶𝑂, 𝝉𝐶𝑂

2

Target 𝝂+, 𝝉+

2

Context 𝝂+, 𝝉+

2

+ +

1 − 𝛽 1 − 𝛽 𝛽 Normalize Context

𝝂𝑈𝑂, 𝝉𝑈𝑂

2

In Context Activations

Normalize Target

In Target Activations

𝝂𝑈𝑂, 𝝉𝑈𝑂

2

Out Context Activations Out Target Activations

slide-64
SLIDE 64

Learned Alpha (𝜷) vs Context Set Size (𝑬𝝊)

Each curve is the learned value of 𝛽 in the first TASKNORM in each of the four ResNet 18 layers.

ResNet18 layers

slide-65
SLIDE 65

Learned Alpha (𝜷) vs Context Set Size (𝑬𝝊)

Each curve is the learned value of 𝛽 in the first TASKNORM in each of the four ResNet 18 layers. When the context set size is small (< 30), TASKNORM learns to use a blend of BN and IN moments.

TASKNORM uses a mix BN and IN moments ResNet18 layers

slide-66
SLIDE 66

Learned Alpha (𝜷) vs Context Set Size (𝑬𝝊)

Each curve is the learned value of 𝛽 in the first TASKNORM in each of the four ResNet 18 layers. When the context set size is small (< 30), TASKNORM learns to use a blend of BN and IN moments. When the context set size is large (> 30), TASKNORM learns to use only the BN moments.

TASKNORM uses

  • nly BN moments

TASKNORM uses a mix BN and IN moments ResNet18 layers

slide-67
SLIDE 67

SCALE * (Context Set Size) + OFFSET vs Context Set Size

Slopes are non-zero indicating that the

  • ptimal value of 𝛽 is

a function of context set size.

slide-68
SLIDE 68

TASKNORM Fixes the Transductive Issue in MAML

Configuration CBN TBN TBN

(1 example at a time)

TBN

(1 class at a time)

TaskNorm Omniglot 5-way, 1-shot 20.1±0.0 98.4±0.7 21.6±1.3 21.6±1.3 94.4±0.8 Omniglot 5-way, 5-shot 20.0±0.0 99.2±0.2 22.0±0.5 23.2±0.5 98.6±0.2 Omniglot 20-way, 1-shot 5.0±0.0 90.9±0.5 3.7±0.2 3.7±0.2 90.0±0.5 Omniglot 20-way, 5-shot 5.0±0.0 96.6±0.2 5.5±0.2 14.5±0.3 96.3±0.2 miniImageNet 5-way, 1-shot 20.1±0.0 45.5±1.8 26.9±1.5 26.9±1.5 42.4±1.7 miniImageNet 5-way, 5-shot 20.2±0.0 59.7±0.9 30.3±0.7 27.2±0.6 58.7±0.9

TASKNORM accuracy approaches that of TBN.

slide-69
SLIDE 69

TaskNorm Fixes the Transductive Issue in MAML

Configuration CBN TBN TBN

(1 example at a time)

TBN

(1 class at a time)

TaskNorm TaskNorm

(1 example at a time)

TaskNorm

(1 class at a time)

Omniglot 5-way, 1-shot 20.1±0.0 98.4±0.7 21.6±1.3 21.6±1.3 94.4±0.8 94.4±0.8 94.4±0.8 Omniglot 5-way, 5-shot 20.0±0.0 99.2±0.2 22.0±0.5 23.2±0.5 98.6±0.2 98.6±0.2 98.6±0.2 Omniglot 20-way, 1-shot 5.0±0.0 90.9±0.5 3.7±0.2 3.7±0.2 90.0±0.5 90.0±0.5 90.0±0.5 Omniglot 20-way, 5-shot 5.0±0.0 96.6±0.2 5.5±0.2 14.5±0.3 96.3±0.2 96.3±0.2 96.3±0.2 miniImageNet 5-way, 1-shot 20.1±0.0 45.5±1.8 26.9±1.5 26.9±1.5 42.4±1.7 42.4±1.7 42.4±1.7 miniImageNet 5-way, 5-shot 20.2±0.0 59.7±0.9 30.3±0.7 27.2±0.6 58.7±0.9 58.7±0.9 58.7±0.9

TASKNORM accuracy approaches that of TBN. TASKNORM accuracy does not change when tested differently.

slide-70
SLIDE 70

Meta-Dataset[1] Multi-task, Few-shot Benchmark

[1] Triantafillou, Eleni, et al. "Meta-dataset: A dataset of datasets for learning to learn from few examples.“ arXiv preprint arXiv:1903.03096 (2019).

ImageNet Omniglot Aircraft Birds DTD Quick Draw Fungi VGG Flower Traffic Signs MSCOCO

slide-71
SLIDE 71

Meta-Dataset[1] Multi-task, Few-shot Benchmark

[1] Triantafillou, Eleni, et al. "Meta-dataset: A dataset of datasets for learning to learn from few examples.“ arXiv preprint arXiv:1903.03096 (2019).

ImageNet Omniglot Aircraft Birds DTD Quick Draw Fungi VGG Flower Traffic Signs MSCOCO entirely held out

slide-72
SLIDE 72

Meta-Dataset Classification Accuracy Using ProtoNets[1]

Held out classes Held out datasets

TASKNORM with Instance Normalization is best on 10

  • f 13 datasets

TaskNorm achieves the highest overall rank of all methods including Transductive BatchNorm (TBN)

TBN = Transductive Batch Norm CBN = Conventional Batch Norm BRN = Batch Renormalization LN = Layer Normalization IN = Instance Normalization RN = Reptile Norm MetaBN = Meta Batch Norm TaskNorm-L = TaskNorm with LN TaskNorm-I = TaskNorm with IN TaskNorm-I = TaskNorm with running moments

[1] Snell, Jake, Kevin Swersky, and Richard Zemel. "Prototypical networks for few-shot learning." Advances in neural information processing systems. 2017.

slide-73
SLIDE 73

Meta-Dataset Classification Accuracy Using CNAPs[1]

Held out classes Held out datasets

TASKNORM with Instance Normalization is best on 11

  • f 13 datasets

TaskNorm achieves the highest overall rank of all methods including Transductive BatchNorm (TBN)

TBN = Transductive Batch Norm CBN = Conventional Batch Norm BRN = Batch Renormalization LN = Layer Normalization IN = Instance Normalization RN = Reptile Norm MetaBN = Meta Batch Norm TaskNorm-L = TaskNorm with LN TaskNorm-I = TaskNorm with IN TaskNorm-I = TaskNorm with running moments Baseline = No Normalization

[1] Requeima, James, et al. "Fast and flexible multi-task classification using conditional neural adaptive processes.“ Advances in Neural Information Processing Systems. 2019.

slide-74
SLIDE 74

Meta-Dataset Training Curves

TaskNorm-I converges the fastest.

slide-75
SLIDE 75

Thanks for watching!

  • Paper: https://arxiv.org/pdf/2003.03284.pdf
  • Code: https://github.com/cambridge-mlg/cnaps