learning with differentiable perturbed optimizers
play

Learning with Differentiable Perturbed Optimizers Quentin Berthet - PowerPoint PPT Presentation

Learning with Differentiable Perturbed Optimizers Quentin Berthet Youth in High-dimensions - ICTP - 2020 Q. Berthet M.Blondel O.Teboul M. Cuturi J-P. Vert F.Bach Learning with Differentiable Perturbed Optimizers Preprint:


  1. Learning with Differentiable Perturbed Optimizers Quentin Berthet Youth in High-dimensions - ICTP - 2020

  2. Q. Berthet M.Blondel O.Teboul M. Cuturi J-P. Vert F.Bach • Learning with Differentiable Perturbed Optimizers Preprint: arXiv:2002.08676

  3. [A lot of] Machine learning these days Supervised learning : couples of inputs/responses ( X i , y i ) , a model g w X i X i ‘deer’ y i y i θ = g w ( X i ) θ = g w ( X i ) ‘ship’ L " L " g w g w ‘bird’ ‘horse’ ‘truck’ Goal : Optimize parameters w ∈ R d of a function g w such that g w ( X i ) ≈ y i � min L ( g w ( X i ) , y i ) . w i Workhorse : first-order methods, based on ∇ w L ( g w ( X i ) , y i ) , backpropagation Problem : What if these models contain nondifferentiable ∗ operations? Q.Berthet - ICTP - 2020 1/12

  4. Discrete decisions in Machine learning X y ∗ ( θ ) y ∗ ( θ ) θ y g w g w y ∗ y ∗ L Examples : discrete operations (e.g. max, rankings), break autodifferentiation y ∗ = vector of ranks e.g. [5 , 2 , 4 , 3 , 1] • θ = scores for k products, y ∗ = shortest path between two points • θ = edge costs, y ∗ = one-hot vector • θ = classification scores for each class, Q.Berthet - ICTP - 2020 2/12

  5. Discrete decisions in Machine learning X y ∗ ( θ ) y ∗ ( θ ) θ y g w g w y ∗ y ∗ L Examples : discrete operations (e.g. max, rankings), break autodifferentiation y ∗ = vector of ranks e.g. [5 , 2 , 4 , 3 , 1] • θ = scores for k products, y ∗ = shortest path between two points • θ = edge costs, y ∗ = one-hot vector • θ = classification scores for each class, Q.Berthet - ICTP - 2020 2/12

  6. Perturbed maximizer Discrete decisions : optimizers of linear program over C , convex hull of Y ⊆ R d y ∗ ( θ ) = argmax F ( θ ) = max y ∈C � y, θ � , and � y, θ � = ∇ θ F ( θ ) . y ∈C C θ θ y ∗ ( θ ) y ∗ ( θ ) Perturbed maximizer : average of solutions for inputs with noise εZ y ∈C � y, θ + εZ � ] , y ∗ ε ( θ ) = E [ y ∗ ( θ + εZ )] = E [argmax F ε ( θ ) = E [max � y, θ + εZ � ] = ∇ θ F ε ( θ ) . y ∈C Q.Berthet - ICTP - 2020 3/12

  7. Perturbed maximizer Discrete decisions : optimizers of linear program over C , convex hull of Y ⊆ R d y ∗ ( θ + εZ ) y ∗ ( θ + εZ ) C θ + εZ θ + εZ y ∗ y ∗ " ( θ ) " ( θ ) θ y ∗ ( θ ) y ∗ ( θ ) Perturbed maximizer : average of solutions for inputs with noise εZ y ∈C � y, θ + εZ � ] , y ∗ ε ( θ ) = E [ y ∗ ( θ + εZ )] = E [argmax F ε ( θ ) = E [max � y, θ + εZ � ] = ∇ θ F ε ( θ ) . y ∈C Q.Berthet - ICTP - 2020 4/12

  8. Perturbed model Model of optimal decision under uncertainty Luce (1959), McFadden et al. (1973) Y = argmax � y, θ + εZ � y ∈C Follows a perturbed model with Y ∼ p θ ( y ) , expectation y ∗ ε ( θ ) = E p θ [ Y ] . Perturb and map Papandreou & Yuille (2011), FT Perturbed L Kalai & Vempala (2003) Features Costs Shortest Path Perturbed Path = 0.5 Perturbed Path = 2.0 Example . Over the unit simplex C = ∆ d with Gumbel noise Z , Gibbs distribution. θi e θi ε � ε , [ y ∗ F ε ( θ ) = ε log e p θ ( e i ) ∝ exp( � θ, e i � /ε ) , ε ( θ )] i = � e θj ε i ∈ [ d ] Q.Berthet - ICTP - 2020 5/12

  9. Properties � ∗ is a convex function with domain C � Link with regularization : ε Ω = F ε y ∗ � � ε ( θ ) = argmax � y, θ � − ε Ω( y ) . y ∈C Consequence of duality and y ∗ ε ( θ ) = ∇ ε F ε ( θ ) . Generalized entropy Ω ε = 0 tiny ε small ε large ε Extreme temperatures. When ε → 0 , y ∗ ε ( θ ) → y ∗ ( θ ) for unique max. When ε → ∞ , y ∗ ε ( θ ) → argmin y Ω( y ) . Nonasymptotic results. Differentiability. Smoothness in the inputs, Jacobian as simple expectations. Q.Berthet - ICTP - 2020 6/12

  10. Learning and Fenchel-Young losses Learning from Y 1 , . . . , Y n for a model p θ . Gibbs distribution ∝ exp( � θ, Y � ) : minimize negative log-likelihood n L Gibbs ( θ ; Y ) = − 1 � � θ, Y i � + log Z ( θ ) n i =1 Stochastic gradient and full (batch) gradient: moment matching ∇ θ L Gibbs ( θ ; Y ) = E Gibbs ,θ [ Y ] − ¯ ∇ θ L Gibbs ( θ ; Y i ) = E Gibbs ,θ [ Y ] − Y i , Y n . Algorithmic challenge: replace by perturbed model Papandreou, Yuille (2011) ∇ θ L i ( θ ) = E p θ [ Y ] − Y i = y ∗ ε ( θ ) − Y i . Stochastic gradient of modified functional in θ , not a log-likelihood n L ε ( θ ; y ) = − 1 � � θ, Y i � + F ε ( θ ) . n i =1 Fenchel-Young loss Blondel et al. (2019) , good properties (convexity, randomness). Q.Berthet - ICTP - 2020 7/12

  11. Learning with perturbations and F-Y losses Within the same framework, possible to virtually bypass the optimization block X y ∗ y ∗ θ " ( θ ) " ( θ ) y y ∗ y ∗ g w g w " " L Easier to implement, no Jacobian of y ∗ ε Population loss minimized at ground truth for perturbed generative model. Q.Berthet - ICTP - 2020 8/12

  12. Learning with perturbations and F-Y losses Within the same framework, possible to virtually bypass the optimization block X θ y g w g w L " L " Easier to implement, no Jacobian of y ∗ ε Population loss minimized at ground truth for perturbed generative model. Q.Berthet - ICTP - 2020 8/12

  13. Computations Monte Carlo estimates. Perturbed maximizer and derivatives as expectations. For θ ∈ R d , Z (1) , . . . , Z ( M ) i.i.d. copies y ( ℓ ) = y ∗ ( θ + εZ ( ℓ ) ) y ∗ ( θ + εZ ) y ∗ ( θ + εZ ) C Unbiased estimate of y ∗ ε ( θ ) given by θ + εZ θ + εZ y ∗ y ∗ " ( θ ) " ( θ ) M y ε,M ( θ ) = 1 y ( ℓ ) . � ¯ θ M y ∗ ( θ ) y ∗ ( θ ) ℓ =1 Supervised learning : Features X i , model output θ w = g w ( X i ) , prediction y pred = y ∗ ε ( θ w ) . Stochastic gradient in w : ∇ w F i ( w ) = J w g w ( X i ) · ( y ∗ ε ( θ ) − Y i ) Q.Berthet - ICTP - 2020 9/12

  14. Computations Monte Carlo estimates. Perturbed maximizer and derivatives as expectations. For θ ∈ R d , Z (1) , . . . , Z ( M ) i.i.d. copies y ( ℓ ) = y ∗ ( θ + εZ ( ℓ ) ) y ∗ ( θ + εZ ) y ∗ ( θ + εZ ) C Unbiased estimate of y ∗ ε ( θ ) given by θ + εZ θ + εZ y ∗ y ∗ " ( θ ) " ( θ ) M y ε,M ( θ ) = 1 y ( ℓ ) . � ¯ θ M y ∗ ( θ ) y ∗ ( θ ) ℓ =1 Supervised learning : Features X i , model output θ w = g w ( X i ) , prediction y pred = y ∗ ε ( θ w ) . Stochastic gradient in w (doubly stochastic scheme) � 1 M � � y ∗ ( θ + εZ ( ℓ ) ) − Y i ∇ w F i ( w ) = J w g w ( X i ) · . M ℓ =1 Q.Berthet - ICTP - 2020 10/12

  15. Experiments Classification : CIFAR-10 dataset of images with 10 classes - Toy comparison X i X i ‘deer’ y i y i θ = g w ( X i ) θ = g w ( X i ) ‘ship’ L " L " g w g w ‘bird’ ‘horse’ ‘truck’ Architecture : vanilla-CNN made of 4 convolutional and 2 fully connected layers. Training : 600 epochs with minibatches of size 32 - influence of M and ε Train Accuracy Loss Loss 1.00 0.815 20.0 0.810 17.5 0.99 0.805 15.0 0.800 12.5 0.98 0.795 10.0 0.790 train M = 1 0.97 7.5 train M = 1000 0.785 test M = 1 5.0 perturbed Fenchel-Young, M = 1 perturbed Fenchel-Young, M = 1 test M = 1000 0.780 0.96 perturbed Fenchel-Young, M = 1000 perturbed Fenchel-Young, M = 1000 2.5 0.775 Cross entropy baseline Cross entropy baseline 0.95 0.770 0.0 0 100 200 300 400 500 600 0 100 200 300 400 500 600 10 4 10 2 10 0 10 2 10 4 epochs epochs Q.Berthet - ICTP - 2020 11/12

  16. Experiments Learning from shortest paths : From 10k examples of Warcraft 96 × 96 RGB images, representing 12 × 12 costs, and matrix of shortest paths. (Vlastelica et al. 19) Features Costs Shortest Path Perturbed Path = 0.5 Perturbed Path = 2.0 Train a CNN for 50 epochs, to learn costs recovery of optimal paths. Shortest Path Perfect Accuracy Cost ratio to optimal 100% 1.10 Perturbed FY 90% Blackbox loss 1.08 80% 70% 1.06 60% 50% 1.04 40% 30% 1.02 Perturbed FY 20% Blackbox loss 10% 1.00 Squared loss 0% 0 10 20 30 40 50 0 10 20 30 40 50 epochs epochs Q.Berthet - ICTP - 2020 12/12

  17. GRAZZIE

Download Presentation
Download Policy: The content available on the website is offered to you 'AS IS' for your personal information and use only. It cannot be commercialized, licensed, or distributed on other websites without prior consent from the author. To download a presentation, simply click this link. If you encounter any difficulties during the download process, it's possible that the publisher has removed the file from their server.

Recommend


More recommend