Exemplo n.º 1
0
    def compute(
        self
    ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor],
                                                    List[Tensor]]]:
        """
        Compute the precision-recall curve

        Returns:
            3-element tuple containing

            precision:
                tensor where element i is the precision of predictions with
                score >= thresholds[i] and the last element is 1.
                If multiclass, this is a list of such tensors, one for each class.
            recall:
                tensor where element i is the recall of predictions with
                score >= thresholds[i] and the last element is 0.
                If multiclass, this is a list of such tensors, one for each class.
            thresholds:
                Thresholds used for computing precision/recall scores
        """
        preds = torch.cat(self.preds, dim=0)
        target = torch.cat(self.target, dim=0)
        return _precision_recall_curve_compute(preds, target, self.num_classes,
                                               self.pos_label)
Exemplo n.º 2
0
    def compute(
        self
    ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor],
                                                    List[Tensor]]]:
        """Compute the precision-recall curve.

        Returns:
            3-element tuple containing

            precision:
                tensor where element ``i`` is the precision of predictions with
                ``score >= thresholds[i]`` and the last element is 1.
                If multiclass, this is a list of such tensors, one for each class.
            recall:
                tensor where element ``i`` is the recall of predictions with
                ``score >= thresholds[i]`` and the last element is 0.
                If multiclass, this is a list of such tensors, one for each class.
            thresholds:
                Thresholds used for computing precision/recall scores
        """
        preds = dim_zero_cat(self.preds)
        target = dim_zero_cat(self.target)
        if not self.num_classes:
            raise ValueError(
                f"`num_classes` bas to be positive number, but got {self.num_classes}"
            )
        return _precision_recall_curve_compute(preds, target, self.num_classes,
                                               self.pos_label)
Exemplo n.º 3
0
def _average_precision_compute(
    preds: Tensor,
    target: Tensor,
    num_classes: int,
    pos_label: Optional[int] = None,
    average: Optional[str] = "macro",
    sample_weights: Optional[Sequence] = None,
) -> Union[List[Tensor], Tensor]:
    """Computes the average precision score.

    Args:
        preds: predictions from model (logits or probabilities)
        target: ground truth values
        num_classes: integer with number of classes.
        pos_label: integer determining the positive class. Default is ``None`` which for binary problem is translated
            to 1. For multiclass problems his argument should not be set as we iteratively change it in the
            range ``[0, num_classes-1]``
        average: reduction method for multi-class or multi-label problems
        sample_weights: sample weights for each data point

    Example:
        >>> # binary case
        >>> preds = torch.tensor([0, 1, 2, 3])
        >>> target = torch.tensor([0, 1, 1, 1])
        >>> pos_label = 1
        >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, pos_label=pos_label)
        >>> _average_precision_compute(preds, target, num_classes, pos_label)
        tensor(1.)

        >>> # multiclass case
        >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
        ...                      [0.05, 0.75, 0.05, 0.05, 0.05],
        ...                      [0.05, 0.05, 0.75, 0.05, 0.05],
        ...                      [0.05, 0.05, 0.05, 0.75, 0.05]])
        >>> target = torch.tensor([0, 1, 3, 2])
        >>> num_classes = 5
        >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes)
        >>> _average_precision_compute(preds, target, num_classes, average=None)
        [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
    """

    # todo: `sample_weights` is unused
    precision, recall, _ = _precision_recall_curve_compute(
        preds, target, num_classes, pos_label)
    if average == "weighted":
        if preds.ndim == target.ndim and target.ndim > 1:
            weights = target.sum(dim=0).float()
        else:
            weights = _bincount(target, minlength=num_classes).float()
        weights = weights / torch.sum(weights)
    else:
        weights = None
    return _average_precision_compute_with_precision_recall(
        precision, recall, num_classes, average, weights)
Exemplo n.º 4
0
def _average_precision_compute(
    preds: Tensor,
    target: Tensor,
    num_classes: int,
    pos_label: int,
    sample_weights: Optional[Sequence] = None,
) -> Union[List[Tensor], Tensor]:
    # todo: `sample_weights` is unused
    precision, recall, _ = _precision_recall_curve_compute(
        preds, target, num_classes, pos_label)
    return _average_precision_compute_with_precision_recall(
        precision, recall, num_classes)
Exemplo n.º 5
0
def _average_precision_compute(
    preds: Tensor,
    target: Tensor,
    num_classes: int,
    pos_label: int,
    sample_weights: Optional[Sequence] = None,
) -> Union[List[Tensor], Tensor]:
    # todo: `sample_weights` is unused
    precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
    # Return the step function integral
    # The following works because the last entry of precision is
    # guaranteed to be 1, as returned by precision_recall_curve
    if num_classes == 1:
        return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1])

    res = []
    for p, r in zip(precision, recall):
        res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1]))
    return res