Ejemplo n.º 1
0
 def test_duplicate_name(self):
     with self.assertRaisesRegex(RuntimeError, "is not unique"):
         store, _, _ = next(torch.distributed.rendezvous(
             self.init_method, rank=self.rank, world_size=self.world_size
         ))
         rpc._init_rpc_backend(
             backend=self.rpc_backend,
             store=store,
             name="duplicate_name",
             rank=self.rank,
             world_size=self.world_size,
             rpc_backend_options=self.rpc_backend_options,
         )
     rpc.shutdown()
Ejemplo n.º 2
0
 def test_duplicate_name(self):
     with self.assertRaisesRegex(RuntimeError, "is not unique"):
         store, _, _ = next(
             torch.distributed.rendezvous(self.init_method,
                                          rank=self.rank,
                                          world_size=self.world_size))
         rpc._init_rpc_backend(
             backend=self.rpc_backend,
             store=store,
             self_name="duplicate_name",
             self_rank=self.rank,
             worker_name_to_id=self.worker_name_to_id,
         )
     rpc.join_rpc()
Ejemplo n.º 3
0
    def test_invalid_names(self):
        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            store, _, _ = next(
                torch.distributed.rendezvous(self.init_method,
                                             rank=self.rank,
                                             world_size=self.world_size))
            rpc._init_rpc_backend(
                backend=self.rpc_backend,
                store=store,
                name="abc*",
                rank=self.rank,
                world_size=self.world_size,
                rpc_backend_options=self.rpc_backend_options,
            )

        base_file_name = self.file_name
        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "1"

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            store, _, _ = next(
                torch.distributed.rendezvous(self.init_method,
                                             rank=self.rank,
                                             world_size=self.world_size))
            rpc._init_rpc_backend(
                backend=self.rpc_backend,
                store=store,
                name=" ",
                rank=self.rank,
                world_size=self.world_size,
                rpc_backend_options=self.rpc_backend_options,
            )

        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "2"
        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
            store, _, _ = next(
                torch.distributed.rendezvous(self.init_method,
                                             rank=self.rank,
                                             world_size=self.world_size))
            rpc._init_rpc_backend(
                backend=self.rpc_backend,
                store=store,
                name="",
                rank=self.rank,
                world_size=self.world_size,
                rpc_backend_options=self.rpc_backend_options,
            )

        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "3"
        # If the number in the message does not match, it is likely that the
        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
            store, _, _ = next(
                torch.distributed.rendezvous(self.init_method,
                                             rank=self.rank,
                                             world_size=self.world_size))
            rpc._init_rpc_backend(
                backend=self.rpc_backend,
                store=store,
                name="".join(["a" for i in range(500)]),
                rank=self.rank,
                world_size=self.world_size,
                rpc_backend_options=self.rpc_backend_options,
            )

        from torch.distributed.rpc.api import _agent
        self.assertEqual(_agent, None)
        # wait_all_workers() should not do anything as _agent is None
        rpc.wait_all_workers()