def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]): """Test that SLURM environment variables are properly checked for rank_zero_only.""" from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only rank_zero_only.rank = _get_rank() with mock.patch.dict(os.environ, env_vars): from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only rank_zero_only.rank = _get_rank() @rank_zero_only def foo(): # The return type is optional because on non-zero ranks it will not be called return 1 x = foo() assert x == 1
def test_rank_zero_none_set(rank_key, rank): """Test that function is not called when rank environment variables are not global zero.""" with mock.patch.dict(os.environ, {rank_key: rank}): from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only rank_zero_only.rank = _get_rank() @rank_zero_only def foo(): return 1 x = foo() assert x is None
def test_rank_zero_priority(environ, expected_rank): """Test the priority in which the rank gets determined when multiple environment variables are available.""" with mock.patch.dict(os.environ, environ): from pytorch_lightning.utilities.rank_zero import _get_rank assert _get_rank() == expected_rank