def test_launch_shutdown(self, agent_mock_cls): nnodes = 1 nproc_per_node = 4 args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--monitor_interval=1", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] agent_mock = Mock() agent_mock.run.return_value = RunResult(WorkerState.SUCCEEDED) agent_mock_cls.return_value = agent_mock rdzv_handler_mock = Mock() with patch("torchelastic.rendezvous.registry.get_rendezvous_handler" ) as param_mock: param_mock.return_value = rdzv_handler_mock launch.main(args) rdzv_handler_mock.shutdown.assert_called_once()
def test_launch_standalone(self): nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--standalone", "--monitor_interval=1", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) )
def test_launch_user_script_bash(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", f"--no_python", ] script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] with self.assertRaises(ValueError): # --no_python also requires --use_env launch.main(args + script_args) with self.assertRaises(ValueError): # --no_python cannot be used with --module launch.main(args + ["--module"] + script_args) launch.main(args + ["--use_env"] + script_args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)))
def test_launch_user_script_python(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) )
def launch_in_proc(args): launch.main(args)