Learning Neural Causal Models From Unknown Interventions Summary - - PowerPoint PPT Presentation

learning neural causal models from unknown interventions
SMART_READER_LITE
LIVE PREVIEW

Learning Neural Causal Models From Unknown Interventions Summary - - PowerPoint PPT Presentation

Learning Neural Causal Models From Unknown Interventions Summary The relationship between each variable and its parents is modeled by a neural network, modulated by structural meta-parameters which capture the overall topology of a directed


slide-1
SLIDE 1

Learning Neural Causal Models From Unknown Interventions

slide-2
SLIDE 2

Summary

  • The relationship between each variable and its parents is

modeled by a neural network, modulated by structural meta-parameters which capture the overall topology of a directed graphical model.

  • Assume Random intervention on a single unknown

variable of an unknown ground truth causal model.

  • To disentangle the slow-changing aspects of each

conditional from the fast-changing adaptations to each intervention, the neural network is parameterized into fast parameters and slow meta-parameters.

slide-3
SLIDE 3

Summary

  • meta-learning objective that favors solutions robust to

frequent but sparse interventional distribution change.

  • Challenging aspect of this setting is to not only learn the

causal graph structure, but also predict the intervention accurately.

slide-4
SLIDE 4

Task Description

  • At most one intervention is concurrently performed.
  • Soft/imperfect Interventions: the conditional distribution of the variable on which

the intervention is performed is changed.

  • Provided data:
  • data from the original ground-truth model.
  • data from a modified ground-truth model with a random intervention applied.
  • Learner aware of the intervention, but not aware of the node. (each run an

episode)

  • The learner, over a large number of episodes, will experience all nodes being

intervened upon, and should be able to infer the SCM from these interventions.

slide-5
SLIDE 5

Objectives

  • Avoid an exponential search over all possible DAGs
  • Handle unknown interventions
  • Model the effect of interventions
  • Model the underlying causal structure.
slide-6
SLIDE 6
  • Function approximation: 1 NN/var
  • Belief over

: drop-out probability for i-th input of network j

  • Represents all

possible graphs.

  • Learning drop-out probability
  • Prevents discrete search

i → j 2M2

Model

slide-7
SLIDE 7
  • SCM (

Categorical Random Variables and Categories):

  • Configuration
  • counts the number of length-n walks from node to node
  • f the graph in element

.

  • counts the number of length-n cycles in the graph
  • M

N C ∈ {0,1}M×M Cn i j cij Tr(Cn) C = ?

Model

slide-8
SLIDE 8
  • Consider two node graph

.

  • :
  • r

= 1 (versus 0)

  • and
  • ,
  • Problem becomes simultaneously learning the structural meta-parameters

and the functional meta-parameters .

  • : Easily learned by ML (Back Propagation)
  • : More Difficult (Bengio et al. 2019)

A, B M = 2 cAB cBA P(cAB = 1) = σ(γAB) P(cBA = 1) = σ(γBA) γ θ θ γ

Model

slide-9
SLIDE 9

Problems

  • An M-variable SCM over random variables

can induce a super-exponential number of adjacency matrices .

  • The super-exponential explosion in the number of

potential graph connectivity patterns

  • The super-exponentially growing storage requirements of

their defining conditional probability tables make CPT- based parametrization of the structural assignments increasingly unwieldy as M scales.

Xi C fi

slide-10
SLIDE 10

Solution

  • Neural networks with
  • masked inputs can provide a

more manageable parametrization.

cij

slide-11
SLIDE 11

Proposed Method

  • Disentangle :
  • : The slow-changing meta-parameters, which

reflect the stationary properties discovered by the learner.

  • : The fast-changing parameters, which adapt in

response to interventional changes in distribution.

θ θslow θfast

slide-12
SLIDE 12

Proposed Method

  • Two kinds of meta-parameters: the causal graph structure

and the model’s slow weights .

  • Model’s fast weights

.

  • , the sum of the slow, stationary meta-

parameters and the fast, adaptational parameters.

γ θslow θfast θ = θslow + θfast

slide-13
SLIDE 13

Optimization Problem

  • The strategy of considering all the possible structural graphs as

separate hypotheses is not feasible because it would require maintaining O( ) models of the data.

  • Sampling

independently using Bernoulli Distribution.

  • We only need to learn the

coefficients .

  • a slight dependency between the

is induced if we require the causal graph to be acyclic.

  • A regularizer acting on the .

2M2 cij M2 γij cij γ

slide-14
SLIDE 14

Optimization Problem

  • Each random variable

.

  • is a neural network (MLP) with parameters
  • And
  • ptimized to maximize the likelihood of data under the model
  • ptimized with respect to a meta-learning objective arising from changes in

distribution because of interventions.

  • Analogous to an ensemble of neural nets differing by their binary input

dropout masks, which select what variables are used as predictors of another variable.

Xi = fθi(ci0 × X0, ci1 × X1, …, cim × Xm, ϵi) fθi θi cij ∼ Bin (σ(γij)) θ γ

slide-15
SLIDE 15

Learning

  • To disentangle:
  • Environment’s stable, unchanging properties (the

causal structure)

  • From unstable, changing properties (the effects of an

intervention)

  • MLP:

, where .

Pi(Xi|Xpa(i); θi) θ = θslow + θfast

slide-16
SLIDE 16

Learning

  • are reset after each episode of transfer distribution adaptation
  • Since an intervention is generally not persistent from one transfer

distribution to another.

  • meta-parameters (

) are preserved, then updated after each episode.

  • The meta-objective for each meta-example over some intervention

distribution is the following:

θfast θslow, γ Dint

slide-17
SLIDE 17

Learning

  • The meta-objective for each meta-example over some intervention distribution

is the following meta-transfer loss:

  • is an example sampled from the intervention distribution

, is an adjacency matrix drawn from our belief distribution (parametrized by ) about graph structure configurations.

  • The likelihood of the -th variable
  • f the sample when predicting it under

the configuration from the set of its putative parents:

Dint X Dint C γ i Xi X C

slide-18
SLIDE 18

Learning

  • A discrete Bernoulli random sampling process is used to produce the

configurations under which the log-likelihood of data samples is obtained.

  • A gradient estimator is required to propagate gradient through to the

structural meta-parameters.

  • superscript indicates the values obtained for the -th draw of

.

  • This gradient is estimated solely with

because estimates employing have much greater variance.

γ

(k)

k C θslow θ

slide-19
SLIDE 19

Acyclic Constraint

slide-20
SLIDE 20

Acyclic Constraint

slide-21
SLIDE 21

Predicting Interventions

  • After an intervention on

, the gradients into and the slow weights for the -th conditional are false, because they do not bear the blame for ’s outcome (which lies with the intervener).

  • The conditional likelihood of the intervened variable tends

to have a poorer relative likelihood under .

  • Hence, the variable with the greatest deterioration in

likelihood is picked as a good guess.

Xi γi i Xi Dint

slide-22
SLIDE 22

Model Description

  • The MLPs are identical in shape but do not share any

parameters, since they are modeling independent causal

  • mechanisms. (M one-hot vectors (nodes) of length N each)

cij ∼ Bin (σ(γij))

Xi = fθi(ci0 × X0, ci1 × X1, …, cim × Xm, ϵi)

slide-23
SLIDE 23

Stability of Training

  • Simultaneous training of both the structural and the

functional meta-parameters.

  • These are not independent and do influence each other,

which leads to instability in training.

  • Pre-train the model under observational data (from the

distribution of the data before interventions) using dropout on the inputs.

  • functional meta-parameters

are not too biased towards certain configurations of the meta-parameters .

θslow γ

slide-24
SLIDE 24

Regularizers

  • DAG Constraint
  • Sparsity:
  • sparse representation of edges in the causal graph.
  • L1 regularizer
  • Slightly faster convergence
slide-25
SLIDE 25

Temperature

  • A temperature hyperparameter to encourage the

groundtruth model to generate some very rare events in the conditional probability tables (CPTs) more frequently.

  • The near-ground-truth MLP model’s logit outputs are divided

by the temperature before being used for sampling.

slide-26
SLIDE 26

All in all: Algorithm

Pre-train on Observational Data Predict Intervened Node

slide-27
SLIDE 27

Simulations: Synthetic Data

The results are, however, sensitive to some hyperparameters, notably the DAG penalty and the sparsity penalty.

slide-28
SLIDE 28

Simulations: Real-World Data

  • BNLearn: Earthquake, Cancer, Asia
  • variables respectively, maximum 2 parents per node.
  • Learn a near-ground-truth MLP from the dataset’s CPT and use it as the ground-truth data generator.
  • In spite of same causal graphs, CPTs were different; hence different SCMs.

M = 5,5,8

slide-29
SLIDE 29
slide-30
SLIDE 30

Simulations: Comparison

  • Peters et al., (2016): ICP
  • Eaton & Murphy (2007a): uncertain interventions
  • Peters et al. (2016): unknown interventions
  • However, neither attempt to predict the intervention.
slide-31
SLIDE 31

Importance of Dropout

  • Used for initial training on observational data.
  • Fully connected off-diagonal (the most DOF).
  • Pre-training cannot be carried out this way.
slide-32
SLIDE 32

Importance of Intervention Prediction

  • After the intervention has been performed, the learner

draws data samples from the intervention distribution and computes the per-variable average log-probability under sampled adjacency matrices.

  • The variable consistently producing the least-likely
  • utputs is predicted to be the intervention node.
slide-33
SLIDE 33

Importance of Intervention Prediction