Loss Valleys and Generalization in Deep Learning
Andrew Gordon Wilson
Assistant Professor https://people.orie.cornell.edu/andrew Cornell University The Robotic Vision Probabilistic Object Detection Challenge CVPR Long Beach, CA June 17, 2019
1 / 41
Loss Valleys and Generalization in Deep Learning Andrew Gordon - - PowerPoint PPT Presentation
Loss Valleys and Generalization in Deep Learning Andrew Gordon Wilson Assistant Professor https://people.orie.cornell.edu/andrew Cornell University The Robotic Vision Probabilistic Object Detection Challenge CVPR Long Beach, CA June 17, 2019
Assistant Professor https://people.orie.cornell.edu/andrew Cornell University The Robotic Vision Probabilistic Object Detection Challenge CVPR Long Beach, CA June 17, 2019
1 / 41
1949 1951 1953 1955 1957 1959 1961 100 200 300 400 500 600 700
Airline Passengers (Thousands) Year
3
104
2 / 41
3 / 41
◮ The ability for a system to learn is determined by its support (which solutions
◮ An influx of new massive datasets provide great opportunities to automatically
4 / 41
◮ A powerful framework for model construction and understanding generalization ◮ Uncertainty representation and calibration (crucial for decision making) ◮ Better point estimates ◮ Interpretably incorporate prior knowledge and domain expertise ◮ It was the most successful approach at the end of the second wave of neural
◮ Neural nets are much less mysterious when viewed through the lens of
◮ Can be computationally intractable (but doesn’t have to be). ◮ Can involve a lot of moving parts (but doesn’t have to).
5 / 41
◮ Bayesian integration will give very different predictions in deep learning
6 / 41
−20 20 40 60 80 100 −20 20 40 60 80 0.065 0.11 0.17 0.28 0.54 1.1 2.3 5 > 5 −20 20 40 60 80 100 −20 20 40 60 80 100 0.065 0.11 0.17 0.28 0.54 1.1 2.3 5 > 5 −20 20 40 60 80 100 −20 20 40 60 0.065 0.11 0.17 0.28 0.54 1.1 2.3 5 > 5
7 / 41
8 / 41
−10 10 20 30 40 50 −10 10 20 30
W1 W2 W3 WSWA
19.95 20.64 21.24 22.38 24.5 28.49 35.97 50 > 50 9 / 41
−10 10 20 30 40 50 −10 10 20 30
W1 W2 W3 WSWA
19.95 20.64 21.24 22.38 24.5 28.49 35.97 50 > 50 10 / 41
−10 10 20 30 40 50 −10 10 20 30
W1 W2 W3 WSWA
19.95 20.64 21.24 22.38 24.5 28.49 35.97 50 > 50 11 / 41
◮ Use learning rate that doesn’t decay to zero (cyclical or constant) ◮ Average weights
◮ Cyclical LR: at the end of each cycle ◮ Constant LR: at the end of each epoch
◮ Recompute batch normalization statistics at the end of training; in practice, do
12 / 41
−10 10 20 30 40 50 −10 10 20 30
W1 W2 W3 WSWA
19.95 20.64 21.24 22.38 24.5 28.49 35.97 50 > 50 −5 5 10 15 20 25 5 10
epoch 125 WSGD WSWA
19.62 20.15 20.67 21.67 23.65 27.52 35.11 50 > 50 −5 5 10 15 20 25 5 10
epoch 125 WSGD WSWA
0.00903 0.02142 0.03422 0.06024 0.1131 0.2206 0.4391 0.8832 > 0.8832
13 / 41
5 1 1 5 2 2 2 2 2 4 2 6 2 8 3
14 / 41
−80 −60 −40 −20 20 40
17.5 20.0 22.5 25.0 27.5 30.0
0.0 0.5 1.0 1.5 2.0 2.5
15 / 41
16 / 41
17 / 41
18 / 41
19 / 41
50 100 150 200 250 300
15 20 25 30 35 40 45 50
20 / 41
◮ Simple drop-in replacement for SGD or other optimizers ◮ Works by finding flat regions of the loss surface ◮ No runtime overhead, but often significant improvements in generalization for
◮ Available in PyTorch contrib (call optim.swa) ◮ https://people.orie.cornell.edu/andrew/code
21 / 41
22 / 41
0.200 0.759 0.927 0.978 0.993 0.998
Confidence (max prob)
0.00 0.05 0.10 0.15 0.20
Confidence - Accuracy WideResNet28x10 CIFAR-100
0.200 0.759 0.927 0.978 0.993 0.998
Confidence (max prob)
0.00 0.05 0.10 0.15 0.20 0.25 0.30 0.35 0.40
Confidence - Accuracy WideResNet28x10 CIFAR-10 → STL-10
0.200 0.759 0.927 0.978 0.993 0.998
Confidence (max prob)
0.00 0.02 0.05 0.08 0.10
Confidence - Accuracy DenseNet-161 ImageNet
0.200 0.759 0.927 0.978 0.993 0.998
Confidence (max prob)
0.00 0.02 0.05 0.08 0.10 0.12
Confidence - Accuracy ResNet-152 ImageNet
23 / 41
24 / 41
◮ Construct a subspace of a network with a high dimensional parameter space ◮ Perform inference directly in the subspace ◮ Sample from approximate posterior for Bayesian model averaging
25 / 41
◮ Choose shift ˆ
◮ Define subspace S = {w|w = ˆ
◮ Likelihood p(D|t) = pM(D|w = ˆ
26 / 41
◮ Approximate inference over parameters t
◮ MCMC, Variational Inference, Normalizing Flows, . . .
◮ Bayesian model averaging at test time:
J
27 / 41
◮ Contains diverse models which give rise to different predictions ◮ Cheap to construct
28 / 41
◮ Directions d1, . . . , dk ∼ N(0, Ip) ◮ Use pre-trained solution as shift ˆ
◮ Subspace S = {w|w = ˆ
29 / 41
◮ Run SGD with a high constant learning rate from a pre-trained solution ◮ Collect snapshots of weights wi ◮ Use SWA solution as shift ˆ
1 M
◮ {d1, . . . , dk} are the first k PCA components of vectors ˆ
30 / 41
31 / 41
32 / 41
33 / 41
◮ Make label predictions using structure
◮ Can quantify recent advances in
◮ Crucial for reducing the dependency of
34 / 41
35 / 41
36 / 41
◮ End-to-end training entirely in low precision. ◮ Can outperform full-precision SGD even with all numbers quantized down to 8
◮ Averaging combines weights that have been rounded up with those that have
◮ Quantizing in a flat region does not hurt loss. ◮ SWALP converges arbitrarily close to the optimal solution. ◮ Special relevance to new GPU architectures.
Low-precision SGD Compute Weight Average Representable Points in Low Precision SGD-LP Trajectory SWALP Solution
37 / 41
◮ Develop optimization procedures which provide better generalization, and good
◮ Develop scalable approaches to Bayesian deep learning, which both provide
38 / 41
◮ Run GPs on millions of points in seconds, vs. thousands of points in hours. ◮ Outperforms stand-alone deep neural networks by learning deep kernels. ◮ Approach accelerated by kernel approximations which admit fast matrix vector
◮ Harmonizes with GPU acceleration. ◮ O(n) training and O(1) testing (instead of O(n3) training and O(n2) testing). ◮ Implemented in our new library GPyTorch: gpytorch.ai
39 / 41
◮ We derive kernels which have recurrent LSTM inductive biases, and apply to
0.0 0.2 0.4 0.6 0.8 1.0
East, mi
0.0 0.2 0.4 0.6 0.8 1.0
North, mi
5 5 10 15 20 30 20 10 5 5 10 15 20 30 20 10 4 8 12 16 20 24 28
Speed, mi/s
10 20 30 40 50 10 20 30 40 50
40 / 41
−5 5 10 20 30 40 50 Front distance, m −5 5 −5 5 Side distance, m −5 5 −5 5 −5 5 10 20 30 40 50 Front distance, m −5 5 −5 5 Side distance, m −5 5 −5 5 41 / 41