Ejemplo n.º 1
0
    def compute(self) -> Tuple[torch.Tensor, float, float, float]:
        """Computes the AUC metric based on saved statistics."""
        targets = torch.cat(self.targets)
        scores = torch.cat(self.scores)

        # ddp hotfix, could be done better
        # but metric must handle DDP on it's own
        if self._ddp_backend == "xla":
            # if you have "RuntimeError: Aborted: Session XXX is not found" here
            # please, ask Google for a more powerful TPU setup ;)
            device = get_device()
            scores = xm.all_gather(scores.to(device)).cpu().detach()
            targets = xm.all_gather(targets.to(device)).cpu().detach()
        elif self._ddp_backend == "ddp":
            scores = torch.cat(all_gather(scores))
            targets = torch.cat(all_gather(targets))

        scores, targets, _, _ = process_multilabel_components(
            outputs=scores, targets=targets
        )
        per_class = auc(scores=scores, targets=targets)
        micro = binary_auc(scores=scores.view(-1), targets=targets.view(-1))[0]
        macro = per_class.mean().item()
        weights = targets.sum(axis=0) / len(targets)
        weighted = (per_class * weights).sum().item()
        if self.compute_per_class_metrics:
            return per_class, micro, macro, weighted
        else:
            return [], micro, macro, weighted
Ejemplo n.º 2
0
    def compute(self) -> Tuple[torch.Tensor, float, float, float]:
        """Computes the AUC metric based on saved statistics."""
        targets = torch.cat(self.targets)
        scores = torch.cat(self.scores)

        # @TODO: ddp hotfix, could be done better
        if self._is_ddp:
            scores = torch.cat(all_gather(scores))
            targets = torch.cat(all_gather(targets))

        scores, targets, _ = process_multilabel_components(outputs=scores,
                                                           targets=targets)
        per_class = auc(scores=scores, targets=targets)
        micro = binary_auc(scores=scores.view(-1), targets=targets.view(-1))[0]
        macro = per_class.mean().item()
        weights = targets.sum(axis=0) / len(targets)
        weighted = (per_class * weights).sum().item()
        return per_class, micro, macro, weighted
Ejemplo n.º 3
0
def binary_average_precision(
    outputs: torch.Tensor, targets: torch.Tensor, weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Computes the average precision.

    Args:
        outputs: NxK tensor that for each of the N examples
            indicates the probability of the example belonging to each of
            the K classes, according to the model.
        targets:  binary NxK tensort that encodes which of the K
            classes are associated with the N-th input
            (eg: a row [0, 1, 0, 1] indicates that the example is
            associated with classes 2 and 4)
        weights: importance for each sample

    Returns:
        torch.Tensor: tensor of [K; ] shape,
        with average precision for K classes

    Examples:
        >>> binary_average_precision(
        >>>     outputs=torch.Tensor([0.1, 0.4, 0.35, 0.8]),
        >>>     targets=torch.Tensor([0, 0, 1, 1]),
        >>> )
        tensor([0.8333])
    """
    # outputs - [bs; num_classes] with scores
    # targets - [bs; num_classes] with binary labels
    outputs, targets, weights = process_multilabel_components(
        outputs=outputs, targets=targets, weights=weights,
    )
    if outputs.numel() == 0:
        return torch.zeros(1)

    ap = torch.zeros(targets.size(1))

    # compute average precision for each class
    for class_i in range(targets.size(1)):
        # sort scores
        class_scores = outputs[:, class_i]
        class_targets = targets[:, class_i]
        _, sortind = torch.sort(class_scores, dim=0, descending=True)
        correct = class_targets[sortind]

        # compute true positive sums
        if weights is not None:
            class_weight = weights[sortind]
            weighted_correct = correct.float() * class_weight

            tp = weighted_correct.cumsum(0)
            rg = class_weight.cumsum(0)
        else:
            tp = correct.float().cumsum(0)
            rg = torch.arange(1, targets.size(0) + 1).float()

        # compute precision curve
        precision = tp.div(rg)

        # compute average precision
        ap[class_i] = precision[correct.bool()].sum() / max(float(correct.sum()), 1)

    return ap