Relay: a high level differentiable IR
Jared Roesch TVMConf December 12th, 2018
1
Relay : a high level differentiable IR Jared Roesch TVMConf - - PowerPoint PPT Presentation
Relay : a high level differentiable IR Jared Roesch TVMConf December 12th, 2018 1 This represents months of joint work with lots of great folks: 2 TVM Stack Optimization High-Level Differentiable IR Relay AutoTVM Tensor Expression
Jared Roesch TVMConf December 12th, 2018
1
2
This represents months of joint work with lots of great folks:
Optimization AutoTVM AutoVTA High-Level Differentiable IR Tensor Expression IR LLVM, CUDA, Metal VTA
Edge FPGA Cloud FPGA ASIC
Hardware Fleet
3
Relay
functions.
4
5
Computation Graph Tensor Expression IR LLVM, CUDA, Metal VTA
Edge FPGA Cloud FPGA ASIC
5
LSTM Training Loop Resnet, DCGAN
6
High-Level Differentiable IR Tensor Expression IR LLVM, CUDA, Metal VTA
Edge FPGA Cloud FPGA ASIC
6
LSTM Training Loop Resnet, DCGAN
for i in range(…): inp, hs = …
Python Relay
for i in range(…): input, hs = …
8
9
PyTorch by up to 3x on model inference.
TensorFlow and TensorFlow Lite.
performance improvement over baseline.
10
Text Format AST Optimizer Compiled Operators Operator Language On-disk representat ion Model Importer DSL Ahead of time compiler Reference Interpreter Graph Runtime GPU CPU FPGA
Frontend Compiler Execution
11
machine learning.
primary value type.
loops.
12
x0
…
sn s1 s2 sn + 1
x1 xN s0
13
14
def @generate(n, i, h, …): if (n == 0) [] else let (output, new_hidden) = @rnn_cell(i, h, …);
n - 1, output, new_hidden, …) Parameters Loop Counter Functional style loop
aggressively, and provide better errors.
relationships such as broadcast, flatten, concat, squeeze, and more.
15
16
Tensor<f32, (32, 3, 32, 32)> Tensor : (BaseType, Shape) -> Type Float : (Width: Int, Lanes: Int) -> BaseType f32 = Float<32, 1> 4-d Tensor N * Channels * Height * Width
check (e.g. preconditions must hold over input tensors).
17
add : forall (Lhs: Type, Rhs: Type, Out: Type), (Lhs, Rhs) -> Out where Broadcast(Lhs, Rhs, Out) Broadcast(Tensor<f32, (3, 4, 5)>, Tensor<f32 (n, 3, 4, 5), Tensor<f32, (n, 3, 4, 5)>) Broadcast(Tensor<f32, (1, 5)>, Tensor<f32, (n, 5)>, Tensor<f32, (n, 5)>)
For example we can type broadcasting addition: Broadcasting is a tricky rule often employed in machine learning:
18
19
concat : forall (Args: Type, Out: Type), (Args) -> Out where IsTuple(Args), Concat(Args, Out)
Or more complex constraints such as:
20
Graph Runtime Interpreter AoT Compiler FPGA GPU CPU Relay
21
allocation, control-flow, recursion).
22
23
TVM
def @my_func(…) { … }
execute a subset of Relay programs.
containing operators, and parameters
GraphRTS
+ operators.so
24
25
3 weeks.
26
def @my_func(…) { … }
Standard Optimize AoT Optimize LittleCpp Clang
librelay_aot_my_func.so
27
f = compile(my_func) f(…)
Frameworks such as MxNet directly to Relay.
VTA will be upstreamed soon after the conference.
28
Frameworks such as MxNet directly to Relay.
VTA will be upstreamed soon after the conference.
DRAM LOAD MODULE
INPUT BUFFER WEIGHT BUFFER STORE BUFFER MICRO-OP BUFFER REGISTER FILE Tensor Core Vector ALU LD→CMP Q CMP→LD Q CMP→ST Q ST→CMP Q
COMPUTE MODULE STORE MODULE
LOAD CMD Q COMPUTE CMD Q STORE CMD Q
INSTRUCTION FETCH MODULE
28
beating PyTorch by up to 3x.
suite of models.
bring 11x performance improvement over baseline.
29
30
Relay-Interpreted RNN Relay-Interpreted Cell Relay-Compiled Cell Relay-Compiled RNN PyTorch
Relay Relay
Relay
31
32
frontend, Haskell DSL).
33
quantization.
performance.
34
machine learning models are just programs.
greater range of programs, new optimizations, and the ability to target a wide range of devices.
collaborations.
http://sampl.cs.washington.edu http://tvm.ai
35