Genera&ve Models and Model Cri&cism via Op&mized Maximum Mean Discrepancy Dougal J. Sutherland Gatsby unit, UCL
Two-sample tes&ng Observe two different datasets: vs Y ∼ Q X ∼ P Ques&on we want to answer: is ? P = Q 2
Two-sample tes&ng Applica&ons: • Do cigareMe smokers and non-smokers have different distribu&ons of cancers? • Do these neurons behave differently when the subject is looking at background image A instead of B? • Do these columns from different databases mean the same thing? • Did my genera&ve model actually learn the distribu&on I wanted it to? 3
Standard approaches • (Unpaired) t -test, Wilcoxon rank-sum test, etc – Only test differences in loca&on (mean) • Kolmogorov-Smirnov test – Tests for all differences – Nonparametric – Hard to extend to > 1d • Want a test that looks for all possible differences, without parametric assump&ons, in mul&ple dimensions 4
Defining a two-sample test 1. Choose a distance between distribu&ons ρ ( P , Q ) Ideally, if and only if – ρ ( P , Q ) = 0 P = Q 2. Es&mate the distribu&on distance from data: ρ ( X, Y ) ˆ 3. Choose a rejec&on threshold; say when P 6 = Q ˆ ρ > c α X,Y ∼ P (ˆ Pr ρ ( X, Y ) > c α ) < α Reminder: • The level of a test is probability of false rejec&on • The power of a test is probability of true rejec&on 5
A Kernel Distance on Distribu&ons Quick reminder about kernels: • Our data lives in a space X • The kernel is a similarity func&on k : X × X → R ✓ ◆ � 1 e.g. 2 σ 2 k x � y k 2 k ( x, y ) = exp • Corresponds to a reproducing kernel Hilbert space (RKHS) , with feature map , by H ϕ : X → H h ϕ ( x ) , ϕ ( y ) i H = k ( x, y ) We’ll use a base kernel on to make a distance X between distribu&ons over . X 6
Mean embeddings of distribu&ons The mean embedding of a distribu&on in an RKHS : h ϕ ( x ) , ϕ ( y ) i = k ( x, y ) , Remember µ P = E x ∼ P [ ϕ ( x )] so we can think of as . ϕ ( x ) k ( x, · ) 1.0 1.0 0.8 0.8 Pr( X ) x ) [ k ( X, x )] 0.6 0.6 0.4 0.4 0.2 0.2 0.0 0.0 −2 −1 0 1 2 −3 −2 −1 0 1 2 3 7 x x
Mean embeddings of distribu&ons The mean embedding of a distribu&on in an RKHS : h ϕ ( x ) , ϕ ( y ) i = k ( x, y ) , Remember µ P = E x ∼ P [ ϕ ( x )] so we can think of as . ϕ ( x ) k ( x, · ) 1.0 1.0 0.8 0.8 Pr( X ) x ) [ k ( X, x )] 0.6 0.6 0.4 0.4 0.2 0.2 0.0 0.0 −2 −1 0 1 2 −3 −2 −1 0 1 2 3 8 x x
Mean embeddings of distribu&ons The mean embedding of a distribu&on in an RKHS : h ϕ ( x ) , ϕ ( y ) i = k ( x, y ) , Remember µ P = E x ∼ P [ ϕ ( x )] so we can think of as . ϕ ( x ) k ( x, · ) 1.0 1.0 0.8 0.8 Pr( X ) x ) [ k ( X, x )] 0.6 0.6 0.4 0.4 0.2 0.2 0.0 0.0 −2 −1 0 1 2 −3 −2 −1 0 1 2 3 9 x x
Mean embeddings of distribu&ons The mean embedding of a distribu&on in an RKHS : h ϕ ( x ) , ϕ ( y ) i = k ( x, y ) , Remember µ P = E x ∼ P [ ϕ ( x )] so we can think of as . ϕ ( x ) k ( x, · ) 1.0 1.0 0.8 0.8 Pr( X ) x ) [ k ( X, x )] 0.6 0.6 0.4 0.4 0.2 0.2 0.0 0.0 −2 −1 0 1 2 −3 −2 −1 0 1 2 3 10 x x
Mean embeddings of distribu&ons The mean embedding of a distribu&on in an RKHS : h ϕ ( x ) , ϕ ( y ) i = k ( x, y ) , Remember µ P = E x ∼ P [ ϕ ( x )] so we can think of as . ϕ ( x ) k ( x, · ) 1.0 1.0 0.8 0.8 Pr( X ) x ) [ k ( X, x )] 0.6 0.6 0.4 0.4 0.2 0.2 0.0 0.0 −2 −1 0 1 2 −3 −2 −1 0 1 2 3 11 x x
Maximum Mean Discrepancy ( MMD ) The MMD is the distance between mean embeddings: 2 � µ P = E X ∼ P [ ϕ ( X )] � � � � � − mmd 2 ( P , Q ) = k µ P � µ Q k 2 � � � � H H = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i h µ P , µ Q i = h E X ∼ P [ ϕ ( X )] , E Y ∼ Q [ ϕ ( Y )] i = E X ∼ P Y ∼ Q [ h ϕ ( X ) , ϕ ( Y ) i ] = E X ∼ P Y ∼ Q [ k ( X, Y )] mmd ( P , Q ) = sup E X ∼ P f ( X ) − E Y ∼ Q f ( Y ) f ∈ H 12
MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 13
MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 14
MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 15
MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 16
MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 17
MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 18
MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 19
MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 20
MMD es&mator mmd 2 ( P , Q ) = k µ P � µ Q k 2 = h µ P , µ P i + h µ Q , µ Q i� 2 h µ P , µ Q i H h µ P , µ Q i = E X ∼ P Y ∼ Q [ k ( X, Y )] 1 X h ˆ µ P , ˆ µ Q i = k ( X i , Y j ) mn ij 21
Permuta&on tes&ng S&ll need rejec&on threshold: X,Y ∼ P (ˆ Pr ρ ( X, Y ) > c α ) ≤ α When , MMD asympto&cs depend on , so it’s P P = Q hard to find a threshold that way. Permuta&on test: split randomly to es&mate MMD when . P = Q X Y 2 ( X (1) , Y (1) ) 2 ( X (2) , Y (2) ) m [ m [ mmd mmd . . . 2 ( X ( i ) , Y ( i ) ) : (1- 𝛽 )th quan&le of m [ ˆ mmd c α 22
Compu&ng permuta&on tests k ( X 1 , X 1 ) K ( X 1 , X m ) K ( X 1 , Y 1 ) K ( X 1 , Y m ) . . . . . . . . . . ... ... . . . . . . . . k ( X m , X 1 ) K ( X m , X m ) K ( X m , Y 1 ) K ( X m , Y m ) . . . . . . K = k ( Y 1 , X 1 ) K ( Y 1 , X m ) K ( Y 1 , Y 1 ) K ( Y 1 , Y m ) . . . . . . . . . . ... ... . . . . . . . . k ( Y m , X 1 ) K ( Y m , X m ) K ( Y m , Y 1 ) K ( Y m , Y m ) . . . . . . Each element of K is added or subtracted to a term of each permuta&on es&mate. So, do it all in one pass. 18 16 Our perm. 14 MKL spectr. Original Matlab code: 381s 12 Time (s) 10 8 BeMer Python code: 182s 6 4 2 0 0 5 10 15 20 25 23 Number of threads
Example two-sample test 0.6 0.6 0.5 0.5 0.4 0.4 0.3 0.3 0.2 0.2 0.1 0.1 0.0 0.0 −3 −2 −1 0 1 2 3 −3 −2 −1 0 1 2 3 ⇣ ⌘ X ∼ P = N (0 , 1) 1 Y ∼ Q = Laplace 0 , √ 2 1. Choose a kernel k 2. Es&mate MMD for true division and many permuta&ons 2 3. Reject if m [ k ( X, Y ) > c α mmd 24
The kernel maMers! Witness func7on f helps compare samples: mmd ( P , Q ) = E X ∼ P f ( X ) − E Y ∼ Q f ( Y ) f ( x ) = µ P ( x ) − µ Q ( x ) = E X ∼ P k ( x, X ) − E Y ∼ Q k ( x, Y ) σ = 0 . 75; p = 0 . 0; 25 −4 −3 −2 −1 0 1 2 3 4
The kernel maMers! Witness func7on f helps compare samples: mmd ( P , Q ) = E X ∼ P f ( X ) − E Y ∼ Q f ( Y ) f ( x ) = µ P ( x ) − µ Q ( x ) = E X ∼ P k ( x, X ) − E Y ∼ Q k ( x, Y ) σ = 0 . 75; p = 0 . 0; σ = 2; p = 0 . ;3 26 −4 −3 −2 −1 0 1 2 3 4
The kernel maMers! Witness func7on f helps compare samples: mmd ( P , Q ) = E X ∼ P f ( X ) − E Y ∼ Q f ( Y ) f ( x ) = µ P ( x ) − µ Q ( x ) = E X ∼ P k ( x, X ) − E Y ∼ Q k ( x, Y ) σ = 0 . 75; p = 0 . 0; σ = 1; p = 0 . ;3 σ = 0 . 1; p = 0 . 16 27 −4 −3 −2 −1 0 1 2 3 4
Choosing a kernel So we need a way to pick a kernel to do the test. Choose a kernel k X Y Chosen k in MMD test 28
Choosing a kernel So we need a way to pick a kernel to do the test. Split data: X Y Choose a kernel k Chosen k in MMD test How to pick k ? Typically: maximize MMD . But we want the (asympto&cally) most powerful test . 29
Recommend
More recommend