genera ve models and model cri cism via op mized maximum

Genera&ve Models and Model Cri&cism via Op&mized Maximum - PowerPoint PPT Presentation

Genera&ve Models and Model Cri&cism via Op&mized Maximum Mean Discrepancy Dougal J. Sutherland Gatsby unit, UCL Two-sample tes&ng Observe two different datasets: vs Y Q X P Ques&on we want to answer: is


  1. Genera&ve Models and Model Cri&cism via Op&mized Maximum Mean Discrepancy Dougal J. Sutherland Gatsby unit, UCL

  2. Two-sample tes&ng Observe two different datasets: vs Y ∼ Q X ∼ P Ques&on we want to answer: is ? P = Q 2

  3. Two-sample tes&ng Applica&ons: • Do cigareMe smokers and non-smokers have different distribu&ons of cancers? • Do these neurons behave differently when the subject is looking at background image A instead of B? • Do these columns from different databases mean the same thing? • Did my genera&ve model actually learn the distribu&on I wanted it to? 3

  4. Standard approaches • (Unpaired) t -test, Wilcoxon rank-sum test, etc – Only test differences in loca&on (mean) • Kolmogorov-Smirnov test – Tests for all differences – Nonparametric – Hard to extend to > 1d • Want a test that looks for all possible differences, without parametric assump&ons, in mul&ple dimensions 4

  5. Defining a two-sample test 1. Choose a distance between distribu&ons ρ ( P , Q ) Ideally, if and only if – ρ ( P , Q ) = 0 P = Q 2. Es&mate the distribu&on distance from data: ρ ( X, Y ) ˆ 3. Choose a rejec&on threshold; say when P 6 = Q ˆ ρ > c α X,Y ∼ P (ˆ Pr ρ ( X, Y ) > c α ) < α Reminder: • The level of a test is probability of false rejec&on • The power of a test is probability of true rejec&on 5

  6. A Kernel Distance on Distribu&ons Quick reminder about kernels: • Our data lives in a space X • The kernel is a similarity func&on k : X × X → R ✓ ◆ � 1 e.g. 2 σ 2 k x � y k 2 k ( x, y ) = exp • Corresponds to a reproducing kernel Hilbert space (RKHS) , with feature map , by H ϕ : X → H h ϕ ( x ) , ϕ ( y ) i H = k ( x, y ) We’ll use a base kernel on to make a distance X between distribu&ons over . X 6

  7. Mean embeddings of distribu&ons The mean embedding of a distribu&on in an RKHS : h ϕ ( x ) , ϕ ( y ) i = k ( x, y ) , Remember µ P = E x ∼ P [ ϕ ( x )] so we can think of as . ϕ ( x ) k ( x, · ) 1.0 1.0 0.8 0.8 Pr( X ) x )  [ k ( X, x )] 0.6 0.6 0.4 0.4 0.2 0.2 0.0 0.0 −2 −1 0 1 2 −3 −2 −1 0 1 2 3 7 x x

  8. Mean embeddings of distribu&ons The mean embedding of a distribu&on in an RKHS : h ϕ ( x ) , ϕ ( y ) i = k ( x, y ) , Remember µ P = E x ∼ P [ ϕ ( x )] so we can think of as . ϕ ( x ) k ( x, · ) 1.0 1.0 0.8 0.8 Pr( X ) x )  [ k ( X, x )] 0.6 0.6 0.4 0.4 0.2 0.2 0.0 0.0 −2 −1 0 1 2 −3 −2 −1 0 1 2 3 8 x x

  9. Mean embeddings of distribu&ons The mean embedding of a distribu&on in an RKHS : h ϕ ( x ) , ϕ ( y ) i = k ( x, y ) , Remember µ P = E x ∼ P [ ϕ ( x )] so we can think of as . ϕ ( x ) k ( x, · ) 1.0 1.0 0.8 0.8 Pr( X ) x )  [ k ( X, x )] 0.6 0.6 0.4 0.4 0.2 0.2 0.0 0.0 −2 −1 0 1 2 −3 −2 −1 0 1 2 3 9 x x

  10. Mean embeddings of distribu&ons The mean embedding of a distribu&on in an RKHS : h ϕ ( x ) , ϕ ( y ) i = k ( x, y ) , Remember µ P = E x ∼ P [ ϕ ( x )] so we can think of as . ϕ ( x ) k ( x, · ) 1.0 1.0 0.8 0.8 Pr( X ) x )  [ k ( X, x )] 0.6 0.6 0.4 0.4 0.2 0.2 0.0 0.0 −2 −1 0 1 2 −3 −2 −1 0 1 2 3 10 x x

  11. Mean embeddings of distribu&ons The mean embedding of a distribu&on in an RKHS : h ϕ ( x ) , ϕ ( y ) i = k ( x, y ) , Remember µ P = E x ∼ P [ ϕ ( x )] so we can think of as . ϕ ( x ) k ( x, · ) 1.0 1.0 0.8 0.8 Pr( X ) x )  [ k ( X, x )] 0.6 0.6 0.4 0.4 0.2 0.2 0.0 0.0 −2 −1 0 1 2 −3 −2 −1 0 1 2 3 11 x x

  12. Maximum Mean Discrepancy ( MMD ) The MMD is the distance between mean embeddings: 2 � µ P = E X ∼ P [ ϕ ( X )] � � � � � − mmd 2 ( P , Q ) = k µ P � µ Q k 2 � � � � H H = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i h µ P , µ Q i = h E X ∼ P [ ϕ ( X )] , E Y ∼ Q [ ϕ ( Y )] i = E X ∼ P Y ∼ Q [ h ϕ ( X ) , ϕ ( Y ) i ] = E X ∼ P Y ∼ Q [ k ( X, Y )] mmd ( P , Q ) = sup E X ∼ P f ( X ) − E Y ∼ Q f ( Y ) f ∈ H 12

  13. MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 13

  14. MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 14

  15. MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 15

  16. MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 16

  17. MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 17

  18. MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 18

  19. MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 19

  20. MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 20

  21. MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 21

  22. Permuta&on tes&ng S&ll need rejec&on threshold: X,Y ∼ P (ˆ Pr ρ ( X, Y ) > c α ) ≤ α When , MMD asympto&cs depend on , so it’s P P = Q hard to find a threshold that way. Permuta&on test: split randomly to es&mate MMD when . P = Q X Y 2 ( X (1) , Y (1) ) 2 ( X (2) , Y (2) ) m [ m [ mmd mmd . . . 2 ( X ( i ) , Y ( i ) ) : (1- 𝛽 )th quan&le of m [ ˆ mmd c α 22

  23. Compu&ng permuta&on tests   k ( X 1 , X 1 ) K ( X 1 , X m ) K ( X 1 , Y 1 ) K ( X 1 , Y m ) . . . . . . . . . . ... ... . . . .   . . . .     k ( X m , X 1 ) K ( X m , X m ) K ( X m , Y 1 ) K ( X m , Y m ) . . . . . .   K =   k ( Y 1 , X 1 ) K ( Y 1 , X m ) K ( Y 1 , Y 1 ) K ( Y 1 , Y m ) . . . . . .    . . . .  ... ... . . . .   . . . .   k ( Y m , X 1 ) K ( Y m , X m ) K ( Y m , Y 1 ) K ( Y m , Y m ) . . . . . . Each element of K is added or subtracted to a term of each permuta&on es&mate. So, do it all in one pass. 18 16 Our perm. 14 MKL spectr. Original Matlab code: 381s 12 Time (s) 10 8 BeMer Python code: 182s 6 4 2 0 0 5 10 15 20 25 23 Number of threads

  24. Example two-sample test 0.6 0.6 0.5 0.5 0.4 0.4 0.3 0.3 0.2 0.2 0.1 0.1 0.0 0.0 −3 −2 −1 0 1 2 3 −3 −2 −1 0 1 2 3 ⇣ ⌘ X ∼ P = N (0 , 1) 1 Y ∼ Q = Laplace 0 , √ 2 1. Choose a kernel k 2. Es&mate MMD for true division and many permuta&ons 2 3. Reject if m [ k ( X, Y ) > c α mmd 24

  25. The kernel maMers! Witness func7on f helps compare samples: mmd ( P , Q ) = E X ∼ P f ( X ) − E Y ∼ Q f ( Y ) f ( x ) = µ P ( x ) − µ Q ( x ) = E X ∼ P k ( x, X ) − E Y ∼ Q k ( x, Y ) σ = 0 . 75; p = 0 . 0; 25 −4 −3 −2 −1 0 1 2 3 4

  26. The kernel maMers! Witness func7on f helps compare samples: mmd ( P , Q ) = E X ∼ P f ( X ) − E Y ∼ Q f ( Y ) f ( x ) = µ P ( x ) − µ Q ( x ) = E X ∼ P k ( x, X ) − E Y ∼ Q k ( x, Y ) σ = 0 . 75; p = 0 . 0; σ = 2; p = 0 . ;3 26 −4 −3 −2 −1 0 1 2 3 4

  27. The kernel maMers! Witness func7on f helps compare samples: mmd ( P , Q ) = E X ∼ P f ( X ) − E Y ∼ Q f ( Y ) f ( x ) = µ P ( x ) − µ Q ( x ) = E X ∼ P k ( x, X ) − E Y ∼ Q k ( x, Y ) σ = 0 . 75; p = 0 . 0; σ = 1; p = 0 . ;3 σ = 0 . 1; p = 0 . 16 27 −4 −3 −2 −1 0 1 2 3 4

  28. Choosing a kernel So we need a way to pick a kernel to do the test. Choose a kernel k X Y Chosen k in MMD test 28

  29. Choosing a kernel So we need a way to pick a kernel to do the test. Split data: X Y Choose a kernel k Chosen k in MMD test How to pick k ? Typically: maximize MMD . But we want the (asympto&cally) most powerful test . 29

Recommend


More recommend