Example #1
0
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        preds, target = _input_format_classification_one_hot(
            self.num_classes, preds, target, self.threshold, self.multilabel)

        # multiply because we are counting (1, 1) pair for true positives
        self.true_positives += torch.sum(preds * target, dim=1)
        self.predicted_positives += torch.sum(preds, dim=1)
Example #2
0
def _fbeta_update(
    preds: torch.Tensor,
    target: torch.Tensor,
    num_classes: int,
    threshold: float = 0.5,
    multilabel: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    preds, target = _input_format_classification_one_hot(num_classes, preds, target, threshold, multilabel)
    true_positives = torch.sum(preds * target, dim=1)
    predicted_positives = torch.sum(preds, dim=1)
    actual_positives = torch.sum(target, dim=1)
    return true_positives, predicted_positives, actual_positives
Example #3
0
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        """
        Update state with predictions and targets.

        Args:
            preds: Predictions from model
            target: Ground truth values
        """
        preds, target = _input_format_classification_one_hot(
            self.num_classes, preds, target, self.threshold, self.multilabel)

        # multiply because we are counting (1, 1) pair for true positives
        self.true_positives += torch.sum(preds * target, dim=1)
        self.actual_positives += torch.sum(target, dim=1)