Model-driven Deep Learning Jian Sun ( ) Xi'an Jiaotong University - - PowerPoint PPT Presentation

model driven deep learning
SMART_READER_LITE
LIVE PREVIEW

Model-driven Deep Learning Jian Sun ( ) Xi'an Jiaotong University - - PowerPoint PPT Presentation

Model-driven Deep Learning Jian Sun ( ) Xi'an Jiaotong University Email : jiansun@mail.xjtu.edu.cn Home page : http://jiansun.gr.xjtu.edu.cn April, 2019 Outline Introduction Background: Image analysis / deep neural networks


slide-1
SLIDE 1

Jian Sun (孙剑)

Model-driven Deep Learning

Xi'an Jiaotong University Email: jiansun@mail.xjtu.edu.cn Home page: http://jiansun.gr.xjtu.edu.cn April, 2019

slide-2
SLIDE 2

Outline

⚫ Introduction

– Background: Image analysis / deep neural networks – Motivation

⚫ Model-driven Deep Learning Approach

– Learning Markov Random Field Model for Image Restoration – Deep ADMM-Net for Fast Compressive Sensing MRI – Deep Fusion-Net for Multi-Atlas MR Image Segmentation

⚫ Recent Progress

– Learning proximal operators – Multimodal medical image synthesis – Learning Graph CNNs for 3D shape analysis – Learning to Optimize

⚫ Discussion & Conclusion

slide-3
SLIDE 3

Backgrounds--Image Processing & Analysis

⚫ Restoration & Reconstruction

Image Degradation: noises, motion blur, k-space sampling, etc. Physical imaging model Restoration & Reconstruction Inverse Problems

slide-4
SLIDE 4

Backgrounds--Image Processing & Analysis

⚫ Segmentation & Recognition

Semantic Segmentation Lesion (Pulmonary nodule) localization and classification

slide-5
SLIDE 5

⚫ Conventional Models: Signal processing approaches

– Wavelets – Image Filtering

Backgrounds--Models

slide-6
SLIDE 6

⚫ Conventional Models: Energy model and its optimization

– Energy Model with Regularization – Dictionary Learning Applications: Image Restoration / Segmentation / Classification / MRI / Lesion detection

Backgrounds--Models x* = argmin

x

D(x,y;w)+ R(w)

slide-7
SLIDE 7

⚫ Conventional Models: statistical models

Evidence lower bound (ELBO) Expectation-maximization (EM) Variational Inference Variational expectation-maximization

Backgrounds--Models

slide-8
SLIDE 8

⚫ Deep Convolutional Neural Network

Backgrounds--Deep Neural Networks

CNN [Krizhevsky A, et al., 2012]

slide-9
SLIDE 9

Backgrounds--Deep Neural Networks

[Hochreiter & Schmidhuber,1997]

⚫ LSTM:

A

[Ian Goodfellow et al., 2014]

⚫ GAN

Generator Discriminator

true/fake

slide-10
SLIDE 10

Conventional Model Vs. Deep NNs

Deep Neural Networks

Pros:

⚫ An universal regressor ⚫ Efficiency ⚫ Effectiveness

Cons:

⚫ Rely on large training set ⚫ Relatively fixed structure ⚫ Hardly incorporate domain

knowledge

Conventional Models

Pros:

⚫ Easy to incorporate domain

knowledge

⚫ Rely on less training data ⚫ Good generalization ability

Cons:

⚫ Maybe not optimal for specific

task

⚫ Parameter tuning

(Optimization / statistics / energy model…) (CNN / LSTM / GAN….)

slide-11
SLIDE 11

Model-driven Deep Learning

Model

⚫Formulations?

– Energy model – Statistical model – Image priors

⚫Parameters?

– Hyperparameters – Statistical model parameters

⚫Strategies?

– Gradient updates in

  • ptimization

– Actions in control

Task-specific training data

Deep learning Why model-driven?

Explainable ML; Prior knowledge; Traditional model-based approach

slide-12
SLIDE 12

⚫ Optimization-driven DL

– Sparse coding optimization

[Karol Gregor, et al, ICML 2010; P. Sprechmann, et al, PAMI 2015, etc.]

– Gradient descent, ADMM, proximal operators, etc

[J. Sun, et al., CVPR 2011; Y. Yang, J. Sun et al., NIPS 2016; Tim. Meinhardt, et al., ICCV 2017, etc.]

⚫ Statistical model-driven DL

– MRF, CRF

[S. Zheng, et al., ICCV 2015; J. Sun, et. al., IEEE TIP 2013, etc.]

– Variational inference

[J. Marino, et al., ICLR 2018; etc ]

– EM [D. P. Kingma, ICLR 2014; Greff, Klaus, et al., NIPS 2017, etc]

……

Model-driven Deep Learning

slide-13
SLIDE 13

Outline

⚫ Introduction

– Background: Image analysis / deep neural networks – Motivation

⚫ Model-driven Deep Learning Approach

– Learning Markov Random Field Model for Image Restoration – Deep ADMM-Net for Fast Compressive Sensing MRI – Deep Fusion-Net for Multi-Atlas MR Image Segmentation

⚫ Recent Progress

– Learning proximal operators – Multimodal medical image synthesis – Learning Graph CNNs for 3D shape analysis – Learning to Optimize

⚫ Discussion & Conclusion

slide-14
SLIDE 14

Example

⚫ Non-local Range MRF [J. Sun, M. Tappen, CVPR 2011]

 A novel Markov random field model  Discriminative parameter learning

slide-15
SLIDE 15

Example

⚫ Non-local Range MRF [J. Sun, M. Tappen, CVPR 2011]

 A novel Markov random field model  Discriminative parameter learning

Non-local Range MRF

slide-16
SLIDE 16

Example

⚫ Non-local Range MRF [J. Sun, M. Tappen, CVPR 2011]

 A novel Markov random field model  Discriminative parameter learning

Non-local Range MRF

slide-17
SLIDE 17

Example

⚫ Non-local Range MRF [J. Sun, M. Tappen, CVPR 2011]

 A novel Markov random field model  Discriminative parameter learning

Non-local Range MRF

slide-18
SLIDE 18

Example

⚫ Non-local Range MRF [J. Sun, M. Tappen, CVPR 2011]

 A novel Markov random field model  Discriminative parameter learning

Non-local Range MRF

unfolding

slide-19
SLIDE 19

⚫ Gradients of loss function w.r.t. model parameters

KEY IDEA:

– General framework to compute gradient of the parameter

Similar to a Neural Network with K layers Back-propagation:

Non-local Range Markov Random Field Model

slide-20
SLIDE 20

Outline

⚫ Introduction

– Background: Image analysis / deep neural networks – Motivation

⚫ Model-driven Deep Learning Approach

– Learning Markov Random Field Model for Image Restoration – Deep ADMM-Net for Fast Compressive Sensing MRI – Deep Fusion-Net for Multi-Atlas MR Image Segmentation

⚫ Recent Progress

– Learning proximal operators – Multimodal medical image synthesis – Learning Graph CNNs for 3D shape analysis – Learning to Optimize

⚫ Discussion & Conclusion

slide-21
SLIDE 21

◆ Less sampling and fast reconstruction ? ◆ Compressive sensing:A dominant approach in fast MRI

reconstruction

[1] Michael Lustig,David L. Donoho,Compressed Sensing MRI, IEEE SIGNAL PROCESSING MAGAZINE, 2008.

MRI Image Reconstruction

Reconstruction

Deep ADMM-Net for Compressive Sensing

slide-22
SLIDE 22

A basic compressive sensing (CS) model: A : measurement matrix, A = PF (P: Sampling matrix; F: Fourier transform) Dl : filter matrix corresponding to convolution operation : regularization term, e.g., l0, l1 norm : regularization term

ll

Deep ADMM-Net for Compressive Sensing

slide-23
SLIDE 23

ADMM (Alternating Direction Method of Multipliers) Augmented Lagrangian function: ADMM iterations:

[Y Yang, J Sun, et al., NIPS 2016]

Deep ADMM-Net for Compressive Sensing

slide-24
SLIDE 24

Data Flow Graph (DFG) for ADMM

Unfolding to stage n in DFG

C(n) = Dlx(n)

Deep ADMM-Net for Compressive Sensing

slide-25
SLIDE 25

⚫ Deep ADMM-Net:

Reconstruction layer (X(n)): Convolution layer (C(n)): Nonlinear transform layer (Z(n)): Multiplier updating layer (M(n)):

Deep ADMM-Net for Compressive Sensing

slide-26
SLIDE 26

⚫ Network training: Gradient computation by backpropagation

Parameter optimization: L-BFGS

Deep ADMM-Net for Compressive Sensing

slide-27
SLIDE 27

⚫ Training Data Generation ⚫ Training loss

ground truth Observe ved data

… …

Sampling in k-space

Deep ADMM-Net for Compressive Sensing

slide-28
SLIDE 28

Deep ADMM-Net for Compressive Sensing

slide-29
SLIDE 29

⚫ Extensions of ADMM-Net ([IEEE PAMI, 2018])

– More flexible network structure

Deep ADMM-Net for Compressive Sensing

slide-30
SLIDE 30

stage n

… … …

ADMM-Net-v2

Deep ADMM-Net for Compressive Sensing

slide-31
SLIDE 31

Deep ADMM-Net for Compressive Sensing

slide-32
SLIDE 32

Deep ADMM-Net for Compressive Sensing

slide-33
SLIDE 33

Our results: ground truth:

Deep ADMM-Net for Compressive Sensing

slide-34
SLIDE 34

Applications to more general compressive imaging: Fast inversion:

Bottleneck

  • Partial Fourier matrix
  • Random matrix with
  • rthogonal rows
  • Structurally random matrix

Deep ADMM-Net for Compressive Sensing

slide-35
SLIDE 35

Natural image compressive sensing

Deep ADMM-Net for Compressive Sensing

slide-36
SLIDE 36

Outline

⚫ Introduction

– Background: Image analysis / deep neural networks – Motivation

⚫ Model-driven Deep Learning Approach

– Learning Markov Random Field Model for Image Restoration – Deep ADMM-Net for Fast Compressive Sensing MRI – Deep Fusion-Net for Multi-Atlas MR Image Segmentation

⚫ Recent Progress

– Learning proximal operators – Multimodal medical image synthesis – Learning Graph CNNs for 3D shape analysis – Learning to Optimize

⚫ Discussion & Conclusion

slide-37
SLIDE 37

Introduction

⚫ Background: Multi-atlas segmentation has been one of the most

widely-used and successful medical image segmentation techniques in the past decade.

Atlases Image Label

Registration ? Target Image Atlas Selection Label Fusion

Iglesias, J.E., et. al: Multi-atlas segmentation of biomedical images: a survey. (Med. Image Anal. 2015)

weighted voting statistical theory … …

Deep Fusion Net for MR Image Segmentation

slide-38
SLIDE 38

Non-local patch-based label fusion (NL-PLF) model

[1] Coupe, P., et al. Patch-based segmentation using expert priors: Application to hippocampus and ventricle segmentation. (NeuroImage 2011) [2] Wang Z, et al. Geodesic patch-based segmentation. (MICCAI 2014) [3] Bai, W., et al. Multi-atlas segmentation with augmented features for cardiac MR images. (Med. Image Anal. 2015)

1. Intensity (Coupe et al., 2011) 2. Intensity + spatial context (Wang et al., 2014) 3. Intensity + gradient + contextual (Bai et al.,

2015) Hand-crafted features

Deep Fusion Net for MR Image Segmentation

wpq

Label fusion: Fusion weight:

slide-39
SLIDE 39

Deep Fusion Net

Deep Fusion Net for MR Image Segmentation

CNN layers for feature extraction Deep features Atlas X1 Atlas X2

Feature extraction

F(T;q) F(X1;q) F(X2;q)

  • Deep Fusion Net (MICCAI 2016): An end-to-end learnable deep architecture

for NL-PLF concatenating feature extraction and non-local patch-based label fusion

[H. R. Yang, J. Sun, et al., MICCAI 2016, Medical Image Analysis, 2018]

slide-40
SLIDE 40

Computing fusion weights Weighted voting

Atlas labels Estimated label

Deep Fusion Net for MR Image Segmentation

CNN layers for feature extraction Deep features Atlas X1 Atlas X2

Label Fusion

  • Deep Fusion Net (MICCAI 2016): An end-to-end learnable deep architecture

for NL-PLF concatenating feature extraction and non-local patch-based label fusion

[H. R. Yang, J. Sun, et al., MICCAI 2016, Medical Image Analysis, 2018]

slide-41
SLIDE 41

Deep Fusion Net

Implementation of Label Fusion Sub-Net

Deep Fusion Net for MR Image Segmentation

slide-42
SLIDE 42

Deep Fusion Net

Network structure

Deep Fusion Net for MR Image Segmentation

slide-43
SLIDE 43

Experiments

⚫ Atlas selection

Top-5 atlas images selected by normalized mutual information(NMI). Top-5 atlas images selected by deep feature distance. A target image

Deep Fusion Net for MR Image Segmentation

Database: MICCAI 2013 SATA Segmentation Challenge

Deep feature distance:

slide-44
SLIDE 44

Experiments

⚫ Atlas selection

Deep Fusion Net for MR Image Segmentation

slide-45
SLIDE 45

⚫ Segmentation accuracy

Groundtruth MV PB [1] MAPM [2] SVMAF [3] CNN DFN

[1] Coupe, P., et al. Patch-based segmentation using expert priors: Application to hippocampus and ventricle segmentation. (NeuroImage 2011) [2] Shi, W., et al. Cardiac image super-resolution with global correspondence using multi-atlas

  • patchmatch. (MICCAI 2013)

[3] Bai, W., et al. Multi-atlas segmentation with augmented features for cardiac MR images. (Med. Image Anal. 2015)

Deep Fusion Net for MR Image Segmentation

MICCAI 2013 SATA Dataset

slide-46
SLIDE 46

⚫ Examples of results

Target Output Segment Ground- truth Slice 1 Slice 2 Slice 3 Slice 4 Slice 5 Target Output Segment Ground- truth Slice 6 Slice 7 Slice 8 Slice 9 Slice 10

Deep Fusion Net for MR Image Segmentation

slide-47
SLIDE 47

Deep Fusion Net for MR Image Segmentation

2009 LV segmentation challenge

ADM: averaged Dice Metric; AJM: averaged Jaccard Metric Epicardium (心外膜)

DLLS: Combining deep learning and level set for the automated segmentation of the left ventricle of the heart from

cardiac cine magnetic resonance. Medical Image Analysis, 2017

DLDM: A combined deep-learning and deformable-model approach to fully automatic segmentation of the left

545 ventricle in cardiac MRI, Medical Image Analysis, 2016

slide-48
SLIDE 48

Outline

⚫ Introduction

– Background: Image analysis / deep neural networks – Motivation

⚫ Model-driven Deep Learning Approach

– Learning Markov Random Field Model for Image Restoration – Deep ADMM-Net for Fast Compressive Sensing MRI – Deep Fusion-Net for Multi-Atlas MR Image Segmentation

⚫ Recent Progress

– Multimodal medical image synthesis – Learning proximal operators – Learning Graph CNNs for 3D shape analysis – Learning to Optimize

⚫ Discussion & Conclusion

slide-49
SLIDE 49

⚫ Background

Multi-modal Medical Image Synthesis

MR

(excellent soft-tissue contrast)

CT

(provide tissue electron densities)

?

(Paired training data) Atlas MR Atlas CT Target MR Target CT (unknown)

slide-50
SLIDE 50

⚫ Background

Multi-modal Medical Image Synthesis

MR

(excellent soft-tissue contrast)

CT

(provide tissue electron densities)

?

(Unpaired training data) Atlas MR Atlas CT Target MR Target CT (unknown)

slide-51
SLIDE 51

Multi-modal Medical Image Synthesis

⚫ MR images CT Images

[H. R. Yang, J. Sun, et al., MICCAI-DLMIA, 2018]

Non-local structure:

slide-52
SLIDE 52

Multi-modal Medical Image Synthesis

Training Loss

slide-53
SLIDE 53

Multi-modal Medical Image Synthesis

⚫ Compared methods  “cycleGAN”: Conventional cycleGAN  “cycleGAN (paired)”: CycleGAN trained with paired data ⚫ Evaluation: MAE, PSNR, SSIM, SSIM(HG).

MAE PSNR SSIM SSIM (HG) CycleGAN (unpaired) CycleGAN (paired) Proposed

slide-54
SLIDE 54

Learning proximal operators

⚫ Learning proximal operators for optimization ([ECCV, 2018])

slide-55
SLIDE 55

Proximal-Dehaze Network Structure [ECCV 2018]

Learning proximal operators

slide-56
SLIDE 56

Learning proximal operators

slide-57
SLIDE 57

Learning proximal operators

slide-58
SLIDE 58

Learning proximal operators

slide-59
SLIDE 59

Learning proximal operators

slide-60
SLIDE 60

⚫ Matrix Deep Learning / Graph-based Deep Learning

Learning on 3D shapes

Graph Matrix Hyper-graph Tensor

Graph representation: Shape Data graph

slide-61
SLIDE 61

Learning on 3D shapes

Spectral Network [ECCV-GMDL, 2018]

slide-62
SLIDE 62

Learning on 3D shapes

slide-63
SLIDE 63

Learning to optimize

⚫ Network optimizers

 Traditional approach designed by experts SGD, Adam, RMSProp, AdaGrad,….  Learning-based approach Learn the optimizer by Recurrent Neural Network

Andrychowicz, Marcin, et al,Learning to Learn by Gradient Descent by Gradient Descent. In NIPS,2016

RNN: black-box

slide-64
SLIDE 64

⚫ Hyper-Adam [AAAI 2019]:

In each iteration of network parameter updating:

 Generate multiple parameter updates using Adam with

multiple weight decay rates

 Adaptive combination of updates to generate final update

Learning to optimize

slide-65
SLIDE 65

Current State Determining multiple groups of hyper- parameters Generating multiple candidate updates with corresponding hyper- parameters in parallel Combining these updates to get the final update using adaptively learned combination weights

Learning to optimize

Hyper-Adam Algorithm

slide-66
SLIDE 66

Learning to optimize

Computational graph of HyperAdam

slide-67
SLIDE 67

Learning to optimize

Generalization to longer horizons: ➢ Structure ➢ Depth ➢ Dataset

slide-68
SLIDE 68

Learning to optimize

Generalization with fixed steps

slide-69
SLIDE 69

Generalization of the Learners Ablation Study

Learning to optimize

slide-70
SLIDE 70

⚫ Summarization:

Model-driven Deep Learning: proposed deep learning approaches by taking the merits of modeling-based approach and deep learning-based approach

– Gradient descent for energy minimization → deep CNN – ADMM algorithm → deep ADMM-net – Non-local approach -> deep fusion-net – Graph-based deep models

⚫ Current work (IMAGINE: Image Intelligence Group)

Deep learning on graphs / manifolds

Learning to learn

Applications: Natural & medical images analysis / data analysis

Summary

slide-71
SLIDE 71

Thanks for your attention!