def test_reinit(self): rpc.init_rpc( name="worker{}".format(self.rank), backend=self.rpc_backend, init_method=self.init_method, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) # This is for the below `dist.barrier`. # For `RpcAgent` other than `ProcessGroupAgent`, # no `_default_pg` is initialized. if not dist.is_initialized(): dist.init_process_group( backend="gloo", init_method=self.init_method, rank=self.rank, world_size=self.world_size, ) # Wait for all init to complete. dist.barrier() with self.assertRaisesRegex(RuntimeError, "is already initialized"): rpc.init_rpc( name="worker{}".format(self.rank), backend=self.rpc_backend, init_method=self.init_method, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) rpc.wait_all_workers()
def test_wait_all_workers(self): # Initialize RPC. rpc.init_rpc( name="worker%d" % self.rank, backend=self.rpc_backend, init_method=self.init_method, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) n = self.rank + 1 dst_rank = n % self.world_size ret = rpc.rpc_sync( "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), torch.ones(n, n)), ) self.assertEqual(ret, torch.ones(n, n) * 2) rpc.wait_all_workers() with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"): rpc.rpc_sync( "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), torch.ones(n, n)), ) # it's safe to call wait_all_workers() multiple times rpc.wait_all_workers()
def new_test_method(self, *arg, **kwargs): self.worker_id = self.rank if setup_rpc: global _ALL_NODE_NAMES _ALL_NODE_NAMES = { "worker{}".format(rank) for rank in range(self.world_size) } rpc.init_rpc( name="worker%d" % self.rank, backend=self.rpc_backend, init_method=self.init_method, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) return_value = old_test_method(self, *arg, **kwargs) if setup_rpc: if clean_shutdown: # Follower reports done. if self.rank == MASTER_RANK: on_master_follower_report_done( "worker{}".format(MASTER_RANK)) else: rpc.rpc_async( "worker{}".format(MASTER_RANK), on_master_follower_report_done, args=("worker{}".format(self.rank), ), ) # Master waits for followers to report done. # Follower waits for master's termination command. _TERMINATION_SIGNAL.wait() if self.rank == MASTER_RANK: # Master sends termination command. futs = [] for dst_rank in range(self.world_size): # torch.distributed.rpc module does not support sending to self. if dst_rank == MASTER_RANK: continue dst_name = "worker{}".format(dst_rank) fut = rpc.rpc_async(dst_name, set_termination_signal, args=()) futs.append(fut) for fut in futs: assert fut.wait( ) is None, "Sending termination signal failed." # Close RPC. Need to do this even if we don't have a clean shutdown # since we need to shutdown the RPC agent. If we don't shutdown the # RPC agent, tests would fail since RPC agent threads, locks and # condition variables are not properly terminated. rpc.wait_all_workers() return return_value
def _test_rref_leak(self, ignore_leak=False): rpc.init_rpc( name="worker{}".format(self.rank), backend=self.rpc_backend, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) # This is for the below `dist.barrier`. # For `RpcAgent` other than `ProcessGroupAgent`, # no `_default_pg` is initialized. if not dist.is_initialized(): dist.init_process_group( backend="gloo", init_method=self.init_method, rank=self.rank, world_size=self.world_size, ) # Wait for all init to complete. dist.barrier() rref = rpc.remote("worker{}".format((self.rank + 1) % self.world_size), torch.add, args=(torch.ones(2, 2), 1)) if ignore_leak: import torch.distributed.rpc.api as api api._ignore_rref_leak = True rpc.wait_all_workers()
def test_duplicate_name(self): with self.assertRaisesRegex(RuntimeError, "is not unique"): store, _, _ = next(torch.distributed.rendezvous( self.init_method, rank=self.rank, world_size=self.world_size )) rpc._init_rpc_backend( backend=self.rpc_backend, store=store, name="duplicate_name", rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) rpc.wait_all_workers()
def test_get_rpc_timeout(self): timeout = timedelta(seconds=1) # A new `RpcBackendOptions` is constructed # when accessing `self.rpc_backend_options`. rpc_backend_options = self.rpc_backend_options rpc_backend_options.rpc_timeout = timeout rpc.init_rpc( name="worker{}".format(self.rank), backend=self.rpc_backend, rank=self.rank, world_size=self.world_size, rpc_backend_options=rpc_backend_options, ) set_timeout = rpc.get_rpc_timeout() self.assertEqual(timeout, set_timeout) rpc.wait_all_workers()
def test_invalid_names(self): with self.assertRaisesRegex(RuntimeError, "Worker name must match"): store, _, _ = next( torch.distributed.rendezvous(self.init_method, rank=self.rank, world_size=self.world_size)) rpc._init_rpc_backend( backend=self.rpc_backend, store=store, name="abc*", rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) base_file_name = self.file_name # Use a different file path for FileStore to avoid rendezvous mismatch. self.file_name = base_file_name + "1" with self.assertRaisesRegex(RuntimeError, "Worker name must match"): store, _, _ = next( torch.distributed.rendezvous(self.init_method, rank=self.rank, world_size=self.world_size)) rpc._init_rpc_backend( backend=self.rpc_backend, store=store, name=" ", rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) # Use a different file path for FileStore to avoid rendezvous mismatch. self.file_name = base_file_name + "2" with self.assertRaisesRegex(RuntimeError, "must be non-empty"): store, _, _ = next( torch.distributed.rendezvous(self.init_method, rank=self.rank, world_size=self.world_size)) rpc._init_rpc_backend( backend=self.rpc_backend, store=store, name="", rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) # Use a different file path for FileStore to avoid rendezvous mismatch. self.file_name = base_file_name + "3" # If the number in the message does not match, it is likely that the # value of MAX_NAME_LEN in RPC WorkerInfo has changed. with self.assertRaisesRegex(RuntimeError, "shorter than 128"): store, _, _ = next( torch.distributed.rendezvous(self.init_method, rank=self.rank, world_size=self.world_size)) rpc._init_rpc_backend( backend=self.rpc_backend, store=store, name="".join(["a" for i in range(500)]), rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) from torch.distributed.rpc.api import _agent self.assertEqual(_agent, None) # wait_all_workers() should not do anything as _agent is None rpc.wait_all_workers()