Exemplo n.º 1
0
def multiclass_auroc(
    pred: torch.Tensor,
    target: torch.Tensor,
    sample_weight: Optional[Sequence] = None,
    num_classes: Optional[int] = None,
) -> torch.Tensor:
    """
    Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from multiclass
    prediction scores

    .. warning :: Deprecated in favor of
     :func:`~pytorch_lightning.metrics.functional.auroc.auroc`. Will be removed
     in v1.4.0.

    Args:
        pred: estimated probabilities, with shape [N, C]
        target: ground-truth labels, with shape [N,]
        sample_weight: sample weights
        num_classes: number of classes (default: None, computes automatically from data)

    Return:
        Tensor containing ROCAUC score

    Example:

        >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
        ...                      [0.05, 0.85, 0.05, 0.05],
        ...                      [0.05, 0.05, 0.85, 0.05],
        ...                      [0.05, 0.05, 0.05, 0.85]])
        >>> target = torch.tensor([0, 1, 3, 2])
        >>> multiclass_auroc(pred, target, num_classes=4)
        tensor(0.6667)
    """
    rank_zero_warn(
        "This `multiclass_auroc` was deprecated in v1.2.0 in favor of"
        " `pytorch_lightning.metrics.functional.auroc import auroc`."
        " It will be removed in v1.4.0", DeprecationWarning)

    if not torch.allclose(pred.sum(dim=1), torch.tensor(1.0)):
        raise ValueError(
            "Multiclass AUROC metric expects the target scores to be"
            " probabilities, i.e. they should sum up to 1.0 over classes")

    if torch.unique(target).size(0) != pred.size(1):
        raise ValueError(
            f"Number of classes found in in 'target' ({torch.unique(target).size(0)})"
            f" does not equal the number of columns in 'pred' ({pred.size(1)})."
            " Multiclass AUROC is not defined when all of the classes do not"
            " occur in the target labels.")

    if num_classes is not None and num_classes != pred.size(1):
        raise ValueError(
            f"Number of classes deduced from 'pred' ({pred.size(1)}) does not equal"
            f" the number of classes passed in 'num_classes' ({num_classes}).")

    return __auroc(preds=pred,
                   target=target,
                   sample_weights=sample_weight,
                   num_classes=num_classes)
Exemplo n.º 2
0
def multiclass_auroc(
    pred: torch.Tensor,
    target: torch.Tensor,
    sample_weight: Optional[Sequence] = None,
    num_classes: Optional[int] = None,
) -> torch.Tensor:
    """
    .. deprecated::
        Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.4.0.
    """
    rank_zero_warn(
        "This `multiclass_auroc` was deprecated in v1.2.0 in favor of"
        " `pytorch_lightning.metrics.functional.auroc import auroc`."
        " It will be removed in v1.4.0", DeprecationWarning
    )

    return __auroc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes)
Exemplo n.º 3
0
def auroc(
    pred: torch.Tensor,
    target: torch.Tensor,
    sample_weight: Optional[Sequence] = None,
    pos_label: int = 1.,
    max_fpr: float = None,
) -> torch.Tensor:
    """
    .. deprecated::
        Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.4.0.
    """
    rank_zero_warn(
        "This `auroc` was deprecated in v1.2.0 in favor of"
        " `pytorch_lightning.metrics.functional.auroc import auroc`."
        " It will be removed in v1.4.0", DeprecationWarning
    )
    return __auroc(
        preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, max_fpr=max_fpr, num_classes=1
    )
def auroc(
    pred: torch.Tensor,
    target: torch.Tensor,
    sample_weight: Optional[Sequence] = None,
    pos_label: int = 1.,
    max_fpr: float = None,
) -> torch.Tensor:
    """
    Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores

    .. warning :: Deprecated in favor of
     :func:`~pytorch_lightning.metrics.functional.auroc.auroc`. Will be removed
     in v1.4.0.

    Args:
        pred: estimated probabilities
        target: ground-truth labels
        sample_weight: sample weights
        pos_label: the label for the positive class
        max_fpr: If not ``None``, calculates standardized partial AUC over the
            range [0, max_fpr]. Should be a float between 0 and 1.

    Return:
        Tensor containing ROCAUC score

    Example:

        >>> x = torch.tensor([0, 1, 2, 3])
        >>> y = torch.tensor([0, 1, 1, 0])
        >>> auroc(x, y)
        tensor(0.5000)
    """
    rank_zero_warn(
        "This `auroc` was deprecated in v1.2.0 in favor of"
        " `pytorch_lightning.metrics.functional.auroc import auroc`."
        " It will be removed in v1.4.0", DeprecationWarning)
    return __auroc(preds=pred,
                   target=target,
                   sample_weights=sample_weight,
                   pos_label=pos_label,
                   max_fpr=max_fpr,
                   num_classes=1)