Direct Optimization CSC2547 Adamo Young, Dami Choi, Sepehr Abbasi - - PowerPoint PPT Presentation

direct optimization
SMART_READER_LITE
LIVE PREVIEW

Direct Optimization CSC2547 Adamo Young, Dami Choi, Sepehr Abbasi - - PowerPoint PPT Presentation

Direct Optimization CSC2547 Adamo Young, Dami Choi, Sepehr Abbasi Zadeh Direct Optimization A way to obtain gradient estimates that directly optimizes a non-differentiable objective. It has first appeared in structured prediction


slide-1
SLIDE 1

Direct Optimization

CSC2547 Adamo Young, Dami Choi, Sepehr Abbasi Zadeh

slide-2
SLIDE 2

Direct Optimization

  • A way to obtain gradient estimates that directly optimizes a

non-differentiable objective.

  • It has first appeared in structured prediction problems.
slide-3
SLIDE 3

Structured Prediction

Whenever the goal state has inter-dependency

Image from Wikipedia Image from http://dbmsnotes-ritu.blogspot.com/

slide-4
SLIDE 4

Structured Prediction

Scoring function , discrete

slide-5
SLIDE 5

Structured Prediction

Inference:

Structured Prediction

Scoring function , discrete

slide-6
SLIDE 6

Structured Prediction

Inference: Training:

Structured Prediction

Scoring function , discrete

slide-7
SLIDE 7

Gradient Estimator

slide-8
SLIDE 8

Gradient Estimator

  • Gradient descent on discrete :
slide-9
SLIDE 9

Gradient Estimator

  • Gradient descent on discrete :
  • Option 1: continuous relaxation
slide-10
SLIDE 10

Gradient Estimator

  • Gradient descent on discrete :
  • Option 1: continuous relaxation
  • Option 2: estimate
slide-11
SLIDE 11

Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)

slide-12
SLIDE 12

Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)

slide-13
SLIDE 13

Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)

Inference: Loss-augmented Inference:

slide-14
SLIDE 14

Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)

slide-15
SLIDE 15

Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)

“Away from worse” “Towards better”

slide-16
SLIDE 16

Limitations

  • Existence of

○ Bias/variance trade-off

  • Solving argmax of loss-adjusted inference
slide-17
SLIDE 17

Applications

  • Phoneme-to-speech alignment (McAllester et al. 2010)
  • Maximizing average precision for ranking (Song et al. 2016)
  • Discrete structured VAE (Lorberbom et al. 2018)
  • RL with discrete action spaces (Lorberbom et al. 2019)
slide-18
SLIDE 18

Applications

  • Phoneme-to-speech alignment (McAllester et al. 2010)
  • Maximizing average precision for ranking (Song et al. 2016)
  • Discrete structured VAE (Lorberbom et al. 2018)
  • RL with discrete action spaces (Lorberbom et al. 2019)
slide-19
SLIDE 19

Direct Optimization through arg max for Discrete Variational Auto-Encoder

Guy Lorberbom, Andreea Gane, Tommi Jaakola, Tamir Hazan

slide-20
SLIDE 20

Probability Background

  • Gumbel Distribution
  • Various Sampling “Tricks”

○ Reparameterization ○ Gumbel-Max ○ Gumbel-Softmax

slide-21
SLIDE 21

Gumbel Distribution

Intuitively: Distribution of extreme value of a number of normally distributed samples x p(x)

https://en.wikipedia.org/wiki/Gumbel_distribution

slide-22
SLIDE 22

Dot = parameter node Rectangle = deterministic node Circle = stochastic node Line = functional dependency

Gradient Estimators for Stochastic Computation Graphs

Schulman et al 2016

slide-23
SLIDE 23

Gradient Estimators for Stochastic Computation Graphs

Dot = parameter node Rectangle = deterministic node Circle = stochastic node Line = functional dependency Red Line = gradient propagation

slide-24
SLIDE 24

Reparameterization Trick

Kingma et al 2015

slide-25
SLIDE 25

Reparameterization Trick

REINFORCE/REBAR/RELAX Reparam

Williams 1988 Tucker et al 2016 Grathwohl et al 2017

slide-26
SLIDE 26

Gumbel-Max Trick

slide-27
SLIDE 27

Gumbel-Max Trick

REINFORCE/REBAR/RELAX Direct Optimization

slide-28
SLIDE 28

Gumbel-Softmax Trick

REINFORCE/REBAR/RELAX CONCRETE Jang et al 2017 Maddison et al 2017

slide-29
SLIDE 29

Gumbel-Softmax Distribution

Jang et al 2017

slide-30
SLIDE 30

Why discrete latent variables?

  • Stronger inductive bias
  • Interpretability
  • Allow structural relations in encoder
slide-31
SLIDE 31

Standard (Gaussian) VAE

Kingma et al 2013

slide-32
SLIDE 32

Standard (Gaussian) VAE

Kingma et al 2013

slide-33
SLIDE 33

Standard (Gaussian) VAE

Kingma et al 2013

slide-34
SLIDE 34

Standard (Gaussian) VAE

Kingma et al 2013

slide-35
SLIDE 35

Naive Categorical VAE

slide-36
SLIDE 36

Naive Categorical VAE

slide-37
SLIDE 37

Naive Categorical VAE

slide-38
SLIDE 38

Naive Categorical VAE

We can apply standard gradient estimators (REINFORCE/REBAR/RELAX)

slide-39
SLIDE 39

Gumbel-Max VAE

slide-40
SLIDE 40

Gumbel-Max VAE + Direct Optimization

slide-41
SLIDE 41

Gumbel-Max VAE + Direct Optimization

slide-42
SLIDE 42

Gumbel-Max VAE + Direct Optimization

Algorithm: 1) Sample from Gumbel 2) Compute 3) Estimate gradient

slide-43
SLIDE 43

Structured Encoder

No structure:

slide-44
SLIDE 44

Structured Encoder

No structure: Pairwise relationships: Solve argmax with QIP/MaxFlow

slide-45
SLIDE 45

Structured Encoder

No structure: Pairwise relationships: Solve with CPLEX/Max Flow Not practical with Gumbel-Softmax: exponential number of terms to sum over in the denominator

slide-46
SLIDE 46

Structured Encoder may help

slide-47
SLIDE 47

Gradient Bias-Variance Tradeoff

Direct Gumbel-Max VAE (with associated epsilon) Gumbel-Softmax VAE (with associated tau)

slide-48
SLIDE 48

Direct Gumbel-Max VAE trains faster

K = 10

slide-49
SLIDE 49

VAE Comparison

Standard (Gaussian) Gumbel-Softmax Naive Categorical + standard gradient estimator Gumbel-Max + Direct + Unbiased, low variance gradients + Discrete latent variables + Discrete latent variables + Unbiased gradients + Discrete latent variables + Allows structural relations

  • Continuous latent

variables

  • Limited structural

relations

  • Biased gradients
  • Limited structural

relations

  • Extra parameter (tau)
  • Limited structural

relations

  • Biased gradients
  • Extra parameter

(epsilon)

  • Optimization

subproblem to get gradients

slide-50
SLIDE 50

Direct Policy Gradients: Direct Optimization of Policies in Discrete Action Spaces

Guy Lorberbom, Chris J. Maddison, Nicolas Heess, Tamir Hazan, Daniel Tarlow

slide-51
SLIDE 51

Reinforcement Learning

Agent Environment

action reward, state Goal: Maximize cumulative reward

slide-52
SLIDE 52

Policy Gradient Method

Goal:

Agent Environment

action reward, state

slide-53
SLIDE 53

Policy Gradient Method

Want: REINFORCE:

slide-54
SLIDE 54

Policy Gradient Method

Want: REINFORCE: Direct Policy Gradient:

slide-55
SLIDE 55

State Reward Tree

Tree of all possible trajectories (fix the seed of the environment) Separate environment stochasticity and policy stochasticity

slide-56
SLIDE 56

State Reward Tree

Given: Can sample trajectories:

slide-57
SLIDE 57

Reparameterize the Policy

Instead of sampling per-timestep we sample per-trajectory. Given action sequences , define:

slide-58
SLIDE 58

Gumbel-max reparameterization

Now that we have Let for each trajectory , and

slide-59
SLIDE 59

Gumbel-max reparameterization

Now that we have Let for each trajectory , and

slide-60
SLIDE 60

Gumbel-max reparameterization

Let , and . Then under this reparameterization,

slide-61
SLIDE 61

Discrete configurations Scoring function Loss Inference Loss-augmented Inference

Structured Prediction RL

slide-62
SLIDE 62

Discrete configurations Scoring function Loss Inference Loss-augmented Inference

Structured Prediction RL

slide-63
SLIDE 63

Discrete configurations Scoring function Loss Inference Loss-augmented Inference

Structured Prediction RL

slide-64
SLIDE 64

Discrete configurations Scoring function Loss Inference Loss-augmented Inference

Structured Prediction RL

slide-65
SLIDE 65

Discrete configurations Scoring function Loss Inference Loss-augmented Inference

Structured Prediction RL

slide-66
SLIDE 66

Discrete configurations Scoring function Loss Inference Loss-augmented Inference

Structured Prediction RL

slide-67
SLIDE 67

Direct Policy Gradient (DirPG)

slide-68
SLIDE 68

Direct Policy Gradient (DirPG)

slide-69
SLIDE 69

Direct Policy Gradient (DirPG)

slide-70
SLIDE 70

Algorithm

For every training step: 1. Sample 2. 3. Compute gradients

slide-71
SLIDE 71

Problem

For every training step: 1. Sample 2. ⇐ How to obtain this? 3. Compute gradients

slide-72
SLIDE 72

Solution: A* sampling (Maddison et al., 2014)

Use heuristic search to find trajectory with direct objective better than

slide-73
SLIDE 73

Complete Algorithm

For every training step: 1. Sample and compute 2. While budget not exceeded: a. Obtain from heuristic search b. End search if 3. Compute gradients

slide-74
SLIDE 74

Limitations

For every training step: 1. Sample and compute 2. While budget not exceeded: a. Obtain from heuristic search b. End search if 3. Compute gradients

Must be able to reset environment to previously visited states.

slide-75
SLIDE 75

Limitations

For every training step: 1. Sample and compute 2. While budget not exceeded: a. Obtain from heuristic search b. End search if 3. Compute gradients

Must be able to reset environment to previously visited states. Termination on first improvement

slide-76
SLIDE 76

Combinatorial bandits

Number of trajectories searched to find increases as training progresses for combinatorial bandits.

slide-77
SLIDE 77

MiniGrid

Comparisons between different heuristics for DirPG and REINFORCE on MiniGrid.

slide-78
SLIDE 78

MiniGrid

Evidence of “pulling up” on MiniGrid.

slide-79
SLIDE 79

Related Work

  • Gradient Estimators

○ REINFORCE (Williams 1988) ○ REBAR (Tucker et al 2017) ○ RELAX (Grathwohl et al 2018) ○ Gumbel-Softmax (Jang et al 2017, Maddison et al 2017)

  • Discrete Deep Generative Models

○ VQ-VAE (Oord et al 2017) ○ Discrete VAE (Rolfe 2017) ○ Gumbel-Sinkhorn (Mena at al 2018)

  • Reinforcement Learning
slide-80
SLIDE 80

Top-Down sampling using A* Sampling

slide-81
SLIDE 81

Non-starters

  • Compute for all possible trajectories
  • Roll-out many trajectories and select best
slide-82
SLIDE 82

Gumbel Process

slide-83
SLIDE 83

Gumbel Process

We know:

slide-84
SLIDE 84

Gumbel Process

We know: Therefore:

slide-85
SLIDE 85

Gumbel Process

We know:

slide-86
SLIDE 86

Gumbel Process

A B

slide-87
SLIDE 87

Gumbel Process

A B

slide-88
SLIDE 88

Gumbel Process

A B

slide-89
SLIDE 89

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

slide-90
SLIDE 90

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

slide-91
SLIDE 91

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

slide-92
SLIDE 92

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3

slide-93
SLIDE 93

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3

slide-94
SLIDE 94

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3

slide-95
SLIDE 95

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3 1.3

slide-96
SLIDE 96

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3 1.3

slide-97
SLIDE 97

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3 1.3 1.1

slide-98
SLIDE 98

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3 1.3 1.1

slide-99
SLIDE 99

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3 1.3 1.1

slide-100
SLIDE 100

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3 1.3 1.1 1.3

slide-101
SLIDE 101

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3 1.3 1.1 1.3

slide-102
SLIDE 102

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3 1.3 1.1 1.3 0.19

slide-103
SLIDE 103

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3 1.3 1.1 1.3 0.19

  • Repeat until terminating state

found.

  • Yield trajectory and
slide-104
SLIDE 104

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3 1.3 1.1 1.3 0.19

  • Repeat until terminating state

found.

  • Yield trajectory and

Recall, Goal: How to prioritize ?

slide-105
SLIDE 105

Trajectory Generation

  • Lazily create partitions of trajectories.
  • Recursion rule:

○ For , copy parent node’s value. ○ For the remaining choices of actions, group them and compute truncated value.

1.3 1.3 1.1 1.3 0.19

  • Repeat until terminating state

found.

  • Yield trajectory and

Recall, Goal: How to prioritize ?

slide-106
SLIDE 106

Search for large using A* Sampling

  • Lower bound of accumulated reward (L)
  • Upper bound of reward-to-go (U)
  • In practice: