Scalable Deep Generative Modeling for Sparse Graphs Hanjun Dai 1 , - - PowerPoint PPT Presentation

scalable deep generative modeling for sparse graphs
SMART_READER_LITE
LIVE PREVIEW

Scalable Deep Generative Modeling for Sparse Graphs Hanjun Dai 1 , - - PowerPoint PPT Presentation

Scalable Deep Generative Modeling for Sparse Graphs Hanjun Dai 1 , Azade Nazi 1 , Yujia Li 2 , Bo Dai 1 , Dale Schuurmans 1 1 Google Brain, 2 DeepMind Graph generative models Given a set of graphs {G 1 , G 2 , , G N }, fit a probabilistic model


slide-1
SLIDE 1

Hanjun Dai1, Azade Nazi1, Yujia Li2, Bo Dai1, Dale Schuurmans1

1Google Brain, 2DeepMind

Scalable Deep Generative Modeling for Sparse Graphs

slide-2
SLIDE 2

Graph generative models

Given a set of graphs {G1, G2, …, GN}, fit a probabilistic model p(G) over graphs. So that we can:

  • sample from it to get new graphs: G ~ p(G)
  • complete a graph given parts: Grest ~ p(Grest | Gpart)
  • btain graph representations

Can also be used for structured prediction p(G|z).

slide-3
SLIDE 3

Types of deep graph generative models

Modeling adjacency matrix directly

⇒ like an image

VAE (Kipf et al 16, Simonovsky et al. 18) Autoregressive (You et al., 18; Liao et al., 19)

Leverage the sparse structure of graphs

Junction-Tree VAE (Jin et al., 18) NetGAN (Bojchevski et al., 18) Deep GraphGen (Li et al., 18) 1 2 3 4

slide-4
SLIDE 4

Autoregressive graph generative models

Model Complexity (n nodes, m edges) Scalability Deep GraphGen (Li et al., 18) O((m + n)2) ~100 nodes GraphRNN (You et al., 18) O(n2) ~2,000 nodes GRAN (Liao et al., 19) O(n2) ~5,000 nodes BiGG (Dai et al., 20) O((m + n) log n) ~100,000 nodes

Time complexity per graph during inference: This work Or O(n2) for fully connected graph

slide-5
SLIDE 5

Autoregressive graph generative models

Model # syncs during training memory cost Deep GraphGen (Li et al., 18) O(m) O(m(m+n)) GraphRNN (You et al., 18) O(n2) or O(n) O(n2) GRAN (Liao et al., 19) O(n) O(n(m+n)) BiGG (Dai et al., 20) O(log n) O( ! log %)

Time/memory complexity per graph during training: This work

slide-6
SLIDE 6

Saving computation for sparse graphs

GraphRNN O(n2) ... ... GRAN, GraphRNN-S O(n2) BiGG (this work) O((m + n) log n) ... ...

slide-7
SLIDE 7

Autoregressive Generation

  • f adjacency matrix

01 02 03 Generating one cell Generating one row Generating rows

slide-8
SLIDE 8

Autoregressive Generation

  • f adjacency matrix

01 02 03 Generating one cell Generating one row Generating rows

slide-9
SLIDE 9

Efficient approach: Recursively divide the range [1, n] into two halves, choose one. O(log n) decisions maximum Naive approach: Given node u, choose a neighbor v. Choose 1 out of n using a softmax O(n)

O(log n) procedure for generating one edge

slide-10
SLIDE 10

Binary tree generation

Following a path from root O(log n) Generating neighbors separately? O(Nu log n) Nu is the number of neighbors of node u Generating via DFS O(|T|) |T| is the tree size. |T| < min{Nu log n, 2n}

slide-11
SLIDE 11

Autoregressive Generation

  • f adjacency matrix

01 02 03 Generating one cell Generating one row Generating rows

slide-12
SLIDE 12

Autoregressive row-binary tree generation

For node t, we first decide whether to generate left child.

t def generate_tree(t): should generate left child?

slide-13
SLIDE 13

Autoregressive row-binary tree generation

htop(t)

t

For node t, we first decide whether to generate left child. ⇒Generate left child: Conditioning on htop(t), which summarizes existing tree (from top-down) Has-left ~ Bernoulli( ∘ | htop(t))

def generate_tree(t): should generate left child?

slide-14
SLIDE 14

Autoregressive row-binary tree generation

htop(t)

t

For node t, we first decide whether to generate left child. ⇒Generate left child: Conditioning on htop(t), which summarizes existing tree (from top-down) Has-left ~ Bernoulli( ∘ | htop(t)) Yes? ⇒ Recursively generate left subtree

def generate_tree(t): should generate left child? if yes: create left child generate_tree(lch(t))

slide-15
SLIDE 15

Autoregressive row-binary tree generation

htop(t)

def generate_tree(t): should generate left child? if yes: create left child generate_tree(lch(t)) should generate right child? if yes: create right child generate_tree(rch(t)) t

For node t, we first decide whether to generate left child. ⇒Generate left child: Conditioning on htop(t), which summarizes existing tree (from top-down) Has-left ~ Bernoulli( ∘ | htop(t)) Yes? ⇒ Recursively generate left subtree ⇒Generate right child: Conditioning on htop(t), and hbot(lch(t)), which summarizes the left subtree of t (from bottom-up) Has-right ~ Bernoulli( ∘ | htop(t), hbot(lch(t))) Yes? ⇒ Recursively generate right subtree hbot(lch(t))

slide-16
SLIDE 16

Realize top-down and bottom-up recursion

hbot(t) = TreeLSTMCell( , )

hbot(lch(t)) hbot(rch(t)) hbot(t)

htop(lch(t)) = LSTMCell( , !"#$%)

htop(t) htop(lch(t)) hbot(lch(t)) htop(rch(t))

htop(rch(t)) = LSTMCell( , !&'()% )

ĥtop(rch(t))

ĥtop(rch(t)) = TreeLSTMCell( , )

slide-17
SLIDE 17

Autoregressive Generation

  • f adjacency matrix

01 02 03 Generating one cell Generating one row Generating rows

slide-18
SLIDE 18

Autoregressive conditioning between rows

To generate neighbors of node u, (i.e., u-th row) How to summarize row0 to rowu-1?

hrow(0) hrow(1)

……

hrow(2) hrow(u-1) Use LSTM? – not efficient O(n) dependency length

hrow(0) hrow(1) hrow(u)

……

slide-19
SLIDE 19

Fenwick tree for prefix summarization

Fenwick tree: data structure that supports prefix sum and single modification

hrow(0) hrow(2) hrow(4) hrow(5) hrow(1) hrow(3)

slide-20
SLIDE 20

Fenwick tree for prefix summarization

Fenwick tree: data structure that supports prefix sum and single modification

hrow(0) hrow(2) hrow(4) hrow(5) hrow(1) hrow(3)

Obtaining “prefix sum” using low-bit query

Current row u Required Context

At most O(log n) dependencies per row

hrow(2)

u = 3

hrow(4)

u = 5 u = 6

slide-21
SLIDE 21

Optimizing BiGG

01 02 Training with O(log n) synchronizations Model parallelism & sublinear memory cost

slide-22
SLIDE 22

Optimizing BiGG

01 02 Training with O(log n) synchronizations Model parallelism & sublinear memory cost

slide-23
SLIDE 23

Training with O(log n) synchronizations ……

1 1 1 1 1 1 1

……

Stage 1: Compute all bottom-up summarizations for all rows O(log n) steps

Sync 1 Sync 2 Sync 3 Sync 4

slide-24
SLIDE 24

Training with O(log n) synchronizations

Stage 2: Construct entire Fenwick Tree O(log n) steps

1 1 1 1 1 1 1

Sync 1 Sync 2 Sync 3

slide-25
SLIDE 25

Training with O(log n) synchronizations

Stage 3: Retrieve all the prefix context O(log n) steps

ℎ"

#$%

ℎ&

#$%

ℎ'

#$%

ℎ(

#$%

ℎ)

#$%

Sync 1 Sync 2

slide-26
SLIDE 26

Training with O(log n) synchronizations

Stage 4: Compute Cross-Entropy O(log n) steps

……

1 1 1 1 1 1 1

……

Sync 4 Sync 3 Sync 2 Sync 1

slide-27
SLIDE 27

Optimizing BiGG

01 02 Training with O(log n) synchronizations Model parallelism & sublinear memory cost

slide-28
SLIDE 28

Model parallelism

slide-29
SLIDE 29

Model parallelism

GPU 1 GPU 2

slide-30
SLIDE 30

Model parallelism

GPU 1 GPU 2 GPU 1 -> 2 message

slide-31
SLIDE 31

Sublinear memory cost

Pass-1 Pass-2 pass-1 to pass-2

Run 2x forward + 1x backward Memory cost during training:

O( ! log %)

slide-32
SLIDE 32

Experiments

slide-33
SLIDE 33

Inference speed

slide-34
SLIDE 34

Training memory

slide-35
SLIDE 35

Training time

Main reason: # GPU cores is limited

slide-36
SLIDE 36

Sample quality on benchmark datasets

slide-37
SLIDE 37

Sample quality as graph size grows

slide-38
SLIDE 38

Summary

Advantages:

  • Improve inference speed to O(min{ (m + n) log n, n2} )
  • Enables parallelized training with sublinear memory cost
  • Did not sacrifice the sample quality

Limitations:

  • Limited by the parallelism of existing hardware
  • Good capacity, but limited extrapolation ability
slide-39
SLIDE 39

Thank You

Hanjun Dai

hadai@google.com Research Scientist