def test_metric_result_respects_dtype(floating_dtype): torch.set_default_dtype(floating_dtype) fixed_dtype = torch.long # default by PyTorch metadata = _Metadata("foo", "bar") metadata.sync = _Sync() rm = _ResultMetric(metadata, is_tensor=True) assert rm.value.dtype == floating_dtype assert rm.cumulated_batch_size.dtype == fixed_dtype # two fixed point numbers - should be converted value, batch_size = torch.tensor(2), 3 assert value.dtype == fixed_dtype with pytest.warns( UserWarning, match= rf"`self.log\('bar', ...\)` in your `foo` .* Converting it to {floating_dtype}" ): rm.update(value, batch_size) # floating and fixed rm.update(torch.tensor(4.0), 5) total = rm.compute() assert total == (2 * 3 + 4 * 5) / (5 + 3) assert total.dtype == floating_dtype # restore to avoid impacting other tests torch.set_default_dtype(torch.float)
def _ddp_test_fn(rank, worldsize): _setup_ddp(rank, worldsize) tensor = torch.tensor([1.0]) sync = _Sync(sync_ddp_if_available, _should=True, op="SUM") actual = sync(tensor) assert actual.item() == dist.get_world_size( ), "Result-Log does not work properly with DDP and Tensors"
def test_sync_dist(_): sync = _Sync(TPUSpawnPlugin().reduce, should=True, op=torch.distributed.ReduceOp.SUM) value = torch.tensor([1.0]) value = (sync(value), ) assert value.item() == 8
def test_metric_result_computed_check(): """Unittest ``_get_cache`` with multielement tensors.""" metadata = _Metadata("foo", "bar", on_epoch=True, enable_graph=True) metadata.sync = _Sync() rm = _ResultMetric(metadata, is_tensor=True) computed_value = torch.tensor([1, 2, 3]) rm._computed = computed_value cache = _ResultCollection._get_cache(rm, on_step=False) # `enable_graph=True` so no detach, identity works assert cache is computed_value
def test_metric_result_dtype_promotion(reduce_fx): metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx) metadata.sync = _Sync() rm = _ResultMetric(metadata, is_tensor=True) assert rm.value.dtype == torch.float # log a double rm.update(torch.tensor(0, dtype=torch.double), 1) # `rm.value.dtype` is promoted assert rm.value.dtype == torch.double # log a float rm.update(torch.tensor(0, dtype=torch.float), 1) # the previous dtype stays assert rm.value.dtype == torch.double total = rm.compute() assert total.dtype == torch.double
def test_logger_sync_dist(distributed_env): # self.log('bar', 7, ..., sync_dist=False) meta = _Metadata("foo", "bar") meta.sync = _Sync(_should=False) result_metric = _ResultMetric(metadata=meta, is_tensor=True) result_metric.update(torch.tensor(7.0), 10) warning_ctx = pytest.warns if distributed_env else no_warning_call with mock.patch( "pytorch_lightning.trainer.connectors.logger_connector.result.distributed_available", return_value=distributed_env, ): with warning_ctx( PossibleUserWarning, match= r"recommended to use `self.log\('bar', ..., sync_dist=True\)`" ): value = _ResultCollection._get_cache(result_metric, on_step=False) assert value == 7.0
def test_result_metric_max_min(reduce_fx, expected): metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx) metadata.sync = _Sync() rm = _ResultMetric(metadata, is_tensor=True) rm.update(torch.tensor(expected), 1) assert rm.compute() == expected