Recent work in Truncated Statistics
Andrew Ilyas
Recent work in Truncated Statistics Andrew Ilyas Motivation: - - PowerPoint PPT Presentation
Recent work in Truncated Statistics Andrew Ilyas Motivation: Poincar and the Baker Motivation: Poincar and the Baker Motivation: Poincar and the Baker Claimed weight: 1 kg/loaf Motivation: Poincar and the Baker Claimed weight: 1
Andrew Ilyas
Claimed weight: 1 kg/loaf
Claimed weight: 1 kg/loaf Average weight: 950 g/loaf
Claimed weight: 1 kg/loaf Average weight: 1.05 kg/loaf
Claimed weight: 1 kg/loaf Average weight: 1.05 kg/loaf
1 kg
Frequency
x ∼ 𝒪(μ, Σ)
Sample x
x ∼ 𝒪(μ, Σ)
Sample x
x ∈ S
x ∼ 𝒪(μ, Σ)
Sample x
x ∈ S
Observe x
x ∼ 𝒪(μ, Σ)
Sample x
x ∈ S x ∉ S
Observe x
x ∼ 𝒪(μ, Σ)
Sample x
x ∈ S
Throw away and restart
x
x ∉ S
Observe x
x ∼ 𝒪(μ, Σ)
Sample x
x ∈ S
Throw away and restart
x
x ∉ S
Observe x
Goal: Obtain estimates from samples
( ̂ μ, ̂ Σ) ≈ (μ, Σ)
x ∼ 𝒪(μ, Σ)
Sample x
x ∈ S
Throw away and restart
x
x ∉ S
Observe x
Goal: Obtain estimates from samples
( ̂ μ, ̂ Σ) ≈ (μ, Σ)
and from truncated to . Which is which?
𝒪([0,1], I) 𝒪([0,1],4 I) [−0.5,0.5] × [1.5,2.5]
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
( ̂ μ, ̂ Σ) = arg max
(μ,Σ) ∑ xi
log(fN(xi; μ, Σ)) = arg max
(μ,Σ) ∑ xi
(xi − μ)⊤Σ−1(xi − μ)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
( ̂ μ, ̂ Σ) = arg max
(μ,Σ) ∑ xi
log(fN(xi; μ, Σ)) = arg max
(μ,Σ) ∑ xi
(xi − μ)⊤Σ−1(xi − μ)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
( ̂ μ, ̂ Σ) = arg max
(μ,Σ) ∑ xi
log(fN(xi; μ, Σ)) = arg max
(μ,Σ) ∑ xi
(xi − μ)⊤Σ−1(xi − μ)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
( ̂ μ, ̂ Σ) = arg max
(μ,Σ) ∑ xi
log(fN(xi; μ, Σ)) = arg max
(μ,Σ) ∑ xi
(xi − μ)⊤Σ−1(xi − μ) ̂ μ = 1 n ∑
xi
xi
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
( ̂ μ, ̂ Σ) = arg max
(μ,Σ) ∑ xi
log(fN(xi; μ, Σ)) = arg max
(μ,Σ) ∑ xi
(xi − μ)⊤Σ−1(xi − μ) ̂ μ = 1 n ∑
xi
xi ̂ Σ = 1 n ∑
xi
(xi − ̂ μ)(xi − ̂ μ)⊤
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
∇μlog(f(x; v, T, S)) = 𝔽z∼𝒪(μ,Σ)[z|z ∈ S] − x
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
∇μlog(f(x; v, T, S)) = 𝔽z∼𝒪(μ,Σ)[z|z ∈ S] − x
∇Σlog(f(x; v, T, S)) = 1 2 xx⊤ − 1 2 𝔽z∼𝒪(μ,Σ) [zz⊤|z ∈ S]
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
∇μlog(f(x; v, T, S)) = 𝔽z∼𝒪(μ,Σ)[z|z ∈ S] − x
∇Σlog(f(x; v, T, S)) = 1 2 xx⊤ − 1 2 𝔽z∼𝒪(μ,Σ) [zz⊤|z ∈ S]
Expected truncated mean/ covariance under current params
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
∇μlog(f(x; v, T, S)) = 𝔽z∼𝒪(μ,Σ)[z|z ∈ S] − x
∇Σlog(f(x; v, T, S)) = 1 2 xx⊤ − 1 2 𝔽z∼𝒪(μ,Σ) [zz⊤|z ∈ S]
Empirical (batch) mean/covariance Expected truncated mean/ covariance under current params
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
∇μlog(f(x; v, T, S)) = 𝔽z∼𝒪(μ,Σ)[z|z ∈ S] − x
∇Σlog(f(x; v, T, S)) = 1 2 xx⊤ − 1 2 𝔽z∼𝒪(μ,Σ) [zz⊤|z ∈ S]
Empirical (batch) mean/covariance Expected truncated mean/ covariance under current params
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
α S
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
α S
α
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
α S
α
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
α S
α
S
Projected Gradient Descent on the Negative Log-Likelihood (NLL)
α S
α
S
xi yi
xi yi
xi yi
What we expect:
xi yi
What we expect:
xi yi
What we get:
z
xi yi
What we get:
Good enough for NBA!
z
xi yi
What we get:
Good enough for NBA!
z
ability
xi yi
What we get:
Good enough for NBA!
z
ability height
xi yi
What we get:
Good enough for NBA!
ε
z
ability height
xi yi
What we get:
Good enough for NBA!
ε
z
ability height
xi yi
What we get:
Good enough for NBA!
NBA? ε
z
ability height
xi yi
What we get:
Good enough for NBA!
NBA? ε
z
ability height Yes Observe yi
xi yi
What we get:
Good enough for NBA!
NBA? ε
z
ability height No Player unobserved Yes Observe yi
xi yi
What we get:
Good enough for NBA!
NBA? ε
z
ability height No Player unobserved Yes Observe yi
based on the value of yi
Not a hypothetical problem (or a new one!)
Fig 1 [Hausman and Wise 1977]
Not a hypothetical problem (or a new one!)
Fig 1 [Hausman and Wise 1977] Corrected previous findings about education (x) vs income (y) affected by truncation on income (y)
Not a hypothetical problem (or a new one!)
Fig 1 [Hausman and Wise 1977] Corrected previous findings about education (x) vs income (y) affected by truncation on income (y) Table 1 [Lin et al 1999]
Not a hypothetical problem (or a new one!)
Fig 1 [Hausman and Wise 1977] Corrected previous findings about education (x) vs income (y) affected by truncation on income (y) Table 1 [Lin et al 1999] Found bias in income (x) vs child support (y) because respondence rate differs based on y
Not a hypothetical problem (or a new one!)
Fig 1 [Hausman and Wise 1977] Corrected previous findings about education (x) vs income (y) affected by truncation on income (y) Table 1 [Lin et al 1999] Found bias in income (x) vs child support (y) because respondence rate differs based on y
Not a hypothetical problem (or a new one!)
Has inspired lots of prior work in statistics/econometrics Our goal: unified efficient (polynomial in dimension) algorithm
[Galton 1897; Pearson 1902; Lee 1914; Fisher 1931; Hotelling 1948; Tukey 1949; Tobin 1958; Amemiya 1973; Breen 1996; Balakrishnan, Cramer 2014]
Sample a covariate x
Sample a covariate x
Sample a covariate x
z = hθ*(x) + ε ε ∼ DN
Sample noise ε, compute latent z
Sample a covariate x
z = hθ*(x) + ε ε ∼ DN
Sample noise ε, compute latent z w.p. 1 - φ(z)
Sample a covariate x
z = hθ*(x) + ε ε ∼ DN
Sample noise ε, compute latent z Throw away (x,z) and restart w.p. 1 - φ(z)
Sample a covariate x
z = hθ*(x) + ε ε ∼ DN
Sample noise ε, compute latent z Throw away (x,z) and restart w.p. 1 - φ(z)
Sample a covariate x
z = hθ*(x) + ε ε ∼ DN
Sample noise ε, compute latent z w.p. φ(z) Throw away (x,z) and restart w.p. 1 - φ(z)
Sample a covariate x
z = hθ*(x) + ε ε ∼ DN
Sample noise ε, compute latent z w.p. φ(z) Throw away (x,z) and restart w.p. 1 - φ(z)
y := π(z)
Project z to a label y
Sample a covariate x
z = hθ*(x) + ε ε ∼ DN
Sample noise ε, compute latent z w.p. φ(z) Throw away (x,z) and restart w.p. 1 - φ(z)
y := π(z)
Project z to a label y
Sample a covariate x
z = hθ*(x) + ε ε ∼ DN
Sample noise ε, compute latent z w.p. φ(z)
T ∪ {(x, y)}
Add (x,y) to training set Throw away (x,z) and restart w.p. 1 - φ(z)
y := π(z)
Project z to a label y
where , want estimate
yi ∼ π (hθ*(xi) + ε) ε ∼ DN ̂ θ for θ*
where , want estimate
yi ∼ π (hθ*(xi) + ε) ε ∼ DN ̂ θ for θ*
where , want estimate
yi ∼ π (hθ*(xi) + ε) ε ∼ DN ̂ θ for θ*
p(θ; x, y) = ∫z∈π−1(y)
DN(z − hθ(x)) dz
where , want estimate
yi ∼ π (hθ*(xi) + ε) ε ∼ DN ̂ θ for θ*
p(θ; x, y) = ∫z∈π−1(y)
DN(z − hθ(x)) dz
All possible latent variables corresponding to label
where , want estimate
yi ∼ π (hθ*(xi) + ε) ε ∼ DN ̂ θ for θ*
p(θ; x, y) = ∫z∈π−1(y)
DN(z − hθ(x)) dz
Likelihood of latent under model All possible latent variables corresponding to label
where , want estimate
yi ∼ π (hθ*(xi) + ε) ε ∼ DN ̂ θ for θ*
p(θ; x, y) = ∫z∈π−1(y)
DN(z − hθ(x)) dz
is a linear function, then:
hθ
Likelihood of latent under model All possible latent variables corresponding to label
where , want estimate
yi ∼ π (hθ*(xi) + ε) ε ∼ DN ̂ θ for θ*
p(θ; x, y) = ∫z∈π−1(y)
DN(z − hθ(x)) dz
is a linear function, then:
hθ
and , MLE is ordinary least squares regression
π(z) = z ε ∼ 𝒪(0,1)
Likelihood of latent under model All possible latent variables corresponding to label
where , want estimate
yi ∼ π (hθ*(xi) + ε) ε ∼ DN ̂ θ for θ*
p(θ; x, y) = ∫z∈π−1(y)
DN(z − hθ(x)) dz
is a linear function, then:
hθ
and , MLE is ordinary least squares regression
π(z) = z ε ∼ 𝒪(0,1)
and , MLE is probit regression
π(z) = 1z≥0 ε ∼ 𝒪(0,1)
Likelihood of latent under model All possible latent variables corresponding to label
where , want estimate
yi ∼ π (hθ*(xi) + ε) ε ∼ DN ̂ θ for θ*
p(θ; x, y) = ∫z∈π−1(y)
DN(z − hθ(x)) dz
is a linear function, then:
hθ
and , MLE is ordinary least squares regression
π(z) = z ε ∼ 𝒪(0,1)
and , MLE is probit regression
π(z) = 1z≥0 ε ∼ 𝒪(0,1)
and , MLE is logistic regression
π(z) = 1z≥0 ε ∼ Logistic(0,1)
Likelihood of latent under model All possible latent variables corresponding to label
where , want estimate
yi ∼ π (hθ*(xi) + ε) ε ∼ DN ̂ θ for θ*
p(θ; x, y) = ∫z∈π−1(y)
DN(z − hθ(x)) dz
is a linear function, then:
hθ
and , MLE is ordinary least squares regression
π(z) = z ε ∼ 𝒪(0,1)
and , MLE is probit regression
π(z) = 1z≥0 ε ∼ 𝒪(0,1)
and , MLE is logistic regression
π(z) = 1z≥0 ε ∼ Logistic(0,1)
Likelihood of latent under model All possible latent variables corresponding to label
Main idea: maximization of the truncated log-likelihood
Main idea: maximization of the truncated log-likelihood
p(θ; x, y) = ∫z∈π−1(y) DN(z − hθ(x)) dz
Main idea: maximization of the truncated log-likelihood
p(θ; x, y) = ∫z∈π−1(y) DN(z − hθ(x)) dz
Main idea: maximization of the truncated log-likelihood
p(θ; x, y) = ∫z∈π−1(y) DN(z − hθ(x))ϕ(z) dz ∫z DN(z − hθ(x))ϕ(z) dz p(θ; x, y) = ∫z∈π−1(y) DN(z − hθ(x)) dz
Main idea: maximization of the truncated log-likelihood
p(θ; x, y) = ∫z∈π−1(y) DN(z − hθ(x))ϕ(z) dz ∫z DN(z − hθ(x))ϕ(z) dz p(θ; x, y) = ∫z∈π−1(y) DN(z − hθ(x)) dz
Main idea: maximization of the truncated log-likelihood
p(θ; x, y) = ∫z∈π−1(y) DN(z − hθ(x))ϕ(z) dz ∫z DN(z − hθ(x))ϕ(z) dz p(θ; x, y) = ∫z∈π−1(y) DN(z − hθ(x)) dz
Main idea: maximization of the truncated log-likelihood
Leads to another SGD-based algorithm
ϕ ⟹
p(θ; x, y) = ∫z∈π−1(y) DN(z − hθ(x))ϕ(z) dz ∫z DN(z − hθ(x))ϕ(z) dz p(θ; x, y) = ∫z∈π−1(y) DN(z − hθ(x)) dz
Main idea: maximization of the truncated log-likelihood
Leads to another SGD-based algorithm
ϕ ⟹
p(θ; x, y) = ∫z∈π−1(y) DN(z − hθ(x))ϕ(z) dz ∫z DN(z − hθ(x))ϕ(z) dz p(θ; x, y) = ∫z∈π−1(y) DN(z − hθ(x)) dz
Main idea: maximization of the truncated log-likelihood
θ ℓ(θ)
θ ℓ(θ)
θ ℓ(θ)
θ ℓ(θ)
Definition (Quasi-convexity): For all , we have
f(y) ≤ f(x) ⟨∇f(x), y − x⟩ ≤ 0
θ ℓ(θ)
Definition (Quasi-convexity): For all , we have
f(y) ≤ f(x) ⟨∇f(x), y − x⟩ ≤ 0
[Hazan et al, 2015] define strict local quasi-convexity (SLQC) property: both stronger (inner product bounded away from zero) and weaker ( is constrained to a ball around ) than just QC
y x*
θ ℓ(θ)
Definition (Quasi-convexity): For all , we have
f(y) ≤ f(x) ⟨∇f(x), y − x⟩ ≤ 0
[Hazan et al, 2015] define strict local quasi-convexity (SLQC) property: both stronger (inner product bounded away from zero) and weaker ( is constrained to a ball around ) than just QC
y x*
Their result: normalized SGD with minimum batch size converges to global optimum for SLQC functions
converges to maximizer of the (population) log-likelihood
converges to maximizer of the (population) log-likelihood
projection set where linear, probit, and logistic regression are all SLQC NSGD converges
⟹
converges to maximizer of the (population) log-likelihood
projection set where linear, probit, and logistic regression are all SLQC NSGD converges
⟹
shown strongly convex by [Daskalakis et al, 2019]
Sample a covariate x
converges to maximizer of the (population) log-likelihood
projection set where linear, probit, and logistic regression are all SLQC NSGD converges
⟹
shown strongly convex by [Daskalakis et al, 2019]
Sample a covariate x
converges to maximizer of the (population) log-likelihood
projection set where linear, probit, and logistic regression are all SLQC NSGD converges
⟹
shown strongly convex by [Daskalakis et al, 2019]
Sample a covariate x Pass to linear model, sample normal/logistic
z = hθ(x) + ε w⊤
* x
converges to maximizer of the (population) log-likelihood
projection set where linear, probit, and logistic regression are all SLQC NSGD converges
⟹
shown strongly convex by [Daskalakis et al, 2019]
Sample a covariate x Pass to linear model, sample normal/logistic
z = hθ(x) + ε w⊤
* x
Truncate to interval [a,b]
z = hθ(x) + ε w⊤
* x
ϕ(z)
0 b a
converges to maximizer of the (population) log-likelihood
projection set where linear, probit, and logistic regression are all SLQC NSGD converges
⟹
shown strongly convex by [Daskalakis et al, 2019]
Sample a covariate x Pass to linear model, sample normal/logistic
z = hθ(x) + ε w⊤
* x
Truncate to interval [a,b]
z = hθ(x) + ε w⊤
* x
ϕ(z)
0 b a
converges to maximizer of the (population) log-likelihood
projection set where linear, probit, and logistic regression are all SLQC NSGD converges
⟹
shown strongly convex by [Daskalakis et al, 2019]
Sample a covariate x Pass to linear model, sample normal/logistic
z = hθ(x) + ε w⊤
* x
Truncate to interval [a,b]
z = hθ(x) + ε w⊤
* x
ϕ(z)
0 b a Project to get a label
z = hθ(x) + ε w⊤
* x
y = 1 y = 0
π(z)
converges to maximizer of the (population) log-likelihood
projection set where linear, probit, and logistic regression are all SLQC NSGD converges
⟹
shown strongly convex by [Daskalakis et al, 2019]
Sample a covariate x Pass to linear model, sample normal/logistic
z = hθ(x) + ε w⊤
* x
Truncate to interval [a,b]
z = hθ(x) + ε w⊤
* x
ϕ(z)
0 b a Project to get a label
z = hθ(x) + ε w⊤
* x
y = 1 y = 0
π(z)
converges to maximizer of the (population) log-likelihood
projection set where linear, probit, and logistic regression are all SLQC NSGD converges
⟹
shown strongly convex by [Daskalakis et al, 2019]
Theorem (informal): if for every , there is a non-zero ( ) probability that , then NSGD finds an -minimizer of the NLL in steps.
x ∈ ℝd α > 0 y = {0,1} ε
poly(1/α,1/ε, d)
Synthetic data
Synthetic data
Synthetic data
Synthetic data
Synthetic data
ε ∼ DN
Synthetic data
ε ∼ DN
* X + ε
Synthetic data
ε ∼ DN
* X + ε
Synthetic data
ε ∼ DN
* X + ε
Synthetic data
−2 −1
0.2 0.4 0.6 0.8 1 Truncation parameter C Cosine similarity with θ∗
ε ∼ DN
* X + ε
Standard regression Truncated regression
Synthetic data
−2 −1
0.2 0.4 0.6 0.8 1 Truncation parameter C Cosine similarity with θ∗
−2 −1
0.2 0.4 0.6 0.8 1 Truncation parameter C
ε ∼ DN
* X + ε
Standard regression Truncated regression
UCI MSD dataset
UCI MSD dataset
UCI MSD dataset
X :
UCI MSD dataset
X :
Z :
UCI MSD dataset
X :
Z :
UCI MSD dataset
X :
Z :
Y :
UCI MSD dataset
X :
Z :
Y :
1,985 1,990 1,995 2,000 45 55 65 75 Truncation parameter C Test set accuracy Standard regression Truncated regression
Mixture of two Gaussians [Nagarajan & Panageas, 2019]
Mixture of two Gaussians [Nagarajan & Panageas, 2019]
Mixture of two Gaussians [Nagarajan & Panageas, 2019]
Mixture of two Gaussians [Nagarajan & Panageas, 2019]
Mixture of two Gaussians [Nagarajan & Panageas, 2019]
maximization method, gives local improvement guarantee
Mixture of two Gaussians [Nagarajan & Panageas, 2019]
maximization method, gives local improvement guarantee
Unknown truncation set [Kontonis et al, 2019]
Unknown truncation set [Kontonis et al, 2019]
Unknown truncation set [Kontonis et al, 2019]
the space of possible sets has bounded VC dimension, or Gaussian surface area (measures of complexity):
Unknown truncation set [Kontonis et al, 2019]
the space of possible sets has bounded VC dimension, or Gaussian surface area (measures of complexity):
High-dimensional (sparse) setting [Daskalakis et al, 2020]
High-dimensional (sparse) setting [Daskalakis et al, 2020]
covariates are very high dimensional, but -sparse
High-dimensional (sparse) setting [Daskalakis et al, 2020]
covariates are very high dimensional, but -sparse
algorithm for dealing with truncation
High-dimensional (sparse) setting [Daskalakis et al, 2020]
covariates are very high dimensional, but -sparse
algorithm for dealing with truncation
school studies, non-response in surveys)
school studies, non-response in surveys)