示例#1
0
    def compute(self) -> Tuple[float, float, float]:
        """
        Compute metrics with accumulated statistics

        Returns:
            tuple of metrics: precision, recall, f1 score
        """
        # ddp hotfix, could be done better
        # but metric must handle DDP on it's own
        if self._ddp_backend == "xla":
            self.statistics = {
                k: xm.mesh_reduce(k, v, np.sum)
                for k, v in self.statistics.items()
            }
        elif self._ddp_backend == "ddp":
            for key in self.statistics:
                value: List[int] = all_gather(self.statistics[key])
                value: int = sum(value)
                self.statistics[key] = value

        precision_value, recall_value, f1_value = get_binary_metrics(
            tp=self.statistics["tp"],
            fp=self.statistics["fp"],
            fn=self.statistics["fn"],
            zero_division=self.zero_division,
        )
        return precision_value, recall_value, f1_value
示例#2
0
    def update(self, outputs: torch.Tensor,
               targets: torch.Tensor) -> Tuple[float, float, float]:
        """
        Update statistics and return metrics intermediate results

        Args:
            outputs: predicted labels
            targets: target labels

        Returns:
            tuple of intermediate metrics: precision, recall, f1 score
        """
        tn, fp, fn, tp, support = super().update(outputs=outputs,
                                                 targets=targets)
        precision_value, recall_value, f1_value = get_binary_metrics(
            tp=tp, fp=fp, fn=fn, zero_division=self.zero_division)
        return precision_value, recall_value, f1_value
示例#3
0
    def compute(self) -> Tuple[float, float, float]:
        """
        Compute metrics with accumulated statistics

        Returns:
            tuple of metrics: precision, recall, f1 score
        """
        # @TODO: ddp hotfix, could be done better
        if self._is_ddp:
            for key in self.statistics:
                value: List[float] = all_gather(self.statistics[key])
                value: float = sum(value)
                self.statistics[key] = value

        precision_value, recall_value, f1_value = get_binary_metrics(
            tp=self.statistics["tp"],
            fp=self.statistics["fp"],
            fn=self.statistics["fn"],
            zero_division=self.zero_division,
        )
        return precision_value, recall_value, f1_value