def test_roc_curve(pred, target, expected_tpr, expected_fpr): fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target)) assert fpr.shape == tpr.shape assert fpr.size(0) == thresh.size(0) assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr)) assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr))
def _auroc_compute( preds: torch.Tensor, target: torch.Tensor, mode: str, num_classes: Optional[int] = None, pos_label: Optional[int] = None, average: Optional[str] = 'macro', max_fpr: Optional[float] = None, sample_weights: Optional[Sequence] = None, ) -> torch.Tensor: # binary mode override num_classes if mode == 'binary': num_classes = 1 # check max_fpr parameter if max_fpr is not None: if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1): raise ValueError( f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") if LooseVersion(torch.__version__) < LooseVersion('1.6.0'): raise RuntimeError( "`max_fpr` argument requires `torch.bucketize` which" " is not available below PyTorch version 1.6") # max_fpr parameter is only support for binary if mode != 'binary': raise ValueError( f"Partial AUC computation not available in" f" multilabel/multiclass setting, 'max_fpr' must be" f" set to `None`, received `{max_fpr}`.") # calculate fpr, tpr if mode == 'multi-label': # for multilabel we iteratively evaluate roc in a binary fashion output = [ roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights) for i in range(num_classes) ] fpr = [o[0] for o in output] tpr = [o[1] for o in output] else: fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights) # calculate standard roc auc score if max_fpr is None or max_fpr == 1: if num_classes != 1: # calculate auc scores per class auc_scores = [auc(x, y) for x, y in zip(fpr, tpr)] # calculate average if average == AverageMethods.NONE: return auc_scores elif average == AverageMethods.MACRO: return torch.mean(torch.stack(auc_scores)) elif average == AverageMethods.WEIGHTED: if mode == DataType.MULTILABEL: support = torch.sum(target, dim=0) else: support = torch.bincount(target.flatten(), minlength=num_classes) return torch.sum( torch.stack(auc_scores) * support / support.sum()) allowed_average = [e.value for e in AverageMethods] raise ValueError( f"Argument `average` expected to be one of the following:" f" {allowed_average} but got {average}") return auc(fpr, tpr) max_fpr = torch.tensor(max_fpr, device=fpr.device) # Add a single point at max_fpr and interpolate its tpr value stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True) weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1]) interp_tpr = torch.lerp(tpr[stop - 1], tpr[stop], weight) tpr = torch.cat([tpr[:stop], interp_tpr.view(1)]) fpr = torch.cat([fpr[:stop], max_fpr.view(1)]) # Compute partial AUC partial_auc = auc(fpr, tpr) # McClish correction: standardize result to be 0.5 if non-discriminant # and 1 if maximal min_area = 0.5 * max_fpr**2 max_area = max_fpr return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))