Ejemplo n.º 1
0
    def compute(self) -> Tuple[torch.Tensor, float, float, float]:
        """Computes the AUC metric based on saved statistics."""
        targets = torch.cat(self.targets)
        scores = torch.cat(self.scores)

        # ddp hotfix, could be done better
        # but metric must handle DDP on it's own
        if self._ddp_backend == "xla":
            # if you have "RuntimeError: Aborted: Session XXX is not found" here
            # please, ask Google for a more powerful TPU setup ;)
            device = get_device()
            scores = xm.all_gather(scores.to(device)).cpu().detach()
            targets = xm.all_gather(targets.to(device)).cpu().detach()
        elif self._ddp_backend == "ddp":
            scores = torch.cat(all_gather(scores))
            targets = torch.cat(all_gather(targets))

        scores, targets, _, _ = process_multilabel_components(
            outputs=scores, targets=targets
        )
        per_class = auc(scores=scores, targets=targets)
        micro = binary_auc(scores=scores.view(-1), targets=targets.view(-1))[0]
        macro = per_class.mean().item()
        weights = targets.sum(axis=0) / len(targets)
        weighted = (per_class * weights).sum().item()
        if self.compute_per_class_metrics:
            return per_class, micro, macro, weighted
        else:
            return [], micro, macro, weighted
Ejemplo n.º 2
0
    def compute(self) -> torch.Tensor:
        """Computes the AUC metric based on saved statistics."""
        targets = torch.cat(self.targets)
        scores = torch.cat(self.scores)

        # @TODO: ddp hotfix, could be done better
        if self._is_ddp:
            scores = torch.cat(all_gather(scores))
            targets = torch.cat(all_gather(targets))

        score = auc(outputs=scores, targets=targets)
        return score
Ejemplo n.º 3
0
    def compute(self) -> Any:
        """
        Returns:
            Confusion matrix of K rows and K columns, where rows corresponds
            to ground-truth targets and columns corresponds to predicted
            targets.
        """
        # ddp hotfix, could be done better
        # but metric must handle DDP on it's own
        if self._ddp_backend == "xla":
            # if you have "RuntimeError: Aborted: Session XXX is not found" here
            # please, ask Google for a more powerful TPU setup ;)
            device = get_device()
            value = torch.tensor([self.conf], device=device)
            self.conf = xm.all_gather(value).sum(0).cpu().detach().numpy()
        elif self._ddp_backend == "ddp":
            value: List[np.ndarray] = all_gather(self.conf)
            value: np.ndarray = np.sum(np.stack(value, axis=0), axis=0)
            self.conf = value

        if self.normalized:
            conf = self.conf.astype(np.float32)
            return conf / conf.sum(1).clip(min=1e-12)[:, None]
        else:
            return self.conf
Ejemplo n.º 4
0
    def compute(self) -> Tuple[float, float, float]:
        """
        Compute metrics with accumulated statistics

        Returns:
            tuple of metrics: precision, recall, f1 score
        """
        # ddp hotfix, could be done better
        # but metric must handle DDP on it's own
        if self._ddp_backend == "xla":
            self.statistics = {
                k: xm.mesh_reduce(k, v, np.sum)
                for k, v in self.statistics.items()
            }
        elif self._ddp_backend == "ddp":
            for key in self.statistics:
                value: List[int] = all_gather(self.statistics[key])
                value: int = sum(value)
                self.statistics[key] = value

        precision_value, recall_value, f1_value = get_binary_metrics(
            tp=self.statistics["tp"],
            fp=self.statistics["fp"],
            fn=self.statistics["fn"],
            zero_division=self.zero_division,
        )
        return precision_value, recall_value, f1_value
Ejemplo n.º 5
0
    def compute(self) -> Any:
        """
        Compute precision, recall, f1 score and support.
        Compute micro, macro and weighted average for the metrics.

        Returns:
            list of aggregated metrics: per-class, micro, macro and weighted averaging of
                precision, recall, f1 score and support metrics
        """
        # ddp hotfix, could be done better
        # but metric must handle DDP on it's own
        if self._ddp_backend == "xla":
            device = get_device()
            for key in self.statistics:
                key_statistics = torch.tensor([self.statistics[key]],
                                              device=device)
                key_statistics = xm.all_gather(key_statistics).sum(
                    dim=0).cpu().numpy()
                self.statistics[key] = key_statistics
        elif self._ddp_backend == "ddp":
            for key in self.statistics:
                value: List[np.ndarray] = all_gather(self.statistics[key])
                value: np.ndarray = np.sum(np.vstack(value), axis=0)
                self.statistics[key] = value

        per_class, micro, macro, weighted = get_aggregated_metrics(
            tp=self.statistics["tp"],
            fp=self.statistics["fp"],
            fn=self.statistics["fn"],
            support=self.statistics["support"],
            zero_division=self.zero_division,
        )
        return per_class, micro, macro, weighted
Ejemplo n.º 6
0
    def compute(self) -> Tuple[torch.Tensor, float, float, float]:
        """Computes the AUC metric based on saved statistics."""
        targets = torch.cat(self.targets)
        scores = torch.cat(self.scores)

        # @TODO: ddp hotfix, could be done better
        if self._is_ddp:
            scores = torch.cat(all_gather(scores))
            targets = torch.cat(all_gather(targets))

        scores, targets, _ = process_multilabel_components(outputs=scores,
                                                           targets=targets)
        per_class = auc(scores=scores, targets=targets)
        micro = binary_auc(scores=scores.view(-1), targets=targets.view(-1))[0]
        macro = per_class.mean().item()
        weights = targets.sum(axis=0) / len(targets)
        weighted = (per_class * weights).sum().item()
        return per_class, micro, macro, weighted
Ejemplo n.º 7
0
def _all_gather(rank, world_size):
    _setup(rank, world_size)

    to_gather = torch.ones(3, dtype=torch.int) * (rank + 1)  # use cpu tensors
    actual = all_gather(to_gather)
    actual = torch.cat(actual)

    expected = torch.cat([torch.ones(3, dtype=torch.int) * (i + 1) for i in range(world_size)])

    assert torch.all(actual.eq(expected))

    _cleanup()
Ejemplo n.º 8
0
 def compute(self) -> Any:
     """
     Returns:
         Confusion matrix of K rows and K columns, where rows corresponds
         to ground-truth targets and columns corresponds to predicted
         targets.
     """
     if self._is_ddp:
         value: List[np.ndarray] = all_gather(self.conf)
         value: np.ndarray = np.sum(np.stack(value, axis=0), axis=0)
         self.conf = value
     if self.normalized:
         conf = self.conf.astype(np.float32)
         return conf / conf.sum(1).clip(min=1e-12)[:, None]
     else:
         return self.conf
Ejemplo n.º 9
0
    def compute(self):
        """
        Compute metrics with accumulated statistics

        Returns:
            tuple of metrics: per_class, micro_metric, macro_metric,
                weighted_metric(None if weights is None)
        """
        per_class = []
        total_statistics = {}
        macro_metric = 0
        weighted_metric = 0
        # ddp hotfix, could be done better
        # but metric must handle DDP on it's own
        # TODO: optimise speed
        if self._ddp_backend == "xla":
            device = get_device()
            for _, statistics in self.statistics.items():
                for key in statistics:
                    value = torch.tensor([statistics[key]], device=device)
                    statistics[key] = xm.all_gather(value).sum(dim=0)
        elif self._ddp_backend == "ddp":
            for _, statistics in self.statistics.items():
                for key in statistics:
                    value: List[torch.Tensor] = all_gather(statistics[key])
                    value: torch.Tensor = torch.sum(torch.vstack(value), dim=0)
                    statistics[key] = value

        for class_idx, statistics in self.statistics.items():
            value = self.metric_fn(**statistics)
            per_class.append(value)
            macro_metric += value
            if self.weights is not None:
                weighted_metric += value * self.weights[class_idx]
            for stats_name, value in statistics.items():
                total_statistics[stats_name] = (
                    total_statistics.get(stats_name, 0) + value)

        macro_metric /= len(self.statistics)
        micro_metric = self.metric_fn(**total_statistics)

        if self.weights is None:
            weighted_metric = None
        if self.compute_per_class_metrics:
            return per_class, micro_metric, macro_metric, weighted_metric
        else:
            return [], micro_metric, macro_metric, weighted_metric
Ejemplo n.º 10
0
    def compute_key_value(self) -> Dict[str, torch.Tensor]:
        """
        Compute segmentation metric for all data and return results in key-value format

        Returns:
             dict of metrics, including micro, macro and weighted (if weights were given) metrics
        """
        metrics = {}
        total_statistics = {}
        macro_metric = 0
        weighted_metric = 0

        # @TODO: ddp hotfix, could be done better
        if self._is_ddp:
            for _, statistics in self.statistics.items():
                for key in statistics:
                    device = statistics[key].device
                    value: List[torch.Tensor] = all_gather(
                        statistics[key].cpu())
                    value: torch.Tensor = torch.sum(torch.vstack(value),
                                                    dim=0).to(device)
                    statistics[key] = value

        for class_idx, statistics in self.statistics.items():
            value = self.metric_fn(**statistics)
            macro_metric += value
            if self.weights is not None:
                weighted_metric += value * self.weights[class_idx]
            metrics[
                f"{self.prefix}{self.metric_name}{self.suffix}/{self.class_names[class_idx]}"] = value
            for stats_name, value in statistics.items():
                total_statistics[stats_name] = total_statistics.get(
                    stats_name, 0) + value
        macro_metric /= len(self.statistics)
        micro_metric = self.metric_fn(**total_statistics)
        metrics[
            f"{self.prefix}{self.metric_name}{self.suffix}/_micro"] = micro_metric
        metrics[f"{self.prefix}{self.metric_name}{self.suffix}"] = macro_metric
        metrics[
            f"{self.prefix}{self.metric_name}{self.suffix}/_macro"] = macro_metric
        if self.weights is not None:
            metrics[
                f"{self.prefix}{self.metric_name}{self.suffix}/_weighted"] = weighted_metric
        # convert torch.Tensor to float
        # metrics = {k: float(v) for k, v in metrics.items()}
        return metrics
Ejemplo n.º 11
0
    def compute(self) -> Tuple[float, float, float]:
        """
        Compute metrics with accumulated statistics

        Returns:
            tuple of metrics: precision, recall, f1 score
        """
        # @TODO: ddp hotfix, could be done better
        if self._is_ddp:
            for key in self.statistics:
                value: List[float] = all_gather(self.statistics[key])
                value: float = sum(value)
                self.statistics[key] = value

        precision_value, recall_value, f1_value = get_binary_metrics(
            tp=self.statistics["tp"],
            fp=self.statistics["fp"],
            fn=self.statistics["fn"],
            zero_division=self.zero_division,
        )
        return precision_value, recall_value, f1_value
Ejemplo n.º 12
0
    def compute_key_value(self) -> Dict[str, float]:
        """
        Compute precision, recall, f1 score and support.
        Compute micro, macro and weighted average for the metrics.

        Returns:
            dict of metrics
        """
        # @TODO: ddp hotfix, could be done better
        if self._is_ddp:
            for key in self.statistics:
                value: List[np.ndarray] = all_gather(self.statistics[key])
                value: np.ndarray = np.sum(np.vstack(value), axis=0)
                self.statistics[key] = value

        per_class, micro, macro, weighted = self.compute()
        metrics = self._convert_metrics_to_kv(per_class=per_class,
                                              micro=micro,
                                              macro=macro,
                                              weighted=weighted)
        return metrics