def test_result_collection_restoration(tmpdir): """This test make sure metrics are properly reloaded on failure.""" result = ResultCollection(True, torch.device("cpu")) metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() metric_d = DummyMetric() current_fx_name = None batch_idx = None def lightning_log(fx, *args, **kwargs): nonlocal current_fx_name if current_fx_name != fx and batch_idx in (None, 0): result.reset(metrics=False, fx=fx) result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist) current_fx_name = fx for epoch in range(2): cumulative_sum = 0 for i in range(3): a = metric_a(i) b = metric_b(i) c = metric_c(i) metric_d(i) cumulative_sum += i metric = metric_a if i < 1 else metric_d lightning_log("training_step", "a", metric, on_step=True, on_epoch=True, metric_attribute="metric") lightning_log("training_step", "b", metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") lightning_log("training_step", "c", metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") lightning_log("training_step", "a_1", a, on_step=True, on_epoch=True) lightning_log("training_step", "b_1", b, on_step=False, on_epoch=True) lightning_log("training_step", "c_1", { "1": c, "2": c }, on_step=True, on_epoch=False) batch_log = result.metrics(on_step=True)["log"] assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"} assert set(batch_log["c_1"]) == {"1", "2"} result_copy = deepcopy(result) new_result = ResultCollection(True, torch.device("cpu")) state_dict = result.state_dict() # check the sync fn was dropped assert "fn" not in state_dict["items"]["training_step.a"]["meta"][ "_sync"] assert not new_result.result_metrics assert len(result.result_metrics) == 7 + epoch > 0 new_result.load_state_dict(state_dict, metrics={ "metric": metric, "metric_b": metric_b, "metric_c": metric_c }) # should match assert result_copy == new_result # the sync fn has been kept assert result_copy["training_step.a"].meta.sync.fn == new_result[ "training_step.a"].meta.sync.fn epoch_log = result.metrics(on_step=False)["log"] epoch_log_copy = result_copy.metrics(on_step=False)["log"] assert epoch_log == epoch_log_copy lightning_log("train_epoch_end", "a", metric_a, on_step=False, on_epoch=True) epoch_log = result.metrics(on_step=False)["log"] assert epoch_log == { "a_1_epoch": 1, "a_epoch": cumulative_sum, "a": cumulative_sum, "b": cumulative_sum, "b_1": 1, } # make sure can be pickled pickle.loads(pickle.dumps(result)) # make sure can be torch.loaded filepath = str(tmpdir / "result") torch.save(result, filepath) torch.load(filepath) # assert metric state reset to default values result.reset() assert metric_a.x == metric_a._defaults["x"] assert metric_b.x == metric_b._defaults["x"] assert metric_c.x == metric_c._defaults["x"] batch_idx = None
def test_result_collection_restoration(tmpdir): """" This test make sure metrics are properly reloaded on failure. """ result = ResultCollection(True, torch.device("cpu")) metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() metric_d = DummyMetric() current_fx_name = None batch_idx = None def lightning_log(fx, *args, **kwargs): nonlocal current_fx_name if current_fx_name != fx and batch_idx in (None, 0): result.reset(metrics=False, fx=fx) result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist) current_fx_name = fx for _ in range(2): cumulative_sum = 0 for i in range(3): a = metric_a(i) b = metric_b(i) c = metric_c(i) metric_d(i) cumulative_sum += i metric = metric_a if i < 1 else metric_d lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True) lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True) lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False) lightning_log('training_step', 'a_1', a, on_step=True, on_epoch=True) lightning_log('training_step', 'b_1', b, on_step=False, on_epoch=True) lightning_log('training_step', 'c_1', { '1': c, '2': c }, on_step=True, on_epoch=False) batch_log = result.metrics(on_step=True)[MetricSource.LOG] assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"} assert set(batch_log['c_1']) == {'1', '2'} result_copy = deepcopy(result) new_result = ResultCollection(True, torch.device("cpu")) state_dict = result.state_dict() # check the sync fn was dropped assert 'fn' not in state_dict['items']['training_step.a']['meta'][ '_sync'] new_result.load_state_dict(state_dict) # should match assert result_copy == new_result # the sync fn has been kept assert result_copy['training_step.a'].meta.sync.fn == new_result[ 'training_step.a'].meta.sync.fn epoch_log = result.metrics(on_step=False)[MetricSource.LOG] epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG] assert epoch_log == epoch_log_copy lightning_log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True) epoch_log = result.metrics(on_step=False)[MetricSource.LOG] assert epoch_log == { 'a_1_epoch': 1, 'a_epoch': cumulative_sum, 'a': cumulative_sum, 'b': cumulative_sum, 'b_1': 1 } # make sure can be pickled pickle.loads(pickle.dumps(result)) # make sure can be torch.loaded filepath = str(tmpdir / 'result') torch.save(result, filepath) torch.load(filepath) # assert metric state reset to default values result.reset() assert metric_a.x == metric_a._defaults['x'] assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] batch_idx = None