def test_detect():
    """Test the detection of a SLURM environment configuration."""
    with mock.patch.dict(os.environ, {}, clear=True):
        assert not SLURMEnvironment.detect()

    with mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}):
        assert SLURMEnvironment.detect()
Exemple #2
0
    def _is_slurm_managing_tasks(self) -> bool:
        """used by choosing cluster enviroment."""
        if not SLURMEnvironment.detect() or SLURMEnvironment.job_name() == "bash":
            return False

        total_requested_devices = len(self._parallel_devices) * self._num_nodes_flag
        num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0)
        return num_slurm_tasks == total_requested_devices
    def _is_slurm_managing_tasks(self) -> bool:
        """Returns whether we let SLURM manage the processes or not.

        Returns ``True`` if and only if these conditions match:

            - A SLURM cluster is detected
            - A distributed plugin is being used
            - The process is not launching in interactive mode
            - The number of tasks in SLURM matches the requested number of devices and nodes in the Trainer
        """
        if (
            (not self.use_ddp and not self.use_ddp2)
            or not SLURMEnvironment.detect()
            or SLURMEnvironment.job_name() == "bash"  # in interactive mode we don't manage tasks
        ):
            return False

        total_requested_devices = (self.num_gpus or self.num_processes) * self.num_nodes
        num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0)
        return num_slurm_tasks == total_requested_devices