Пример #1
0
def test_v1_5_metrics_utils():
    x = torch.tensor([1, 2, 3])
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.equal(to_onehot(x), torch.Tensor([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]).to(int))

    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert get_num_classes(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 0])) == 4

    x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]])
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.equal(select_topk(x, topk=2), torch.Tensor([[0, 1, 1], [1, 1, 0]]).to(torch.int32))

    x = torch.tensor([[0.2, 0.5], [0.9, 0.1]])
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.equal(to_categorical(x), torch.Tensor([1, 0]).to(int))
Пример #2
0
def _top1(x):
    return select_topk(x, 1)
Пример #3
0
def _top2(x):
    return select_topk(x, 2)
Пример #4
0
    case = _check_classification_inputs(
        preds,
        target,
        threshold=threshold,
        num_classes=num_classes,
        is_multiclass=is_multiclass,
        top_k=top_k,
    )

    if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k:
        preds = (preds >= threshold).int()
        num_classes = num_classes if not is_multiclass else 2

    if case == DataType.MULTILABEL and top_k:
        preds = select_topk(preds, top_k)

    if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or is_multiclass:
        if preds.is_floating_point():
            num_classes = preds.shape[1]
            preds = select_topk(preds, top_k or 1)
        else:
            num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1
            preds = to_onehot(preds, max(2, num_classes))

        target = to_onehot(target, max(2, num_classes))

        if is_multiclass is False:
            preds, target = preds[:, 1, ...], target[:, 1, ...]

    if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and is_multiclass is not False) or is_multiclass: