def __init__(self):
        super().__init__()
        self._max_batches: Optional[Union[int, Sequence[int]]] = None
        self.outputs = []
        self.evaluation_loop = EvaluationEpochLoop()

        self._val_results = ResultCollection(training=False)
        self._test_results = ResultCollection(training=False)
Пример #2
0
 def __init__(self, trainer: 'pl.Trainer'):
     self.trainer: 'pl.Trainer' = trainer
     self.outputs: EPOCH_OUTPUT = []
     self.predictions: Optional[PredictionCollection] = None
     self.max_batches: Optional[List[Union[int, float]]] = None
     self.warning_cache = WarningCache()
     self.num_dataloaders: Optional[int] = None
     self._val_results = ResultCollection(training=False)
     self._test_results = ResultCollection(training=False)
Пример #3
0
    def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1) -> None:
        super().__init__()
        if max_steps is None:
            rank_zero_deprecation(
                "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."
                " Use `max_steps = -1` instead."
            )
            max_steps = -1
        elif max_steps < -1:
            raise MisconfigurationException(
                f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {max_steps}."
            )
        self.min_steps = min_steps
        self.max_steps = max_steps

        self.global_step: int = 0
        self.batch_progress = BatchProgress()
        self.scheduler_progress = SchedulerProgress()

        self.batch_loop: Optional[TrainingBatchLoop] = None
        self.val_loop: Optional["loops.EvaluationLoop"] = None

        self._results = ResultCollection(training=True)
        self._outputs: _OUTPUTS_TYPE = []
        self._warning_cache = WarningCache()
        self._dataloader_iter: Optional[Iterator] = None
        # caches the loaded dataloader state until dataloader objects are available
        self._dataloader_state_dict: Dict[str, Any] = {}
        def on_save_checkpoint(self, checkpoint) -> None:
            results = self.trainer._results
            # simplify logic
            state_dict = results.state_dict(drop_value=False)

            # check device
            assert results['validation_step.v'].value.device.type == device
            assert state_dict['items']['validation_step.v'][
                'value'].device.type == device

            # sync fn should be kept
            assert results[
                'validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce

            # sync fn dropped from the state dict
            assert 'fn' not in state_dict['items']['validation_step.v'][
                'meta']['_sync']
            results.load_state_dict(state_dict)

            # check device after loading
            assert results['validation_step.v'].value.device.type == device

            # sync fn was preserved in the original result
            assert results[
                'validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce

            # default sync fn
            new_results = ResultCollection(False, device)
            new_results.load_state_dict(state_dict, map_location='cpu')
            assert new_results['validation_step.v'].meta.sync.fn == _Sync.no_op

            # check map location
            assert new_results['validation_step.v'].value.device.type == 'cpu'
Пример #5
0
    def __init__(self, min_steps: int, max_steps: int):
        super().__init__()
        self.min_steps: int = min_steps
        self.max_steps: int = max_steps

        self.global_step: int = 0

        # the total batch index across all epochs
        self.total_batch_idx: int = 0
        # the current batch index in the loop that runs over the dataloader(s)
        self.iteration_count: int = 0
        # the current split index when the batch gets split into chunks in truncated backprop through time
        self.split_idx: Optional[int] = None

        self._dataloader_idx: Optional[int] = None
        self._should_stop: bool = False

        self.is_last_batch: Optional[bool] = None
        self.batches_seen: int = 0
        self.warning_cache: WarningCache = WarningCache()
        self.epoch_output: Optional[List[List[STEP_OUTPUT]]] = None

        self.batch_loop: Optional[TrainingBatchLoop] = None

        self._results = ResultCollection(training=True)
Пример #6
0
        def on_save_checkpoint(self, checkpoint) -> None:
            results = self.trainer._results
            # simplify logic
            state_dict = results.state_dict(drop_value=False)

            # check device
            assert results["validation_step.v"].value.device.type == device
            assert state_dict["items"]["validation_step.v"][
                "value"].device.type == device

            # sync fn should be kept
            assert results[
                "validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce

            # sync fn dropped from the state dict
            assert "fn" not in state_dict["items"]["validation_step.v"][
                "meta"]["_sync"]
            results.load_state_dict(state_dict)

            # check device after loading
            assert results["validation_step.v"].value.device.type == device

            # sync fn was preserved in the original result
            assert results[
                "validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce

            # default sync fn
            new_results = ResultCollection(False, device)
            new_results.load_state_dict(state_dict, map_location="cpu")
            assert new_results["validation_step.v"].meta.sync.fn is None

            # check map location
            assert new_results["validation_step.v"].value.device.type == "cpu"
Пример #7
0
    def __init__(self):
        super().__init__()
        self.outputs: List[EPOCH_OUTPUT] = []
        self.epoch_loop = EvaluationEpochLoop()

        self._results = ResultCollection(training=False)
        self._max_batches: Optional[Union[int, Sequence[int]]] = None
        self._has_run: bool = False
Пример #8
0
 def __init__(self,
              min_epochs: Optional[int] = None,
              max_epochs: Optional[int] = None,
              min_steps: Optional[int] = None,
              max_steps: Optional[int] = None):
     super().__init__()
     self.max_epochs = 1000 if (max_epochs is None
                                and max_steps is None) else max_epochs
     self.min_epochs = 1 if (min_epochs is None
                             and min_steps is None) else min_epochs
     self.training_loop = TrainingEpochLoop(min_steps, max_steps)
     self.results = ResultCollection(training=True)
Пример #9
0
def test_result_metric_integration():
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()

    result = ResultCollection(True, torch.device("cpu"))

    for _ in range(3):
        cumulative_sum = 0
        for i in range(5):
            metric_a(i)
            metric_b(i)
            metric_c(i)

            cumulative_sum += i

            result.log("h", "a", metric_a, on_step=True, on_epoch=True)
            result.log("h", "b", metric_b, on_step=False, on_epoch=True)
            result.log("h", "c", metric_c, on_step=True, on_epoch=False)

            batch_log = result.metrics(True)["log"]
            assert batch_log == {"a_step": i, "c": i}

        epoch_log = result.metrics(False)["log"]
        result.reset()

        # assert metric state reset to default values
        assert metric_a.x == metric_a._defaults["x"]
        assert metric_b.x == metric_b._defaults["x"]
        assert metric_c.x == metric_c._defaults["x"]

        assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}

    result.minimize = torch.tensor(1.0)
    result.extra = {}
    assert str(result) == ("ResultCollection("
                           "minimize=1.0, "
                           "{"
                           "'h.a': ResultMetric('a', value=DummyMetric()), "
                           "'h.b': ResultMetric('b', value=DummyMetric()), "
                           "'h.c': ResultMetric('c', value=DummyMetric())"
                           "})")
    assert repr(result) == ("{"
                            "True, "
                            "device(type='cpu'), "
                            "minimize=tensor(1.), "
                            "{'h.a': ResultMetric('a', value=DummyMetric()), "
                            "'h.b': ResultMetric('b', value=DummyMetric()), "
                            "'h.c': ResultMetric('c', value=DummyMetric()), "
                            "'_extra': {}}"
                            "}")
    def __init__(self, min_steps: int, max_steps: int):
        super().__init__()
        self.min_steps: int = min_steps
        self.max_steps: int = max_steps

        self.global_step: int = 0
        # manually tracking which is the last batch is necessary for iterable dataset support
        self.is_last_batch: Optional[bool] = None
        self.batch_progress = Progress()
        self.scheduler_progress = SchedulerProgress()

        self.batch_loop: Optional[TrainingBatchLoop] = None
        self.val_loop: Optional["loops.EvaluationLoop"] = None

        self._results = ResultCollection(training=True)
        self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
    def __init__(self, min_steps: int, max_steps: int):
        super().__init__()
        self.min_steps: int = min_steps
        self.max_steps: int = max_steps
        self.global_step: int = 0
        # the total batch index across all epochs
        self.total_batch_idx: int = 0
        self.is_last_batch: Optional[bool] = None
        self.batch_progress = Progress()
        self.scheduler_progress = SchedulerProgress()

        self.batch_loop: Optional[TrainingBatchLoop] = None
        self.val_loop: Optional["loops.EvaluationLoop"] = None

        self._results = ResultCollection(training=True)
        self._dataloader_idx: Optional[int] = None
        self._warning_cache: WarningCache = WarningCache()
        self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
Пример #12
0
    def __init__(
        self,
        trainer,
        max_epochs: Optional[int],
        min_epochs: Optional[int],
        max_steps: Optional[int],
        min_steps: Optional[int],
        num_sanity_val_steps: int,
    ):
        self.trainer = trainer
        self.accumulated_loss = None
        self.warning_cache = WarningCache()
        self.running_loss = TensorRunningAccum(window_length=20)
        self._skip_backward = False
        self._optimizer_freq_cumsum = None
        self._hiddens = None

        self.global_step = 0
        self.current_epoch = 0
        self.trainer.should_stop = False

        # the total batch index across all epochs
        self.total_batch_idx = 0
        # the current batch index in the loop that runs over the dataloader(s)
        self.batch_idx = 0
        # the current split index when the batch gets split into chunks in truncated backprop through time
        self.split_idx = None

        self.trainer.num_training_batches = 0
        self.trainer.train_dataloader = None

        # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000
        self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs
        # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1
        self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs
        self.max_steps = max_steps
        self.min_steps = min_steps

        if num_sanity_val_steps == -1:
            self.trainer.num_sanity_val_steps = float("inf")
        else:
            self.trainer.num_sanity_val_steps = num_sanity_val_steps

        self.results = ResultCollection(training=True)
Пример #13
0
def _ddp_test_fn(rank, worldsize):
    _setup_ddp(rank, worldsize)
    torch.tensor([1.0])

    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()

    metric_a = metric_a.to(f"cuda:{rank}")
    metric_b = metric_b.to(f"cuda:{rank}")
    metric_c = metric_c.to(f"cuda:{rank}")

    result = ResultCollection(True, torch.device(f"cuda:{rank}"))

    for _ in range(3):
        cumulative_sum = 0
        for i in range(5):
            metric_a(i)
            metric_b(i)
            metric_c(i)

            cumulative_sum += i

            result.log("h", "a", metric_a, on_step=True, on_epoch=True)
            result.log("h", "b", metric_b, on_step=False, on_epoch=True)
            result.log("h", "c", metric_c, on_step=True, on_epoch=False)

            batch_log = result.metrics(True)["log"]
            assert batch_log == {"a_step": i, "c": i}

        epoch_log = result.metrics(False)["log"]
        result.reset()

        # assert metric state reset to default values
        assert metric_a.x == metric_a._defaults["x"], (metric_a.x,
                                                       metric_a._defaults["x"])
        assert metric_b.x == metric_b._defaults["x"]
        assert metric_c.x == metric_c._defaults["x"]

        assert epoch_log == {
            "b": cumulative_sum * worldsize,
            "a_epoch": cumulative_sum * worldsize
        }
Пример #14
0
    def __init__(self, min_steps: int, max_steps: int):
        super().__init__()
        self.min_steps: int = min_steps

        if max_steps and max_steps < -1:
            raise MisconfigurationException(
                f"`max_steps` must be a positive integer or -1. You passed in {max_steps}."
            )
        self.max_steps: int = max_steps

        self.global_step: int = 0
        # manually tracking which is the last batch is necessary for iterable dataset support
        self.is_last_batch: Optional[bool] = None
        self.batch_progress = Progress()
        self.scheduler_progress = SchedulerProgress()

        self.batch_loop: Optional[TrainingBatchLoop] = None
        self.val_loop: Optional["loops.EvaluationLoop"] = None

        self._results = ResultCollection(training=True)
        self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
def test_result_collection_simple_loop():
    result = ResultCollection(True, torch.device("cpu"))
    current_fx_name = None
    batch_idx = None

    def lightning_log(fx, *args, **kwargs):
        nonlocal current_fx_name
        if current_fx_name != fx and batch_idx in (None, 0):
            result.reset(metrics=False, fx=fx)
        result.log(fx, *args, **kwargs)
        current_fx_name = fx

    lightning_log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True)
    lightning_log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True)
    for epoch in range(2):
        lightning_log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True)
        lightning_log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True)
        for batch_idx in range(2):
            lightning_log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
            lightning_log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
            lightning_log('c2', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
        batch_idx = None
        lightning_log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True)
        lightning_log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True)

        for k in ('a0.a', 'a1.a'):
            assert result[k].value == torch.tensor(0.), k
            assert result[k].cumulated_batch_size == torch.tensor(1.), k

        for k in ('b0.a', 'b1.a'):
            assert result[k].value == torch.tensor(1.) + epoch, k
            assert result[k].cumulated_batch_size == torch.tensor(1.), k

        for k in ('c0.a', 'c1.a', 'c2.a'):
            assert result[k].value == torch.tensor(4.) + epoch * 2, k
            assert result[k].cumulated_batch_size == torch.tensor(2.), k

        for k in ('d0.a', 'd1.a'):
            assert result[k].value == torch.tensor(3.) + epoch, k
            assert result[k].cumulated_batch_size == torch.tensor(1.), k
def test_result_metric_integration():
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()

    result = ResultCollection(True, torch.device("cpu"))

    for _ in range(3):
        cumulative_sum = 0
        for i in range(5):
            metric_a(i)
            metric_b(i)
            metric_c(i)

            cumulative_sum += i

            result.log('h', 'a', metric_a, on_step=True, on_epoch=True)
            result.log('h', 'b', metric_b, on_step=False, on_epoch=True)
            result.log('h', 'c', metric_c, on_step=True, on_epoch=False)

            batch_log = result.metrics(True)[MetricSource.LOG]
            assert batch_log == {"a_step": i, "c": i}

        epoch_log = result.metrics(False)[MetricSource.LOG]
        result.reset()

        # assert metric state reset to default values
        assert metric_a.x == metric_a._defaults['x']
        assert metric_b.x == metric_b._defaults['x']
        assert metric_c.x == metric_c._defaults['x']

        assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}

    assert str(result) == (
        "ResultCollection(True, cpu, {"
        "'h.a': ResultMetric('a', value=DummyMetric()), "
        "'h.b': ResultMetric('b', value=DummyMetric()), "
        "'h.c': ResultMetric('c', value=DummyMetric())"
        "})"
    )
    def __init__(self, min_steps: int, max_steps: int):
        super().__init__()
        self.min_steps: int = min_steps
        self.max_steps: int = max_steps
        self.global_step: int = 0
        # the total batch index across all epochs
        self.total_batch_idx: int = 0
        # the current split index when the batch gets split into chunks in truncated backprop through time
        self.split_idx: Optional[int] = None
        # the number of batches seen this run, updates immediately after batch_loop.run()
        # TODO: replace by progress tracking
        self.batches_seen: int = 0
        self.is_last_batch: Optional[bool] = None
        self.batch_progress = Progress()
        self.scheduler_progress = SchedulerProgress()

        self.batch_loop: Optional[TrainingBatchLoop] = None
        self.val_loop: Optional["loops.EvaluationLoop"] = None

        self._results = ResultCollection(training=True)
        self._dataloader_idx: Optional[int] = None
        self._warning_cache: WarningCache = WarningCache()
        self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
def test_result_collection_on_tensor_with_mean_reduction():
    result_collection = ResultCollection(True, torch.device("cpu"))
    product = [(True, True), (False, True), (True, False), (False, False)]
    values = torch.arange(1, 10).float(
    )  # need to convert to float() due to precision issues using torch 1.4
    batches = values * values

    for i, v in enumerate(values):
        for prog_bar in [False, True]:
            for logger in [False, True]:
                for on_step, on_epoch in product:
                    name = "loss"
                    if on_step:
                        name += "_on_step"
                    if on_epoch:
                        name += "_on_epoch"
                    if prog_bar:
                        name += "_prog_bar"
                    if logger:
                        name += "_logger"
                    log_kwargs = dict(
                        fx="training_step",
                        name=name,
                        value=v,
                        on_step=on_step,
                        on_epoch=on_epoch,
                        batch_size=batches[i],
                        prog_bar=prog_bar,
                        logger=logger,
                    )
                    if not on_step and not on_epoch:
                        with pytest.raises(
                                MisconfigurationException,
                                match="on_step=False, on_epoch=False"):
                            result_collection.log(**log_kwargs)
                    else:
                        result_collection.log(**log_kwargs)

    total_value = sum(values * batches)
    total_batches = sum(batches)
    assert result_collection[
        "training_step.loss_on_step_on_epoch"].value == total_value
    assert result_collection[
        "training_step.loss_on_step_on_epoch"].cumulated_batch_size == total_batches

    batch_metrics = result_collection.metrics(True)
    max_ = max(values)
    assert batch_metrics["pbar"] == {
        "loss_on_step_on_epoch_prog_bar_step": max_,
        "loss_on_step_on_epoch_prog_bar_logger_step": max_,
        "loss_on_step_prog_bar": max_,
        "loss_on_step_prog_bar_logger": max_,
    }
    assert batch_metrics["log"] == {
        "loss_on_step_on_epoch_logger_step": max_,
        "loss_on_step_logger": max_,
        "loss_on_step_on_epoch_prog_bar_logger_step": max_,
        "loss_on_step_prog_bar_logger": max_,
    }
    assert batch_metrics["callback"] == {
        "loss_on_step": max_,
        "loss_on_step_logger": max_,
        "loss_on_step_on_epoch": max_,
        "loss_on_step_on_epoch_logger": max_,
        "loss_on_step_on_epoch_logger_step": max_,
        "loss_on_step_on_epoch_prog_bar": max_,
        "loss_on_step_on_epoch_prog_bar_logger": max_,
        "loss_on_step_on_epoch_prog_bar_logger_step": max_,
        "loss_on_step_on_epoch_prog_bar_step": max_,
        "loss_on_step_on_epoch_step": max_,
        "loss_on_step_prog_bar": max_,
        "loss_on_step_prog_bar_logger": max_,
    }

    epoch_metrics = result_collection.metrics(False)
    mean = total_value / total_batches
    assert epoch_metrics["pbar"] == {
        "loss_on_epoch_prog_bar": mean,
        "loss_on_epoch_prog_bar_logger": mean,
        "loss_on_step_on_epoch_prog_bar_epoch": mean,
        "loss_on_step_on_epoch_prog_bar_logger_epoch": mean,
    }
    assert epoch_metrics["log"] == {
        "loss_on_epoch_logger": mean,
        "loss_on_epoch_prog_bar_logger": mean,
        "loss_on_step_on_epoch_logger_epoch": mean,
        "loss_on_step_on_epoch_prog_bar_logger_epoch": mean,
    }
    assert epoch_metrics["callback"] == {
        "loss_on_epoch": mean,
        "loss_on_epoch_logger": mean,
        "loss_on_epoch_prog_bar": mean,
        "loss_on_epoch_prog_bar_logger": mean,
        "loss_on_step_on_epoch": mean,
        "loss_on_step_on_epoch_epoch": mean,
        "loss_on_step_on_epoch_logger": mean,
        "loss_on_step_on_epoch_logger_epoch": mean,
        "loss_on_step_on_epoch_prog_bar": mean,
        "loss_on_step_on_epoch_prog_bar_epoch": mean,
        "loss_on_step_on_epoch_prog_bar_logger": mean,
        "loss_on_step_on_epoch_prog_bar_logger_epoch": mean,
    }
Пример #19
0
def test_result_collection_extra_reference():
    """Unit-test to check that the `extra` dict reference is properly set."""
    rc = ResultCollection(True)
    assert rc.extra is rc["_extra"]
Пример #20
0
def test_result_collection_restoration(tmpdir):
    """"
    This test make sure metrics are properly reloaded on failure.
    """

    result = ResultCollection(True, torch.device("cpu"))
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()
    metric_d = DummyMetric()
    current_fx_name = None
    batch_idx = None

    def lightning_log(fx, *args, **kwargs):
        nonlocal current_fx_name
        if current_fx_name != fx and batch_idx in (None, 0):
            result.reset(metrics=False, fx=fx)
        result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
        current_fx_name = fx

    for _ in range(2):

        cumulative_sum = 0

        for i in range(3):

            a = metric_a(i)
            b = metric_b(i)
            c = metric_c(i)
            metric_d(i)

            cumulative_sum += i

            metric = metric_a if i < 1 else metric_d
            lightning_log('training_step',
                          'a',
                          metric,
                          on_step=True,
                          on_epoch=True)
            lightning_log('training_step',
                          'b',
                          metric_b,
                          on_step=False,
                          on_epoch=True)
            lightning_log('training_step',
                          'c',
                          metric_c,
                          on_step=True,
                          on_epoch=False)
            lightning_log('training_step',
                          'a_1',
                          a,
                          on_step=True,
                          on_epoch=True)
            lightning_log('training_step',
                          'b_1',
                          b,
                          on_step=False,
                          on_epoch=True)
            lightning_log('training_step',
                          'c_1', {
                              '1': c,
                              '2': c
                          },
                          on_step=True,
                          on_epoch=False)

            batch_log = result.metrics(on_step=True)[MetricSource.LOG]
            assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
            assert set(batch_log['c_1']) == {'1', '2'}

            result_copy = deepcopy(result)
            new_result = ResultCollection(True, torch.device("cpu"))
            state_dict = result.state_dict()
            # check the sync fn was dropped
            assert 'fn' not in state_dict['items']['training_step.a']['meta'][
                '_sync']
            new_result.load_state_dict(state_dict)
            # should match
            assert result_copy == new_result
            # the sync fn has been kept
            assert result_copy['training_step.a'].meta.sync.fn == new_result[
                'training_step.a'].meta.sync.fn

        epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
        epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG]
        assert epoch_log == epoch_log_copy

        lightning_log('train_epoch_end',
                      'a',
                      metric_a,
                      on_step=False,
                      on_epoch=True)
        epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
        assert epoch_log == {
            'a_1_epoch': 1,
            'a_epoch': cumulative_sum,
            'a': cumulative_sum,
            'b': cumulative_sum,
            'b_1': 1
        }

        # make sure can be pickled
        pickle.loads(pickle.dumps(result))
        # make sure can be torch.loaded
        filepath = str(tmpdir / 'result')
        torch.save(result, filepath)
        torch.load(filepath)

        # assert metric state reset to default values
        result.reset()
        assert metric_a.x == metric_a._defaults['x']
        assert metric_b.x == metric_b._defaults['x']
        assert metric_c.x == metric_c._defaults['x']

        batch_idx = None
Пример #21
0
def test_result_collection_restoration(tmpdir):
    """This test make sure metrics are properly reloaded on failure."""

    result = ResultCollection(True, torch.device("cpu"))
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()
    metric_d = DummyMetric()
    current_fx_name = None
    batch_idx = None

    def lightning_log(fx, *args, **kwargs):
        nonlocal current_fx_name
        if current_fx_name != fx and batch_idx in (None, 0):
            result.reset(metrics=False, fx=fx)
        result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
        current_fx_name = fx

    for epoch in range(2):

        cumulative_sum = 0

        for i in range(3):

            a = metric_a(i)
            b = metric_b(i)
            c = metric_c(i)
            metric_d(i)

            cumulative_sum += i

            metric = metric_a if i < 1 else metric_d
            lightning_log("training_step",
                          "a",
                          metric,
                          on_step=True,
                          on_epoch=True,
                          metric_attribute="metric")
            lightning_log("training_step",
                          "b",
                          metric_b,
                          on_step=False,
                          on_epoch=True,
                          metric_attribute="metric_b")
            lightning_log("training_step",
                          "c",
                          metric_c,
                          on_step=True,
                          on_epoch=False,
                          metric_attribute="metric_c")
            lightning_log("training_step",
                          "a_1",
                          a,
                          on_step=True,
                          on_epoch=True)
            lightning_log("training_step",
                          "b_1",
                          b,
                          on_step=False,
                          on_epoch=True)
            lightning_log("training_step",
                          "c_1", {
                              "1": c,
                              "2": c
                          },
                          on_step=True,
                          on_epoch=False)

            batch_log = result.metrics(on_step=True)["log"]
            assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
            assert set(batch_log["c_1"]) == {"1", "2"}

            result_copy = deepcopy(result)
            new_result = ResultCollection(True, torch.device("cpu"))
            state_dict = result.state_dict()
            # check the sync fn was dropped
            assert "fn" not in state_dict["items"]["training_step.a"]["meta"][
                "_sync"]

            assert not new_result.result_metrics
            assert len(result.result_metrics) == 7 + epoch > 0

            new_result.load_state_dict(state_dict,
                                       metrics={
                                           "metric": metric,
                                           "metric_b": metric_b,
                                           "metric_c": metric_c
                                       })
            # should match
            assert result_copy == new_result
            # the sync fn has been kept
            assert result_copy["training_step.a"].meta.sync.fn == new_result[
                "training_step.a"].meta.sync.fn

        epoch_log = result.metrics(on_step=False)["log"]
        epoch_log_copy = result_copy.metrics(on_step=False)["log"]
        assert epoch_log == epoch_log_copy

        lightning_log("train_epoch_end",
                      "a",
                      metric_a,
                      on_step=False,
                      on_epoch=True)
        epoch_log = result.metrics(on_step=False)["log"]
        assert epoch_log == {
            "a_1_epoch": 1,
            "a_epoch": cumulative_sum,
            "a": cumulative_sum,
            "b": cumulative_sum,
            "b_1": 1,
        }

        # make sure can be pickled
        pickle.loads(pickle.dumps(result))
        # make sure can be torch.loaded
        filepath = str(tmpdir / "result")
        torch.save(result, filepath)
        torch.load(filepath)

        # assert metric state reset to default values
        result.reset()
        assert metric_a.x == metric_a._defaults["x"]
        assert metric_b.x == metric_b._defaults["x"]
        assert metric_c.x == metric_c._defaults["x"]

        batch_idx = None
Пример #22
0
def test_result_collection_simple_loop():
    result = ResultCollection(True, torch.device("cpu"))
    current_fx_name = None
    batch_idx = None

    def lightning_log(fx, *args, **kwargs):
        nonlocal current_fx_name
        if current_fx_name != fx and batch_idx in (None, 0):
            result.reset(metrics=False, fx=fx)
        result.log(fx, *args, **kwargs)
        current_fx_name = fx

    lightning_log("a0", "a", torch.tensor(0.0), on_step=True, on_epoch=True)
    lightning_log("a1", "a", torch.tensor(0.0), on_step=True, on_epoch=True)
    for epoch in range(2):
        lightning_log("b0",
                      "a",
                      torch.tensor(1.0) + epoch,
                      on_step=True,
                      on_epoch=True)
        lightning_log("b1",
                      "a",
                      torch.tensor(1.0) + epoch,
                      on_step=True,
                      on_epoch=True)
        for batch_idx in range(2):
            lightning_log("c0",
                          "a",
                          torch.tensor(2.0) + epoch,
                          on_step=True,
                          on_epoch=True)
            lightning_log("c1",
                          "a",
                          torch.tensor(2.0) + epoch,
                          on_step=True,
                          on_epoch=True)
            lightning_log("c2",
                          "a",
                          torch.tensor(2.0) + epoch,
                          on_step=True,
                          on_epoch=True)
        batch_idx = None
        lightning_log("d0",
                      "a",
                      torch.tensor(3.0) + epoch,
                      on_step=False,
                      on_epoch=True)
        lightning_log("d1",
                      "a",
                      torch.tensor(3.0) + epoch,
                      on_step=False,
                      on_epoch=True)

        for k in ("a0.a", "a1.a"):
            assert result[k].value == torch.tensor(0.0), k
            assert result[k].cumulated_batch_size == torch.tensor(1.0), k

        for k in ("b0.a", "b1.a"):
            assert result[k].value == torch.tensor(1.0) + epoch, k
            assert result[k].cumulated_batch_size == torch.tensor(1.0), k

        for k in ("c0.a", "c1.a", "c2.a"):
            assert result[k].value == torch.tensor(4.0) + epoch * 2, k
            assert result[k].cumulated_batch_size == torch.tensor(2.0), k

        for k in ("d0.a", "d1.a"):
            assert result[k].value == torch.tensor(3.0) + epoch, k
            assert result[k].cumulated_batch_size == torch.tensor(1.0), k
Пример #23
0
def test_result_collection_on_tensor_with_mean_reduction():
    result_collection = ResultCollection(True, torch.device("cpu"))
    product = [(True, True), (False, True), (True, False), (False, False)]
    values = torch.arange(1, 10).float(
    )  # need to convert to float() due to precision issues using torch 1.4
    batches = values * values

    for i, v in enumerate(values):
        for prog_bar in [False, True]:
            for logger in [False, True]:
                for on_step, on_epoch in product:
                    name = "loss"
                    if on_step:
                        name += "_on_step"
                    if on_epoch:
                        name += "_on_epoch"
                    if prog_bar:
                        name += "_prog_bar"
                    if logger:
                        name += "_logger"
                    result_collection.log(
                        "training_step",
                        name,
                        v,
                        on_step=on_step,
                        on_epoch=on_epoch,
                        batch_size=batches[i],
                        prog_bar=prog_bar,
                        logger=logger,
                    )

    total_value = sum(values * batches)
    total_batches = sum(batches)
    assert result_collection[
        "training_step.loss_on_step_on_epoch"].value == total_value
    assert result_collection[
        "training_step.loss_on_step_on_epoch"].cumulated_batch_size == total_batches

    batch_metrics = result_collection.metrics(True)
    max_ = max(values)
    assert batch_metrics[MetricSource.PBAR] == {
        "loss_on_step_on_epoch_prog_bar_step": max_,
        "loss_on_step_on_epoch_prog_bar_logger_step": max_,
        "loss_on_step_prog_bar": max_,
        "loss_on_step_prog_bar_logger": max_,
    }
    assert batch_metrics[MetricSource.LOG] == {
        "loss_on_step_on_epoch_logger_step": max_,
        "loss_on_step_logger": max_,
        "loss_on_step_on_epoch_prog_bar_logger_step": max_,
        "loss_on_step_prog_bar_logger": max_,
    }
    assert batch_metrics[MetricSource.CALLBACK] == {
        "loss_on_step": max_,
        "loss_on_step_logger": max_,
        "loss_on_step_on_epoch": max_,
        "loss_on_step_on_epoch_logger": max_,
        "loss_on_step_on_epoch_logger_step": max_,
        "loss_on_step_on_epoch_prog_bar": max_,
        "loss_on_step_on_epoch_prog_bar_logger": max_,
        "loss_on_step_on_epoch_prog_bar_logger_step": max_,
        "loss_on_step_on_epoch_prog_bar_step": max_,
        "loss_on_step_on_epoch_step": max_,
        "loss_on_step_prog_bar": max_,
        "loss_on_step_prog_bar_logger": max_,
    }

    epoch_metrics = result_collection.metrics(False)
    mean = total_value / total_batches
    assert epoch_metrics[MetricSource.PBAR] == {
        "loss_on_epoch_prog_bar": mean,
        "loss_on_epoch_prog_bar_logger": mean,
        "loss_on_step_on_epoch_prog_bar_epoch": mean,
        "loss_on_step_on_epoch_prog_bar_logger_epoch": mean,
    }
    assert epoch_metrics[MetricSource.LOG] == {
        "loss_on_epoch_logger": mean,
        "loss_on_epoch_prog_bar_logger": mean,
        "loss_on_step_on_epoch_logger_epoch": mean,
        "loss_on_step_on_epoch_prog_bar_logger_epoch": mean,
    }
    assert epoch_metrics[MetricSource.CALLBACK] == {
        "loss_on_epoch": mean,
        "loss_on_epoch_logger": mean,
        "loss_on_epoch_prog_bar": mean,
        "loss_on_epoch_prog_bar_logger": mean,
        "loss_on_step_on_epoch": mean,
        "loss_on_step_on_epoch_epoch": mean,
        "loss_on_step_on_epoch_logger": mean,
        "loss_on_step_on_epoch_logger_epoch": mean,
        "loss_on_step_on_epoch_prog_bar": mean,
        "loss_on_step_on_epoch_prog_bar_epoch": mean,
        "loss_on_step_on_epoch_prog_bar_logger": mean,
        "loss_on_step_on_epoch_prog_bar_logger_epoch": mean,
    }