def _categorical_metric_call_wrapper(
    metric: Callable,
    df: pd.DataFrame,
    feature: pd.Series,
    facet_values: Optional[List[Any]],
    positive_label_index: pd.Series,
    positive_predicted_label_index: pd.Series,
    group_variable: pd.Series,
) -> MetricResult:
    """
    Dispatch calling of different metric functions with the correct arguments
    Calculate CI from a list of values or 1 vs all
    """
    if facet_values:
        try:
            # Build index series from facet
            sensitive_facet_index = _categorical_data_idx(feature, facet_values)
            metric_description = common.metric_description(metric)
            metric_value = smclarify.bias.metrics.call_metric(
                metric,
                df=df,
                feature=feature,
                sensitive_facet_index=sensitive_facet_index,
                label=positive_label_index,
                positive_label_index=positive_label_index,
                predicted_label=positive_predicted_label_index,
                positive_predicted_label_index=positive_predicted_label_index,
                group_variable=group_variable,
            )
        except Exception as exc:
            logger.exception(f"{metric.__name__} metrics failed")
            return MetricError(metric.__name__, metric_description, error=exc)
    else:
        raise ValueError("Facet values must be provided to compute the bias metrics")
    return MetricResult(metric.__name__, metric_description, metric_value)
def _continuous_metric_call_wrapper(
    metric: Callable,
    df: pd.DataFrame,
    feature: pd.Series,
    facet_threshold_index: pd.IntervalIndex,
    positive_label_index: pd.Series,
    positive_predicted_label_index: pd.Series,
    group_variable: pd.Series,
) -> MetricResult:
    """
    Dispatch calling of different metric functions with the correct arguments and bool facet data
    """
    try:
        sensitive_facet_index = _continuous_data_idx(feature, facet_threshold_index)
        metric_description = common.metric_description(metric)
        metric_value = smclarify.bias.metrics.call_metric(
            metric,
            df=df,
            feature=feature,
            sensitive_facet_index=sensitive_facet_index,
            label=positive_label_index,
            positive_label_index=positive_label_index,
            predicted_label=positive_predicted_label_index,
            positive_predicted_label_index=positive_predicted_label_index,
            group_variable=group_variable,
        )
    except Exception as exc:
        logger.exception(f"{metric.__name__} metrics failed")
        return MetricError(metric.__name__, metric_description, error=exc)
    return MetricResult(metric.__name__, metric_description, metric_value)
def test_metric_descriptions():
    """
    Test the list of callable metrics have descriptions present
    """
    pretraining_metrics = PRETRAINING_METRICS
    postraining_metrics = POSTTRAINING_METRICS

    pretraining_metric_descriptions = {}
    for metric in pretraining_metrics:
        description = common.metric_description(metric)
        pretraining_metric_descriptions.update({metric.__name__: description})
    expected_result_1 = {
        "CDDL": "Conditional Demographic Disparity in Labels (CDDL)",
        "CI": "Class Imbalance (CI)",
        "DPL": "Difference in Positive Proportions in Labels (DPL)",
        "JS": "Jensen-Shannon Divergence (JS)",
        "KL": "Kullback-Liebler Divergence (KL)",
        "KS": "Kolmogorov-Smirnov Distance (KS)",
        "LP": "L-p Norm (LP)",
        "TVD": "Total Variation Distance (TVD)",
    }
    assert pretraining_metric_descriptions == expected_result_1

    # post training metrics
    posttraining_metric_descriptions = {}
    for metric in postraining_metrics:
        description = common.metric_description(metric)
        posttraining_metric_descriptions.update({metric.__name__: description})
    expected_result_2 = {
        "AD": "Accuracy Difference (AD)",
        "CDDPL":
        "Conditional Demographic Disparity in Predicted Labels (CDDPL)",
        "DAR": "Difference in Acceptance Rates (DAR)",
        "DCA": "Difference in Conditional Acceptance (DCA)",
        "DCR": "Difference in Conditional Rejection (DCR)",
        "DI": "Disparate Impact (DI)",
        "DPPL":
        "Difference in Positive Proportions in Predicted Labels (DPPL)",
        "DRR": "Difference in Rejection Rates (DRR)",
        "FT": "Flip Test (FT)",
        "RD": "Recall Difference (RD)",
        "TE": "Treatment Equality (TE)",
    }
    assert posttraining_metric_descriptions == expected_result_2