Example #1
0
 def test_filter_labels(self):
     golds = np.array([-1, 0, 0, 1, 1])
     preds = np.array([0, 0, 1, 1, -1])
     filtered = filter_labels(
         label_dict={"golds": golds, "preds": preds},
         filter_dict={"golds": [-1], "preds": [-1]},
     )
     np.testing.assert_array_equal(filtered["golds"], np.array([0, 0, 1]))
     np.testing.assert_array_equal(filtered["preds"], np.array([0, 1, 1]))
Example #2
0
def metric_score(
    golds: Optional[np.ndarray] = None,
    preds: Optional[np.ndarray] = None,
    probs: Optional[np.ndarray] = None,
    metric: str = "accuracy",
    filter_dict: Optional[Dict[str, List[int]]] = None,
    **kwargs: Any,
) -> float:
    """Evaluate a standard metric on a set of predictions/probabilities.

    Parameters
    ----------
    golds
        An array of gold (int) labels
    preds
        An array of (int) predictions
    probs
        An [n_datapoints, n_classes] array of probabilistic (float) predictions
    metric
        The name of the metric to calculate
    filter_dict
        A mapping from label set name to the labels that should be filtered out for
        that label set

    Returns
    -------
    float
        The value of the requested metric

    Raises
    ------
    ValueError
        The requested metric is not currently supported
    ValueError
        The user attempted to calculate roc_auc score for a non-binary problem
    """
    if metric not in METRICS:
        msg = f"The metric you provided ({metric}) is not currently implemented."
        raise ValueError(msg)

    # Print helpful error messages if golds or preds has invalid shape or type
    golds = to_int_label_array(golds) if golds is not None else None
    preds = to_int_label_array(preds) if preds is not None else None

    # Optionally filter out examples (e.g., abstain predictions or unknown labels)
    label_dict: Dict[str, Optional[np.ndarray]] = {
        "golds": golds,
        "preds": preds,
        "probs": probs,
    }
    if filter_dict:
        if set(filter_dict.keys()).difference(set(label_dict.keys())):
            raise ValueError(
                "filter_dict must only include keys in ['golds', 'preds', 'probs']"
            )
        # label_dict is overwritten from type Dict[str, Optional[np.ndarray]]
        # to Dict[str, np.ndarray]
        label_dict = filter_labels(label_dict, filter_dict)  # type: ignore

    # Confirm that required label sets are available
    func, label_names = METRICS[metric]
    for label_name in label_names:
        if label_dict[label_name] is None:
            raise ValueError(
                f"Metric {metric} requires access to {label_name}.")

    label_sets = [label_dict[label_name] for label_name in label_names]
    return func(*label_sets, **kwargs)