Learning with Differentiable Perturbed Optimizers Quentin Berthet - - PowerPoint PPT Presentation

learning with differentiable perturbed optimizers
SMART_READER_LITE
LIVE PREVIEW

Learning with Differentiable Perturbed Optimizers Quentin Berthet - - PowerPoint PPT Presentation

Learning with Differentiable Perturbed Optimizers Quentin Berthet Youth in High-dimensions - ICTP - 2020 Q. Berthet M.Blondel O.Teboul M. Cuturi J-P. Vert F.Bach Learning with Differentiable Perturbed Optimizers Preprint:


slide-1
SLIDE 1

Learning with Differentiable Perturbed Optimizers

Quentin Berthet Youth in High-dimensions - ICTP - 2020

slide-2
SLIDE 2
  • Q. Berthet

M.Blondel O.Teboul

  • M. Cuturi

J-P. Vert F.Bach

  • Learning with Differentiable Perturbed Optimizers

Preprint: arXiv:2002.08676

slide-3
SLIDE 3

[A lot of] Machine learning these days

Supervised learning: couples of inputs/responses (Xi, yi), a model gw

‘bird’

‘deer’ ‘ship’ ‘horse’ ‘truck’

Xi Xi gw gw θ = gw(Xi) θ = gw(Xi) L" L" yi yi

Goal: Optimize parameters w ∈ Rd of a function gw such that gw(Xi) ≈ yi min

w

  • i

L(gw(Xi), yi) . Workhorse: first-order methods, based on ∇wL(gw(Xi), yi), backpropagation Problem: What if these models contain nondifferentiable∗ operations?

Q.Berthet - ICTP - 2020 1/12

slide-4
SLIDE 4

Discrete decisions in Machine learning

X gw gw θ y∗(θ) y∗(θ) y∗ y∗ y L

Examples: discrete operations (e.g. max, rankings), break autodifferentiation

  • θ = scores for k products,

y∗ = vector of ranks e.g. [5, 2, 4, 3, 1]

  • θ = edge costs,

y∗ = shortest path between two points

  • θ = classification scores for each class,

y∗ = one-hot vector

Q.Berthet - ICTP - 2020 2/12

slide-5
SLIDE 5

Discrete decisions in Machine learning

X gw gw θ y∗(θ) y∗(θ) y∗ y∗ y L

Examples: discrete operations (e.g. max, rankings), break autodifferentiation

  • θ = scores for k products,

y∗ = vector of ranks e.g. [5, 2, 4, 3, 1]

  • θ = edge costs,

y∗ = shortest path between two points

  • θ = classification scores for each class,

y∗ = one-hot vector

Q.Berthet - ICTP - 2020 2/12

slide-6
SLIDE 6

Perturbed maximizer

Discrete decisions: optimizers of linear program over C, convex hull of Y ⊆ Rd F(θ) = max

y∈C y, θ ,

and y∗(θ) = argmax

y∈C

y, θ = ∇θF(θ) .

C

y∗(θ) y∗(θ) θ θ

Perturbed maximizer: average of solutions for inputs with noise εZ Fε(θ) = E[max

y∈C y, θ+εZ] , y∗ ε(θ) = E[y∗(θ+εZ)] = E[argmax y∈C

y, θ+εZ] = ∇θFε(θ) .

Q.Berthet - ICTP - 2020 3/12

slide-7
SLIDE 7

Perturbed maximizer

Discrete decisions: optimizers of linear program over C, convex hull of Y ⊆ Rd

C

y∗(θ) y∗(θ) y∗(θ + εZ) y∗(θ + εZ) θ + εZ θ + εZ θ y∗

"(θ)

y∗

"(θ)

Perturbed maximizer: average of solutions for inputs with noise εZ Fε(θ) = E[max

y∈C y, θ+εZ] , y∗ ε(θ) = E[y∗(θ+εZ)] = E[argmax y∈C

y, θ+εZ] = ∇θFε(θ) .

Q.Berthet - ICTP - 2020 4/12

slide-8
SLIDE 8

Perturbed model

Model of optimal decision under uncertainty Luce (1959), McFadden et al. (1973) Y = argmax

y∈C

y, θ + εZ Follows a perturbed model with Y ∼ pθ(y), expectation y∗

ε(θ) = Epθ[Y ].

Perturb and map Papandreou & Yuille (2011), FT Perturbed L Kalai & Vempala (2003)

Features Costs Shortest Path Perturbed Path = 0.5 Perturbed Path = 2.0

  • Example. Over the unit simplex C = ∆d with Gumbel noise Z, Gibbs distribution.

Fε(θ) = ε log

  • i∈[d]

e

θi ε ,

pθ(ei) ∝ exp(θ, ei/ε) , [y∗

ε(θ)]i =

e

θi ε

e

θj ε Q.Berthet - ICTP - 2020 5/12

slide-9
SLIDE 9

Properties

Link with regularization: ε Ω =

∗ is a convex function with domain C y∗

ε(θ) = argmax y∈C

  • y, θ − εΩ(y)
  • .

Consequence of duality and y∗

ε(θ) = ∇εFε(θ). Generalized entropy Ω

ε = 0 tiny ε small ε large ε Extreme temperatures. When ε → 0, y∗

ε(θ) → y∗(θ) for unique max.

When ε → ∞, y∗

ε(θ) → argminy Ω(y). Nonasymptotic results.

  • Differentiability. Smoothness in the inputs, Jacobian as simple expectations.

Q.Berthet - ICTP - 2020 6/12

slide-10
SLIDE 10

Learning and Fenchel-Young losses

Learning from Y1, . . . , Yn for a model pθ. Gibbs distribution ∝ exp(θ, Y ): minimize negative log-likelihood LGibbs(θ; Y ) = −1 n

n

  • i=1

θ, Yi + log Z(θ) Stochastic gradient and full (batch) gradient: moment matching ∇θLGibbs(θ; Yi) = EGibbs,θ[Y ] − Yi , ∇θLGibbs(θ; Y ) = EGibbs,θ[Y ] − ¯ Yn . Algorithmic challenge: replace by perturbed model Papandreou, Yuille (2011) ∇θLi(θ) = Epθ[Y ] − Yi = y∗

ε(θ) − Yi .

Stochastic gradient of modified functional in θ, not a log-likelihood Lε(θ; y) = −1 n

n

  • i=1

θ, Yi + Fε(θ) . Fenchel-Young loss Blondel et al. (2019), good properties (convexity, randomness).

Q.Berthet - ICTP - 2020 7/12

slide-11
SLIDE 11

Learning with perturbations and F-Y losses

Within the same framework, possible to virtually bypass the optimization block

y∗

"

y∗

"

X gw gw θ y∗

"(θ)

y∗

"(θ)

y L

Easier to implement, no Jacobian of y∗

ε

Population loss minimized at ground truth for perturbed generative model.

Q.Berthet - ICTP - 2020 8/12

slide-12
SLIDE 12

Learning with perturbations and F-Y losses

Within the same framework, possible to virtually bypass the optimization block

X gw gw θ y L" L"

Easier to implement, no Jacobian of y∗

ε

Population loss minimized at ground truth for perturbed generative model.

Q.Berthet - ICTP - 2020 8/12

slide-13
SLIDE 13

Computations

Monte Carlo estimates. Perturbed maximizer and derivatives as expectations. For θ ∈ Rd, Z(1), . . . , Z(M) i.i.d. copies y(ℓ) = y∗(θ + εZ(ℓ)) Unbiased estimate of y∗

ε(θ) given by

¯ yε,M(θ) = 1 M

M

  • ℓ=1

y(ℓ) .

C

y∗(θ) y∗(θ) y∗(θ + εZ) y∗(θ + εZ) θ + εZ θ + εZ θ y∗

"(θ)

y∗

"(θ)

Supervised learning: Features Xi, model output θw = gw(Xi), prediction ypred = y∗

ε(θw).

Stochastic gradient in w: ∇wFi(w) = Jwgw(Xi) · (y∗

ε(θ) − Yi)

Q.Berthet - ICTP - 2020 9/12

slide-14
SLIDE 14

Computations

Monte Carlo estimates. Perturbed maximizer and derivatives as expectations. For θ ∈ Rd, Z(1), . . . , Z(M) i.i.d. copies y(ℓ) = y∗(θ + εZ(ℓ)) Unbiased estimate of y∗

ε(θ) given by

¯ yε,M(θ) = 1 M

M

  • ℓ=1

y(ℓ) .

C

y∗(θ) y∗(θ) y∗(θ + εZ) y∗(θ + εZ) θ + εZ θ + εZ θ y∗

"(θ)

y∗

"(θ)

Supervised learning: Features Xi, model output θw = gw(Xi), prediction ypred = y∗

ε(θw).

Stochastic gradient in w (doubly stochastic scheme) ∇wFi(w) = Jwgw(Xi) · 1 M

M

  • ℓ=1

y∗(θ + εZ(ℓ)) − Yi

  • .

Q.Berthet - ICTP - 2020 10/12

slide-15
SLIDE 15

Experiments

Classification: CIFAR-10 dataset of images with 10 classes - Toy comparison

‘bird’

‘deer’ ‘ship’ ‘horse’ ‘truck’

Xi Xi gw gw θ = gw(Xi) θ = gw(Xi) L" L" yi yi

Architecture: vanilla-CNN made of 4 convolutional and 2 fully connected layers. Training: 600 epochs with minibatches of size 32 - influence of M and ε

100 200 300 400 500 600

epochs

0.95 0.96 0.97 0.98 0.99 1.00

Train Accuracy

perturbed Fenchel-Young, M = 1 perturbed Fenchel-Young, M = 1000 Cross entropy baseline

100 200 300 400 500 600

epochs

0.770 0.775 0.780 0.785 0.790 0.795 0.800 0.805 0.810 0.815

Loss

perturbed Fenchel-Young, M = 1 perturbed Fenchel-Young, M = 1000 Cross entropy baseline

10

4

10

2

100 102 104 0.0 2.5 5.0 7.5 10.0 12.5 15.0 17.5 20.0

Loss

train M = 1 train M = 1000 test M = 1 test M = 1000

Q.Berthet - ICTP - 2020 11/12

slide-16
SLIDE 16

Experiments

Learning from shortest paths: From 10k examples of Warcraft 96 × 96 RGB images, representing 12×12 costs, and matrix of shortest paths. (Vlastelica et al. 19)

Features Costs Shortest Path Perturbed Path = 0.5 Perturbed Path = 2.0

Train a CNN for 50 epochs, to learn costs recovery of optimal paths.

10 20 30 40 50 epochs 0% 10% 20% 30% 40% 50% 60% 70% 80% 90% 100%

Shortest Path Perfect Accuracy

Perturbed FY Blackbox loss Squared loss 10 20 30 40 50 epochs 1.00 1.02 1.04 1.06 1.08 1.10

Cost ratio to optimal

Perturbed FY Blackbox loss

Q.Berthet - ICTP - 2020 12/12

slide-17
SLIDE 17

GRAZZIE