コード例 #1
0
ファイル: test_rpc.py プロジェクト: zhuzijing/pytorch
 def test_init_invalid_backend(self):
     with self.assertRaisesRegex(RuntimeError,
                                 "Unrecognized RPC backend"):
         dist.init_model_parallel(self_name='worker{}'.format(self.rank),
                                  backend="invalid",
                                  self_rank=self.rank,
                                  init_method=RPC_INIT_URL)
コード例 #2
0
 def wrapper(self):
     store = dist.FileStore(self.file.name, self.world_size)
     dist.init_process_group(backend='gloo', rank=self.rank,
                             world_size=self.world_size, store=store)
     dist.init_model_parallel('worker%d' % self.rank)
     func(self)
     dist.join_rpc()
コード例 #3
0
 def test_duplicated_names(self):
     store = dist.FileStore(self.file.name, self.world_size)
     dist.init_process_group(backend="gloo", rank=self.rank,
                             world_size=self.world_size, store=store)
     with self.assertRaisesRegex(RuntimeError, "is not unique"):
         dist.init_model_parallel("duplicated_name")
     dist.join_rpc()
コード例 #4
0
 def test_register_rpc_backend_and_init_rpc_backend(
     self, mock_init_rref_context, mock_dist_autograd_init
 ):
     backend_name = "stub_backend"
     rpc_backend_registry.register_rpc_backend(
         backend_name, stub_init_rpc_backend_handler
     )
     dist.init_model_parallel(self_name="worker1", backend=backend_name, self_rank=1)
コード例 #5
0
 def test_duplicate_name(self):
     dist.init_process_group(backend="gloo", init_method=self.init_method)
     with self.assertRaisesRegex(RuntimeError, "is not unique"):
         dist.init_model_parallel(
             self_name="duplicate_name",
             backend=TEST_CONFIG.backend,
             self_rank=self.rank,
             init_method=self.init_method,
         )
     dist.join_rpc()
コード例 #6
0
ファイル: test_rpc.py プロジェクト: zhilanyue/pytorch
 def wrapper(self):
     store = dist.FileStore(self.file.name, self.world_size)
     dist.init_process_group(backend='gloo', rank=self.rank,
                             world_size=self.world_size, store=store)
     dist.init_model_parallel(self_name='worker%d' % self.rank,
                              backend=BACKEND,
                              self_rank=self.rank,
                              init_method=RPC_INIT_URL)
     func(self)
     dist.join_rpc()
コード例 #7
0
ファイル: test_rpc.py プロジェクト: zhilanyue/pytorch
 def test_reinit(self):
     store = dist.FileStore(self.file.name, self.world_size)
     dist.init_process_group(backend="gloo", rank=self.rank,
                             world_size=self.world_size, store=store)
     with self.assertRaisesRegex(RuntimeError, "is not unique"):
         dist.init_model_parallel(self_name="duplicate_name",
                                  backend=BACKEND,
                                  self_rank=self.rank,
                                  init_method=RPC_INIT_URL)
     dist.join_rpc()
コード例 #8
0
 def wrapper(self, *arg, **kwargs):
     self.worker_id = self.rank
     dist.init_process_group(backend="gloo", init_method=self.init_method)
     dist.init_model_parallel(
         self_name="worker%d" % self.rank,
         backend=TEST_CONFIG.backend,
         self_rank=self.rank,
         init_method=self.init_method,
     )
     test_method(self, *arg, **kwargs)
     dist.join_rpc()
コード例 #9
0
ファイル: rpc_test.py プロジェクト: xcheng16/pytorch
 def wrapper(self, *arg, **kwargs):
     store = dist.FileStore(self.file_name, self.world_size)
     dist.init_process_group(
         backend="gloo", rank=self.rank, world_size=self.world_size, store=store
     )
     dist.init_model_parallel(
         self_name="worker%d" % self.rank,
         backend=BACKEND,
         self_rank=self.rank,
         init_method=RPC_INIT_URL,
     )
     test_method(self, *arg, **kwargs)
     dist.join_rpc()
コード例 #10
0
ファイル: test_rpc.py プロジェクト: zwh930712/pytorch
    def test_invalid_names(self):
        store = dist.FileStore(self.file.name, self.world_size)
        dist.init_process_group(backend="gloo",
                                rank=self.rank,
                                world_size=self.world_size,
                                store=store)

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            dist.init_model_parallel(self_name="abc*")

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            dist.init_model_parallel(self_name=" ")

        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
            dist.init_model_parallel(self_name="")

        # If the number in the message does not match, it is likely that the
        # value of MAX_NAME_LEN in RPC WorkerId has changed.
        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
            dist.init_model_parallel(self_name="".join(
                ["a" for _ in range(500)]),
                                     backend=BACKEND,
                                     self_rank=self.rank,
                                     init_method=RPC_INIT_URL)
        dist.join_rpc()
コード例 #11
0
ファイル: test_rpc.py プロジェクト: zhuzijing/pytorch
 def test_reinit(self):
     store = dist.FileStore(self.file.name, self.world_size)
     dist.init_process_group(backend="gloo", rank=self.rank,
                             world_size=self.world_size, store=store)
     dist.init_model_parallel(self_name='worker{}'.format(self.rank),
                              backend=BACKEND,
                              self_rank=self.rank,
                              init_method=RPC_INIT_URL)
     with self.assertRaisesRegex(RuntimeError,
                                 "is already initialized"):
         dist.init_model_parallel(self_name='worker{}'.format(self.rank),
                                  backend=BACKEND,
                                  self_rank=self.rank,
                                  init_method=RPC_INIT_URL)
     dist.join_rpc()
コード例 #12
0
 def test_reinit(self):
     dist.init_process_group(backend="gloo", init_method=self.init_method)
     dist.init_model_parallel(
         self_name="worker{}".format(self.rank),
         backend=TEST_CONFIG.backend,
         self_rank=self.rank,
         init_method=self.init_method,
     )
     with self.assertRaisesRegex(RuntimeError, "is already initialized"):
         dist.init_model_parallel(
             self_name="worker{}".format(self.rank),
             backend=TEST_CONFIG.backend,
             self_rank=self.rank,
             init_method=self.init_method,
         )
     dist.join_rpc()
コード例 #13
0
    def test_invalid_names(self):
        dist.init_process_group(backend="gloo", init_method=self.init_method)

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            dist.init_model_parallel(self_name="abc*")

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            dist.init_model_parallel(self_name=" ")

        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
            dist.init_model_parallel(self_name="")

        # If the number in the message does not match, it is likely that the
        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
            dist.init_model_parallel(
                self_name="".join(["a" for _ in range(500)]),
                backend=TEST_CONFIG.backend,
                self_rank=self.rank,
                init_method=self.init_method,
            )
        dist.join_rpc()