Multiple Instance Learning for Fast, Stable and Early RNN - - PowerPoint PPT Presentation

multiple instance learning
SMART_READER_LITE
LIVE PREVIEW

Multiple Instance Learning for Fast, Stable and Early RNN - - PowerPoint PPT Presentation

The Edge of Machine Learning Multiple Instance Learning for Fast, Stable and Early RNN Predictions Don Dennis , Microsoft Research India, Joint work with Chirag P., Harsha and Prateek Accepted to NIPS 18 1 Algorithms for the IDE - EdgeML


slide-1
SLIDE 1

The Edge of Machine Learning

Don Dennis, Microsoft Research India, Joint work with Chirag P., Harsha and Prateek Accepted to NIPS ’18

Multiple Instance Learning for Fast, Stable and Early RNN Predictions

1

slide-2
SLIDE 2

Algorithms for the IDE - EdgeML

  • A library of machine learning algorithms
  • Trained on the cloud
  • Ability to run on tiniest of IoT devices

Arduino Uno

2

slide-3
SLIDE 3

Code: https://github.com/Microsoft/EdgeML

Previous Work: EdgeML Classifiers

Bonsai ProtoNN Fast(G)RNN

Gupta et al., ICML ’17 Kumar et al., ICML ’17 Kusupati et al., NIPS ’18

3

slide-4
SLIDE 4

Code: En route

Previous Work: EdgeML Applications

Wake Word GesturePod

Patil et al., (to be submitted) (work in progress)

4

slide-5
SLIDE 5

Problem

5

slide-6
SLIDE 6

Problem

  • Given time series data point, classify it as a certain class.
  • GesturePod:

– Data: Accelerometer and gyroscope information – Task: Detect if gesture was performed

6

slide-7
SLIDE 7

Problem

7

slide-8
SLIDE 8

Problem

8

slide-9
SLIDE 9

Problem

9

ProtoNN and Bonsai

slide-10
SLIDE 10

Problem

10

Expensive! Prohibitive on IoT Devices ProtoNN and Bonsai

slide-11
SLIDE 11

RNNs are Expensive

  • For time series data:
  • T RNN updates are performed:
  • T is determined by the data labelling process. Example GesturePod – 2 seconds.

11

slide-12
SLIDE 12

RNNs are Expensive

  • For time series data:
  • T RNN updates are performed:
  • T is determined by the data labelling process. Example GesturePod – 2 seconds.

12

slide-13
SLIDE 13

RNNs are Expensive

Observe how k << T.

  • RNN runs over longer data point – unnecessarily large T and prediction time.
  • Predictors must recognize signatures with different offsets - requires larger predictors.
  • Sequential compute.
  • Also lag.

13

slide-14
SLIDE 14

RNNs are Expensive

14

Solution ? Approach 1 of 2 : Exploit the fact that k << T and learn a smaller classifier. How?

slide-15
SLIDE 15

How ?

  • STEP 1: Divide X into smaller

instances.

15

slide-16
SLIDE 16

How ?

  • STEP 1: Divide X into smaller

instances.

16

slide-17
SLIDE 17

How ?

  • STEP 1: Divide X into smaller

instances.

  • STEP 2: Identify positive
  • instances. Discard negative

(noise) instances.

17

slide-18
SLIDE 18

How ?

  • STEP 1: Divide X into smaller

instances.

  • STEP 2: Identify positive
  • instances. Discard negative

(noise) instances.

18

slide-19
SLIDE 19

How ?

  • STEP 1: Divide X into smaller

instances.

  • STEP 2: Identify positive
  • instances. Discard negative

(noise) instances.

  • STEP 3: Use these instances

to train a smaller classifier.

19

slide-20
SLIDE 20

How ?

  • STEP 1: Divide X into smaller

instances.

  • STEP 2: Identify positive
  • instances. Discard negative

(noise) instances.

  • STEP 3: Use these instances

to train a smaller classifier.

20

Note! Most of the instances are just noise.

slide-21
SLIDE 21

How ?

  • STEP 1: Divide X into smaller

instances.

  • STEP 2: Identify positive
  • instances. Discard negative

(noise) instances.

  • STEP 3: Use these instances

to train a smaller classifier.

21

slide-22
SLIDE 22

How ?

  • STEP 1: Divide X into smaller

instances.

  • STEP 2: Identify positive
  • instances. Discard negative

(noise) instances.

  • STEP 3: Use these instances

to train a smaller classifier.

22

Robust Learning

slide-23
SLIDE 23

How ?

  • STEP 1: Divide X into smaller

instances.

  • STEP 2: Identify positive
  • instances. Discard negative

(noise) instances.

  • STEP 3: Use these instances

to train a smaller classifier.

23

Robust Learning

Standard techniques don’t apply.

  • Too much noise.
  • Ignores temporal structure of the

data.

slide-24
SLIDE 24

How ?

  • STEP 1: Divide X into smaller

instances.

  • STEP 2: Identify positive
  • instances. Discard negative

(noise) instances.

  • STEP 3: Use these instances

to train a smaller classifier.

24

Robust Learning Traditional Multi Instance Learning (MIL)

Standard techniques don’t apply.

  • Too much noise.
  • Ignores temporal structure of the

data

slide-25
SLIDE 25

How ?

  • STEP 1: Divide X into smaller

instances.

  • STEP 2: Identify positive
  • instances. Discard negative

(noise) instances.

  • STEP 3: Use these instances

to train a smaller classifier.

25

Robust Learning

Standard techniques don’t apply.

  • Too much noise.
  • Ignores temporal structure of the

data

Traditional Multi Instance Learning (MIL)

Standard techniques don’t apply.

  • Heterogenous.
  • Ignores temporal structure of the

data.

slide-26
SLIDE 26

How ?

Property 1: Positive instances are clustered together. Property 2: Number of positive instances can be estimated.

Exploit temporal locality with MIL/Robust learning techniques

26

slide-27
SLIDE 27

Algorithm: MI-RNN

Two phase algorithm – alternates between identifying positive instances and training on the positive instances.

27

slide-28
SLIDE 28

Algorithm: MI-RNN

  • Step 1:

Assign labels Instance = source data

28

slide-29
SLIDE 29

Algorithm: MI-RNN

  • Step 1:

Assign labels Instance = source data

29

slide-30
SLIDE 30

Algorithm: MI-RNN

  • Step 1:

Assign labels Instance = source data

30

slide-31
SLIDE 31

Algorithm: MI-RNN

  • Step 2:

Train classifier on this data

31

slide-32
SLIDE 32

Algorithm: MI-RNN

  • Step 2:

Train classifier on this data

32

True positive instances Correctly labeled

slide-33
SLIDE 33

Algorithm: MI-RNN

  • Step 2:

Train classifier on this data

33

True positive instances Correctly labeled Mislabeled instances Common to all classes

slide-34
SLIDE 34

Algorithm: MI-RNN

  • Step 2:

Train classifier on this data

34

Common to all classes

slide-35
SLIDE 35

Algorithm: MI-RNN

  • Step 2:

Train classifier on this data

35

Common to all classes Classifier will be confused. Low prediction confidence.

slide-36
SLIDE 36

Algorithm: MI-RNN

  • Step 3:

Wherever possible, use classifier’s prediction score to pick top-κ Should satisfy property 1 and property 2

36

Top-κ

slide-37
SLIDE 37

Algorithm: MI-RNN

  • Step 3:

Wherever possible, use classifier’s prediction score to pick top-κ Should satisfy property 1 and property 2

37

Top-κ

slide-38
SLIDE 38

Algorithm: MI-RNN

  • Step 4:

Repeat with new labels

38

slide-39
SLIDE 39

MI-RNN: Does It Work?

39

slide-40
SLIDE 40

MI-RNN: Does It Work?

  • Of course!

40

slide-41
SLIDE 41

MI-RNN: Does It Work?

  • Of course!
  • Theoretical analysis:

Convergence to global optima in linear time for nice data

41

slide-42
SLIDE 42

MI-RNN: Does It Work?

  • Of course!
  • Theoretical analysis:

Convergence to global optima in linear time for nice data

  • Experiments:

Significantly improve accuracy while saving computation

– Various tasks: activity recognition, audio keyword detection, gesture recognition

42

slide-43
SLIDE 43

MI-RNN: Does It Work?

43

Dataset Hidden Dim LSTM MI-RNN Savings % HAR-6 (Activity detection) 8 89.54 91.92 62.5 16 92.90 93.89 32 93.04 91.78 Google-13 (Audio) 16 86.99 89.78 50.5 32 89.84 92.61 64 91.13 93.16 WakeWord-2 (Audio) 8 98.07 98.08 50.0 16 98.78 99.07 32 99.01 98.96

slide-44
SLIDE 44

MI-RNN: Does It Work?

44

Dataset Hidden Dim LSTM MI-RNN Savings % HAR-6 (Activity detection) 8 89.54 91.92 62.5 16 92.90 93.89 32 93.04 91.78 Google-13 (Audio) 16 86.99 89.78 50.5 32 89.84 92.61 64 91.13 93.16 WakeWord-2 (Audio) 8 98.07 98.08 50.0 16 98.78 99.07 32 99.01 98.96

MI-RNN better than LSTM almost always

slide-45
SLIDE 45

MI-RNN: Does It Work?

45

MI-RNN better than LSTM almost always

Dataset Hidden Dim LSTM MI-RNN Savings % GesturePod-6 (Gesture detection) 8

  • 98.00

50 32 94.04 99.13 48 97.13 98.43 DSA-19 (Activity detection) 32 84.56 87.01 28 48 85.35 89.60 64 85.17 88.11

slide-46
SLIDE 46

MI-RNN: Savings?

46

Dataset Hidden Dim LSTM Hidden Dim MI-RNN Savings Savings at ~1% drop HAR-6 32 93.04 16 93.89 10.5x 42x Google-13 64 91.13 32 92.61 8x 32x WakeWord-2 32 99.01 16 99.07 8x 32x GesturePod-6 48 97.13 8 98.00 72x

  • DSA-19

64 85.17 32 87.01 5.5x

slide-47
SLIDE 47

MI-RNN: Savings?

47

Dataset Hidden Dim LSTM Hidden Dim MI-RNN Savings Savings at ~1% drop HAR-6 32 93.04 16 93.89 10.5x 42x Google-13 64 91.13 32 92.61 8x 32x WakeWord-2 32 99.01 16 99.07 8x 32x GesturePod-6 48 97.13 8 98.00 72x

  • DSA-19

64 85.17 32 87.01 5.5x

  • MI-RNN achieves same or better accuracy

with ½ or ¼ of LSTM hidden dim.

slide-48
SLIDE 48

MI-RNN in Action

48

Synthetic MNIST: Detecting the presence of Zero.

slide-49
SLIDE 49

49

MI-RNN in Action

slide-50
SLIDE 50

RNNs are Expensive

50

Solution ? Approach 2 of 2 : Early Prediction How?

slide-51
SLIDE 51

Can we do even better?

  • For a lot of cases, looking
  • nly at a small prefix is

enough to classify/reject. Early Prediction

51

slide-52
SLIDE 52

Can we do even better?

  • Existing work:

– Assumes pretrained classifier and uses secondary classifiers – Template matching approaches – Separate policy for early classification

  • Not feasible!

52

slide-53
SLIDE 53

Early Prediction

Our Approach Inference: Predict at each step – stop as soon as prediction confidence is high. Training: Incentivize early prediction by rewarding correct and early detections.

53

slide-54
SLIDE 54

Algorithm: E-RNN

54

Early Loss: Regular Loss:

slide-55
SLIDE 55

Algorithm: E-RNN

55

Early Loss: Regular Loss:

Incentivizes early and consistent prediction.

slide-56
SLIDE 56

E-RNN: How well does it work?

56

slide-57
SLIDE 57

E-RNN: How well does it work?

57

  • Abysmally bad 
slide-58
SLIDE 58

E-RNN: How well does it work?

58

  • Abysmally bad 
  • In GesturePod-6, we loose 10-12% accuracy attempting to

predict early.

slide-59
SLIDE 59

E-RNN: How well does it work?

59

  • Abysmally bad 
  • In GesturePod-6, we loose 10-12% accuracy attempting to

predict early.

  • Gets confused easily due to common prefixes!

Positive datapoint Negative datapoint

slide-60
SLIDE 60

E-RNN: How well does it work?

60

  • MI-RNN can help!
  • Instances are very tight

around signatures.

Positive datapoint Negative datapoint

slide-61
SLIDE 61

E-RNN: How well does it work?

61

  • MI-RNN can help!
  • Instances are very tight

around signatures.

Positive datapoint Negative datapoint

slide-62
SLIDE 62

E-RNN: How well does it work?

62

  • MI-RNN can help!
  • Instances are very tight

around signatures.

  • Low confusion - common

prefixes are small.

Positive datapoint Negative datapoint

slide-63
SLIDE 63

Algorithm: EMI-RNN

63

  • Combine the MI-RNN training routine with E-RNN loss function

and train jointly.

  • Not only do you predict on smaller windows, but you predict

early very often!

slide-64
SLIDE 64

EMI-RNN: Results

64

slide-65
SLIDE 65

EMI-RNN: Results

65

For HAR-6, we are 8x faster at 8 hidden size wth better accuracy

slide-66
SLIDE 66

EMI-RNN: Results

66

Comparing across hidden sizes, savings amplify by 4-16x

slide-67
SLIDE 67

Raspberry Pi0

67

Device Hidden Dim. LSTM (ms) MI-RNN (ms) EMI-RNN (ms) RPi0 (22.5 ms) 16 28.14 14.06 5.62 32 74.46 37.41 14.96 64 226.1 112.6 45.03 RPi3 (26.39 ms) 16 12.76 6.48 2.59 32 33.10 16.47 6.58 64 92.09 46.28 18.51 1GHz, Single-core CPU - 512MB RAM

slide-68
SLIDE 68

Conclusions and Future Work

  • 8x – 72x savings with MI-RNN. Additional savings from early

prediction.

  • Better or match LSTM performance.
  • 10x performance gain away from Arduino class devices:
  • EMI-FastGRNN
  • Rolling LSTM

68

slide-69
SLIDE 69

Thank You!

69

slide-70
SLIDE 70

Next Talk

Support Recovery for Orthogonal Matching Pursuit: Upper and Lower Bounds

Somani et al., NIPS ’18

70