Example #1
0
 def test_duplicated_names(self):
     store = dist.FileStore(self.file.name, self.world_size)
     dist.init_process_group(backend="gloo", rank=self.rank,
                             world_size=self.world_size, store=store)
     with self.assertRaisesRegex(RuntimeError, "is not unique"):
         dist.init_model_parallel("duplicated_name")
     dist.join_rpc()
Example #2
0
    def test_invalid_names(self):
        store = dist.FileStore(self.file.name, self.world_size)
        dist.init_process_group(backend="gloo",
                                rank=self.rank,
                                world_size=self.world_size,
                                store=store)

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            dist.init_model_parallel(self_name="abc*")

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            dist.init_model_parallel(self_name=" ")

        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
            dist.init_model_parallel(self_name="")

        # If the number in the message does not match, it is likely that the
        # value of MAX_NAME_LEN in RPC WorkerId has changed.
        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
            dist.init_model_parallel(self_name="".join(
                ["a" for _ in range(500)]),
                                     backend=BACKEND,
                                     self_rank=self.rank,
                                     init_method=RPC_INIT_URL)
        dist.join_rpc()
Example #3
0
 def wrapper(self):
     store = dist.FileStore(self.file.name, self.world_size)
     dist.init_process_group(backend='gloo', rank=self.rank,
                             world_size=self.world_size, store=store)
     dist.init_model_parallel('worker%d' % self.rank)
     func(self)
     dist.join_rpc()
Example #4
0
 def wrapper(self):
     store = dist.FileStore(self.file.name, self.world_size)
     dist.init_process_group(backend='gloo', rank=self.rank,
                             world_size=self.world_size, store=store)
     dist.init_rpc('worker{}'.format(self.rank))
     func(self)
     dist.join_rpc()
Example #5
0
 def test_duplicate_name(self):
     dist.init_process_group(backend="gloo", init_method=self.init_method)
     with self.assertRaisesRegex(RuntimeError, "is not unique"):
         dist.init_model_parallel(
             self_name="duplicate_name",
             backend=TEST_CONFIG.backend,
             self_rank=self.rank,
             init_method=self.init_method,
         )
     dist.join_rpc()
Example #6
0
 def wrapper(self):
     store = dist.FileStore(self.file.name, self.world_size)
     dist.init_process_group(backend='gloo', rank=self.rank,
                             world_size=self.world_size, store=store)
     dist.init_model_parallel(self_name='worker%d' % self.rank,
                              backend=BACKEND,
                              self_rank=self.rank,
                              init_method=RPC_INIT_URL)
     func(self)
     dist.join_rpc()
Example #7
0
 def test_reinit(self):
     store = dist.FileStore(self.file.name, self.world_size)
     dist.init_process_group(backend="gloo", rank=self.rank,
                             world_size=self.world_size, store=store)
     with self.assertRaisesRegex(RuntimeError, "is not unique"):
         dist.init_model_parallel(self_name="duplicate_name",
                                  backend=BACKEND,
                                  self_rank=self.rank,
                                  init_method=RPC_INIT_URL)
     dist.join_rpc()
Example #8
0
 def wrapper(self, *arg, **kwargs):
     self.worker_id = self.rank
     dist.init_process_group(backend="gloo", init_method=self.init_method)
     dist.init_model_parallel(
         self_name="worker%d" % self.rank,
         backend=TEST_CONFIG.backend,
         self_rank=self.rank,
         init_method=self.init_method,
     )
     test_method(self, *arg, **kwargs)
     dist.join_rpc()
Example #9
0
 def wrapper(self, *arg, **kwargs):
     store = dist.FileStore(self.file_name, self.world_size)
     dist.init_process_group(
         backend="gloo", rank=self.rank, world_size=self.world_size, store=store
     )
     dist.init_model_parallel(
         self_name="worker%d" % self.rank,
         backend=BACKEND,
         self_rank=self.rank,
         init_method=RPC_INIT_URL,
     )
     test_method(self, *arg, **kwargs)
     dist.join_rpc()
Example #10
0
    def test_join_rpc(self):
        n = self.rank + 1
        dstRank = n % self.world_size
        ret = dist.rpc('worker%d' % dstRank, torch.add,
                       args=(torch.ones(n, n), torch.ones(n, n)))
        self.assertEqual(ret, torch.ones(n, n) * 2)
        dist.join_rpc()

        with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
            dist.rpc('worker%d' % dstRank, torch.add,
                     args=(torch.ones(n, n), torch.ones(n, n)))

        # it's safe to call join_rpc() multiple times
        dist.join_rpc()
Example #11
0
 def test_reinit(self):
     store = dist.FileStore(self.file.name, self.world_size)
     dist.init_process_group(backend="gloo", rank=self.rank,
                             world_size=self.world_size, store=store)
     dist.init_model_parallel(self_name='worker{}'.format(self.rank),
                              backend=BACKEND,
                              self_rank=self.rank,
                              init_method=RPC_INIT_URL)
     with self.assertRaisesRegex(RuntimeError,
                                 "is already initialized"):
         dist.init_model_parallel(self_name='worker{}'.format(self.rank),
                                  backend=BACKEND,
                                  self_rank=self.rank,
                                  init_method=RPC_INIT_URL)
     dist.join_rpc()
Example #12
0
 def test_reinit(self):
     dist.init_process_group(backend="gloo", init_method=self.init_method)
     dist.init_model_parallel(
         self_name="worker{}".format(self.rank),
         backend=TEST_CONFIG.backend,
         self_rank=self.rank,
         init_method=self.init_method,
     )
     with self.assertRaisesRegex(RuntimeError, "is already initialized"):
         dist.init_model_parallel(
             self_name="worker{}".format(self.rank),
             backend=TEST_CONFIG.backend,
             self_rank=self.rank,
             init_method=self.init_method,
         )
     dist.join_rpc()
Example #13
0
    def test_invalid_names(self):
        dist.init_process_group(backend="gloo", init_method=self.init_method)

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            dist.init_model_parallel(self_name="abc*")

        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
            dist.init_model_parallel(self_name=" ")

        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
            dist.init_model_parallel(self_name="")

        # If the number in the message does not match, it is likely that the
        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
            dist.init_model_parallel(
                self_name="".join(["a" for _ in range(500)]),
                backend=TEST_CONFIG.backend,
                self_rank=self.rank,
                init_method=self.init_method,
            )
        dist.join_rpc()