Esempio n. 1
0
def test_get_multiclass_statistics(outputs, targets, tn_true, fp_true, fn_true,
                                   tp_true, support_true):
    tn, fp, fn, tp, support = get_multiclass_statistics(outputs, targets)

    assert torch.allclose(torch.tensor(tn_true).to(tn), tn)
    assert torch.allclose(torch.tensor(fp_true).to(fp), fp)
    assert torch.allclose(torch.tensor(fn_true).to(fn), fn)
    assert torch.allclose(torch.tensor(tp_true).to(tp), tp)
    assert torch.allclose(torch.tensor(support_true).to(support), support)
Esempio n. 2
0
    def update(
        self, outputs: torch.Tensor, targets: torch.Tensor
    ) -> Union[Tuple[int, int, int, int, int, int], Tuple[Any, Any, Any, Any, Any, int]]:
        """
        Compute statistics from outputs and targets,
        update accumulated statistics with new values.

        Args:
            outputs: prediction values
            targets: true answers

        Returns:
            Tuple of int or array: true negative, false positive, false
                negative, true positive, support statistics and num_classes

        """
        tn, fp, fn, tp, support, num_classes = get_multiclass_statistics(
            outputs=outputs.cpu().detach(),
            targets=targets.cpu().detach(),
            num_classes=self.num_classes,
        )

        tn = tn.numpy()
        fp = fp.numpy()
        fn = fn.numpy()
        tp = tp.numpy()
        support = support.numpy()

        if self.num_classes is None:
            self.num_classes = num_classes

        self.statistics["tn"] += tn
        self.statistics["fp"] += fp
        self.statistics["fn"] += fn
        self.statistics["tp"] += tp
        self.statistics["support"] += support

        return tn, fp, fn, tp, support, self.num_classes
Esempio n. 3
0
def precision_recall_fbeta_support(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    beta: float = 1,
    eps: float = 1e-6,
    argmax_dim: int = -1,
    num_classes: Optional[int] = None,
    zero_division: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Counts precision_val, recall, fbeta_score.

    Args:
        outputs: A list of predicted elements
        targets:  A list of elements that are to be predicted
        beta: beta param for f_score
        eps: epsilon to avoid zero division
        argmax_dim: int, that specifies dimension for argmax transformation
            in case of scores/probabilities in ``outputs``
        num_classes: int, that specifies number of classes if it known.
        zero_division: int value, should be one of 0 or 1;
            used for precision_val and recall computation

    Returns:
        tuple of precision_val, recall, fbeta_score

    Examples:
        >>> precision_recall_fbeta_support(
        >>>     outputs=torch.tensor([
        >>>         [1, 0, 0],
        >>>         [0, 1, 0],
        >>>         [0, 0, 1],
        >>>     ]),
        >>>     targets=torch.tensor([0, 1, 2]),
        >>>     beta=1,
        >>> )
        (
            tensor([1., 1., 1.]),  # precision_val per class
            tensor([1., 1., 1.]),  # recall per class
            tensor([1., 1., 1.]),  # fbeta per class
            tensor([1., 1., 1.]),  # support per class
        )
        >>> precision_recall_fbeta_support(
        >>>     outputs=torch.tensor([[0, 0, 1, 1, 0, 1, 0, 1]]),
        >>>     targets=torch.tensor([[0, 1, 0, 1, 0, 0, 1, 1]]),
        >>>     beta=1,
        >>> )
        (
            tensor([0.5000, 0.5000]),  # precision per class
            tensor([0.5000, 0.5000]),  # recall per class
            tensor([0.5000, 0.5000]),  # fbeta per class
            tensor([4., 4.]),          # support per class
        )
    """
    tn, fp, fn, tp, support = get_multiclass_statistics(
        outputs=outputs,
        targets=targets,
        argmax_dim=argmax_dim,
        num_classes=num_classes,
    )
    # @TODO: sync between metrics
    # precision_val = _precision(tp=tp, fp=fp, eps=eps, zero_division=zero_division)
    precision_val = (tp + eps) / (fp + tp + eps)
    recall_val = (tp + eps) / (fn + tp + eps)
    numerator = (1 + beta**2) * precision_val * recall_val
    denominator = beta**2 * precision_val + recall_val
    fbeta = numerator / denominator

    return precision_val, recall_val, fbeta, support