def test_reinit(self): dist.init_process_group( backend=dist.Backend.GLOO, init_method=self.init_method, rank=self.rank, world_size=self.world_size, ) rpc.init_model_parallel( self_name="worker{}".format(self.rank), backend=TEST_CONFIG.rpc_backend, init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, ) # Wait for all init to complete. dist.barrier() with self.assertRaisesRegex(RuntimeError, "is already initialized"): rpc.init_model_parallel( self_name="worker{}".format(self.rank), backend=TEST_CONFIG.rpc_backend, init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, ) rpc.join_rpc()
def test_reinit(self): rpc.init_model_parallel( self_name="worker{}".format(self.rank), backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name], init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, ) # This is for the below `dist.barrier`. # For `RpcAgent` other than `ProcessGroupAgent`, # no `_default_pg` is initialized. if not dist.is_initialized(): dist.init_process_group( backend="gloo", init_method=self.init_method, rank=self.rank, world_size=self.world_size, ) # Wait for all init to complete. dist.barrier() with self.assertRaisesRegex(RuntimeError, "is already initialized"): rpc.init_model_parallel( self_name="worker{}".format(self.rank), backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name], init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, ) rpc.join_rpc()
def test_join_rpc(self): # Initialize RPC. rpc.init_model_parallel( self_name="worker%d" % self.rank, backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name], init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, ) n = self.rank + 1 dst_rank = n % self.world_size ret = rpc.rpc_sync( "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), torch.ones(n, n)), ) self.assertEqual(ret, torch.ones(n, n) * 2) rpc.join_rpc() with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"): rpc.rpc_sync( "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), torch.ones(n, n)), ) # it's safe to call join_rpc() multiple times rpc.join_rpc()
def test_register_rpc_backend_and_init_rpc_backend( self, mock_init_rref_context, mock_dist_autograd_init): backend_name = "stub_backend" rpc.register_backend(backend_name, stub_init_rpc_backend_handler) rpc.init_model_parallel(self_name="worker1", backend=backend_name, self_rank=1)
def test_init_invalid_backend(self): with self.assertRaisesRegex(RuntimeError, "Unrecognized RPC backend"): rpc.init_model_parallel( self_name="worker{}".format(self.rank), backend="invalid", self_rank=self.rank, init_method=self.init_method, )
def new_test_method(self, *arg, **kwargs): self.worker_id = self.rank self.worker_name_to_id = { "worker{}".format(rank): rank for rank in range(self.world_size) } if setup_model_parallel: global _ALL_NODE_NAMES _ALL_NODE_NAMES = self.worker_name_to_id.keys() # Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359 rpc.init_model_parallel( self_name="worker%d" % self.rank, backend=rpc.backend_registry.BackendType[ TEST_CONFIG.rpc_backend_name], init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, num_send_recv_threads=16, ) return_value = old_test_method(self, *arg, **kwargs) if setup_model_parallel and clean_shutdown: # Follower reports done. if self.rank == MASTER_RANK: on_master_follower_report_done("worker{}".format(MASTER_RANK)) else: rpc.rpc_async( "worker{}".format(MASTER_RANK), on_master_follower_report_done, args=("worker{}".format(self.rank), ), ) # Master waits for followers to report done. # Follower waits for master's termination command. _TERMINATION_SIGNAL.wait() if self.rank == MASTER_RANK: # Master sends termination command. futs = [] for dst_rank in range(self.world_size): # torch.distributed.rpc module does not support sending to self. if dst_rank == MASTER_RANK: continue dst_name = "worker{}".format(dst_rank) fut = rpc.rpc_async(dst_name, set_termination_signal, args=()) futs.append(fut) for fut in futs: assert fut.wait( ) is None, "Sending termination signal failed." # Close RPC. rpc.join_rpc() return return_value
def test_duplicate_name(self): dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method) with self.assertRaisesRegex(RuntimeError, "is not unique"): rpc.init_model_parallel( self_name="duplicate_name", backend=TEST_CONFIG.rpc_backend, self_rank=self.rank, init_method=self.init_method, ) rpc.join_rpc()
def test_duplicate_name(self): with self.assertRaisesRegex(RuntimeError, "is not unique"): rpc.init_model_parallel( self_name="duplicate_name", backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name], init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, ) rpc.join_rpc()
def wrapper(self, *arg, **kwargs): self.worker_id = self.rank dist.init_process_group(backend="gloo", init_method=self.init_method) rpc.init_model_parallel( self_name="worker%d" % self.rank, backend=TEST_CONFIG.backend, self_rank=self.rank, init_method=self.init_method, ) test_method(self, *arg, **kwargs) rpc.join_rpc()
def test_register_rpc_backend_and_start_rpc_backend( self, mock_rpc_agent, mock_dist_autograd_init): backend_name = "stub_backend" rpc.register_backend(backend_name, stub_start_rpc_backend_handler) rpc.init_model_parallel( self_name="worker1", backend=backend_name, init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, )
def wrapper(self, *arg, **kwargs): self.worker_id = self.rank dist.init_process_group(backend="gloo", init_method=self.init_method) # Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359 rpc.init_model_parallel(self_name="worker%d" % self.rank, backend=TEST_CONFIG.rpc_backend, self_rank=self.rank, init_method=self.init_method, num_send_recv_threads=16) test_method(self, *arg, **kwargs) rpc.join_rpc()
def test_set_rpc_timeout(self): timeout = timedelta(seconds=1) rpc.init_model_parallel(self_name="worker{}".format(self.rank), backend=rpc.backend_registry.BackendType[ TEST_CONFIG.rpc_backend_name], init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, rpc_timeout=timeout) set_timeout = rpc.get_rpc_timeout() self.assertEqual(timeout, set_timeout) rpc.join_rpc()
def test_invalid_names(self): dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method) with self.assertRaisesRegex(RuntimeError, "Worker name must match"): rpc.init_model_parallel(self_name="abc*") with self.assertRaisesRegex(RuntimeError, "Worker name must match"): rpc.init_model_parallel(self_name=" ") with self.assertRaisesRegex(RuntimeError, "must be non-empty"): rpc.init_model_parallel(self_name="") # If the number in the message does not match, it is likely that the # value of MAX_NAME_LEN in RPC WorkerInfo has changed. with self.assertRaisesRegex(RuntimeError, "shorter than 128"): rpc.init_model_parallel(self_name="".join( ["a" for _ in range(500)]), ) from torch.distributed.rpc.api import _agent self.assertEqual(_agent, None) # join_rpc() should not do anything as _agent is None rpc.join_rpc() # We need this barrier here because although init_process_group is # blocking, it does not guarantee that all ranks are done with # initialization after the call. We did run into issues with it where # rank 3 crashed with "connection closed by peer" RuntimeError, which is # caused by other ranks exit before rank 3 is ready. This can be fixed # by adding a collective call to sync all processes. # # We decided not fixing this issue in init_process_group because it # would add extra overhead to the call, and normal use cases won't # create a progress group and exit without doing anything. Hence, it is # not worthy to introduce the overhead just for this test case. dist.barrier()
def wrapper(self, *arg, **kwargs): self.worker_id = self.rank global _ALL_NODE_NAMES _ALL_NODE_NAMES = { "worker{}".format(rank) for rank in range(self.world_size) } # Initialize RPC. dist.init_process_group(backend="gloo", init_method=self.init_method) # Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359 rpc.init_model_parallel(self_name="worker%d" % self.rank, backend=TEST_CONFIG.rpc_backend, self_rank=self.rank, init_method=self.init_method, num_send_recv_threads=16) test_method(self, *arg, **kwargs) # Follower reports done. if self.rank == MASTER_RANK: on_master_follower_report_done("worker{}".format(MASTER_RANK)) else: rpc.rpc_async( "worker{}".format(MASTER_RANK), on_master_follower_report_done, args=("worker{}".format(self.rank), ), ) # Master waits for followers to report done. # Follower waits for master's termination command. _TERMINATION_SIGNAL.wait() if self.rank == MASTER_RANK: # Master sends termination command. futs = [] for dst_rank in range(self.world_size): # torch.distributed.rpc module does not support sending to self. if dst_rank == MASTER_RANK: continue dst_name = "worker{}".format(dst_rank) fut = rpc.rpc_async( dst_name, set_termination_signal, args=(), ) futs.append(fut) for fut in futs: assert fut.wait() is None, "Sending termination signal failed." # Close RPC. rpc.join_rpc()
def test_invalid_names(self): with self.assertRaisesRegex(RuntimeError, "Worker name must match"): rpc.init_model_parallel( self_name="abc*", backend=rpc.backend_registry.BackendType[ TEST_CONFIG.rpc_backend_name], init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, num_send_recv_threads=16, ) base_file_name = self.file_name # Use a different file path for FileStore to avoid rendezvous mismatch. self.file_name = base_file_name + "1" with self.assertRaisesRegex(RuntimeError, "Worker name must match"): rpc.init_model_parallel( self_name=" ", backend=rpc.backend_registry.BackendType[ TEST_CONFIG.rpc_backend_name], init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, num_send_recv_threads=16, ) # Use a different file path for FileStore to avoid rendezvous mismatch. self.file_name = base_file_name + "2" with self.assertRaisesRegex(RuntimeError, "must be non-empty"): rpc.init_model_parallel( self_name="", backend=rpc.backend_registry.BackendType[ TEST_CONFIG.rpc_backend_name], init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, num_send_recv_threads=16, ) # Use a different file path for FileStore to avoid rendezvous mismatch. self.file_name = base_file_name + "3" # If the number in the message does not match, it is likely that the # value of MAX_NAME_LEN in RPC WorkerInfo has changed. with self.assertRaisesRegex(RuntimeError, "shorter than 128"): rpc.init_model_parallel( self_name="".join(["a" for _ in range(500)]), backend=rpc.backend_registry.BackendType[ TEST_CONFIG.rpc_backend_name], init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, num_send_recv_threads=16, ) from torch.distributed.rpc.api import _agent self.assertEqual(_agent, None) # join_rpc() should not do anything as _agent is None rpc.join_rpc()
def test_reinit(self): dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method) rpc.init_model_parallel( self_name="worker{}".format(self.rank), backend=TEST_CONFIG.rpc_backend, self_rank=self.rank, init_method=self.init_method, ) with self.assertRaisesRegex(RuntimeError, "is already initialized"): rpc.init_model_parallel( self_name="worker{}".format(self.rank), backend=TEST_CONFIG.rpc_backend, self_rank=self.rank, init_method=self.init_method, ) rpc.join_rpc()
def test_register_rpc_backend_and_start_rpc_backend( self, mock_rpc_agent, mock_dist_autograd_init): backend_name = "stub_backend" backend = rpc.backend_registry.register_backend( backend_name, stub_start_rpc_backend_handler) with self.assertRaisesRegex(RuntimeError, "^RPC backend .+: already registered$"): rpc.backend_registry.register_backend( backend_name, stub_start_rpc_backend_handler) rpc.init_model_parallel( self_name="worker1", backend=backend, init_method=self.init_method, self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, )
def test_invalid_names(self): dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method) with self.assertRaisesRegex(RuntimeError, "Worker name must match"): rpc.init_model_parallel(self_name="abc*") with self.assertRaisesRegex(RuntimeError, "Worker name must match"): rpc.init_model_parallel(self_name=" ") with self.assertRaisesRegex(RuntimeError, "must be non-empty"): rpc.init_model_parallel(self_name="") # If the number in the message does not match, it is likely that the # value of MAX_NAME_LEN in RPC WorkerInfo has changed. with self.assertRaisesRegex(RuntimeError, "shorter than 128"): rpc.init_model_parallel( self_name="".join(["a" for _ in range(500)]), backend=TEST_CONFIG.rpc_backend, self_rank=self.rank, init_method=self.init_method, ) rpc.join_rpc()