def test_filter_labels(self): golds = np.array([-1, 0, 0, 1, 1]) preds = np.array([0, 0, 1, 1, -1]) filtered = filter_labels( label_dict={"golds": golds, "preds": preds}, filter_dict={"golds": [-1], "preds": [-1]}, ) np.testing.assert_array_equal(filtered["golds"], np.array([0, 0, 1])) np.testing.assert_array_equal(filtered["preds"], np.array([0, 1, 1]))
def metric_score( golds: Optional[np.ndarray] = None, preds: Optional[np.ndarray] = None, probs: Optional[np.ndarray] = None, metric: str = "accuracy", filter_dict: Optional[Dict[str, List[int]]] = None, **kwargs: Any, ) -> float: """Evaluate a standard metric on a set of predictions/probabilities. Parameters ---------- golds An array of gold (int) labels preds An array of (int) predictions probs An [n_datapoints, n_classes] array of probabilistic (float) predictions metric The name of the metric to calculate filter_dict A mapping from label set name to the labels that should be filtered out for that label set Returns ------- float The value of the requested metric Raises ------ ValueError The requested metric is not currently supported ValueError The user attempted to calculate roc_auc score for a non-binary problem """ if metric not in METRICS: msg = f"The metric you provided ({metric}) is not currently implemented." raise ValueError(msg) # Print helpful error messages if golds or preds has invalid shape or type golds = to_int_label_array(golds) if golds is not None else None preds = to_int_label_array(preds) if preds is not None else None # Optionally filter out examples (e.g., abstain predictions or unknown labels) label_dict: Dict[str, Optional[np.ndarray]] = { "golds": golds, "preds": preds, "probs": probs, } if filter_dict: if set(filter_dict.keys()).difference(set(label_dict.keys())): raise ValueError( "filter_dict must only include keys in ['golds', 'preds', 'probs']" ) # label_dict is overwritten from type Dict[str, Optional[np.ndarray]] # to Dict[str, np.ndarray] label_dict = filter_labels(label_dict, filter_dict) # type: ignore # Confirm that required label sets are available func, label_names = METRICS[metric] for label_name in label_names: if label_dict[label_name] is None: raise ValueError( f"Metric {metric} requires access to {label_name}.") label_sets = [label_dict[label_name] for label_name in label_names] return func(*label_sets, **kwargs)