def test_flag_with_cohort_summary(self):
     measure.summary(
         self.df,
         y_true=self.df["classification"],
         y_pred=self.df["avg_classification"],
         prtc_attr=self.df["prtc_attr"],
         cohort_labels=self.cohorts[0],
         pred_type="classification",
         flag_oor=True,
     )
 def test_multi_cohort_cols(self):
     _ = measure.summary(
         self.df,
         y_true=self.df["classification"],
         y_pred=self.df["avg_classification"],
         prtc_attr=self.df["prtc_attr"],
         cohort_labels=self.cohorts,
     )
 def test_no_cohort(self):
     _ = measure.summary(
         X=self.df,
         y_true=self.df["classification"],
         y_pred=self.df["avg_classification"],
         prtc_attr=self.df["prtc_attr"],
         pred_type="classification",
     )
 def test_cohort_summary(self):
     _ = measure.summary(
         X=self.df,
         y_true=self.df["classification"],
         y_pred=self.df["avg_classification"],
         prtc_attr=self.df["prtc_attr"],
         cohort_labels=self.cohorts[0],
         pred_type="classification",
     )
 def test_summary_default_flags_regression(self):
     _ = measure.summary(
         X=self.df,
         y_true=self.df["regression"],
         y_pred=self.df["avg_regression"],
         prtc_attr=self.df["prtc_attr"],
         pred_type="regression",
         flag_oor=True,
     )
 def test_summary_default_flags_classification(self):
     _ = measure.summary(
         self.df,
         y_true=self.df["classification"],
         y_pred=self.df["avg_classification"],
         prtc_attr=self.df["prtc_attr"],
         pred_type="classification",
         flag_oor=True,
     )
 def test_toomany_cohorts(self):
     tmc = self.df["A"].reset_index()
     with pytest.raises(valid.ValidationError):
         _ = measure.summary(
             self.df,
             y_true=self.df["classification"],
             y_pred=self.df["avg_classification"],
             prtc_attr=self.df["prtc_attr"],
             cohort_labels=tmc["index"],
         )