def _auroc_compute( preds: Tensor, target: Tensor, mode: str, num_classes: Optional[int] = None, pos_label: Optional[int] = None, average: Optional[str] = 'macro', max_fpr: Optional[float] = None, sample_weights: Optional[Sequence] = None, ) -> Tensor: # binary mode override num_classes if mode == 'binary': num_classes = 1 # check max_fpr parameter if max_fpr is not None: if not isinstance(max_fpr, float) and 0 < max_fpr <= 1: raise ValueError( f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") if _TORCH_LOWER_1_6: raise RuntimeError( "`max_fpr` argument requires `torch.bucketize` which" " is not available below PyTorch version 1.6") # max_fpr parameter is only support for binary if mode != 'binary': raise ValueError( f"Partial AUC computation not available in" f" multilabel/multiclass setting, 'max_fpr' must be" f" set to `None`, received `{max_fpr}`.") # calculate fpr, tpr if mode == 'multi-label': if average == AverageMethod.MICRO: fpr, tpr, _ = roc(preds.flatten(), target.flatten(), 1, pos_label, sample_weights) else: # for multilabel we iteratively evaluate roc in a binary fashion output = [ roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights) for i in range(num_classes) ] fpr = [o[0] for o in output] tpr = [o[1] for o in output] else: fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights) # calculate standard roc auc score if max_fpr is None or max_fpr == 1: if mode == 'multi-label' and average == AverageMethod.MICRO: pass elif num_classes != 1: # calculate auc scores per class auc_scores = [ _auc_compute_without_check(x, y, 1.0) for x, y in zip(fpr, tpr) ] # calculate average if average == AverageMethod.NONE: return auc_scores elif average == AverageMethod.MACRO: return torch.mean(torch.stack(auc_scores)) elif average == AverageMethod.WEIGHTED: if mode == DataType.MULTILABEL: support = torch.sum(target, dim=0) else: support = torch.bincount(target.flatten(), minlength=num_classes) return torch.sum( torch.stack(auc_scores) * support / support.sum()) allowed_average = (AverageMethod.NONE.value, AverageMethod.MACRO.value, AverageMethod.WEIGHTED.value) raise ValueError( f"Argument `average` expected to be one of the following:" f" {allowed_average} but got {average}") return _auc_compute_without_check(fpr, tpr, 1.0) max_fpr = tensor(max_fpr, device=fpr.device) # Add a single point at max_fpr and interpolate its tpr value stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True) weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1]) interp_tpr = torch.lerp(tpr[stop - 1], tpr[stop], weight) tpr = torch.cat([tpr[:stop], interp_tpr.view(1)]) fpr = torch.cat([fpr[:stop], max_fpr.view(1)]) # Compute partial AUC partial_auc = _auc_compute_without_check(fpr, tpr, 1.0) # McClish correction: standardize result to be 0.5 if non-discriminant # and 1 if maximal min_area = 0.5 * max_fpr**2 max_area = max_fpr return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))
def _auroc_compute( preds: Tensor, target: Tensor, mode: DataType, num_classes: Optional[int] = None, pos_label: Optional[int] = None, average: Optional[str] = "macro", max_fpr: Optional[float] = None, sample_weights: Optional[Sequence] = None, ) -> Tensor: """Computes Area Under the Receiver Operating Characteristic Curve. Args: preds: predictions from model (logits or probabilities) target: Ground truth labels mode: 'multi class multi dim' or 'multi-label' or 'binary' num_classes: integer with number of classes for multi-label and multiclass problems. Should be set to ``None`` for binary problems pos_label: integer determining the positive class. Should be set to ``None`` for binary problems average: Defines the reduction that is applied to the output: max_fpr: If not ``None``, calculates standardized partial AUC over the range [0, max_fpr]. Should be a float between 0 and 1. sample_weights: sample weights for each data point Example: >>> # binary case >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> preds, target, mode = _auroc_update(preds, target) >>> _auroc_compute(preds, target, mode, pos_label=1) tensor(0.5000) >>> # multiclass case >>> preds = torch.tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], ... [0.85, 0.05, 0.10], ... [0.10, 0.10, 0.80]]) >>> target = torch.tensor([0, 1, 1, 2, 2]) >>> preds, target, mode = _auroc_update(preds, target) >>> _auroc_compute(preds, target, mode, num_classes=3) tensor(0.7778) """ # binary mode override num_classes if mode == DataType.BINARY: num_classes = 1 # check max_fpr parameter if max_fpr is not None: if not isinstance(max_fpr, float) and 0 < max_fpr <= 1: raise ValueError( f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") if _TORCH_LOWER_1_6: raise RuntimeError( "`max_fpr` argument requires `torch.bucketize` which" " is not available below PyTorch version 1.6") # max_fpr parameter is only support for binary if mode != DataType.BINARY: raise ValueError( f"Partial AUC computation not available in" f" multilabel/multiclass setting, 'max_fpr' must be" f" set to `None`, received `{max_fpr}`.") # calculate fpr, tpr if mode == DataType.MULTILABEL: if average == AverageMethod.MICRO: fpr, tpr, _ = roc(preds.flatten(), target.flatten(), 1, pos_label, sample_weights) elif num_classes: # for multilabel we iteratively evaluate roc in a binary fashion output = [ roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights) for i in range(num_classes) ] fpr = [o[0] for o in output] tpr = [o[1] for o in output] else: raise ValueError( "Detected input to be `multilabel` but you did not provide `num_classes` argument" ) else: if mode != DataType.BINARY: if num_classes is None: raise ValueError( "Detected input to `multiclass` but you did not provide `num_classes` argument" ) if average == AverageMethod.WEIGHTED and len( torch.unique(target)) < num_classes: # If one or more classes has 0 observations, we should exclude them, as its weight will be 0 target_bool_mat = torch.zeros((len(target), num_classes), dtype=bool, device=target.device) target_bool_mat[torch.arange(len(target)), target.long()] = 1 class_observed = target_bool_mat.sum(axis=0) > 0 for c in range(num_classes): if not class_observed[c]: warnings.warn( f"Class {c} had 0 observations, omitted from AUROC calculation", UserWarning) preds = preds[:, class_observed] target = target_bool_mat[:, class_observed] target = torch.where(target)[1] num_classes = class_observed.sum() if num_classes == 1: raise ValueError( "Found 1 non-empty class in `multiclass` AUROC calculation" ) fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights) # calculate standard roc auc score if max_fpr is None or max_fpr == 1: if mode == DataType.MULTILABEL and average == AverageMethod.MICRO: pass elif num_classes != 1: # calculate auc scores per class auc_scores = [ _auc_compute_without_check(x, y, 1.0) for x, y in zip(fpr, tpr) ] # calculate average if average == AverageMethod.NONE: return tensor(auc_scores) if average == AverageMethod.MACRO: return torch.mean(torch.stack(auc_scores)) if average == AverageMethod.WEIGHTED: if mode == DataType.MULTILABEL: support = torch.sum(target, dim=0) else: support = torch.bincount(target.flatten(), minlength=num_classes) return torch.sum( torch.stack(auc_scores) * support / support.sum()) allowed_average = (AverageMethod.NONE.value, AverageMethod.MACRO.value, AverageMethod.WEIGHTED.value) raise ValueError( f"Argument `average` expected to be one of the following:" f" {allowed_average} but got {average}") return _auc_compute_without_check(fpr, tpr, 1.0) _device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device max_area: Tensor = tensor(max_fpr, device=_device) # Add a single point at max_fpr and interpolate its tpr value stop = torch.bucketize(max_area, fpr, out_int32=True, right=True) weight = (max_area - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1]) interp_tpr: Tensor = torch.lerp(tpr[stop - 1], tpr[stop], weight) tpr = torch.cat([tpr[:stop], interp_tpr.view(1)]) fpr = torch.cat([fpr[:stop], max_area.view(1)]) # Compute partial AUC partial_auc = _auc_compute_without_check(fpr, tpr, 1.0) # McClish correction: standardize result to be 0.5 if non-discriminant and 1 if maximal min_area: Tensor = 0.5 * max_area**2 return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))