Direct Optimization CSC2547 Adamo Young, Dami Choi, Sepehr Abbasi - - PowerPoint PPT Presentation
Direct Optimization CSC2547 Adamo Young, Dami Choi, Sepehr Abbasi - - PowerPoint PPT Presentation
Direct Optimization CSC2547 Adamo Young, Dami Choi, Sepehr Abbasi Zadeh Direct Optimization A way to obtain gradient estimates that directly optimizes a non-differentiable objective. It has first appeared in structured prediction
Direct Optimization
- A way to obtain gradient estimates that directly optimizes a
non-differentiable objective.
- It has first appeared in structured prediction problems.
Structured Prediction
Whenever the goal state has inter-dependency
Image from Wikipedia Image from http://dbmsnotes-ritu.blogspot.com/
Structured Prediction
Scoring function , discrete
Structured Prediction
Inference:
Structured Prediction
Scoring function , discrete
Structured Prediction
Inference: Training:
Structured Prediction
Scoring function , discrete
Gradient Estimator
Gradient Estimator
- Gradient descent on discrete :
Gradient Estimator
- Gradient descent on discrete :
- Option 1: continuous relaxation
Gradient Estimator
- Gradient descent on discrete :
- Option 1: continuous relaxation
- Option 2: estimate
Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)
Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)
Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)
Inference: Loss-augmented Inference:
Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)
Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)
“Away from worse” “Towards better”
Limitations
- Existence of
○ Bias/variance trade-off
- Solving argmax of loss-adjusted inference
Applications
- Phoneme-to-speech alignment (McAllester et al. 2010)
- Maximizing average precision for ranking (Song et al. 2016)
- Discrete structured VAE (Lorberbom et al. 2018)
- RL with discrete action spaces (Lorberbom et al. 2019)
Applications
- Phoneme-to-speech alignment (McAllester et al. 2010)
- Maximizing average precision for ranking (Song et al. 2016)
- Discrete structured VAE (Lorberbom et al. 2018)
- RL with discrete action spaces (Lorberbom et al. 2019)
Direct Optimization through arg max for Discrete Variational Auto-Encoder
Guy Lorberbom, Andreea Gane, Tommi Jaakola, Tamir Hazan
Probability Background
- Gumbel Distribution
- Various Sampling “Tricks”
○ Reparameterization ○ Gumbel-Max ○ Gumbel-Softmax
Gumbel Distribution
Intuitively: Distribution of extreme value of a number of normally distributed samples x p(x)
https://en.wikipedia.org/wiki/Gumbel_distribution
Dot = parameter node Rectangle = deterministic node Circle = stochastic node Line = functional dependency
Gradient Estimators for Stochastic Computation Graphs
Schulman et al 2016
Gradient Estimators for Stochastic Computation Graphs
Dot = parameter node Rectangle = deterministic node Circle = stochastic node Line = functional dependency Red Line = gradient propagation
Reparameterization Trick
Kingma et al 2015
Reparameterization Trick
REINFORCE/REBAR/RELAX Reparam
Williams 1988 Tucker et al 2016 Grathwohl et al 2017
Gumbel-Max Trick
Gumbel-Max Trick
REINFORCE/REBAR/RELAX Direct Optimization
Gumbel-Softmax Trick
REINFORCE/REBAR/RELAX CONCRETE Jang et al 2017 Maddison et al 2017
Gumbel-Softmax Distribution
Jang et al 2017
Why discrete latent variables?
- Stronger inductive bias
- Interpretability
- Allow structural relations in encoder
Standard (Gaussian) VAE
Kingma et al 2013
Standard (Gaussian) VAE
Kingma et al 2013
Standard (Gaussian) VAE
Kingma et al 2013
Standard (Gaussian) VAE
Kingma et al 2013
Naive Categorical VAE
Naive Categorical VAE
Naive Categorical VAE
Naive Categorical VAE
We can apply standard gradient estimators (REINFORCE/REBAR/RELAX)
Gumbel-Max VAE
Gumbel-Max VAE + Direct Optimization
Gumbel-Max VAE + Direct Optimization
Gumbel-Max VAE + Direct Optimization
Algorithm: 1) Sample from Gumbel 2) Compute 3) Estimate gradient
Structured Encoder
No structure:
Structured Encoder
No structure: Pairwise relationships: Solve argmax with QIP/MaxFlow
Structured Encoder
No structure: Pairwise relationships: Solve with CPLEX/Max Flow Not practical with Gumbel-Softmax: exponential number of terms to sum over in the denominator
Structured Encoder may help
Gradient Bias-Variance Tradeoff
Direct Gumbel-Max VAE (with associated epsilon) Gumbel-Softmax VAE (with associated tau)
Direct Gumbel-Max VAE trains faster
K = 10
VAE Comparison
Standard (Gaussian) Gumbel-Softmax Naive Categorical + standard gradient estimator Gumbel-Max + Direct + Unbiased, low variance gradients + Discrete latent variables + Discrete latent variables + Unbiased gradients + Discrete latent variables + Allows structural relations
- Continuous latent
variables
- Limited structural
relations
- Biased gradients
- Limited structural
relations
- Extra parameter (tau)
- Limited structural
relations
- Biased gradients
- Extra parameter
(epsilon)
- Optimization
subproblem to get gradients
Direct Policy Gradients: Direct Optimization of Policies in Discrete Action Spaces
Guy Lorberbom, Chris J. Maddison, Nicolas Heess, Tamir Hazan, Daniel Tarlow
Reinforcement Learning
Agent Environment
action reward, state Goal: Maximize cumulative reward
Policy Gradient Method
Goal:
Agent Environment
action reward, state
Policy Gradient Method
Want: REINFORCE:
Policy Gradient Method
Want: REINFORCE: Direct Policy Gradient:
State Reward Tree
Tree of all possible trajectories (fix the seed of the environment) Separate environment stochasticity and policy stochasticity
State Reward Tree
Given: Can sample trajectories:
Reparameterize the Policy
Instead of sampling per-timestep we sample per-trajectory. Given action sequences , define:
Gumbel-max reparameterization
Now that we have Let for each trajectory , and
Gumbel-max reparameterization
Now that we have Let for each trajectory , and
Gumbel-max reparameterization
Let , and . Then under this reparameterization,
Discrete configurations Scoring function Loss Inference Loss-augmented Inference
Structured Prediction RL
Discrete configurations Scoring function Loss Inference Loss-augmented Inference
Structured Prediction RL
Discrete configurations Scoring function Loss Inference Loss-augmented Inference
Structured Prediction RL
Discrete configurations Scoring function Loss Inference Loss-augmented Inference
Structured Prediction RL
Discrete configurations Scoring function Loss Inference Loss-augmented Inference
Structured Prediction RL
Discrete configurations Scoring function Loss Inference Loss-augmented Inference
Structured Prediction RL
Direct Policy Gradient (DirPG)
Direct Policy Gradient (DirPG)
Direct Policy Gradient (DirPG)
Algorithm
For every training step: 1. Sample 2. 3. Compute gradients
Problem
For every training step: 1. Sample 2. ⇐ How to obtain this? 3. Compute gradients
Solution: A* sampling (Maddison et al., 2014)
Use heuristic search to find trajectory with direct objective better than
Complete Algorithm
For every training step: 1. Sample and compute 2. While budget not exceeded: a. Obtain from heuristic search b. End search if 3. Compute gradients
Limitations
For every training step: 1. Sample and compute 2. While budget not exceeded: a. Obtain from heuristic search b. End search if 3. Compute gradients
Must be able to reset environment to previously visited states.
Limitations
For every training step: 1. Sample and compute 2. While budget not exceeded: a. Obtain from heuristic search b. End search if 3. Compute gradients
Must be able to reset environment to previously visited states. Termination on first improvement
Combinatorial bandits
Number of trajectories searched to find increases as training progresses for combinatorial bandits.
MiniGrid
Comparisons between different heuristics for DirPG and REINFORCE on MiniGrid.
MiniGrid
Evidence of “pulling up” on MiniGrid.
Related Work
- Gradient Estimators
○ REINFORCE (Williams 1988) ○ REBAR (Tucker et al 2017) ○ RELAX (Grathwohl et al 2018) ○ Gumbel-Softmax (Jang et al 2017, Maddison et al 2017)
- Discrete Deep Generative Models
○ VQ-VAE (Oord et al 2017) ○ Discrete VAE (Rolfe 2017) ○ Gumbel-Sinkhorn (Mena at al 2018)
- Reinforcement Learning
Top-Down sampling using A* Sampling
Non-starters
- Compute for all possible trajectories
- Roll-out many trajectories and select best
Gumbel Process
Gumbel Process
We know:
Gumbel Process
We know: Therefore:
Gumbel Process
We know:
Gumbel Process
A B
Gumbel Process
A B
Gumbel Process
A B
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3 1.3
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3 1.3
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3 1.3 1.1
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3 1.3 1.1
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3 1.3 1.1
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3 1.3 1.1 1.3
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3 1.3 1.1 1.3
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3 1.3 1.1 1.3 0.19
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3 1.3 1.1 1.3 0.19
- Repeat until terminating state
found.
- Yield trajectory and
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3 1.3 1.1 1.3 0.19
- Repeat until terminating state
found.
- Yield trajectory and
Recall, Goal: How to prioritize ?
Trajectory Generation
- Lazily create partitions of trajectories.
- Recursion rule:
○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.
1.3 1.3 1.1 1.3 0.19
- Repeat until terminating state
found.
- Yield trajectory and
Recall, Goal: How to prioritize ?
Search for large using A* Sampling
- Lower bound of accumulated reward (L)
- Upper bound of reward-to-go (U)
- In practice: