示例#1
0
    def test_reinit(self):
        rpc.init_rpc(
            name="worker{}".format(self.rank),
            backend=self.rpc_backend,
            init_method=self.init_method,
            rank=self.rank,
            world_size=self.world_size,
            rpc_backend_options=self.rpc_backend_options,
        )

        # This is for the below `dist.barrier`.
        # For `RpcAgent` other than `ProcessGroupAgent`,
        # no `_default_pg` is initialized.
        if not dist.is_initialized():
            dist.init_process_group(
                backend="gloo",
                init_method=self.init_method,
                rank=self.rank,
                world_size=self.world_size,
            )
        # Wait for all init to complete.
        dist.barrier()

        with self.assertRaisesRegex(RuntimeError, "is already initialized"):
            rpc.init_rpc(
                name="worker{}".format(self.rank),
                backend=self.rpc_backend,
                init_method=self.init_method,
                rank=self.rank,
                world_size=self.world_size,
                rpc_backend_options=self.rpc_backend_options,
            )
        rpc.wait_all_workers()
示例#2
0
    def test_wait_all_workers(self):
        # Initialize RPC.
        rpc.init_rpc(
            name="worker%d" % self.rank,
            backend=self.rpc_backend,
            init_method=self.init_method,
            rank=self.rank,
            world_size=self.world_size,
            rpc_backend_options=self.rpc_backend_options,
        )

        n = self.rank + 1
        dst_rank = n % self.world_size
        ret = rpc.rpc_sync(
            "worker{}".format(dst_rank),
            torch.add,
            args=(torch.ones(n, n), torch.ones(n, n)),
        )
        self.assertEqual(ret, torch.ones(n, n) * 2)
        rpc.wait_all_workers()

        with self.assertRaisesRegex(RuntimeError,
                                    "^RPC has not been initialized"):
            rpc.rpc_sync(
                "worker{}".format(dst_rank),
                torch.add,
                args=(torch.ones(n, n), torch.ones(n, n)),
            )

        # it's safe to call wait_all_workers() multiple times
        rpc.wait_all_workers()
示例#3
0
    def new_test_method(self, *arg, **kwargs):
        self.worker_id = self.rank

        if setup_rpc:
            global _ALL_NODE_NAMES
            _ALL_NODE_NAMES = {
                "worker{}".format(rank)
                for rank in range(self.world_size)
            }

            rpc.init_rpc(
                name="worker%d" % self.rank,
                backend=self.rpc_backend,
                init_method=self.init_method,
                rank=self.rank,
                world_size=self.world_size,
                rpc_backend_options=self.rpc_backend_options,
            )

        return_value = old_test_method(self, *arg, **kwargs)

        if setup_rpc:
            if clean_shutdown:
                # Follower reports done.
                if self.rank == MASTER_RANK:
                    on_master_follower_report_done(
                        "worker{}".format(MASTER_RANK))
                else:
                    rpc.rpc_async(
                        "worker{}".format(MASTER_RANK),
                        on_master_follower_report_done,
                        args=("worker{}".format(self.rank), ),
                    )

                # Master waits for followers to report done.
                # Follower waits for master's termination command.
                _TERMINATION_SIGNAL.wait()
                if self.rank == MASTER_RANK:
                    # Master sends termination command.
                    futs = []
                    for dst_rank in range(self.world_size):
                        # torch.distributed.rpc module does not support sending to self.
                        if dst_rank == MASTER_RANK:
                            continue
                        dst_name = "worker{}".format(dst_rank)
                        fut = rpc.rpc_async(dst_name,
                                            set_termination_signal,
                                            args=())
                        futs.append(fut)
                    for fut in futs:
                        assert fut.wait(
                        ) is None, "Sending termination signal failed."

            # Close RPC. Need to do this even if we don't have a clean shutdown
            # since we need to shutdown the RPC agent. If we don't shutdown the
            # RPC agent, tests would fail since RPC agent threads, locks and
            # condition variables are not properly terminated.
            rpc.wait_all_workers()

        return return_value
示例#4
0
    def _test_rref_leak(self, ignore_leak=False):
        rpc.init_rpc(
            name="worker{}".format(self.rank),
            backend=self.rpc_backend,
            rank=self.rank,
            world_size=self.world_size,
            rpc_backend_options=self.rpc_backend_options,
        )

        # This is for the below `dist.barrier`.
        # For `RpcAgent` other than `ProcessGroupAgent`,
        # no `_default_pg` is initialized.
        if not dist.is_initialized():
            dist.init_process_group(
                backend="gloo",
                init_method=self.init_method,
                rank=self.rank,
                world_size=self.world_size,
            )
        # Wait for all init to complete.
        dist.barrier()

        rref = rpc.remote("worker{}".format((self.rank + 1) % self.world_size),
                          torch.add,
                          args=(torch.ones(2, 2), 1))

        if ignore_leak:
            import torch.distributed.rpc.api as api
            api._ignore_rref_leak = True

        rpc.wait_all_workers()
示例#5
0
 def test_duplicate_name(self):
     with self.assertRaisesRegex(RuntimeError, "is not unique"):
         store, _, _ = next(torch.distributed.rendezvous(
             self.init_method, rank=self.rank, world_size=self.world_size
         ))
         rpc._init_rpc_backend(
             backend=self.rpc_backend,
             store=store,
             name="duplicate_name",
             rank=self.rank,
             world_size=self.world_size,
             rpc_backend_options=self.rpc_backend_options,
         )
     rpc.wait_all_workers()
示例#6
0
    def test_get_rpc_timeout(self):
        timeout = timedelta(seconds=1)

        # A new `RpcBackendOptions` is constructed
        # when accessing `self.rpc_backend_options`.
        rpc_backend_options = self.rpc_backend_options
        rpc_backend_options.rpc_timeout = timeout

        rpc.init_rpc(
            name="worker{}".format(self.rank),
            backend=self.rpc_backend,
            rank=self.rank,
            world_size=self.world_size,
            rpc_backend_options=rpc_backend_options,
        )
        set_timeout = rpc.get_rpc_timeout()
        self.assertEqual(timeout, set_timeout)
        rpc.wait_all_workers()
示例#7
0
    def test_invalid_names(self):
        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            store, _, _ = next(
                torch.distributed.rendezvous(self.init_method,
                                             rank=self.rank,
                                             world_size=self.world_size))
            rpc._init_rpc_backend(
                backend=self.rpc_backend,
                store=store,
                name="abc*",
                rank=self.rank,
                world_size=self.world_size,
                rpc_backend_options=self.rpc_backend_options,
            )

        base_file_name = self.file_name
        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "1"

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            store, _, _ = next(
                torch.distributed.rendezvous(self.init_method,
                                             rank=self.rank,
                                             world_size=self.world_size))
            rpc._init_rpc_backend(
                backend=self.rpc_backend,
                store=store,
                name=" ",
                rank=self.rank,
                world_size=self.world_size,
                rpc_backend_options=self.rpc_backend_options,
            )

        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "2"
        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
            store, _, _ = next(
                torch.distributed.rendezvous(self.init_method,
                                             rank=self.rank,
                                             world_size=self.world_size))
            rpc._init_rpc_backend(
                backend=self.rpc_backend,
                store=store,
                name="",
                rank=self.rank,
                world_size=self.world_size,
                rpc_backend_options=self.rpc_backend_options,
            )

        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "3"
        # If the number in the message does not match, it is likely that the
        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
            store, _, _ = next(
                torch.distributed.rendezvous(self.init_method,
                                             rank=self.rank,
                                             world_size=self.world_size))
            rpc._init_rpc_backend(
                backend=self.rpc_backend,
                store=store,
                name="".join(["a" for i in range(500)]),
                rank=self.rank,
                world_size=self.world_size,
                rpc_backend_options=self.rpc_backend_options,
            )

        from torch.distributed.rpc.api import _agent
        self.assertEqual(_agent, None)
        # wait_all_workers() should not do anything as _agent is None
        rpc.wait_all_workers()