Introduction to Super Learning
Ted Westling, PhD Postdoctoral Researcher Center for Causal Inference Perelman School of Medicine University of Pennsylvania September 25, 2018
1 / 48
Introduction to Super Learning Ted Westling, PhD Postdoctoral - - PowerPoint PPT Presentation
Introduction to Super Learning Ted Westling, PhD Postdoctoral Researcher Center for Causal Inference Perelman School of Medicine University of Pennsylvania September 25, 2018 1 / 48 Learning Goals Conceptual understanding of Super
Ted Westling, PhD Postdoctoral Researcher Center for Causal Inference Perelman School of Medicine University of Pennsylvania September 25, 2018
1 / 48
2 / 48
2 / 48
2 / 48
15 minute break
3 / 48
15 minute break
minutes)
minutes)
4 / 48
4 / 48
5 / 48
5 / 48
(Y1, X1), . . . , (Yn, Xn) from the joint distribution of (Y, X).
5 / 48
6 / 48
– Conditional mean (regression) function
6 / 48
– Conditional mean (regression) function – Conditional quantile function
6 / 48
– Conditional mean (regression) function – Conditional quantile function – Conditional density function
6 / 48
– Conditional mean (regression) function – Conditional quantile function – Conditional density function – Conditional hazard function
6 / 48
– Conditional mean (regression) function – Conditional quantile function – Conditional density function – Conditional hazard function
6 / 48
– Conditional mean (regression) function – Conditional quantile function – Conditional density function – Conditional hazard function
µ(x) := E[Y | X = x].
6 / 48
7 / 48
7 / 48
7 / 48
estimators
7 / 48
estimators
7 / 48
estimators
7 / 48
estimators
(not our goal here)
7 / 48
We want to estimate µ(x) = E[Y | X = x]. How should we do it?
8 / 48
We want to estimate µ(x) = E[Y | X = x]. How should we do it?
8 / 48
We want to estimate µ(x) = E[Y | X = x]. How should we do it?
GAM
8 / 48
We want to estimate µ(x) = E[Y | X = x]. How should we do it?
GAM
Random Forest
8 / 48
We want to estimate µ(x) = E[Y | X = x]. How should we do it?
GAM
Random Forest
Neural network
8 / 48
We want to estimate µ(x) = E[Y | X = x]. How should we do it?
GAM
Random Forest
Neural network GLM
8 / 48
How do we choose which algorithm to use?
9 / 48
Super Learning is: An ensemble method for combining predictions from many candidate machine learning algorithms
10 / 48
µ1, . . . , ˆ µK are candidate estimators of µ.
11 / 48
µ1, . . . , ˆ µK are candidate estimators of µ.
11 / 48
µ1, . . . , ˆ µK are candidate estimators of µ.
µk, MSE(ˆ µk) = E
µk(X))2 measures the performance of ˆ µk as an estimator of µ.
11 / 48
µ1, . . . , ˆ µK are candidate estimators of µ.
µk, MSE(ˆ µk) = E
µk(X))2 measures the performance of ˆ µk as an estimator of µ.
µk), we could choose the ˆ µk with the smallest MSE(ˆ µk).
11 / 48
MSE(ˆ µk) = E
µk(X))2
12 / 48
MSE(ˆ µk) = E
µk(X))2
MSE(ˆ µk) = 1
n
n
i=1 [Yi − ˆ
µk(Xi)]2.
12 / 48
MSE(ˆ µk) = E
µk(X))2
MSE(ˆ µk) = 1
n
n
i=1 [Yi − ˆ
µk(Xi)]2.
µk which are overfit, because ˆ µk are trained on the same data used to evaluate the MSE.
12 / 48
MSE(ˆ µk) = E
µk(X))2
MSE(ˆ µk) = 1
n
n
i=1 [Yi − ˆ
µk(Xi)]2.
µk which are overfit, because ˆ µk are trained on the same data used to evaluate the MSE.
the exam!
12 / 48
MSE(ˆ µk) = E
µk(X))2
MSE(ˆ µk) = 1
n
n
i=1 [Yi − ˆ
µk(Xi)]2.
µk which are overfit, because ˆ µk are trained on the same data used to evaluate the MSE.
the exam!
12 / 48
13 / 48
13 / 48
1 2 3 4 5 6 7 8 9 10 1 Fold 1 1 2 3 4 5 6 7 8 9 10 2 Fold 2 1 2 3 4 5 6 7 8 9 10 3 Fold 3 1 2 3 4 5 6 7 8 9 10 4 Fold 4 1 2 3 4 5 6 7 8 9 10 5 Fold 5 1 2 3 4 5 6 7 8 9 10 6 Fold 6 1 2 3 4 5 6 7 8 9 10 7 Fold 7 1 2 3 4 5 6 7 8 9 10 8 Fold 8 1 2 3 4 5 6 7 8 9 10 9 Fold 9 1 2 3 4 5 6 7 8 9 10 10 Fold 10 Schematic of 10-fold cross-validation. Gray: training sets. Yellow: validation sets. 14 / 48
15 / 48
µk,v using the training set;
15 / 48
µk,v using the training set;
µk,v(Xi) for Xi in the validation set Vv.
15 / 48
µk,v using the training set;
µk,v(Xi) for Xi in the validation set Vv.
µk) = 1 V
V
1 |Vv|
[Yi − ˆ µk,v(Xi)]2.
15 / 48
µk,v using the training set;
µk,v(Xi) for Xi in the validation set Vv.
µk) = 1 V
V
1 |Vv|
[Yi − ˆ µk,v(Xi)]2. We average the MSEs of the V validation sets.
15 / 48
1 2 3 4 5 6 7 8 9 10 1 Fold 1 1 2 3 4 5 6 7 8 9 10 2 Fold 2 1 2 3 4 5 6 7 8 9 10 3 Fold 3 1 2 3 4 5 6 7 8 9 10 4 Fold 4 1 2 3 4 5 6 7 8 9 10 5 Fold 5 1 2 3 4 5 6 7 8 9 10 6 Fold 6 1 2 3 4 5 6 7 8 9 10 7 Fold 7 1 2 3 4 5 6 7 8 9 10 8 Fold 8 1 2 3 4 5 6 7 8 9 10 9 Fold 9 1 2 3 4 5 6 7 8 9 10 10 Fold 10 1 2 3 4 5 6 7 8 9 10 CV preds. Schematic of 10-fold cross-validation. Gray: training sets. Yellow: validation sets. 16 / 48
17 / 48
– more training data, so better for small n
17 / 48
– more training data, so better for small n – more computation time
17 / 48
– more training data, so better for small n – more computation time – well-suited to high-dimensional covariates
17 / 48
– more training data, so better for small n – more computation time – well-suited to high-dimensional covariates – well-suited to complicated or non-smooth µ
17 / 48
– more training data, so better for small n – more computation time – well-suited to high-dimensional covariates – well-suited to complicated or non-smooth µ
17 / 48
– more training data, so better for small n – more computation time – well-suited to high-dimensional covariates – well-suited to complicated or non-smooth µ
– more test data
17 / 48
– more training data, so better for small n – more computation time – well-suited to high-dimensional covariates – well-suited to complicated or non-smooth µ
– more test data – less computation time.
17 / 48
– more training data, so better for small n – more computation time – well-suited to high-dimensional covariates – well-suited to complicated or non-smooth µ
– more test data – less computation time. (People typically use V = 5 or V = 10.)
17 / 48
µ1), . . . , MSECV(ˆ µK) for each of our candidate algorithms.
18 / 48
µ1), . . . , MSECV(ˆ µK) for each of our candidate algorithms.
µk minimizing these cross-validated MSEs.
18 / 48
µ1), . . . , MSECV(ˆ µK) for each of our candidate algorithms.
µk minimizing these cross-validated MSEs.
18 / 48
K-dimensional simplex: each λk ∈ [0, 1] and
k λk = 1.
19 / 48
K-dimensional simplex: each λk ∈ [0, 1] and
k λk = 1.
all convex combinations ˆ µλ := K
k=1 λk ˆ
µk.
19 / 48
K-dimensional simplex: each λk ∈ [0, 1] and
k λk = 1.
all convex combinations ˆ µλ := K
k=1 λk ˆ
µk.
µ
λ, where
λ∈SK
K
λk ˆ µk
(We use constrained optimization to compute the argmin.)
19 / 48
λ∈SK
K
λk ˆ µk
20 / 48
λ∈SK
K
λk ˆ µk
K
λk ˆ µk
V
V
1 |Vv|
K
λk ˆ µk,v(Xi) 2 .
20 / 48
λ∈SK
K
λk ˆ µk
K
λk ˆ µk
V
V
1 |Vv|
K
λk ˆ µk,v(Xi) 2 .
20 / 48
Putting it all together:
21 / 48
Putting it all together:
µ1, . . . , ˆ µK.
21 / 48
Putting it all together:
µ1, . . . , ˆ µK.
µk,v(Xi) for all k, v and i ∈ Vv.
21 / 48
Putting it all together:
µ1, . . . , ˆ µK.
µk,v(Xi) for all k, v and i ∈ Vv.
MSECV K
k=1 λk ˆ
µk
21 / 48
Putting it all together:
µ1, . . . , ˆ µK.
µk,v(Xi) for all k, v and i ∈ Vv.
MSECV K
k=1 λk ˆ
µk
µSL = K
k=1
λk ˆ µk.
21 / 48
21 / 48
21 / 48
Recall the construction of SL for a continuous outcome:
22 / 48
Recall the construction of SL for a continuous outcome:
µ1, . . . , ˆ µK.
µk,v(Xi) for all k, v and i ∈ Vv.
MSECV K
k=1 λk ˆ
µk
µSL = K
k=1
λk ˆ µk.
22 / 48
In this section, we generalize this procedure to estimation
appropriate loss for the summary of interest.
23 / 48
24 / 48
24 / 48
24 / 48
24 / 48
24 / 48
24 / 48
24 / 48
25 / 48
θ0 = arg min
θ∈Θ
EP0 [L(O, θ)] .
25 / 48
θ0 = arg min
θ∈Θ
EP0 [L(O, θ)] .
25 / 48
θ0 = arg min
θ∈Θ
EP0 [L(O, θ)] .
learning literature (see, e.g. Vapnik, 1992, 1999, 2013) and are not to be confused with loss and risk from the decision theory literature (e.g. Ferguson, 2014).
25 / 48
MSE is the oracle risk corresponding to a squared-error loss function
26 / 48
MSE is the oracle risk corresponding to a squared-error loss function
26 / 48
MSE is the oracle risk corresponding to a squared-error loss function
26 / 48
MSE is the oracle risk corresponding to a squared-error loss function
26 / 48
MSE is the oracle risk corresponding to a squared-error loss function
26 / 48
θ0 = arg min
θ∈Θ
R0(θ) R0(θ) = EP0[L(O, θ)]
27 / 48
θ0 = arg min
θ∈Θ
R0(θ) R0(θ) = EP0[L(O, θ)]
θ1, . . . , ˆ θK are candidate estimators.
27 / 48
θ0 = arg min
θ∈Θ
R0(θ) R0(θ) = EP0[L(O, θ)]
θ1, . . . , ˆ θK are candidate estimators.
θk.
27 / 48
θ0 = arg min
θ∈Θ
R0(θ) R0(θ) = EP0[L(O, θ)]
θ1, . . . , ˆ θK are candidate estimators.
θk.
R(ˆ θk) = 1
n
n
i=1 L(Oi, ˆ
θk).
27 / 48
θ0 = arg min
θ∈Θ
R0(θ) R0(θ) = EP0[L(O, θ)]
θ1, . . . , ˆ θK are candidate estimators.
θk.
R(ˆ θk) = 1
n
n
i=1 L(Oi, ˆ
θk).
θk) = 1 V
V
1 |Vv|
L(Oi, ˆ θk,v).
27 / 48
Using this framework, we can generalize the SL recipe:
28 / 48
Using this framework, we can generalize the SL recipe:
θ1, . . . , ˆ θK.
28 / 48
Using this framework, we can generalize the SL recipe:
θ1, . . . , ˆ θK.
RCV(ˆ θk), k = 1, . . . , K.
28 / 48
Using this framework, we can generalize the SL recipe:
θ1, . . . , ˆ θK.
RCV(ˆ θk), k = 1, . . . , K.
RCV K
k=1 λk ˆ
θk
28 / 48
Using this framework, we can generalize the SL recipe:
θ1, . . . , ˆ θK.
RCV(ˆ θk), k = 1, . . . , K.
RCV K
k=1 λk ˆ
θk
θSL = K
k=1
λk ˆ θk.
28 / 48
van der Vaart et al. (2006) showed that, under some conditions, the oracle risk of the SL estimator is as good as the oracle risk of the oracle minimizer up to a multiple of log n
n
as long as the number of candidate algorithms is polynomial in n.
29 / 48
We return to O = (Y, X), θ = µ.
30 / 48
We return to O = (Y, X), θ = µ.
30 / 48
We return to O = (Y, X), θ = µ.
30 / 48
We return to O = (Y, X), θ = µ.
functions for a binary outcome.
30 / 48
We return to O = (Y, X), θ = µ.
functions for a binary outcome. – Negative log-likelihood loss: L(O, µ) = −Y log µ(X) − [1 − Y] log[1 − µ(X)].
30 / 48
We return to O = (Y, X), θ = µ.
functions for a binary outcome. – Negative log-likelihood loss: L(O, µ) = −Y log µ(X) − [1 − Y] log[1 − µ(X)]. – AUC loss.
30 / 48
30 / 48
30 / 48
30 / 48
In this section, we will introduce three of the add-ons to SL that are frequently useful in practice: variable screens,
31 / 48
32 / 48
32 / 48
32 / 48
32 / 48
candidate algorithm, the SuperLearner package has built-in functionality to ease this process.
32 / 48
candidate algorithm, the SuperLearner package has built-in functionality to ease this process. Screening algorithms allow us to guide the SL using our domain knowledge.
32 / 48
different ways of reducing the dimensionality.
33 / 48
different ways of reducing the dimensionality.
might try providing a smaller number of summary measures – e.g. mean, median, min, max.
33 / 48
different ways of reducing the dimensionality.
might try providing a smaller number of summary measures – e.g. mean, median, min, max.
points, we might try providing just baseline, or just the last time point, or some summaries of the trajectory.
33 / 48
different ways of reducing the dimensionality.
might try providing a smaller number of summary measures – e.g. mean, median, min, max.
points, we might try providing just baseline, or just the last time point, or some summaries of the trajectory.
33 / 48
weights in the procedure – e.g. case-control sampling,
34 / 48
weights in the procedure – e.g. case-control sampling,
SuperLearner, but method.AUC does not make correct use of weights!!!!
34 / 48
weights in the procedure – e.g. case-control sampling,
SuperLearner, but method.AUC does not make correct use of weights!!!!
use of observation weights.
34 / 48
35 / 48
assayed.
35 / 48
assayed.
ncontrol total controls) are assayed.
35 / 48
assayed.
ncontrol total controls) are assayed.
status using the results of the assay and other covariates.
35 / 48
36 / 48
36 / 48
36 / 48
regression of the indicator of inclusion in the control cohort
36 / 48
before time t0.
37 / 48
before time t0.
Y = min{T, C} and ∆ = I(T ≤ C).
37 / 48
before time t0.
Y = min{T, C} and ∆ = I(T ≤ C).
µ(x) = P(T ≤ t0 | X = x) = E[Y | X = x].
37 / 48
µ0 = arg min
µ
EP0
G0(Y | X)L((Y, X), µ)
38 / 48
µ0 = arg min
µ
EP0
G0(Y | X)L((Y, X), µ)
∆ G0(Y|X).
38 / 48
µ0 = arg min
µ
EP0
G0(Y | X)L((Y, X), µ)
∆ G0(Y|X).
38 / 48
µ0 = arg min
µ
EP0
G0(Y | X)L((Y, X), µ)
∆ G0(Y|X).
⊥ T, we can use a Kaplan-Meier estimator for G0;
38 / 48
candidate algorithm.
39 / 48
candidate algorithm.
data, so their estimated risks will be optimistic.
39 / 48
candidate algorithm.
data, so their estimated risks will be optimistic.
cross-validation.
39 / 48
candidate algorithm.
data, so their estimated risks will be optimistic.
cross-validation.
39 / 48
40 / 48
40 / 48
V2-fold CV.
40 / 48
V2-fold CV.
validation set for fold v.
40 / 48
V2-fold CV.
validation set for fold v.
discrete SL and SL.
40 / 48
40 / 48
40 / 48
41 / 48
– Fluzone – inactivated influenza vaccine (IIV) – FluMist – live-attenuated influenza vaccine (LAIV) – placebo.
41 / 48
– Fluzone – inactivated influenza vaccine (IIV) – FluMist – live-attenuated influenza vaccine (LAIV) – placebo.
41 / 48
– Fluzone – inactivated influenza vaccine (IIV) – FluMist – live-attenuated influenza vaccine (LAIV) – placebo.
41 / 48
42 / 48
variety of markers (HAI, NAI, MN, AM titers, proteins/virus/peptide magnitude/breadth).
43 / 48
variety of markers (HAI, NAI, MN, AM titers, proteins/virus/peptide magnitude/breadth).
– Demographics: age, vaccinated in last year (EVERVAX)
43 / 48
variety of markers (HAI, NAI, MN, AM titers, proteins/virus/peptide magnitude/breadth).
– Demographics: age, vaccinated in last year (EVERVAX) – Day 0 markers
43 / 48
variety of markers (HAI, NAI, MN, AM titers, proteins/virus/peptide magnitude/breadth).
– Demographics: age, vaccinated in last year (EVERVAX) – Day 0 markers – Day 30 markers
43 / 48
variety of markers (HAI, NAI, MN, AM titers, proteins/virus/peptide magnitude/breadth).
– Demographics: age, vaccinated in last year (EVERVAX) – Day 0 markers – Day 30 markers – Difference markers = Day 30 markers - Day 0 markers
43 / 48
44 / 48
44 / 48
44 / 48
44 / 48
44 / 48
44 / 48
44 / 48
44 / 48
44 / 48
variables for predicting flu status in the placebo and Fluzone arms separately.
45 / 48
variables for predicting flu status in the placebo and Fluzone arms separately.
and both IgA + IgG measurements.
45 / 48
variables for predicting flu status in the placebo and Fluzone arms separately.
and both IgA + IgG measurements.
45 / 48
EV x (Day 0, Day 30) EV x (Day 0, Diff) Day 30 − Day 0 EV x Day 0 EV x Day 30 Baseline Day 0 Day 30 0.4 0.6 0.8 1.0 0.4 0.6 0.8 1.0 0.4 0.6 0.8 1.0 SL Discrete SL SL.glm SL.bayesglm SL.glmnet SL.earth SL.gam SL.xgboost SL.ranger SL.mean SL Discrete SL SL.glm SL.bayesglm SL.glmnet SL.earth SL.gam SL.xgboost SL.ranger SL.mean SL Discrete SL SL.glm SL.bayesglm SL.glmnet SL.earth SL.gam SL.xgboost SL.ranger SL.mean
AUC Learner Screen
screen.marginal.05 screen.marginal.10
IgA IgG Neither
46 / 48
EV x (Day 0, Day 30) EV x (Day 0, Diff) Day 30 − Day 0 EV x Day 0 EV x Day 30 Baseline Day 0 Day 30 0.2 0.4 0.6 0.8 1.0 0.2 0.4 0.6 0.8 1.0 0.2 0.4 0.6 0.8 1.0 SL Discrete SL SL.xgboost SL Discrete SL SL.xgboost SL Discrete SL SL.xgboost SL Discrete SL SL.bayesglm SL Discrete SL SL.xgboost SL Discrete SL SL.xgboost SL Discrete SL SL.glm SL Discrete SL SL.bayesglm SL Discrete SL SL.gam
AUC Learner Screen
screen.marginal.05 screen.marginal.10
IgA IgG Neither
47 / 48
EV x (Day 0, Day 30) EV x (Day 0, Diff) Day 30 − Day 0 EV x Day 0 EV x Day 30 Baseline Day 0 Day 30 0.4 0.6 0.8 1.0 0.4 0.6 0.8 1.0 0.4 0.6 0.8 1.0 SL Discrete SL SL.glm SL.bayesglm SL.glmnet SL.earth SL.gam SL.xgboost SL.ranger SL.mean SL Discrete SL SL.glm SL.bayesglm SL.glmnet SL.earth SL.gam SL.xgboost SL.ranger SL.mean SL Discrete SL SL.glm SL.bayesglm SL.glmnet SL.earth SL.gam SL.xgboost SL.ranger SL.mean
AUC Learner Screen
screen.marginal.05 screen.marginal.10
IgA IgG Neither
48 / 48
EV x (Day 0, Day 30) EV x (Day 0, Diff) Day 30 − Day 0 EV x Day 0 EV x Day 30 Baseline Day 0 Day 30 0.2 0.4 0.6 0.8 0.2 0.4 0.6 0.8 0.2 0.4 0.6 0.8 SL Discrete SL SL.bayesglm SL Discrete SL SL.bayesglm SL Discrete SL SL.bayesglm SL Discrete SL SL.bayesglm SL Discrete SL SL.bayesglm SL Discrete SL SL.bayesglm SL Discrete SL SL.ranger SL Discrete SL SL.bayesglm SL Discrete SL SL.bayesglm
AUC Learner
IgA IgG Neither
49 / 48
Ferguson, T. S. (2014). Mathematical statistics: A decision theoretic
van der Vaart, A. W., Dudoit, S., and van der Laan, M. J. (2006). Oracle inequalities for multi-fold cross validation. Statistics & Decisions, 24(3):351–371. Vapnik, V. (1992). Principles of risk minimization for learning theory. In Advances in Neural Information Processing Systems, pages 831–838. Vapnik, V. (2013). The nature of statistical learning theory. Springer Science & Business Media. Vapnik, V. N. (1999). An overview of statistical learning theory. IEEE Transactions on Neural Networks, 10(5):988–999.
50 / 48