learning neural causal models from unknown interventions
play

Learning Neural Causal Models From Unknown Interventions Summary - PowerPoint PPT Presentation

Learning Neural Causal Models From Unknown Interventions Summary The relationship between each variable and its parents is modeled by a neural network, modulated by structural meta-parameters which capture the overall topology of a directed


  1. Learning Neural Causal Models From Unknown Interventions

  2. Summary • The relationship between each variable and its parents is modeled by a neural network, modulated by structural meta-parameters which capture the overall topology of a directed graphical model. • Assume Random intervention on a single unknown variable of an unknown ground truth causal model. • To disentangle the slow-changing aspects of each conditional from the fast-changing adaptations to each intervention, the neural network is parameterized into fast parameters and slow meta-parameters.

  3. Summary • meta-learning objective that favors solutions robust to frequent but sparse interventional distribution change. • Challenging aspect of this setting is to not only learn the causal graph structure, but also predict the intervention accurately.

  4. Task Description • At most one intervention is concurrently performed. • Soft/imperfect Interventions: the conditional distribution of the variable on which the intervention is performed is changed. • Provided data: • data from the original ground-truth model. • data from a modified ground-truth model with a random intervention applied. • Learner aware of the intervention, but not aware of the node. (each run an episode) • The learner, over a large number of episodes, will experience all nodes being intervened upon, and should be able to infer the SCM from these interventions.

  5. Objectives • Avoid an exponential search over all possible DAGs • Handle unknown interventions • Model the e ff ect of interventions • Model the underlying causal structure.

  6. Model • Function approximation: 1 NN/var • Belief over : drop-out probability for i-th input of i → j network j • Represents all 2 M 2 possible graphs. • Learning drop-out probability • Prevents discrete search

  7. Model • SCM ( Categorical Random Variables and Categories): M N • Configuration C ∈ {0,1} M × M • C n counts the number of length-n walks from node to node i j of the graph in element . c ij • Tr(C n ) counts the number of length-n cycles in the graph • C = ?

  8. Model • Consider two node graph . A , B • : or = 1 (versus 0) M = 2 c AB c BA • and P ( c AB = 1) = σ ( γ AB ) P ( c BA = 1) = σ ( γ BA ) • , • Problem becomes simultaneously learning the structural meta-parameters γ and the functional meta-parameters . θ • : Easily learned by ML (Back Propagation) θ • : More Di ffi cult (Bengio et al. 2019) γ

  9. Problems • An M-variable SCM over random variables can induce X i a super-exponential number of adjacency matrices . C • The super-exponential explosion in the number of potential graph connectivity patterns • The super-exponentially growing storage requirements of their defining conditional probability tables make CPT- based parametrization of the structural assignments f i increasingly unwieldy as M scales.

  10. Solution • Neural networks with -masked inputs can provide a c ij more manageable parametrization.

  11. Proposed Method • Disentangle : θ • : The slow-changing meta-parameters, which θ slow reflect the stationary properties discovered by the learner. • : The fast-changing parameters, which adapt in θ fast response to interventional changes in distribution.

  12. Proposed Method • Two kinds of meta-parameters: the causal graph structure and the model’s slow weights . γ θ slow • Model’s fast weights . θ fast • , the sum of the slow, stationary meta- θ = θ slow + θ fast parameters and the fast, adaptational parameters.

  13. Optimization Problem • The strategy of considering all the possible structural graphs as separate hypotheses is not feasible because it would require 2 M 2 maintaining O( ) models of the data. • Sampling independently using Bernoulli Distribution. c ij • We only need to learn the M 2 coe ffi cients . γ ij • a slight dependency between the is induced if we require the c ij causal graph to be acyclic. • A regularizer acting on the . γ

  14. Optimization Problem • Each random variable . X i = f θ i ( c i 0 × X 0 , c i 1 × X 1 , …, c im × X m , ϵ i ) • is a neural network (MLP) with parameters f θ i θ i c ij ∼ Bin ( σ ( γ ij ) ) • And • optimized to maximize the likelihood of data under the model θ • optimized with respect to a meta-learning objective arising from changes in γ distribution because of interventions. • Analogous to an ensemble of neural nets di ff ering by their binary input dropout masks, which select what variables are used as predictors of another variable.

  15. Learning • To disentangle: • Environment’s stable, unchanging properties (the causal structure) • From unstable, changing properties (the e ff ects of an intervention) • MLP: , where . P i ( X i | X pa ( i ) ; θ i ) θ = θ slow + θ fast

  16. Learning • are reset after each episode of transfer distribution adaptation θ fast • Since an intervention is generally not persistent from one transfer distribution to another. • meta-parameters ( ) are preserved, then updated after each θ slow , γ episode. • The meta-objective for each meta-example over some intervention distribution is the following: D int

  17. Learning • The meta-objective for each meta-example over some intervention distribution is the following meta-transfer loss: D int • is an example sampled from the intervention distribution , is an X D int C adjacency matrix drawn from our belief distribution (parametrized by ) about γ graph structure configurations. • The likelihood of the -th variable of the sample when predicting it under i X i X the configuration from the set of its putative parents: C

  18. Learning • A discrete Bernoulli random sampling process is used to produce the configurations under which the log-likelihood of data samples is obtained. • A gradient estimator is required to propagate gradient through to the γ structural meta-parameters. • ( k ) superscript indicates the values obtained for the -th draw of . k C • This gradient is estimated solely with because estimates employing θ slow have much greater variance. θ

  19. Acyclic Constraint

  20. Acyclic Constraint

  21. Predicting Interventions • After an intervention on , the gradients into and the X i γ i slow weights for the -th conditional are false, because i they do not bear the blame for ’s outcome (which lies X i with the intervener). • The conditional likelihood of the intervened variable tends to have a poorer relative likelihood under . D int • Hence, the variable with the greatest deterioration in likelihood is picked as a good guess.

  22. Model Description • The MLPs are identical in shape but do not share any parameters, since they are modeling independent causal mechanisms. (M one-hot vectors (nodes) of length N each) X i = f θ i ( c i 0 × X 0 , c i 1 × X 1 , …, c im × X m , ϵ i ) c ij ∼ Bin ( σ ( γ ij ) )

  23. Stability of Training • Simultaneous training of both the structural and the functional meta-parameters. • These are not independent and do influence each other, which leads to instability in training. • Pre-train the model under observational data (from the distribution of the data before interventions) using dropout on the inputs. • functional meta-parameters are not too biased θ slow towards certain configurations of the meta-parameters . γ

  24. Regularizers • DAG Constraint • Sparsity: • sparse representation of edges in the causal graph. • L1 regularizer • Slightly faster convergence

  25. Temperature • A temperature hyperparameter to encourage the groundtruth model to generate some very rare events in the conditional probability tables (CPTs) more frequently. • The near-ground-truth MLP model’s logit outputs are divided by the temperature before being used for sampling.

  26. All in all: Algorithm Pre-train on Observational Data Predict Intervened Node

  27. Simulations: Synthetic Data The results are, however, sensitive to some hyperparameters, notably the DAG penalty and the sparsity penalty.

  28. Simulations: Real-World Data • BNLearn: Earthquake, Cancer, Asia • -variables respectively, maximum 2 parents per node. M = 5,5,8 • Learn a near-ground-truth MLP from the dataset’s CPT and use it as the ground-truth data generator. • In spite of same causal graphs, CPTs were di ff erent; hence di ff erent SCMs.

  29. Simulations: Comparison • Peters et al., (2016): ICP • Eaton & Murphy (2007a): uncertain interventions • Peters et al. (2016): unknown interventions • However, neither attempt to predict the intervention.

  30. Importance of Dropout • Used for initial training on observational data. • Fully connected o ff -diagonal (the most DOF). • Pre-training cannot be carried out this way.

  31. Importance of Intervention Prediction • After the intervention has been performed, the learner draws data samples from the intervention distribution and computes the per-variable average log-probability under sampled adjacency matrices. • The variable consistently producing the least-likely outputs is predicted to be the intervention node.

  32. Importance of Intervention Prediction

Download Presentation
Download Policy: The content available on the website is offered to you 'AS IS' for your personal information and use only. It cannot be commercialized, licensed, or distributed on other websites without prior consent from the author. To download a presentation, simply click this link. If you encounter any difficulties during the download process, it's possible that the publisher has removed the file from their server.

Recommend


More recommend