basics of numerical optimization computing derivatives
play

Basics of Numerical Optimization: Computing Derivatives Ju Sun - PowerPoint PPT Presentation

Basics of Numerical Optimization: Computing Derivatives Ju Sun Computer Science & Engineering University of Minnesota, Twin Cities February 25, 2020 1 / 36 Derivatives for numerical optimization gradient descent Newtons


  1. Approximate the gradient For f ( x ) : R n → R , ∂x i ≈ f ( x + δ e i ) − f ( x ) ∂f ( forward ) δ ∂x i ≈ f ( x ) − f ( x − δ e i ) ∂f ( backward ) δ ∂x i ≈ f ( x + δ e i ) − f ( x − δ e i ) ∂f ( central ) 2 δ (Credit: numex-blog.com) f ′ ( x ) = lim δ → 0 f ( x + δ ) − f ( x ) δ Similarly, to approximate the Jacobian for f ( x ) : R n → R m : ∂f j ∂x i ≈ f j ( x + δ e i ) − f j ( x ) ( one element each time ) δ ∂x i ≈ f ( x + δ e i ) − f ( x ) ∂f ( one column each time ) δ Jp ≈ f ( x + δ p ) − f ( x ) ( directional ) δ central themes can also be derived 11 / 36

  2. Why central? Stronger form of Taylor’s theorems – 1st order : If f ( x ) : R n → R is twice continuously differentiable, � � � δ � 2 f ( x + δ ) = f ( x ) + �∇ f ( x ) , δ � + O 2 – 2nd order : If f ( x ) : R n → R is three-times continuously differentiable, � � � � δ , ∇ 2 f ( x ) δ � δ � 3 f ( x + δ ) = f ( x ) + �∇ f ( x ) , δ � + 1 + O 2 2 12 / 36

  3. Why central? Stronger form of Taylor’s theorems – 1st order : If f ( x ) : R n → R is twice continuously differentiable, � � � δ � 2 f ( x + δ ) = f ( x ) + �∇ f ( x ) , δ � + O 2 – 2nd order : If f ( x ) : R n → R is three-times continuously differentiable, � � � � δ , ∇ 2 f ( x ) δ � δ � 3 f ( x + δ ) = f ( x ) + �∇ f ( x ) , δ � + 1 + O 2 2 Why the central theme is better? – Forward: by 1st-order Taylor expansion � δ 2 �� � δ ( f ( x + δ e i ) − f ( x )) = 1 1 δ ∂f ∂f ∂x i + O = ∂x i + O ( δ ) δ 12 / 36

  4. Why central? Stronger form of Taylor’s theorems – 1st order : If f ( x ) : R n → R is twice continuously differentiable, � � � δ � 2 f ( x + δ ) = f ( x ) + �∇ f ( x ) , δ � + O 2 – 2nd order : If f ( x ) : R n → R is three-times continuously differentiable, � � � � δ , ∇ 2 f ( x ) δ � δ � 3 f ( x + δ ) = f ( x ) + �∇ f ( x ) , δ � + 1 + O 2 2 Why the central theme is better? – Forward: by 1st-order Taylor expansion � δ 2 �� � 1 δ ( f ( x + δ e i ) − f ( x )) = 1 δ ∂f ∂f ∂x i + O = ∂x i + O ( δ ) δ – Central: by 2nd-order Taylor expansion 1 δ ( f ( x + δ e i ) − f ( x − δ e i )) = � δ 3 �� � 2 δ 2 ∂ 2 f 2 δ 2 ∂ 2 f 1 δ ∂f ∂x i + 1 i + δ ∂f ∂x i − 1 ∂x i + O ( δ 2 ) ∂f i + O = ∂x 2 ∂x 2 2 δ 12 / 36

  5. Approximate the Hessian – Recall that for f ( x ) : R n → R that is 2nd-order differentiable, ∂x i ( x ) : R n → R . So ∂f � � � � � ∂f � ∂f ∂f ( x + δ e j ) − ( x ) ∂f 2 ∂ ∂x i ∂x i ∂x j ∂x i ( x ) = ( x ) ≈ ∂x j ∂x i δ 13 / 36

  6. Approximate the Hessian – Recall that for f ( x ) : R n → R that is 2nd-order differentiable, ∂x i ( x ) : R n → R . So ∂f � � � � � ∂f � ∂f ∂f ( x + δ e j ) − ( x ) ∂f 2 ∂ ∂x i ∂x i ∂x j ∂x i ( x ) = ( x ) ≈ ∂x j ∂x i δ – We can also compute one row of Hessian each time by � ∂f � � ∂f � � ∂f � ( x + δ e j ) − ( x ) ∂ ∂ x ∂ x ( x ) ≈ , ∂x j ∂ x δ � ⊺ � obtaining � H , which might not be symmetric. Return 1 � H + � H instead 2 – Most times (e.g., in TRM, Newton-CG), only ∇ 2 f ( x ) v for certain v ’s needed: (see, e.g., Manopt https://www.manopt.org/ ) ∇ 2 f ( x ) v ≈ ∇ f ( x + δ v ) − f ( x ) (1) δ 13 / 36

  7. A few words – Can be used for sanity check for correctness of analytic gradient 14 / 36

  8. A few words – Can be used for sanity check for correctness of analytic gradient – Finite-difference approximation of higher (i.e., ≥ 2 )-order derivatives combined with high-order iterative methods can be very efficient (e.g., Manopt https://www.manopt.org/tutorial.html#costdescription ) 14 / 36

  9. A few words – Can be used for sanity check for correctness of analytic gradient – Finite-difference approximation of higher (i.e., ≥ 2 )-order derivatives combined with high-order iterative methods can be very efficient (e.g., Manopt https://www.manopt.org/tutorial.html#costdescription ) – Numerical stability can be an issue: truncation and round off s (finite δ ; accurate evaluation of the nominators) 14 / 36

  10. Outline Analytic differentiation Finite-difference approximation Automatic differentiation Differentiable programming Suggested reading 15 / 36

  11. Four kinds of computing techniques Credit: [Baydin et al., 2017] 16 / 36

  12. Four kinds of computing techniques Credit: [Baydin et al., 2017] Misnomer: should be automatic numerical differentiation 16 / 36

  13. Forward mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: 17 / 36

  14. Forward mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: � dy k � dy k − 1 � � dy 2 � dy 1 ����� dx = d d f f Chain rule: dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 17 / 36

  15. Forward mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: � dy k � dy k − 1 � � dy 2 � dy 1 ����� dx = d d f f Chain rule: dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 � d f � Compute x 0 in one pass, from inner to outer most parenthesis: dx 17 / 36

  16. Forward mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: � dy k � dy k − 1 � � dy 2 � dy 1 ����� dx = d d f f Chain rule: dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 � d f � Compute x 0 in one pass, from inner to outer most parenthesis: dx dy 0 � Input: x 0 , initialization = 1 � dy 0 � x 0 for i = 1 , . . . , k do � � compute y i = f i y i − 1 � � � dyi − 1 � dyi − 1 dyi � dyi compute = � � = f ′ � � � · y i − 1 � � i � dy 0 dyi − 1 dy 0 dy 0 � x 0 � � � yi − 1 x 0 x 0 end for dyk � Output: � dy 0 � x 0 17 / 36

  17. Reverse mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: 18 / 36

  18. Reverse mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: ����� dy k � dy k − 1 � � dy 2 � dy 1 � dx = d d f f Chain rule: dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 18 / 36

  19. Reverse mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: ����� dy k � dy k − 1 � � dy 2 � dy 1 � dx = d d f f Chain rule: dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 � d f � Compute x 0 in two passes, from inner to outer most parenthesis for the 2nd: dx 18 / 36

  20. Reverse mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: ����� dy k � dy k − 1 � � dy 2 � dy 1 � dx = d d f f Chain rule: dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 � d f � Compute x 0 in two passes, from inner to outer most parenthesis for the 2nd: dx Input: x 0 , dyk dyk = 1 for i = 1 , . . . , k do compute y i = f i � y i − 1 � end for // forward pass for i = k − 1 , k − 2 , . . . , 0 do � � � dyk � dyk dyi +1 dyk � � = f ′ � compute = · i +1 ( y i ) � � � � dyi dyi +1 dyi dyi +1 � yi � � � yi +1 yi yi +1 end for // backward pass dyk � Output: � dy 0 � x 0 18 / 36

  21. Forward vs reverse modes 19 / 36

  22. Forward vs reverse modes – forward mode AD : one forward pass, compute the intermediate variable and derivative values together – reverse mode AD : one forward pass to compute the intermediate variable values, one backward pass to compute the intermediate derivatives 19 / 36

  23. Forward vs reverse modes – forward mode AD : one forward pass, compute the intermediate variable and derivative values together – reverse mode AD : one forward pass to compute the intermediate variable values, one backward pass to compute the intermediate derivatives Effectively, two different ways of grouping the multiplicative differential terms: � dy k � dy k − 1 � � dy 2 � dy 1 ����� dx = d d f f dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 i.e., starting from the root: dy 0 dy 0 �→ dy 1 dy 0 �→ dy 2 dy 0 �→ · · · �→ dy k dy 0 ����� dy k � dy k − 1 � � dy 2 � dy 1 � dx = d d f f dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 i.e., starting from the leaf: dy k dy k dy k − 2 �→ · · · �→ dy k dy k dy k �→ dy k − 1 �→ dy 0 ...mixed forward and reverse modes are indeed possible!

  24. Forward vs reverse modes – forward mode AD : one forward pass, compute the intermediate variable and derivative values together – reverse mode AD : one forward pass to compute the intermediate variable values, one backward pass to compute the intermediate derivatives Effectively, two different ways of grouping the multiplicative differential terms: � dy k � dy k − 1 � � dy 2 � dy 1 ����� dx = d d f f dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 i.e., starting from the root: dy 0 dy 0 �→ dy 1 dy 0 �→ dy 2 dy 0 �→ · · · �→ dy k dy 0 ����� dy k � dy k − 1 � � dy 2 � dy 1 � dx = d d f f dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 i.e., starting from the leaf: dy k dy k dy k − 2 �→ · · · �→ dy k dy k dy k �→ dy k − 1 �→ dy 0 19 / 36

  25. Forward vs reverse modes – forward mode AD : one forward pass, compute the intermediate variable and derivative values together – reverse mode AD : one forward pass to compute the intermediate variable values, one backward pass to compute the intermediate derivatives Effectively, two different ways of grouping the multiplicative differential terms: � dy k � dy k − 1 � � dy 2 � dy 1 ����� dx = d d f f dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 i.e., starting from the root: dy 0 dy 0 �→ dy 1 dy 0 �→ dy 2 dy 0 �→ · · · �→ dy k dy 0 ����� dy k � dy k − 1 � � dy 2 � dy 1 � dx = d d f f dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 i.e., starting from the leaf: dy k dy k dy k − 2 �→ · · · �→ dy k dy k dy k �→ dy k − 1 �→ dy 0 ...mixed forward and reverse modes are indeed possible! 19 / 36

  26. Chain rule in computational graphs Let f : R n → R m and h : R n → R k , and f is differentiable at x Chain rule and y = f ( x ) and h is differentiable at y . Then, h ◦ f : R n → R k is differentiable at x , and (write z = h ( y ) ) m � or ∂z j ∂z j ∂y ℓ J [ h ◦ f ] ( x ) = J h ( f ( x )) J f ( x ) , ∂x i = ∂x i ∀ i, j ∂y ℓ ℓ =1 20 / 36

  27. Chain rule in computational graphs Let f : R n → R m and h : R n → R k , and f is differentiable at x Chain rule and y = f ( x ) and h is differentiable at y . Then, h ◦ f : R n → R k is differentiable at x , and (write z = h ( y ) ) m � or ∂z j ∂z j ∂y ℓ J [ h ◦ f ] ( x ) = J h ( f ( x )) J f ( x ) , ∂x i = ∂x i ∀ i, j ∂y ℓ ℓ =1 NB: this is a computational graph, not a NN 20 / 36

  28. Chain rule in computational graphs Let f : R n → R m and h : R n → R k , and f is differentiable at x Chain rule and y = f ( x ) and h is differentiable at y . Then, h ◦ f : R n → R k is differentiable at x , and (write z = h ( y ) ) m � or ∂z j ∂z j ∂y ℓ J [ h ◦ f ] ( x ) = J h ( f ( x )) J f ( x ) , ∂x i = ∂x i ∀ i, j ∂y ℓ ℓ =1 – Each node is a variable, as a function of all incoming variables NB: this is a computational graph, not a NN 20 / 36

  29. Chain rule in computational graphs Let f : R n → R m and h : R n → R k , and f is differentiable at x Chain rule and y = f ( x ) and h is differentiable at y . Then, h ◦ f : R n → R k is differentiable at x , and (write z = h ( y ) ) m � or ∂z j ∂z j ∂y ℓ J [ h ◦ f ] ( x ) = J h ( f ( x )) J f ( x ) , ∂x i = ∂x i ∀ i, j ∂y ℓ ℓ =1 – Each node is a variable, as a function of all incoming variables – If node B a descent of node A , ∂B ∂A is the rate of change in B wrt change in A NB: this is a computational graph, not a NN 20 / 36

  30. Chain rule in computational graphs Let f : R n → R m and h : R n → R k , and f is differentiable at x Chain rule and y = f ( x ) and h is differentiable at y . Then, h ◦ f : R n → R k is differentiable at x , and (write z = h ( y ) ) m � or ∂z j ∂z j ∂y ℓ J [ h ◦ f ] ( x ) = J h ( f ( x )) J f ( x ) , ∂x i = ∂x i ∀ i, j ∂y ℓ ℓ =1 – Each node is a variable, as a function of all incoming variables – If node B a descent of node A , ∂B ∂A is the rate of change in B wrt change in A – Traveling along a path, rates of changes should be multiplied NB: this is a computational graph, not a NN 20 / 36

  31. Chain rule in computational graphs Let f : R n → R m and h : R n → R k , and f is differentiable at x Chain rule and y = f ( x ) and h is differentiable at y . Then, h ◦ f : R n → R k is differentiable at x , and (write z = h ( y ) ) m � or ∂z j ∂z j ∂y ℓ J [ h ◦ f ] ( x ) = J h ( f ( x )) J f ( x ) , ∂x i = ∂x i ∀ i, j ∂y ℓ ℓ =1 – Each node is a variable, as a function of all incoming variables – If node B a descent of node A , ∂B ∂A is the rate of change in B wrt change in A – Traveling along a path, rates of changes should be multiplied – Chain rule: summing up rates over all connecting paths! (e.g., x 2 to z j as shown) NB: this is a computational graph, not a NN 20 / 36

  32. A multivariate example — forward mode � � � x 1 � sin x 1 x 2 + x 1 x 2 − e x 2 x 2 − e x 2 y = 21 / 36

  33. A multivariate example — forward mode � � � x 1 � sin x 1 x 2 + x 1 x 2 − e x 2 x 2 − e x 2 y = ∂ – interested in ∂x 1 ; for each variable v i . = ∂v i v i , write ˙ ∂x 1 21 / 36

  34. A multivariate example — forward mode � � � x 1 � sin x 1 x 2 + x 1 x 2 − e x 2 x 2 − e x 2 y = ∂ – interested in ∂x 1 ; for each variable v i . = ∂v i v i , write ˙ ∂x 1 – for each node, sum up partials over all incoming edges, e.g., v 4 = ∂v 4 v 1 + ∂v 4 ˙ ∂v 1 ˙ ∂v 3 ˙ v 3 21 / 36

  35. A multivariate example — forward mode � � � x 1 � sin x 1 x 2 + x 1 x 2 − e x 2 x 2 − e x 2 y = ∂ – interested in ∂x 1 ; for each variable v i . = ∂v i v i , write ˙ ∂x 1 – for each node, sum up partials over all incoming edges, e.g., v 4 = ∂v 4 v 1 + ∂v 4 ˙ ∂v 1 ˙ ∂v 3 ˙ v 3 – complexity: 21 / 36

  36. A multivariate example — forward mode � � � x 1 � sin x 1 x 2 + x 1 x 2 − e x 2 x 2 − e x 2 y = ∂ – interested in ∂x 1 ; for each variable v i . = ∂v i v i , write ˙ ∂x 1 – for each node, sum up partials over all incoming edges, e.g., v 4 = ∂v 4 v 1 + ∂v 4 ˙ ∂v 1 ˙ ∂v 3 ˙ v 3 – complexity: O (# edges + # nodes ) 21 / 36

  37. A multivariate example — forward mode � � � x 1 � sin x 1 x 2 + x 1 x 2 − e x 2 x 2 − e x 2 y = ∂ – interested in ∂x 1 ; for each variable v i . = ∂v i v i , write ˙ ∂x 1 – for each node, sum up partials over all incoming edges, e.g., v 4 = ∂v 4 v 1 + ∂v 4 ˙ ∂v 1 ˙ ∂v 3 ˙ v 3 – complexity: O (# edges + # nodes ) – for f : R n → R m , make n forward passes: O ( n (# edges + # nodes )) 21 / 36

  38. A multivariate example — reverse mode 22 / 36

  39. A multivariate example — reverse mode – interested in ∂y ∂ ; for each variable v i , write v i . ∂y = ∂v i (called adjoint variable ) 22 / 36

  40. A multivariate example — reverse mode – interested in ∂y ∂ ; for each variable v i , write v i . ∂y = ∂v i (called adjoint variable ) – for each node, sum up partials over all outgoing edges, e.g., v 4 = ∂v 5 ∂v 4 v 5 + ∂v 6 ∂v 4 v 6 22 / 36

  41. A multivariate example — reverse mode – interested in ∂y ∂ ; for each variable v i , write v i . ∂y = ∂v i (called adjoint variable ) – for each node, sum up partials over all outgoing edges, e.g., v 4 = ∂v 5 ∂v 4 v 5 + ∂v 6 ∂v 4 v 6 – complexity: 22 / 36

  42. A multivariate example — reverse mode – interested in ∂y ∂ ; for each variable v i , write v i . ∂y = ∂v i (called adjoint variable ) – for each node, sum up partials over all outgoing edges, e.g., v 4 = ∂v 5 ∂v 4 v 5 + ∂v 6 ∂v 4 v 6 – complexity: O (# edges + # nodes ) 22 / 36

  43. A multivariate example — reverse mode – interested in ∂y ∂ ; for each variable v i , write v i . ∂y = ∂v i (called adjoint variable ) – for each node, sum up partials over all outgoing edges, e.g., v 4 = ∂v 5 ∂v 4 v 5 + ∂v 6 ∂v 4 v 6 – complexity: O (# edges + # nodes ) – for f : R n → R m , make n forward passes: O ( m (# edges + # nodes )) example from Ch 1 of [Griewank and Walther, 2008] 22 / 36

  44. Forward vs. reverse modes For general function f : R n → R m , suppose there is no loop in the computational graph, i.e., acyclic graph . Define E : set of edges ; V : set of nodes 23 / 36

  45. Forward vs. reverse modes For general function f : R n → R m , suppose there is no loop in the computational graph, i.e., acyclic graph . Define E : set of edges ; V : set of nodes forward mode reverse mode start from roots leaves end with leaves roots v i . v i . = ∂v i ∂y invariants ˙ ∂x ( x —root of interest) = ∂v i ( y —leaf of interest) rule sum over incoming edges sum over outgoing edges complexity O ( n | E | + n | V | ) O ( m | E | + m | V | ) better when m ≫ n n ≫ m 23 / 36

  46. Directional derivatives Consider f ( x ) : R n → R m . Let v s ’s be the variables in its computational graph. Particularly, v n − 1 = x 1 , v n − 2 = x 2 , . . . , v 0 = x n . D p ( · ) means directional derivative wrt p . In practical implementations, 24 / 36

  47. Directional derivatives Consider f ( x ) : R n → R m . Let v s ’s be the variables in its computational graph. Particularly, v n − 1 = x 1 , v n − 2 = x 2 , . . . , v 0 = x n . D p ( · ) means directional derivative wrt p . In practical implementations, forward mode : compute J f p , i.e., Jacobian-vector product 24 / 36

  48. Directional derivatives Consider f ( x ) : R n → R m . Let v s ’s be the variables in its computational graph. Particularly, v n − 1 = x 1 , v n − 2 = x 2 , . . . , v 0 = x n . D p ( · ) means directional derivative wrt p . In practical implementations, forward mode : compute J f p , i.e., Jacobian-vector product – Why? (1) Columns of J f can be obtained by setting p = e 1 , . . . , e n . (2) When J f has special structures (e.g., sparsity), save computation by judicious choices of p ’s (3) Problem may only need J f p for a specific p , not J f itself—save computation again 24 / 36

  49. Directional derivatives Consider f ( x ) : R n → R m . Let v s ’s be the variables in its computational graph. Particularly, v n − 1 = x 1 , v n − 2 = x 2 , . . . , v 0 = x n . D p ( · ) means directional derivative wrt p . In practical implementations, forward mode : compute J f p , i.e., Jacobian-vector product – Why? (1) Columns of J f can be obtained by setting p = e 1 , . . . , e n . (2) When J f has special structures (e.g., sparsity), save computation by judicious choices of p ’s (3) Problem may only need J f p for a specific p , not J f itself—save computation again – How? (1) initialize D p v n − 1 = p 1 , . . . , D p v 0 = p n . (2) apply chain rule: � � ∂v i ∂v i ∇ x v i = ∂v j ∇ x v j = ⇒ D p v i = ∂v j D p v j j : incoming j : incoming 24 / 36

  50. Directional derivatives Consider f ( x ) : R n → R m . Let v s ’s be the variables in its computational graph. Particularly, v n − 1 = x 1 , v n − 2 = x 2 , . . . , v 0 = x n . D p ( · ) means directional derivative wrt p . In practical implementations, forward mode : compute J f p , i.e., Jacobian-vector product – Why? (1) Columns of J f can be obtained by setting p = e 1 , . . . , e n . (2) When J f has special structures (e.g., sparsity), save computation by judicious choices of p ’s (3) Problem may only need J f p for a specific p , not J f itself—save computation again – How? (1) initialize D p v n − 1 = p 1 , . . . , D p v 0 = p n . (2) apply chain rule: � � ∂v i ∂v i ∇ x v i = ∂v j ∇ x v j = ⇒ D p v i = ∂v j D p v j j : incoming j : incoming reverse mode : compute J ⊺ f q = ∇ x ( f ⊺ q ) , i.e., Jacobian-trans-vector product 24 / 36

  51. Directional derivatives Consider f ( x ) : R n → R m . Let v s ’s be the variables in its computational graph. Particularly, v n − 1 = x 1 , v n − 2 = x 2 , . . . , v 0 = x n . D p ( · ) means directional derivative wrt p . In practical implementations, forward mode : compute J f p , i.e., Jacobian-vector product – Why? (1) Columns of J f can be obtained by setting p = e 1 , . . . , e n . (2) When J f has special structures (e.g., sparsity), save computation by judicious choices of p ’s (3) Problem may only need J f p for a specific p , not J f itself—save computation again – How? (1) initialize D p v n − 1 = p 1 , . . . , D p v 0 = p n . (2) apply chain rule: � � ∂v i ∂v i ∇ x v i = ∂v j ∇ x v j = ⇒ D p v i = ∂v j D p v j j : incoming j : incoming reverse mode : compute J ⊺ f q = ∇ x ( f ⊺ q ) , i.e., Jacobian-trans-vector product – Why? Similar to the above dv i ( f ⊺ q ) = � ∂v k d d d – How? Track dv i ( f ⊺ q ) : dv k ( f ⊺ q ) k : outgoing ∂v i 24 / 36

  52. Tensor abstraction Tensors : multi-dimensional arrays 25 / 36

  53. Tensor abstraction Tensors : multi-dimensional arrays Each node in the computational graph can be a tensor (scalar, vector, matrix, 3-D tensor, ...) 25 / 36

  54. Tensor abstraction Tensors : multi-dimensional arrays Each node in the computational graph can be a tensor (scalar, vector, matrix, 3-D tensor, ...) f ( W ) = � Y − σ ( W k σ ( W k − 1 σ . . . ( W 1 X ))) � 2 F 25 / 36

  55. Tensor abstraction Tensors : multi-dimensional arrays Each node in the computational graph can be a tensor (scalar, vector, matrix, 3-D tensor, ...) computational graph for DNN f ( W ) = � Y − σ ( W k σ ( W k − 1 σ . . . ( W 1 X ))) � 2 F 25 / 36

  56. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple 26 / 36

  57. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple – Tensor (i.e., vector) chain rules apply, often via tensor-free computation 26 / 36

  58. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple – Tensor (i.e., vector) chain rules apply, often via tensor-free computation – Basis of implementation for: Tensorflow, Pytorch, Jax, etc Jax: https://github.com/google/jax 26 / 36

  59. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple – Tensor (i.e., vector) chain rules apply, often via tensor-free computation – Basis of implementation for: Tensorflow, Pytorch, Jax, etc Jax: https://github.com/google/jax Good to know: – In practice, graphs are built automatically by software 26 / 36

  60. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple – Tensor (i.e., vector) chain rules apply, often via tensor-free computation – Basis of implementation for: Tensorflow, Pytorch, Jax, etc Jax: https://github.com/google/jax Good to know: – In practice, graphs are built automatically by software – Higher-order derivatives can also be done, particularly Hessian-vector product ∇ 2 f ( x ) v (Check out Jax!) 26 / 36

  61. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple – Tensor (i.e., vector) chain rules apply, often via tensor-free computation – Basis of implementation for: Tensorflow, Pytorch, Jax, etc Jax: https://github.com/google/jax Good to know: – In practice, graphs are built automatically by software – Higher-order derivatives can also be done, particularly Hessian-vector product ∇ 2 f ( x ) v (Check out Jax!) – Auto-diff in Tensorflow and Pytorch are specialized to DNNs and focus on 1st order, Jax (in Python) is full fledged and also supports GPU 26 / 36

  62. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple – Tensor (i.e., vector) chain rules apply, often via tensor-free computation – Basis of implementation for: Tensorflow, Pytorch, Jax, etc Jax: https://github.com/google/jax Good to know: – In practice, graphs are built automatically by software – Higher-order derivatives can also be done, particularly Hessian-vector product ∇ 2 f ( x ) v (Check out Jax!) – Auto-diff in Tensorflow and Pytorch are specialized to DNNs and focus on 1st order, Jax (in Python) is full fledged and also supports GPU – General resources for autodiff: http://www.autodiff.org/ , [Griewank and Walther, 2008] 26 / 36

  63. Autodiff in Pytorch 2 with ∇ f ( x ) = − A ⊺ ( y − Ax ) Solve least squares f ( x ) = 1 2 � y − Ax � 2 27 / 36

  64. Autodiff in Pytorch 2 with ∇ f ( x ) = − A ⊺ ( y − Ax ) Solve least squares f ( x ) = 1 2 � y − Ax � 2 loss vs. iterate 27 / 36

  65. Autodiff in Pytorch Train a shallow neural network � � y i − W 2 σ ( W 1 x i ) � 2 f ( W ) = 2 i where σ ( z ) = max ( z, 0) , i.e., ReLU https://pytorch.org/tutorials/beginner/pytorch_with_ examples.html – torch.mm – torch.clamp – torch.no grad() Back propagation is reverse mode auto-differentiation! 28 / 36

  66. Outline Analytic differentiation Finite-difference approximation Automatic differentiation Differentiable programming Suggested reading 29 / 36

  67. Example: image enhancement 30 / 36

  68. Example: image enhancement – Each stage applies a parameterized function to the image, i.e., q w k ◦ · · · ◦ h w 3 ◦ g w 2 ◦ f w 1 ( X ) ( X is the camera raw) 30 / 36

  69. Example: image enhancement – Each stage applies a parameterized function to the image, i.e., q w k ◦ · · · ◦ h w 3 ◦ g w 2 ◦ f w 1 ( X ) ( X is the camera raw) – The parameterized functions may or may not be DNNs 30 / 36

  70. Example: image enhancement – Each stage applies a parameterized function to the image, i.e., q w k ◦ · · · ◦ h w 3 ◦ g w 2 ◦ f w 1 ( X ) ( X is the camera raw) – The parameterized functions may or may not be DNNs – Each function may be analytic, or simply a chunk of codes dependent on the parameters 30 / 36

  71. Example: image enhancement – Each stage applies a parameterized function to the image, i.e., q w k ◦ · · · ◦ h w 3 ◦ g w 2 ◦ f w 1 ( X ) ( X is the camera raw) – The parameterized functions may or may not be DNNs – Each function may be analytic, or simply a chunk of codes dependent on the parameters – w i ’s are the trainable parameters Credit: https://people.csail.mit.edu/tzumao/gradient_halide/ 30 / 36

  72. Example: image enhancement – the trainable parameters are learned by gradient descent based on auto-differentiation – This is generalization of training DNNs with the classic feedforward structure to training general parameterized functions, using derivative-based methods Credit: https://people.csail.mit.edu/tzumao/gradient_halide/ 31 / 36

  73. Example: control a trebuchet https://fluxml.ai/2019/03/05/dp-vs-rl.html 32 / 36

Download Presentation
Download Policy: The content available on the website is offered to you 'AS IS' for your personal information and use only. It cannot be commercialized, licensed, or distributed on other websites without prior consent from the author. To download a presentation, simply click this link. If you encounter any difficulties during the download process, it's possible that the publisher has removed the file from their server.

Recommend


More recommend