@dataclass class Conf: """ Holds arguments to launch an agent (e.g. simulates an agent run on a node). """ entrypoint: Callable local_world_size: int args: Tuple = () role: str = "default" redirects: Std = Std.NONE tee: Std = Std.NONE @unittest.skipIf(is_tsan(), "tests incompatible with tsan") class LocalElasticAgentTest(unittest.TestCase): @classmethod def setUpClass(cls): # start a standalone, single process etcd server to use for all tests cls._etcd_server = EtcdServer() cls._etcd_server.start() @classmethod def tearDownClass(cls): # stop the standalone etcd server cls._etcd_server.stop() def setUp(self): self._test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__) self._run_id = str(uuid.uuid4()).split("-")[0]
class LaunchTest(unittest.TestCase): @classmethod def setUpClass(cls): # start a standalone, single process etcd server to use for all tests cls._etcd_server = EtcdServer() cls._etcd_server.start() cls._etcd_endpoint = cls._etcd_server.get_endpoint() @classmethod def tearDownClass(cls): # stop the standalone etcd server cls._etcd_server.stop() def setUp(self): self.test_dir = tempfile.mkdtemp() # remove any lingering environment variables for env in os.environ.keys(): if env.startswith("PET_"): del os.environ[env] # set a sentinel env var on the parent proc # this should be present on the child and gets # asserted in ``bin/test_script.py`` os.environ["TEST_SENTINEL_PARENT"] = "FOOBAR" def tearDown(self): shutil.rmtree(self.test_dir) def test_launch_user_script_python(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) def test_launch_user_script_bash(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", "--start_method=fork", "--no_python", ] script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] with self.assertRaises(ValueError): # --no_python cannot be used with --module launch.main(args + ["--module"] + script_args) launch.main(args + script_args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) def test_launch_with_env_vars(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node os.environ["PET_NNODES"] = str(nnodes) os.environ["PET_NPROC_PER_NODE"] = str(nproc_per_node) os.environ["PET_RDZV_BACKEND"] = "etcd" os.environ["PET_RDZV_ENDPOINT"] = self._etcd_endpoint os.environ["PET_RDZV_ID"] = run_id os.environ["PET_MONITOR_INTERVAL"] = "1" os.environ["PET_START_METHOD"] = "fork" os.environ["PET_NO_PYTHON"] = "1" script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] with self.assertRaises(ValueError): # --no_python cannot be used with --module os.environ["PET_MODULE"] = "1" launch.main(script_args) os.environ["PET_MODULE"] = "0" launch.main(script_args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) def _test_nproc_launch_configuration(self, nproc_type, expected_number): run_id = str(uuid.uuid4().int) nnodes = 1 args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_type}", "--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", "--start_method=fork", "--no_python", ] script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] launch.main(args + script_args) world_size = nnodes * expected_number # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_nproc_launch_auto_configurations(self): self._test_nproc_launch_configuration("auto", os.cpu_count()) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_nproc_launch_number_configurations(self): self._test_nproc_launch_configuration("4", 4) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_nproc_launch_unknown_configurations(self): with self.assertRaises(ValueError): self._test_nproc_launch_configuration("unknown", 4) @unittest.skipIf(is_tsan(), "test incompatible with tsan") @patch("torch.cuda.is_available", return_value=True) @patch("torch.cuda.device_count", return_value=3) def test_nproc_gpu_launch_configurations(self, _mock1, _mock2): self._test_nproc_launch_configuration("auto", 3) self._test_nproc_launch_configuration("gpu", 3) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_elastic(self): run_id = str(uuid.uuid4().int) min_nodes = 1 max_nodes = 2 nproc_per_node = 4 # we are only launching 1 node (even though max = 2) world_size = nproc_per_node args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc_per_node={nproc_per_node}", "--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_elastic_worker_raise_exception(self): """ Asserts that when the worker program fails and lancher raieses exception to indicate that worker process failed """ run_id = str(uuid.uuid4().int) min_nodes = 1 max_nodes = 2 nproc_per_node = 4 args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc_per_node={nproc_per_node}", "--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", "--max_restarts=0", "--start_method=fork", path("bin/test_script.py"), "--fail", ] proc = mp.Process(target=launch_in_proc, args=(args, )) proc.start() proc.join() self.assertEqual(1, proc.exitcode) @unittest.skipIf(is_tsan(), "test incompatible with tsan") @mock.patch( "torchelastic.agent.server.local_elastic_agent.LocalElasticAgent.run") def test_launch_elastic_agent_raise_exception(self, mock_agent_run): """ Asserts that when the agent raises an exception the launcher re-raises the original exception """ run_id = str(uuid.uuid4().int) min_nodes = 1 max_nodes = 2 nproc_per_node = 4 args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc_per_node={nproc_per_node}", "--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", "--max_restarts=0", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] mock_agent_run.side_effect = MockException with self.assertRaises(MockException): launch.main(args) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_standalone(self): nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--standalone", "--monitor_interval=1", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_elastic_multiple_agents(self): run_id = str(uuid.uuid4().int) min_nodes = 1 max_nodes = 2 nproc_per_node = 4 nnodes = 2 world_size = nnodes * nproc_per_node args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc_per_node={nproc_per_node}", "--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] procs = [] for _ in range(nnodes - 1): p = mp.Process(target=launch.main, args=[args]) procs.append(p) p.start() launch.main(args) for i in range(nnodes - 1): p = procs[i] p.join() self.assertEqual(0, p.exitcode) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) def test_min_max_nodes_parse(self): min_nodes, max_nodes = launch.parse_min_max_nnodes("1") self.assertTrue(min_nodes, max_nodes) self.assertTrue(1, min_nodes) min_nodes, max_nodes = launch.parse_min_max_nnodes("2:20") self.assertTrue(2, min_nodes) self.assertTrue(20, max_nodes) with self.assertRaises(RuntimeError): launch.parse_min_max_nnodes("2:20:30") @patch("torchelastic.distributed.launch.LocalElasticAgent") def test_launch_rdzv_shutdown(self, agent_mock_cls): nnodes = 1 nproc_per_node = 4 args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--monitor_interval=1", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] agent_mock = Mock() agent_mock.run.return_value = RunResult(WorkerState.SUCCEEDED) agent_mock_cls.return_value = agent_mock rdzv_handler_mock = Mock() with patch("torchelastic.rendezvous.registry.get_rendezvous_handler" ) as param_mock: param_mock.return_value = rdzv_handler_mock launch.main(args) rdzv_handler_mock.shutdown.assert_called_once()
class LaunchTest(unittest.TestCase): @classmethod def setUpClass(cls): # start a standalone, single process etcd server to use for all tests cls._etcd_server = EtcdServer() cls._etcd_server.start() cls._etcd_endpoint = cls._etcd_server.get_endpoint() @classmethod def tearDownClass(cls): # stop the standalone etcd server cls._etcd_server.stop() def setUp(self): self.test_dir = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.test_dir) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_user_script_python(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_user_script_bash(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", "--start_method=fork", "--no_python", ] script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] with self.assertRaises(ValueError): # --no_python cannot be used with --module launch.main(args + ["--module"] + script_args) launch.main(args + script_args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) # @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_wrapper_fn_kill_script_process(self): """ tests that the wrapper_fn properly terminates the script process (the script process is the sub_sub_process of the agent """ nprocs = 2 sleep = 300 # wraps wrapper_fn to be torch.multiprocessing compatible # which requires rank to be passed as first arugment def wrap_wrap(rank, *args): launch.wrapper_fn(*args) context = start_processes( fn=wrap_wrap, args=(None, (path("bin/sleep_script.py"), "--sleep", f"{sleep}")), nprocs=nprocs, join=False, start_method="fork", ) # quick check to see that the wrapper_fn started running # without this join() call we don't see an exception on typos # and other silly mistakes (silently fails) context.join(timeout=-1) script_pids = [] for wrapper_fn_pid in context.pids(): script_pid = get_child_pids(wrapper_fn_pid) # there should only be one child of wrapper_fn self.assertEqual(1, len(script_pid)) script_pids.append(script_pid[0]) for wrapper_fn_proc in context.processes: wrapper_fn_proc.terminate() wrapper_fn_proc.join() for script_pid in script_pids: self.assertFalse(pid_exists(script_pid)) def _test_nproc_launch_configuration(self, nproc_type, expected_number): run_id = str(uuid.uuid4().int) nnodes = 1 args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_type}", "--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", "--start_method=fork", "--no_python", ] script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] launch.main(args + script_args) world_size = nnodes * expected_number # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_nproc_launch_auto_configurations(self): self._test_nproc_launch_configuration("auto", os.cpu_count()) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_nproc_launch_number_configurations(self): self._test_nproc_launch_configuration("4", 4) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_nproc_launch_unknown_configurations(self): with self.assertRaises(ValueError): self._test_nproc_launch_configuration("unknown", 4) @unittest.skipIf(is_tsan(), "test incompatible with tsan") @patch("torch.cuda.is_available", return_value=True) @patch("torch.cuda.device_count", return_value=3) def test_nproc_gpu_launch_configurations(self, _mock1, _mock2): self._test_nproc_launch_configuration("auto", 3) self._test_nproc_launch_configuration("gpu", 3) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_elastic(self): run_id = str(uuid.uuid4().int) min_nodes = 1 max_nodes = 2 nproc_per_node = 4 # we are only launching 1 node (even though max = 2) world_size = nproc_per_node args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc_per_node={nproc_per_node}", "--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_standalone(self): nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--standalone", "--monitor_interval=1", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_elastic_multiple_agents(self): run_id = str(uuid.uuid4().int) min_nodes = 1 max_nodes = 2 nproc_per_node = 4 nnodes = 2 world_size = nnodes * nproc_per_node args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc_per_node={nproc_per_node}", "--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] procs = [] for _ in range(nnodes - 1): p = mp.Process(target=launch.main, args=[args]) procs.append(p) p.start() launch.main(args) for i in range(nnodes - 1): p = procs[i] p.join() self.assertEqual(0, p.exitcode) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) def test_min_max_nodes_parse(self): min_nodes, max_nodes = launch.parse_min_max_nnodes("1") self.assertTrue(min_nodes, max_nodes) self.assertTrue(1, min_nodes) min_nodes, max_nodes = launch.parse_min_max_nnodes("2:20") self.assertTrue(2, min_nodes) self.assertTrue(20, max_nodes) with self.assertRaises(RuntimeError): launch.parse_min_max_nnodes("2:20:30") @patch("torchelastic.distributed.launch.LocalElasticAgent") def test_launch_rdzv_shutdown(self, _): nnodes = 1 nproc_per_node = 4 args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--monitor_interval=1", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] rdzv_handler_mock = Mock() with patch( "torchelastic.rendezvous.registry.get_rendezvous_handler" ) as param_mock: param_mock.return_value = rdzv_handler_mock launch.main(args) rdzv_handler_mock.shutdown.assert_called_once()
class LocalElasticAgentTest(unittest.TestCase): @classmethod def setUpClass(cls): # start a standalone, single process etcd server to use for all tests cls._etcd_server = EtcdServer() cls._etcd_server.start() @classmethod def tearDownClass(cls): # stop the standalone etcd server cls._etcd_server.stop() def setUp(self): # clear env vars os.environ.pop("TORCHELASTIC_ERROR_FILE", None) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_happy_function(self): spec = self._get_worker_spec(fn=_happy_function) agent = LocalElasticAgent(spec, start_method="fork") agent.run() def _get_worker_spec( self, fn=None, cmd=None, args=(), max_restarts=1, num_agents=1, monitor_interval=0.1, local_world_size=8, ): run_id = str(uuid.uuid4().int) rdzv_params = RendezvousParameters( backend="etcd", endpoint=f"{self._etcd_server.get_endpoint()}", run_id=run_id, min_nodes=num_agents, max_nodes=num_agents, ) rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) spec = WorkerSpec( role="test_trainer", local_world_size=local_world_size, fn=fn, cmd=cmd, args=args, rdzv_handler=rdzv_handler, max_restarts=max_restarts, monitor_interval=monitor_interval, ) return spec @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_check_role_name(self): spec = self._get_worker_spec(fn=_get_env_var, args=("ROLE_NAME",)) agent = LocalElasticAgent(spec, start_method="fork") group_result = agent.run() results = group_result.return_values for role_name in results.values(): self.assertEquals(spec.role, role_name) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_distributed_sum(self): spec = self._get_worker_spec(fn=_distributed_sum, args=(0,)) agent = LocalElasticAgent(spec, start_method="fork") agent.run() class RoleConfig: __slots__ = ["role", "workers", "num_agents", "workers_num", "role_size"] def __init__( self, role: str, workers=None, num_agents: int = 0, workers_num: int = 0 ): self.role = role self.workers = workers if workers_num != 0 and num_agents != 0: self.workers = [workers_num] * num_agents self.role_size = sum(self.workers) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_correct_rank_assignment_heterogeneous(self): roles_config = [ self.RoleConfig("trainer", workers=[1, 2, 3, 4]), self.RoleConfig("ps", workers=[5, 2]), # split configuration to run the last one on the main process self.RoleConfig("master", workers=[8]), ] self.run_configuration(roles_config, 25) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_correct_rank_assignment_homogeneous(self): num_workers = 4 roles_config = [ self.RoleConfig("trainer", num_agents=4, workers_num=num_workers), self.RoleConfig("ps", num_agents=2, workers_num=num_workers), # split configuration to run the last one on the main process self.RoleConfig("master", num_agents=1, workers_num=num_workers), ] self.run_configuration(roles_config, 28) def run_configuration(self, roles_config, expected_world_size): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = sum(len(cfg.workers) for cfg in roles_config) run_id = str(uuid.uuid4().int) procs = [] manager = multiprocessing.Manager() return_dict = manager.dict() default_args = (run_id, host, port, nnodes, nnodes, _check_rank_assignment, ()) for ind in range(len(roles_config) - 1): config = roles_config[ind] for num_workers in config.workers: p = multiprocessing.Process( target=_run_agent, args=(*default_args, num_workers, config.role, return_dict), ) procs.append(p) p.start() # run one on the main process for debugging config = roles_config[len(roles_config) - 1] _run_agent(*default_args, config.workers[0], config.role, return_dict) for i in range(nnodes - 1): p = procs[i] p.join() self.assertEqual(0, p.exitcode) role_info_dict = {role_info.role: role_info for role_info in roles_config} self.verify_rank_consistency(return_dict, role_info_dict, expected_world_size) def verify_rank_consistency(self, return_dict, role_info_dict, expected_world_size): role_ranks = {} global_ranks = [] grouped_ranks = {} for role, group_result in return_dict.values(): res = group_result.return_values for ( group_rank, rank, world_size, role_rank, role_world_size, ) in res.values(): role_info_config = role_info_dict[role] self.assertEqual(expected_world_size, world_size) self.assertEqual(role_info_config.role_size, role_world_size) if group_rank not in grouped_ranks: grouped_ranks[group_rank] = [] grouped_ranks[group_rank].append((rank, role_rank)) global_ranks.append(rank) if role not in role_ranks: role_ranks[role] = [] role_ranks[role].append(role_rank) global_ranks = sorted(global_ranks) self.assertEqual(list(range(0, expected_world_size)), global_ranks) for role, role_config_info in role_info_dict.items(): self.assertEqual( list(range(0, role_config_info.role_size)), sorted(role_ranks[role]) ) # Make sure that each agent assignes consecutive ranks to workes # The first argument is the global_rank and the second argument # is role_rank for ranks_lst in grouped_ranks.values(): self.verify_ranks_sequential(ranks_lst, 0) self.verify_ranks_sequential(ranks_lst, 1) def verify_ranks_sequential(self, ranks_pairs, rank_idx): ranks = sorted(rank_pair[rank_idx] for rank_pair in ranks_pairs) start_rank, end_rank = ranks[0], ranks[-1] self.assertEqual(list(range(start_rank, end_rank + 1)), ranks) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_distributed_sum_heterogenous(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = 4 run_id = str(uuid.uuid4().int) procs = [] default_args = (run_id, host, port, nnodes, nnodes, _distributed_sum, (0,)) for ind in range(nnodes - 1): p = multiprocessing.Process( target=_run_agent, args=(*default_args, ind + 1) ) procs.append(p) p.start() # run one on the main process for debugging _run_agent(*default_args, 8) for i in range(nnodes - 1): p = procs[i] p.join() self.assertEqual(0, p.exitcode) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_sad_function(self): self._test_run_sad_function() @record def _test_run_sad_function(self): spec = self._get_worker_spec(fn=_sad_function, max_restarts=0) agent = LocalElasticAgent(spec, start_method="fork") group_results = agent.run() failed_results = group_results.failures self.assertEqual(spec.local_world_size, len(failed_results)) # all ranks will have the same result for result in failed_results.values(): self.assertTrue(os.path.exists(result.error_file)) with open(result.error_file, "r") as f: data = f.read().replace("\n", "") self.assertTrue("RuntimeError: sad because i throw" in data) self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state) self.assertEqual(0, agent._remaining_restarts) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_bipolar_function(self): spec = self._get_worker_spec(fn=_bipolar_function, max_restarts=2) agent = LocalElasticAgent(spec, start_method="fork") agent.run() self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state) self.assertEqual(0, agent._remaining_restarts) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_check_env_function(self): spec = self._get_worker_spec(fn=_check_env_function, max_restarts=2) agent = LocalElasticAgent(spec, start_method="fork") agent.run() @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_check_run_id(self): def return_run_id(): return os.environ["TORCHELASTIC_RUN_ID"] spec = self._get_worker_spec(fn=return_run_id, max_restarts=0) agent = LocalElasticAgent(spec, start_method="fork") group_result = agent.run() results = group_result.return_values for i in range(spec.local_world_size): self.assertEqual(spec.rdzv_handler.get_run_id(), results[i]) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_get_worker_return_values(self): spec = self._get_worker_spec(fn=_return_rank_times, args=(2,)) agent = LocalElasticAgent(spec, start_method="fork") group_result = agent.run() results = group_result.return_values self.assertEqual(spec.local_world_size, len(results)) for i in range(spec.local_world_size): self.assertEqual(i * 2, results[i]) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_double_agent_happy(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = 2 run_id = str(uuid.uuid4().int) procs = [] for _ in range(nnodes - 1): p = multiprocessing.Process( target=_run_agent, args=(run_id, host, port, nnodes, nnodes, _distributed_sum, (0,)), ) procs.append(p) p.start() # run one on the main process for debugging _run_agent(run_id, host, port, nnodes, nnodes, _distributed_sum, (0,)) for i in range(nnodes - 1): p = procs[i] p.join() self.assertEqual(0, p.exitcode) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_double_agent_fault_tolerance(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = 2 run_id = str(uuid.uuid4().int) procs = [] for _ in range(nnodes): p = multiprocessing.Process( target=_run_agent, args=(run_id, host, port, nnodes, nnodes, _distributed_sum, (0,)), ) procs.append(p) p.start() # restart odd agents for i in range(nnodes): if i % 2 != 0: procs[i].kill() p = multiprocessing.Process( target=_run_agent, args=(run_id, host, port, nnodes, nnodes, _distributed_sum, (0,)), ) procs[i] = p p.start() for i in range(nnodes): p = procs[i] p.join() self.assertEqual(0, p.exitcode) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_double_agent_elastic(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() min_size = 1 max_size = 2 run_id = str(uuid.uuid4().int) procs = [] for _ in range(max_size): p = multiprocessing.Process( target=_run_agent, args=(run_id, host, port, min_size, max_size, _distributed_sum, (0,)), ) procs.append(p) p.start() # kill odd agents for i in range(max_size): if i % 2 != 0: procs[i].kill() for i in range(max_size): if i % 2 == 0: p = procs[i] p.join() self.assertEqual(0, p.exitcode) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_torch_rpc(self): """ Simple torch rpc example with torchelastic. Creates two agents (to simulate two node job), each agent runs a single worker. worker0 calls an rpc_sync on worker1. """ # TODO upstream this to torch.distributed.rpc so that users do not have # to redundantly set rank as part of name (e.g. worker0) AND also pass # it explicitly as an argument to rpc.init_rpc def init_rpc(name_prefix, backend): rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) rpc.init_rpc( name=f"{name_prefix}{rank}", backend=backend, rank=rank, world_size=world_size, ) def worker_0(queue, msg): init_rpc("worker", BackendType.PROCESS_GROUP) ret = rpc.rpc_sync(to="worker1", func=echo, args=(msg,)) queue.put(ret) rpc.shutdown() def worker_1(): init_rpc("worker", BackendType.PROCESS_GROUP) rpc.shutdown() def run_agent( run_id, etcd_host, etcd_port, start_method, worker_fn, worker_args=() ): rdzv_params = RendezvousParameters( backend="etcd", endpoint=f"{etcd_host}:{etcd_port}", run_id=run_id, min_nodes=2, max_nodes=2, ) rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) spec = WorkerSpec( role="test_trainer", local_world_size=1, fn=worker_fn, args=worker_args, rdzv_handler=rdzv_handler, max_restarts=3, monitor_interval=1, ) agent = LocalElasticAgent(spec, start_method) agent.run() run_id = str(uuid.uuid4().int) host = self._etcd_server.get_host() port = self._etcd_server.get_port() start_method = "fork" msg = "hello world" mp_queue = multiprocessing.get_context(start_method).Queue() agent0 = multiprocessing.Process( target=run_agent, args=(run_id, host, port, start_method, worker_0, (mp_queue, msg)), ) agent1 = multiprocessing.Process( target=run_agent, args=(run_id, host, port, start_method, worker_1, ()) ) agent0.start() agent1.start() agent0.join() agent1.join() self.assertEqual(msg, mp_queue.get()) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_workers_drift_success(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = 2 run_id = str(uuid.uuid4().int) procs = [] default_args = (run_id, host, port, nnodes, nnodes, _simulate_work) for _ in range(nnodes - 1): p = multiprocessing.Process( target=_run_agent, args=(*default_args, (10,), 2, "test_trainer", {}, 30), ) procs.append(p) p.start() _run_agent(*default_args, (1,), 2, "test_trainer", {}, 30) for i in range(nnodes - 1): p = procs[i] p.join() self.assertEqual(0, p.exitcode) @patch("torchelastic.utils.store.barrier") @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_workers_drift_fail(self, barrier_mock): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = 2 run_id = str(uuid.uuid4().int) procs = [] default_args = (run_id, host, port, nnodes, nnodes, _simulate_work) for _ in range(nnodes - 1): p = multiprocessing.Process( target=_run_agent, args=(*default_args, (60,), 2, "test_trainer", {}, 10), ) procs.append(p) p.start() _run_agent(*default_args, (1,), 2, "test_trainer", {}, 10) barrier_mock.assert_called_once() @patch("torchelastic.utils.store.barrier") @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_barrier_failed(self, barrier_mock): barrier_mock.side_effect = RuntimeError("test error") spec = self._get_worker_spec(fn=_happy_function) agent = LocalElasticAgent(spec, start_method="fork") agent.run() barrier_mock.assert_called_once() def test_provide_fn_and_cmd(self): with self.assertRaises(AssertionError): self._get_worker_spec( fn=_bipolar_function, cmd=["test.bin"], max_restarts=2 ) def test_provide_none(self): with self.assertRaises(AssertionError): self._get_worker_spec(max_restarts=2) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_failed_result_with_run_id(self): temp_dir = tempfile.mkdtemp() os.environ["TORCHELASTIC_ERROR_FILE"] = f"{temp_dir}/error.log" self._test_failed_result_with_run_id() shutil.rmtree(temp_dir) @record def _test_failed_result_with_run_id(self): max_restarts = 3 spec = self._get_worker_spec(fn=_sad_function, max_restarts=max_restarts) agent = LocalElasticAgent(spec, start_method="fork") run_result = agent.run() for failure in run_result.failures.values(): error_file = failure.error_file self.assertTrue(error_file.endswith(f"_{max_restarts}")) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_transient_bug(self): temp_dir = tempfile.mkdtemp() os.environ["TORCHELASTIC_ERROR_FILE"] = f"{temp_dir}/error.log" self._test_transient_bug(temp_dir) shutil.rmtree(temp_dir) @record def _test_transient_bug(self, error_dir: str): max_restarts = 3 spec = self._get_worker_spec(fn=_transient_bug, max_restarts=max_restarts) agent = LocalElasticAgent(spec, start_method="fork") run_result = agent.run() self.assertEqual(WorkerState.SUCCEEDED, run_result.state) for rank in range(len(run_result.return_values)): error_file_0 = os.path.join(error_dir, str(rank), "error.log_0") self.assertTrue(os.path.exists(error_file_0)) error_file_1 = os.path.join(error_dir, str(rank), "error.log_1") self.assertFalse(os.path.exists(error_file_1))
class LocalTimerTest(unittest.TestCase): def setUp(self): self.mp_queue = mp.Queue() self.max_interval = 0.01 self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval) self.server.start() def tearDown(self): self.server.stop() def test_exception_propagation(self): with self.assertRaises(Exception, msg="foobar"): with timer.expires(after=1): raise Exception("foobar") def test_no_client(self): # no timer client configured; exception expected timer.configure(None) with self.assertRaises(RuntimeError): with timer.expires(after=1): pass def test_client_interaction(self): # no timer client configured but one passed in explicitly # no exception expected timer_client = timer.LocalTimerClient(self.mp_queue) timer_client.acquire = mock.MagicMock(wraps=timer_client.acquire) timer_client.release = mock.MagicMock(wraps=timer_client.release) with timer.expires(after=1, scope="test", client=timer_client): pass timer_client.acquire.assert_called_once_with("test", mock.ANY) timer_client.release.assert_called_once_with("test") def test_happy_path(self): timer.configure(timer.LocalTimerClient(self.mp_queue)) with timer.expires(after=0.5): time.sleep(0.1) @unittest.skipIf(is_tsan(), "test is tsan incompatible") def test_get_timer_recursive(self): """ If a function acquires a countdown timer with default scope, then recursive calls to the function should re-acquire the timer rather than creating a new one. That is only the last recursive call's timer will take effect. """ self.server.start() timer.configure(timer.LocalTimerClient(self.mp_queue)) # func should not time out def func(n): if n > 0: with timer.expires(after=0.1): func(n - 1) time.sleep(0.05) func(4) # func2 should time out def func2(n): if n > 0: with timer.expires(after=0.1): func2(n - 1) time.sleep(0.2) p = mp.Process(target=func2, args=(2, )) p.start() p.join() self.assertEqual(-signal.SIGKILL, p.exitcode) @staticmethod def _run(mp_queue, timeout, duration): client = timer.LocalTimerClient(mp_queue) timer.configure(client) with timer.expires(after=timeout): time.sleep(duration) @unittest.skipIf(is_tsan(), "test is tsan incompatible") def test_timer(self): timeout = 0.1 duration = 1 p = mp.Process(target=self._run, args=(self.mp_queue, timeout, duration)) p.start() p.join() self.assertEqual(-signal.SIGKILL, p.exitcode)
class LocalTimerServerTest(unittest.TestCase): def setUp(self): self.mp_queue = mp.Queue() self.max_interval = 0.01 self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval) def tearDown(self): self.server.stop() @unittest.skipIf(is_tsan(), "test is tsan incompatible") def test_watchdog_call_count(self): """ checks that the watchdog function ran wait/interval +- 1 times """ self.server._run_watchdog = mock.MagicMock( wraps=self.server._run_watchdog) wait = 0.1 self.server.start() time.sleep(wait) self.server.stop() watchdog_call_count = self.server._run_watchdog.call_count self.assertGreaterEqual(watchdog_call_count, int(wait / self.max_interval) - 1) self.assertLessEqual(watchdog_call_count, int(wait / self.max_interval) + 1) def test_watchdog_empty_queue(self): """ checks that the watchdog can run on an empty queue """ self.server._run_watchdog() def _expired_timer(self, pid, scope): expired = time.time() - 60 return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=expired) def _valid_timer(self, pid, scope): valid = time.time() + 60 return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=valid) def _release_timer(self, pid, scope): return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=-1) @unittest.skipIf(is_tsan(), "test is tsan incompatible") @mock.patch("os.kill") def test_expired_timers(self, mock_os_kill): """ tests that a single expired timer on a process should terminate the process and clean up all pending timers that was owned by the process """ test_pid = -3 self.mp_queue.put(self._expired_timer(pid=test_pid, scope="test1")) self.mp_queue.put(self._valid_timer(pid=test_pid, scope="test2")) self.server._run_watchdog() self.assertEqual(0, len(self.server._timers)) mock_os_kill.assert_called_once_with(test_pid, signal.SIGKILL) @mock.patch("os.kill") def test_acquire_release(self, mock_os_kill): """ tests that: 1. a timer can be acquired then released (should not terminate process) 2. a timer can be vacuously released (e.g. no-op) """ test_pid = -3 self.mp_queue.put(self._valid_timer(pid=test_pid, scope="test1")) self.mp_queue.put(self._release_timer(pid=test_pid, scope="test1")) self.mp_queue.put(self._release_timer(pid=test_pid, scope="test2")) self.server._run_watchdog() self.assertEqual(0, len(self.server._timers)) mock_os_kill.assert_not_called() @mock.patch("os.kill") def test_valid_timers(self, mock_os_kill): """ tests that valid timers are processed correctly and the process is left alone """ self.mp_queue.put(self._valid_timer(pid=-3, scope="test1")) self.mp_queue.put(self._valid_timer(pid=-3, scope="test2")) self.mp_queue.put(self._valid_timer(pid=-2, scope="test1")) self.mp_queue.put(self._valid_timer(pid=-2, scope="test2")) self.server._run_watchdog() self.assertEqual(4, len(self.server._timers)) self.assertTrue((-3, "test1") in self.server._timers) self.assertTrue((-3, "test2") in self.server._timers) self.assertTrue((-2, "test1") in self.server._timers) self.assertTrue((-2, "test2") in self.server._timers) mock_os_kill.assert_not_called()
class LocalTimerExample(unittest.TestCase): """ Demonstrates how to use LocalTimerServer and LocalTimerClient to enforce expiration of code-blocks. Since torch multiprocessing's ``start_process`` method currently does not take the multiprocessing context as parameter argument there is no way to create the mp.Queue in the correct context BEFORE spawning child processes. Once the ``start_process`` API is changed in torch, then re-enable ``test_torch_mp_example`` unittest. As of now this will SIGSEGV. """ @unittest.skipIf(is_asan_or_tsan(), "test is a/tsan incompatible") def test_torch_mp_example(self): # in practice set the max_interval to a larger value (e.g. 60 seconds) mp_queue = mp.get_context("spawn").Queue() server = timer.LocalTimerServer(mp_queue, max_interval=0.01) server.start() world_size = 8 # all processes should complete successfully # since start_process does NOT take context as parameter argument yet # this method WILL FAIL (hence the test is disabled) torch_mp.spawn(fn=_happy_function, args=(mp_queue, ), nprocs=world_size, join=True) with self.assertRaises(Exception): # torch.multiprocessing.spawn kills all sub-procs # if one of them gets killed torch_mp.spawn(fn=_stuck_function, args=(mp_queue, ), nprocs=world_size, join=True) server.stop() @unittest.skipIf(is_asan_or_tsan(), "test is a/tsan incompatible") def test_example_start_method_spawn(self): self._run_example_with(start_method="spawn") @unittest.skipIf(is_asan_or_tsan(), "test is a/tsan incompatible") def test_example_start_method_forkserver(self): self._run_example_with(start_method="forkserver") @unittest.skipIf(is_tsan(), "test is tsan incompatible") def test_example_start_method_fork(self): self._run_example_with(start_method="fork") def _run_example_with(self, start_method): spawn_ctx = mp.get_context(start_method) mp_queue = spawn_ctx.Queue() server = timer.LocalTimerServer(mp_queue, max_interval=0.01) server.start() world_size = 8 processes = [] for i in range(0, world_size): if i % 2 == 0: p = spawn_ctx.Process(target=_stuck_function, args=(i, mp_queue)) else: p = spawn_ctx.Process(target=_happy_function, args=(i, mp_queue)) p.start() processes.append(p) for i in range(0, world_size): p = processes[i] p.join() if i % 2 == 0: self.assertEqual(-signal.SIGKILL, p.exitcode) else: self.assertEqual(0, p.exitcode) server.stop()
class MpProcessContextTest(unittest.TestCase): @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_success(self): nprocs = 4 mult = 2 params = [MpParameters(fn=run_compute, args=(mult, ))] * nprocs proc_context = start_processes(params, start_method="spawn") ret_vals = proc_context.wait() while not ret_vals: ret_vals = proc_context.wait() self.assertEqual(4, len(ret_vals)) for local_rank, ret_val in ret_vals.items(): self.assertEqual(mult * local_rank, ret_val) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_success_no_return_func(self): nprocs = 4 params = [MpParameters(fn=run_dummy, args=())] * nprocs proc_context = start_processes(params, start_method="spawn") ret_vals = proc_context.wait() while not ret_vals: ret_vals = proc_context.wait() self.assertEqual(4, len(ret_vals)) for ret_val in ret_vals.values(): self.assertEqual(None, ret_val) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_huge_output(self): # python multiprocessing.queue module uses pipes and actually PipedQueues # This means that if a single object is greater than a pipe size # the writer process will block until reader process will start # reading the pipe. # This test makes a worker fn to return huge output, around ~10 MB nprocs = 4 size = 200000 params = [MpParameters(fn=fill_dict, args=(size, ))] * nprocs proc_context = start_processes(params, start_method="spawn") ret_vals = proc_context.wait() while not ret_vals: ret_vals = proc_context.wait() self.assertEqual(4, len(ret_vals)) for ret_val in ret_vals.values(): self.assertEqual(size, len(ret_val)) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_failure(self): nprocs = 4 params = [MpParameters(fn=run_failure, args=())] * nprocs proc_context = start_processes(params, start_method="spawn") with self.assertRaises(ProcessRaisedException): ret_vals = proc_context.wait() while not ret_vals: ret_vals = proc_context.wait() @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_termination(self): nprocs = 5 params = [MpParameters(fn=run_infinite, args=())] * nprocs proc_context = start_processes(params, start_method="spawn") proc_context.terminate() # Processes should terminate with SIGTERM with self.assertRaises(Exception): proc_context.wait() def test_wait_busy_loop(self): nprocs = 2 wait_time = 10 # seconds params = [MpParameters(fn=run_with_wait, args=(wait_time, ))] * nprocs proc_context = start_processes(params, start_method="spawn") self.assertIsNone(proc_context.wait(1)) while proc_context.wait(1) is None: pass def test_wrap_fn(self): nprocs = 2 start_method = "spawn" out_queues: Dict[int, mp.SimpleQueue] = { i: mp.get_context(start_method).SimpleQueue() for i in range(0, nprocs) } params = [MpParameters(fn=run_compute, args=(1, ))] * nprocs for idx in range(nprocs): _wrap(idx, params, out_queues) for idx, out_q in out_queues.items(): self.assertFalse(out_q.empty(), "out queue should not be empty") self.assertEqual(idx, out_q.get())
class MpProcessContextTest(unittest.TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.test_dir) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_success(self): nprocs = 4 mult = 2 params = [MpParameters(fn=run_compute, args=(mult,))] * nprocs proc_context = start_processes(params, start_method="spawn") proc_group_result = self._get_result(proc_context) ret_vals = proc_group_result.return_values self.assertEqual(4, len(ret_vals)) for local_rank, ret_val in ret_vals.items(): self.assertEqual(mult * local_rank, ret_val) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_success_no_return_func(self): nprocs = 4 params = [MpParameters(fn=run_dummy, args=())] * nprocs proc_context = start_processes(params, start_method="spawn") proc_group_result = self._get_result(proc_context) ret_vals = proc_group_result.return_values self.assertEqual(4, len(ret_vals)) for ret_val in ret_vals.values(): self.assertEqual(None, ret_val) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_huge_output(self): # python multiprocessing.queue module uses pipes and actually PipedQueues # This means that if a single object is greater than a pipe size # the writer process will block until reader process will start # reading the pipe. # This test makes a worker fn to return huge output, around ~10 MB nprocs = 4 size = 200000 params = [MpParameters(fn=fill_dict, args=(size,))] * nprocs proc_context = start_processes(params, start_method="spawn") proc_group_result = self._get_result(proc_context) ret_vals = proc_group_result.return_values self.assertEqual(4, len(ret_vals)) for ret_val in ret_vals.values(): self.assertEqual(size, len(ret_val)) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_failure(self): os.environ["TORCHELASTIC_ERROR_FILE"] = f"{self.test_dir}/error.log" _process_error_handler.configure() nprocs = 4 params = [MpParameters(fn=run_failure, args=())] * nprocs proc_context = start_processes(params, start_method="spawn") proc_group_result = self._get_result(proc_context) failed_result = proc_group_result.failure self.assertTrue(os.path.exists(failed_result.error_file)) with open(failed_result.error_file, "r") as f: data = f.read().replace("\n", "") self.assertTrue("RuntimeError: Test error" in data) _process_error_handler.cleanup() def _get_result(self, proc_context) -> ProcessGroupResult: proc_group_result = proc_context.wait() while not proc_group_result: proc_group_result = proc_context.wait() return proc_group_result @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_failure_signal(self): os.environ["TORCHELASTIC_ERROR_FILE"] = f"{self.test_dir}/error.log" _process_error_handler.configure() nprocs = 5 params = [MpParameters(fn=run_failure_signal, args=())] * nprocs proc_context = start_processes(params, start_method="spawn") # Processes should terminate with SIGSEGV proc_group_result = proc_context.wait() failure = proc_group_result.failure self.assertTrue(os.path.exists(failure.error_file)) self.assertEqual("SIGSEGV", failure.get_signal_name()) with open(failure.error_file, "r") as f: data = f.read().replace("\n", "") self.assertTrue("string_at" in data) _process_error_handler.cleanup() def test_wait_busy_loop(self): nprocs = 2 wait_time = 10 # seconds params = [MpParameters(fn=run_with_wait, args=(wait_time,))] * nprocs proc_context = start_processes(params, start_method="spawn") self.assertIsNone(proc_context.wait(1)) while proc_context.wait(1) is None: pass def test_wrap_fn(self): nprocs = 2 start_method = "spawn" out_queues: Dict[int, mp.SimpleQueue] = { i: mp.get_context(start_method).SimpleQueue() for i in range(0, nprocs) } error_files = ["error.log" for _ in range(0, nprocs)] params = [MpParameters(fn=run_compute, args=(1,))] * nprocs for idx in range(nprocs): _wrap(idx, error_files, params, out_queues) for idx, out_q in out_queues.items(): self.assertFalse(out_q.empty(), "out queue should not be empty") self.assertEqual(idx, out_q.get())