def compute( self ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """ Compute the precision-recall curve Returns: 3-element tuple containing precision: tensor where element i is the precision of predictions with score >= thresholds[i] and the last element is 1. If multiclass, this is a list of such tensors, one for each class. recall: tensor where element i is the recall of predictions with score >= thresholds[i] and the last element is 0. If multiclass, this is a list of such tensors, one for each class. thresholds: Thresholds used for computing precision/recall scores """ preds = torch.cat(self.preds, dim=0) target = torch.cat(self.target, dim=0) return _precision_recall_curve_compute(preds, target, self.num_classes, self.pos_label)
def compute( self ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute the precision-recall curve. Returns: 3-element tuple containing precision: tensor where element ``i`` is the precision of predictions with ``score >= thresholds[i]`` and the last element is 1. If multiclass, this is a list of such tensors, one for each class. recall: tensor where element ``i`` is the recall of predictions with ``score >= thresholds[i]`` and the last element is 0. If multiclass, this is a list of such tensors, one for each class. thresholds: Thresholds used for computing precision/recall scores """ preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) if not self.num_classes: raise ValueError( f"`num_classes` bas to be positive number, but got {self.num_classes}" ) return _precision_recall_curve_compute(preds, target, self.num_classes, self.pos_label)
def _average_precision_compute( preds: Tensor, target: Tensor, num_classes: int, pos_label: Optional[int] = None, average: Optional[str] = "macro", sample_weights: Optional[Sequence] = None, ) -> Union[List[Tensor], Tensor]: """Computes the average precision score. Args: preds: predictions from model (logits or probabilities) target: ground truth values num_classes: integer with number of classes. pos_label: integer determining the positive class. Default is ``None`` which for binary problem is translated to 1. For multiclass problems his argument should not be set as we iteratively change it in the range ``[0, num_classes-1]`` average: reduction method for multi-class or multi-label problems sample_weights: sample weights for each data point Example: >>> # binary case >>> preds = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) >>> pos_label = 1 >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, pos_label=pos_label) >>> _average_precision_compute(preds, target, num_classes, pos_label) tensor(1.) >>> # multiclass case >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> num_classes = 5 >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes) >>> _average_precision_compute(preds, target, num_classes, average=None) [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] """ # todo: `sample_weights` is unused precision, recall, _ = _precision_recall_curve_compute( preds, target, num_classes, pos_label) if average == "weighted": if preds.ndim == target.ndim and target.ndim > 1: weights = target.sum(dim=0).float() else: weights = _bincount(target, minlength=num_classes).float() weights = weights / torch.sum(weights) else: weights = None return _average_precision_compute_with_precision_recall( precision, recall, num_classes, average, weights)
def _average_precision_compute( preds: Tensor, target: Tensor, num_classes: int, pos_label: int, sample_weights: Optional[Sequence] = None, ) -> Union[List[Tensor], Tensor]: # todo: `sample_weights` is unused precision, recall, _ = _precision_recall_curve_compute( preds, target, num_classes, pos_label) return _average_precision_compute_with_precision_recall( precision, recall, num_classes)
def _average_precision_compute( preds: Tensor, target: Tensor, num_classes: int, pos_label: int, sample_weights: Optional[Sequence] = None, ) -> Union[List[Tensor], Tensor]: # todo: `sample_weights` is unused precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) # Return the step function integral # The following works because the last entry of precision is # guaranteed to be 1, as returned by precision_recall_curve if num_classes == 1: return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) res = [] for p, r in zip(precision, recall): res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) return res