def test_attributes_from_environment_variables(caplog):
    """Test that the SLURM cluster environment takes the attributes from the environment variables."""
    env = SLURMEnvironment()
    assert env.auto_requeue is True
    assert env.main_address == "1.1.1.1"
    assert env.main_port == 15000 + 1234
    assert env.job_id() == int("0001234")
    assert env.world_size() == 20
    assert env.global_rank() == 1
    assert env.local_rank() == 2
    assert env.node_rank() == 3
    assert env.job_name() == "JOB"
    # setter should be no-op
    with caplog.at_level(logging.DEBUG,
                         logger="pytorch_lightning.plugins.environments"):
        env.set_global_rank(100)
    assert env.global_rank() == 1
    assert "setting global rank is not allowed" in caplog.text

    caplog.clear()

    with caplog.at_level(logging.DEBUG,
                         logger="pytorch_lightning.plugins.environments"):
        env.set_world_size(100)
    assert env.world_size() == 20
    assert "setting world size is not allowed" in caplog.text
示例#2
0
    def _is_slurm_managing_tasks(self) -> bool:
        """used by choosing cluster enviroment."""
        if not SLURMEnvironment.detect() or SLURMEnvironment.job_name() == "bash":
            return False

        total_requested_devices = len(self._parallel_devices) * self._num_nodes_flag
        num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0)
        return num_slurm_tasks == total_requested_devices
def test_default_attributes():
    """Test the default attributes when no environment variables are set."""
    env = SLURMEnvironment()
    assert env.creates_processes_externally
    assert env.main_address == "127.0.0.1"
    assert env.main_port == 12910
    assert env.job_name() is None
    assert env.job_id() is None

    with pytest.raises(KeyError):
        # world size is required to be passed as env variable
        env.world_size()
    with pytest.raises(KeyError):
        # local rank is required to be passed as env variable
        env.local_rank()
    with pytest.raises(KeyError):
        # node_rank is required to be passed as env variable
        env.node_rank()
    def _is_slurm_managing_tasks(self) -> bool:
        """Returns whether we let SLURM manage the processes or not.

        Returns ``True`` if and only if these conditions match:

            - A SLURM cluster is detected
            - A distributed plugin is being used
            - The process is not launching in interactive mode
            - The number of tasks in SLURM matches the requested number of devices and nodes in the Trainer
        """
        if (
            (not self.use_ddp and not self.use_ddp2)
            or not SLURMEnvironment.detect()
            or SLURMEnvironment.job_name() == "bash"  # in interactive mode we don't manage tasks
        ):
            return False

        total_requested_devices = (self.num_gpus or self.num_processes) * self.num_nodes
        num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0)
        return num_slurm_tasks == total_requested_devices