def test_remote_same_worker(self): n = self.rank + 1 dst_rank = n % self.world_size rref_a = dist.remote("worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)) rref_b = dist.remote("worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)) rref_c = dist.remote("worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)) self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
def test_py_rref_args_user_share(self): n = self.rank + 1 owner_rank = n % self.world_size user_rank = (n + 1) % self.world_size rref_a = dist.remote("worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 2, 0)) rref_b = dist.remote("worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 1, 0)) rref_c = dist.remote("worker{}".format(user_rank), my_rref_function, args=(rref_a, rref_b)) self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
def test_py_rpc_rref_args(self): n = self.rank + 1 dst_rank = n % self.world_size rref_a = dist.remote("worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 2, 0)) rref_b = dist.remote("worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 1, 0)) c = dist.rpc_sync("worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)) self.assertEqual(c, torch.ones(n, n) + 4)
def test_builtin_remote_ret(self): n = self.rank + 1 dst_rank = n % self.world_size rref = dist.remote('worker{}'.format(dst_rank), torch.add, args=(torch.ones(n, n), torch.ones(n, n))) self.assertEqual(rref.to_here(), torch.ones(n, n) * 2)
def test_py_udf_remote(self): n = self.rank + 1 dst_rank = n % self.world_size rref = dist.remote( "worker{}".format(dst_rank), my_function, kwargs={"a": n, "b": n + 1, "c": n + 2}, ) self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
def test_nested_remote(self): n = self.rank + 1 dst_rank1 = n % self.world_size dst_rank2 = (n + 1) % self.world_size rref = dist.remote( "worker{}".format(dst_rank1), nested_remote, args=("worker{}".format(dst_rank2), ), ) self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3)
def rref_forward_chain(dst, world_size, rref, ttl): if ttl > 0: current_dst = "worker{}".format(dst) next_dst = (dst + 1) % world_size ret_rref = dist.remote(current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1)) return [ret_rref] else: return rref.to_here()
def test_nested_rref(self): n = self.rank + 1 dst_rank1 = n % self.world_size dst_rank2 = (n + 1) % self.world_size rref_of_rrefs = dist.remote( "worker{}".format(dst_rank1), nested_rref, args=("worker{}".format(dst_rank2), ), ) rrefs = rref_of_rrefs.to_here() self.assertEqual(len(rrefs), 2) self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1) self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
def test_multi_builtin_remote_ret(self): m = 10 n = self.rank + 1 dst_rank = n % self.world_size rrefs = [] expected = [] for i in range(m): n = n + i rrefs.append( dist.remote('worker{}'.format(dst_rank), torch.add, args=(torch.ones(n, n), torch.ones(n, n)))) expected.append(torch.ones(n, n) * 2) for i in range(m): self.assertEqual(rrefs[i].to_here(), expected[i])
def test_rref_forward_chain(self): ttl = 8 n = self.rank + 1 dst_rank = n % self.world_size rref = dist.remote("worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)) ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl) for i in range(ttl): self.assertEqual(len(ret_rref), 1) ret_rref = ret_rref[0].to_here() ret = ret_rref self.assertEqual(ret, torch.add(torch.ones(n, n), 1))
def test_nested_rref_stress(self): n = self.rank + 1 dst_rank1 = n % self.world_size dst_rank2 = (n + 1) % self.world_size all_rrefs = [] for _ in range(20): all_rrefs.append( dist.remote( "worker{}".format(dst_rank1), nested_rref, args=("worker{}".format(dst_rank2), ), )) for i in range(20): rref_of_rrefs = all_rrefs[i] rrefs = rref_of_rrefs.to_here() self.assertEqual(len(rrefs), 2) self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1) self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}): m = 10 n = self.rank + 1 dst_rank = n % self.world_size rrefs = [] expected = [] for i in range(m): n = n + i rrefs.append( dist.remote( "worker{}".format(dst_rank), fn, args=args_fn(n), kwargs=kwargs_fn(n), ) ) expected.append(fn(*args_fn(n), **kwargs_fn(n))) for i in range(m): self.assertEqual(rrefs[i].to_here(), expected[i])
def test_remote_with_exception(self): n = self.rank + 1 dst_rank = n % self.world_size rref = dist.remote("worker{}".format(dst_rank), raise_func) with self.assertRaisesRegex(Exception, "ValueError"): rref.to_here()
def rpc_return_rref(dst): return dist.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
def nested_remote(dst): rref = dist.remote(dst, torch.add, args=(torch.ones(2, 2), 3)) return rref.to_here()
def nested_rref(dst): return ( dist.remote(dst, torch.add, args=(torch.ones(2, 2), 1)), dist.remote(dst, torch.add, args=(torch.ones(2, 2), 2)), )