def test_result_collection_batch_size_extraction(): fx_name = "training_step" log_val = torch.tensor(7.0) results = _ResultCollection(training=True, device="cpu") results.batch = torch.randn(1, 4) train_mse = MeanSquaredError() train_mse(torch.randn(4, 5), torch.randn(4, 5)) results.log(fx_name, "train_logs", { "mse": train_mse, "log_val": log_val }, on_step=False, on_epoch=True) assert results.batch_size == 1 assert isinstance(results["training_step.train_logs"]["mse"].value, MeanSquaredError) assert results["training_step.train_logs"]["log_val"].value == log_val results = _ResultCollection(training=True, device="cpu") results.batch = torch.randn(1, 4) results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True) assert results.batch_size == 1 assert results["training_step.train_log"].value == log_val assert results["training_step.train_log"].cumulated_batch_size == 1
def on_save_checkpoint(self, checkpoint) -> None: results = self.trainer._results # simplify logic state_dict = results.state_dict(drop_value=False) # check device assert results["validation_step.v"].value.device.type == device assert state_dict["items"]["validation_step.v"][ "value"].device.type == device # sync fn should be kept assert results[ "validation_step.v"].meta.sync.fn == self.trainer.strategy.reduce # sync fn dropped from the state dict assert "fn" not in state_dict["items"]["validation_step.v"][ "meta"]["_sync"] results.load_state_dict(state_dict) # check device after loading assert results["validation_step.v"].value.device.type == device # sync fn was preserved in the original result assert results[ "validation_step.v"].meta.sync.fn == self.trainer.strategy.reduce # default sync fn new_results = _ResultCollection(False, device) new_results.load_state_dict(state_dict, map_location="cpu") assert new_results["validation_step.v"].meta.sync.fn is None # check map location assert new_results["validation_step.v"].value.device.type == "cpu"
def test_result_collection_no_batch_size_extraction(): results = _ResultCollection(training=True, device="cpu") results.batch = torch.randn(1, 4) fx_name = "training_step" batch_size = 10 log_val = torch.tensor(7.0) train_mae = MeanAbsoluteError() train_mae(torch.randn(4, 5), torch.randn(4, 5)) train_mse = MeanSquaredError() train_mse(torch.randn(4, 5), torch.randn(4, 5)) results.log(fx_name, "step_log_val", log_val, on_step=True, on_epoch=False) results.log(fx_name, "epoch_log_val", log_val, on_step=False, on_epoch=True, batch_size=batch_size) results.log(fx_name, "epoch_sum_log_val", log_val, on_step=True, on_epoch=True, reduce_fx="sum") results.log(fx_name, "train_mae", train_mae, on_step=True, on_epoch=False) results.log(fx_name, "train_mse", {"mse": train_mse}, on_step=True, on_epoch=False) assert results.batch_size is None assert isinstance(results["training_step.train_mse"]["mse"].value, MeanSquaredError) assert isinstance(results["training_step.train_mae"].value, MeanAbsoluteError) assert results["training_step.step_log_val"].value == log_val assert results["training_step.step_log_val"].cumulated_batch_size == 0 assert results["training_step.epoch_log_val"].value == log_val * batch_size assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size assert results["training_step.epoch_sum_log_val"].value == log_val
def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None: super().__init__() if max_steps is None: rank_zero_deprecation( "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7." " Use `max_steps = -1` instead.") max_steps = -1 elif max_steps < -1: raise MisconfigurationException( f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {max_steps}." ) self.min_steps = min_steps self.max_steps = max_steps self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() self.batch_loop = TrainingBatchLoop() self.val_loop = loops.EvaluationLoop(verbose=False) self._results = _ResultCollection(training=True) self._outputs: _OUTPUTS_TYPE = [] self._warning_cache = WarningCache() # caches the loaded dataloader state until dataloader objects are available self._dataloader_state_dict: Dict[str, Any] = {} self._batches_that_stepped: int = 0
def __init__(self, verbose: bool = True) -> None: super().__init__() self.epoch_loop = EvaluationEpochLoop() self.verbose = verbose self._results = _ResultCollection(training=False) self._outputs: List[EPOCH_OUTPUT] = [] self._logged_outputs: List[_OUT_DICT] = [] self._max_batches: List[int] = [] self._has_run: bool = False
def test_result_metric_integration(): metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() result = _ResultCollection(True, torch.device("cpu")) for _ in range(3): cumulative_sum = 0 for i in range(5): metric_a(i) metric_b(i) metric_c(i) cumulative_sum += i result.log("h", "a", metric_a, on_step=True, on_epoch=True) result.log("h", "b", metric_b, on_step=False, on_epoch=True) result.log("h", "c", metric_c, on_step=True, on_epoch=False) batch_log = result.metrics(True)["log"] assert batch_log == {"a_step": i, "c": i} epoch_log = result.metrics(False)["log"] result.reset() # assert metric state reset to default values assert metric_a.x == metric_a._defaults["x"] assert metric_b.x == metric_b._defaults["x"] assert metric_c.x == metric_c._defaults["x"] assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum} result.minimize = torch.tensor(1.0) result.extra = {} assert str(result) == ( "_ResultCollection(" "{" "'h.a': _ResultMetric('a', value=DummyMetric()), " "'h.b': _ResultMetric('b', value=DummyMetric()), " "'h.c': _ResultMetric('c', value=DummyMetric())" "})" ) assert repr(result) == ( "{" "True, " "device(type='cpu'), " "{'h.a': _ResultMetric('a', value=DummyMetric()), " "'h.b': _ResultMetric('b', value=DummyMetric()), " "'h.c': _ResultMetric('c', value=DummyMetric())" "}}" )
def _ddp_test_fn(rank, worldsize): _setup_ddp(rank, worldsize) torch.tensor([1.0]) metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() metric_a = metric_a.to(f"cuda:{rank}") metric_b = metric_b.to(f"cuda:{rank}") metric_c = metric_c.to(f"cuda:{rank}") result = _ResultCollection(True, torch.device(f"cuda:{rank}")) for _ in range(3): cumulative_sum = 0 for i in range(5): metric_a(i) metric_b(i) metric_c(i) cumulative_sum += i result.log("h", "a", metric_a, on_step=True, on_epoch=True) result.log("h", "b", metric_b, on_step=False, on_epoch=True) result.log("h", "c", metric_c, on_step=True, on_epoch=False) batch_log = result.metrics(True)["log"] assert batch_log == {"a_step": i, "c": i} epoch_log = result.metrics(False)["log"] result.reset() # assert metric state reset to default values assert metric_a.x == metric_a._defaults["x"], (metric_a.x, metric_a._defaults["x"]) assert metric_b.x == metric_b._defaults["x"] assert metric_c.x == metric_c._defaults["x"] assert epoch_log == { "b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize }
def test_result_collection_simple_loop(): result = _ResultCollection(True, torch.device("cpu")) current_fx_name = None batch_idx = None def lightning_log(fx, *args, **kwargs): nonlocal current_fx_name if current_fx_name != fx and batch_idx in (None, 0): result.reset(metrics=False, fx=fx) result.log(fx, *args, **kwargs) current_fx_name = fx lightning_log("a0", "a", torch.tensor(0.0), on_step=True, on_epoch=True) lightning_log("a1", "a", torch.tensor(0.0), on_step=True, on_epoch=True) for epoch in range(2): lightning_log("b0", "a", torch.tensor(1.0) + epoch, on_step=True, on_epoch=True) lightning_log("b1", "a", torch.tensor(1.0) + epoch, on_step=True, on_epoch=True) for batch_idx in range(2): lightning_log("c0", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True) lightning_log("c1", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True) lightning_log("c2", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True) batch_idx = None lightning_log("d0", "a", torch.tensor(3.0) + epoch, on_step=False, on_epoch=True) lightning_log("d1", "a", torch.tensor(3.0) + epoch, on_step=False, on_epoch=True) for k in ("a0.a", "a1.a"): assert result[k].value == torch.tensor(0.0), k assert result[k].cumulated_batch_size == torch.tensor(1.0), k for k in ("b0.a", "b1.a"): assert result[k].value == torch.tensor(1.0) + epoch, k assert result[k].cumulated_batch_size == torch.tensor(1.0), k for k in ("c0.a", "c1.a", "c2.a"): assert result[k].value == torch.tensor(4.0) + epoch * 2, k assert result[k].cumulated_batch_size == torch.tensor(2.0), k for k in ("d0.a", "d1.a"): assert result[k].value == torch.tensor(3.0) + epoch, k assert result[k].cumulated_batch_size == torch.tensor(1.0), k
def test_result_collection_restoration(tmpdir): """This test make sure metrics are properly reloaded on failure.""" result = _ResultCollection(True, torch.device("cpu")) metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() metric_d = DummyMetric() current_fx_name = None batch_idx = None def lightning_log(fx, *args, **kwargs): nonlocal current_fx_name if current_fx_name != fx and batch_idx in (None, 0): result.reset(metrics=False, fx=fx) result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist) current_fx_name = fx for epoch in range(2): cumulative_sum = 0 for i in range(3): a = metric_a(i) b = metric_b(i) c = metric_c(i) metric_d(i) cumulative_sum += i metric = metric_a if i < 1 else metric_d lightning_log("training_step", "a", metric, on_step=True, on_epoch=True, metric_attribute="metric") lightning_log("training_step", "b", metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") lightning_log("training_step", "c", metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") lightning_log("training_step", "a_1", a, on_step=True, on_epoch=True) lightning_log("training_step", "b_1", b, on_step=False, on_epoch=True) lightning_log("training_step", "c_1", { "1": c, "2": c }, on_step=True, on_epoch=False) batch_log = result.metrics(on_step=True)["log"] assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"} assert set(batch_log["c_1"]) == {"1", "2"} result_copy = deepcopy(result) new_result = _ResultCollection(True, torch.device("cpu")) state_dict = result.state_dict() # check the sync fn was dropped assert "fn" not in state_dict["items"]["training_step.a"]["meta"][ "_sync"] assert not new_result.result_metrics assert len(result.result_metrics) == 7 + epoch > 0 new_result.load_state_dict(state_dict, metrics={ "metric": metric, "metric_b": metric_b, "metric_c": metric_c }) # should match assert result_copy == new_result # the sync fn has been kept assert result_copy["training_step.a"].meta.sync.fn == new_result[ "training_step.a"].meta.sync.fn epoch_log = result.metrics(on_step=False)["log"] epoch_log_copy = result_copy.metrics(on_step=False)["log"] assert epoch_log == epoch_log_copy lightning_log("train_epoch_end", "a", metric_a, on_step=False, on_epoch=True) epoch_log = result.metrics(on_step=False)["log"] assert epoch_log == { "a_1_epoch": 1, "a_epoch": cumulative_sum, "a": cumulative_sum, "b": cumulative_sum, "b_1": 1, } # make sure can be pickled pickle.loads(pickle.dumps(result)) # make sure can be torch.loaded filepath = str(tmpdir / "result") torch.save(result, filepath) torch.load(filepath) # assert metric state reset to default values result.reset() assert metric_a.x == metric_a._defaults["x"] assert metric_b.x == metric_b._defaults["x"] assert metric_c.x == metric_c._defaults["x"] batch_idx = None
def test_result_collection_on_tensor_with_mean_reduction(): result_collection = _ResultCollection(True) product = [(True, True), (False, True), (True, False), (False, False)] values = torch.arange(1, 10) batches = values * values for i, v in enumerate(values): for prog_bar in [False, True]: for logger in [False, True]: for on_step, on_epoch in product: name = "loss" if on_step: name += "_on_step" if on_epoch: name += "_on_epoch" if prog_bar: name += "_prog_bar" if logger: name += "_logger" log_kwargs = dict( fx="training_step", name=name, value=v, on_step=on_step, on_epoch=on_epoch, batch_size=batches[i], prog_bar=prog_bar, logger=logger, ) if not on_step and not on_epoch: with pytest.raises( MisconfigurationException, match="on_step=False, on_epoch=False"): result_collection.log(**log_kwargs) else: result_collection.log(**log_kwargs) total_value = sum(values * batches) total_batches = sum(batches) assert result_collection[ "training_step.loss_on_step_on_epoch"].value == total_value assert result_collection[ "training_step.loss_on_step_on_epoch"].cumulated_batch_size == total_batches batch_metrics = result_collection.metrics(True) max_ = max(values) assert batch_metrics["pbar"] == { "loss_on_step_on_epoch_prog_bar_step": max_, "loss_on_step_on_epoch_prog_bar_logger_step": max_, "loss_on_step_prog_bar": max_, "loss_on_step_prog_bar_logger": max_, } assert batch_metrics["log"] == { "loss_on_step_on_epoch_logger_step": max_, "loss_on_step_logger": max_, "loss_on_step_on_epoch_prog_bar_logger_step": max_, "loss_on_step_prog_bar_logger": max_, } assert batch_metrics["callback"] == { "loss_on_step": max_, "loss_on_step_logger": max_, "loss_on_step_on_epoch": max_, "loss_on_step_on_epoch_logger": max_, "loss_on_step_on_epoch_logger_step": max_, "loss_on_step_on_epoch_prog_bar": max_, "loss_on_step_on_epoch_prog_bar_logger": max_, "loss_on_step_on_epoch_prog_bar_logger_step": max_, "loss_on_step_on_epoch_prog_bar_step": max_, "loss_on_step_on_epoch_step": max_, "loss_on_step_prog_bar": max_, "loss_on_step_prog_bar_logger": max_, } epoch_metrics = result_collection.metrics(False) mean = total_value / total_batches assert epoch_metrics["pbar"] == { "loss_on_epoch_prog_bar": mean, "loss_on_epoch_prog_bar_logger": mean, "loss_on_step_on_epoch_prog_bar_epoch": mean, "loss_on_step_on_epoch_prog_bar_logger_epoch": mean, } assert epoch_metrics["log"] == { "loss_on_epoch_logger": mean, "loss_on_epoch_prog_bar_logger": mean, "loss_on_step_on_epoch_logger_epoch": mean, "loss_on_step_on_epoch_prog_bar_logger_epoch": mean, } assert epoch_metrics["callback"] == { "loss_on_epoch": mean, "loss_on_epoch_logger": mean, "loss_on_epoch_prog_bar": mean, "loss_on_epoch_prog_bar_logger": mean, "loss_on_step_on_epoch": mean, "loss_on_step_on_epoch_epoch": mean, "loss_on_step_on_epoch_logger": mean, "loss_on_step_on_epoch_logger_epoch": mean, "loss_on_step_on_epoch_prog_bar": mean, "loss_on_step_on_epoch_prog_bar_epoch": mean, "loss_on_step_on_epoch_prog_bar_logger": mean, "loss_on_step_on_epoch_prog_bar_logger_epoch": mean, }