예제 #1
0
def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]):
    """Test that SLURM environment variables are properly checked for rank_zero_only."""
    from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only

    rank_zero_only.rank = _get_rank()

    with mock.patch.dict(os.environ, env_vars):
        from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only

        rank_zero_only.rank = _get_rank()

        @rank_zero_only
        def foo():  # The return type is optional because on non-zero ranks it will not be called
            return 1

        x = foo()
        assert x == 1
def test_rank_zero_none_set(rank_key, rank):
    """Test that function is not called when rank environment variables are not global zero."""

    with mock.patch.dict(os.environ, {rank_key: rank}):
        from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only

        rank_zero_only.rank = _get_rank()

        @rank_zero_only
        def foo():
            return 1

        x = foo()
        assert x is None
def test_rank_zero_priority(environ, expected_rank):
    """Test the priority in which the rank gets determined when multiple environment variables are available."""
    with mock.patch.dict(os.environ, environ):
        from pytorch_lightning.utilities.rank_zero import _get_rank

        assert _get_rank() == expected_rank