def test_classification_metrics_avg() -> None:
    hue1 = "H1"
    hue2 = "H2"
    m = MetricsDict(hues=[hue1, hue2], is_classification_metrics=True)
    m.add_metric("foo", 1.0)
    m.add_metric("foo", 2.0)
    # Perfect predictions for hue1, should give AUC == 1.0
    m.add_predictions(["S1", "S2"], np.array([0.0, 1.0]), np.array([0.0, 1.0]), hue=hue1)
    expected_hue1_auc = 1.0
    # Worst possible predictions for hue2, should give AUC == 0.0
    m.add_predictions(["S1", "S2"], np.array([1.0, 0.0]), np.array([0.0, 1.0]), hue=hue2)
    expected_hue2_auc = 0.0
    averaged = m.average(across_hues=False)
    g1_averaged = averaged.values(hue=hue1)
    assert MetricType.AREA_UNDER_ROC_CURVE.value in g1_averaged
    assert g1_averaged[MetricType.AREA_UNDER_ROC_CURVE.value] == [expected_hue1_auc]
    assert MetricType.AREA_UNDER_PR_CURVE.value in g1_averaged
    assert MetricType.SUBJECT_COUNT.value in g1_averaged
    assert g1_averaged[MetricType.SUBJECT_COUNT.value] == [2.0]
    default_averaged = averaged.values()
    assert default_averaged == {"foo": [1.5]}
    can_enumerate = list(averaged.enumerate_single_values())
    assert len(can_enumerate) >= 8
    assert can_enumerate[0] == (hue1, MetricType.AREA_UNDER_ROC_CURVE.value, 1.0)
    assert can_enumerate[-1] == (MetricsDict.DEFAULT_HUE_KEY, "foo", 1.5)

    g2_averaged = averaged.values(hue=hue2)
    assert MetricType.AREA_UNDER_ROC_CURVE.value in g2_averaged
    assert g2_averaged[MetricType.AREA_UNDER_ROC_CURVE.value] == [expected_hue2_auc]

    averaged_across_hues = m.average(across_hues=True)
    assert averaged_across_hues.get_hue_names() == [MetricsDict.DEFAULT_HUE_KEY]
    assert MetricType.AREA_UNDER_ROC_CURVE.value in averaged_across_hues.values()
    expected_averaged_auc = 0.5 * (expected_hue1_auc + expected_hue2_auc)
    assert averaged_across_hues.values()[MetricType.AREA_UNDER_ROC_CURVE.value] == [expected_averaged_auc]
def test_metrics_dict_roc() -> None:
    """
    Test if adding ROC entries to a MetricsDict instance works, and returns the correct AUC.
    """
    # Prepare a vector of predictions and labels. We can compute AUC off those to compare.
    # MetricsDict will get that supplied in 3 chunks, and should return the same AUC value.
    predictions = np.array([0.5, 0.6, 0.1, 0.8, 0.2, 0.9])
    labels = np.array([0, 1.0, 0, 0, 1, 1], dtype=np.float)
    split_length = [3, 2, 1]
    assert sum(split_length) == len(predictions)
    summed = np.cumsum(split_length)
    m = MetricsDict()
    for i, end in enumerate(summed):
        start = 0 if i == 0 else summed[i - 1]
        pred = predictions[start:end]
        label = labels[start:end]
        subject_ids = list(range(len(pred)))
        m.add_predictions(subject_ids, pred, label)
    assert m.has_prediction_entries
    actual_auc = m.get_roc_auc()
    expected_auc = roc_auc_score(labels, predictions)
    assert actual_auc == pytest.approx(expected_auc, 1e-6)
    actual_pr_auc = m.get_pr_auc()
    expected_pr_auc = 0.7111111
    assert actual_pr_auc == pytest.approx(expected_pr_auc, 1e-6)
def test_metrics_dict_roc_degenerate() -> None:
    """
    Test if adding ROC entries to a MetricsDict instance works, if there is only 1 class present.
    """
    # Prepare a vector of predictions and labels. We can compute AUC off those to compare.
    # MetricsDict will get that supplied in 3 chunks, and should return the same AUC value.
    predictions = np.array([0.5, 0.6, 0.1, 0.8, 0.2, 0.9])
    m = MetricsDict()
    subject_ids = list(range(len(predictions)))
    m.add_predictions(subject_ids, predictions, np.ones_like(predictions))
    assert m.has_prediction_entries
    assert m.get_roc_auc() == 1.0
    assert m.get_pr_auc() == 1.0
def test_metrics_dict_average_additional_metrics() -> None:
    """
    Test if computing the ROC entries and metrics at optimal threshold with MetricsDict.average() works
    as expected and returns the correct values.
    """
    # Prepare a vector of predictions and labels.
    predictions = np.array([0.5, 0.6, 0.1, 0.8, 0.2, 0.9])
    labels = np.array([0, 1.0, 0, 0, 1, 1], dtype=np.float)
    split_length = [3, 2, 1]

    # Get MetricsDict
    assert sum(split_length) == len(predictions)
    summed = np.cumsum(split_length)
    # MetricsDict will get that supplied in 3 chunks.
    m = MetricsDict()
    for i, end in enumerate(summed):
        start = 0 if i == 0 else summed[i - 1]
        pred = predictions[start:end]
        label = labels[start:end]
        subject_ids = list(range(len(pred)))
        m.add_predictions(subject_ids, pred, label)
    assert m.has_prediction_entries

    # Compute average MetricsDict
    averaged = m.average()

    # Compute additional expected metrics for the averaged MetricsDict
    expected_auc = roc_auc_score(labels, predictions)
    expected_fpr, expected_tpr, thresholds = roc_curve(labels, predictions)
    expected_optimal_idx = np.argmax(expected_tpr - expected_fpr)
    expected_optimal_threshold = float(thresholds[expected_optimal_idx])
    expected_accuracy = np.mean((predictions > expected_optimal_threshold) == labels)

    # Check computed values against expected
    assert averaged.values()[MetricType.OPTIMAL_THRESHOLD.value][0] == pytest.approx(expected_optimal_threshold)
    assert averaged.values()[MetricType.ACCURACY_AT_OPTIMAL_THRESHOLD.value][0] == pytest.approx(expected_accuracy)
    assert averaged.values()[MetricType.FALSE_POSITIVE_RATE_AT_OPTIMAL_THRESHOLD.value][0] == \
           pytest.approx(expected_fpr[expected_optimal_idx])
    assert averaged.values()[MetricType.FALSE_NEGATIVE_RATE_AT_OPTIMAL_THRESHOLD.value][0] == \
           pytest.approx(1 - expected_tpr[expected_optimal_idx])
    assert averaged.values()[MetricType.AREA_UNDER_ROC_CURVE.value][0] == pytest.approx(expected_auc, 1e-6)