def test_duplicate_name(self): with self.assertRaisesRegex(RuntimeError, "is not unique"): store, _, _ = next(torch.distributed.rendezvous( self.init_method, rank=self.rank, world_size=self.world_size )) rpc._init_rpc( backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name], store=store, self_name="duplicate_name", self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, ) rpc.join_rpc()
def test_invalid_names(self): with self.assertRaisesRegex(RuntimeError, "Worker name must match"): store, _, _ = next(torch.distributed.rendezvous( self.init_method, rank=self.rank, world_size=self.world_size )) rpc._init_rpc( backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name], store=store, self_name="abc*", self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, num_send_recv_threads=16, ) base_file_name = self.file_name # Use a different file path for FileStore to avoid rendezvous mismatch. self.file_name = base_file_name + "1" with self.assertRaisesRegex(RuntimeError, "Worker name must match"): store, _, _ = next(torch.distributed.rendezvous( self.init_method, rank=self.rank, world_size=self.world_size )) rpc._init_rpc( backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name], store=store, self_name=" ", self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, num_send_recv_threads=16, ) # Use a different file path for FileStore to avoid rendezvous mismatch. self.file_name = base_file_name + "2" with self.assertRaisesRegex(RuntimeError, "must be non-empty"): store, _, _ = next(torch.distributed.rendezvous( self.init_method, rank=self.rank, world_size=self.world_size )) rpc._init_rpc( backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name], store=store, self_name="", self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, num_send_recv_threads=16, ) # Use a different file path for FileStore to avoid rendezvous mismatch. self.file_name = base_file_name + "3" # If the number in the message does not match, it is likely that the # value of MAX_NAME_LEN in RPC WorkerInfo has changed. with self.assertRaisesRegex(RuntimeError, "shorter than 128"): store, _, _ = next(torch.distributed.rendezvous( self.init_method, rank=self.rank, world_size=self.world_size )) rpc._init_rpc( backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name], store=store, self_name="".join(["a" for i in range(500)]), self_rank=self.rank, worker_name_to_id=self.worker_name_to_id, num_send_recv_threads=16, ) from torch.distributed.rpc.api import _agent self.assertEqual(_agent, None) # join_rpc() should not do anything as _agent is None rpc.join_rpc()