Beispiel #1
0
def test_incorrect_inputs(preds, target, num_classes, is_multiclass):
    with pytest.raises(ValueError):
        _input_format_classification(preds=preds,
                                     target=target,
                                     threshold=THRESHOLD,
                                     num_classes=num_classes,
                                     is_multiclass=is_multiclass)
Beispiel #2
0
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode,
                     post_preds, post_target):
    def __get_data_type_enum(str_exp_mode):
        return next(DataType[n] for n in dir(DataType)
                    if DataType[n] == str_exp_mode)

    for exp_mode in (exp_mode, __get_data_type_enum(exp_mode)):
        preds_out, target_out, mode = _input_format_classification(
            preds=inputs.preds[0],
            target=inputs.target[0],
            threshold=THRESHOLD,
            num_classes=num_classes,
            is_multiclass=is_multiclass,
            top_k=top_k,
        )

        assert mode == exp_mode
        assert torch.equal(preds_out, post_preds(inputs.preds[0]).int())
        assert torch.equal(target_out, post_target(inputs.target[0]).int())

        # Test that things work when batch_size = 1
        preds_out, target_out, mode = _input_format_classification(
            preds=inputs.preds[0][[0], ...],
            target=inputs.target[0][[0], ...],
            threshold=THRESHOLD,
            num_classes=num_classes,
            is_multiclass=is_multiclass,
            top_k=top_k,
        )

        assert mode == exp_mode
        assert torch.equal(preds_out,
                           post_preds(inputs.preds[0][[0], ...]).int())
        assert torch.equal(target_out,
                           post_target(inputs.target[0][[0], ...]).int())
Beispiel #3
0
def test_incorrect_inputs_topk(preds, target, num_classes, multiclass, top_k):
    with pytest.raises(ValueError):
        _input_format_classification(
            preds=preds,
            target=target,
            threshold=THRESHOLD,
            num_classes=num_classes,
            multiclass=multiclass,
            top_k=top_k,
        )
Beispiel #4
0
def test_threshold():
    target = T([1, 1, 1]).int()
    preds_probs = T([0.5 - 1e-5, 0.5, 0.5 + 1e-5])

    preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5)

    assert torch.equal(tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int())
Beispiel #5
0
def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_index, top_k, mdmc_reduce=None):
    preds, target, _ = _input_format_classification(
        preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k
    )
    sk_preds, sk_target = preds.numpy(), target.numpy()

    if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1:
        sk_preds = np.delete(sk_preds, ignore_index, 1)
        sk_target = np.delete(sk_target, ignore_index, 1)

    if preds.shape[1] == 1 and reduce == "samples":
        sk_target = sk_target.T
        sk_preds = sk_preds.T

    sk_stats = multilabel_confusion_matrix(
        sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1
    )

    if preds.shape[1] == 1 and reduce != "samples":
        sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]]
    else:
        sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]]

    if reduce == "micro":
        sk_stats = sk_stats.sum(axis=0, keepdims=True)

    sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1)

    if reduce == "micro":
        sk_stats = sk_stats[0]

    if reduce == "macro" and ignore_index is not None and preds.shape[1]:
        sk_stats[ignore_index, :] = -1

    return sk_stats
Beispiel #6
0
def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes,
                              multiclass, ignore_index, top_k, threshold):
    preds, target, _ = _input_format_classification(preds,
                                                    target,
                                                    threshold=threshold,
                                                    num_classes=num_classes,
                                                    multiclass=multiclass,
                                                    top_k=top_k)

    if mdmc_reduce == "global":
        preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
        target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])

        return _sk_stat_scores(preds, target, reduce, None, False,
                               ignore_index, top_k, threshold)
    if mdmc_reduce == "samplewise":
        scores = []

        for i in range(preds.shape[0]):
            pred_i = preds[i, ...].T
            target_i = target[i, ...].T
            scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False,
                                       ignore_index, top_k, threshold)

            scores.append(np.expand_dims(scores_i, 0))

        return np.concatenate(scores)
Beispiel #7
0
def _subset_accuracy_update(
    preds: Tensor,
    target: Tensor,
    threshold: float,
    top_k: Optional[int],
) -> Tuple[Tensor, Tensor]:

    preds, target = _input_squeeze(preds, target)
    preds, target, mode = _input_format_classification(preds,
                                                       target,
                                                       threshold=threshold,
                                                       top_k=top_k)

    if mode == DataType.MULTILABEL and top_k:
        raise ValueError(
            "You can not use the `top_k` parameter to calculate accuracy for multi-label inputs."
        )

    if mode == DataType.MULTILABEL:
        correct = (preds == target).all(dim=1).sum()
        total = tensor(target.shape[0], device=target.device)
    elif mode == DataType.MULTICLASS:
        correct = (preds * target).sum()
        total = target.sum()
    elif mode == DataType.MULTIDIM_MULTICLASS:
        sample_correct = (preds * target).sum(dim=(1, 2))
        correct = (sample_correct == target.shape[2]).sum()
        total = tensor(target.shape[0], device=target.device)

    return correct, total
Beispiel #8
0
def _sk_fbeta_f1_multidim_multiclass(preds, target, sk_fn, num_classes,
                                     average, multiclass, ignore_index,
                                     mdmc_average):
    preds, target, _ = _input_format_classification(preds,
                                                    target,
                                                    threshold=THRESHOLD,
                                                    num_classes=num_classes,
                                                    multiclass=multiclass)

    if mdmc_average == "global":
        preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
        target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])

        return _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, False,
                            ignore_index)
    elif mdmc_average == "samplewise":
        scores = []

        for i in range(preds.shape[0]):
            pred_i = preds[i, ...].T
            target_i = target[i, ...].T
            scores_i = _sk_fbeta_f1(pred_i, target_i, sk_fn, num_classes,
                                    average, False, ignore_index)

            scores.append(np.expand_dims(scores_i, 0))

        return np.concatenate(scores).mean(axis=0)
def _confusion_matrix_update(preds: Tensor,
                             target: Tensor,
                             num_classes: int,
                             threshold: float = 0.5,
                             multilabel: bool = False) -> Tensor:
    preds, target, mode = _input_format_classification(preds, target,
                                                       threshold)
    if mode not in (DataType.BINARY, DataType.MULTILABEL):
        preds = preds.argmax(dim=1)
        target = target.argmax(dim=1)
    if multilabel:
        unique_mapping = (
            (2 * target + preds) +
            4 * torch.arange(num_classes, device=preds.device)).flatten()
        minlength = 4 * num_classes
    else:
        unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(
            torch.long)
        minlength = num_classes**2

    bins = torch.bincount(unique_mapping, minlength=minlength)
    if multilabel:
        confmat = bins.reshape(num_classes, 2, 2)
    else:
        confmat = bins.reshape(num_classes, num_classes)
    return confmat
Beispiel #10
0
def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
    """Given a predictions and targets tensor, computes the confidences of the top-1 prediction and records their
    correctness.

    Args:
        preds:  Input softmaxed predictions.
        target: Labels.

    Raises:
        ValueError: If the dataset shape is not binary, multiclass, or multidimensional-multiclass.

    Returns:
        tuple with confidences and accuracies
    """
    _, _, mode = _input_format_classification(preds, target)

    if mode == DataType.BINARY:
        confidences, accuracies = preds, target
    elif mode == DataType.MULTICLASS:
        confidences, predictions = preds.max(dim=1)
        accuracies = predictions.eq(target)
    elif mode == DataType.MULTIDIM_MULTICLASS:
        # reshape tensors
        # for preds, move the class dimension to the final axis and flatten the rest
        confidences, predictions = torch.transpose(preds, 1,
                                                   -1).flatten(0,
                                                               -2).max(dim=1)
        # for targets, just flatten the target
        accuracies = predictions.eq(target.flatten())
    else:
        raise ValueError(
            f"Calibration error is not well-defined for data with size {preds.size()} and targets {target.size()}."
        )
    # must be cast to float for ddp allgather to work
    return confidences.float(), accuracies.float()
Beispiel #11
0
def _sk_spec_mdim_mcls(preds,
                       target,
                       reduce,
                       mdmc_reduce,
                       num_classes,
                       multiclass,
                       ignore_index,
                       top_k=None):
    preds, target, _ = _input_format_classification(preds,
                                                    target,
                                                    threshold=THRESHOLD,
                                                    num_classes=num_classes,
                                                    multiclass=multiclass,
                                                    top_k=top_k)

    if mdmc_reduce == "global":
        preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
        target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])
        return _sk_spec(preds, target, reduce, num_classes, False,
                        ignore_index, top_k, mdmc_reduce)
    fp, tn = [], []
    stats = []

    for i in range(preds.shape[0]):
        pred_i = preds[i, ...].T
        target_i = target[i, ...].T
        fp_i, tn_i = _sk_stats_score(pred_i, target_i, reduce, num_classes,
                                     False, ignore_index, top_k)
        fp.append(fp_i)
        tn.append(tn_i)

    stats.append(fp)
    stats.append(tn)
    return _sk_spec(preds[0], target[0], reduce, num_classes, multiclass,
                    ignore_index, top_k, mdmc_reduce, stats)
Beispiel #12
0
def _confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5) -> Tensor:
    preds, target, mode = _input_format_classification(preds, target, threshold)
    if mode not in (DataType.BINARY, DataType.MULTILABEL):
        preds = preds.argmax(dim=1)
        target = target.argmax(dim=1)
    unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
    bins = torch.bincount(unique_mapping, minlength=num_classes**2)
    confmat = bins.reshape(num_classes, num_classes)
    return confmat
def _sk_hamming_loss(preds, target):
    sk_preds, sk_target, _ = _input_format_classification(preds,
                                                          target,
                                                          threshold=THRESHOLD)
    sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()
    sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0],
                                           -1), sk_target.reshape(
                                               sk_target.shape[0], -1)

    return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds)
Beispiel #14
0
def _stat_scores_update(
    preds: Tensor,
    target: Tensor,
    reduce: str = "micro",
    mdmc_reduce: Optional[str] = None,
    num_classes: Optional[int] = None,
    top_k: Optional[int] = None,
    threshold: float = 0.5,
    multiclass: Optional[bool] = None,
    ignore_index: Optional[int] = None,
    is_multiclass: Optional[bool] = None,  # todo: deprecated, remove in v0.4
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    if is_multiclass is not None and multiclass is None:
        warn(
            "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.",
            DeprecationWarning
        )
        multiclass = is_multiclass

    preds, target, _ = _input_format_classification(
        preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k
    )

    if ignore_index is not None and not 0 <= ignore_index < preds.shape[1]:
        raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {preds.shape[0]} classes")

    if ignore_index is not None and preds.shape[1] == 1:
        raise ValueError("You can not use `ignore_index` with binary data.")

    if preds.ndim == 3:
        if not mdmc_reduce:
            raise ValueError(
                "When your inputs are multi-dimensional multi-class, you have to set the `mdmc_reduce` parameter"
            )
        if mdmc_reduce == "global":
            preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
            target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])

    # Delete what is in ignore_index, if applicable (and classes don't matter):
    if ignore_index is not None and reduce != "macro":
        preds = _del_column(preds, ignore_index)
        target = _del_column(target, ignore_index)

    tp, fp, tn, fn = _stat_scores(preds, target, reduce=reduce)

    # Take care of ignore_index
    if ignore_index is not None and reduce == "macro":
        tp[..., ignore_index] = -1
        fp[..., ignore_index] = -1
        tn[..., ignore_index] = -1
        fn[..., ignore_index] = -1

    return tp, fp, tn, fn
Beispiel #15
0
def _sk_accuracy(preds, target, subset_accuracy):
    sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD)
    sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()

    if mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy:
        sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1))
        sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2])
    elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
        return np.all(sk_preds == sk_target, axis=(1, 2)).mean()
    elif mode == DataType.MULTILABEL and not subset_accuracy:
        sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1)

    return sk_accuracy(y_true=sk_target, y_pred=sk_preds)
Beispiel #16
0
def _hamming_distance_update(
    preds: Tensor,
    target: Tensor,
    threshold: float = 0.5,
) -> Tuple[Tensor, int]:
    preds, target, _ = _input_format_classification(preds,
                                                    target,
                                                    threshold=threshold)

    correct = (preds == target).sum()
    total = preds.numel()

    return correct, total
Beispiel #17
0
def _auroc_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, str]:
    # use _input_format_classification for validating the input and get the mode of data
    _, _, mode = _input_format_classification(preds, target)

    if mode == 'multi class multi dim':
        n_classes = preds.shape[1]
        preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
        target = target.flatten()
    if mode == 'multi-label' and preds.ndim > 2:
        n_classes = preds.shape[1]
        preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
        target = target.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)

    return preds, target, mode
Beispiel #18
0
def _subset_accuracy_update(
    preds: Tensor,
    target: Tensor,
    threshold: float,
    top_k: Optional[int],
    ignore_index: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
    """Updates and returns variables required to compute subset accuracy.

    Args:
        preds: Predicted tensor
        target: Ground truth tensor
        threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
            of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
        top_k: Number of highest probability or logit score predictions considered to find the correct label,
            relevant only for (multi-dimensional) multi-class inputs.
    """

    preds, target = _input_squeeze(preds, target)
    preds, target, mode = _input_format_classification(
        preds,
        target,
        threshold=threshold,
        top_k=top_k,
        ignore_index=ignore_index)

    if mode == DataType.MULTILABEL and top_k:
        raise ValueError(
            "You can not use the `top_k` parameter to calculate accuracy for multi-label inputs."
        )

    if mode == DataType.MULTILABEL:
        correct = (preds == target).all(dim=1).sum()
        total = tensor(target.shape[0], device=target.device)
    elif mode == DataType.MULTICLASS:
        correct = (preds * target).sum()
        total = target.sum()
    elif mode == DataType.MULTIDIM_MULTICLASS:
        sample_correct = (preds * target).sum(dim=(1, 2))
        correct = (sample_correct == target.shape[2]).sum()
        total = tensor(target.shape[0], device=target.device)
    else:
        correct, total = tensor(0), tensor(0)

    return correct, total
Beispiel #19
0
def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
    _, _, mode = _input_format_classification(preds, target)

    if mode == DataType.BINARY:
        confidences, accuracies = preds, target
    elif mode == DataType.MULTICLASS:
        confidences, predictions = preds.max(dim=1)
        accuracies = predictions.eq(target)
    elif mode == DataType.MULTIDIM_MULTICLASS:
        # reshape tensors
        # for preds, move the class dimension to the final axis and flatten the rest
        confidences, predictions = torch.transpose(preds, 1, -1).flatten(0, -2).max(dim=1)
        # for targets, just flatten the target
        accuracies = predictions.eq(target.flatten())
    else:
        raise ValueError(
            f"Calibration error is not well-defined for data with size {preds.size()} and targets {target.size()}")

    return confidences, accuracies
Beispiel #20
0
def _hamming_distance_update(
    preds: Tensor,
    target: Tensor,
    threshold: float = 0.5,
) -> Tuple[Tensor, int]:
    """Returns the number of positions where prediction equals target, and number of predictions.

    Args:
        preds: Predicted tensor
        target: Ground truth tensor
        threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
            of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
    """

    preds, target, _ = _input_format_classification(preds, target, threshold=threshold)

    correct = (preds == target).sum()
    total = preds.numel()

    return correct, total
Beispiel #21
0
def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average=None):
    if average == "none":
        average = None
    if num_classes == 1:
        average = "binary"

    labels = list(range(num_classes))
    try:
        labels.remove(ignore_index)
    except ValueError:
        pass

    sk_preds, sk_target, _ = _input_format_classification(
        preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass
    )
    sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()
    sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels)

    if len(labels) != num_classes and not average:
        sk_scores = np.insert(sk_scores, ignore_index, np.nan)

    return sk_scores
Beispiel #22
0
def _auroc_update(preds: Tensor,
                  target: Tensor) -> Tuple[Tensor, Tensor, DataType]:
    """Updates and returns variables required to compute Area Under the Receiver Operating Characteristic Curve.
    Validates the inputs and returns the mode of the inputs.

    Args:
        preds: Predicted tensor
        target: Ground truth tensor
    """

    # use _input_format_classification for validating the input and get the mode of data
    _, _, mode = _input_format_classification(preds, target)

    if mode == "multi class multi dim":
        n_classes = preds.shape[1]
        preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
        target = target.flatten()
    if mode == "multi-label" and preds.ndim > 2:
        n_classes = preds.shape[1]
        preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
        target = target.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)

    return preds, target, mode
Beispiel #23
0
def _accuracy_update(
    preds: Tensor,
    target: Tensor,
    threshold: float,
    top_k: Optional[int],
    ignore_index: Optional[int],
    subset_accuracy: bool,
) -> Tuple[Tensor, Tensor]:

    preds, target, mode = _input_format_classification(preds,
                                                       target,
                                                       threshold=threshold,
                                                       top_k=top_k)
    correct, total = None, None

    # Delete what is in ignore_index, if applicable (and classes don't matter):
    if ignore_index is not None:
        preds = _del_column(preds, ignore_index)
        target = _del_column(target, ignore_index)

    if mode == DataType.BINARY or (mode == DataType.MULTILABEL
                                   and subset_accuracy):
        correct = (preds == target).all(dim=1).sum()
        total = tensor(target.shape[0], device=target.device)
    elif mode == DataType.MULTILABEL and not subset_accuracy:
        correct = (preds == target).sum()
        total = tensor(target.numel(), device=target.device)
    elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS
                                         and not subset_accuracy):
        correct = (preds * target).sum()
        total = target.sum()
    elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
        sample_correct = (preds * target).sum(dim=(1, 2))
        correct = (sample_correct == target.shape[2]).sum()
        total = tensor(target.shape[0], device=target.device)

    return correct, total
Beispiel #24
0
def _confusion_matrix_update(preds: Tensor,
                             target: Tensor,
                             num_classes: int,
                             threshold: float = 0.5,
                             multilabel: bool = False) -> Tensor:
    """Updates and returns confusion matrix (without any normalization) based on the mode of the input.

    Args:
        preds: Predicted tensor
        target: Ground truth tensor
        threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the
            case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
        multilabel: determines if data is multilabel or not.
    """

    preds, target, mode = _input_format_classification(preds, target,
                                                       threshold)
    if mode not in (DataType.BINARY, DataType.MULTILABEL):
        preds = preds.argmax(dim=1)
        target = target.argmax(dim=1)
    if multilabel:
        unique_mapping = (
            (2 * target + preds) +
            4 * torch.arange(num_classes, device=preds.device)).flatten()
        minlength = 4 * num_classes
    else:
        unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(
            torch.long)
        minlength = num_classes**2

    bins = torch.bincount(unique_mapping, minlength=minlength)
    if multilabel:
        confmat = bins.reshape(num_classes, 2, 2)
    else:
        confmat = bins.reshape(num_classes, num_classes)
    return confmat
Beispiel #25
0
def _sk_calibration(preds, target, n_bins, norm, debias=False):
    _, _, mode = _input_format_classification(preds,
                                              target,
                                              threshold=THRESHOLD)
    sk_preds, sk_target = preds.numpy(), target.numpy()

    if mode == DataType.MULTICLASS:
        # binary label is whether or not the predicted class is correct
        sk_target = np.equal(np.argmax(sk_preds, axis=1), sk_target)
        sk_preds = np.max(sk_preds, axis=1)
    elif mode == DataType.MULTIDIM_MULTICLASS:
        # reshape from shape (N, C, ...) to (N*EXTRA_DIMS, C)
        sk_preds = np.transpose(sk_preds, axes=(0, 2, 1))
        sk_preds = sk_preds.reshape(np.prod(sk_preds.shape[:-1]),
                                    sk_preds.shape[-1])
        # reshape from shape (N, ...) to (N*EXTRA_DIMS,)
        # binary label is whether or not the predicted class is correct
        sk_target = np.equal(np.argmax(sk_preds, axis=1), sk_target.flatten())
        sk_preds = np.max(sk_preds, axis=1)
    return sk_calib(y_true=sk_target,
                    y_prob=sk_preds,
                    norm=norm,
                    n_bins=n_bins,
                    reduce_bias=debias)
Beispiel #26
0
def _stat_scores_update(
    preds: Tensor,
    target: Tensor,
    reduce: Optional[str] = "micro",
    mdmc_reduce: Optional[str] = None,
    num_classes: Optional[int] = None,
    top_k: Optional[int] = None,
    threshold: float = 0.5,
    multiclass: Optional[bool] = None,
    ignore_index: Optional[int] = None,
    mode: DataType = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """Updates and returns the number of true positives, false positives, true negatives, false negatives. Raises
    ValueError if:

        - The `ignore_index` is not valid
        - When `ignore_index` is used with binary data
        - When inputs are multi-dimensional multi-class, and the ``mdmc_reduce`` parameter is not set

    Args:
        preds: Predicted tensor
        target: Ground truth tensor
        reduce: Defines the reduction that is applied
        mdmc_reduce: Defines how the multi-dimensional multi-class inputs are handled
        num_classes: Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data.
        top_k: Number of the highest probability or logit score predictions considered finding the correct label,
            relevant only for (multi-dimensional) multi-class inputs
        threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
            of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities
        multiclass: Used only in certain special cases, where you want to treat inputs as a different type
            than what they appear to be
        ignore_index: Specify a class (label) to ignore. If given, this class index does not contribute
            to the returned score, regardless of reduction method. If an index is ignored, and
            ``reduce='macro'``, the class statistics for the ignored class will all be returned
            as ``-1``.
        mode: Mode of the input tensors
    """

    _negative_index_dropped = False

    if ignore_index is not None and ignore_index < 0 and mode is not None:
        preds, target = _drop_negative_ignored_indices(preds, target,
                                                       ignore_index, mode)
        _negative_index_dropped = True

    preds, target, _ = _input_format_classification(
        preds,
        target,
        threshold=threshold,
        num_classes=num_classes,
        multiclass=multiclass,
        top_k=top_k,
        ignore_index=ignore_index,
    )

    if ignore_index is not None and ignore_index >= preds.shape[1]:
        raise ValueError(
            f"The `ignore_index` {ignore_index} is not valid for inputs with {preds.shape[1]} classes"
        )

    if ignore_index is not None and preds.shape[1] == 1:
        raise ValueError("You can not use `ignore_index` with binary data.")

    if preds.ndim == 3:
        if not mdmc_reduce:
            raise ValueError(
                "When your inputs are multi-dimensional multi-class, you have to set the `mdmc_reduce` parameter"
            )
        if mdmc_reduce == "global":
            preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
            target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])

    # Delete what is in ignore_index, if applicable (and classes don't matter):
    if ignore_index is not None and reduce != "macro" and not _negative_index_dropped:
        preds = _del_column(preds, ignore_index)
        target = _del_column(target, ignore_index)

    tp, fp, tn, fn = _stat_scores(preds, target, reduce=reduce)

    # Take care of ignore_index
    if ignore_index is not None and reduce == "macro" and not _negative_index_dropped:
        tp[..., ignore_index] = -1
        fp[..., ignore_index] = -1
        tn[..., ignore_index] = -1
        fn[..., ignore_index] = -1

    return tp, fp, tn, fn
Beispiel #27
0
def test_incorrect_threshold(threshold):
    preds, target = rand(size=(7, )), randint(high=2, size=(7, ))
    with pytest.raises(ValueError):
        _input_format_classification(preds, target, threshold=threshold)