예제 #1
0
 def test_duplicate_name(self):
     with self.assertRaisesRegex(RuntimeError, "is not unique"):
         store, _, _ = next(torch.distributed.rendezvous(
             self.init_method, rank=self.rank, world_size=self.world_size
         ))
         rpc._init_rpc(
             backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
             store=store,
             self_name="duplicate_name",
             self_rank=self.rank,
             worker_name_to_id=self.worker_name_to_id,
         )
     rpc.join_rpc()
예제 #2
0
    def test_invalid_names(self):
        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            store, _, _ = next(torch.distributed.rendezvous(
                self.init_method, rank=self.rank, world_size=self.world_size
            ))
            rpc._init_rpc(
                backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
                store=store,
                self_name="abc*",
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        base_file_name = self.file_name
        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "1"

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            store, _, _ = next(torch.distributed.rendezvous(
                self.init_method, rank=self.rank, world_size=self.world_size
            ))
            rpc._init_rpc(
                backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
                store=store,
                self_name=" ",
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "2"
        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
            store, _, _ = next(torch.distributed.rendezvous(
                self.init_method, rank=self.rank, world_size=self.world_size
            ))
            rpc._init_rpc(
                backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
                store=store,
                self_name="",
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "3"
        # If the number in the message does not match, it is likely that the
        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
            store, _, _ = next(torch.distributed.rendezvous(
                self.init_method, rank=self.rank, world_size=self.world_size
            ))
            rpc._init_rpc(
                backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
                store=store,
                self_name="".join(["a" for i in range(500)]),
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        from torch.distributed.rpc.api import _agent
        self.assertEqual(_agent, None)
        # join_rpc() should not do anything as _agent is None
        rpc.join_rpc()