Example #1
0
def test_error_on_wrong_input():
    """Test that base metric class raises error on wrong input types."""
    with pytest.raises(ValueError, match="Expected keyword argument `dist_sync_on_step` to be an `bool` but.*"):
        DummyMetric(dist_sync_on_step=None)

    with pytest.raises(ValueError, match="Expected keyword argument `dist_sync_fn` to be an callable function.*"):
        DummyMetric(dist_sync_fn=[2, 3])

    with pytest.raises(ValueError, match="Expected keyword argument `compute_on_cpu` to be an `bool` bu.*"):
        DummyMetric(compute_on_cpu=None)
Example #2
0
def test_state_dict(tmpdir):
    """ test that metric states can be removed and added to state dict """
    metric = DummyMetric()
    assert metric.state_dict() == OrderedDict()
    metric.persistent(True)
    assert metric.state_dict() == OrderedDict(x=0)
    metric.persistent(False)
    assert metric.state_dict() == OrderedDict()
Example #3
0
def test_add_state_persistent():
    a = DummyMetric()

    a.add_state("a", tensor(0), "sum", persistent=True)
    assert "a" in a.state_dict()

    a.add_state("b", tensor(0), "sum", persistent=False)

    if _TORCH_LOWER_1_6:
        assert "b" not in a.state_dict()
Example #4
0
def test_add_state_persistent():
    a = DummyMetric()

    a.add_state("a", torch.tensor(0), "sum", persistent=True)
    assert "a" in a.state_dict()

    a.add_state("b", torch.tensor(0), "sum", persistent=False)

    if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
        assert "b" not in a.state_dict()
Example #5
0
def _test_ddp_sum_cat(rank, worldsize):
    setup_ddp(rank, worldsize)
    dummy = DummyMetric()
    dummy._reductions = {"foo": torch.cat, "bar": torch.sum}
    dummy.foo = [tensor([1])]
    dummy.bar = tensor(1)
    dummy._sync_dist()

    assert torch.all(torch.eq(dummy.foo, tensor([1, 1])))
    assert dummy.bar == worldsize
Example #6
0
 class TestModule(nn.Module):
     def __init__(self):
         super().__init__()
         self.metric = DummyMetric()
         self.metric.add_state('a', tensor(0), persistent=True)
         self.metric.add_state('b', [], persistent=True)
         self.metric.register_buffer('c', tensor(0))
Example #7
0
def _test_ddp_sum(rank, worldsize):
    setup_ddp(rank, worldsize)
    dummy = DummyMetric()
    dummy._reductions = {"foo": torch.sum}
    dummy.foo = tensor(1)
    dummy._sync_dist()

    assert dummy.foo == worldsize
Example #8
0
def _test_ddp_cat(rank, worldsize):
    setup_ddp(rank, worldsize)
    dummy = DummyMetric()
    dummy._reductions = {"foo": torch.cat}
    dummy.foo = [torch.tensor([1])]
    dummy._sync_dist()

    assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1])))
Example #9
0
def test_add_state():
    a = DummyMetric()

    a.add_state("a", tensor(0), "sum")
    assert a._reductions["a"](tensor([1, 1])) == 2

    a.add_state("b", tensor(0), "mean")
    assert np.allclose(a._reductions["b"](tensor([1.0, 2.0])).numpy(), 1.5)

    a.add_state("c", tensor(0), "cat")
    assert a._reductions["c"]([tensor([1]), tensor([1])]).shape == (2, )

    with pytest.raises(ValueError):
        a.add_state("d1", tensor(0), 'xyz')

    with pytest.raises(ValueError):
        a.add_state("d2", tensor(0), 42)

    with pytest.raises(ValueError):
        a.add_state("d3", [tensor(0)], 'sum')

    with pytest.raises(ValueError):
        a.add_state("d4", 42, 'sum')

    def custom_fx(_):
        return -1

    a.add_state("e", tensor(0), custom_fx)
    assert a._reductions["e"](tensor([1, 1])) == -1
Example #10
0
def test_metric_scripts():
    torch.jit.script(DummyMetric())
    torch.jit.script(DummyMetricSum())
Example #11
0
def test_inherit():
    DummyMetric()
Example #12
0
def test_inherit():
    """Test that metric that inherits can be instanciated."""
    DummyMetric()
Example #13
0
def test_metric_scripts():
    """test that metrics are scriptable."""
    torch.jit.script(DummyMetric())
    torch.jit.script(DummyMetricSum())
Example #14
0
 def __init__(self):
     super().__init__()
     self.metric = DummyMetric()
     self.metric.add_state("a", tensor(0), persistent=True)
     self.metric.add_state("b", [], persistent=True)
     self.metric.register_buffer("c", tensor(0))
Example #15
0
 def __init__(self):
     super().__init__()
     self.metric = DummyMetric()
     self.metric.add_state('a', torch.tensor(0), persistent=True)
     self.metric.add_state('b', [], persistent=True)
     self.metric.register_buffer('c', torch.tensor(0))