Esempio n. 1
0
 def test_reinit(self):
     dist.init_process_group(
         backend=dist.Backend.GLOO,
         init_method=self.init_method,
         rank=self.rank,
         world_size=self.world_size,
     )
     rpc.init_model_parallel(
         self_name="worker{}".format(self.rank),
         backend=TEST_CONFIG.rpc_backend,
         init_method=self.init_method,
         self_rank=self.rank,
         worker_name_to_id=self.worker_name_to_id,
     )
     # Wait for all init to complete.
     dist.barrier()
     with self.assertRaisesRegex(RuntimeError, "is already initialized"):
         rpc.init_model_parallel(
             self_name="worker{}".format(self.rank),
             backend=TEST_CONFIG.rpc_backend,
             init_method=self.init_method,
             self_rank=self.rank,
             worker_name_to_id=self.worker_name_to_id,
         )
     rpc.join_rpc()
Esempio n. 2
0
    def test_reinit(self):
        rpc.init_model_parallel(
            self_name="worker{}".format(self.rank),
            backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
            init_method=self.init_method,
            self_rank=self.rank,
            worker_name_to_id=self.worker_name_to_id,
        )

        # This is for the below `dist.barrier`.
        # For `RpcAgent` other than `ProcessGroupAgent`,
        # no `_default_pg` is initialized.
        if not dist.is_initialized():
            dist.init_process_group(
                backend="gloo",
                init_method=self.init_method,
                rank=self.rank,
                world_size=self.world_size,
            )
        # Wait for all init to complete.
        dist.barrier()

        with self.assertRaisesRegex(RuntimeError, "is already initialized"):
            rpc.init_model_parallel(
                self_name="worker{}".format(self.rank),
                backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
                init_method=self.init_method,
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
            )
        rpc.join_rpc()
Esempio n. 3
0
    def test_join_rpc(self):
        # Initialize RPC.
        rpc.init_model_parallel(
            self_name="worker%d" % self.rank,
            backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
            init_method=self.init_method,
            self_rank=self.rank,
            worker_name_to_id=self.worker_name_to_id,
        )

        n = self.rank + 1
        dst_rank = n % self.world_size
        ret = rpc.rpc_sync(
            "worker{}".format(dst_rank),
            torch.add,
            args=(torch.ones(n, n), torch.ones(n, n)),
        )
        self.assertEqual(ret, torch.ones(n, n) * 2)
        rpc.join_rpc()

        with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
            rpc.rpc_sync(
                "worker{}".format(dst_rank),
                torch.add,
                args=(torch.ones(n, n), torch.ones(n, n)),
            )

        # it's safe to call join_rpc() multiple times
        rpc.join_rpc()
Esempio n. 4
0
 def test_register_rpc_backend_and_init_rpc_backend(
         self, mock_init_rref_context, mock_dist_autograd_init):
     backend_name = "stub_backend"
     rpc.register_backend(backend_name, stub_init_rpc_backend_handler)
     rpc.init_model_parallel(self_name="worker1",
                             backend=backend_name,
                             self_rank=1)
Esempio n. 5
0
 def test_init_invalid_backend(self):
     with self.assertRaisesRegex(RuntimeError, "Unrecognized RPC backend"):
         rpc.init_model_parallel(
             self_name="worker{}".format(self.rank),
             backend="invalid",
             self_rank=self.rank,
             init_method=self.init_method,
         )
Esempio n. 6
0
    def new_test_method(self, *arg, **kwargs):
        self.worker_id = self.rank
        self.worker_name_to_id = {
            "worker{}".format(rank): rank
            for rank in range(self.world_size)
        }

        if setup_model_parallel:
            global _ALL_NODE_NAMES
            _ALL_NODE_NAMES = self.worker_name_to_id.keys()

            # Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359
            rpc.init_model_parallel(
                self_name="worker%d" % self.rank,
                backend=rpc.backend_registry.BackendType[
                    TEST_CONFIG.rpc_backend_name],
                init_method=self.init_method,
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        return_value = old_test_method(self, *arg, **kwargs)

        if setup_model_parallel and clean_shutdown:
            # Follower reports done.
            if self.rank == MASTER_RANK:
                on_master_follower_report_done("worker{}".format(MASTER_RANK))
            else:
                rpc.rpc_async(
                    "worker{}".format(MASTER_RANK),
                    on_master_follower_report_done,
                    args=("worker{}".format(self.rank), ),
                )

            # Master waits for followers to report done.
            # Follower waits for master's termination command.
            _TERMINATION_SIGNAL.wait()
            if self.rank == MASTER_RANK:
                # Master sends termination command.
                futs = []
                for dst_rank in range(self.world_size):
                    # torch.distributed.rpc module does not support sending to self.
                    if dst_rank == MASTER_RANK:
                        continue
                    dst_name = "worker{}".format(dst_rank)
                    fut = rpc.rpc_async(dst_name,
                                        set_termination_signal,
                                        args=())
                    futs.append(fut)
                for fut in futs:
                    assert fut.wait(
                    ) is None, "Sending termination signal failed."

            # Close RPC.
            rpc.join_rpc()

        return return_value
Esempio n. 7
0
 def test_duplicate_name(self):
     dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method)
     with self.assertRaisesRegex(RuntimeError, "is not unique"):
         rpc.init_model_parallel(
             self_name="duplicate_name",
             backend=TEST_CONFIG.rpc_backend,
             self_rank=self.rank,
             init_method=self.init_method,
         )
     rpc.join_rpc()
Esempio n. 8
0
 def test_duplicate_name(self):
     with self.assertRaisesRegex(RuntimeError, "is not unique"):
         rpc.init_model_parallel(
             self_name="duplicate_name",
             backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
             init_method=self.init_method,
             self_rank=self.rank,
             worker_name_to_id=self.worker_name_to_id,
         )
     rpc.join_rpc()
Esempio n. 9
0
 def wrapper(self, *arg, **kwargs):
     self.worker_id = self.rank
     dist.init_process_group(backend="gloo", init_method=self.init_method)
     rpc.init_model_parallel(
         self_name="worker%d" % self.rank,
         backend=TEST_CONFIG.backend,
         self_rank=self.rank,
         init_method=self.init_method,
     )
     test_method(self, *arg, **kwargs)
     rpc.join_rpc()
Esempio n. 10
0
 def test_register_rpc_backend_and_start_rpc_backend(
         self, mock_rpc_agent, mock_dist_autograd_init):
     backend_name = "stub_backend"
     rpc.register_backend(backend_name, stub_start_rpc_backend_handler)
     rpc.init_model_parallel(
         self_name="worker1",
         backend=backend_name,
         init_method=self.init_method,
         self_rank=self.rank,
         worker_name_to_id=self.worker_name_to_id,
     )
Esempio n. 11
0
 def wrapper(self, *arg, **kwargs):
     self.worker_id = self.rank
     dist.init_process_group(backend="gloo", init_method=self.init_method)
     # Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359
     rpc.init_model_parallel(self_name="worker%d" % self.rank,
                             backend=TEST_CONFIG.rpc_backend,
                             self_rank=self.rank,
                             init_method=self.init_method,
                             num_send_recv_threads=16)
     test_method(self, *arg, **kwargs)
     rpc.join_rpc()
Esempio n. 12
0
 def test_set_rpc_timeout(self):
     timeout = timedelta(seconds=1)
     rpc.init_model_parallel(self_name="worker{}".format(self.rank),
                             backend=rpc.backend_registry.BackendType[
                                 TEST_CONFIG.rpc_backend_name],
                             init_method=self.init_method,
                             self_rank=self.rank,
                             worker_name_to_id=self.worker_name_to_id,
                             rpc_timeout=timeout)
     set_timeout = rpc.get_rpc_timeout()
     self.assertEqual(timeout, set_timeout)
     rpc.join_rpc()
Esempio n. 13
0
    def test_invalid_names(self):
        dist.init_process_group(backend=dist.Backend.GLOO,
                                init_method=self.init_method)

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            rpc.init_model_parallel(self_name="abc*")

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            rpc.init_model_parallel(self_name=" ")

        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
            rpc.init_model_parallel(self_name="")

        # If the number in the message does not match, it is likely that the
        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
            rpc.init_model_parallel(self_name="".join(
                ["a" for _ in range(500)]), )

        from torch.distributed.rpc.api import _agent
        self.assertEqual(_agent, None)
        # join_rpc() should not do anything as _agent is None
        rpc.join_rpc()
        # We need this barrier here because although init_process_group is
        # blocking, it does not guarantee that all ranks are done with
        # initialization after the call. We did run into issues with it where
        # rank 3 crashed with "connection closed by peer" RuntimeError, which is
        # caused by other ranks exit before rank 3 is ready. This can be fixed
        # by adding a collective call to sync all processes.
        #
        # We decided not fixing this issue in init_process_group because it
        # would add extra overhead to the call, and normal use cases won't
        # create a progress group and exit without doing anything. Hence, it is
        # not worthy to introduce the overhead just for this test case.
        dist.barrier()
Esempio n. 14
0
    def wrapper(self, *arg, **kwargs):
        self.worker_id = self.rank
        global _ALL_NODE_NAMES
        _ALL_NODE_NAMES = {
            "worker{}".format(rank)
            for rank in range(self.world_size)
        }

        # Initialize RPC.
        dist.init_process_group(backend="gloo", init_method=self.init_method)
        # Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359
        rpc.init_model_parallel(self_name="worker%d" % self.rank,
                                backend=TEST_CONFIG.rpc_backend,
                                self_rank=self.rank,
                                init_method=self.init_method,
                                num_send_recv_threads=16)
        test_method(self, *arg, **kwargs)

        # Follower reports done.
        if self.rank == MASTER_RANK:
            on_master_follower_report_done("worker{}".format(MASTER_RANK))
        else:
            rpc.rpc_async(
                "worker{}".format(MASTER_RANK),
                on_master_follower_report_done,
                args=("worker{}".format(self.rank), ),
            )

        # Master waits for followers to report done.
        # Follower waits for master's termination command.
        _TERMINATION_SIGNAL.wait()
        if self.rank == MASTER_RANK:
            # Master sends termination command.
            futs = []
            for dst_rank in range(self.world_size):
                # torch.distributed.rpc module does not support sending to self.
                if dst_rank == MASTER_RANK:
                    continue
                dst_name = "worker{}".format(dst_rank)
                fut = rpc.rpc_async(
                    dst_name,
                    set_termination_signal,
                    args=(),
                )
                futs.append(fut)
            for fut in futs:
                assert fut.wait() is None, "Sending termination signal failed."

        # Close RPC.
        rpc.join_rpc()
Esempio n. 15
0
    def test_invalid_names(self):
        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            rpc.init_model_parallel(
                self_name="abc*",
                backend=rpc.backend_registry.BackendType[
                    TEST_CONFIG.rpc_backend_name],
                init_method=self.init_method,
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        base_file_name = self.file_name
        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "1"
        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            rpc.init_model_parallel(
                self_name=" ",
                backend=rpc.backend_registry.BackendType[
                    TEST_CONFIG.rpc_backend_name],
                init_method=self.init_method,
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "2"
        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
            rpc.init_model_parallel(
                self_name="",
                backend=rpc.backend_registry.BackendType[
                    TEST_CONFIG.rpc_backend_name],
                init_method=self.init_method,
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "3"
        # If the number in the message does not match, it is likely that the
        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
            rpc.init_model_parallel(
                self_name="".join(["a" for _ in range(500)]),
                backend=rpc.backend_registry.BackendType[
                    TEST_CONFIG.rpc_backend_name],
                init_method=self.init_method,
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        from torch.distributed.rpc.api import _agent
        self.assertEqual(_agent, None)
        # join_rpc() should not do anything as _agent is None
        rpc.join_rpc()
Esempio n. 16
0
 def test_reinit(self):
     dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method)
     rpc.init_model_parallel(
         self_name="worker{}".format(self.rank),
         backend=TEST_CONFIG.rpc_backend,
         self_rank=self.rank,
         init_method=self.init_method,
     )
     with self.assertRaisesRegex(RuntimeError, "is already initialized"):
         rpc.init_model_parallel(
             self_name="worker{}".format(self.rank),
             backend=TEST_CONFIG.rpc_backend,
             self_rank=self.rank,
             init_method=self.init_method,
         )
     rpc.join_rpc()
Esempio n. 17
0
    def test_register_rpc_backend_and_start_rpc_backend(
            self, mock_rpc_agent, mock_dist_autograd_init):
        backend_name = "stub_backend"

        backend = rpc.backend_registry.register_backend(
            backend_name, stub_start_rpc_backend_handler)

        with self.assertRaisesRegex(RuntimeError,
                                    "^RPC backend .+: already registered$"):
            rpc.backend_registry.register_backend(
                backend_name, stub_start_rpc_backend_handler)

        rpc.init_model_parallel(
            self_name="worker1",
            backend=backend,
            init_method=self.init_method,
            self_rank=self.rank,
            worker_name_to_id=self.worker_name_to_id,
        )
Esempio n. 18
0
    def test_invalid_names(self):
        dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method)

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            rpc.init_model_parallel(self_name="abc*")

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            rpc.init_model_parallel(self_name=" ")

        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
            rpc.init_model_parallel(self_name="")

        # If the number in the message does not match, it is likely that the
        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
            rpc.init_model_parallel(
                self_name="".join(["a" for _ in range(500)]),
                backend=TEST_CONFIG.rpc_backend,
                self_rank=self.rank,
                init_method=self.init_method,
            )
        rpc.join_rpc()