maximum likelihood with bias corrected calibration is
play

Maximum Likelihood with Bias-Corrected Calibration is Hard-To-Beat - PowerPoint PPT Presentation

Maximum Likelihood with Bias-Corrected Calibration is Hard-To-Beat at Label Shift Adaptation Amr M. Alexandari*, Anshul Kundaje, Avanti Shrikumar * *co-first authors co-corresponding authors Amr Alexandari Anshul Kundaje PhD Student


  1. Maximum Likelihood with Bias-Corrected Calibration is Hard-To-Beat at Label Shift Adaptation Amr M. Alexandari*, Anshul Kundaje†, Avanti Shrikumar *† *co-first authors †co-corresponding authors Amr Alexandari Anshul Kundaje PhD Student Assistant Professor Dept. of Computer Science Depts. of CS & Genetics

  2. Label Shift Illustrated Train Model

  3. Label Shift Illustrated Original model under-predicts

  4. Label Shift Illustrated update

  5. Label Shift Illustrated We don’t have How do we ground-truth update our ? labels for the classifier? new patients!

  6. Main Contributions - An approach that achieves state-of-the-art on label shift adaptation - Scales to datasets with high-dimensional inputs - Does not require model retraining - Combines Max Likelihood with specific types of calibration. - Calibration with Temp. Scaling (TS) was insufficient (& sometimes harmful!) - Achieved state-of-the-art with extensions of TS (one of which we propose) that correct for systematic bias

  7. Formal Definition of Label Shift Let: - 𝑧 denote our labels (whether or not person has disease) - 𝒚 denote the observed symptoms - 𝑞(𝒚, 𝑧) denote joint distribution (𝒚, 𝑧) at beginning of outbreak (“source domain”) - 𝑟(𝒚, 𝑧) denote joint distribution at widespread stage (“target domain”), when we don’t know labels - Goal: adapt source-domain classifier that predicts 𝑞(𝑧|𝒚) to instead predict 𝑟(𝑧|𝒚) for target domain Core assumption: disease has same symptoms irrespective of outbreak stage, i.e. 𝑞 𝒚 𝑧 = 𝑟(𝒚|𝑧) . - Thus, difference between source & target domain is exclusively caused by shift in label proportions 𝑞(𝑧) and 𝑟(𝑧) . Formally, 𝑟 𝒚, 𝑧 = 𝑞 𝒚|𝑧 𝑟 𝑧 - Also called prior probability shift (Amos, 2008), corresponds to “anti-causal learning” i.e. predicting cause 𝑧 from effects 𝒚 (Schloelkopf, 2012). - Anti-causal learning is appropriate here because diseases status 𝑧 cause the symptoms 𝒚 .

  8. Estimating 𝑟 𝑧 𝒚 with Bayes’ Rule - Although 𝑞(𝒚|𝑧) is preserved, computing it is hard when 𝒚 is high-dimensional. - Much easier to estimate 𝑞(𝑧|𝒚) and 𝑞(𝑧) from the source domain, as 𝑧 is lower-dimensional. - If we know 𝑟(𝑧) , we can retrieve 𝑟 𝑧 𝑦 without ever estimating 𝑞 𝒚 𝑧 using Bayes’ Rule (first shown in Saerens et al., 2002): !(#,𝒚) !(𝒚|#)!(#) We first write 𝑟 𝑧 𝒚 = !(𝒚) = ∑ !∗ !(𝒚|# ∗ )!(# ∗ ) (terms in red are not explicitly known) )(𝒚|#)!(#) Substituting 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) (label shift assumption), we have 𝑟 𝑧 𝒚 = ∑ !∗ )(𝒚|# ∗ )!(# ∗ ) Through Bayes’ rule, observe that 𝑞 𝒚 𝑧 = )(#|𝒚))(𝒚) )(#) #(!|𝒚)#(𝒚) !(#) #(!) Substituting, we get 𝑟 𝑧 𝒚 = Reminders: #(!|𝒚)#(𝒚) ∑ ! !(#) - 𝒚 denotes features (e.g. symptoms) #(!) - 𝑧 denotes labels (e.g. disease status) #(!|𝒚) #(!) !(#) - 𝑞 indicates source-domain (labels known) 𝑞(𝑦) cancels out, giving 𝑟 𝑧 𝒚 = #(!|𝒚) - 𝑟 indicates target domain (labels unknown) ∑ ! #(!) !(#) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧)

  9. Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  10. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  11. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  12. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  13. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  14. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  15. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  16. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  17. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

Download Presentation
Download Policy: The content available on the website is offered to you 'AS IS' for your personal information and use only. It cannot be commercialized, licensed, or distributed on other websites without prior consent from the author. To download a presentation, simply click this link. If you encounter any difficulties during the download process, it's possible that the publisher has removed the file from their server.

Recommend


More recommend