def test_metric_result_computed_check(): """Unittest ``_get_cache`` with multielement tensors.""" metadata = _Metadata("foo", "bar", on_epoch=True, enable_graph=True) metadata.sync = _Sync() rm = _ResultMetric(metadata, is_tensor=True) computed_value = torch.tensor([1, 2, 3]) rm._computed = computed_value cache = _ResultCollection._get_cache(rm, on_step=False) # `enable_graph=True` so no detach, identity works assert cache is computed_value
def test_logger_sync_dist(distributed_env): # self.log('bar', 7, ..., sync_dist=False) meta = _Metadata("foo", "bar") meta.sync = _Sync(_should=False) result_metric = _ResultMetric(metadata=meta, is_tensor=True) result_metric.update(torch.tensor(7.0), 10) warning_ctx = pytest.warns if distributed_env else no_warning_call with mock.patch( "pytorch_lightning.trainer.connectors.logger_connector.result.distributed_available", return_value=distributed_env, ): with warning_ctx( PossibleUserWarning, match= r"recommended to use `self.log\('bar', ..., sync_dist=True\)`" ): value = _ResultCollection._get_cache(result_metric, on_step=False) assert value == 7.0