CS 6355: Structured Prediction
L2S: Learning to Search
1
Some slides adapted from Daumé and Ross
L2S: Learning to Search CS 6355: Structured Prediction 1 Some - - PowerPoint PPT Presentation
L2S: Learning to Search CS 6355: Structured Prediction 1 Some slides adapted from Daum and Ross Inference What is inference? An overview of what we have seen before Combinatorial optimization Different views of inference
1
Some slides adapted from Daumé and Ross
– An overview of what we have seen before – Combinatorial optimization – Different views of inference
– Dynamic programming, greedy algorithms, search
– Sampling
2
3
4
5
x1 x2 x3 y3 y2 y1 Suppose each y can be one of A, B or C, and the true label is (𝑧1 = A, 𝑧2 = B, 𝑧3 = C) 𝐳 = (𝑧1, 𝑧2, 𝑧3)
6
x1 x2 x3 y3 y2 y1 𝑑(𝐵, 𝐵, 𝐵) = 1 𝑑(𝐵, 𝐵, 𝐶) = 1 𝑑(𝐵, 𝐵, 𝐷) = 1 … 𝑑(𝐵, 𝐶, 𝐷) = 0 … 𝑑(𝐷, 𝐷, 𝐶) = 1 𝑑(𝐷, 𝐷, 𝐷) = 1 𝑑(𝐵, 𝐵, 𝐵) = 2 𝑑(𝐵, 𝐵, 𝐶) = 2 𝑑(𝐵, 𝐵, 𝐷) = 1 … 𝑑(𝐵, 𝐶, 𝐷) = 0 … 𝑑(𝐷, 𝐷, 𝐶) = 3 𝑑(𝐷, 𝐷, 𝐷) = 2 Hamming Distance
Suppose each y can be one of A, B or C, and the true label is (𝑧1 = A, 𝑧2 = B, 𝑧3 = C) 𝐳 = (𝑧1, 𝑧2, 𝑧3) The cost vector for this input x can be: The goal: Learn a classifier that has lowest cost What is the dimension of the cost vector c?
7
x1 x2 x3 y3 y2 y1 𝑑(𝐵, 𝐵, 𝐵) = 1 𝑑(𝐵, 𝐵, 𝐶) = 1 𝑑(𝐵, 𝐵, 𝐷) = 1 … 𝑑(𝐵, 𝐶, 𝐷) = 0 … 𝑑(𝐷, 𝐷, 𝐶) = 1 𝑑(𝐷, 𝐷, 𝐷) = 1 Suppose each y can be one of A, B or C, and the true label is (𝑧1 = A, 𝑧2 = B, 𝑧3 = C) 𝐳 = (𝑧1, 𝑧2, 𝑧3) The cost vector for this input x can be: The goal: Learn a classifier that has lowest cost
8
x1 x2 x3 y3 y2 y1 𝑑(𝐵, 𝐵, 𝐵) = 1 𝑑(𝐵, 𝐵, 𝐶) = 1 𝑑(𝐵, 𝐵, 𝐷) = 1 … 𝑑(𝐵, 𝐶, 𝐷) = 0 … 𝑑(𝐷, 𝐷, 𝐶) = 1 𝑑(𝐷, 𝐷, 𝐷) = 1 𝑑(𝐵, 𝐵, 𝐵) = 2 𝑑(𝐵, 𝐵, 𝐶) = 2 𝑑(𝐵, 𝐵, 𝐷) = 1 … 𝑑(𝐵, 𝐶, 𝐷) = 0 … 𝑑(𝐷, 𝐷, 𝐶) = 3 𝑑(𝐷, 𝐷, 𝐷) = 2 Hamming Distance
Suppose each y can be one of A, B or C, and the true label is (𝑧1 = A, 𝑧2 = B, 𝑧3 = C) 𝐳 = (𝑧1, 𝑧2, 𝑧3) The cost vector for this input x can be: The goal: Learn a classifier that has lowest cost
9
x1 x2 x3 y3 y2 y1 𝑑(𝐵, 𝐵, 𝐵) = 1 𝑑(𝐵, 𝐵, 𝐶) = 1 𝑑(𝐵, 𝐵, 𝐷) = 1 … 𝑑(𝐵, 𝐶, 𝐷) = 0 … 𝑑(𝐷, 𝐷, 𝐶) = 1 𝑑(𝐷, 𝐷, 𝐷) = 1 𝑑(𝐵, 𝐵, 𝐵) = 2 𝑑(𝐵, 𝐵, 𝐶) = 2 𝑑(𝐵, 𝐵, 𝐷) = 1 … 𝑑(𝐵, 𝐶, 𝐷) = 0 … 𝑑(𝐷, 𝐷, 𝐶) = 3 𝑑(𝐷, 𝐷, 𝐷) = 2 Hamming Distance
Suppose each y can be one of A, B or C, and the true label is (𝑧1 = A, 𝑧2 = B, 𝑧3 = C) 𝐳 = (𝑧1, 𝑧2, 𝑧3) The cost vector for this input x can be: The goal: Learn a classifier that has lowest cost What is the dimension of the cost vector c?
10
11
12
7 2 4 5 blank 6 8 3 1
13
blank 1 2 3 4 5 6 7 8 Initial State Goal State
7 2 4 5 blank 6 8 3 1
14
blank 1 2 3 4 5 6 7 8 Initial State Goal State Initial state: s0 Actions: Actions(s) Transition model: Result(s, a) Goal test Path cost / score What are these five components for 8-puzzle?
15
How do we solve a search problem? Answer: By starting at the initial state, and navigating the state space till we get to an answer
16
17
18
19
20
21
22
23
24
25
Predicting an output 𝐳 as a sequence of decisions
26
Predicting an output 𝐳 as a sequence of decisions
27
Predicting an output 𝐳 as a sequence of decisions
28
Predicting an output 𝐳 as a sequence of decisions
29
Predicting an output 𝐳 as a sequence of decisions
30
Predicting an output 𝐳 as a sequence of decisions
31
Predicting an output 𝐳 as a sequence of decisions
32
x1 x2 x3 y3 y2 y1 Suppose each y can be one
33
x1 x2 x3 y3 y2 y1
Suppose each y can be one
34
x1 x2 x3 y3 y2 y1
(-,-,-) (A,-,-) (B,-,-) (C,-,-) (A,A,-) (C,C,-) (A,A,A) (C,C,C) ….. Suppose each y can be one
35
[Hal Daumé III and Daniel Marcu, ICML 2005]
36
37
38
39
40
41
The goal is to learn w. How?
42
43
44
45
Suppose each y can be one
label is (y1=A, y2=B, y3=C) y = (y1, y2, y3)
46
Suppose each y can be one
label is (y1=A, y2=B, y3=C) y = (y1, y2, y3) (-,-,-) (A,-,-) (-,B,-) (C,-,-) (A,A,-) (C,C,-) (A,A,A) (C,C,C) …..
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
node 1 y-good node 2 y-good node 4 y-good current node 3 y-good node 5 y-good
66
Let’s say we found an error (of either type) at the current node, then we should have made the choice of node 4 instead of the current node
node 1 y-good node 2 y-good node 4 y-good current node 3 y-good node 5 y-good
67
Let’s say we found an error (of either type) at the current node, then we should have made the choice of node 4 instead of the current node Node 4 is the y-good sibling of the current node
68
69
70
71
72
73
It comes with the usual perceptron-style mistake bound and generalization bound. (See references)
n∈sibs
n∈nodes
74
Hal Daumé III, John Langford, Daniel Marcu (2007)
75
76
77
78
79
80
81
π πref
ref
82
π πref
ref
For example if we are using Hamming distance for cost vector 𝐝, then the reference policy is trivial to compute, why?
83
π πref
ref
For example if we are using Hamming distance for cost vector 𝐝, then the reference policy is trivial to compute, why? Just make the right decision at every step
84
π πref
ref
For example if we are using Hamming distance for cost vector 𝐝, then the reference policy is trivial to compute, why? Just make the right decision at every step Suppose gold state is (A, B, C, A) and we are at the state (A, C, -, -) The reference policy tells us the next action is assigned C to the third slot.
– 𝑦, 𝑧 ∈ 𝑌 ×[𝐿]
– min= Pr ℎ 𝑦 ≠ 𝑧
– 𝑦, 𝐝 ∈ 𝑌 × 0, ∞ S
– min= 𝐹>,T 𝑑= >
85
– 𝑦, 𝑧 ∈ 𝑌 ×[𝐿]
– min= Pr ℎ 𝑦 ≠ 𝑧
– 𝑦, 𝐝 ∈ 𝑌 × 0, ∞ S
– min= 𝐹>,T 𝑑= >
86
– 𝑦, 𝑧 ∈ 𝑌 ×[𝐿]
– min= Pr ℎ 𝑦 ≠ 𝑧
– 𝑦, 𝐝 ∈ 𝑌 × 0, ∞ S
– min= 𝐹>,T 𝑑= >
87
– 𝑦, 𝑧 ∈ 𝑌 ×[𝐿]
– min= Pr ℎ 𝑦 ≠ 𝑧
– 𝑦, 𝐝 ∈ 𝑌 × 0, ∞ S
– min= 𝐹>,T 𝑑= >
88
Exercise: How would you design a cost- sensitive learner?
– 𝑦, 𝑧 ∈ 𝑌 ×[𝐿]
– min= Pr ℎ 𝑦 ≠ 𝑧
– 𝑦, 𝐝 ∈ 𝑌 × 0, ∞ S
– min= 𝐹>,T 𝑑= >
89
SEARN uses a cost-sensitive learner to learn a policy
90
91
92
93
94
95
96
97
98
99
roll in At each state, use some policy to move to a new state.
100
roll in What is the cost of deviating from the policy at this step?
101
roll in
What is the cost of deviating from the policy at this step? Assuming that there are three possible actions at this state
102
roll in
What is the cost of deviating from the policy at this step?
103
roll in
roll out roll out What is the cost of deviating from the policy at this step? Once we make the one- step deviation, we could use some policy to get to a goal state again
104
roll in
roll out roll out What is the cost of deviating from the policy at this step?
E E E
ro roll llin in rol rollo lout ut
deviations
loss=.2 loss=0 loss=.8
105
E E E
ro roll llin in rol rollo lout ut
deviations
loss=.2 loss=0 loss=.8
106
E E E
ro roll llin in rol rollo lout ut
deviations
loss=.2 loss=0 loss=.8
107
E E E
ro roll llin in rol rollo lout ut
deviations
loss=.2 loss=0 loss=.8
108
E E E
ro roll llin in rol rollo lout ut
deviations
loss=.2 loss=0 loss=.8
109
E E E
ro roll llin in rol rollo lout ut
deviations
loss=.2 loss=0 loss=.8
110
E E E
ro roll llin in rol rollo lout ut
deviations
loss=.2 loss=0 loss=.8
111
E E E
ro roll llin in rol rollo lout ut
deviations
loss=.2 loss=0 loss=.8
112
h ← βh0 + (1 − β)h
E E E
ro roll llin in rol rollo lout ut
deviations
loss=.2 loss=0 loss=.8
113
Roll-in with current policy h
h ← βh0 + (1 − β)h
E E E
ro roll llin in rol rollo lout ut
deviations
loss=.2 loss=0 loss=.8
114
Roll-in with current policy h Roll-out with current policy h
h ← βh0 + (1 − β)h
E E E
ro roll llin in rol rollo lout ut
deviations
loss=.2 loss=0 loss=.8
115
Roll-in with current policy h Roll-out with current policy h
a0 cy(s,a0,h)
E E E
ro roll llin in rol rollo lout ut
deviations
loss=.2 loss=0 loss=.8
116
Roll-in with current policy h Roll-out with current policy h
a0 cy(s,a0,h)
lh(c, s, a) = Ey∼(s,a,h)cy − min
a0 Ey∼(s,a0,h)cy
E E E
ro roll llin in rol rollo lout ut
deviations
loss=.2 loss=0 loss=.8
117
Roll-in with current policy h Roll-out with current policy h
a0 cy(s,a0,h)
lh(c, s, a) = Ey∼(s,a,h)cy − min
a0 Ey∼(s,a0,h)cy
The loss defined this way is called regret
118
[Stéphane Ross, Geoffrey J. Gordon, J. Andrew Bagnell, 2011]
π πref
ref
} π π1
1
π π2
2
119
π πref
ref
} π π1
1
π π2
2
120
π πref
ref
} π π1
1
π π2
2
121
π πref
ref
} π π1
1
π π2
2
122
π πref
ref
} π π1
1
π π2
2
123
π πref
ref
} π π1
1
π π2
2
124
π πref
ref
} π π1
1
π π2
2
125
π πref
ref
} π π1
1
π π2
2
126
π πref
ref
} π π1
1
π π2
2
127
π πref
ref
} π π1
1
π π2
2
128
129
130
131
132
133
134
135
136
137
138