コード例 #1
0
ファイル: rpc_test.py プロジェクト: vincentqb/pytorch
    def test_invalid_names(self):
        dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method)

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            rpc.init_model_parallel(self_name="abc*")

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            rpc.init_model_parallel(self_name=" ")

        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
            rpc.init_model_parallel(self_name="")

        # If the number in the message does not match, it is likely that the
        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
            rpc.init_model_parallel(
                self_name="".join(["a" for _ in range(500)]),
                backend=TEST_CONFIG.rpc_backend,
                self_rank=self.rank,
                init_method=self.init_method,
            )
        rpc.join_rpc()
コード例 #2
0
ファイル: rpc_test.py プロジェクト: wywywy01/pytorch
 def test_reinit(self):
     dist.init_process_group(
         backend=dist.Backend.GLOO,
         init_method=self.init_method,
         rank=self.rank,
         world_size=self.world_size,
     )
     rpc.init_model_parallel(
         self_name="worker{}".format(self.rank),
         backend=TEST_CONFIG.rpc_backend,
         init_method=self.init_method,
         self_rank=self.rank,
         worker_name_to_id=self.worker_name_to_id,
     )
     with self.assertRaisesRegex(RuntimeError, "is already initialized"):
         rpc.init_model_parallel(
             self_name="worker{}".format(self.rank),
             backend=TEST_CONFIG.rpc_backend,
             init_method=self.init_method,
             self_rank=self.rank,
             worker_name_to_id=self.worker_name_to_id,
         )
     rpc.join_rpc()
コード例 #3
0
ファイル: rpc_test.py プロジェクト: wywywy01/pytorch
    def test_join_rpc(self):
        # Initialize RPC.
        dist.init_process_group(
            backend="gloo",
            init_method=self.init_method,
            rank=self.rank,
            world_size=self.world_size,
        )
        rpc.init_model_parallel(
            self_name="worker%d" % self.rank,
            backend=TEST_CONFIG.rpc_backend,
            init_method=self.init_method,
            self_rank=self.rank,
            worker_name_to_id=self.worker_name_to_id,
        )

        n = self.rank + 1
        dst_rank = n % self.world_size
        ret = rpc.rpc_sync(
            "worker{}".format(dst_rank),
            torch.add,
            args=(torch.ones(n, n), torch.ones(n, n)),
        )
        self.assertEqual(ret, torch.ones(n, n) * 2)
        rpc.join_rpc()

        with self.assertRaisesRegex(RuntimeError,
                                    "^RPC has not been initialized"):
            rpc.rpc_sync(
                "worker{}".format(dst_rank),
                torch.add,
                args=(torch.ones(n, n), torch.ones(n, n)),
            )

        # it's safe to call join_rpc() multiple times
        rpc.join_rpc()
コード例 #4
0
ファイル: dist_utils.py プロジェクト: wywywy01/pytorch
        def new_test_method(self, *arg, **kwargs):
            self.worker_id = self.rank
            self.worker_name_to_id = {
                "worker{}".format(rank): rank
                for rank in range(self.world_size)
            }

            if setup_model_parallel:
                global _ALL_NODE_NAMES
                _ALL_NODE_NAMES = self.worker_name_to_id.keys()

                dist.init_process_group(
                    backend="gloo",
                    init_method=self.init_method,
                    rank=self.rank,
                    world_size=self.world_size,
                )
                # Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359
                rpc.init_model_parallel(
                    self_name="worker%d" % self.rank,
                    backend=TEST_CONFIG.rpc_backend,
                    init_method=self.init_method,
                    self_rank=self.rank,
                    worker_name_to_id=self.worker_name_to_id,
                    num_send_recv_threads=16,
                )

            old_test_method(self, *arg, **kwargs)

            if setup_model_parallel:
                # Follower reports done.
                if self.rank == MASTER_RANK:
                    on_master_follower_report_done(
                        "worker{}".format(MASTER_RANK))
                else:
                    rpc.rpc_async(
                        "worker{}".format(MASTER_RANK),
                        on_master_follower_report_done,
                        args=("worker{}".format(self.rank), ),
                    )

                # Master waits for followers to report done.
                # Follower waits for master's termination command.
                _TERMINATION_SIGNAL.wait()
                if self.rank == MASTER_RANK:
                    # Master sends termination command.
                    futs = []
                    for dst_rank in range(self.world_size):
                        # torch.distributed.rpc module does not support sending to self.
                        if dst_rank == MASTER_RANK:
                            continue
                        dst_name = "worker{}".format(dst_rank)
                        fut = rpc.rpc_async(dst_name,
                                            set_termination_signal,
                                            args=())
                        futs.append(fut)
                    for fut in futs:
                        assert fut.wait(
                        ) is None, "Sending termination signal failed."

                # Close RPC.
                rpc.join_rpc()
コード例 #5
0
ファイル: rpc_test.py プロジェクト: zhuyawen/pytorch
    def test_invalid_names(self):
        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            store, _, _ = next(torch.distributed.rendezvous(
                self.init_method, rank=self.rank, world_size=self.world_size
            ))
            rpc._init_rpc(
                backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
                store=store,
                self_name="abc*",
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        base_file_name = self.file_name
        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "1"

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            store, _, _ = next(torch.distributed.rendezvous(
                self.init_method, rank=self.rank, world_size=self.world_size
            ))
            rpc._init_rpc(
                backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
                store=store,
                self_name=" ",
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "2"
        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
            store, _, _ = next(torch.distributed.rendezvous(
                self.init_method, rank=self.rank, world_size=self.world_size
            ))
            rpc._init_rpc(
                backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
                store=store,
                self_name="",
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        # Use a different file path for FileStore to avoid rendezvous mismatch.
        self.file_name = base_file_name + "3"
        # If the number in the message does not match, it is likely that the
        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
            store, _, _ = next(torch.distributed.rendezvous(
                self.init_method, rank=self.rank, world_size=self.world_size
            ))
            rpc._init_rpc(
                backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
                store=store,
                self_name="".join(["a" for i in range(500)]),
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        from torch.distributed.rpc.api import _agent
        self.assertEqual(_agent, None)
        # join_rpc() should not do anything as _agent is None
        rpc.join_rpc()
コード例 #6
0
ファイル: rpc_test.py プロジェクト: wywywy01/pytorch
    def test_invalid_names(self):
        dist.init_process_group(
            backend=dist.Backend.GLOO,
            init_method=self.init_method,
            rank=self.rank,
            world_size=self.world_size,
        )

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            rpc.init_model_parallel(
                self_name="abc*",
                backend=TEST_CONFIG.rpc_backend,
                init_method=self.init_method,
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            rpc.init_model_parallel(
                self_name=" ",
                backend=TEST_CONFIG.rpc_backend,
                init_method=self.init_method,
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
            rpc.init_model_parallel(
                self_name="",
                backend=TEST_CONFIG.rpc_backend,
                init_method=self.init_method,
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        # If the number in the message does not match, it is likely that the
        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
            rpc.init_model_parallel(
                self_name="".join(["a" for _ in range(500)]),
                backend=TEST_CONFIG.rpc_backend,
                init_method=self.init_method,
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        from torch.distributed.rpc.api import _agent
        self.assertEqual(_agent, None)
        # join_rpc() should not do anything as _agent is None
        rpc.join_rpc()
        # We need this barrier here because although init_process_group is
        # blocking, it does not guarantee that all ranks are done with
        # initialization after the call. We did run into issues with it where
        # rank 3 crashed with "connection closed by peer" RuntimeError, which is
        # caused by other ranks exit before rank 3 is ready. This can be fixed
        # by adding a collective call to sync all processes.
        #
        # We decided not fixing this issue in init_process_group because it
        # would add extra overhead to the call, and normal use cases won't
        # create a progress group and exit without doing anything. Hence, it is
        # not worthy to introduce the overhead just for this test case.
        dist.barrier()
コード例 #7
0
    def new_test_method(self, *arg, **kwargs):
        self.worker_id = self.rank
        self.worker_name_to_id = {
            "worker{}".format(rank): rank
            for rank in range(self.world_size)
        }

        if setup_model_parallel:
            global _ALL_NODE_NAMES
            _ALL_NODE_NAMES = self.worker_name_to_id.keys()

            # Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359
            rpc.init_model_parallel(
                self_name="worker%d" % self.rank,
                backend=rpc.backend_registry.BackendType[
                    TEST_CONFIG.rpc_backend_name],
                init_method=self.init_method,
                self_rank=self.rank,
                worker_name_to_id=self.worker_name_to_id,
                num_send_recv_threads=16,
            )

        return_value = old_test_method(self, *arg, **kwargs)

        if setup_model_parallel:
            if clean_shutdown:
                # Follower reports done.
                if self.rank == MASTER_RANK:
                    on_master_follower_report_done(
                        "worker{}".format(MASTER_RANK))
                else:
                    rpc.rpc_async(
                        "worker{}".format(MASTER_RANK),
                        on_master_follower_report_done,
                        args=("worker{}".format(self.rank), ),
                    )

                # Master waits for followers to report done.
                # Follower waits for master's termination command.
                _TERMINATION_SIGNAL.wait()
                if self.rank == MASTER_RANK:
                    # Master sends termination command.
                    futs = []
                    for dst_rank in range(self.world_size):
                        # torch.distributed.rpc module does not support sending to self.
                        if dst_rank == MASTER_RANK:
                            continue
                        dst_name = "worker{}".format(dst_rank)
                        fut = rpc.rpc_async(dst_name,
                                            set_termination_signal,
                                            args=())
                        futs.append(fut)
                    for fut in futs:
                        assert fut.wait(
                        ) is None, "Sending termination signal failed."

            # Close RPC. Need to do this even if we don't have a clean shutdown
            # since we need to shutdown the RPC agent. If we don't shutdown the
            # RPC agent, tests would fail since RPC agent threads, locks and
            # condition variables are not properly terminated.
            rpc.join_rpc()

        return return_value