コード例 #1
0
 def select_cluster_environment(self) -> ClusterEnvironment:
     if self._cluster_environment is not None:
         return self._cluster_environment
     if self.is_slurm_managing_tasks:
         env = SLURMEnvironment()
     elif TorchElasticEnvironment.is_using_torchelastic():
         env = TorchElasticEnvironment()
     elif KubeflowEnvironment.is_using_kubeflow():
         env = KubeflowEnvironment()
     else:
         env = LightningEnvironment()
     return env
コード例 #2
0
 def select_cluster_environment(self) -> ClusterEnvironment:
     if self._cluster_environment is not None:
         return self._cluster_environment
     if self._is_slurm_managing_tasks():
         env = SLURMEnvironment()
         rank_zero_info("Multiprocessing is handled by SLURM.")
     elif TorchElasticEnvironment.is_using_torchelastic():
         env = TorchElasticEnvironment()
     elif KubeflowEnvironment.is_using_kubeflow():
         env = KubeflowEnvironment()
     elif LSFEnvironment.is_using_lsf():
         env = LSFEnvironment()
     else:
         env = LightningEnvironment()
     return env
コード例 #3
0
def test_default_attributes():
    """Test the default attributes when no environment variables are set."""
    env = KubeflowEnvironment()
    assert env.creates_processes_externally

    with pytest.raises(KeyError):
        # MASTER_ADDR is required
        env.main_address
    with pytest.raises(KeyError):
        # MASTER_PORT is required
        env.main_port
    with pytest.raises(KeyError):
        # WORLD_SIZE is required
        env.world_size()
    with pytest.raises(KeyError):
        # RANK is required
        env.global_rank()
    assert env.local_rank() == 0
コード例 #4
0
def test_attributes_from_environment_variables(caplog):
    """Test that the torchelastic cluster environment takes the attributes from the environment variables."""
    env = KubeflowEnvironment()
    assert env.master_address() == "1.2.3.4"
    assert env.master_port() == 500
    assert env.world_size() == 20
    assert env.global_rank() == 1
    assert env.local_rank() == 0
    assert env.node_rank() == 1
    # setter should be no-op
    with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"):
        env.set_global_rank(100)
    assert env.global_rank() == 1
    assert "setting global rank is not allowed" in caplog.text

    caplog.clear()

    with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"):
        env.set_world_size(100)
    assert env.world_size() == 20
    assert "setting world size is not allowed" in caplog.text