def __init__(self): super().__init__() self._max_batches: Optional[Union[int, Sequence[int]]] = None self.outputs = [] self.evaluation_loop = EvaluationEpochLoop() self._val_results = ResultCollection(training=False) self._test_results = ResultCollection(training=False)
def __init__(self, trainer: 'pl.Trainer'): self.trainer: 'pl.Trainer' = trainer self.outputs: EPOCH_OUTPUT = [] self.predictions: Optional[PredictionCollection] = None self.max_batches: Optional[List[Union[int, float]]] = None self.warning_cache = WarningCache() self.num_dataloaders: Optional[int] = None self._val_results = ResultCollection(training=False) self._test_results = ResultCollection(training=False)
def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1) -> None: super().__init__() if max_steps is None: rank_zero_deprecation( "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7." " Use `max_steps = -1` instead." ) max_steps = -1 elif max_steps < -1: raise MisconfigurationException( f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {max_steps}." ) self.min_steps = min_steps self.max_steps = max_steps self.global_step: int = 0 self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() self.batch_loop: Optional[TrainingBatchLoop] = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) self._outputs: _OUTPUTS_TYPE = [] self._warning_cache = WarningCache() self._dataloader_iter: Optional[Iterator] = None # caches the loaded dataloader state until dataloader objects are available self._dataloader_state_dict: Dict[str, Any] = {}
def on_save_checkpoint(self, checkpoint) -> None: results = self.trainer._results # simplify logic state_dict = results.state_dict(drop_value=False) # check device assert results['validation_step.v'].value.device.type == device assert state_dict['items']['validation_step.v'][ 'value'].device.type == device # sync fn should be kept assert results[ 'validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce # sync fn dropped from the state dict assert 'fn' not in state_dict['items']['validation_step.v'][ 'meta']['_sync'] results.load_state_dict(state_dict) # check device after loading assert results['validation_step.v'].value.device.type == device # sync fn was preserved in the original result assert results[ 'validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce # default sync fn new_results = ResultCollection(False, device) new_results.load_state_dict(state_dict, map_location='cpu') assert new_results['validation_step.v'].meta.sync.fn == _Sync.no_op # check map location assert new_results['validation_step.v'].value.device.type == 'cpu'
def __init__(self, min_steps: int, max_steps: int): super().__init__() self.min_steps: int = min_steps self.max_steps: int = max_steps self.global_step: int = 0 # the total batch index across all epochs self.total_batch_idx: int = 0 # the current batch index in the loop that runs over the dataloader(s) self.iteration_count: int = 0 # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx: Optional[int] = None self._dataloader_idx: Optional[int] = None self._should_stop: bool = False self.is_last_batch: Optional[bool] = None self.batches_seen: int = 0 self.warning_cache: WarningCache = WarningCache() self.epoch_output: Optional[List[List[STEP_OUTPUT]]] = None self.batch_loop: Optional[TrainingBatchLoop] = None self._results = ResultCollection(training=True)
def on_save_checkpoint(self, checkpoint) -> None: results = self.trainer._results # simplify logic state_dict = results.state_dict(drop_value=False) # check device assert results["validation_step.v"].value.device.type == device assert state_dict["items"]["validation_step.v"][ "value"].device.type == device # sync fn should be kept assert results[ "validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce # sync fn dropped from the state dict assert "fn" not in state_dict["items"]["validation_step.v"][ "meta"]["_sync"] results.load_state_dict(state_dict) # check device after loading assert results["validation_step.v"].value.device.type == device # sync fn was preserved in the original result assert results[ "validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce # default sync fn new_results = ResultCollection(False, device) new_results.load_state_dict(state_dict, map_location="cpu") assert new_results["validation_step.v"].meta.sync.fn is None # check map location assert new_results["validation_step.v"].value.device.type == "cpu"
def __init__(self): super().__init__() self.outputs: List[EPOCH_OUTPUT] = [] self.epoch_loop = EvaluationEpochLoop() self._results = ResultCollection(training=False) self._max_batches: Optional[Union[int, Sequence[int]]] = None self._has_run: bool = False
def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None, min_steps: Optional[int] = None, max_steps: Optional[int] = None): super().__init__() self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs self.training_loop = TrainingEpochLoop(min_steps, max_steps) self.results = ResultCollection(training=True)
def test_result_metric_integration(): metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() result = ResultCollection(True, torch.device("cpu")) for _ in range(3): cumulative_sum = 0 for i in range(5): metric_a(i) metric_b(i) metric_c(i) cumulative_sum += i result.log("h", "a", metric_a, on_step=True, on_epoch=True) result.log("h", "b", metric_b, on_step=False, on_epoch=True) result.log("h", "c", metric_c, on_step=True, on_epoch=False) batch_log = result.metrics(True)["log"] assert batch_log == {"a_step": i, "c": i} epoch_log = result.metrics(False)["log"] result.reset() # assert metric state reset to default values assert metric_a.x == metric_a._defaults["x"] assert metric_b.x == metric_b._defaults["x"] assert metric_c.x == metric_c._defaults["x"] assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum} result.minimize = torch.tensor(1.0) result.extra = {} assert str(result) == ("ResultCollection(" "minimize=1.0, " "{" "'h.a': ResultMetric('a', value=DummyMetric()), " "'h.b': ResultMetric('b', value=DummyMetric()), " "'h.c': ResultMetric('c', value=DummyMetric())" "})") assert repr(result) == ("{" "True, " "device(type='cpu'), " "minimize=tensor(1.), " "{'h.a': ResultMetric('a', value=DummyMetric()), " "'h.b': ResultMetric('b', value=DummyMetric()), " "'h.c': ResultMetric('c', value=DummyMetric()), " "'_extra': {}}" "}")
def __init__(self, min_steps: int, max_steps: int): super().__init__() self.min_steps: int = min_steps self.max_steps: int = max_steps self.global_step: int = 0 # manually tracking which is the last batch is necessary for iterable dataset support self.is_last_batch: Optional[bool] = None self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() self.batch_loop: Optional[TrainingBatchLoop] = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
def __init__(self, min_steps: int, max_steps: int): super().__init__() self.min_steps: int = min_steps self.max_steps: int = max_steps self.global_step: int = 0 # the total batch index across all epochs self.total_batch_idx: int = 0 self.is_last_batch: Optional[bool] = None self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() self.batch_loop: Optional[TrainingBatchLoop] = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) self._dataloader_idx: Optional[int] = None self._warning_cache: WarningCache = WarningCache() self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
def __init__( self, trainer, max_epochs: Optional[int], min_epochs: Optional[int], max_steps: Optional[int], min_steps: Optional[int], num_sanity_val_steps: int, ): self.trainer = trainer self.accumulated_loss = None self.warning_cache = WarningCache() self.running_loss = TensorRunningAccum(window_length=20) self._skip_backward = False self._optimizer_freq_cumsum = None self._hiddens = None self.global_step = 0 self.current_epoch = 0 self.trainer.should_stop = False # the total batch index across all epochs self.total_batch_idx = 0 # the current batch index in the loop that runs over the dataloader(s) self.batch_idx = 0 # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx = None self.trainer.num_training_batches = 0 self.trainer.train_dataloader = None # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs self.max_steps = max_steps self.min_steps = min_steps if num_sanity_val_steps == -1: self.trainer.num_sanity_val_steps = float("inf") else: self.trainer.num_sanity_val_steps = num_sanity_val_steps self.results = ResultCollection(training=True)
def _ddp_test_fn(rank, worldsize): _setup_ddp(rank, worldsize) torch.tensor([1.0]) metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() metric_a = metric_a.to(f"cuda:{rank}") metric_b = metric_b.to(f"cuda:{rank}") metric_c = metric_c.to(f"cuda:{rank}") result = ResultCollection(True, torch.device(f"cuda:{rank}")) for _ in range(3): cumulative_sum = 0 for i in range(5): metric_a(i) metric_b(i) metric_c(i) cumulative_sum += i result.log("h", "a", metric_a, on_step=True, on_epoch=True) result.log("h", "b", metric_b, on_step=False, on_epoch=True) result.log("h", "c", metric_c, on_step=True, on_epoch=False) batch_log = result.metrics(True)["log"] assert batch_log == {"a_step": i, "c": i} epoch_log = result.metrics(False)["log"] result.reset() # assert metric state reset to default values assert metric_a.x == metric_a._defaults["x"], (metric_a.x, metric_a._defaults["x"]) assert metric_b.x == metric_b._defaults["x"] assert metric_c.x == metric_c._defaults["x"] assert epoch_log == { "b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize }
def __init__(self, min_steps: int, max_steps: int): super().__init__() self.min_steps: int = min_steps if max_steps and max_steps < -1: raise MisconfigurationException( f"`max_steps` must be a positive integer or -1. You passed in {max_steps}." ) self.max_steps: int = max_steps self.global_step: int = 0 # manually tracking which is the last batch is necessary for iterable dataset support self.is_last_batch: Optional[bool] = None self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() self.batch_loop: Optional[TrainingBatchLoop] = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
def test_result_collection_simple_loop(): result = ResultCollection(True, torch.device("cpu")) current_fx_name = None batch_idx = None def lightning_log(fx, *args, **kwargs): nonlocal current_fx_name if current_fx_name != fx and batch_idx in (None, 0): result.reset(metrics=False, fx=fx) result.log(fx, *args, **kwargs) current_fx_name = fx lightning_log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True) lightning_log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True) for epoch in range(2): lightning_log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) lightning_log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) for batch_idx in range(2): lightning_log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) lightning_log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) lightning_log('c2', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) batch_idx = None lightning_log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) lightning_log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) for k in ('a0.a', 'a1.a'): assert result[k].value == torch.tensor(0.), k assert result[k].cumulated_batch_size == torch.tensor(1.), k for k in ('b0.a', 'b1.a'): assert result[k].value == torch.tensor(1.) + epoch, k assert result[k].cumulated_batch_size == torch.tensor(1.), k for k in ('c0.a', 'c1.a', 'c2.a'): assert result[k].value == torch.tensor(4.) + epoch * 2, k assert result[k].cumulated_batch_size == torch.tensor(2.), k for k in ('d0.a', 'd1.a'): assert result[k].value == torch.tensor(3.) + epoch, k assert result[k].cumulated_batch_size == torch.tensor(1.), k
def test_result_metric_integration(): metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() result = ResultCollection(True, torch.device("cpu")) for _ in range(3): cumulative_sum = 0 for i in range(5): metric_a(i) metric_b(i) metric_c(i) cumulative_sum += i result.log('h', 'a', metric_a, on_step=True, on_epoch=True) result.log('h', 'b', metric_b, on_step=False, on_epoch=True) result.log('h', 'c', metric_c, on_step=True, on_epoch=False) batch_log = result.metrics(True)[MetricSource.LOG] assert batch_log == {"a_step": i, "c": i} epoch_log = result.metrics(False)[MetricSource.LOG] result.reset() # assert metric state reset to default values assert metric_a.x == metric_a._defaults['x'] assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum} assert str(result) == ( "ResultCollection(True, cpu, {" "'h.a': ResultMetric('a', value=DummyMetric()), " "'h.b': ResultMetric('b', value=DummyMetric()), " "'h.c': ResultMetric('c', value=DummyMetric())" "})" )
def __init__(self, min_steps: int, max_steps: int): super().__init__() self.min_steps: int = min_steps self.max_steps: int = max_steps self.global_step: int = 0 # the total batch index across all epochs self.total_batch_idx: int = 0 # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx: Optional[int] = None # the number of batches seen this run, updates immediately after batch_loop.run() # TODO: replace by progress tracking self.batches_seen: int = 0 self.is_last_batch: Optional[bool] = None self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() self.batch_loop: Optional[TrainingBatchLoop] = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) self._dataloader_idx: Optional[int] = None self._warning_cache: WarningCache = WarningCache() self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
def test_result_collection_on_tensor_with_mean_reduction(): result_collection = ResultCollection(True, torch.device("cpu")) product = [(True, True), (False, True), (True, False), (False, False)] values = torch.arange(1, 10).float( ) # need to convert to float() due to precision issues using torch 1.4 batches = values * values for i, v in enumerate(values): for prog_bar in [False, True]: for logger in [False, True]: for on_step, on_epoch in product: name = "loss" if on_step: name += "_on_step" if on_epoch: name += "_on_epoch" if prog_bar: name += "_prog_bar" if logger: name += "_logger" log_kwargs = dict( fx="training_step", name=name, value=v, on_step=on_step, on_epoch=on_epoch, batch_size=batches[i], prog_bar=prog_bar, logger=logger, ) if not on_step and not on_epoch: with pytest.raises( MisconfigurationException, match="on_step=False, on_epoch=False"): result_collection.log(**log_kwargs) else: result_collection.log(**log_kwargs) total_value = sum(values * batches) total_batches = sum(batches) assert result_collection[ "training_step.loss_on_step_on_epoch"].value == total_value assert result_collection[ "training_step.loss_on_step_on_epoch"].cumulated_batch_size == total_batches batch_metrics = result_collection.metrics(True) max_ = max(values) assert batch_metrics["pbar"] == { "loss_on_step_on_epoch_prog_bar_step": max_, "loss_on_step_on_epoch_prog_bar_logger_step": max_, "loss_on_step_prog_bar": max_, "loss_on_step_prog_bar_logger": max_, } assert batch_metrics["log"] == { "loss_on_step_on_epoch_logger_step": max_, "loss_on_step_logger": max_, "loss_on_step_on_epoch_prog_bar_logger_step": max_, "loss_on_step_prog_bar_logger": max_, } assert batch_metrics["callback"] == { "loss_on_step": max_, "loss_on_step_logger": max_, "loss_on_step_on_epoch": max_, "loss_on_step_on_epoch_logger": max_, "loss_on_step_on_epoch_logger_step": max_, "loss_on_step_on_epoch_prog_bar": max_, "loss_on_step_on_epoch_prog_bar_logger": max_, "loss_on_step_on_epoch_prog_bar_logger_step": max_, "loss_on_step_on_epoch_prog_bar_step": max_, "loss_on_step_on_epoch_step": max_, "loss_on_step_prog_bar": max_, "loss_on_step_prog_bar_logger": max_, } epoch_metrics = result_collection.metrics(False) mean = total_value / total_batches assert epoch_metrics["pbar"] == { "loss_on_epoch_prog_bar": mean, "loss_on_epoch_prog_bar_logger": mean, "loss_on_step_on_epoch_prog_bar_epoch": mean, "loss_on_step_on_epoch_prog_bar_logger_epoch": mean, } assert epoch_metrics["log"] == { "loss_on_epoch_logger": mean, "loss_on_epoch_prog_bar_logger": mean, "loss_on_step_on_epoch_logger_epoch": mean, "loss_on_step_on_epoch_prog_bar_logger_epoch": mean, } assert epoch_metrics["callback"] == { "loss_on_epoch": mean, "loss_on_epoch_logger": mean, "loss_on_epoch_prog_bar": mean, "loss_on_epoch_prog_bar_logger": mean, "loss_on_step_on_epoch": mean, "loss_on_step_on_epoch_epoch": mean, "loss_on_step_on_epoch_logger": mean, "loss_on_step_on_epoch_logger_epoch": mean, "loss_on_step_on_epoch_prog_bar": mean, "loss_on_step_on_epoch_prog_bar_epoch": mean, "loss_on_step_on_epoch_prog_bar_logger": mean, "loss_on_step_on_epoch_prog_bar_logger_epoch": mean, }
def test_result_collection_extra_reference(): """Unit-test to check that the `extra` dict reference is properly set.""" rc = ResultCollection(True) assert rc.extra is rc["_extra"]
def test_result_collection_restoration(tmpdir): """" This test make sure metrics are properly reloaded on failure. """ result = ResultCollection(True, torch.device("cpu")) metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() metric_d = DummyMetric() current_fx_name = None batch_idx = None def lightning_log(fx, *args, **kwargs): nonlocal current_fx_name if current_fx_name != fx and batch_idx in (None, 0): result.reset(metrics=False, fx=fx) result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist) current_fx_name = fx for _ in range(2): cumulative_sum = 0 for i in range(3): a = metric_a(i) b = metric_b(i) c = metric_c(i) metric_d(i) cumulative_sum += i metric = metric_a if i < 1 else metric_d lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True) lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True) lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False) lightning_log('training_step', 'a_1', a, on_step=True, on_epoch=True) lightning_log('training_step', 'b_1', b, on_step=False, on_epoch=True) lightning_log('training_step', 'c_1', { '1': c, '2': c }, on_step=True, on_epoch=False) batch_log = result.metrics(on_step=True)[MetricSource.LOG] assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"} assert set(batch_log['c_1']) == {'1', '2'} result_copy = deepcopy(result) new_result = ResultCollection(True, torch.device("cpu")) state_dict = result.state_dict() # check the sync fn was dropped assert 'fn' not in state_dict['items']['training_step.a']['meta'][ '_sync'] new_result.load_state_dict(state_dict) # should match assert result_copy == new_result # the sync fn has been kept assert result_copy['training_step.a'].meta.sync.fn == new_result[ 'training_step.a'].meta.sync.fn epoch_log = result.metrics(on_step=False)[MetricSource.LOG] epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG] assert epoch_log == epoch_log_copy lightning_log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True) epoch_log = result.metrics(on_step=False)[MetricSource.LOG] assert epoch_log == { 'a_1_epoch': 1, 'a_epoch': cumulative_sum, 'a': cumulative_sum, 'b': cumulative_sum, 'b_1': 1 } # make sure can be pickled pickle.loads(pickle.dumps(result)) # make sure can be torch.loaded filepath = str(tmpdir / 'result') torch.save(result, filepath) torch.load(filepath) # assert metric state reset to default values result.reset() assert metric_a.x == metric_a._defaults['x'] assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] batch_idx = None
def test_result_collection_restoration(tmpdir): """This test make sure metrics are properly reloaded on failure.""" result = ResultCollection(True, torch.device("cpu")) metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() metric_d = DummyMetric() current_fx_name = None batch_idx = None def lightning_log(fx, *args, **kwargs): nonlocal current_fx_name if current_fx_name != fx and batch_idx in (None, 0): result.reset(metrics=False, fx=fx) result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist) current_fx_name = fx for epoch in range(2): cumulative_sum = 0 for i in range(3): a = metric_a(i) b = metric_b(i) c = metric_c(i) metric_d(i) cumulative_sum += i metric = metric_a if i < 1 else metric_d lightning_log("training_step", "a", metric, on_step=True, on_epoch=True, metric_attribute="metric") lightning_log("training_step", "b", metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") lightning_log("training_step", "c", metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") lightning_log("training_step", "a_1", a, on_step=True, on_epoch=True) lightning_log("training_step", "b_1", b, on_step=False, on_epoch=True) lightning_log("training_step", "c_1", { "1": c, "2": c }, on_step=True, on_epoch=False) batch_log = result.metrics(on_step=True)["log"] assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"} assert set(batch_log["c_1"]) == {"1", "2"} result_copy = deepcopy(result) new_result = ResultCollection(True, torch.device("cpu")) state_dict = result.state_dict() # check the sync fn was dropped assert "fn" not in state_dict["items"]["training_step.a"]["meta"][ "_sync"] assert not new_result.result_metrics assert len(result.result_metrics) == 7 + epoch > 0 new_result.load_state_dict(state_dict, metrics={ "metric": metric, "metric_b": metric_b, "metric_c": metric_c }) # should match assert result_copy == new_result # the sync fn has been kept assert result_copy["training_step.a"].meta.sync.fn == new_result[ "training_step.a"].meta.sync.fn epoch_log = result.metrics(on_step=False)["log"] epoch_log_copy = result_copy.metrics(on_step=False)["log"] assert epoch_log == epoch_log_copy lightning_log("train_epoch_end", "a", metric_a, on_step=False, on_epoch=True) epoch_log = result.metrics(on_step=False)["log"] assert epoch_log == { "a_1_epoch": 1, "a_epoch": cumulative_sum, "a": cumulative_sum, "b": cumulative_sum, "b_1": 1, } # make sure can be pickled pickle.loads(pickle.dumps(result)) # make sure can be torch.loaded filepath = str(tmpdir / "result") torch.save(result, filepath) torch.load(filepath) # assert metric state reset to default values result.reset() assert metric_a.x == metric_a._defaults["x"] assert metric_b.x == metric_b._defaults["x"] assert metric_c.x == metric_c._defaults["x"] batch_idx = None
def test_result_collection_simple_loop(): result = ResultCollection(True, torch.device("cpu")) current_fx_name = None batch_idx = None def lightning_log(fx, *args, **kwargs): nonlocal current_fx_name if current_fx_name != fx and batch_idx in (None, 0): result.reset(metrics=False, fx=fx) result.log(fx, *args, **kwargs) current_fx_name = fx lightning_log("a0", "a", torch.tensor(0.0), on_step=True, on_epoch=True) lightning_log("a1", "a", torch.tensor(0.0), on_step=True, on_epoch=True) for epoch in range(2): lightning_log("b0", "a", torch.tensor(1.0) + epoch, on_step=True, on_epoch=True) lightning_log("b1", "a", torch.tensor(1.0) + epoch, on_step=True, on_epoch=True) for batch_idx in range(2): lightning_log("c0", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True) lightning_log("c1", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True) lightning_log("c2", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True) batch_idx = None lightning_log("d0", "a", torch.tensor(3.0) + epoch, on_step=False, on_epoch=True) lightning_log("d1", "a", torch.tensor(3.0) + epoch, on_step=False, on_epoch=True) for k in ("a0.a", "a1.a"): assert result[k].value == torch.tensor(0.0), k assert result[k].cumulated_batch_size == torch.tensor(1.0), k for k in ("b0.a", "b1.a"): assert result[k].value == torch.tensor(1.0) + epoch, k assert result[k].cumulated_batch_size == torch.tensor(1.0), k for k in ("c0.a", "c1.a", "c2.a"): assert result[k].value == torch.tensor(4.0) + epoch * 2, k assert result[k].cumulated_batch_size == torch.tensor(2.0), k for k in ("d0.a", "d1.a"): assert result[k].value == torch.tensor(3.0) + epoch, k assert result[k].cumulated_batch_size == torch.tensor(1.0), k
def test_result_collection_on_tensor_with_mean_reduction(): result_collection = ResultCollection(True, torch.device("cpu")) product = [(True, True), (False, True), (True, False), (False, False)] values = torch.arange(1, 10).float( ) # need to convert to float() due to precision issues using torch 1.4 batches = values * values for i, v in enumerate(values): for prog_bar in [False, True]: for logger in [False, True]: for on_step, on_epoch in product: name = "loss" if on_step: name += "_on_step" if on_epoch: name += "_on_epoch" if prog_bar: name += "_prog_bar" if logger: name += "_logger" result_collection.log( "training_step", name, v, on_step=on_step, on_epoch=on_epoch, batch_size=batches[i], prog_bar=prog_bar, logger=logger, ) total_value = sum(values * batches) total_batches = sum(batches) assert result_collection[ "training_step.loss_on_step_on_epoch"].value == total_value assert result_collection[ "training_step.loss_on_step_on_epoch"].cumulated_batch_size == total_batches batch_metrics = result_collection.metrics(True) max_ = max(values) assert batch_metrics[MetricSource.PBAR] == { "loss_on_step_on_epoch_prog_bar_step": max_, "loss_on_step_on_epoch_prog_bar_logger_step": max_, "loss_on_step_prog_bar": max_, "loss_on_step_prog_bar_logger": max_, } assert batch_metrics[MetricSource.LOG] == { "loss_on_step_on_epoch_logger_step": max_, "loss_on_step_logger": max_, "loss_on_step_on_epoch_prog_bar_logger_step": max_, "loss_on_step_prog_bar_logger": max_, } assert batch_metrics[MetricSource.CALLBACK] == { "loss_on_step": max_, "loss_on_step_logger": max_, "loss_on_step_on_epoch": max_, "loss_on_step_on_epoch_logger": max_, "loss_on_step_on_epoch_logger_step": max_, "loss_on_step_on_epoch_prog_bar": max_, "loss_on_step_on_epoch_prog_bar_logger": max_, "loss_on_step_on_epoch_prog_bar_logger_step": max_, "loss_on_step_on_epoch_prog_bar_step": max_, "loss_on_step_on_epoch_step": max_, "loss_on_step_prog_bar": max_, "loss_on_step_prog_bar_logger": max_, } epoch_metrics = result_collection.metrics(False) mean = total_value / total_batches assert epoch_metrics[MetricSource.PBAR] == { "loss_on_epoch_prog_bar": mean, "loss_on_epoch_prog_bar_logger": mean, "loss_on_step_on_epoch_prog_bar_epoch": mean, "loss_on_step_on_epoch_prog_bar_logger_epoch": mean, } assert epoch_metrics[MetricSource.LOG] == { "loss_on_epoch_logger": mean, "loss_on_epoch_prog_bar_logger": mean, "loss_on_step_on_epoch_logger_epoch": mean, "loss_on_step_on_epoch_prog_bar_logger_epoch": mean, } assert epoch_metrics[MetricSource.CALLBACK] == { "loss_on_epoch": mean, "loss_on_epoch_logger": mean, "loss_on_epoch_prog_bar": mean, "loss_on_epoch_prog_bar_logger": mean, "loss_on_step_on_epoch": mean, "loss_on_step_on_epoch_epoch": mean, "loss_on_step_on_epoch_logger": mean, "loss_on_step_on_epoch_logger_epoch": mean, "loss_on_step_on_epoch_prog_bar": mean, "loss_on_step_on_epoch_prog_bar_epoch": mean, "loss_on_step_on_epoch_prog_bar_logger": mean, "loss_on_step_on_epoch_prog_bar_logger_epoch": mean, }