Exemple #1
0
 def test_remote_same_worker(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     rref_a = dist.remote("worker{}".format(dst_rank),
                          torch.add,
                          args=(torch.ones(n, n), 2))
     rref_b = dist.remote("worker{}".format(dst_rank),
                          torch.add,
                          args=(torch.ones(n, n), 1))
     rref_c = dist.remote("worker{}".format(dst_rank),
                          my_rref_function,
                          args=(rref_a, rref_b))
     self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
Exemple #2
0
 def test_py_rref_args_user_share(self):
     n = self.rank + 1
     owner_rank = n % self.world_size
     user_rank = (n + 1) % self.world_size
     rref_a = dist.remote("worker{}".format(owner_rank),
                          my_function,
                          args=(torch.ones(n, n), 2, 0))
     rref_b = dist.remote("worker{}".format(owner_rank),
                          my_function,
                          args=(torch.ones(n, n), 1, 0))
     rref_c = dist.remote("worker{}".format(user_rank),
                          my_rref_function,
                          args=(rref_a, rref_b))
     self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
Exemple #3
0
    def test_py_rpc_rref_args(self):
        n = self.rank + 1
        dst_rank = n % self.world_size
        rref_a = dist.remote("worker{}".format(dst_rank),
                             my_function,
                             args=(torch.ones(n, n), 2, 0))
        rref_b = dist.remote("worker{}".format(dst_rank),
                             my_function,
                             args=(torch.ones(n, n), 1, 0))

        c = dist.rpc_sync("worker{}".format(dst_rank),
                          my_rref_function,
                          args=(rref_a, rref_b))

        self.assertEqual(c, torch.ones(n, n) + 4)
Exemple #4
0
 def test_builtin_remote_ret(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     rref = dist.remote('worker{}'.format(dst_rank),
                        torch.add,
                        args=(torch.ones(n, n), torch.ones(n, n)))
     self.assertEqual(rref.to_here(), torch.ones(n, n) * 2)
Exemple #5
0
 def test_py_udf_remote(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     rref = dist.remote(
         "worker{}".format(dst_rank),
         my_function,
         kwargs={"a": n, "b": n + 1, "c": n + 2},
     )
     self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
Exemple #6
0
 def test_nested_remote(self):
     n = self.rank + 1
     dst_rank1 = n % self.world_size
     dst_rank2 = (n + 1) % self.world_size
     rref = dist.remote(
         "worker{}".format(dst_rank1),
         nested_remote,
         args=("worker{}".format(dst_rank2), ),
     )
     self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3)
Exemple #7
0
def rref_forward_chain(dst, world_size, rref, ttl):
    if ttl > 0:
        current_dst = "worker{}".format(dst)
        next_dst = (dst + 1) % world_size
        ret_rref = dist.remote(current_dst,
                               rref_forward_chain,
                               args=(next_dst, world_size, rref, ttl - 1))
        return [ret_rref]
    else:
        return rref.to_here()
Exemple #8
0
 def test_nested_rref(self):
     n = self.rank + 1
     dst_rank1 = n % self.world_size
     dst_rank2 = (n + 1) % self.world_size
     rref_of_rrefs = dist.remote(
         "worker{}".format(dst_rank1),
         nested_rref,
         args=("worker{}".format(dst_rank2), ),
     )
     rrefs = rref_of_rrefs.to_here()
     self.assertEqual(len(rrefs), 2)
     self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
     self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
Exemple #9
0
    def test_multi_builtin_remote_ret(self):
        m = 10
        n = self.rank + 1
        dst_rank = n % self.world_size
        rrefs = []
        expected = []
        for i in range(m):
            n = n + i
            rrefs.append(
                dist.remote('worker{}'.format(dst_rank),
                            torch.add,
                            args=(torch.ones(n, n), torch.ones(n, n))))
            expected.append(torch.ones(n, n) * 2)

        for i in range(m):
            self.assertEqual(rrefs[i].to_here(), expected[i])
Exemple #10
0
    def test_rref_forward_chain(self):
        ttl = 8
        n = self.rank + 1
        dst_rank = n % self.world_size

        rref = dist.remote("worker{}".format(dst_rank),
                           torch.add,
                           args=(torch.ones(n, n), 1))

        ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl)

        for i in range(ttl):
            self.assertEqual(len(ret_rref), 1)
            ret_rref = ret_rref[0].to_here()

        ret = ret_rref
        self.assertEqual(ret, torch.add(torch.ones(n, n), 1))
Exemple #11
0
    def test_nested_rref_stress(self):
        n = self.rank + 1
        dst_rank1 = n % self.world_size
        dst_rank2 = (n + 1) % self.world_size
        all_rrefs = []
        for _ in range(20):
            all_rrefs.append(
                dist.remote(
                    "worker{}".format(dst_rank1),
                    nested_rref,
                    args=("worker{}".format(dst_rank2), ),
                ))

        for i in range(20):
            rref_of_rrefs = all_rrefs[i]
            rrefs = rref_of_rrefs.to_here()
            self.assertEqual(len(rrefs), 2)
            self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
            self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
Exemple #12
0
    def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}):
        m = 10
        n = self.rank + 1
        dst_rank = n % self.world_size
        rrefs = []
        expected = []
        for i in range(m):
            n = n + i
            rrefs.append(
                dist.remote(
                    "worker{}".format(dst_rank),
                    fn,
                    args=args_fn(n),
                    kwargs=kwargs_fn(n),
                )
            )
            expected.append(fn(*args_fn(n), **kwargs_fn(n)))

        for i in range(m):
            self.assertEqual(rrefs[i].to_here(), expected[i])
Exemple #13
0
 def test_remote_with_exception(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     rref = dist.remote("worker{}".format(dst_rank), raise_func)
     with self.assertRaisesRegex(Exception, "ValueError"):
         rref.to_here()
Exemple #14
0
def rpc_return_rref(dst):
    return dist.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
Exemple #15
0
def nested_remote(dst):
    rref = dist.remote(dst, torch.add, args=(torch.ones(2, 2), 3))
    return rref.to_here()
Exemple #16
0
def nested_rref(dst):
    return (
        dist.remote(dst, torch.add, args=(torch.ones(2, 2), 1)),
        dist.remote(dst, torch.add, args=(torch.ones(2, 2), 2)),
    )