Esempio n. 1
0
def test_result_metric_integration():
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()

    result = ResultCollection(True, torch.device("cpu"))

    for _ in range(3):
        cumulative_sum = 0
        for i in range(5):
            metric_a(i)
            metric_b(i)
            metric_c(i)

            cumulative_sum += i

            result.log("h", "a", metric_a, on_step=True, on_epoch=True)
            result.log("h", "b", metric_b, on_step=False, on_epoch=True)
            result.log("h", "c", metric_c, on_step=True, on_epoch=False)

            batch_log = result.metrics(True)["log"]
            assert batch_log == {"a_step": i, "c": i}

        epoch_log = result.metrics(False)["log"]
        result.reset()

        # assert metric state reset to default values
        assert metric_a.x == metric_a._defaults["x"]
        assert metric_b.x == metric_b._defaults["x"]
        assert metric_c.x == metric_c._defaults["x"]

        assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}

    result.minimize = torch.tensor(1.0)
    result.extra = {}
    assert str(result) == ("ResultCollection("
                           "minimize=1.0, "
                           "{"
                           "'h.a': ResultMetric('a', value=DummyMetric()), "
                           "'h.b': ResultMetric('b', value=DummyMetric()), "
                           "'h.c': ResultMetric('c', value=DummyMetric())"
                           "})")
    assert repr(result) == ("{"
                            "True, "
                            "device(type='cpu'), "
                            "minimize=tensor(1.), "
                            "{'h.a': ResultMetric('a', value=DummyMetric()), "
                            "'h.b': ResultMetric('b', value=DummyMetric()), "
                            "'h.c': ResultMetric('c', value=DummyMetric()), "
                            "'_extra': {}}"
                            "}")
Esempio n. 2
0
def _ddp_test_fn(rank, worldsize):
    _setup_ddp(rank, worldsize)
    torch.tensor([1.0])

    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()

    metric_a = metric_a.to(f"cuda:{rank}")
    metric_b = metric_b.to(f"cuda:{rank}")
    metric_c = metric_c.to(f"cuda:{rank}")

    result = ResultCollection(True, torch.device(f"cuda:{rank}"))

    for _ in range(3):
        cumulative_sum = 0
        for i in range(5):
            metric_a(i)
            metric_b(i)
            metric_c(i)

            cumulative_sum += i

            result.log("h", "a", metric_a, on_step=True, on_epoch=True)
            result.log("h", "b", metric_b, on_step=False, on_epoch=True)
            result.log("h", "c", metric_c, on_step=True, on_epoch=False)

            batch_log = result.metrics(True)["log"]
            assert batch_log == {"a_step": i, "c": i}

        epoch_log = result.metrics(False)["log"]
        result.reset()

        # assert metric state reset to default values
        assert metric_a.x == metric_a._defaults["x"], (metric_a.x,
                                                       metric_a._defaults["x"])
        assert metric_b.x == metric_b._defaults["x"]
        assert metric_c.x == metric_c._defaults["x"]

        assert epoch_log == {
            "b": cumulative_sum * worldsize,
            "a_epoch": cumulative_sum * worldsize
        }
def test_result_metric_integration():
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()

    result = ResultCollection(True, torch.device("cpu"))

    for _ in range(3):
        cumulative_sum = 0
        for i in range(5):
            metric_a(i)
            metric_b(i)
            metric_c(i)

            cumulative_sum += i

            result.log('h', 'a', metric_a, on_step=True, on_epoch=True)
            result.log('h', 'b', metric_b, on_step=False, on_epoch=True)
            result.log('h', 'c', metric_c, on_step=True, on_epoch=False)

            batch_log = result.metrics(True)[MetricSource.LOG]
            assert batch_log == {"a_step": i, "c": i}

        epoch_log = result.metrics(False)[MetricSource.LOG]
        result.reset()

        # assert metric state reset to default values
        assert metric_a.x == metric_a._defaults['x']
        assert metric_b.x == metric_b._defaults['x']
        assert metric_c.x == metric_c._defaults['x']

        assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}

    assert str(result) == (
        "ResultCollection(True, cpu, {"
        "'h.a': ResultMetric('a', value=DummyMetric()), "
        "'h.b': ResultMetric('b', value=DummyMetric()), "
        "'h.c': ResultMetric('c', value=DummyMetric())"
        "})"
    )
Esempio n. 4
0
def test_result_collection_restoration(tmpdir):
    """This test make sure metrics are properly reloaded on failure."""

    result = ResultCollection(True, torch.device("cpu"))
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()
    metric_d = DummyMetric()
    current_fx_name = None
    batch_idx = None

    def lightning_log(fx, *args, **kwargs):
        nonlocal current_fx_name
        if current_fx_name != fx and batch_idx in (None, 0):
            result.reset(metrics=False, fx=fx)
        result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
        current_fx_name = fx

    for epoch in range(2):

        cumulative_sum = 0

        for i in range(3):

            a = metric_a(i)
            b = metric_b(i)
            c = metric_c(i)
            metric_d(i)

            cumulative_sum += i

            metric = metric_a if i < 1 else metric_d
            lightning_log("training_step",
                          "a",
                          metric,
                          on_step=True,
                          on_epoch=True,
                          metric_attribute="metric")
            lightning_log("training_step",
                          "b",
                          metric_b,
                          on_step=False,
                          on_epoch=True,
                          metric_attribute="metric_b")
            lightning_log("training_step",
                          "c",
                          metric_c,
                          on_step=True,
                          on_epoch=False,
                          metric_attribute="metric_c")
            lightning_log("training_step",
                          "a_1",
                          a,
                          on_step=True,
                          on_epoch=True)
            lightning_log("training_step",
                          "b_1",
                          b,
                          on_step=False,
                          on_epoch=True)
            lightning_log("training_step",
                          "c_1", {
                              "1": c,
                              "2": c
                          },
                          on_step=True,
                          on_epoch=False)

            batch_log = result.metrics(on_step=True)["log"]
            assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
            assert set(batch_log["c_1"]) == {"1", "2"}

            result_copy = deepcopy(result)
            new_result = ResultCollection(True, torch.device("cpu"))
            state_dict = result.state_dict()
            # check the sync fn was dropped
            assert "fn" not in state_dict["items"]["training_step.a"]["meta"][
                "_sync"]

            assert not new_result.result_metrics
            assert len(result.result_metrics) == 7 + epoch > 0

            new_result.load_state_dict(state_dict,
                                       metrics={
                                           "metric": metric,
                                           "metric_b": metric_b,
                                           "metric_c": metric_c
                                       })
            # should match
            assert result_copy == new_result
            # the sync fn has been kept
            assert result_copy["training_step.a"].meta.sync.fn == new_result[
                "training_step.a"].meta.sync.fn

        epoch_log = result.metrics(on_step=False)["log"]
        epoch_log_copy = result_copy.metrics(on_step=False)["log"]
        assert epoch_log == epoch_log_copy

        lightning_log("train_epoch_end",
                      "a",
                      metric_a,
                      on_step=False,
                      on_epoch=True)
        epoch_log = result.metrics(on_step=False)["log"]
        assert epoch_log == {
            "a_1_epoch": 1,
            "a_epoch": cumulative_sum,
            "a": cumulative_sum,
            "b": cumulative_sum,
            "b_1": 1,
        }

        # make sure can be pickled
        pickle.loads(pickle.dumps(result))
        # make sure can be torch.loaded
        filepath = str(tmpdir / "result")
        torch.save(result, filepath)
        torch.load(filepath)

        # assert metric state reset to default values
        result.reset()
        assert metric_a.x == metric_a._defaults["x"]
        assert metric_b.x == metric_b._defaults["x"]
        assert metric_c.x == metric_c._defaults["x"]

        batch_idx = None
Esempio n. 5
0
def test_result_collection_restoration(tmpdir):
    """"
    This test make sure metrics are properly reloaded on failure.
    """

    result = ResultCollection(True, torch.device("cpu"))
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()
    metric_d = DummyMetric()
    current_fx_name = None
    batch_idx = None

    def lightning_log(fx, *args, **kwargs):
        nonlocal current_fx_name
        if current_fx_name != fx and batch_idx in (None, 0):
            result.reset(metrics=False, fx=fx)
        result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
        current_fx_name = fx

    for _ in range(2):

        cumulative_sum = 0

        for i in range(3):

            a = metric_a(i)
            b = metric_b(i)
            c = metric_c(i)
            metric_d(i)

            cumulative_sum += i

            metric = metric_a if i < 1 else metric_d
            lightning_log('training_step',
                          'a',
                          metric,
                          on_step=True,
                          on_epoch=True)
            lightning_log('training_step',
                          'b',
                          metric_b,
                          on_step=False,
                          on_epoch=True)
            lightning_log('training_step',
                          'c',
                          metric_c,
                          on_step=True,
                          on_epoch=False)
            lightning_log('training_step',
                          'a_1',
                          a,
                          on_step=True,
                          on_epoch=True)
            lightning_log('training_step',
                          'b_1',
                          b,
                          on_step=False,
                          on_epoch=True)
            lightning_log('training_step',
                          'c_1', {
                              '1': c,
                              '2': c
                          },
                          on_step=True,
                          on_epoch=False)

            batch_log = result.metrics(on_step=True)[MetricSource.LOG]
            assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
            assert set(batch_log['c_1']) == {'1', '2'}

            result_copy = deepcopy(result)
            new_result = ResultCollection(True, torch.device("cpu"))
            state_dict = result.state_dict()
            # check the sync fn was dropped
            assert 'fn' not in state_dict['items']['training_step.a']['meta'][
                '_sync']
            new_result.load_state_dict(state_dict)
            # should match
            assert result_copy == new_result
            # the sync fn has been kept
            assert result_copy['training_step.a'].meta.sync.fn == new_result[
                'training_step.a'].meta.sync.fn

        epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
        epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG]
        assert epoch_log == epoch_log_copy

        lightning_log('train_epoch_end',
                      'a',
                      metric_a,
                      on_step=False,
                      on_epoch=True)
        epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
        assert epoch_log == {
            'a_1_epoch': 1,
            'a_epoch': cumulative_sum,
            'a': cumulative_sum,
            'b': cumulative_sum,
            'b_1': 1
        }

        # make sure can be pickled
        pickle.loads(pickle.dumps(result))
        # make sure can be torch.loaded
        filepath = str(tmpdir / 'result')
        torch.save(result, filepath)
        torch.load(filepath)

        # assert metric state reset to default values
        result.reset()
        assert metric_a.x == metric_a._defaults['x']
        assert metric_b.x == metric_b._defaults['x']
        assert metric_c.x == metric_c._defaults['x']

        batch_idx = None