def test_self_add(self): self_worker_id = dist.get_worker_id() self_worker_name = "worker{}".format(self.rank) with self.assertRaisesRegex( RuntimeError, "does not support making RPC calls to self" ): dist.rpc_sync(self_worker_id, torch.add, args=(torch.ones(2, 2), 1)) with self.assertRaisesRegex( RuntimeError, "does not support making RPC calls to self" ): dist.rpc_sync(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))
def test_py_built_in(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2)) self.assertEqual(ret, min(n, n + 1, n + 2))
def test_scalar_add(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync( "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), n) ) self.assertEqual(ret, (torch.ones(n, n) + n))
def test_py_class_instance_method(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync( "worker{}".format(dst_rank), my_class(2).my_instance_method, args=(n,) ) self.assertEqual(ret, my_class(2).my_instance_method(n))
def test_py_class_static_method(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync( "worker{}".format(dst_rank), my_class.my_static_method, args=(n + 10,) ) self.assertEqual(ret, my_class.my_static_method(n + 10))
def test_py_function_exception(self): n = self.rank + 1 dst_rank = n % self.world_size with self.assertRaisesRegex(Exception, "TypeError"): ret = dist.rpc_sync("worker{}".format(dst_rank), no_result, args=(10, ))
def test_nonzero(self): n = self.rank + 1 dst_rank = n % self.world_size x = torch.ones(self.world_size, self.world_size) x[self.rank][self.rank] = 0 ret = dist.rpc_sync("worker{}".format(dst_rank), torch.nonzero, args=(x,)) self.assertEqual(ret, x.nonzero())
def test_sync_rpc(self): dst_rank = (self.rank + 1) % self.world_size for i in range(20): dist.sync_rpc() n = i + self.rank + 1 ret1 = dist.rpc_sync( "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), torch.ones(n, n)), ) dist.sync_rpc() ret2 = dist.rpc_sync( "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2) ) dist.sync_rpc() self.assertEqual(ret1, torch.ones(n, n) * 2) self.assertEqual(ret2, torch.ones(n, n) * 3)
def test_add_with_id(self): n = self.rank + 1 dst_rank = n % self.world_size workder_id = dist.get_worker_id("worker{}".format(dst_rank)) ret = dist.rpc_sync( workder_id, torch.add, args=(torch.ones(n, n), torch.ones(n, n)) ) self.assertEqual(ret, torch.ones(n, n) * 2)
def test_py_user_defined(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync( "worker{}".format(dst_rank), my_function, kwargs={"a": n, "b": n + 1, "c": n + 2}, ) self.assertEqual(ret, my_function(n, n + 1, n + 2))
def test_py_tensors(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync( "worker{}".format(dst_rank), my_tensor_function, args=(torch.ones(n, n), torch.ones(n, n)), ) self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n)))
def test_nested_rpc(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync( "worker{}".format(dst_rank), nested_rpc, args=("worker{}".format(self.rank),), ) self.assertEqual(ret, torch.ones(2, 2) + 1)
def test_multi_rpc(self): dst_rank = (self.rank + 1) % self.world_size for i in range(20): n = i + self.rank + 1 ret = dist.rpc_sync( "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), torch.ones(n, n)), ) self.assertEqual(ret, torch.ones(n, n) * 2)
def test_py_tensors_in_container(self): n = self.rank + 1 dst_rank = n % self.world_size a = [torch.ones(n, n), torch.ones(n, n)] b = TensorClass(build_complex_tensors()) c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)} ret = dist.rpc_sync( "worker{}".format(dst_rank), my_complex_tensor_function, args=(a, b, c) ) self.assertEqual(ret, my_complex_tensor_function(a, b, c))
def test_rpc_return_rref(self): n = self.rank + 1 dst_rank1 = n % self.world_size dst_rank2 = (n + 1) % self.world_size rref = dist.rpc_sync( "worker{}".format(dst_rank1), rpc_return_rref, args=("worker{}".format(dst_rank2), ), ) self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
def test_join_rpc(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync( "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), torch.ones(n, n)), ) self.assertEqual(ret, torch.ones(n, n) * 2) dist.join_rpc() with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"): dist.rpc_sync( "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), torch.ones(n, n)), ) # it's safe to call join_rpc() multiple times dist.join_rpc()
def test_py_nested_pickle(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync( "worker{}".format(dst_rank), run_nested_pickle, args=(MyPickleClass(), torch.ones(2, 2)), ) m = MyPickleClass() m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2))) self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2)))
def test_py_rpc_rref_args(self): n = self.rank + 1 dst_rank = n % self.world_size rref_a = dist.remote("worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 2, 0)) rref_b = dist.remote("worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 1, 0)) c = dist.rpc_sync("worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)) self.assertEqual(c, torch.ones(n, n) + 4)
def test_rpc_complex_args(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: num_tensors = 10 tensors = [] for i in range(num_tensors): tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0))) ret = dist.rpc_sync('worker{}'.format(dst_rank), torch.stack, args=(tensors, )) self.assertEqual(torch.stack(tensors), ret) # Verify appropriate tensors have been attached the autograd graph. next_funcs = dist_autograd._current_context()._send_functions( )[0].next_functions idx = 0 for i in range(num_tensors): if i % 2 == 0: self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[i][0].name()) self.assertEqual(tensors[i], next_funcs[i][0].variable) else: self.assertIsNone(next_funcs[i][0])
def test_autograd_send_function(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) ret = dist.rpc_sync('worker{}'.format(dst_rank), torch.add, args=(t1, t2)) # Get send function. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) # Retrieve the next functions in the graph. next_funcs = send_functions[0].next_functions self.assertEqual(2, len(next_funcs)) # We should now hit t1 and t2 in the autograd graph. self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[0][0].name()) self.assertEqual(t1, next_funcs[0][0].variable) self.assertEqual(0, next_funcs[0][1]) self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[1][0].name()) self.assertEqual(t2, next_funcs[1][0].variable) self.assertEqual(0, next_funcs[1][1]) # autograd context should be cleaned up by now. with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(context_id) # No autograd context available. with self.assertRaises(RuntimeError): ctx = dist_autograd._current_context()
def test_autograd_functions(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) ret = dist.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) dist.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, )) # Get send function. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) # Retrieve the next functions in the graph. next_funcs = list(send_functions.values())[0].next_functions self.assertEqual(2, len(next_funcs)) # We should now hit t1 and t2 in the autograd graph. self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[0][0].name()) self.assertEqual(t1, next_funcs[0][0].variable) self.assertEqual(0, next_funcs[0][1]) self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[1][0].name()) self.assertEqual(t2, next_funcs[1][0].variable) self.assertEqual(0, next_funcs[1][1]) # Test recv functions. recv_functions = ctx._recv_functions() self.assertEqual(1, len(recv_functions)) self.assertEqual(ret.grad_fn, list(recv_functions.values())[0]) # We should have send/recv functions from the previous rank, get all # contexts in this node to find them. # Wait for the prev rank to be done with rpc. while not prev_rank_rpc_done: time.sleep(0.1) pass # Now verify the autograd graph. ctx = dist_autograd._retrieve_context(prev_rank_context_id) # Get the send function. send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) # Verify next function is AddBackward0 next_funcs = list(send_functions.values())[0].next_functions self.assertEqual(1, len(next_funcs)) add_backward_fn = next_funcs[0][0] self.assertEqual("AddBackward0", add_backward_fn.name()) # Verify the next two functions are the same recv backward function. next_funcs = add_backward_fn.next_functions self.assertEqual(2, len(next_funcs)) self.assertEqual("torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()) self.assertEqual("torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name()) self.assertEqual(next_funcs[0][0], next_funcs[1][0]) # autograd context should be cleaned up by now. with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(context_id) # No autograd context available. with self.assertRaises(RuntimeError): ctx = dist_autograd._current_context()
def test_py_no_return_result(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync("worker{}".format(dst_rank), no_result) self.assertEqual(ret, no_result())
def test_expected_src(self): dst_rank = (self.rank + 1) % self.world_size expected_src_rank = (self.rank - 1) % self.world_size ret = dist.rpc_sync("worker{}".format(dst_rank), set_value, args=(self.rank,)) value = VALUE_FUTURE.result() self.assertEqual(value, expected_src_rank)
def test_py_class_constructor(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync("worker{}".format(dst_rank), my_class, args=(n,)) self.assertEqual(ret.a, n)
def nested_rpc(dst): return dist.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))