Adaptive Adversarial Multi-task Representation Learning
Yuren Mao1 Weiwei Liu2 Xuemin Lin1 WHU
C H I N A
1. University of New South Wales, Australia. 2. Wuhan University, China.
Adaptive Adversarial Multi-task Representation Learning Yuren Mao 1 - - PowerPoint PPT Presentation
ICML 2020 WHU C H I N A Adaptive Adversarial Multi-task Representation Learning Yuren Mao 1 Weiwei Liu 2 Xuemin Lin 1 1. University of New South Wales, Australia. 2. Wuhan University, China. Overview: Adaptive AMTRL (Adversarial Multi-task
Yuren Mao1 Weiwei Liu2 Xuemin Lin1 WHU
C H I N A
1. University of New South Wales, Australia. 2. Wuhan University, China.
…… ……
Task 1 Task T
……
Shared Layers Task Specific Layers Discriminator
…
MinMax Forward Propagation Backward Propagation
𝜖 𝑀(𝜄𝑡ℎ, 𝜄1) 𝜖𝜄1 𝜖 𝑀(𝜄𝑡ℎ, 𝜄𝑈) 𝜖𝜄𝑈 𝜖 𝑀(𝜄𝑡ℎ, 𝜄1) 𝜖𝜄𝑡ℎ 𝜖 𝑀(𝜄𝑡ℎ, 𝜄𝑈) 𝜖𝜄𝑡ℎ 𝜖 𝑀𝐸 (𝜄𝑡ℎ) 𝜖𝑋 −𝜖 𝑀𝐸 (𝜄𝑡ℎ) 𝜖𝜄𝑡ℎ
Gradient Reversal Layer Input Original MTRL
AMTRL Algorithm PAC Bound
LD(h) − LS(h) ≤ c1ρ Ga(G∗(X1)) n + c2Qsupg∈G∗∥g(X1)∥ √n +
2nT
Generalization Error The number of tasks does not matter Negligible
Better Performance
(a) Three 2-d Gaussian distributions (b) Discriminator (c) Relatedness changing curve 𝑡𝑝𝑔𝑢𝑛𝑏𝑦(𝑋𝑌 + 𝑐)Task Relatedness for AMTRL Adaptive AMTRL
Lagrangian
Weighting Strategy
…… ……
Task 1 Task T
……
Shared Layers Task Specific Layers Discriminator
…
MinMax Forward Propagation Backward Propagation
𝜖 𝑀(𝜄𝑡ℎ, 𝜄1) 𝜖𝜄1 𝜖 𝑀(𝜄𝑡ℎ, 𝜄𝑈) 𝜖𝜄𝑈 𝜖 𝑀(𝜄𝑡ℎ, 𝜄1) 𝜖𝜄𝑡ℎ 𝜖 𝑀(𝜄𝑡ℎ, 𝜄𝑈) 𝜖𝜄𝑡ℎ 𝜖 𝑀𝐸 (𝜄𝑡ℎ) 𝜖𝑋 −𝜖 𝑀𝐸 (𝜄𝑡ℎ) 𝜖𝜄𝑡ℎ
Gradient Reversal Layer Input Original MTRL
Adversarial Multi-task Representation Learning (AMTRL) has achieved success in various applications, ranging from sentiment analysis to question answering systems.
min
h L(h, λ) = LS(h) + λLadv
Ladv = max
Φ
1 nT
T
n
etΦ(g(xt
i))
LS(h) = 1 nT
T
n
lt(f t(g(xt
i)), yt i)
Empirical loss: Loss of the adversarial module:
Adversarial AMTRL aims to minimize the task-averaged empirical risk and enforce the representation of each task to share an identical distribution. We formulate it as a constraint optimization problem
min
h
LS(h) s.t. Ladv − c = 0,
and propose to solve the problem with an augmented Lagrangian method.
min
h
1 T LS(h) + λ(Ladv − c) + r 2(Ladv − c)2.
𝜇 and 𝑠 updates in the training process.
(a) Three 2-d Gaussian distributions (b) Discriminator (c) Relatedness changing curve
𝑡𝑝𝑔𝑢𝑛𝑏𝑦(𝑋𝑌 + 𝑐)
Rij = min{ N
n=1 ejΦ(g(xi n)) + eiΦ(g(xj n))
N
n=1 eiΦ(g(xi n)) + ejΦ(g(xj n))
, 1}
Relatedness between task i and task j:
R = R11 R12 · · · R1T R21 R22 · · · R2T . . . . . . ... . . . RT1 RT2 · · · RTT .
Relatedness matrix:
In multi-task learning, tasks regularize each other and improve the generalization of some tasks. The weights of each task influences the effect of the regularization. This paper proposes a weighting strategy for AMTRL based on the proposed task relatedness. where 1 is a 1×𝑈 vector of all 1, and 𝑆 is the relatedness matrix. Combining the augmented Lagrangian method with the weighting strategy, optimization
w = 1 1R1′1R,
min
h
1 T
T
wtLSt(f t ◦ g) + λ(Ladv − c) + r 2(Ladv − c)2.
LD(h) − LS(h) ≤ c1ρ Ga(G∗(X1)) n + c2Qsupg∈G∗∥g(X1)∥ √n +
2nT
Assume the representation of each task share an identical distribution, we have the following generalization error bound.
Generalization Error The number of tasks does not matter Negligible
Sentiment Analysis and Topic Classification. Sentiment Analysis. Topic Classification
Rt = 1 T
T
Rtk.
Mean of
Sentiment Analysis and Topic Classification. Sentiment Analysis. Topic Classification
Sentiment Analysis.
errel = erMTL
1 T
T
1 ert STL
Relative Error: Error rate for the task ’appeal’.