Learning with Differentiable Perturbed Optimizers Quentin Berthet - - PowerPoint PPT Presentation
Learning with Differentiable Perturbed Optimizers Quentin Berthet - - PowerPoint PPT Presentation
Learning with Differentiable Perturbed Optimizers Quentin Berthet Youth in High-dimensions - ICTP - 2020 Q. Berthet M.Blondel O.Teboul M. Cuturi J-P. Vert F.Bach Learning with Differentiable Perturbed Optimizers Preprint:
- Q. Berthet
M.Blondel O.Teboul
- M. Cuturi
J-P. Vert F.Bach
- Learning with Differentiable Perturbed Optimizers
Preprint: arXiv:2002.08676
[A lot of] Machine learning these days
Supervised learning: couples of inputs/responses (Xi, yi), a model gw
‘bird’
‘deer’ ‘ship’ ‘horse’ ‘truck’
Xi Xi gw gw θ = gw(Xi) θ = gw(Xi) L" L" yi yi
Goal: Optimize parameters w ∈ Rd of a function gw such that gw(Xi) ≈ yi min
w
- i
L(gw(Xi), yi) . Workhorse: first-order methods, based on ∇wL(gw(Xi), yi), backpropagation Problem: What if these models contain nondifferentiable∗ operations?
Q.Berthet - ICTP - 2020 1/12
Discrete decisions in Machine learning
X gw gw θ y∗(θ) y∗(θ) y∗ y∗ y L
Examples: discrete operations (e.g. max, rankings), break autodifferentiation
- θ = scores for k products,
y∗ = vector of ranks e.g. [5, 2, 4, 3, 1]
- θ = edge costs,
y∗ = shortest path between two points
- θ = classification scores for each class,
y∗ = one-hot vector
Q.Berthet - ICTP - 2020 2/12
Discrete decisions in Machine learning
X gw gw θ y∗(θ) y∗(θ) y∗ y∗ y L
Examples: discrete operations (e.g. max, rankings), break autodifferentiation
- θ = scores for k products,
y∗ = vector of ranks e.g. [5, 2, 4, 3, 1]
- θ = edge costs,
y∗ = shortest path between two points
- θ = classification scores for each class,
y∗ = one-hot vector
Q.Berthet - ICTP - 2020 2/12
Perturbed maximizer
Discrete decisions: optimizers of linear program over C, convex hull of Y ⊆ Rd F(θ) = max
y∈C y, θ ,
and y∗(θ) = argmax
y∈C
y, θ = ∇θF(θ) .
C
y∗(θ) y∗(θ) θ θ
Perturbed maximizer: average of solutions for inputs with noise εZ Fε(θ) = E[max
y∈C y, θ+εZ] , y∗ ε(θ) = E[y∗(θ+εZ)] = E[argmax y∈C
y, θ+εZ] = ∇θFε(θ) .
Q.Berthet - ICTP - 2020 3/12
Perturbed maximizer
Discrete decisions: optimizers of linear program over C, convex hull of Y ⊆ Rd
C
y∗(θ) y∗(θ) y∗(θ + εZ) y∗(θ + εZ) θ + εZ θ + εZ θ y∗
"(θ)
y∗
"(θ)
Perturbed maximizer: average of solutions for inputs with noise εZ Fε(θ) = E[max
y∈C y, θ+εZ] , y∗ ε(θ) = E[y∗(θ+εZ)] = E[argmax y∈C
y, θ+εZ] = ∇θFε(θ) .
Q.Berthet - ICTP - 2020 4/12
Perturbed model
Model of optimal decision under uncertainty Luce (1959), McFadden et al. (1973) Y = argmax
y∈C
y, θ + εZ Follows a perturbed model with Y ∼ pθ(y), expectation y∗
ε(θ) = Epθ[Y ].
Perturb and map Papandreou & Yuille (2011), FT Perturbed L Kalai & Vempala (2003)
Features Costs Shortest Path Perturbed Path = 0.5 Perturbed Path = 2.0
- Example. Over the unit simplex C = ∆d with Gumbel noise Z, Gibbs distribution.
Fε(θ) = ε log
- i∈[d]
e
θi ε ,
pθ(ei) ∝ exp(θ, ei/ε) , [y∗
ε(θ)]i =
e
θi ε
e
θj ε Q.Berthet - ICTP - 2020 5/12
Properties
Link with regularization: ε Ω =
- Fε
∗ is a convex function with domain C y∗
ε(θ) = argmax y∈C
- y, θ − εΩ(y)
- .
Consequence of duality and y∗
ε(θ) = ∇εFε(θ). Generalized entropy Ω
ε = 0 tiny ε small ε large ε Extreme temperatures. When ε → 0, y∗
ε(θ) → y∗(θ) for unique max.
When ε → ∞, y∗
ε(θ) → argminy Ω(y). Nonasymptotic results.
- Differentiability. Smoothness in the inputs, Jacobian as simple expectations.
Q.Berthet - ICTP - 2020 6/12
Learning and Fenchel-Young losses
Learning from Y1, . . . , Yn for a model pθ. Gibbs distribution ∝ exp(θ, Y ): minimize negative log-likelihood LGibbs(θ; Y ) = −1 n
n
- i=1
θ, Yi + log Z(θ) Stochastic gradient and full (batch) gradient: moment matching ∇θLGibbs(θ; Yi) = EGibbs,θ[Y ] − Yi , ∇θLGibbs(θ; Y ) = EGibbs,θ[Y ] − ¯ Yn . Algorithmic challenge: replace by perturbed model Papandreou, Yuille (2011) ∇θLi(θ) = Epθ[Y ] − Yi = y∗
ε(θ) − Yi .
Stochastic gradient of modified functional in θ, not a log-likelihood Lε(θ; y) = −1 n
n
- i=1
θ, Yi + Fε(θ) . Fenchel-Young loss Blondel et al. (2019), good properties (convexity, randomness).
Q.Berthet - ICTP - 2020 7/12
Learning with perturbations and F-Y losses
Within the same framework, possible to virtually bypass the optimization block
y∗
"
y∗
"
X gw gw θ y∗
"(θ)
y∗
"(θ)
y L
Easier to implement, no Jacobian of y∗
ε
Population loss minimized at ground truth for perturbed generative model.
Q.Berthet - ICTP - 2020 8/12
Learning with perturbations and F-Y losses
Within the same framework, possible to virtually bypass the optimization block
X gw gw θ y L" L"
Easier to implement, no Jacobian of y∗
ε
Population loss minimized at ground truth for perturbed generative model.
Q.Berthet - ICTP - 2020 8/12
Computations
Monte Carlo estimates. Perturbed maximizer and derivatives as expectations. For θ ∈ Rd, Z(1), . . . , Z(M) i.i.d. copies y(ℓ) = y∗(θ + εZ(ℓ)) Unbiased estimate of y∗
ε(θ) given by
¯ yε,M(θ) = 1 M
M
- ℓ=1
y(ℓ) .
C
y∗(θ) y∗(θ) y∗(θ + εZ) y∗(θ + εZ) θ + εZ θ + εZ θ y∗
"(θ)
y∗
"(θ)
Supervised learning: Features Xi, model output θw = gw(Xi), prediction ypred = y∗
ε(θw).
Stochastic gradient in w: ∇wFi(w) = Jwgw(Xi) · (y∗
ε(θ) − Yi)
Q.Berthet - ICTP - 2020 9/12
Computations
Monte Carlo estimates. Perturbed maximizer and derivatives as expectations. For θ ∈ Rd, Z(1), . . . , Z(M) i.i.d. copies y(ℓ) = y∗(θ + εZ(ℓ)) Unbiased estimate of y∗
ε(θ) given by
¯ yε,M(θ) = 1 M
M
- ℓ=1
y(ℓ) .
C
y∗(θ) y∗(θ) y∗(θ + εZ) y∗(θ + εZ) θ + εZ θ + εZ θ y∗
"(θ)
y∗
"(θ)
Supervised learning: Features Xi, model output θw = gw(Xi), prediction ypred = y∗
ε(θw).
Stochastic gradient in w (doubly stochastic scheme) ∇wFi(w) = Jwgw(Xi) · 1 M
M
- ℓ=1
y∗(θ + εZ(ℓ)) − Yi
- .
Q.Berthet - ICTP - 2020 10/12
Experiments
Classification: CIFAR-10 dataset of images with 10 classes - Toy comparison
‘bird’
‘deer’ ‘ship’ ‘horse’ ‘truck’
Xi Xi gw gw θ = gw(Xi) θ = gw(Xi) L" L" yi yi
Architecture: vanilla-CNN made of 4 convolutional and 2 fully connected layers. Training: 600 epochs with minibatches of size 32 - influence of M and ε
100 200 300 400 500 600
epochs
0.95 0.96 0.97 0.98 0.99 1.00
Train Accuracy
perturbed Fenchel-Young, M = 1 perturbed Fenchel-Young, M = 1000 Cross entropy baseline
100 200 300 400 500 600
epochs
0.770 0.775 0.780 0.785 0.790 0.795 0.800 0.805 0.810 0.815
Loss
perturbed Fenchel-Young, M = 1 perturbed Fenchel-Young, M = 1000 Cross entropy baseline
10
4
10
2
100 102 104 0.0 2.5 5.0 7.5 10.0 12.5 15.0 17.5 20.0
Loss
train M = 1 train M = 1000 test M = 1 test M = 1000
Q.Berthet - ICTP - 2020 11/12
Experiments
Learning from shortest paths: From 10k examples of Warcraft 96 × 96 RGB images, representing 12×12 costs, and matrix of shortest paths. (Vlastelica et al. 19)
Features Costs Shortest Path Perturbed Path = 0.5 Perturbed Path = 2.0
Train a CNN for 50 epochs, to learn costs recovery of optimal paths.
10 20 30 40 50 epochs 0% 10% 20% 30% 40% 50% 60% 70% 80% 90% 100%
Shortest Path Perfect Accuracy
Perturbed FY Blackbox loss Squared loss 10 20 30 40 50 epochs 1.00 1.02 1.04 1.06 1.08 1.10
Cost ratio to optimal
Perturbed FY Blackbox loss
Q.Berthet - ICTP - 2020 12/12