def test_micro( tp: np.array, fp: np.array, fn: np.array, support: np.array, zero_division: int, true_answer: Tuple[float], ): """ Test micro metrics averaging Args: tp: true positive statistic fp: false positive statistic fn: false negative statistic support: support statistic zero_division: 0 or 1 true_answer: true metric value """ _, micro, _, _ = get_aggregated_metrics(tp=tp, fp=fp, fn=fn, support=support, zero_division=zero_division) assert micro[-1] is None for pred, real in zip(micro[:-1], true_answer): assert abs(pred - real) < EPS
def compute(self) -> Any: """ Compute precision, recall, f1 score and support. Compute micro, macro and weighted average for the metrics. Returns: list of aggregated metrics: per-class, micro, macro and weighted averaging of precision, recall, f1 score and support metrics """ # ddp hotfix, could be done better # but metric must handle DDP on it's own if self._ddp_backend == "xla": device = get_device() for key in self.statistics: key_statistics = torch.tensor([self.statistics[key]], device=device) key_statistics = xm.all_gather(key_statistics).sum( dim=0).cpu().numpy() self.statistics[key] = key_statistics elif self._ddp_backend == "ddp": for key in self.statistics: value: List[np.ndarray] = all_gather(self.statistics[key]) value: np.ndarray = np.sum(np.vstack(value), axis=0) self.statistics[key] = value per_class, micro, macro, weighted = get_aggregated_metrics( tp=self.statistics["tp"], fp=self.statistics["fp"], fn=self.statistics["fn"], support=self.statistics["support"], zero_division=self.zero_division, ) return per_class, micro, macro, weighted
def update( self, outputs: torch.Tensor, targets: torch.Tensor ) -> Tuple[Any, Any, Any, Any]: """ Update statistics and return intermediate metrics results Args: outputs: prediction values targets: true answers Returns: tuple of metrics intermediate results with per-class, micro, macro and weighted averaging """ tn, fp, fn, tp, support, num_classes = super().update( outputs=outputs, targets=targets ) per_class, micro, macro, weighted = get_aggregated_metrics( tp=tp, fp=fp, fn=fn, support=support, zero_division=self.zero_division ) if self.num_classes is None: self.num_classes = num_classes return per_class, micro, macro, weighted
def compute(self) -> Any: """ Compute precision, recall, f1 score and support. Compute micro, macro and weighted average for the metrics. Returns: list of aggregated metrics: per-class, micro, macro and weighted averaging of precision, recall, f1 score and support metrics """ per_class, micro, macro, weighted = get_aggregated_metrics( tp=self.statistics["tp"], fp=self.statistics["fp"], fn=self.statistics["fn"], support=self.statistics["support"], zero_division=self.zero_division, ) return per_class, micro, macro, weighted