def test_join_rpc(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc( "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), torch.ones(n, n)), ) self.assertEqual(ret, torch.ones(n, n) * 2) dist.join_rpc() with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"): dist.rpc( "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), torch.ones(n, n)), ) # it's safe to call join_rpc() multiple times dist.join_rpc()
def test_py_user_defined(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc( "worker{}".format(dst_rank), my_function, kwargs={ "a": n, "b": n + 1, "c": n + 2 }, ) self.assertEqual(ret, my_function(n, n + 1, n + 2))
def test_py_tensors_multi_async_call(self): futs = [] n = self.rank + 1 dst_rank = n % self.world_size for i in range(100): fut = dist.rpc("worker{}".format(dst_rank), my_tensor_function, args=(torch.ones(i, i), torch.ones(i, i)), async_call=True) futs.append(fut) j = 0 for fut in futs: self.assertEqual( fut.wait(), my_tensor_function(torch.ones(j, j), torch.ones(j, j))) j += 1
def _stress_test_rpc(self, f, repeat=1000, args=()): import time n = self.rank + 1 dst_rank = n % self.world_size futs = [] tik = time.time() for _ in range(repeat): fut = dist.rpc("worker{}".format(dst_rank), f, args=args, async_call=True) futs.append(fut) for fut in futs: self.assertEqual(fut.wait(), 0) tok = time.time() print("Rank {} finished testing {} {} times in {} seconds.".format( self.rank, f.__name__, repeat, tok - tik))
def test_rpc_complex_args(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: num_tensors = 10 tensors = [] for i in range(num_tensors): tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0))) ret = dist.rpc('worker{}'.format(dst_rank), torch.stack, args=(tensors, )) self.assertEqual(torch.stack(tensors), ret) # Verify appropriate tensors have been attached the autograd graph. next_funcs = dist_autograd._current_context()._send_functions( )[0].next_functions idx = 0 for i in range(num_tensors): if i % 2 == 0: self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[i][0].name()) self.assertEqual(tensors[i], next_funcs[i][0].variable) else: self.assertIsNone(next_funcs[i][0])
def test_autograd_send_function(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) ret = dist.rpc('worker{}'.format(dst_rank), torch.add, args=(t1, t2)) # Get send function. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) # Retrieve the next functions in the graph. next_funcs = send_functions[0].next_functions self.assertEqual(2, len(next_funcs)) # We should now hit t1 and t2 in the autograd graph. self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[0][0].name()) self.assertEqual(t1, next_funcs[0][0].variable) self.assertEqual(0, next_funcs[0][1]) self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[1][0].name()) self.assertEqual(t2, next_funcs[1][0].variable) self.assertEqual(0, next_funcs[1][1]) # autograd context should be cleaned up by now. with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(context_id) # No autograd context available. with self.assertRaises(RuntimeError): ctx = dist_autograd._current_context()
def test_py_function_exception(self): n = self.rank + 1 dst_rank = n % self.world_size with self.assertRaisesRegex(Exception, "TypeError"): ret = dist.rpc("worker{}".format(dst_rank), no_result, args=(10, ))
def test_py_no_return_result(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc("worker{}".format(dst_rank), no_result) self.assertEqual(ret, no_result())
def nested_rpc(dst): return dist.rpc(dst, torch.add, args=(torch.ones(2, 2), 1))
def test_py_class_constructor(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc("worker{}".format(dst_rank), my_class, args=(n, )) self.assertEqual(ret.a, n)
def test_py_raise_in_user_func(self): n = self.rank + 1 dst_rank = n % self.world_size fut = dist.rpc("worker{}".format(dst_rank), raise_func, async_call=True) with self.assertRaisesRegex(Exception, "ValueError"): fut.wait()
def test_py_built_in(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2)) self.assertEqual(ret, min(n, n + 1, n + 2))
def test_scalar_add(self): n = self.rank + 1 dstRank = n % self.world_size ret = dist.rpc('worker%d' % dstRank, torch.add, args=(torch.ones(n, n), n)) self.assertEqual(ret, (torch.ones(n, n) + n))
def test_add(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc('worker{}'.format(dst_rank), torch.add, args=(torch.ones(n, n), torch.ones(n, n))) self.assertEqual(ret, torch.ones(n, n) * 2)
def test_py_class_static_method(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc('worker{}'.format(dst_rank), my_class.my_static_method, args=(n + 10,)) self.assertEqual(ret, my_class.my_static_method(n + 10))
def test_py_class_instance_method(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc('worker{}'.format(dst_rank), my_class(2).my_instance_method, args=(n,)) self.assertEqual(ret, my_class(2).my_instance_method(n))