コード例 #1
0
        def on_save_checkpoint(self, checkpoint) -> None:
            results = self.trainer._results
            # simplify logic
            state_dict = results.state_dict(drop_value=False)

            # check device
            assert results["validation_step.v"].value.device.type == device
            assert state_dict["items"]["validation_step.v"][
                "value"].device.type == device

            # sync fn should be kept
            assert results[
                "validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce

            # sync fn dropped from the state dict
            assert "fn" not in state_dict["items"]["validation_step.v"][
                "meta"]["_sync"]
            results.load_state_dict(state_dict)

            # check device after loading
            assert results["validation_step.v"].value.device.type == device

            # sync fn was preserved in the original result
            assert results[
                "validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce

            # default sync fn
            new_results = ResultCollection(False, device)
            new_results.load_state_dict(state_dict, map_location="cpu")
            assert new_results["validation_step.v"].meta.sync.fn is None

            # check map location
            assert new_results["validation_step.v"].value.device.type == "cpu"
コード例 #2
0
        def on_save_checkpoint(self, checkpoint) -> None:
            results = self.trainer._results
            # simplify logic
            state_dict = results.state_dict(drop_value=False)

            # check device
            assert results['validation_step.v'].value.device.type == device
            assert state_dict['items']['validation_step.v'][
                'value'].device.type == device

            # sync fn should be kept
            assert results[
                'validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce

            # sync fn dropped from the state dict
            assert 'fn' not in state_dict['items']['validation_step.v'][
                'meta']['_sync']
            results.load_state_dict(state_dict)

            # check device after loading
            assert results['validation_step.v'].value.device.type == device

            # sync fn was preserved in the original result
            assert results[
                'validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce

            # default sync fn
            new_results = ResultCollection(False, device)
            new_results.load_state_dict(state_dict, map_location='cpu')
            assert new_results['validation_step.v'].meta.sync.fn == _Sync.no_op

            # check map location
            assert new_results['validation_step.v'].value.device.type == 'cpu'
コード例 #3
0
def test_result_collection_restoration(tmpdir):
    """This test make sure metrics are properly reloaded on failure."""

    result = ResultCollection(True, torch.device("cpu"))
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()
    metric_d = DummyMetric()
    current_fx_name = None
    batch_idx = None

    def lightning_log(fx, *args, **kwargs):
        nonlocal current_fx_name
        if current_fx_name != fx and batch_idx in (None, 0):
            result.reset(metrics=False, fx=fx)
        result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
        current_fx_name = fx

    for epoch in range(2):

        cumulative_sum = 0

        for i in range(3):

            a = metric_a(i)
            b = metric_b(i)
            c = metric_c(i)
            metric_d(i)

            cumulative_sum += i

            metric = metric_a if i < 1 else metric_d
            lightning_log("training_step",
                          "a",
                          metric,
                          on_step=True,
                          on_epoch=True,
                          metric_attribute="metric")
            lightning_log("training_step",
                          "b",
                          metric_b,
                          on_step=False,
                          on_epoch=True,
                          metric_attribute="metric_b")
            lightning_log("training_step",
                          "c",
                          metric_c,
                          on_step=True,
                          on_epoch=False,
                          metric_attribute="metric_c")
            lightning_log("training_step",
                          "a_1",
                          a,
                          on_step=True,
                          on_epoch=True)
            lightning_log("training_step",
                          "b_1",
                          b,
                          on_step=False,
                          on_epoch=True)
            lightning_log("training_step",
                          "c_1", {
                              "1": c,
                              "2": c
                          },
                          on_step=True,
                          on_epoch=False)

            batch_log = result.metrics(on_step=True)["log"]
            assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
            assert set(batch_log["c_1"]) == {"1", "2"}

            result_copy = deepcopy(result)
            new_result = ResultCollection(True, torch.device("cpu"))
            state_dict = result.state_dict()
            # check the sync fn was dropped
            assert "fn" not in state_dict["items"]["training_step.a"]["meta"][
                "_sync"]

            assert not new_result.result_metrics
            assert len(result.result_metrics) == 7 + epoch > 0

            new_result.load_state_dict(state_dict,
                                       metrics={
                                           "metric": metric,
                                           "metric_b": metric_b,
                                           "metric_c": metric_c
                                       })
            # should match
            assert result_copy == new_result
            # the sync fn has been kept
            assert result_copy["training_step.a"].meta.sync.fn == new_result[
                "training_step.a"].meta.sync.fn

        epoch_log = result.metrics(on_step=False)["log"]
        epoch_log_copy = result_copy.metrics(on_step=False)["log"]
        assert epoch_log == epoch_log_copy

        lightning_log("train_epoch_end",
                      "a",
                      metric_a,
                      on_step=False,
                      on_epoch=True)
        epoch_log = result.metrics(on_step=False)["log"]
        assert epoch_log == {
            "a_1_epoch": 1,
            "a_epoch": cumulative_sum,
            "a": cumulative_sum,
            "b": cumulative_sum,
            "b_1": 1,
        }

        # make sure can be pickled
        pickle.loads(pickle.dumps(result))
        # make sure can be torch.loaded
        filepath = str(tmpdir / "result")
        torch.save(result, filepath)
        torch.load(filepath)

        # assert metric state reset to default values
        result.reset()
        assert metric_a.x == metric_a._defaults["x"]
        assert metric_b.x == metric_b._defaults["x"]
        assert metric_c.x == metric_c._defaults["x"]

        batch_idx = None
コード例 #4
0
def test_result_collection_restoration(tmpdir):
    """"
    This test make sure metrics are properly reloaded on failure.
    """

    result = ResultCollection(True, torch.device("cpu"))
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()
    metric_d = DummyMetric()
    current_fx_name = None
    batch_idx = None

    def lightning_log(fx, *args, **kwargs):
        nonlocal current_fx_name
        if current_fx_name != fx and batch_idx in (None, 0):
            result.reset(metrics=False, fx=fx)
        result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
        current_fx_name = fx

    for _ in range(2):

        cumulative_sum = 0

        for i in range(3):

            a = metric_a(i)
            b = metric_b(i)
            c = metric_c(i)
            metric_d(i)

            cumulative_sum += i

            metric = metric_a if i < 1 else metric_d
            lightning_log('training_step',
                          'a',
                          metric,
                          on_step=True,
                          on_epoch=True)
            lightning_log('training_step',
                          'b',
                          metric_b,
                          on_step=False,
                          on_epoch=True)
            lightning_log('training_step',
                          'c',
                          metric_c,
                          on_step=True,
                          on_epoch=False)
            lightning_log('training_step',
                          'a_1',
                          a,
                          on_step=True,
                          on_epoch=True)
            lightning_log('training_step',
                          'b_1',
                          b,
                          on_step=False,
                          on_epoch=True)
            lightning_log('training_step',
                          'c_1', {
                              '1': c,
                              '2': c
                          },
                          on_step=True,
                          on_epoch=False)

            batch_log = result.metrics(on_step=True)[MetricSource.LOG]
            assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
            assert set(batch_log['c_1']) == {'1', '2'}

            result_copy = deepcopy(result)
            new_result = ResultCollection(True, torch.device("cpu"))
            state_dict = result.state_dict()
            # check the sync fn was dropped
            assert 'fn' not in state_dict['items']['training_step.a']['meta'][
                '_sync']
            new_result.load_state_dict(state_dict)
            # should match
            assert result_copy == new_result
            # the sync fn has been kept
            assert result_copy['training_step.a'].meta.sync.fn == new_result[
                'training_step.a'].meta.sync.fn

        epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
        epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG]
        assert epoch_log == epoch_log_copy

        lightning_log('train_epoch_end',
                      'a',
                      metric_a,
                      on_step=False,
                      on_epoch=True)
        epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
        assert epoch_log == {
            'a_1_epoch': 1,
            'a_epoch': cumulative_sum,
            'a': cumulative_sum,
            'b': cumulative_sum,
            'b_1': 1
        }

        # make sure can be pickled
        pickle.loads(pickle.dumps(result))
        # make sure can be torch.loaded
        filepath = str(tmpdir / 'result')
        torch.save(result, filepath)
        torch.load(filepath)

        # assert metric state reset to default values
        result.reset()
        assert metric_a.x == metric_a._defaults['x']
        assert metric_b.x == metric_b._defaults['x']
        assert metric_c.x == metric_c._defaults['x']

        batch_idx = None