Esempio n. 1
0
def test_reset_compute():
    a = DummyMetricSum()
    assert a.x == 0
    a.update(tensor(5))
    assert a.compute() == 5
    a.reset()
    assert a.compute() == 0
Esempio n. 2
0
def test_load_state_dict(tmpdir):
    """ test that metric states can be loaded with state dict """
    metric = DummyMetricSum()
    metric.persistent(True)
    metric.update(5)
    loaded_metric = DummyMetricSum()
    loaded_metric.load_state_dict(metric.state_dict())
    assert metric.compute() == 5
Esempio n. 3
0
def _test_ddp_compositional_tensor(rank, worldsize):
    setup_ddp(rank, worldsize)
    dummy = DummyMetricSum()
    dummy._reductions = {"x": torch.sum}
    dummy = dummy.clone() + dummy.clone()
    dummy.update(tensor(1))
    val = dummy.compute()
    assert val == 2 * worldsize
Esempio n. 4
0
def test_reset_compute():
    a = DummyMetricSum()
    assert a.x == 0
    a.update(tensor(5))
    assert a.compute() == 5
    a.reset()
    if not _LIGHTNING_AVAILABLE or _LIGHTNING_GREATER_EQUAL_1_3:
        assert a.compute() == 0
    else:
        assert a.compute() == 5
Esempio n. 5
0
def test_pickle(tmpdir):
    # doesn't tests for DDP
    a = DummyMetricSum()
    a.update(1)

    metric_pickled = pickle.dumps(a)
    metric_loaded = pickle.loads(metric_pickled)

    assert metric_loaded.compute() == 1

    metric_loaded.update(5)
    assert metric_loaded.compute() == 6

    metric_pickled = cloudpickle.dumps(a)
    metric_loaded = cloudpickle.loads(metric_pickled)

    assert metric_loaded.compute() == 1
Esempio n. 6
0
def test_warning_on_compute_before_update():
    metric = DummyMetricSum()

    # make sure everything is fine with forward
    with pytest.warns(None) as record:
        val = metric(1)
    assert not record

    metric.reset()

    with pytest.warns(UserWarning, match=r'The ``compute`` method of metric .*'):
        val = metric.compute()
    assert val == 0.0

    # after update things should be fine
    metric.update(2.0)
    with pytest.warns(None) as record:
        val = metric.compute()
    assert not record
    assert val == 2.0
Esempio n. 7
0
def test_warning_on_compute_before_update():
    """test that an warning is raised if user tries to call compute before update."""
    metric = DummyMetricSum()

    # make sure everything is fine with forward
    with pytest.warns(None) as record:
        val = metric(1)
    assert not record

    metric.reset()

    with pytest.warns(UserWarning, match=r"The ``compute`` method of metric .*"):
        val = metric.compute()
    assert val == 0.0

    # after update things should be fine
    metric.update(2.0)
    with pytest.warns(None) as record:
        val = metric.compute()
    assert not record
    assert val == 2.0
Esempio n. 8
0
def test_constant_memory(device, requires_grad):
    """Checks that when updating a metric the memory does not increase."""
    if not torch.cuda.is_available() and device == "cuda":
        pytest.skip("Test requires GPU support")

    def get_memory_usage():
        if device == "cpu":
            pid = os.getpid()
            py = psutil.Process(pid)
            return py.memory_info()[0] / 2.0**30
        else:
            return torch.cuda.memory_allocated()

    x = torch.randn(10, requires_grad=requires_grad, device=device)

    # try update method
    metric = DummyMetricSum().to(device)

    metric.update(x.sum())

    # we allow for 5% flucturation due to measuring
    base_memory_level = 1.05 * get_memory_usage()

    for _ in range(10):
        metric.update(x.sum())
        memory = get_memory_usage()
        assert base_memory_level >= memory, "memory increased above base level"

    # try forward method
    metric = DummyMetricSum().to(device)
    metric(x.sum())
    base_memory_level = get_memory_usage()

    for _ in range(10):
        metric.update(x.sum())
        memory = get_memory_usage()
        assert base_memory_level >= memory, "memory increased above base level"