def test_metric_result_computed_check():
    """Unittest ``_get_cache`` with multielement tensors."""
    metadata = _Metadata("foo", "bar", on_epoch=True, enable_graph=True)
    metadata.sync = _Sync()
    rm = _ResultMetric(metadata, is_tensor=True)
    computed_value = torch.tensor([1, 2, 3])
    rm._computed = computed_value
    cache = _ResultCollection._get_cache(rm, on_step=False)
    # `enable_graph=True` so no detach, identity works
    assert cache is computed_value
Beispiel #2
0
def test_logger_sync_dist(distributed_env):
    # self.log('bar', 7, ..., sync_dist=False)
    meta = _Metadata("foo", "bar")
    meta.sync = _Sync(_should=False)
    result_metric = _ResultMetric(metadata=meta, is_tensor=True)
    result_metric.update(torch.tensor(7.0), 10)

    warning_ctx = pytest.warns if distributed_env else no_warning_call

    with mock.patch(
            "pytorch_lightning.trainer.connectors.logger_connector.result.distributed_available",
            return_value=distributed_env,
    ):
        with warning_ctx(
                PossibleUserWarning,
                match=
                r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"
        ):
            value = _ResultCollection._get_cache(result_metric, on_step=False)
        assert value == 7.0