def compute(self) -> Tuple[torch.Tensor, float, float, float]: """Computes the AUC metric based on saved statistics.""" targets = torch.cat(self.targets) scores = torch.cat(self.scores) # ddp hotfix, could be done better # but metric must handle DDP on it's own if self._ddp_backend == "xla": # if you have "RuntimeError: Aborted: Session XXX is not found" here # please, ask Google for a more powerful TPU setup ;) device = get_device() scores = xm.all_gather(scores.to(device)).cpu().detach() targets = xm.all_gather(targets.to(device)).cpu().detach() elif self._ddp_backend == "ddp": scores = torch.cat(all_gather(scores)) targets = torch.cat(all_gather(targets)) scores, targets, _, _ = process_multilabel_components( outputs=scores, targets=targets ) per_class = auc(scores=scores, targets=targets) micro = binary_auc(scores=scores.view(-1), targets=targets.view(-1))[0] macro = per_class.mean().item() weights = targets.sum(axis=0) / len(targets) weighted = (per_class * weights).sum().item() if self.compute_per_class_metrics: return per_class, micro, macro, weighted else: return [], micro, macro, weighted
def compute(self) -> Tuple[torch.Tensor, float, float, float]: """Computes the AUC metric based on saved statistics.""" targets = torch.cat(self.targets) scores = torch.cat(self.scores) # @TODO: ddp hotfix, could be done better if self._is_ddp: scores = torch.cat(all_gather(scores)) targets = torch.cat(all_gather(targets)) scores, targets, _ = process_multilabel_components(outputs=scores, targets=targets) per_class = auc(scores=scores, targets=targets) micro = binary_auc(scores=scores.view(-1), targets=targets.view(-1))[0] macro = per_class.mean().item() weights = targets.sum(axis=0) / len(targets) weighted = (per_class * weights).sum().item() return per_class, micro, macro, weighted
def binary_average_precision( outputs: torch.Tensor, targets: torch.Tensor, weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Computes the average precision. Args: outputs: NxK tensor that for each of the N examples indicates the probability of the example belonging to each of the K classes, according to the model. targets: binary NxK tensort that encodes which of the K classes are associated with the N-th input (eg: a row [0, 1, 0, 1] indicates that the example is associated with classes 2 and 4) weights: importance for each sample Returns: torch.Tensor: tensor of [K; ] shape, with average precision for K classes Examples: >>> binary_average_precision( >>> outputs=torch.Tensor([0.1, 0.4, 0.35, 0.8]), >>> targets=torch.Tensor([0, 0, 1, 1]), >>> ) tensor([0.8333]) """ # outputs - [bs; num_classes] with scores # targets - [bs; num_classes] with binary labels outputs, targets, weights = process_multilabel_components( outputs=outputs, targets=targets, weights=weights, ) if outputs.numel() == 0: return torch.zeros(1) ap = torch.zeros(targets.size(1)) # compute average precision for each class for class_i in range(targets.size(1)): # sort scores class_scores = outputs[:, class_i] class_targets = targets[:, class_i] _, sortind = torch.sort(class_scores, dim=0, descending=True) correct = class_targets[sortind] # compute true positive sums if weights is not None: class_weight = weights[sortind] weighted_correct = correct.float() * class_weight tp = weighted_correct.cumsum(0) rg = class_weight.cumsum(0) else: tp = correct.float().cumsum(0) rg = torch.arange(1, targets.size(0) + 1).float() # compute precision curve precision = tp.div(rg) # compute average precision ap[class_i] = precision[correct.bool()].sum() / max(float(correct.sum()), 1) return ap