def test_detect(): """Test the detection of a SLURM environment configuration.""" with mock.patch.dict(os.environ, {}, clear=True): assert not SLURMEnvironment.detect() with mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}): assert SLURMEnvironment.detect()
def _is_slurm_managing_tasks(self) -> bool: """used by choosing cluster enviroment.""" if not SLURMEnvironment.detect() or SLURMEnvironment.job_name() == "bash": return False total_requested_devices = len(self._parallel_devices) * self._num_nodes_flag num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0) return num_slurm_tasks == total_requested_devices
def _is_slurm_managing_tasks(self) -> bool: """Returns whether we let SLURM manage the processes or not. Returns ``True`` if and only if these conditions match: - A SLURM cluster is detected - A distributed plugin is being used - The process is not launching in interactive mode - The number of tasks in SLURM matches the requested number of devices and nodes in the Trainer """ if ( (not self.use_ddp and not self.use_ddp2) or not SLURMEnvironment.detect() or SLURMEnvironment.job_name() == "bash" # in interactive mode we don't manage tasks ): return False total_requested_devices = (self.num_gpus or self.num_processes) * self.num_nodes num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0) return num_slurm_tasks == total_requested_devices