示例#1
0
def test_auc(x, y, expected):
    # Test Area Under Curve (AUC) computation
    assert auc(torch.tensor(x), torch.tensor(y), reorder=True) == expected
示例#2
0
def _auroc_compute(
    preds: torch.Tensor,
    target: torch.Tensor,
    mode: str,
    num_classes: Optional[int] = None,
    pos_label: Optional[int] = None,
    average: Optional[str] = 'macro',
    max_fpr: Optional[float] = None,
    sample_weights: Optional[Sequence] = None,
) -> torch.Tensor:
    # binary mode override num_classes
    if mode == 'binary':
        num_classes = 1

    # check max_fpr parameter
    if max_fpr is not None:
        if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
            raise ValueError(
                f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")

        if LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
            raise RuntimeError(
                "`max_fpr` argument requires `torch.bucketize` which"
                " is not available below PyTorch version 1.6")

        # max_fpr parameter is only support for binary
        if mode != 'binary':
            raise ValueError(
                f"Partial AUC computation not available in"
                f" multilabel/multiclass setting, 'max_fpr' must be"
                f" set to `None`, received `{max_fpr}`.")

    # calculate fpr, tpr
    if mode == 'multi-label':
        # for multilabel we iteratively evaluate roc in a binary fashion
        output = [
            roc(preds[:, i],
                target[:, i],
                num_classes=1,
                pos_label=1,
                sample_weights=sample_weights) for i in range(num_classes)
        ]
        fpr = [o[0] for o in output]
        tpr = [o[1] for o in output]
    else:
        fpr, tpr, _ = roc(preds, target, num_classes, pos_label,
                          sample_weights)

    # calculate standard roc auc score
    if max_fpr is None or max_fpr == 1:
        if num_classes != 1:
            # calculate auc scores per class
            auc_scores = [auc(x, y) for x, y in zip(fpr, tpr)]

            # calculate average
            if average == AverageMethods.NONE:
                return auc_scores
            elif average == AverageMethods.MACRO:
                return torch.mean(torch.stack(auc_scores))
            elif average == AverageMethods.WEIGHTED:
                if mode == DataType.MULTILABEL:
                    support = torch.sum(target, dim=0)
                else:
                    support = torch.bincount(target.flatten(),
                                             minlength=num_classes)
                return torch.sum(
                    torch.stack(auc_scores) * support / support.sum())

            allowed_average = [e.value for e in AverageMethods]
            raise ValueError(
                f"Argument `average` expected to be one of the following:"
                f" {allowed_average} but got {average}")

        return auc(fpr, tpr)

    max_fpr = torch.tensor(max_fpr, device=fpr.device)
    # Add a single point at max_fpr and interpolate its tpr value
    stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True)
    weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1])
    interp_tpr = torch.lerp(tpr[stop - 1], tpr[stop], weight)
    tpr = torch.cat([tpr[:stop], interp_tpr.view(1)])
    fpr = torch.cat([fpr[:stop], max_fpr.view(1)])

    # Compute partial AUC
    partial_auc = auc(fpr, tpr)

    # McClish correction: standardize result to be 0.5 if non-discriminant
    # and 1 if maximal
    min_area = 0.5 * max_fpr**2
    max_area = max_fpr
    return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))