Ejemplo n.º 1
0
def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index):
    """Test a combination of parameters that are invalid and should raise an error.

    This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting
    ``num_classes`` when ``reduce='macro'`, not setting ``mdmc_reduce`` when inputs
    are multi-dim multi-class``, setting ``ignore_index`` when inputs are binary, as well
    as setting ``ignore_index`` to a value higher than the number of classes.
    """
    with pytest.raises(ValueError):
        stat_scores(
            inputs.preds[0], inputs.target[0], reduce, mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index
        )

    with pytest.raises(ValueError):
        sts = StatScores(reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index)
        sts(inputs.preds[0], inputs.target[0])
Ejemplo n.º 2
0
def test_top_k(k: int, preds: torch.Tensor, target: torch.Tensor, reduce: str, expected: torch.Tensor):
    """ A simple test to check that top_k works as expected """

    class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3)
    class_metric.update(preds, target)

    assert torch.equal(class_metric.compute(), expected.T)
    assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T)