Пример #1
0
def test_result_collection_restoration(tmpdir):
    """This test make sure metrics are properly reloaded on failure."""

    result = ResultCollection(True, torch.device("cpu"))
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()
    metric_d = DummyMetric()
    current_fx_name = None
    batch_idx = None

    def lightning_log(fx, *args, **kwargs):
        nonlocal current_fx_name
        if current_fx_name != fx and batch_idx in (None, 0):
            result.reset(metrics=False, fx=fx)
        result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
        current_fx_name = fx

    for epoch in range(2):

        cumulative_sum = 0

        for i in range(3):

            a = metric_a(i)
            b = metric_b(i)
            c = metric_c(i)
            metric_d(i)

            cumulative_sum += i

            metric = metric_a if i < 1 else metric_d
            lightning_log("training_step",
                          "a",
                          metric,
                          on_step=True,
                          on_epoch=True,
                          metric_attribute="metric")
            lightning_log("training_step",
                          "b",
                          metric_b,
                          on_step=False,
                          on_epoch=True,
                          metric_attribute="metric_b")
            lightning_log("training_step",
                          "c",
                          metric_c,
                          on_step=True,
                          on_epoch=False,
                          metric_attribute="metric_c")
            lightning_log("training_step",
                          "a_1",
                          a,
                          on_step=True,
                          on_epoch=True)
            lightning_log("training_step",
                          "b_1",
                          b,
                          on_step=False,
                          on_epoch=True)
            lightning_log("training_step",
                          "c_1", {
                              "1": c,
                              "2": c
                          },
                          on_step=True,
                          on_epoch=False)

            batch_log = result.metrics(on_step=True)["log"]
            assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
            assert set(batch_log["c_1"]) == {"1", "2"}

            result_copy = deepcopy(result)
            new_result = ResultCollection(True, torch.device("cpu"))
            state_dict = result.state_dict()
            # check the sync fn was dropped
            assert "fn" not in state_dict["items"]["training_step.a"]["meta"][
                "_sync"]

            assert not new_result.result_metrics
            assert len(result.result_metrics) == 7 + epoch > 0

            new_result.load_state_dict(state_dict,
                                       metrics={
                                           "metric": metric,
                                           "metric_b": metric_b,
                                           "metric_c": metric_c
                                       })
            # should match
            assert result_copy == new_result
            # the sync fn has been kept
            assert result_copy["training_step.a"].meta.sync.fn == new_result[
                "training_step.a"].meta.sync.fn

        epoch_log = result.metrics(on_step=False)["log"]
        epoch_log_copy = result_copy.metrics(on_step=False)["log"]
        assert epoch_log == epoch_log_copy

        lightning_log("train_epoch_end",
                      "a",
                      metric_a,
                      on_step=False,
                      on_epoch=True)
        epoch_log = result.metrics(on_step=False)["log"]
        assert epoch_log == {
            "a_1_epoch": 1,
            "a_epoch": cumulative_sum,
            "a": cumulative_sum,
            "b": cumulative_sum,
            "b_1": 1,
        }

        # make sure can be pickled
        pickle.loads(pickle.dumps(result))
        # make sure can be torch.loaded
        filepath = str(tmpdir / "result")
        torch.save(result, filepath)
        torch.load(filepath)

        # assert metric state reset to default values
        result.reset()
        assert metric_a.x == metric_a._defaults["x"]
        assert metric_b.x == metric_b._defaults["x"]
        assert metric_c.x == metric_c._defaults["x"]

        batch_idx = None
Пример #2
0
def test_result_collection_restoration(tmpdir):
    """"
    This test make sure metrics are properly reloaded on failure.
    """

    result = ResultCollection(True, torch.device("cpu"))
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()
    metric_d = DummyMetric()
    current_fx_name = None
    batch_idx = None

    def lightning_log(fx, *args, **kwargs):
        nonlocal current_fx_name
        if current_fx_name != fx and batch_idx in (None, 0):
            result.reset(metrics=False, fx=fx)
        result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
        current_fx_name = fx

    for _ in range(2):

        cumulative_sum = 0

        for i in range(3):

            a = metric_a(i)
            b = metric_b(i)
            c = metric_c(i)
            metric_d(i)

            cumulative_sum += i

            metric = metric_a if i < 1 else metric_d
            lightning_log('training_step',
                          'a',
                          metric,
                          on_step=True,
                          on_epoch=True)
            lightning_log('training_step',
                          'b',
                          metric_b,
                          on_step=False,
                          on_epoch=True)
            lightning_log('training_step',
                          'c',
                          metric_c,
                          on_step=True,
                          on_epoch=False)
            lightning_log('training_step',
                          'a_1',
                          a,
                          on_step=True,
                          on_epoch=True)
            lightning_log('training_step',
                          'b_1',
                          b,
                          on_step=False,
                          on_epoch=True)
            lightning_log('training_step',
                          'c_1', {
                              '1': c,
                              '2': c
                          },
                          on_step=True,
                          on_epoch=False)

            batch_log = result.metrics(on_step=True)[MetricSource.LOG]
            assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
            assert set(batch_log['c_1']) == {'1', '2'}

            result_copy = deepcopy(result)
            new_result = ResultCollection(True, torch.device("cpu"))
            state_dict = result.state_dict()
            # check the sync fn was dropped
            assert 'fn' not in state_dict['items']['training_step.a']['meta'][
                '_sync']
            new_result.load_state_dict(state_dict)
            # should match
            assert result_copy == new_result
            # the sync fn has been kept
            assert result_copy['training_step.a'].meta.sync.fn == new_result[
                'training_step.a'].meta.sync.fn

        epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
        epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG]
        assert epoch_log == epoch_log_copy

        lightning_log('train_epoch_end',
                      'a',
                      metric_a,
                      on_step=False,
                      on_epoch=True)
        epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
        assert epoch_log == {
            'a_1_epoch': 1,
            'a_epoch': cumulative_sum,
            'a': cumulative_sum,
            'b': cumulative_sum,
            'b_1': 1
        }

        # make sure can be pickled
        pickle.loads(pickle.dumps(result))
        # make sure can be torch.loaded
        filepath = str(tmpdir / 'result')
        torch.save(result, filepath)
        torch.load(filepath)

        # assert metric state reset to default values
        result.reset()
        assert metric_a.x == metric_a._defaults['x']
        assert metric_b.x == metric_b._defaults['x']
        assert metric_c.x == metric_c._defaults['x']

        batch_idx = None