THOUGHTS ON PROGRESS MADE AND CHALLENGES AHEAD IN FEW-SHOT LEARNING
Hugo Larochelle Google Brain
T HOUGHTS ON P ROGRESS M ADE AND C HALLENGES A HEAD IN F EW -S HOT L - - PowerPoint PPT Presentation
T HOUGHTS ON P ROGRESS M ADE AND C HALLENGES A HEAD IN F EW -S HOT L EARNING Hugo Larochelle Google Brain 3 Human-level concept learning People are through probabilistic good at it program induction Brenden M. Lake, 1 * Ruslan
Hugo Larochelle Google Brain
3
Brenden M. Lake,1* Ruslan Salakhutdinov,2 Joshua B. Tenenbaum3
Fei-Fei Li, Rob Fergus and Pietro Perona
Fei-Fei Li
Michael Fink
(2005)
Evgeniy Bart and Shimon Ullman
representation and algorithm better suited for few-shot learning
5
6
7 Dtrain Dtest
8 Dtrain Dtest
8
Dtrain Dtest
8
Dtrain Dtest
8
Dtrain Dtest
9
Dtrain Dtest
9
Dtrain Dtest
11
p(y|x, Dtrain)
12
p(y|x, Dtrain)
13
p(y|x, Dtrain)
14
k
i=1
x),g(xi))/ Pk j=1 ec(f(ˆ x),g(xj))
Oriol Vinyals, Charles Blundell, Timothy P. Lillicrap, Koray Kavukcuoglu, and Daan Wierstra
15
(xi,yi)∈Sk
k0 exp(−d(fφ(x), ck0))
Jake Snell, Kevin Swersky and Richard Zemel
16
16
Sachin Ravi and Hugo Larochelle
16
Chelsea Finn, Pieter Abbeel and Sergey Levine
17
p(y|x, Dtrain)
to map to a vector space
18 Supervised Learning
(Examples,
xt-1 yt-1 xt-2 yt-2 xt
yt-3
edicted Label
t
Nikhil Mishra, Mostafa Rohaninejad, Xi Chen and Pieter Abbeel
(a) Dense Block (dilation rate R, D lters) concatenate inputs, shape [T, C]
causal conv, kernel 2 dilation R, D lters
(b) Attention Block (key size K, value size V) concatenate inputs, shape [T, C]
a ne, output size K (query) a ne, output size K (keys) a ne, output size V (values) matmul, masked softmax matmul
p(y|x, Dtrain)
19
20
Model 5-class 1-shot 5-shot Baseline-finetune 28.86 ± 0.54% 49.79 ± 0.79% Baseline-nearest-neighbor 41.08 ± 0.70% 51.04 ± 0.65% Matching Network 43.40 ± 0.78% 51.09 ± 0.71% Matching Network FCE 43.56 ± 0.84% 55.31 ± 0.73% Meta-Learner LSTM (OURS) 43.44 ± 0.77% 60.60 ± 0.71%
21
Model 5-class 1-shot 5-shot Baseline-finetune 28.86 ± 0.54% 49.79 ± 0.79% Baseline-nearest-neighbor 41.08 ± 0.70% 51.04 ± 0.65% Matching Network 43.40 ± 0.78% 51.09 ± 0.71% Matching Network FCE 43.56 ± 0.84% 55.31 ± 0.73% Meta-Learner LSTM (OURS) 43.44 ± 0.77% 60.60 ± 0.71%
MAML (Finn et al.) Prototypical Nets (Snell et al.) SNAIL (Mishra et al.)
Eleni Triantafillou, Tyler Zhu, Vincent Dumoulin, Pascal Lamblin, Kelvin Xu, Ross Goroshin, Carles Gelada, Kevin Swersky, Pierre-Antoine Manzagol, Hugo Larochelle Google
23
(a) ImageNet (b) Omniglot (c) Aircraft (d) Birds (e) DTD (f) Quick Draw (g) Fungi (h) VGG Flower (i) Traffic Signs (j) MSCOCO
23
(a) ImageNet (b) Omniglot (c) Aircraft (d) Birds (e) DTD (f) Quick Draw (g) Fungi (h) VGG Flower (i) Traffic Signs (j) MSCOCO
24
Table 1: Results on META-DATASET using models trained on ILSVRC-2012 only. Test Source Method: Accuracy ± confidence k-NN Finetune MatchingNet ProtoNet MAML ILSVRC 34.70±0.95 38.34±1.12 40.89±1.08 43.37±1.17 38.10±1.13 Omniglot 59.84±0.96 59.19±1.18 61.85±1.00 66.18±1.12 54.00±1.47 Aircraft 36.47±0.93 41.18±1.07 41.91±0.96 42.14±0.97 42.52±1.16 Birds 40.38±1.09 45.82±1.25 54.26±1.16 57.85±1.23 50.78±1.32 Textures 56.45±0.78 58.06±0.88 61.70±0.84 60.95±0.80 61.26±0.93 Quick Draw 36.09±1.19 38.43±1.39 38.52±1.12 44.02±1.35 30.71±1.51 Fungi 23.70±0.97 22.20±0.92 27.21±0.97 31.18±1.15 20.35±0.87 VGG Flower 66.16±0.99 69.32±1.13 75.05±0.91 79.89±0.90 65.12±1.15 Traffic Signs 44.81±1.47 39.36±1.28 45.36±1.31 44.04±1.24 31.10±1.20 MSCOCO 29.69±1.00 30.25±1.17 32.32±1.08 36.44±1.23 25.17±1.15
4 3.4 2.2 1.35 4.05
25
Table 2: Results on META-DATASET using models trained on All datasets. Test Source Method: Accuracy ± confidence k-NN Finetune MatchingNet ProtoNet MAML ILSVRC 25.88±0.83 25.84±0.83 35.88±0.98 38.51±1.01 30.56±1.00 Omniglot 92.45±0.41 85.20±0.73 90.21±0.46 91.32±0.50 78.05±0.98 Aircraft 54.60±0.97 58.22±1.02 70.71±0.78 71.54±0.84 68.62±0.90 Birds 36.74±1.01 38.56±1.08 59.28±1.06 61.81±1.13 54.59±1.24 Textures 50.06±0.77 48.37±0.82 60.61±0.82 59.31±0.75 59.25±0.80 Quick Draw 59.54±1.08 54.05±1.30 57.44±1.17 60.99±1.21 44.48±1.41 Fungi 24.60±0.95 22.90±0.95 31.10±1.04 35.96±1.25 21.12±0.88 VGG Flower 62.49±0.91 59.72±1.17 76.72±0.83 81.06±0.87 66.05±1.09 Traffic Signs 41.68±1.46 30.02±1.13 43.20±1.33 39.95±1.18 30.23±1.24 MSCOCO 23.55±0.99 23.01±0.96 26.87±1.00 30.81±1.13 21.13±1.06
3.4 4.3 2.15 1.4 3.75
26
ILSVRC only. Test Source Method: Accuracy ± confidence k-NN Finetune MatchingNet ProtoNet MAML ILSVRC
Omniglot 32.61±1.04 26.01±1.39 28.36±1.1 25.14±1.23 24.05±1.77 Aircraft 18.13±1.34 17.04±1.48 28.8±1.24 29.4±1.28 26.1±1.47 Birds
5.02±1.57 3.96±1.67 3.81±1.81 Textures
Quick Draw 23.45±1.61 15.62±1.9 18.92±1.62 16.97±1.81 13.77±2.07 Fungi 0.9±1.36 0.7±1.32 3.89±1.42 4.78±1.7 0.77±1.24 VGG Flower
1.67±1.23 1.17±1.25 0.93±1.58 Traffic Signs
MSCOCO
27
ILSVRC only. Test Source Method: Accuracy ± confidence k-NN Finetune MatchingNet ProtoNet MAML ILSVRC
Omniglot 32.61±1.04 26.01±1.39 28.36±1.1 25.14±1.23 24.05±1.77 Aircraft 18.13±1.34 17.04±1.48 28.8±1.24 29.4±1.28 26.1±1.47 Birds
5.02±1.57 3.96±1.67 3.81±1.81 Textures
Quick Draw 23.45±1.61 15.62±1.9 18.92±1.62 16.97±1.81 13.77±2.07 Fungi 0.9±1.36 0.7±1.32 3.89±1.42 4.78±1.7 0.77±1.24 VGG Flower
1.67±1.23 1.17±1.25 0.93±1.58 Traffic Signs
MSCOCO
28
29
30
31