def test_single_bound(sample_metric_frame):
    """Tests handling of invalid confidence interval."""
    with pytest.raises(ValueError) as exc:
        plot_metric_frame(
            sample_metric_frame,
            metrics=["Recall"],
            conf_intervals=["Recall Bounds Single"],
        )
    assert str(exc.value) == _CONF_INTERVALS_MUST_BE_ARRAY
def test_invalid_metric_frame():
    """Tests handling of invalid metric frame."""
    with pytest.raises(ValueError) as exc:
        plot_metric_frame(
            "not_metric_frame",
            metrics=["Recall"],
            conf_intervals=["Recall Bounds Flipped"],
        )
    assert str(exc.value) == _METRIC_FRAME_INVALID_ERROR
def test_flipped_bounds(sample_metric_frame):
    """Tests handling of flipped bounds for confidence intervals.

    Flipped bounds are when upper_bound is lower than lower_bound.
    """
    with pytest.raises(ValueError) as exc:
        plot_metric_frame(
            sample_metric_frame,
            metrics=["Recall"],
            conf_intervals=["Recall Bounds Flipped"],
        )
    assert str(exc.value) == _CONF_INTERVALS_FLIPPED_BOUNDS_ERROR
def test_single_ax_input(sample_metric_frame):
    """Tests plotting function works with single axis input."""
    ax = matplotlib.pyplot.subplot()
    ax = plot_metric_frame(
        sample_metric_frame,
        metrics=["Recall"],
        conf_intervals=["Recall Bounds"],
        ax=ax,
        kind="bar",
        colormap="Pastel1",
    )
def test_multi_ax_input(sample_metric_frame):
    """Tests plotting function works with multiple axis input."""
    fig, ax = matplotlib.pyplot.subplots(nrows=1, ncols=2)
    ax[0].set_title("Recall Plot")
    ax[0].set_ylabel("Recall")
    ax[0].set_xlabel("Race")
    ax[0].set_ylim((0, 1))
    ax = plot_metric_frame(
        sample_metric_frame,
        metrics=["Recall", "Accuracy"],
        conf_intervals=["Recall Bounds", "Accuracy Bounds"],
        ax=ax,
        kind="bar",
        colormap="Pastel1",
    )
def test_plotting_output(sample_metric_frame):
    """Tests for the correct output shape and output type."""
    axs = plot_metric_frame(sample_metric_frame).flatten()
    assert len(axs) == 3  # 3 is number of metrics that aren't arrays
    assert isinstance(axs[0], matplotlib.axes.Axes)
Exemplo n.º 7
0
    "Accuracy": accuracy_score,
    "Accuracy Bounds": accuracy_wilson,
}
metric_frame = MetricFrame(
    metrics=metrics_dict,
    y_true=y_test_true,
    y_pred=y_test_pred,
    sensitive_features=test_set_sex,
)

# %%
# Plotting
# ========
# Plot metrics without confidence intervals
# -----------------------------------------
plot_metric_frame(metric_frame, kind="point", metrics=["Recall", "Accuracy"])


# %%
# Plot metrics with confidence intervals (possibly asymmetric)
# ------------------------------------------------------------
plot_metric_frame(
    metric_frame,
    kind="bar",
    metrics=["Recall", "Accuracy"],
    conf_intervals=["Recall Bounds", "Accuracy Bounds"],
    plot_ci_labels=True,
    subplots=False,
)
plot_metric_frame(
    metric_frame,