예제 #1
0
def _ddp_test_fn(rank, worldsize):
    _setup_ddp(rank, worldsize)
    torch.tensor([1.0])

    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()

    # dist_sync_on_step is False by default
    result = Result()

    for epoch in range(3):
        cumulative_sum = 0

        for i in range(5):
            metric_a(i)
            metric_b(i)
            metric_c(i)

            cumulative_sum += i

            result.log('a', metric_a, on_step=True, on_epoch=True)
            result.log('b', metric_b, on_step=False, on_epoch=True)
            result.log('c', metric_c, on_step=True, on_epoch=False)

            batch_log = result.get_batch_log_metrics()
            batch_expected = {"a_step": i, "a": i, "c": i}
            assert set(batch_log.keys()) == set(batch_expected.keys())
            for k in batch_expected.keys():
                assert batch_expected[k] == batch_log[k]

        epoch_log = result.get_epoch_log_metrics()
        result.reset()

        # assert metric state reset to default values
        assert metric_a.x == metric_a._defaults['x']
        assert metric_b.x == metric_b._defaults['x']
        assert metric_c.x == metric_c._defaults['x']

        epoch_expected = {
            "b": cumulative_sum * worldsize,
            "a_epoch": cumulative_sum * worldsize
        }

        assert set(epoch_log.keys()) == set(epoch_expected.keys())
        for k in epoch_expected.keys():
            assert epoch_expected[k] == epoch_log[k]
예제 #2
0
def test_result_metric_integration():
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()

    result = Result()

    for epoch in range(3):
        cumulative_sum = 0

        for i in range(5):
            metric_a(i)
            metric_b(i)
            metric_c(i)

            cumulative_sum += i

            result.log('a', metric_a, on_step=True, on_epoch=True)
            result.log('b', metric_b, on_step=False, on_epoch=True)
            result.log('c', metric_c, on_step=True, on_epoch=False)

            batch_log = result.get_batch_log_metrics()
            batch_expected = {"a_step": i, "a": i, "c": i}
            assert set(batch_log.keys()) == set(batch_expected.keys())
            for k in batch_expected.keys():
                assert batch_expected[k] == batch_log[k]

        epoch_log = result.get_epoch_log_metrics()

        # assert metric state reset to default values
        assert metric_a.x == metric_a._defaults['x']
        assert metric_b.x == metric_b._defaults['x']
        assert metric_c.x == metric_c._defaults['x']

        epoch_expected = {
            "b": cumulative_sum,
            "a": cumulative_sum,
            "a_epoch": cumulative_sum
        }

        assert set(epoch_log.keys()) == set(epoch_expected.keys())
        for k in epoch_expected.keys():
            assert epoch_expected[k] == epoch_log[k]