Beispiel #1
0
def _input_format_classification_one_hot(
    num_classes: int,
    preds: Tensor,
    target: Tensor,
    threshold: float = 0.5,
    multilabel: bool = False,
) -> Tuple[Tensor, Tensor]:
    """Convert preds and target tensors into one hot spare label tensors

    Args:
        num_classes: number of classes
        preds: either tensor with labels, tensor with probabilities/logits or multilabel tensor
        target: tensor with ground true labels
        threshold: float used for thresholding multilabel input
        multilabel: boolean flag indicating if input is multilabel

    Raises:
        ValueError:
            If ``preds`` and ``target`` don't have the same number of dimensions
            or one additional dimension for ``preds``.

    Returns:
        preds: one hot tensor of shape [num_classes, -1] with predicted labels
        target: one hot tensors of shape [num_classes, -1] with true labels
    """
    if preds.ndim not in (target.ndim, target.ndim + 1):
        raise ValueError(
            "preds and target must have same number of dimensions, or one additional dimension for preds"
        )

    if preds.ndim == target.ndim + 1:
        # multi class probabilities
        preds = torch.argmax(preds, dim=1)

    if preds.ndim == target.ndim and preds.dtype in (
            torch.long, torch.int) and num_classes > 1 and not multilabel:
        # multi-class
        preds = to_onehot(preds, num_classes=num_classes)
        target = to_onehot(target, num_classes=num_classes)

    elif preds.ndim == target.ndim and preds.is_floating_point():
        # binary or multilabel probabilities
        preds = (preds >= threshold).long()

    # transpose class as first dim and reshape
    if preds.ndim > 1:
        preds = preds.transpose(1, 0)
        target = target.transpose(1, 0)

    return preds.reshape(num_classes, -1), target.reshape(num_classes, -1)
Beispiel #2
0
    def update(self, preds: torch.Tensor,
               targets: torch.Tensor):  # type: ignore
        if self.exclude_neutral:
            pp = (preds[targets != 0] >= 0).int()
            tt = (targets[targets != 0] >= 0).int()
        else:
            pp = (preds >= 0).int()
            tt = (targets >= 0).int()

        pp = to_onehot(pp, num_classes=2).transpose(1, 0).reshape(2, -1)
        tt = to_onehot(tt, num_classes=2).transpose(1, 0).reshape(2, -1)

        true_positives = torch.sum(pp * tt, dim=1)
        predicted_positives = torch.sum(pp, dim=1)
        actual_positives = torch.sum(tt, dim=1)

        self.true_positives += true_positives
        self.predicted_positives += predicted_positives
        self.actual_positives += actual_positives
Beispiel #3
0
def test_onehot():
    test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
    expected = torch.stack([
        torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
        torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
    ])

    assert test_tensor.shape == (2, 5)
    assert expected.shape == (2, 10, 5)

    onehot_classes = to_onehot(test_tensor, num_classes=10)
    onehot_no_classes = to_onehot(test_tensor)

    assert torch.allclose(onehot_classes, onehot_no_classes)

    assert onehot_classes.shape == expected.shape
    assert onehot_no_classes.shape == expected.shape

    assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes)
    assert torch.allclose(expected.to(onehot_classes), onehot_classes)
Beispiel #4
0
def _hinge_update(
    preds: Tensor,
    target: Tensor,
    squared: bool = False,
    multiclass_mode: Optional[Union[str, MulticlassMode]] = None,
) -> Tuple[Tensor, Tensor]:
    """Updates and returns sum over Hinge loss scores for each observation and the total number of observations.

    Args:
        preds: Predicted tensor
        target: Ground truth tensor
        squared: If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.
        multiclass_mode:
            Which approach to use for multi-class inputs (has no effect in the binary case). ``None`` (default),
            ``MulticlassMode.CRAMMER_SINGER`` or ``"crammer-singer"``, uses the Crammer Singer multi-class hinge loss.
            ``MulticlassMode.ONE_VS_ALL`` or ``"one-vs-all"`` computes the hinge loss in a one-vs-all fashion.
    """
    preds, target = _input_squeeze(preds, target)

    mode = _check_shape_and_type_consistency_hinge(preds, target)

    if mode == DataType.MULTICLASS:
        target = to_onehot(target, max(2, preds.shape[1])).bool()

    if mode == DataType.MULTICLASS and (multiclass_mode is None
                                        or multiclass_mode
                                        == MulticlassMode.CRAMMER_SINGER):
        margin = preds[target]
        margin -= torch.max(preds[~target].view(preds.shape[0], -1), dim=1)[0]
    elif mode == DataType.BINARY or multiclass_mode == MulticlassMode.ONE_VS_ALL:
        target = target.bool()
        margin = torch.zeros_like(preds)
        margin[target] = preds[target]
        margin[~target] = -preds[~target]
    else:
        raise ValueError(
            "The `multiclass_mode` should be either None / 'crammer-singer' / MulticlassMode.CRAMMER_SINGER"
            "(default) or 'one-vs-all' / MulticlassMode.ONE_VS_ALL,"
            f" got {multiclass_mode}.")

    measures = 1 - margin
    measures = torch.clamp(measures, 0)

    if squared:
        measures = measures.pow(2)

    total = tensor(target.shape[0], device=target.device)
    return measures.sum(dim=0), total
Beispiel #5
0
def _hinge_update(
    preds: Tensor,
    target: Tensor,
    squared: bool = False,
    multiclass_mode: Optional[Union[str, MulticlassMode]] = None,
) -> Tuple[Tensor, Tensor]:
    if preds.shape[0] == 1:
        preds, target = preds.squeeze().unsqueeze(
            0), target.squeeze().unsqueeze(0)
    else:
        preds, target = preds.squeeze(), target.squeeze()

    mode = _check_shape_and_type_consistency_hinge(preds, target)

    if mode == DataType.MULTICLASS:
        target = to_onehot(target, max(2, preds.shape[1])).bool()

    if mode == DataType.MULTICLASS and (multiclass_mode is None
                                        or multiclass_mode
                                        == MulticlassMode.CRAMMER_SINGER):
        margin = preds[target]
        margin -= torch.max(preds[~target].view(preds.shape[0], -1), dim=1)[0]
    elif mode == DataType.BINARY or multiclass_mode == MulticlassMode.ONE_VS_ALL:
        target = target.bool()
        margin = torch.zeros_like(preds)
        margin[target] = preds[target]
        margin[~target] = -preds[~target]
    else:
        raise ValueError(
            "The `multiclass_mode` should be either None / 'crammer-singer' / MulticlassMode.CRAMMER_SINGER"
            "(default) or 'one-vs-all' / MulticlassMode.ONE_VS_ALL,"
            f" got {multiclass_mode}.")

    measures = 1 - margin
    measures = torch.clamp(measures, 0)

    if squared:
        measures = measures.pow(2)

    total = tensor(target.shape[0], device=target.device)
    return measures.sum(dim=0), total
Beispiel #6
0
    def update(self, preds: Tensor, target: Tensor) -> None:  # type: ignore
        """
        Args
            preds: (n_samples, n_classes) tensor
            target: (n_samples, n_classes) tensor
        """
        # binary case
        if len(preds.shape) == len(target.shape) == 1:
            preds = preds.reshape(-1, 1)
            target = target.reshape(-1, 1)

        if len(preds.shape) == len(target.shape) + 1:
            target = to_onehot(target, num_classes=self.num_classes)

        target = target == 1
        # Iterate one threshold at a time to conserve memory
        for i in range(self.num_thresholds):
            predictions = preds >= self.thresholds[i]
            self.TPs[:, i] += (target & predictions).sum(dim=0)
            self.FPs[:, i] += ((~target) & (predictions)).sum(dim=0)
            self.FNs[:, i] += ((target) & (~predictions)).sum(dim=0)
Beispiel #7
0
def _onehot2(x):
    return to_onehot(x, 2)
Beispiel #8
0
def _onehot(x):
    return to_onehot(x, NUM_CLASSES)
Beispiel #9
0
    )

    if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k:
        preds = (preds >= threshold).int()
        num_classes = num_classes if not multiclass else 2

    if case == DataType.MULTILABEL and top_k:
        preds = select_topk(preds, top_k)

    if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or multiclass:
        if preds.is_floating_point():
            num_classes = preds.shape[1]
            preds = select_topk(preds, top_k or 1)
        else:
            num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1
            preds = to_onehot(preds, max(2, num_classes))

        target = to_onehot(target, max(2, num_classes))

        if multiclass is False:
            preds, target = preds[:, 1, ...], target[:, 1, ...]

    if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and multiclass is not False) or multiclass:
        target = target.reshape(target.shape[0], target.shape[1], -1)
        preds = preds.reshape(preds.shape[0], preds.shape[1], -1)
    else:
        target = target.reshape(target.shape[0], -1)
        preds = preds.reshape(preds.shape[0], -1)

    # Some operations above create an extra dimension for MC/binary case - this removes it
    if preds.ndim > 2: