def test_result_collection_batch_size_extraction():
    fx_name = "training_step"
    log_val = torch.tensor(7.0)

    results = _ResultCollection(training=True, device="cpu")
    results.batch = torch.randn(1, 4)
    train_mse = MeanSquaredError()
    train_mse(torch.randn(4, 5), torch.randn(4, 5))
    results.log(fx_name,
                "train_logs", {
                    "mse": train_mse,
                    "log_val": log_val
                },
                on_step=False,
                on_epoch=True)
    assert results.batch_size == 1
    assert isinstance(results["training_step.train_logs"]["mse"].value,
                      MeanSquaredError)
    assert results["training_step.train_logs"]["log_val"].value == log_val

    results = _ResultCollection(training=True, device="cpu")
    results.batch = torch.randn(1, 4)
    results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True)
    assert results.batch_size == 1
    assert results["training_step.train_log"].value == log_val
    assert results["training_step.train_log"].cumulated_batch_size == 1
        def on_save_checkpoint(self, checkpoint) -> None:
            results = self.trainer._results
            # simplify logic
            state_dict = results.state_dict(drop_value=False)

            # check device
            assert results["validation_step.v"].value.device.type == device
            assert state_dict["items"]["validation_step.v"][
                "value"].device.type == device

            # sync fn should be kept
            assert results[
                "validation_step.v"].meta.sync.fn == self.trainer.strategy.reduce

            # sync fn dropped from the state dict
            assert "fn" not in state_dict["items"]["validation_step.v"][
                "meta"]["_sync"]
            results.load_state_dict(state_dict)

            # check device after loading
            assert results["validation_step.v"].value.device.type == device

            # sync fn was preserved in the original result
            assert results[
                "validation_step.v"].meta.sync.fn == self.trainer.strategy.reduce

            # default sync fn
            new_results = _ResultCollection(False, device)
            new_results.load_state_dict(state_dict, map_location="cpu")
            assert new_results["validation_step.v"].meta.sync.fn is None

            # check map location
            assert new_results["validation_step.v"].value.device.type == "cpu"
def test_result_collection_no_batch_size_extraction():
    results = _ResultCollection(training=True, device="cpu")
    results.batch = torch.randn(1, 4)
    fx_name = "training_step"
    batch_size = 10
    log_val = torch.tensor(7.0)

    train_mae = MeanAbsoluteError()
    train_mae(torch.randn(4, 5), torch.randn(4, 5))
    train_mse = MeanSquaredError()
    train_mse(torch.randn(4, 5), torch.randn(4, 5))
    results.log(fx_name, "step_log_val", log_val, on_step=True, on_epoch=False)
    results.log(fx_name, "epoch_log_val", log_val, on_step=False, on_epoch=True, batch_size=batch_size)
    results.log(fx_name, "epoch_sum_log_val", log_val, on_step=True, on_epoch=True, reduce_fx="sum")
    results.log(fx_name, "train_mae", train_mae, on_step=True, on_epoch=False)
    results.log(fx_name, "train_mse", {"mse": train_mse}, on_step=True, on_epoch=False)

    assert results.batch_size is None
    assert isinstance(results["training_step.train_mse"]["mse"].value, MeanSquaredError)
    assert isinstance(results["training_step.train_mae"].value, MeanAbsoluteError)
    assert results["training_step.step_log_val"].value == log_val
    assert results["training_step.step_log_val"].cumulated_batch_size == 0
    assert results["training_step.epoch_log_val"].value == log_val * batch_size
    assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size
    assert results["training_step.epoch_sum_log_val"].value == log_val
    def __init__(self,
                 min_steps: Optional[int] = None,
                 max_steps: int = -1) -> None:
        super().__init__()
        if max_steps is None:
            rank_zero_deprecation(
                "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."
                " Use `max_steps = -1` instead.")
            max_steps = -1
        elif max_steps < -1:
            raise MisconfigurationException(
                f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {max_steps}."
            )
        self.min_steps = min_steps
        self.max_steps = max_steps

        self.batch_progress = BatchProgress()
        self.scheduler_progress = SchedulerProgress()

        self.batch_loop = TrainingBatchLoop()
        self.val_loop = loops.EvaluationLoop(verbose=False)

        self._results = _ResultCollection(training=True)
        self._outputs: _OUTPUTS_TYPE = []
        self._warning_cache = WarningCache()
        # caches the loaded dataloader state until dataloader objects are available
        self._dataloader_state_dict: Dict[str, Any] = {}
        self._batches_that_stepped: int = 0
示例#5
0
    def __init__(self, verbose: bool = True) -> None:
        super().__init__()
        self.epoch_loop = EvaluationEpochLoop()
        self.verbose = verbose

        self._results = _ResultCollection(training=False)
        self._outputs: List[EPOCH_OUTPUT] = []
        self._logged_outputs: List[_OUT_DICT] = []
        self._max_batches: List[int] = []
        self._has_run: bool = False
def test_result_metric_integration():
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()

    result = _ResultCollection(True, torch.device("cpu"))

    for _ in range(3):
        cumulative_sum = 0
        for i in range(5):
            metric_a(i)
            metric_b(i)
            metric_c(i)

            cumulative_sum += i

            result.log("h", "a", metric_a, on_step=True, on_epoch=True)
            result.log("h", "b", metric_b, on_step=False, on_epoch=True)
            result.log("h", "c", metric_c, on_step=True, on_epoch=False)

            batch_log = result.metrics(True)["log"]
            assert batch_log == {"a_step": i, "c": i}

        epoch_log = result.metrics(False)["log"]
        result.reset()

        # assert metric state reset to default values
        assert metric_a.x == metric_a._defaults["x"]
        assert metric_b.x == metric_b._defaults["x"]
        assert metric_c.x == metric_c._defaults["x"]

        assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}

    result.minimize = torch.tensor(1.0)
    result.extra = {}
    assert str(result) == (
        "_ResultCollection("
        "{"
        "'h.a': _ResultMetric('a', value=DummyMetric()), "
        "'h.b': _ResultMetric('b', value=DummyMetric()), "
        "'h.c': _ResultMetric('c', value=DummyMetric())"
        "})"
    )
    assert repr(result) == (
        "{"
        "True, "
        "device(type='cpu'), "
        "{'h.a': _ResultMetric('a', value=DummyMetric()), "
        "'h.b': _ResultMetric('b', value=DummyMetric()), "
        "'h.c': _ResultMetric('c', value=DummyMetric())"
        "}}"
    )
def _ddp_test_fn(rank, worldsize):
    _setup_ddp(rank, worldsize)
    torch.tensor([1.0])

    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()

    metric_a = metric_a.to(f"cuda:{rank}")
    metric_b = metric_b.to(f"cuda:{rank}")
    metric_c = metric_c.to(f"cuda:{rank}")

    result = _ResultCollection(True, torch.device(f"cuda:{rank}"))

    for _ in range(3):
        cumulative_sum = 0
        for i in range(5):
            metric_a(i)
            metric_b(i)
            metric_c(i)

            cumulative_sum += i

            result.log("h", "a", metric_a, on_step=True, on_epoch=True)
            result.log("h", "b", metric_b, on_step=False, on_epoch=True)
            result.log("h", "c", metric_c, on_step=True, on_epoch=False)

            batch_log = result.metrics(True)["log"]
            assert batch_log == {"a_step": i, "c": i}

        epoch_log = result.metrics(False)["log"]
        result.reset()

        # assert metric state reset to default values
        assert metric_a.x == metric_a._defaults["x"], (metric_a.x,
                                                       metric_a._defaults["x"])
        assert metric_b.x == metric_b._defaults["x"]
        assert metric_c.x == metric_c._defaults["x"]

        assert epoch_log == {
            "b": cumulative_sum * worldsize,
            "a_epoch": cumulative_sum * worldsize
        }
def test_result_collection_simple_loop():
    result = _ResultCollection(True, torch.device("cpu"))
    current_fx_name = None
    batch_idx = None

    def lightning_log(fx, *args, **kwargs):
        nonlocal current_fx_name
        if current_fx_name != fx and batch_idx in (None, 0):
            result.reset(metrics=False, fx=fx)
        result.log(fx, *args, **kwargs)
        current_fx_name = fx

    lightning_log("a0", "a", torch.tensor(0.0), on_step=True, on_epoch=True)
    lightning_log("a1", "a", torch.tensor(0.0), on_step=True, on_epoch=True)
    for epoch in range(2):
        lightning_log("b0", "a", torch.tensor(1.0) + epoch, on_step=True, on_epoch=True)
        lightning_log("b1", "a", torch.tensor(1.0) + epoch, on_step=True, on_epoch=True)
        for batch_idx in range(2):
            lightning_log("c0", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True)
            lightning_log("c1", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True)
            lightning_log("c2", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True)
        batch_idx = None
        lightning_log("d0", "a", torch.tensor(3.0) + epoch, on_step=False, on_epoch=True)
        lightning_log("d1", "a", torch.tensor(3.0) + epoch, on_step=False, on_epoch=True)

        for k in ("a0.a", "a1.a"):
            assert result[k].value == torch.tensor(0.0), k
            assert result[k].cumulated_batch_size == torch.tensor(1.0), k

        for k in ("b0.a", "b1.a"):
            assert result[k].value == torch.tensor(1.0) + epoch, k
            assert result[k].cumulated_batch_size == torch.tensor(1.0), k

        for k in ("c0.a", "c1.a", "c2.a"):
            assert result[k].value == torch.tensor(4.0) + epoch * 2, k
            assert result[k].cumulated_batch_size == torch.tensor(2.0), k

        for k in ("d0.a", "d1.a"):
            assert result[k].value == torch.tensor(3.0) + epoch, k
            assert result[k].cumulated_batch_size == torch.tensor(1.0), k
def test_result_collection_restoration(tmpdir):
    """This test make sure metrics are properly reloaded on failure."""

    result = _ResultCollection(True, torch.device("cpu"))
    metric_a = DummyMetric()
    metric_b = DummyMetric()
    metric_c = DummyMetric()
    metric_d = DummyMetric()
    current_fx_name = None
    batch_idx = None

    def lightning_log(fx, *args, **kwargs):
        nonlocal current_fx_name
        if current_fx_name != fx and batch_idx in (None, 0):
            result.reset(metrics=False, fx=fx)
        result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
        current_fx_name = fx

    for epoch in range(2):

        cumulative_sum = 0

        for i in range(3):

            a = metric_a(i)
            b = metric_b(i)
            c = metric_c(i)
            metric_d(i)

            cumulative_sum += i

            metric = metric_a if i < 1 else metric_d
            lightning_log("training_step",
                          "a",
                          metric,
                          on_step=True,
                          on_epoch=True,
                          metric_attribute="metric")
            lightning_log("training_step",
                          "b",
                          metric_b,
                          on_step=False,
                          on_epoch=True,
                          metric_attribute="metric_b")
            lightning_log("training_step",
                          "c",
                          metric_c,
                          on_step=True,
                          on_epoch=False,
                          metric_attribute="metric_c")
            lightning_log("training_step",
                          "a_1",
                          a,
                          on_step=True,
                          on_epoch=True)
            lightning_log("training_step",
                          "b_1",
                          b,
                          on_step=False,
                          on_epoch=True)
            lightning_log("training_step",
                          "c_1", {
                              "1": c,
                              "2": c
                          },
                          on_step=True,
                          on_epoch=False)

            batch_log = result.metrics(on_step=True)["log"]
            assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
            assert set(batch_log["c_1"]) == {"1", "2"}

            result_copy = deepcopy(result)
            new_result = _ResultCollection(True, torch.device("cpu"))
            state_dict = result.state_dict()
            # check the sync fn was dropped
            assert "fn" not in state_dict["items"]["training_step.a"]["meta"][
                "_sync"]

            assert not new_result.result_metrics
            assert len(result.result_metrics) == 7 + epoch > 0

            new_result.load_state_dict(state_dict,
                                       metrics={
                                           "metric": metric,
                                           "metric_b": metric_b,
                                           "metric_c": metric_c
                                       })
            # should match
            assert result_copy == new_result
            # the sync fn has been kept
            assert result_copy["training_step.a"].meta.sync.fn == new_result[
                "training_step.a"].meta.sync.fn

        epoch_log = result.metrics(on_step=False)["log"]
        epoch_log_copy = result_copy.metrics(on_step=False)["log"]
        assert epoch_log == epoch_log_copy

        lightning_log("train_epoch_end",
                      "a",
                      metric_a,
                      on_step=False,
                      on_epoch=True)
        epoch_log = result.metrics(on_step=False)["log"]
        assert epoch_log == {
            "a_1_epoch": 1,
            "a_epoch": cumulative_sum,
            "a": cumulative_sum,
            "b": cumulative_sum,
            "b_1": 1,
        }

        # make sure can be pickled
        pickle.loads(pickle.dumps(result))
        # make sure can be torch.loaded
        filepath = str(tmpdir / "result")
        torch.save(result, filepath)
        torch.load(filepath)

        # assert metric state reset to default values
        result.reset()
        assert metric_a.x == metric_a._defaults["x"]
        assert metric_b.x == metric_b._defaults["x"]
        assert metric_c.x == metric_c._defaults["x"]

        batch_idx = None
def test_result_collection_on_tensor_with_mean_reduction():
    result_collection = _ResultCollection(True)
    product = [(True, True), (False, True), (True, False), (False, False)]
    values = torch.arange(1, 10)
    batches = values * values

    for i, v in enumerate(values):
        for prog_bar in [False, True]:
            for logger in [False, True]:
                for on_step, on_epoch in product:
                    name = "loss"
                    if on_step:
                        name += "_on_step"
                    if on_epoch:
                        name += "_on_epoch"
                    if prog_bar:
                        name += "_prog_bar"
                    if logger:
                        name += "_logger"
                    log_kwargs = dict(
                        fx="training_step",
                        name=name,
                        value=v,
                        on_step=on_step,
                        on_epoch=on_epoch,
                        batch_size=batches[i],
                        prog_bar=prog_bar,
                        logger=logger,
                    )
                    if not on_step and not on_epoch:
                        with pytest.raises(
                                MisconfigurationException,
                                match="on_step=False, on_epoch=False"):
                            result_collection.log(**log_kwargs)
                    else:
                        result_collection.log(**log_kwargs)

    total_value = sum(values * batches)
    total_batches = sum(batches)
    assert result_collection[
        "training_step.loss_on_step_on_epoch"].value == total_value
    assert result_collection[
        "training_step.loss_on_step_on_epoch"].cumulated_batch_size == total_batches

    batch_metrics = result_collection.metrics(True)
    max_ = max(values)
    assert batch_metrics["pbar"] == {
        "loss_on_step_on_epoch_prog_bar_step": max_,
        "loss_on_step_on_epoch_prog_bar_logger_step": max_,
        "loss_on_step_prog_bar": max_,
        "loss_on_step_prog_bar_logger": max_,
    }
    assert batch_metrics["log"] == {
        "loss_on_step_on_epoch_logger_step": max_,
        "loss_on_step_logger": max_,
        "loss_on_step_on_epoch_prog_bar_logger_step": max_,
        "loss_on_step_prog_bar_logger": max_,
    }
    assert batch_metrics["callback"] == {
        "loss_on_step": max_,
        "loss_on_step_logger": max_,
        "loss_on_step_on_epoch": max_,
        "loss_on_step_on_epoch_logger": max_,
        "loss_on_step_on_epoch_logger_step": max_,
        "loss_on_step_on_epoch_prog_bar": max_,
        "loss_on_step_on_epoch_prog_bar_logger": max_,
        "loss_on_step_on_epoch_prog_bar_logger_step": max_,
        "loss_on_step_on_epoch_prog_bar_step": max_,
        "loss_on_step_on_epoch_step": max_,
        "loss_on_step_prog_bar": max_,
        "loss_on_step_prog_bar_logger": max_,
    }

    epoch_metrics = result_collection.metrics(False)
    mean = total_value / total_batches
    assert epoch_metrics["pbar"] == {
        "loss_on_epoch_prog_bar": mean,
        "loss_on_epoch_prog_bar_logger": mean,
        "loss_on_step_on_epoch_prog_bar_epoch": mean,
        "loss_on_step_on_epoch_prog_bar_logger_epoch": mean,
    }
    assert epoch_metrics["log"] == {
        "loss_on_epoch_logger": mean,
        "loss_on_epoch_prog_bar_logger": mean,
        "loss_on_step_on_epoch_logger_epoch": mean,
        "loss_on_step_on_epoch_prog_bar_logger_epoch": mean,
    }
    assert epoch_metrics["callback"] == {
        "loss_on_epoch": mean,
        "loss_on_epoch_logger": mean,
        "loss_on_epoch_prog_bar": mean,
        "loss_on_epoch_prog_bar_logger": mean,
        "loss_on_step_on_epoch": mean,
        "loss_on_step_on_epoch_epoch": mean,
        "loss_on_step_on_epoch_logger": mean,
        "loss_on_step_on_epoch_logger_epoch": mean,
        "loss_on_step_on_epoch_prog_bar": mean,
        "loss_on_step_on_epoch_prog_bar_epoch": mean,
        "loss_on_step_on_epoch_prog_bar_logger": mean,
        "loss_on_step_on_epoch_prog_bar_logger_epoch": mean,
    }