示例#1
0
def test_incorrect_inputs(preds, target, num_classes, is_multiclass):
    with pytest.raises(ValueError):
        _input_format_classification(preds=preds,
                                     target=target,
                                     threshold=THRESHOLD,
                                     num_classes=num_classes,
                                     is_multiclass=is_multiclass)
示例#2
0
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode,
                     post_preds, post_target):
    def __get_data_type_enum(str_exp_mode):
        return next(DataType[n] for n in dir(DataType)
                    if DataType[n] == str_exp_mode)

    for exp_mode in (exp_mode, __get_data_type_enum(exp_mode)):
        preds_out, target_out, mode = _input_format_classification(
            preds=inputs.preds[0],
            target=inputs.target[0],
            threshold=THRESHOLD,
            num_classes=num_classes,
            is_multiclass=is_multiclass,
            top_k=top_k,
        )

        assert mode == exp_mode
        assert torch.equal(preds_out, post_preds(inputs.preds[0]).int())
        assert torch.equal(target_out, post_target(inputs.target[0]).int())

        # Test that things work when batch_size = 1
        preds_out, target_out, mode = _input_format_classification(
            preds=inputs.preds[0][[0], ...],
            target=inputs.target[0][[0], ...],
            threshold=THRESHOLD,
            num_classes=num_classes,
            is_multiclass=is_multiclass,
            top_k=top_k,
        )

        assert mode == exp_mode
        assert torch.equal(preds_out,
                           post_preds(inputs.preds[0][[0], ...]).int())
        assert torch.equal(target_out,
                           post_target(inputs.target[0][[0], ...]).int())
示例#3
0
def _accuracy_update(
        preds: torch.Tensor, target: torch.Tensor, threshold: float,
        top_k: Optional[int],
        subset_accuracy: bool) -> Tuple[torch.Tensor, torch.Tensor]:

    preds, target, mode = _input_format_classification(preds,
                                                       target,
                                                       threshold=threshold,
                                                       top_k=top_k)

    if mode == DataType.MULTILABEL and top_k:
        raise ValueError(
            "You can not use the `top_k` parameter to calculate accuracy for multi-label inputs."
        )

    if mode == DataType.BINARY or (mode == DataType.MULTILABEL
                                   and subset_accuracy):
        correct = (preds == target).all(dim=1).sum()
        total = torch.tensor(target.shape[0], device=target.device)
    elif mode == DataType.MULTILABEL and not subset_accuracy:
        correct = (preds == target).sum()
        total = torch.tensor(target.numel(), device=target.device)
    elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS
                                         and not subset_accuracy):
        correct = (preds * target).sum()
        total = target.sum()
    elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
        sample_correct = (preds * target).sum(dim=(1, 2))
        correct = (sample_correct == target.shape[2]).sum()
        total = torch.tensor(target.shape[0], device=target.device)

    return correct, total
示例#4
0
def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_index, top_k, mdmc_reduce=None):
    preds, target, _ = _input_format_classification(
        preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k
    )
    sk_preds, sk_target = preds.numpy(), target.numpy()

    if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1:
        sk_preds = np.delete(sk_preds, ignore_index, 1)
        sk_target = np.delete(sk_target, ignore_index, 1)

    if preds.shape[1] == 1 and reduce == "samples":
        sk_target = sk_target.T
        sk_preds = sk_preds.T

    sk_stats = multilabel_confusion_matrix(
        sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1
    )

    if preds.shape[1] == 1 and reduce != "samples":
        sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]]
    else:
        sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]]

    if reduce == "micro":
        sk_stats = sk_stats.sum(axis=0, keepdims=True)

    sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1)

    if reduce == "micro":
        sk_stats = sk_stats[0]

    if reduce == "macro" and ignore_index is not None and preds.shape[1]:
        sk_stats[ignore_index, :] = -1

    return sk_stats
示例#5
0
def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes,
                              is_multiclass, ignore_index, top_k):
    preds, target, _ = _input_format_classification(
        preds,
        target,
        threshold=THRESHOLD,
        num_classes=num_classes,
        is_multiclass=is_multiclass,
        top_k=top_k)

    if mdmc_reduce == "global":
        preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
        target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])

        return _sk_stat_scores(preds, target, reduce, None, False,
                               ignore_index, top_k)
    elif mdmc_reduce == "samplewise":
        scores = []

        for i in range(preds.shape[0]):
            pred_i = preds[i, ...].T
            target_i = target[i, ...].T
            scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False,
                                       ignore_index, top_k)

            scores.append(np.expand_dims(scores_i, 0))

        return np.concatenate(scores)
def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn, num_classes,
                                        average, is_multiclass, ignore_index,
                                        mdmc_average):
    preds, target, _ = _input_format_classification(
        preds,
        target,
        threshold=THRESHOLD,
        num_classes=num_classes,
        is_multiclass=is_multiclass)

    if mdmc_average == "global":
        preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
        target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])

        return _sk_prec_recall(preds, target, sk_fn, num_classes, average,
                               False, ignore_index)
    elif mdmc_average == "samplewise":
        scores = []

        for i in range(preds.shape[0]):
            pred_i = preds[i, ...].T
            target_i = target[i, ...].T
            scores_i = _sk_prec_recall(pred_i, target_i, sk_fn, num_classes,
                                       average, False, ignore_index)

            scores.append(np.expand_dims(scores_i, 0))

        return np.concatenate(scores).mean(axis=0)
def _sk_hamming_loss(preds, target):
    sk_preds, sk_target, _ = _input_format_classification(preds,
                                                          target,
                                                          threshold=THRESHOLD)
    sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()
    sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0],
                                           -1), sk_target.reshape(
                                               sk_target.shape[0], -1)

    return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds)
示例#8
0
def test_threshold():
    target = T([1, 1, 1]).int()
    preds_probs = T([0.5 - 1e-5, 0.5, 0.5 + 1e-5])

    preds_probs_out, _, _ = _input_format_classification(preds_probs,
                                                         target,
                                                         threshold=0.5)

    assert torch.equal(torch.tensor([0, 1, 1], dtype=torch.int),
                       preds_probs_out.squeeze().int())
示例#9
0
def _stat_scores_update(
    preds: torch.Tensor,
    target: torch.Tensor,
    reduce: str = "micro",
    mdmc_reduce: Optional[str] = None,
    num_classes: Optional[int] = None,
    top_k: Optional[int] = None,
    threshold: float = 0.5,
    is_multiclass: Optional[bool] = None,
    ignore_index: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

    preds, target, _ = _input_format_classification(
        preds,
        target,
        threshold=threshold,
        num_classes=num_classes,
        is_multiclass=is_multiclass,
        top_k=top_k)

    if ignore_index is not None and not 0 <= ignore_index < preds.shape[1]:
        raise ValueError(
            f"The `ignore_index` {ignore_index} is not valid for inputs with {preds.shape[0]} classes"
        )

    if ignore_index is not None and preds.shape[1] == 1:
        raise ValueError("You can not use `ignore_index` with binary data.")

    if preds.ndim == 3:
        if not mdmc_reduce:
            raise ValueError(
                "When your inputs are multi-dimensional multi-class, you have to set the `mdmc_reduce` parameter"
            )
        if mdmc_reduce == "global":
            preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
            target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])

    # Delete what is in ignore_index, if applicable (and classes don't matter):
    if ignore_index is not None and reduce != "macro":
        preds = _del_column(preds, ignore_index)
        target = _del_column(target, ignore_index)

    tp, fp, tn, fn = _stat_scores(preds, target, reduce=reduce)

    # Take care of ignore_index
    if ignore_index is not None and reduce == "macro":
        tp[..., ignore_index] = -1
        fp[..., ignore_index] = -1
        tn[..., ignore_index] = -1
        fn[..., ignore_index] = -1

    return tp, fp, tn, fn
示例#10
0
def _sk_accuracy(preds, target, subset_accuracy):
    sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD)
    sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()

    if mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy:
        sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1))
        sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2])
    elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
        return np.all(sk_preds == sk_target, axis=(1, 2)).mean()
    elif mode == DataType.MULTILABEL and not subset_accuracy:
        sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1)

    return sk_accuracy(y_true=sk_target, y_pred=sk_preds)
示例#11
0
def _hamming_distance_update(
    preds: torch.Tensor,
    target: torch.Tensor,
    threshold: float = 0.5,
) -> Tuple[torch.Tensor, int]:
    preds, target, _ = _input_format_classification(preds,
                                                    target,
                                                    threshold=threshold)

    correct = (preds == target).sum()
    total = preds.numel()

    return correct, total
示例#12
0
def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, str]:
    # use _input_format_classification for validating the input and get the mode of data
    _, _, mode = _input_format_classification(preds, target)

    if mode == 'multi class multi dim':
        n_classes = preds.shape[1]
        preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
        target = target.flatten()
    if mode == 'multi-label' and preds.ndim > 2:
        n_classes = preds.shape[1]
        preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
        target = target.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)

    return preds, target, mode
示例#13
0
def _confusion_matrix_update(preds: torch.Tensor,
                             target: torch.Tensor,
                             num_classes: int,
                             threshold: float = 0.5) -> torch.Tensor:
    preds, target, mode = _input_format_classification(preds, target,
                                                       threshold)
    if mode not in (DataType.BINARY, DataType.MULTILABEL):
        preds = preds.argmax(dim=1)
        target = target.argmax(dim=1)
    unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(
        torch.long)
    bins = torch.bincount(unique_mapping, minlength=num_classes**2)
    confmat = bins.reshape(num_classes, num_classes)
    return confmat
def _sk_prec_recall(preds,
                    target,
                    sk_fn,
                    num_classes,
                    average,
                    is_multiclass,
                    ignore_index,
                    mdmc_average=None):
    if average == "none":
        average = None
    if num_classes == 1:
        average = "binary"

    labels = list(range(num_classes))
    try:
        labels.remove(ignore_index)
    except ValueError:
        pass

    sk_preds, sk_target, _ = _input_format_classification(
        preds,
        target,
        THRESHOLD,
        num_classes=num_classes,
        is_multiclass=is_multiclass)
    sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()

    sk_scores = sk_fn(sk_target,
                      sk_preds,
                      average=average,
                      zero_division=0,
                      labels=labels)

    if len(labels) != num_classes and not average:
        sk_scores = np.insert(sk_scores, ignore_index, np.nan)

    return sk_scores
示例#15
0
def test_incorrect_threshold(threshold):
    preds, target = rand(size=(7, )), randint(high=2, size=(7, ))
    with pytest.raises(ValueError):
        _input_format_classification(preds, target, threshold=threshold)