def test_invalid_log_dir(self): with tempfile.NamedTemporaryFile(dir=self.test_dir) as not_a_dir: cases = { "does_not_exist": FileNotFoundError, not_a_dir.name: NotADirectoryError, # test_dir is not empty since we touched not_a_dir file self.test_dir: RuntimeError, } for (log_dir, expected_error) in cases.items(): with self.subTest(log_dir=log_dir, expected_error=expected_error): with self.assertRaises(expected_error): start_processes( name="echo", entrypoint=echo1, args={0: ("hello", )}, envs={0: { "RANK": "0" }}, log_dir=log_dir, )
def test_args_env_len_mismatch(self): cases = [ # 1 x args; 2 x envs { "args": {0: ("hello",)}, "envs": {0: {"RANK": "0"}, 1: {"RANK": "1"}}, }, # 2 x args; 1 x envs { "args": {0: ("hello",), 1: ("world",)}, "envs": {0: {"RANK": "0"}}, }, ] for kwds in cases: args = kwds["args"] envs = kwds["envs"] with self.subTest(args=args, envs=envs): with self.assertRaises(RuntimeError): start_processes( name="echo", entrypoint=echo1, args=args, envs=envs, log_dir=self.log_dir(), )
def test_binary_incorrect_entrypoint(self): with self.assertRaises(FileNotFoundError): start_processes( name="echo", entrypoint="does_not_exist.py", args={0: ("foo"), 1: ("bar",)}, envs={0: {}, 1: {}}, log_dir=self.log_dir(), )
def test_binary_exit(self): FAIL = 138 pc = start_processes( name="echo", entrypoint=self.bin("echo1.py"), args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, log_dir=self.log_dir(), redirects={0: Std.ALL}, ) results = pc.wait(period=0.1) self.assertTrue(results.is_failed()) self.assertEqual(1, len(results.failures)) failure = results.failures[0] self.assertEqual(138, failure.exitcode) self.assertEqual("<N/A>", failure.signal_name()) self.assertEqual("<NONE>", failure.error_file_data["message"]) self.assert_in_file([f"exit {FAIL} from 0"], results.stderrs[0]) self.assert_in_file([], results.stdouts[0]) self.assertFalse(results.stderrs[1]) self.assertFalse(results.stdouts[1]) self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped())
def test_binary(self): for redirs in redirects(): with self.subTest(redirs=redirs): pc = start_processes( name="echo", entrypoint=self.bin("echo1.py"), args={0: ("hello",), 1: ("hello",)}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, log_dir=self.log_dir(), redirects=redirs, ) results = pc.wait(period=0.1) self.assert_pids_noexist(pc.pids()) # binaries are not functions hence do not have return values self.assertEqual(0, len(results.return_values)) self.assertFalse(results.is_failed()) nprocs = pc.nprocs for i in range(nprocs): if redirs & Std.OUT != Std.OUT: self.assertFalse(results.stdouts[i]) if redirs & Std.ERR != Std.ERR: self.assertFalse(results.stderrs[i]) if redirs & Std.OUT == Std.OUT: self.assert_in_file( [f"hello stdout from {i}"], results.stdouts[i] ) if redirs & Std.ERR == Std.ERR: self.assert_in_file( [f"hello stderr from {i}"], results.stderrs[i] )
def test_function_signal(self): """ run 2x copies of echo3, induce a segfault on first """ SEGFAULT = True for start_method, redirs in product(start_methods(), redirects()): with self.subTest(start_method=start_method): log_dir = self.log_dir() pc = start_processes( name="echo", entrypoint=echo3, args={0: ("hello", SEGFAULT), 1: ("world",)}, envs={0: {}, 1: {}}, log_dir=log_dir, start_method=start_method, redirects=redirs, ) results = pc.wait(period=0.1) self.assert_pids_noexist(pc.pids()) self.assertEqual(1, len(results.failures)) self.assertFalse(results.return_values) failure = results.failures[0] error_file = failure.error_file self.assertEqual(-signal.SIGSEGV, failure.exitcode) self.assertEqual("SIGSEGV", failure.signal_name()) self.assertEqual(pc.pids()[0], failure.pid) self.assertEqual(os.path.join(log_dir, "0", "error.json"), error_file)
def test_function(self): for start_method, redirs in product(start_methods(), redirects()): with self.subTest(start_method=start_method, redirs=redirs): pc = start_processes( name="echo", entrypoint=echo1, args={0: ("hello",), 1: ("hello",)}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, log_dir=self.log_dir(), start_method=start_method, redirects=redirs, ) results = pc.wait(period=0.1) nprocs = pc.nprocs self.assert_pids_noexist(pc.pids()) self.assertEqual( {i: f"hello_{i}" for i in range(nprocs)}, results.return_values ) for i in range(nprocs): if redirs & Std.OUT != Std.OUT: self.assertFalse(results.stdouts[i]) if redirs & Std.ERR != Std.ERR: self.assertFalse(results.stderrs[i]) if redirs & Std.OUT == Std.OUT: self.assert_in_file( [f"hello stdout from {i}"], results.stdouts[i] ) if redirs & Std.ERR == Std.ERR: self.assert_in_file( [f"hello stderr from {i}"], results.stderrs[i] )
def test_binary_signal(self): pc = start_processes( name="echo", entrypoint=self.bin("echo3.py"), args={ 0: ("--segfault", "true", "foo"), 1: ("bar", ) }, envs={ 0: { "RANK": "0" }, 1: { "RANK": "1" } }, log_dir=self.log_dir(), ) results = pc.wait(period=0.1) self.assert_pids_noexist(pc.pids()) self.assertTrue(results.is_failed()) self.assertEqual(1, len(results.failures)) failure = results.failures[0] self.assertNotEqual(signal.SIGSEGV, failure.exitcode) self.assertEqual("SIGSEGV", failure.signal_name()) self.assertEqual("<NONE>", failure.error_file_data["message"])
def _start_mp( self, dist_infos: Dict[int, _DistInfo], spec: WorkerSpec ) -> BaseProcessContext: proc_params = [ MpParameters(fn=_wrap, args=(dist_infos, spec.fn, spec.args)) ] * spec.local_world_size return start_processes(proc_params, start_method=self._start_method)
def test_function_large_ret_val(self): # python multiprocessing.queue module uses pipes and actually PipedQueues # This means that if a single object is greater than a pipe size # the writer process will block until reader process will start # reading the pipe. # This test makes a worker fn to return huge output, around ~10 MB size = 200000 for start_method in start_methods(): with self.subTest(start_method=start_method): pc = start_processes( name="echo", entrypoint=echo_large, args={ 0: (size, ), 1: (size, ), 2: (size, ), 3: (size, ) }, envs={ 0: {}, 1: {}, 2: {}, 3: {} }, log_dir=self.log_dir(), start_method=start_method, ) results = pc.wait(period=0.1) for i in range(pc.nprocs): self.assertEqual(size, len(results.return_values[i]))
def test_function_exit(self): """ run 2x copies of echo1 fail (exit) the first functions that exit from python do not generate an error file (even if they are decorated with @record) """ FAIL = 138 for start_method in start_methods(): with self.subTest(start_method=start_method): log_dir = self.log_dir() pc = start_processes( name="echo", entrypoint=echo1, args={ 0: ("hello", FAIL), 1: ("hello", ) }, envs={ 0: { "RANK": "0" }, 1: { "RANK": "1" } }, log_dir=log_dir, start_method=start_method, redirects={0: Std.ERR}, ) results = pc.wait(period=0.1) self.assert_pids_noexist(pc.pids()) self.assertTrue(results.is_failed()) self.assertEqual(1, len(results.failures)) self.assertFalse(results.return_values) failure = results.failures[0] error_file = failure.error_file self.assertEqual(FAIL, failure.exitcode) self.assertEqual("<N/A>", failure.signal_name()) self.assertEqual(pc.pids()[0], failure.pid) self.assertEqual("<N/A>", error_file) self.assertEqual(f"Process failed with exitcode {FAIL}", failure.message) self.assertLessEqual(failure.timestamp, int(time.time())) self.assert_in_file([f"exit {FAIL} from 0"], results.stderrs[0]) self.assertFalse(results.stdouts[0]) self.assertFalse(results.stderrs[1]) self.assertFalse(results.stdouts[1]) self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped())
def test_pcontext_wait(self): pc = start_processes( name="sleep", entrypoint=time.sleep, args={0: (1, )}, envs={0: {}}, log_dir=self.log_dir(), start_method="fork", ) self.assertIsNone(pc.wait(timeout=0.1, period=0.01)) self.assertIsNotNone(pc.wait(period=0.1)) self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped())
def test_void_function(self): for start_method in start_methods(): with self.subTest(start_method=start_method): pc = start_processes( name="echo", entrypoint=echo0, args={0: ("hello",), 1: ("world",)}, envs={0: {}, 1: {}}, log_dir=self.log_dir(), start_method=start_method, ) results = pc.wait(period=0.1) self.assertEqual({0: None, 1: None}, results.return_values)
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: spec = worker_group.spec store = worker_group.store master_addr, master_port = super()._get_master_addr_port(store) restart_count = spec.max_restarts - self._remaining_restarts args: Dict[int, Tuple] = {} envs: Dict[int, Dict[str, str]] = {} for worker in worker_group.workers: local_rank = worker.local_rank worker_env = { "LOCAL_RANK": str(local_rank), "RANK": str(worker.global_rank), "GROUP_RANK": str(worker_group.group_rank), "ROLE_RANK": str(worker.role_rank), "ROLE_NAME": spec.role, "LOCAL_WORLD_SIZE": str(spec.local_world_size), "WORLD_SIZE": str(worker.world_size), "ROLE_WORLD_SIZE": str(worker.role_world_size), "MASTER_ADDR": master_addr, "MASTER_PORT": str(master_port), "TORCHELASTIC_RESTART_COUNT": str(restart_count), "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), "NCCL_ASYNC_ERROR_HANDLING": str(1), } if "OMP_NUM_THREADS" in os.environ: worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] envs[local_rank] = worker_env args[local_rank] = spec.args # scaling events do not count towards restarts (gets same attempt #) # remove existing log dir if this restart is due to a scaling event attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}") shutil.rmtree(attempt_log_dir, ignore_errors=True) os.makedirs(attempt_log_dir) self._pcontext = start_processes( name=spec.role, entrypoint=spec.entrypoint, args=args, envs=envs, log_dir=attempt_log_dir, start_method=self._start_method, redirects=spec.redirects, tee=spec.tee, ) return self._pcontext.pids()
def test_multiprocess_context_close(self): pc = start_processes( name="sleep", entrypoint=time.sleep, args={0: (1, )}, envs={0: {}}, log_dir=self.log_dir(), start_method="fork", ) pids = pc.pids() pc.close() self.assert_pids_noexist(pids) self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped())
def test_function_raise(self): """ run 2x copies of echo2, raise an exception on the first """ RAISE = True for start_method in start_methods(): with self.subTest(start_method=start_method): log_dir = self.log_dir() pc = start_processes( name="echo", entrypoint=echo2, args={ 0: ("hello", RAISE), 1: ("world", ) }, envs={ 0: {}, 1: {} }, log_dir=log_dir, start_method=start_method, ) results = pc.wait(period=0.1) self.assert_pids_noexist(pc.pids()) self.assertEqual(1, len(results.failures)) self.assertFalse(results.return_values) failure = results.failures[0] error_file = failure.error_file error_file_data = failure.error_file_data self.assertEqual(1, failure.exitcode) self.assertEqual("<N/A>", failure.signal_name()) self.assertEqual(pc.pids()[0], failure.pid) self.assertEqual(os.path.join(log_dir, "0", "error.json"), error_file) self.assertEqual( int(error_file_data["message"]["extraInfo"]["timestamp"]), int(failure.timestamp), ) self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped())
def test_binary_redirect_and_tee(self): pc = start_processes( name="trainer", entrypoint=self.bin("echo1.py"), args={0: ("hello",), 1: ("world",)}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, log_dir=self.log_dir(), start_method="fork", redirects={0: Std.ERR, 1: Std.NONE}, tee={0: Std.OUT, 1: Std.ERR}, ) result = pc.wait() self.assertFalse(result.is_failed()) self.assert_in_file(["hello stdout from 0"], pc.stdouts[0]) self.assert_in_file(["hello stderr from 0"], pc.stderrs[0]) self.assert_in_file(["world stderr from 1"], pc.stderrs[1]) self.assertFalse(pc.stdouts[1]) self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped())
def test_function_redirect_and_tee(self): for start_method in start_methods(): with self.subTest(start_method=start_method): log_dir = self.log_dir() pc = start_processes( name="trainer", entrypoint=echo1, args={ 0: ("hello", ), 1: ("world", ) }, envs={ 0: { "RANK": "0" }, 1: { "RANK": "1" } }, log_dir=log_dir, start_method="fork", redirects={ 0: Std.ERR, 1: Std.NONE }, tee={ 0: Std.OUT, 1: Std.ERR }, ) result = pc.wait() self.assertFalse(result.is_failed()) self.assert_in_file(["hello stdout from 0"], pc.stdouts[0]) self.assert_in_file(["hello stderr from 0"], pc.stderrs[0]) self.assert_in_file(["world stderr from 1"], pc.stderrs[1]) self.assertFalse(pc.stdouts[1]) self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped())
def test_invoke_mp(self, mp_mock): params = [MpParameters(fn=dummy_fn, args=())] * 4 start_processes(params, start_method="fork") mp_mock.assert_called_once_with(params, "fork")
def test_invoke_mp_no_params(self): with self.assertRaises(ValueError): start_processes([], start_method="fork")