def test_rpc_complex_args(self): with dist_autograd.context() as context_id: num_tensors = 10 tensors = [] for i in range(num_tensors): tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0))) ret = rpc.rpc_sync( "worker{}".format(self._next_rank()), torch.stack, args=(tensors,) ) self.assertEqual(torch.stack(tensors), ret) # Verify appropriate tensors have been attached the autograd graph. next_funcs = list( dist_autograd._current_context()._send_functions().values() )[0].next_functions idx = 0 for i in range(num_tensors): if i % 2 == 0: self.assertEqual( "torch::autograd::AccumulateGrad", next_funcs[i][0].name() ) self.assertEqual(tensors[i], next_funcs[i][0].variable) else: self.assertIsNone(next_funcs[i][0]) # Verify that the worker id has been recorded in the context ctx = dist_autograd._current_context() worker_ids = ctx._known_worker_ids() self.assertEqual(len(worker_ids), 1) dst_rank = (self.rank + 1) % self.world_size self.assertEqual(worker_ids[0], dst_rank)
def test_worker_ids_recorded(self): dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank} with dist_autograd.context() as context_id: # if no tensors require grad, we do not add the send functions, so # no worker ids should be recorded. t1 = torch.ones(3, 3, requires_grad=False) t2 = torch.zeros(3, 3, requires_grad=False) for dst_rank in dst_ranks: ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) rpc.rpc_sync( "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1) ) # no worker ids should be recorded. ctx = dist_autograd._current_context() worker_ids = ctx._known_worker_ids() self.assertEqual(len(worker_ids), 0) # worker_ids should be recorded when tensors do require grad t1.requires_grad = True t2.requires_grad = True for dst_rank in dst_ranks: ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) rpc.rpc_sync( "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1) ) # all worker_ids in dst_ranks should be recorded. worker_ids = ctx._known_worker_ids() self.assertEqual(len(worker_ids), len(dst_ranks)) self.assertEqual(set(worker_ids), dst_ranks)
def test_graph_for_py_nested_call_itself(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) ret = rpc.rpc_sync("worker{}".format(dst_rank), my_py_nested_call, args=(t1, t2, (self.rank - 1 + self.world_size) % self.world_size, self.world_size, 0)) rpc.rpc_sync("worker{}".format((self.rank + 1) % self.world_size), _set_rpc_done, args=(context_id, 1)) # For self.rank, it has 2 graphs to verify. # One is for current context id when this rank send first rpc # call and execute the torch.add() operator. # Another one is for prev context id when this rank make # nested call. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(2, len(send_functions)) recv_functions = ctx._recv_functions() self.assertEqual(2, len(recv_functions)) self._verify_graph_for_first_rpc_call(list(send_functions.values())[0], list(recv_functions.values())[1], t1, t2, ret) self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1]) # Verify two pairs of send and recv functions for nested # call self._check_rpc_done(1) ctx = dist_autograd._retrieve_context(ctx_ids[1]) self._verify_graph_for_nested_rpc_call(ctx) # this barrier is needed so one worker does not clean up their # autograd context before another worker tries to access it. dist.barrier()
def test_graph_for_py_nested_call(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) nest_dst_rank = (dst_rank + 1) % self.world_size ret = rpc.rpc_sync("worker{}".format(dst_rank), my_py_nested_call, args=(t1, t2, dst_rank, self.world_size, 1)) for rd in [1, 2, 3]: rpc.rpc_sync("worker{}".format( (self.rank + rd) % self.world_size), _set_rpc_done, args=(context_id, rd)) # For self.rank, it has 4 graphs to verify # One is for current context id when this rank send first rpc call. # Second one is for prev context id when this rank make 1st nested # call. # Third one is for prev prev context id when this rank make # 2nd nested call. # Last one is for prev prev prev context id when this rank # execute the torch.add() operator. # Verify first graph for current context id. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) recv_functions = ctx._recv_functions() self.assertEqual(1, len(recv_functions)) self._verify_graph_for_first_rpc_call( list(send_functions.values())[0], list(recv_functions.values())[0], t1, t2, ret) # Verify second graph for 1st nested call. self._check_rpc_done(1) ctx = dist_autograd._retrieve_context(ctx_ids[1]) self._verify_graph_for_nested_rpc_call(ctx) # Verify third graph for 2nd nested call. self._check_rpc_done(2) ctx = dist_autograd._retrieve_context(ctx_ids[2]) self._verify_graph_for_nested_rpc_call(ctx) # verify last graph for rpc call execution. self._check_rpc_done(3) ctx = dist_autograd._retrieve_context(ctx_ids[3]) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) self._verify_graph_for_rpc_call_exec( list(send_functions.values())[0]) # this barrier is needed so one worker does not clean up their # autograd context before another worker tries to access it. dist.barrier()
def _test_graph(self, fn): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) ret = rpc.rpc_sync("worker{}".format(dst_rank), fn, args=(t1, t2)) rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)) # Verify graph for current context id. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) recv_functions = ctx._recv_functions() self.assertEqual(1, len(recv_functions)) self._verify_graph_for_first_rpc_call( list(send_functions.values())[0], list(recv_functions.values())[0], t1, t2, ret) # Wait for the prev rank to be done with rpc. self._check_rpc_done(1) # Verify graph for previous context id. ctx = dist_autograd._retrieve_context(ctx_ids[1]) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) self._verify_graph_for_rpc_call_exec( list(send_functions.values())[0]) # this barrier is needed so one worker does not clean up their # autograd context before another worker tries to access it. dist.barrier() # autograd context should be cleaned up by now. with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(context_id) # No autograd context available. with self.assertRaises(RuntimeError): ctx = dist_autograd._current_context()
def test_autograd_send_function(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) ret = dist.rpc_sync('worker{}'.format(dst_rank), torch.add, args=(t1, t2)) # Get send function. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) # Retrieve the next functions in the graph. next_funcs = send_functions[0].next_functions self.assertEqual(2, len(next_funcs)) # We should now hit t1 and t2 in the autograd graph. self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[0][0].name()) self.assertEqual(t1, next_funcs[0][0].variable) self.assertEqual(0, next_funcs[0][1]) self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[1][0].name()) self.assertEqual(t2, next_funcs[1][0].variable) self.assertEqual(0, next_funcs[1][1]) # autograd context should be cleaned up by now. with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(context_id) # No autograd context available. with self.assertRaises(RuntimeError): ctx = dist_autograd._current_context()
def step(self): """ Performs a single optimization step. This will call :meth:`torch.optim.Optimizer.step` on each worker containing parameters to be optimized, and will block until all workers return. The current distributed autograd :class:`~torch.distributed.autograd.context` will be used globally. """ autograd_ctx_id = dist_autograd._current_context()._context_id() rpc_futs = [] for optim in self.remote_optimizers: rpc_futs.append(rpc.rpc_async( optim.owner(), _local_optimizer_step, args=(optim, autograd_ctx_id), )) _wait_for_all(rpc_futs)
def test_no_graph_with_tensors_not_require_grad(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=False) t2 = torch.zeros(3, 3, requires_grad=False) ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)) ctx = dist_autograd._current_context() send_functions = ctx._send_functions() self.assertEqual(len(send_functions), 0) recv_functions = ctx._recv_functions() self.assertEqual(len(recv_functions), 0) # Wait for the prev rank to be done with rpc. self._check_rpc_done(1) # prev context id is not passed over as tensors do not require grads with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(ctx_ids[1])
def _test_no_graph_with_tensors_not_require_grad(self, exec_mode): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=False) t2 = torch.zeros(3, 3, requires_grad=False) if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) elif ExecMode.REMOTE == exec_mode: ret = rpc.remote("worker{}".format(dst_rank), torch.add, args=(t1, t2)).to_here().wait() else: raise ValueError("Unrecognized ExecMode {}".format(exec_mode)) rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)) ctx = dist_autograd._current_context() send_functions = ctx._send_functions() self.assertEqual(len(send_functions), 0) recv_functions = ctx._recv_functions() self.assertEqual(len(recv_functions), 0) # Wait for the prev rank to be done with rpc. self._check_rpc_done(1) if ExecMode.RPC_SYNC == exec_mode: # prev context id is not passed over as tensors do not require # grads with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(ctx_ids[1]) elif ExecMode.REMOTE == exec_mode: # NB: RRef.to_here() always passes the autograd context to the # the callee, as the caller does not know whether the return # value would contain a requires_grad tensor or not. pass
def test_rpc_complex_args(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: num_tensors = 10 tensors = [] for i in range(num_tensors): tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0))) ret = dist.rpc_sync('worker{}'.format(dst_rank), torch.stack, args=(tensors, )) self.assertEqual(torch.stack(tensors), ret) # Verify appropriate tensors have been attached the autograd graph. next_funcs = dist_autograd._current_context()._send_functions( )[0].next_functions idx = 0 for i in range(num_tensors): if i % 2 == 0: self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[i][0].name()) self.assertEqual(tensors[i], next_funcs[i][0].variable) else: self.assertIsNone(next_funcs[i][0])
def test_autograd_functions(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, )) # Get send function. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) # Retrieve the next functions in the graph. next_funcs = list(send_functions.values())[0].next_functions self.assertEqual(2, len(next_funcs)) # We should now hit t1 and t2 in the autograd graph. self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[0][0].name()) self.assertEqual(t1, next_funcs[0][0].variable) self.assertEqual(0, next_funcs[0][1]) self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[1][0].name()) self.assertEqual(t2, next_funcs[1][0].variable) self.assertEqual(0, next_funcs[1][1]) # Test recv functions. recv_functions = ctx._recv_functions() self.assertEqual(1, len(recv_functions)) self.assertEqual(ret.grad_fn, list(recv_functions.values())[0]) # We should have send/recv functions from the previous rank, get all # contexts in this node to find them. # Wait for the prev rank to be done with rpc. while not prev_rank_rpc_done: time.sleep(0.1) pass # Now verify the autograd graph. ctx = dist_autograd._retrieve_context(prev_rank_context_id) # Get the send function. send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) # Verify next function is AddBackward0 next_funcs = list(send_functions.values())[0].next_functions self.assertEqual(1, len(next_funcs)) add_backward_fn = next_funcs[0][0] self.assertEqual("AddBackward0", add_backward_fn.name()) # Verify the next two functions are the same recv backward function. next_funcs = add_backward_fn.next_functions self.assertEqual(2, len(next_funcs)) self.assertEqual("torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()) self.assertEqual("torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name()) self.assertEqual(next_funcs[0][0], next_funcs[1][0]) # autograd context should be cleaned up by now. with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(context_id) # No autograd context available. with self.assertRaises(RuntimeError): ctx = dist_autograd._current_context()
def _test_graph_for_py_nested_call_itself(self, exec_mode): dst_rank = (self.rank + 1) % self.world_size # This is for the below `dist.barrier`. # For `RpcAgent` other than `ProcessGroupAgent`, # no `_default_pg` is initialized. if not dist.is_initialized(): dist.init_process_group( backend="gloo", init_method=self.init_method, rank=self.rank, world_size=self.world_size, ) with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync("worker{}".format(dst_rank), my_py_nested_call, args=(t1, t2, (self.rank - 1 + self.world_size) % self.world_size, self.world_size, 0)) elif ExecMode.REMOTE == exec_mode: ret = rpc.remote( "worker{}".format(dst_rank), my_py_nested_call, args=(t1, t2, (self.rank - 1 + self.world_size) % self.world_size, self.world_size, 0)).to_here().wait() else: raise ValueError("Unrecognized ExecMode {}".format(exec_mode)) rpc.rpc_sync("worker{}".format((self.rank + 1) % self.world_size), _set_rpc_done, args=(context_id, 1)) # For self.rank, it has 2 graphs to verify. # One is for current context id when this rank send first rpc # call and execute the torch.add() operator. # Another one is for prev context id when this rank make # nested call. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(2, len(send_functions)) recv_functions = ctx._recv_functions() self.assertEqual(2, len(recv_functions)) self._verify_graph_for_first_rpc_call( list(send_functions.values())[0], list(recv_functions.values())[1], t1, t2, ret) self._verify_graph_for_rpc_call_exec( list(send_functions.values())[1]) # Verify two pairs of send and recv functions for nested # call self._check_rpc_done(1) ctx = dist_autograd._retrieve_context(ctx_ids[1]) self._verify_graph_for_nested_rpc_call(ctx) # this barrier is needed so one worker does not clean up their # autograd context before another worker tries to access it. dist.barrier()
def _test_graph(self, fn, exec_mode): dst_rank = (self.rank + 1) % self.world_size # This is for the below `dist.barrier`. # For `RpcAgent` other than `ProcessGroupAgent`, # no `_default_pg` is initialized. if not dist.is_initialized(): dist.init_process_group( backend="gloo", init_method=self.init_method, rank=self.rank, world_size=self.world_size, ) with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync("worker{}".format(dst_rank), fn, args=(t1, t2)) elif ExecMode.REMOTE == exec_mode: ret = rpc.remote("worker{}".format(dst_rank), fn, args=(t1, t2)).to_here().wait() else: raise ValueError("Unrecognized ExecMode {}".format(exec_mode)) rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)) # Verify graph for current context id. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) recv_functions = ctx._recv_functions() self.assertEqual(1, len(recv_functions)) self._verify_graph_for_first_rpc_call( list(send_functions.values())[0], list(recv_functions.values())[0], t1, t2, ret) # Wait for the prev rank to be done with rpc. self._check_rpc_done(1) # Verify graph for previous context id. ctx = dist_autograd._retrieve_context(ctx_ids[1]) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) self._verify_graph_for_rpc_call_exec( list(send_functions.values())[0]) # this barrier is needed so one worker does not clean up their # autograd context before another worker tries to access it. dist.barrier() # autograd context should be cleaned up by now. with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(context_id) # No autograd context available. with self.assertRaises(RuntimeError): ctx = dist_autograd._current_context()