Learning Neural Causal Models From Unknown Interventions Summary - - PowerPoint PPT Presentation
Learning Neural Causal Models From Unknown Interventions Summary - - PowerPoint PPT Presentation
Learning Neural Causal Models From Unknown Interventions Summary The relationship between each variable and its parents is modeled by a neural network, modulated by structural meta-parameters which capture the overall topology of a directed
Summary
- The relationship between each variable and its parents is
modeled by a neural network, modulated by structural meta-parameters which capture the overall topology of a directed graphical model.
- Assume Random intervention on a single unknown
variable of an unknown ground truth causal model.
- To disentangle the slow-changing aspects of each
conditional from the fast-changing adaptations to each intervention, the neural network is parameterized into fast parameters and slow meta-parameters.
Summary
- meta-learning objective that favors solutions robust to
frequent but sparse interventional distribution change.
- Challenging aspect of this setting is to not only learn the
causal graph structure, but also predict the intervention accurately.
Task Description
- At most one intervention is concurrently performed.
- Soft/imperfect Interventions: the conditional distribution of the variable on which
the intervention is performed is changed.
- Provided data:
- data from the original ground-truth model.
- data from a modified ground-truth model with a random intervention applied.
- Learner aware of the intervention, but not aware of the node. (each run an
episode)
- The learner, over a large number of episodes, will experience all nodes being
intervened upon, and should be able to infer the SCM from these interventions.
Objectives
- Avoid an exponential search over all possible DAGs
- Handle unknown interventions
- Model the effect of interventions
- Model the underlying causal structure.
- Function approximation: 1 NN/var
- Belief over
: drop-out probability for i-th input of network j
- Represents all
possible graphs.
- Learning drop-out probability
- Prevents discrete search
i → j 2M2
Model
- SCM (
Categorical Random Variables and Categories):
- Configuration
- counts the number of length-n walks from node to node
- f the graph in element
.
- counts the number of length-n cycles in the graph
- M
N C ∈ {0,1}M×M Cn i j cij Tr(Cn) C = ?
Model
- Consider two node graph
.
- :
- r
= 1 (versus 0)
- and
- ,
- Problem becomes simultaneously learning the structural meta-parameters
and the functional meta-parameters .
- : Easily learned by ML (Back Propagation)
- : More Difficult (Bengio et al. 2019)
A, B M = 2 cAB cBA P(cAB = 1) = σ(γAB) P(cBA = 1) = σ(γBA) γ θ θ γ
Model
Problems
- An M-variable SCM over random variables
can induce a super-exponential number of adjacency matrices .
- The super-exponential explosion in the number of
potential graph connectivity patterns
- The super-exponentially growing storage requirements of
their defining conditional probability tables make CPT- based parametrization of the structural assignments increasingly unwieldy as M scales.
Xi C fi
Solution
- Neural networks with
- masked inputs can provide a
more manageable parametrization.
cij
Proposed Method
- Disentangle :
- : The slow-changing meta-parameters, which
reflect the stationary properties discovered by the learner.
- : The fast-changing parameters, which adapt in
response to interventional changes in distribution.
θ θslow θfast
Proposed Method
- Two kinds of meta-parameters: the causal graph structure
and the model’s slow weights .
- Model’s fast weights
.
- , the sum of the slow, stationary meta-
parameters and the fast, adaptational parameters.
γ θslow θfast θ = θslow + θfast
Optimization Problem
- The strategy of considering all the possible structural graphs as
separate hypotheses is not feasible because it would require maintaining O( ) models of the data.
- Sampling
independently using Bernoulli Distribution.
- We only need to learn the
coefficients .
- a slight dependency between the
is induced if we require the causal graph to be acyclic.
- A regularizer acting on the .
2M2 cij M2 γij cij γ
Optimization Problem
- Each random variable
.
- is a neural network (MLP) with parameters
- And
- ptimized to maximize the likelihood of data under the model
- ptimized with respect to a meta-learning objective arising from changes in
distribution because of interventions.
- Analogous to an ensemble of neural nets differing by their binary input
dropout masks, which select what variables are used as predictors of another variable.
Xi = fθi(ci0 × X0, ci1 × X1, …, cim × Xm, ϵi) fθi θi cij ∼ Bin (σ(γij)) θ γ
Learning
- To disentangle:
- Environment’s stable, unchanging properties (the
causal structure)
- From unstable, changing properties (the effects of an
intervention)
- MLP:
, where .
Pi(Xi|Xpa(i); θi) θ = θslow + θfast
Learning
- are reset after each episode of transfer distribution adaptation
- Since an intervention is generally not persistent from one transfer
distribution to another.
- meta-parameters (
) are preserved, then updated after each episode.
- The meta-objective for each meta-example over some intervention
distribution is the following:
θfast θslow, γ Dint
Learning
- The meta-objective for each meta-example over some intervention distribution
is the following meta-transfer loss:
- is an example sampled from the intervention distribution
, is an adjacency matrix drawn from our belief distribution (parametrized by ) about graph structure configurations.
- The likelihood of the -th variable
- f the sample when predicting it under
the configuration from the set of its putative parents:
Dint X Dint C γ i Xi X C
Learning
- A discrete Bernoulli random sampling process is used to produce the
configurations under which the log-likelihood of data samples is obtained.
- A gradient estimator is required to propagate gradient through to the
structural meta-parameters.
- superscript indicates the values obtained for the -th draw of
.
- This gradient is estimated solely with
because estimates employing have much greater variance.
γ
(k)
k C θslow θ
Acyclic Constraint
Acyclic Constraint
Predicting Interventions
- After an intervention on
, the gradients into and the slow weights for the -th conditional are false, because they do not bear the blame for ’s outcome (which lies with the intervener).
- The conditional likelihood of the intervened variable tends
to have a poorer relative likelihood under .
- Hence, the variable with the greatest deterioration in
likelihood is picked as a good guess.
Xi γi i Xi Dint
Model Description
- The MLPs are identical in shape but do not share any
parameters, since they are modeling independent causal
- mechanisms. (M one-hot vectors (nodes) of length N each)
cij ∼ Bin (σ(γij))
Xi = fθi(ci0 × X0, ci1 × X1, …, cim × Xm, ϵi)
Stability of Training
- Simultaneous training of both the structural and the
functional meta-parameters.
- These are not independent and do influence each other,
which leads to instability in training.
- Pre-train the model under observational data (from the
distribution of the data before interventions) using dropout on the inputs.
- functional meta-parameters
are not too biased towards certain configurations of the meta-parameters .
θslow γ
Regularizers
- DAG Constraint
- Sparsity:
- sparse representation of edges in the causal graph.
- L1 regularizer
- Slightly faster convergence
Temperature
- A temperature hyperparameter to encourage the
groundtruth model to generate some very rare events in the conditional probability tables (CPTs) more frequently.
- The near-ground-truth MLP model’s logit outputs are divided
by the temperature before being used for sampling.
All in all: Algorithm
Pre-train on Observational Data Predict Intervened Node
Simulations: Synthetic Data
The results are, however, sensitive to some hyperparameters, notably the DAG penalty and the sparsity penalty.
Simulations: Real-World Data
- BNLearn: Earthquake, Cancer, Asia
- variables respectively, maximum 2 parents per node.
- Learn a near-ground-truth MLP from the dataset’s CPT and use it as the ground-truth data generator.
- In spite of same causal graphs, CPTs were different; hence different SCMs.
M = 5,5,8
Simulations: Comparison
- Peters et al., (2016): ICP
- Eaton & Murphy (2007a): uncertain interventions
- Peters et al. (2016): unknown interventions
- However, neither attempt to predict the intervention.
Importance of Dropout
- Used for initial training on observational data.
- Fully connected off-diagonal (the most DOF).
- Pre-training cannot be carried out this way.
Importance of Intervention Prediction
- After the intervention has been performed, the learner
draws data samples from the intervention distribution and computes the per-variable average log-probability under sampled adjacency matrices.
- The variable consistently producing the least-likely
- utputs is predicted to be the intervention node.