def test_reset_compute(): a = DummyMetricSum() assert a.x == 0 a.update(tensor(5)) assert a.compute() == 5 a.reset() assert a.compute() == 0
def test_load_state_dict(tmpdir): """ test that metric states can be loaded with state dict """ metric = DummyMetricSum() metric.persistent(True) metric.update(5) loaded_metric = DummyMetricSum() loaded_metric.load_state_dict(metric.state_dict()) assert metric.compute() == 5
def _test_ddp_compositional_tensor(rank, worldsize): setup_ddp(rank, worldsize) dummy = DummyMetricSum() dummy._reductions = {"x": torch.sum} dummy = dummy.clone() + dummy.clone() dummy.update(tensor(1)) val = dummy.compute() assert val == 2 * worldsize
def test_reset_compute(): a = DummyMetricSum() assert a.x == 0 a.update(tensor(5)) assert a.compute() == 5 a.reset() if not _LIGHTNING_AVAILABLE or _LIGHTNING_GREATER_EQUAL_1_3: assert a.compute() == 0 else: assert a.compute() == 5
def test_pickle(tmpdir): # doesn't tests for DDP a = DummyMetricSum() a.update(1) metric_pickled = pickle.dumps(a) metric_loaded = pickle.loads(metric_pickled) assert metric_loaded.compute() == 1 metric_loaded.update(5) assert metric_loaded.compute() == 6 metric_pickled = cloudpickle.dumps(a) metric_loaded = cloudpickle.loads(metric_pickled) assert metric_loaded.compute() == 1
def test_warning_on_compute_before_update(): metric = DummyMetricSum() # make sure everything is fine with forward with pytest.warns(None) as record: val = metric(1) assert not record metric.reset() with pytest.warns(UserWarning, match=r'The ``compute`` method of metric .*'): val = metric.compute() assert val == 0.0 # after update things should be fine metric.update(2.0) with pytest.warns(None) as record: val = metric.compute() assert not record assert val == 2.0
def test_warning_on_compute_before_update(): """test that an warning is raised if user tries to call compute before update.""" metric = DummyMetricSum() # make sure everything is fine with forward with pytest.warns(None) as record: val = metric(1) assert not record metric.reset() with pytest.warns(UserWarning, match=r"The ``compute`` method of metric .*"): val = metric.compute() assert val == 0.0 # after update things should be fine metric.update(2.0) with pytest.warns(None) as record: val = metric.compute() assert not record assert val == 2.0
def test_constant_memory(device, requires_grad): """Checks that when updating a metric the memory does not increase.""" if not torch.cuda.is_available() and device == "cuda": pytest.skip("Test requires GPU support") def get_memory_usage(): if device == "cpu": pid = os.getpid() py = psutil.Process(pid) return py.memory_info()[0] / 2.0**30 else: return torch.cuda.memory_allocated() x = torch.randn(10, requires_grad=requires_grad, device=device) # try update method metric = DummyMetricSum().to(device) metric.update(x.sum()) # we allow for 5% flucturation due to measuring base_memory_level = 1.05 * get_memory_usage() for _ in range(10): metric.update(x.sum()) memory = get_memory_usage() assert base_memory_level >= memory, "memory increased above base level" # try forward method metric = DummyMetricSum().to(device) metric(x.sum()) base_memory_level = get_memory_usage() for _ in range(10): metric.update(x.sum()) memory = get_memory_usage() assert base_memory_level >= memory, "memory increased above base level"