SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient
Lantao Yu†, Weinan Zhang†, Jun Wang‡, Yong Yu†
†Shanghai Jiao Tong University, ‡University College London
SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient - - PowerPoint PPT Presentation
SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient Lantao Yu , Weinan Zhang , Jun Wang , Yong Yu Shanghai Jiao Tong University, University College London Attribution Multiple slides taken from
Lantao Yu†, Weinan Zhang†, Jun Wang‡, Yong Yu†
†Shanghai Jiao Tong University, ‡University College London
https://github.com/hindupuravinash/the-gan-zoo (not updated since 2018.09)
More than 500 species in the zoo
https://github.com/hindupuravinash/the-gan-zoo
GAN ACGAN BGAN DCGAN EBGAN fGAN GoGAN CGAN
……
Mihaela Rosca, Balaji Lakshminarayanan, David Warde-Farley, Shakir Mohamed, “Variational Approaches for Auto-Encoding Generative Adversarial Networks”, arXiv, 2017
Generator “Girl with red hair” Generator
−0.3 0.1 ⋮ 0.9 random vector
image
Generator
text image
paired data blue eyes, red hair, short hair
Photo Vincent van Gogh’s style
unpaired data x y domain x domain y
Draw Generator Examples
Generator
It is a neural network (NN), or a function.
Generator
0.1 −3 ⋮ 2.4 0.9 image
vector Generator
3 −3 ⋮ 2.4 0.9
Generator
0.1 2.1 ⋮ 5.4 0.9
Generator
0.1 −3 ⋮ 2.4 3.5
high dimensional vector
Powered by: http://mattya.github.io/chainer-DCGAN/ Each dimension of input vector represents some characteristics. Longer hair blue hair Open mouth
Discri- minator
scalar
image
It is a neural network (NN), or a function.
Larger value means real, smaller value means fake. Discri- minator Discri- minator Discri- minator
1.0 1.0 0.1
Discri- minator
0.1
D G sample generated
G
D Update
vector vector vector vector
1 1 1 1 randomly sampled Database
Step 1: Fix generator G, and update discriminator D Discriminator learns to assign high scores to real objects and low scores to generated objects.
Fix
D G
Step 2: Fix discriminator D, and update generator G
Discri- minator NN Generator vector 0.13 hidden layer
update fix
Gradient Ascent
large network Generator learns to “fool” the discriminator
D G Learning D Sample some real objects: Generate some fake objects: G
D Update Learning G G D
image
1 1 1 1
image image image
1
update fix
vector vector vector vector vector vector vector vector
fix
100 updates Source of training data: https://zhuanlan.zhihu.com/p/24767059
1000 updates
2000 updates
5000 updates
10,000 updates
20,000 updates
50,000 updates
Source of video: https://www.gwern.net/Faces
[Ian J. Goodfellow, et al., NIPS, 2014]
1. Reinforcement Learning 2. GAN + RL
NLP tasks usually involve Se Sequence Generatio ion How to use GAN to improve sequence generation?
Human Input sentence c response sentence x Chatbot
En De
response sentence x Input sentence c
[Li, et al., EMNLP , 2016]
reward
𝑆 𝑑, 𝑦
Learn to maximize expected reward E.g. Policy Gradient human “How are you?” “Not bad” “I’m John”
+1
𝜄𝑢
𝑑1, 𝑦1 𝑑2, 𝑦2 𝑑𝑂, 𝑦𝑂
……
𝑆 𝑑1, 𝑦1 𝑆 𝑑2, 𝑦2 𝑆 𝑑𝑂, 𝑦𝑂
……
1 𝑂
𝑗=1 𝑂
𝑆 𝑑𝑗, 𝑦𝑗 𝛼𝑚𝑝𝑄𝜄𝑢 𝑦𝑗|𝑑𝑗
𝜄𝑢+1 ← 𝜄𝑢 + 𝜃𝛼 ത 𝑆𝜄𝑢
𝑆 𝑑𝑗, 𝑦𝑗 is positive Updating 𝜄 to increase 𝑄𝜄 𝑦𝑗|𝑑𝑗 𝑆 𝑑𝑗, 𝑦𝑗 is negative Updating 𝜄 to decrease 𝑄𝜄 𝑦𝑗|𝑑𝑗
1 𝑂
𝑗=1 𝑂
𝑆 𝑑𝑗, 𝑦𝑗 𝛼𝑚𝑝𝑄𝜄 𝑦𝑗|𝑑𝑗 1 𝑂
𝑗=1 𝑂
𝑚𝑝𝑄𝜄 ො 𝑦𝑗|𝑑𝑗 1 𝑂
𝑗=1 𝑂
𝛼𝑚𝑝𝑄𝜄 ො 𝑦𝑗|𝑑𝑗 1 𝑂
𝑗=1 𝑂
𝑆 𝑑𝑗, 𝑦𝑗 𝑚𝑝𝑄𝜄 𝑦𝑗|𝑑𝑗 𝑆 𝑑𝑗, ො 𝑦𝑗 = 1
weighted by 𝑆 𝑑𝑗, 𝑦𝑗 Objective Function Gradient Maximum Likelihood Reinforcement Learning - Policy Gradient Training Data 𝑑1, ො 𝑦1 , … , 𝑑𝑂, ො 𝑦𝑂 𝑑1, 𝑦1 , … , 𝑑𝑂, 𝑦𝑂
1. Reinforcement Learning 2. GAN + RL
Encoder Decoder Input sentence c
sentence x Training data: A: How are you ? B: I’m good. …… …… How are you ? I’m good. Seq2seq Output: Not bad I’m John. Maximize likelihood Training Criterion Human better better
Discriminator Input sentence c response sentence x Chatbot
En De
response sentence x Input sentence c reward
𝑆 𝑑, 𝑦
I am busy.
Replace human evaluation with machine evaluation
[Li, et al., EMNLP , 2017]
However, there is an issue when you train your generator.
Gumbel-softmax
Continuous Input for Discriminator
Xu, et al., EMNLP, 2017][Alex Lamb, et al., NIPS, 2016][Yizhe Zhang, et al., ICML, 2017]
Reinforcement Learning
, 2017][Tong Che, et al, arXiv, 2017][Jiaxian Guo, et al., AAAI, 2018][Kevin Lin, et al, NIPS, 2017][William Fedus, et al., ICLR, 2018]
A A A B A B A A B B B <BOS>
Use the distribution as the input of discriminator Avoid the sampling process
Discriminator scalar
Update Parameters We can do backpropagation now.
1 1 1 1 1 0.9 0.1 0.1 0.9 0.1 0.1 0.7 0.1 0.1 0.8 0.1 0.1 0.9
Can never be 1-hot Discriminator can immediately find the difference. Discriminator with constraint (e.g. WGAN) can be helpful.
Gumbel-softmax
Continuous Input for Discriminator
Xu, et al., EMNLP, 2017][Alex Lamb, et al., NIPS, 2016][Yizhe Zhang, et al., ICML, 2017]
Reinforcement Learning
2017][Jiaxian Guo, et al., AAAI, 2018][Kevin Lin, et al, NIPS, 2017][William Fedus, et al., ICLR, 2018]
.
RL is difficult to train GAN is difficult to train Sequence Generation GAN (RL+GAN)
Discrimi nator Chatbot En De You is good Discrimi nator Chatbot En De 0.9 0.1 0.1 0.1 You You is You is good I don’t know which part is wrong …
Discrimi nator Chatbot En De 0.9 0.1 0.1 You You is You is good Method 2. Discriminator For Partially Decoded Sequences Method 1. Monte Carlo (MC) Search [Yu, et al., AAAI, 2017]
[Li, et al., EMNLP , 2017]
Method 3. Step-wise evaluation[Tual, Lee, TASLP
, 2019][Xu, et al., EMNLP , 2018][William Fedus, et al., ICLR, 2018]
generative model Gθ to produce sequences that mimic the real ones.
ptrue(yt|Y1:t—1), which is only revealed by the given dataset D
= {Y1:T} .
the learned model
in the inference stag: exposure bias
Training Inference
When generating the next token , sample from:
yt max
θ
EY ∼ptrue X
t
log Gθ(yt|Y1:t−1) Gθ( ˆ yt| ˆ Y1:t−1)
Update the model as follows:
The real prefix The guessed prefix
max
θ
1 |D| X
Y1:T ∈D
X
t
log[Gθ(yt|Y1:t−1)]
the fake model-generated data
discriminator
data, G nicely fits the true underlying data distribution
Real World Generator Discriminator Data
[Goodfellow I, Pouget-Abadie J, Mirza M, et al. 2014. Generative adversarial nets. In NIPS 2014.]
perceptron
guidance from it
P (true) P (true)
G D
min
G max D Ex∼pdata(x)[log D(x)] + Ez∼pz(z)[log(1 − D(G(z))]
can't change that to "penguin + .001" on the next step, because there is no such word as "penguin + .001". You have to go all the way from "penguin" to "ostrich".”
P (true) P (true)
G D
[https://www.reddit.com/r/MachineLearning/comments/40ldq6/generative_adversarial_networks_for_text/]
rθ(G) 1 m
m
X
i=1
log(1 D(G(z(i))))
as the state
being true data) for the sequence Gθ(yt|Y1:t−1) Dφ(Y n
1:T )
is the expected accumulative reward that
sequence (no immediate reward)
J(θ) = E[RT |s0, θ] = X
y1∈Y
Gθ(y1|s0) · QGθ
Dφ(s0, y1)
QGθ
Dφ(s, a)
QGθ
Dφ(s = Y1:T −1, a = yT ) = Dφ(Y1:T )
QGθ
Dφ(s = Y1:T −1, a = yT ) = Dφ(Y1:T )
QGθ
Dφ(s = Y1:t−1, a = yt) =
⇢
1 N
PN
n=1 Dφ(Y n 1:T ), Y n 1:T ∈ MCGβ(Y1:t; N)
for t < T Dφ(Y1:t) for t = T,
1:T , . . . , Y N 1:T
= MCGβ(Y1:t; N)
min
φ −EY ∼pdata[log Dφ(Y )] − EY ∼Gθ[log(1 − Dφ(Y ))]
[Richard Sutton et al. Policy Gradient Methods for Reinforcement Learning with Function Approximation. NIPS 1999.]
rθJ(θ) = EY1:t−1∼Gθ[ X
yt∈Y
rθGθ(yt|Y1:t−1) · QGθ
Dφ(Y1:t−1, yt)]
' 1 T
T
X
t=1
X
yt∈Y
rθGθ(yt|Y1:t−1) · QGθ
Dφ(Y1:t−1, yt)
= 1 T
T
X
t=1
X
yt∈Y
Gθ(yt|Y1:t−1)rθ log Gθ(yt|Y1:t−1) · QGθ
Dφ(Y1:t−1, yt)
= 1 T
T
X
t=1
Eyt∼Gθ(yt|Y1:t−1)[rθ log Gθ(yt|Y1:t−1) · QGθ
Dφ(Y1:t−1, yt)],
θ θ + αhrθJ(θ)
[Hochreiter, S., and Schmidhuber, J. 1997. Long short-term memory. Neural computation 9(8):1735–1780.]
Shanghai is incredibly is incredibly Softmax sampling
?
[Kim, Y. 2014. Convolutional neural networks for sentence classification. EMNLP 2014.]
true data is with a high mass density
model
Evaluation Use
model-generated data is considered as real as possible
but it is hard or impossible to directly calculate
with a certain generalization ability
Gθ max
θ
1 |D| X
x∈D
[log Gθ(x)] Ex∼ptrue(x)[log Gθ(x)] Ex∼Gθ(x)[log ptrue(x)] ptrue(x)
training data for the generative model
human observer to accurately evaluate the perceptual quality of the generative model
NLLoracle = −EY1:T ∼Gθ h
T
X
t=1
log Goracle(yt|Y1:t−1) i
NLLoracle = −EY1:T ∼Gθ h
T
X
t=1
log Goracle(yt|Y1:t−1) i
Machine
and most important thing that not on violence throughout the horizon is OTHERS american fire and OTHERS but we need you are a strong source
remember now i can’t afford to start with just the way our european support for the right thing to protect those american story from the world and
were going to be an
student medical education and warm the republicans who like my times if he said is that brought the
extraordinary honor that he was the most trusted man in america
celebrate the journalism that walter practiced a standard of honesty and integrity and responsibility to which so many
little bit harder to find today
tribute to the life and times of the man who chronicled our time.
Human Machine
SeqGAN, to effectively train Generative Adversarial Nets for discrete structured sequences generation via policy gradient.
metric to accurately evaluate the “perceptual quality”
Sequences.
convincing results.
training, GAN parameters, g-steps, d-steps MC tree depth etc.
estimate as very less samples are used in each episode. [Jigyasa]
very creative idea that provides a nice way to automatically compare how close the generator distribution is to the actual model of the world. [Rajas]
[Vipul]
Siddhant, Saransh, Rajas]
Atishya, Vipul, Saransh]
[Atishya, Jigyasa, Siddhant, Rajas]
1. K-discriminators trained with partial/complete sequences [Keshav] 2. K distinct Ds are expensive. Weight sharing [Atishya] 3. Use LM for intermediate rewards. ”Surprise” value [Rajas/Soumya/Saransh]
1. Won’t work [Saransh]
Generation
Gumbel-softmax
Continuous Input for Discriminator
Xu, et al., EMNLP, 2017][Alex Lamb, et al., NIPS, 2016][Yizhe Zhang, et al., ICML, 2017]
Reinforcement Learning
2017][Jiaxian Guo, et al., AAAI, 2018][Kevin Lin, et al, NIPS, 2017][William Fedus, et al., ICLR, 2018]