def _ddp_test_fn(rank, worldsize): _setup_ddp(rank, worldsize) torch.tensor([1.0]) metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() # dist_sync_on_step is False by default result = Result() for epoch in range(3): cumulative_sum = 0 for i in range(5): metric_a(i) metric_b(i) metric_c(i) cumulative_sum += i result.log('a', metric_a, on_step=True, on_epoch=True) result.log('b', metric_b, on_step=False, on_epoch=True) result.log('c', metric_c, on_step=True, on_epoch=False) batch_log = result.get_batch_log_metrics() batch_expected = {"a_step": i, "a": i, "c": i} assert set(batch_log.keys()) == set(batch_expected.keys()) for k in batch_expected.keys(): assert batch_expected[k] == batch_log[k] epoch_log = result.get_epoch_log_metrics() result.reset() # assert metric state reset to default values assert metric_a.x == metric_a._defaults['x'] assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] epoch_expected = { "b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize } assert set(epoch_log.keys()) == set(epoch_expected.keys()) for k in epoch_expected.keys(): assert epoch_expected[k] == epoch_log[k]
def test_result_metric_integration(): metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() result = Result() for epoch in range(3): cumulative_sum = 0 for i in range(5): metric_a(i) metric_b(i) metric_c(i) cumulative_sum += i result.log('a', metric_a, on_step=True, on_epoch=True) result.log('b', metric_b, on_step=False, on_epoch=True) result.log('c', metric_c, on_step=True, on_epoch=False) batch_log = result.get_batch_log_metrics() batch_expected = {"a_step": i, "a": i, "c": i} assert set(batch_log.keys()) == set(batch_expected.keys()) for k in batch_expected.keys(): assert batch_expected[k] == batch_log[k] epoch_log = result.get_epoch_log_metrics() result.reset() # assert metric state reset to default values assert metric_a.x == metric_a._defaults['x'] assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] epoch_expected = {"b": cumulative_sum, "a_epoch": cumulative_sum} assert set(epoch_log.keys()) == set(epoch_expected.keys()) for k in epoch_expected.keys(): assert epoch_expected[k] == epoch_log[k]