コード例 #1
0
def test_micro(
    tp: np.array,
    fp: np.array,
    fn: np.array,
    support: np.array,
    zero_division: int,
    true_answer: Tuple[float],
):
    """
    Test micro metrics averaging

    Args:
        tp: true positive statistic
        fp: false positive statistic
        fn: false negative statistic
        support: support statistic
        zero_division: 0 or 1
        true_answer: true metric value
    """
    _, micro, _, _ = get_aggregated_metrics(tp=tp,
                                            fp=fp,
                                            fn=fn,
                                            support=support,
                                            zero_division=zero_division)
    assert micro[-1] is None
    for pred, real in zip(micro[:-1], true_answer):
        assert abs(pred - real) < EPS
コード例 #2
0
    def compute(self) -> Any:
        """
        Compute precision, recall, f1 score and support.
        Compute micro, macro and weighted average for the metrics.

        Returns:
            list of aggregated metrics: per-class, micro, macro and weighted averaging of
                precision, recall, f1 score and support metrics
        """
        # ddp hotfix, could be done better
        # but metric must handle DDP on it's own
        if self._ddp_backend == "xla":
            device = get_device()
            for key in self.statistics:
                key_statistics = torch.tensor([self.statistics[key]],
                                              device=device)
                key_statistics = xm.all_gather(key_statistics).sum(
                    dim=0).cpu().numpy()
                self.statistics[key] = key_statistics
        elif self._ddp_backend == "ddp":
            for key in self.statistics:
                value: List[np.ndarray] = all_gather(self.statistics[key])
                value: np.ndarray = np.sum(np.vstack(value), axis=0)
                self.statistics[key] = value

        per_class, micro, macro, weighted = get_aggregated_metrics(
            tp=self.statistics["tp"],
            fp=self.statistics["fp"],
            fn=self.statistics["fn"],
            support=self.statistics["support"],
            zero_division=self.zero_division,
        )
        return per_class, micro, macro, weighted
コード例 #3
0
    def update(
        self, outputs: torch.Tensor, targets: torch.Tensor
    ) -> Tuple[Any, Any, Any, Any]:
        """
        Update statistics and return intermediate metrics results

        Args:
            outputs: prediction values
            targets: true answers

        Returns:
            tuple of metrics intermediate results with per-class, micro, macro and
                weighted averaging

        """
        tn, fp, fn, tp, support, num_classes = super().update(
            outputs=outputs, targets=targets
        )
        per_class, micro, macro, weighted = get_aggregated_metrics(
            tp=tp, fp=fp, fn=fn, support=support, zero_division=self.zero_division
        )
        if self.num_classes is None:
            self.num_classes = num_classes

        return per_class, micro, macro, weighted
コード例 #4
0
    def compute(self) -> Any:
        """
        Compute precision, recall, f1 score and support.
        Compute micro, macro and weighted average for the metrics.

        Returns:
            list of aggregated metrics: per-class, micro, macro and weighted averaging of
                precision, recall, f1 score and support metrics
        """
        per_class, micro, macro, weighted = get_aggregated_metrics(
            tp=self.statistics["tp"],
            fp=self.statistics["fp"],
            fn=self.statistics["fn"],
            support=self.statistics["support"],
            zero_division=self.zero_division,
        )
        return per_class, micro, macro, weighted