class LocalElasticAgentZeusTest(unittest.TestCase): """ Sets up and tears down a local zeus server for testing Tests should connect to localhost:self._mock_zeus_port """ def setUp(self): self._mock_zeus = ZeusServerFixture() self._mock_zeus.wait_for_ready(max_secs=10) self._mock_zeus_port = self._mock_zeus.get_client_port() def tearDown(self): self._mock_zeus.kill() self._mock_zeus_port = None def _get_worker_spec( self, fn, args=(), max_restarts=1, num_agents=1, monitor_interval=0.1, local_world_size=8, ): run_id = str(uuid.uuid4().int) rdzv_handler = dist.rendezvous( f"zeus-adapter://localhost:{self._mock_zeus_port}/{run_id}" f"?min_size={num_agents}" f"&max_size={num_agents}") spec = WorkerSpec( role="test_trainer", local_world_size=local_world_size, fn=fn, args=args, rdzv_handler=rdzv_handler, max_restarts=max_restarts, monitor_interval=monitor_interval, ) return spec @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_sad_function(self): spec = self._get_worker_spec(fn=_sad_function, max_restarts=2) try: agent = LocalElasticAgent(spec, start_method="spawn") with self.assertRaises(WorkerGroupFailureException) as cm: agent.run() finally: spec.rdzv_handler.shutdown() excs = cm.exception.get_worker_exceptions() for i in range(spec.local_world_size): self.assertTrue(isinstance(excs[i], Exception)) self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state) self.assertEqual(0, agent._remaining_restarts)
class LaunchTest(unittest.TestCase): @classmethod def setUpClass(cls): # start a standalone, single process etcd server to use for all tests cls._etcd_server = EtcdServer() cls._etcd_server.start() cls._etcd_endpoint = cls._etcd_server.get_endpoint() @classmethod def tearDownClass(cls): # stop the standalone etcd server cls._etcd_server.stop() def setUp(self): self.test_dir = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.test_dir) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_user_script_python(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_user_script_bash(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", f"--no_python", ] script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] with self.assertRaises(ValueError): # --no_python cannot be used with --module launch.main(args + ["--module"] + script_args) launch.main(args + script_args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) # @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_wrapper_fn_kill_script_process(self): """ tests that the wrapper_fn properly terminates the script process (the script process is the sub_sub_process of the agent """ nprocs = 2 sleep = 300 # wraps wrapper_fn to be torch.multiprocessing compatible # which requires rank to be passed as first arugment def wrap_wrap(rank, *args): launch.wrapper_fn(*args) context = start_processes( fn=wrap_wrap, args=(None, (path("bin/sleep_script.py"), "--sleep", f"{sleep}")), nprocs=nprocs, join=False, start_method="fork", ) # quick check to see that the wrapper_fn started running # without this join() call we don't see an exception on typos # and other silly mistakes (silently fails) context.join(timeout=-1) script_pids = [] for wrapper_fn_pid in context.pids(): script_pid = get_child_pids(wrapper_fn_pid) # there should only be one child of wrapper_fn self.assertEqual(1, len(script_pid)) script_pids.append(script_pid[0]) for wrapper_fn_proc in context.processes: wrapper_fn_proc.terminate() wrapper_fn_proc.join() for script_pid in script_pids: self.assertFalse(pid_exists(script_pid)) def _test_nproc_launch_configuration(self, nproc_type, expected_number): run_id = str(uuid.uuid4().int) nnodes = 1 args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_type}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", f"--no_python", ] script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] launch.main(args + script_args) world_size = nnodes * expected_number # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_nproc_launch_auto_configurations(self): self._test_nproc_launch_configuration("auto", os.cpu_count()) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_nproc_launch_number_configurations(self): self._test_nproc_launch_configuration("4", 4) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_nproc_launch_unknown_configurations(self): with self.assertRaises(ValueError): self._test_nproc_launch_configuration("unknown", 4) @unittest.skipIf(is_tsan(), "test incompatible with tsan") @patch("torch.cuda.is_available", return_value=True) @patch("torch.cuda.device_count", return_value=3) def test_nproc_gpu_launch_configurations(self, _mock1, _mock2): self._test_nproc_launch_configuration("auto", 3) self._test_nproc_launch_configuration("gpu", 3) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_elastic(self): run_id = str(uuid.uuid4().int) min_nodes = 1 max_nodes = 2 nproc_per_node = 4 # we are only launching 1 node (even though max = 2) world_size = nproc_per_node args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc_per_node={nproc_per_node}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_standalone(self): nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", f"--standalone", f"--monitor_interval=1", f"--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_elastic_multiple_agents(self): run_id = str(uuid.uuid4().int) min_nodes = 1 max_nodes = 2 nproc_per_node = 4 nnodes = 2 world_size = nnodes * nproc_per_node args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc_per_node={nproc_per_node}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] procs = [] for _ in range(nnodes - 1): p = mp.Process(target=launch.main, args=[args]) procs.append(p) p.start() launch.main(args) for i in range(nnodes - 1): p = procs[i] p.join() self.assertEqual(0, p.exitcode) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) def test_min_max_nodes_parse(self): min_nodes, max_nodes = launch.parse_min_max_nnodes("1") self.assertTrue(min_nodes, max_nodes) self.assertTrue(1, min_nodes) min_nodes, max_nodes = launch.parse_min_max_nnodes("2:20") self.assertTrue(2, min_nodes) self.assertTrue(20, max_nodes) with self.assertRaises(RuntimeError): launch.parse_min_max_nnodes("2:20:30")
class LocalElasticAgentTest(unittest.TestCase): @classmethod def setUpClass(cls): # start a standalone, single process etcd server to use for all tests cls._etcd_server = EtcdServerFixture() cls._etcd_server.start() @classmethod def tearDownClass(cls): # stop the standalone etcd server cls._etcd_server.stop() def _get_worker_spec(self, fn, args=(), max_restarts=1, num_agents=1, monitor_interval=0.1): host = self._etcd_server.get_host() port = self._etcd_server.get_port() run_id = str(uuid.uuid4().int) rdzv_handler = dist.rendezvous(f"etcd://{host}:{port}/{run_id}" f"?min_workers={num_agents}" f"&max_workers={num_agents}") spec = WorkerSpec( role="test_trainer", local_world_size=8, fn=fn, args=args, rdzv_handler=rdzv_handler, max_restarts=max_restarts, monitor_interval=monitor_interval, ) return spec @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_happy_function(self): spec = self._get_worker_spec(fn=_happy_function) agent = LocalElasticAgent(spec, start_method="fork") agent.run() @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_distributed_sum(self): spec = self._get_worker_spec(fn=_distributed_sum, args=(0, )) agent = LocalElasticAgent(spec, start_method="fork") agent.run() @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_sad_function(self): spec = self._get_worker_spec(fn=_sad_function, max_restarts=2) agent = LocalElasticAgent(spec, start_method="fork") with self.assertRaises(WorkerGroupFailureException) as cm: agent.run() excs = cm.exception.get_worker_exceptions() for i in range(spec.local_world_size): self.assertTrue(isinstance(excs[i], Exception)) self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state) self.assertEqual(0, agent._remaining_restarts) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_bipolar_function(self): spec = self._get_worker_spec(fn=_bipolar_function, max_restarts=2) agent = LocalElasticAgent(spec, start_method="fork") with self.assertRaises(Exception): agent.run() self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state) self.assertEqual(0, agent._remaining_restarts) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_check_env_function(self): spec = self._get_worker_spec(fn=_check_env_function, max_restarts=2) agent = LocalElasticAgent(spec, start_method="fork") agent.run() @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_get_worker_return_values(self): spec = self._get_worker_spec(fn=_return_rank_times, args=(2, )) agent = LocalElasticAgent(spec, start_method="fork") ret_vals = agent.run() self.assertEqual(spec.local_world_size, len(ret_vals)) for i in range(spec.local_world_size): self.assertEqual(i * 2, ret_vals[i]) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_double_agent_happy(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = 2 run_id = str(uuid.uuid4().int) procs = [] for _ in range(nnodes - 1): p = multiprocessing.Process(target=_run_agent, args=(run_id, host, port, nnodes, nnodes)) procs.append(p) p.start() # run one on the main process for debugging _run_agent(run_id, host, port, nnodes, nnodes) for i in range(nnodes - 1): p = procs[i] p.join() self.assertEqual(0, p.exitcode) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_double_agent_fault_tolerance(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = 2 run_id = str(uuid.uuid4().int) procs = [] for _ in range(nnodes): p = multiprocessing.Process(target=_run_agent, args=(run_id, host, port, nnodes, nnodes)) procs.append(p) p.start() # restart odd agents for i in range(nnodes): if i % 2 != 0: procs[i].kill() p = multiprocessing.Process(target=_run_agent, args=(run_id, host, port, nnodes, nnodes)) procs[i] = p p.start() for i in range(nnodes): p = procs[i] p.join() self.assertEqual(0, p.exitcode) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_double_agent_elastic(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() min_size = 1 max_size = 2 run_id = str(uuid.uuid4().int) procs = [] for _ in range(max_size): p = multiprocessing.Process(target=_run_agent, args=(run_id, host, port, min_size, max_size)) procs.append(p) p.start() # kill odd agents for i in range(max_size): if i % 2 != 0: procs[i].kill() for i in range(max_size): if i % 2 == 0: p = procs[i] p.join() self.assertEqual(0, p.exitcode)
class LocalTimerExample(unittest.TestCase): """ Demonstrates how to use LocalTimerServer and LocalTimerClient to enforce expiration of code-blocks. Since torch multiprocessing's ``start_process`` method currently does not take the multiprocessing context as parameter argument there is no way to create the mp.Queue in the correct context BEFORE spawning child processes. Once the ``start_process`` API is changed in torch, then re-enable ``test_torch_mp_example`` unittest. As of now this will SIGSEGV. """ @unittest.skipIf(is_asan_or_tsan(), "test is a/tsan incompatible") def test_torch_mp_example(self): # in practice set the max_interval to a larger value (e.g. 60 seconds) mp_queue = mp.get_context("spawn").Queue() server = timer.LocalTimerServer(mp_queue, max_interval=0.01) server.start() world_size = 8 # all processes should complete successfully # since start_process does NOT take context as parameter argument yet # this method WILL FAIL (hence the test is disabled) torch_mp.spawn( fn=_happy_function, args=(mp_queue,), nprocs=world_size, join=True ) with self.assertRaises(Exception): # torch.multiprocessing.spawn kills all sub-procs # if one of them gets killed torch_mp.spawn( fn=_stuck_function, args=(mp_queue,), nprocs=world_size, join=True ) server.stop() @unittest.skipIf(is_asan_or_tsan(), "test is a/tsan incompatible") def test_example_start_method_spawn(self): self._run_example_with(start_method="spawn") @unittest.skipIf(is_asan_or_tsan(), "test is a/tsan incompatible") def test_example_start_method_forkserver(self): self._run_example_with(start_method="forkserver") @unittest.skipIf(is_tsan(), "test is tsan incompatible") def test_example_start_method_fork(self): self._run_example_with(start_method="fork") def _run_example_with(self, start_method): spawn_ctx = mp.get_context(start_method) mp_queue = spawn_ctx.Queue() server = timer.LocalTimerServer(mp_queue, max_interval=0.01) server.start() world_size = 8 processes = [] for i in range(0, world_size): if i % 2 == 0: p = spawn_ctx.Process(target=_stuck_function, args=(i, mp_queue)) else: p = spawn_ctx.Process(target=_happy_function, args=(i, mp_queue)) p.start() processes.append(p) for i in range(0, world_size): p = processes[i] p.join() if i % 2 == 0: self.assertEqual(-signal.SIGKILL, p.exitcode) else: self.assertEqual(0, p.exitcode) server.stop()
class LocalElasticAgentTest(unittest.TestCase): @classmethod def setUpClass(cls): # start a standalone, single process etcd server to use for all tests cls._etcd_server = EtcdServer() cls._etcd_server.start() @classmethod def tearDownClass(cls): # stop the standalone etcd server cls._etcd_server.stop() @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_happy_function(self): spec = self._get_worker_spec(fn=_happy_function) agent = LocalElasticAgent(spec, start_method="fork") agent.run() def _get_worker_spec( self, fn, args=(), max_restarts=1, num_agents=1, monitor_interval=0.1, local_world_size=8, ): run_id = str(uuid.uuid4().int) rdzv_handler = dist.rendezvous( f"etcd://{self._etcd_server.get_endpoint()}/{run_id}" f"?min_workers={num_agents}" f"&max_workers={num_agents}") spec = WorkerSpec( role="test_trainer", local_world_size=local_world_size, fn=fn, args=args, rdzv_handler=rdzv_handler, max_restarts=max_restarts, monitor_interval=monitor_interval, ) return spec @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_distributed_sum(self): spec = self._get_worker_spec(fn=_distributed_sum, args=(0, )) agent = LocalElasticAgent(spec, start_method="fork") agent.run() class RoleConfig: __slots__ = [ "role", "workers", "num_agents", "workers_num", "role_size" ] def __init__(self, role: str, workers=None, num_agents: int = 0, workers_num: int = 0): self.role = role self.workers = workers if workers_num != 0 and num_agents != 0: self.workers = [workers_num] * num_agents self.role_size = sum(self.workers) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_correct_rank_assignment_heterogeneous(self): roles_config = [ self.RoleConfig("trainer", workers=[1, 2, 3, 4]), self.RoleConfig("ps", workers=[5, 2]), # split configuration to run the last one on the main process self.RoleConfig("master", workers=[8]), ] self.run_configuration(roles_config, 25) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_correct_rank_assignment_homogeneous(self): num_workers = 4 roles_config = [ self.RoleConfig("trainer", num_agents=4, workers_num=num_workers), self.RoleConfig("ps", num_agents=2, workers_num=num_workers), # split configuration to run the last one on the main process self.RoleConfig("master", num_agents=1, workers_num=num_workers), ] self.run_configuration(roles_config, 28) def run_configuration(self, roles_config, expected_world_size): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = sum(len(cfg.workers) for cfg in roles_config) run_id = str(uuid.uuid4().int) procs = [] manager = multiprocessing.Manager() return_dict = manager.dict() default_args = (run_id, host, port, nnodes, nnodes, _check_rank_assignment, ()) for ind in range(len(roles_config) - 1): config = roles_config[ind] for num_workers in config.workers: p = multiprocessing.Process( target=_run_agent, args=(*default_args, num_workers, config.role, return_dict), ) procs.append(p) p.start() # run one on the main process for debugging config = roles_config[len(roles_config) - 1] _run_agent(*default_args, config.workers[0], config.role, return_dict) for i in range(nnodes - 1): p = procs[i] p.join() self.assertEqual(0, p.exitcode) role_info_dict = { role_info.role: role_info for role_info in roles_config } self.verify_rank_consistency(return_dict, role_info_dict, expected_world_size) def verify_rank_consistency(self, return_dict, role_info_dict, expected_world_size): role_ranks = {} global_ranks = [] grouped_ranks = {} for role, res in return_dict.values(): for ( group_rank, rank, world_size, role_rank, role_world_size, ) in res.values(): role_info_config = role_info_dict[role] self.assertEqual(expected_world_size, world_size) self.assertEqual(role_info_config.role_size, role_world_size) if group_rank not in grouped_ranks: grouped_ranks[group_rank] = [] grouped_ranks[group_rank].append((rank, role_rank)) global_ranks.append(rank) if role not in role_ranks: role_ranks[role] = [] role_ranks[role].append(role_rank) global_ranks = sorted(global_ranks) self.assertEqual(list(range(0, expected_world_size)), global_ranks) for role, role_config_info in role_info_dict.items(): self.assertEqual(list(range(0, role_config_info.role_size)), sorted(role_ranks[role])) # Make sure that each agent assignes consecutive ranks to workes # The first argument is the global_rank and the second argument # is role_rank for ranks_lst in grouped_ranks.values(): self.verify_ranks_sequential(ranks_lst, 0) self.verify_ranks_sequential(ranks_lst, 1) def verify_ranks_sequential(self, ranks_pairs, rank_idx): ranks = sorted(rank_pair[rank_idx] for rank_pair in ranks_pairs) start_rank, end_rank = ranks[0], ranks[-1] self.assertEqual(list(range(start_rank, end_rank + 1)), ranks) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_distributed_sum_heterogenous(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = 4 run_id = str(uuid.uuid4().int) procs = [] default_args = (run_id, host, port, nnodes, nnodes, _distributed_sum, (0, )) for ind in range(nnodes - 1): p = multiprocessing.Process(target=_run_agent, args=(*default_args, ind + 1)) procs.append(p) p.start() # run one on the main process for debugging _run_agent(*default_args, 8) for i in range(nnodes - 1): p = procs[i] p.join() self.assertEqual(0, p.exitcode) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_sad_function(self): spec = self._get_worker_spec(fn=_sad_function, max_restarts=2) agent = LocalElasticAgent(spec, start_method="fork") with self.assertRaises(WorkerGroupFailureException) as cm: agent.run() excs = cm.exception.get_worker_exceptions() for i in range(spec.local_world_size): self.assertTrue(isinstance(excs[i], Exception)) self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state) self.assertEqual(0, agent._remaining_restarts) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_bipolar_function(self): spec = self._get_worker_spec(fn=_bipolar_function, max_restarts=2) agent = LocalElasticAgent(spec, start_method="fork") with self.assertRaises(Exception): agent.run() self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state) self.assertEqual(0, agent._remaining_restarts) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_check_env_function(self): spec = self._get_worker_spec(fn=_check_env_function, max_restarts=2) agent = LocalElasticAgent(spec, start_method="fork") agent.run() @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_run_check_run_id(self): def return_run_id(): return os.environ["TORCHELASTIC_RUN_ID"] spec = self._get_worker_spec(fn=return_run_id, max_restarts=0) agent = LocalElasticAgent(spec, start_method="fork") ret = agent.run() for i in range(spec.local_world_size): self.assertEqual(spec.rdzv_handler.get_run_id(), ret[i]) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_get_worker_return_values(self): spec = self._get_worker_spec(fn=_return_rank_times, args=(2, )) agent = LocalElasticAgent(spec, start_method="fork") ret_vals = agent.run() self.assertEqual(spec.local_world_size, len(ret_vals)) for i in range(spec.local_world_size): self.assertEqual(i * 2, ret_vals[i]) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_double_agent_happy(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = 2 run_id = str(uuid.uuid4().int) procs = [] for _ in range(nnodes - 1): p = multiprocessing.Process( target=_run_agent, args=(run_id, host, port, nnodes, nnodes, _distributed_sum, (0, )), ) procs.append(p) p.start() # run one on the main process for debugging _run_agent(run_id, host, port, nnodes, nnodes, _distributed_sum, (0, )) for i in range(nnodes - 1): p = procs[i] p.join() self.assertEqual(0, p.exitcode) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_double_agent_fault_tolerance(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = 2 run_id = str(uuid.uuid4().int) procs = [] for _ in range(nnodes): p = multiprocessing.Process( target=_run_agent, args=(run_id, host, port, nnodes, nnodes, _distributed_sum, (0, )), ) procs.append(p) p.start() # restart odd agents for i in range(nnodes): if i % 2 != 0: procs[i].kill() p = multiprocessing.Process( target=_run_agent, args=(run_id, host, port, nnodes, nnodes, _distributed_sum, (0, )), ) procs[i] = p p.start() for i in range(nnodes): p = procs[i] p.join() self.assertEqual(0, p.exitcode) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_double_agent_elastic(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() min_size = 1 max_size = 2 run_id = str(uuid.uuid4().int) procs = [] for _ in range(max_size): p = multiprocessing.Process( target=_run_agent, args=(run_id, host, port, min_size, max_size, _distributed_sum, (0, )), ) procs.append(p) p.start() # kill odd agents for i in range(max_size): if i % 2 != 0: procs[i].kill() for i in range(max_size): if i % 2 == 0: p = procs[i] p.join() self.assertEqual(0, p.exitcode) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_torch_rpc(self): """ Simple torch rpc example with torchelastic. Creates two agents (to simulate two node job), each agent runs a single worker. worker0 calls an rpc_sync on worker1. """ # TODO upstream this to torch.distributed.rpc so that users do not have # to redundantly set rank as part of name (e.g. worker0) AND also pass # it explicitly as an argument to rpc.init_rpc def init_rpc(name_prefix, backend): rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) rpc.init_rpc( name=f"{name_prefix}{rank}", backend=backend, rank=rank, world_size=world_size, ) def worker_0(queue, msg): init_rpc("worker", BackendType.PROCESS_GROUP) ret = rpc.rpc_sync(to="worker1", func=echo, args=(msg, )) queue.put(ret) rpc.shutdown() def worker_1(): init_rpc("worker", BackendType.PROCESS_GROUP) rpc.shutdown() def run_agent(run_id, etcd_host, etcd_port, start_method, worker_fn, worker_args=()): rdzv_handler = dist.rendezvous( f"etcd://{etcd_host}:{etcd_port}/{run_id}" f"?min_workers=2" f"&max_workers=2") spec = WorkerSpec( role="test_trainer", local_world_size=1, fn=worker_fn, args=worker_args, rdzv_handler=rdzv_handler, max_restarts=3, monitor_interval=1, ) agent = LocalElasticAgent(spec, start_method) agent.run() run_id = str(uuid.uuid4().int) host = self._etcd_server.get_host() port = self._etcd_server.get_port() start_method = "fork" msg = "hello world" mp_queue = multiprocessing.get_context(start_method).Queue() agent0 = multiprocessing.Process( target=run_agent, args=(run_id, host, port, start_method, worker_0, (mp_queue, msg)), ) agent1 = multiprocessing.Process(target=run_agent, args=(run_id, host, port, start_method, worker_1, ())) agent0.start() agent1.start() agent0.join() agent1.join() self.assertEqual(msg, mp_queue.get()) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_workers_drift_success(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = 2 run_id = str(uuid.uuid4().int) procs = [] default_args = (run_id, host, port, nnodes, nnodes, _simulate_work) for _ in range(nnodes - 1): p = multiprocessing.Process( target=_run_agent, args=(*default_args, (10, ), 2, "test_trainer", {}, 30), ) procs.append(p) p.start() _run_agent(*default_args, (1, ), 2, "test_trainer", {}, 30) for i in range(nnodes - 1): p = procs[i] p.join() self.assertEqual(0, p.exitcode) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_workers_drift_fail(self): host = self._etcd_server.get_host() port = self._etcd_server.get_port() nnodes = 2 run_id = str(uuid.uuid4().int) procs = [] default_args = (run_id, host, port, nnodes, nnodes, _simulate_work) for _ in range(nnodes - 1): p = multiprocessing.Process( target=_run_agent, args=(*default_args, (60, ), 2, "test_trainer", {}, 10), ) procs.append(p) p.start() # TODO(aivanou): standardize error between different rendezvous stores with self.assertRaises(LookupError): _run_agent(*default_args, (1, ), 2, "test_trainer", {}, 10)
class LaunchTest(unittest.TestCase): @classmethod def setUpClass(cls): # start a standalone, single process etcd server to use for all tests cls._etcd_server = EtcdServer() cls._etcd_server.start() cls._etcd_endpoint = cls._etcd_server.get_endpoint() @classmethod def tearDownClass(cls): # stop the standalone etcd server cls._etcd_server.stop() def setUp(self): self.test_dir = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.test_dir) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_user_script_python(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_user_script_bash(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", f"--no_python", ] script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] with self.assertRaises(ValueError): # --no_python cannot be used with --module launch.main(args + ["--module"] + script_args) launch.main(args + script_args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) def _test_nproc_launch_configuration(self, nproc_type, expected_number): run_id = str(uuid.uuid4().int) nnodes = 1 args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_type}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", f"--no_python", ] script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] launch.main(args + script_args) world_size = nnodes * expected_number # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) def test_nproc_launch_auto_configurations(self): self._test_nproc_launch_configuration("auto", os.cpu_count()) def test_nproc_launch_number_configurations(self): self._test_nproc_launch_configuration("4", 4) def test_nproc_launch_unknown_configurations(self): with self.assertRaises(ValueError): self._test_nproc_launch_configuration("unknown", 4) @patch("torch.cuda.is_available", return_value=True) @patch("torch.cuda.device_count", return_value=3) def test_nproc_gpu_launch_configurations(self, _mock1, _mock2): self._test_nproc_launch_configuration("auto", 3) self._test_nproc_launch_configuration("gpu", 3) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_elastic(self): run_id = str(uuid.uuid4().int) min_nodes = 1 max_nodes = 2 nproc_per_node = 4 # we are only launching 1 node (even though max = 2) world_size = nproc_per_node args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc_per_node={nproc_per_node}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_with_etcd(self): nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", f"--with_etcd", f"--monitor_interval=1", f"--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)))
class TestRpc(unittest.TestCase): def test_init_app(self): app.init_app(role="trainer", backend=BackendType.PROCESS_GROUP, backend_options=None) @patch("torch.distributed.autograd._init") @patch("torch.distributed.rpc.api._init_rpc_backend") def test_init_rpc(self, rpc_backend_mock, autograd_mock): os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" store = TestStore() app.init_rpc( name="trainer_worker", backend=BackendType.PROCESS_GROUP, backend_options=None, store=store, ) autograd_mock.assert_called_once() rpc_backend_mock.assert_called_once() @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_custom_init_rpc(self): def init_rpc(rank, world_size, port, name): os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(port) rendezvous_iterator = dist.rendezvous("env://", rank=rank, world_size=world_size) store, _, _ = next(rendezvous_iterator) app.init_rpc( name=name, backend=BackendType.PROCESS_GROUP, backend_options=None, store=store, ) def master(msg, port): init_rpc(rank=0, world_size=2, port=port, name="master") ret = rpc.rpc_sync(to="worker", func=echo, args=(msg, )) rpc.shutdown() return ret def worker(port): init_rpc(rank=1, world_size=2, port=port, name="worker") rpc.shutdown() sock = find_free_port() port = sock.getsockname()[1] sock.close() worker_proc = multiprocessing.Process(target=worker, args=(port, )) worker_proc.start() expected_msg = "test_message_on_worker" actual_msg = master(expected_msg, port) worker_proc.join() self.assertEqual(expected_msg, actual_msg) def test_get_worker_names(self): pass def test_get_role_info(self): pass def test_get_all_roles(self): pass def test_wait_all(self): pass def test_rpc_sync_on_role(self): pass def test_rpc_async_on_role(self): pass def test_rpc_remote_on_role(self): pass def test_init_process_group(self): pass
class LaunchTest(unittest.TestCase): @classmethod def setUpClass(cls): # start a standalone, single process etcd server to use for all tests cls._etcd_server = EtcdServerFixture() cls._etcd_server.start() host = cls._etcd_server.get_host() port = cls._etcd_server.get_port() cls._etcd_endpoint = f"{host}:{port}" @classmethod def tearDownClass(cls): # stop the standalone etcd server cls._etcd_server.stop() def setUp(self): self.test_dir = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.test_dir) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_user_script_python(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_user_script_python_use_env(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", f"--use_env", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", path("bin/test_script_use_env.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_user_script_bash(self): run_id = str(uuid.uuid4().int) nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", f"--no_python", ] script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] with self.assertRaises(ValueError): # --no_python also requires --use_env launch.main(args + script_args) with self.assertRaises(ValueError): # --no_python cannot be used with --module launch.main(args + ["--module"] + script_args) launch.main(args + ["--use_env"] + script_args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))) @unittest.skipIf(is_tsan(), "test incompatible with tsan") def test_launch_elastic(self): run_id = str(uuid.uuid4().int) min_nodes = 1 max_nodes = 2 nproc_per_node = 4 # we are only launching 1 node (even though max = 2) world_size = nproc_per_node args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc_per_node={nproc_per_node}", f"--rdzv_backend=etcd", f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", f"--monitor_interval=1", f"--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)))
class LocalTimerTest(unittest.TestCase): def setUp(self): self.mp_queue = mp.Queue() self.max_interval = 0.01 self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval) self.server.start() def tearDown(self): self.server.stop() def test_exception_propagation(self): with self.assertRaises(Exception, msg="foobar"): with timer.expires(after=1): raise Exception("foobar") def test_no_client(self): # no timer client configured; exception expected timer.configure(None) with self.assertRaises(RuntimeError): with timer.expires(after=1): pass def test_client_interaction(self): # no timer client configured but one passed in explicitly # no exception expected timer_client = timer.LocalTimerClient(self.mp_queue) timer_client.acquire = mock.MagicMock(wraps=timer_client.acquire) timer_client.release = mock.MagicMock(wraps=timer_client.release) with timer.expires(after=1, scope="test", client=timer_client): pass timer_client.acquire.assert_called_once_with("test", mock.ANY) timer_client.release.assert_called_once_with("test") def test_happy_path(self): timer.configure(timer.LocalTimerClient(self.mp_queue)) with timer.expires(after=0.5): time.sleep(0.1) @unittest.skipIf(is_tsan(), "test is tsan incompatible") def test_get_timer_recursive(self): """ If a function acquires a countdown timer with default scope, then recursive calls to the function should re-acquire the timer rather than creating a new one. That is only the last recursive call's timer will take effect. """ self.server.start() timer.configure(timer.LocalTimerClient(self.mp_queue)) # func should not time out def func(n): if n > 0: with timer.expires(after=0.1): func(n - 1) time.sleep(0.05) func(4) # func2 should time out def func2(n): if n > 0: with timer.expires(after=0.1): func2(n - 1) time.sleep(0.2) p = mp.Process(target=func2, args=(2, )) p.start() p.join() self.assertEqual(-signal.SIGKILL, p.exitcode) @staticmethod def _run(mp_queue, timeout, duration): client = timer.LocalTimerClient(mp_queue) timer.configure(client) with timer.expires(after=timeout): time.sleep(duration) @unittest.skipIf(is_tsan(), "test is tsan incompatible") def test_timer(self): timeout = 0.1 duration = 1 p = mp.Process(target=self._run, args=(self.mp_queue, timeout, duration)) p.start() p.join() self.assertEqual(-signal.SIGKILL, p.exitcode)
class LocalTimerServerTest(unittest.TestCase): def setUp(self): self.mp_queue = mp.Queue() self.max_interval = 0.01 self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval) def tearDown(self): self.server.stop() @unittest.skipIf(is_tsan(), "test is tsan incompatible") def test_watchdog_call_count(self): """ checks that the watchdog function ran wait/interval +- 1 times """ self.server._run_watchdog = mock.MagicMock( wraps=self.server._run_watchdog) wait = 0.1 self.server.start() time.sleep(wait) self.server.stop() watchdog_call_count = self.server._run_watchdog.call_count self.assertGreaterEqual(watchdog_call_count, int(wait / self.max_interval) - 1) self.assertLessEqual(watchdog_call_count, int(wait / self.max_interval) + 1) def test_watchdog_empty_queue(self): """ checks that the watchdog can run on an empty queue """ self.server._run_watchdog() def _expired_timer(self, pid, scope): expired = time.time() - 60 return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=expired) def _valid_timer(self, pid, scope): valid = time.time() + 60 return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=valid) def _release_timer(self, pid, scope): return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=-1) @unittest.skipIf(is_tsan(), "test is tsan incompatible") @mock.patch("os.kill") def test_expired_timers(self, mock_os_kill): """ tests that a single expired timer on a process should terminate the process and clean up all pending timers that was owned by the process """ test_pid = -3 self.mp_queue.put(self._expired_timer(pid=test_pid, scope="test1")) self.mp_queue.put(self._valid_timer(pid=test_pid, scope="test2")) self.server._run_watchdog() self.assertEqual(0, len(self.server._timers)) mock_os_kill.assert_called_once_with(test_pid, signal.SIGKILL) @mock.patch("os.kill") def test_acquire_release(self, mock_os_kill): """ tests that: 1. a timer can be acquired then released (should not terminate process) 2. a timer can be vacuously released (e.g. no-op) """ test_pid = -3 self.mp_queue.put(self._valid_timer(pid=test_pid, scope="test1")) self.mp_queue.put(self._release_timer(pid=test_pid, scope="test1")) self.mp_queue.put(self._release_timer(pid=test_pid, scope="test2")) self.server._run_watchdog() self.assertEqual(0, len(self.server._timers)) mock_os_kill.assert_not_called() @mock.patch("os.kill") def test_valid_timers(self, mock_os_kill): """ tests that valid timers are processed correctly and the process is left alone """ self.mp_queue.put(self._valid_timer(pid=-3, scope="test1")) self.mp_queue.put(self._valid_timer(pid=-3, scope="test2")) self.mp_queue.put(self._valid_timer(pid=-2, scope="test1")) self.mp_queue.put(self._valid_timer(pid=-2, scope="test2")) self.server._run_watchdog() self.assertEqual(4, len(self.server._timers)) self.assertTrue((-3, "test1") in self.server._timers) self.assertTrue((-3, "test2") in self.server._timers) self.assertTrue((-2, "test1") in self.server._timers) self.assertTrue((-2, "test2") in self.server._timers) mock_os_kill.assert_not_called()