def test_graph_for_py_nested_call(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) nest_dst_rank = (dst_rank + 1) % self.world_size ret = rpc.rpc_sync("worker{}".format(dst_rank), my_py_nested_call, args=(t1, t2, dst_rank, self.world_size, 1)) for rd in [1, 2, 3]: rpc.rpc_sync("worker{}".format( (self.rank + rd) % self.world_size), _set_rpc_done, args=(context_id, rd)) # For self.rank, it has 4 graphs to verify # One is for current context id when this rank send first rpc call. # Second one is for prev context id when this rank make 1st nested # call. # Third one is for prev prev context id when this rank make # 2nd nested call. # Last one is for prev prev prev context id when this rank # execute the torch.add() operator. # Verify first graph for current context id. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) recv_functions = ctx._recv_functions() self.assertEqual(1, len(recv_functions)) self._verify_graph_for_first_rpc_call( list(send_functions.values())[0], list(recv_functions.values())[0], t1, t2, ret) # Verify second graph for 1st nested call. self._check_rpc_done(1) ctx = dist_autograd._retrieve_context(ctx_ids[1]) self._verify_graph_for_nested_rpc_call(ctx) # Verify third graph for 2nd nested call. self._check_rpc_done(2) ctx = dist_autograd._retrieve_context(ctx_ids[2]) self._verify_graph_for_nested_rpc_call(ctx) # verify last graph for rpc call execution. self._check_rpc_done(3) ctx = dist_autograd._retrieve_context(ctx_ids[3]) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) self._verify_graph_for_rpc_call_exec( list(send_functions.values())[0]) # this barrier is needed so one worker does not clean up their # autograd context before another worker tries to access it. dist.barrier()
def _all_contexts_cleaned_up(num_contexts, timeout_seconds=10): global known_context_ids start = time.time() context_id_to_raised = {} while time.time() - start < timeout_seconds: for context_id in known_context_ids: try: dist_autograd._retrieve_context(context_id) except RuntimeError: context_id_to_raised[context_id] = True if len(context_id_to_raised) == num_contexts: break # all contexts have been cleaned up if trying to retrieve any context resulted in a RuntimeError. success = len(context_id_to_raised) == num_contexts and all(context_id_to_raised.values()) return success
def test_context_cleanup_many_workers(self): global known_context_ids dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank} with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) for dst_rank in dst_ranks: ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) rpc.rpc_sync("worker{}".format(dst_rank), _store_context_id, args=(context_id,)) # the thread's context id should be cleaned up with self.assertRaises(RuntimeError): dist_autograd._retrieve_context(context_id) # check that all contexts have been cleaned up. success = _all_contexts_cleaned_up(num_contexts=len(dst_ranks)) self.assertTrue(success)
def test_graph_for_py_nested_call_itself(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) ret = rpc.rpc_sync("worker{}".format(dst_rank), my_py_nested_call, args=(t1, t2, (self.rank - 1 + self.world_size) % self.world_size, self.world_size, 0)) rpc.rpc_sync("worker{}".format((self.rank + 1) % self.world_size), _set_rpc_done, args=(context_id, 1)) # For self.rank, it has 2 graphs to verify. # One is for current context id when this rank send first rpc # call and execute the torch.add() operator. # Another one is for prev context id when this rank make # nested call. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(2, len(send_functions)) recv_functions = ctx._recv_functions() self.assertEqual(2, len(recv_functions)) self._verify_graph_for_first_rpc_call(list(send_functions.values())[0], list(recv_functions.values())[1], t1, t2, ret) self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1]) # Verify two pairs of send and recv functions for nested # call self._check_rpc_done(1) ctx = dist_autograd._retrieve_context(ctx_ids[1]) self._verify_graph_for_nested_rpc_call(ctx) # this barrier is needed so one worker does not clean up their # autograd context before another worker tries to access it. dist.barrier()
def test_autograd_context(self): context_ids = [] for i in range(1000): with dist_autograd.context() as context_id: self.assertEqual( context_id, dist_autograd._retrieve_context(context_id)._context_id()) # First 16 bits should be worker_id. self.assertEqual(self.worker_id, context_id >> 48) context_ids.append(context_id) for context_id in context_ids: with self.assertRaisesRegex( RuntimeError, 'Could not find autograd context with id: {}'.format( context_id)): dist_autograd._retrieve_context(context_id)
def _test_graph(self, fn): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) ret = rpc.rpc_sync("worker{}".format(dst_rank), fn, args=(t1, t2)) rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)) # Verify graph for current context id. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) recv_functions = ctx._recv_functions() self.assertEqual(1, len(recv_functions)) self._verify_graph_for_first_rpc_call( list(send_functions.values())[0], list(recv_functions.values())[0], t1, t2, ret) # Wait for the prev rank to be done with rpc. self._check_rpc_done(1) # Verify graph for previous context id. ctx = dist_autograd._retrieve_context(ctx_ids[1]) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) self._verify_graph_for_rpc_call_exec( list(send_functions.values())[0]) # this barrier is needed so one worker does not clean up their # autograd context before another worker tries to access it. dist.barrier() # autograd context should be cleaned up by now. with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(context_id) # No autograd context available. with self.assertRaises(RuntimeError): ctx = dist_autograd._current_context()
def test_autograd_context(self): # Verify max possible id. max_auto_increment = 281474976710655 self.assertEqual(max_auto_increment + (self.worker_id << 48), dist_autograd._get_max_id()) context_ids = [] for i in range(1000): with dist_autograd.context() as context_id: self.assertEqual( context_id, dist_autograd._retrieve_context(context_id)._context_id()) # First 16 bits should be worker_id. self.assertEqual(self.worker_id, context_id >> 48) context_ids.append(context_id) for context_id in context_ids: with self.assertRaisesRegex( RuntimeError, 'Could not find autograd context with id: {}'.format( context_id)): dist_autograd._retrieve_context(context_id)
def test_no_graph_with_tensors_not_require_grad(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=False) t2 = torch.zeros(3, 3, requires_grad=False) ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)) ctx = dist_autograd._current_context() send_functions = ctx._send_functions() self.assertEqual(len(send_functions), 0) recv_functions = ctx._recv_functions() self.assertEqual(len(recv_functions), 0) # Wait for the prev rank to be done with rpc. self._check_rpc_done(1) # prev context id is not passed over as tensors do not require grads with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(ctx_ids[1])
def _test_no_graph_with_tensors_not_require_grad(self, exec_mode): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=False) t2 = torch.zeros(3, 3, requires_grad=False) if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) elif ExecMode.REMOTE == exec_mode: ret = rpc.remote("worker{}".format(dst_rank), torch.add, args=(t1, t2)).to_here().wait() else: raise ValueError("Unrecognized ExecMode {}".format(exec_mode)) rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)) ctx = dist_autograd._current_context() send_functions = ctx._send_functions() self.assertEqual(len(send_functions), 0) recv_functions = ctx._recv_functions() self.assertEqual(len(recv_functions), 0) # Wait for the prev rank to be done with rpc. self._check_rpc_done(1) if ExecMode.RPC_SYNC == exec_mode: # prev context id is not passed over as tensors do not require # grads with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(ctx_ids[1]) elif ExecMode.REMOTE == exec_mode: # NB: RRef.to_here() always passes the autograd context to the # the callee, as the caller does not know whether the return # value would contain a requires_grad tensor or not. pass
def test_autograd_send_function(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) ret = dist.rpc_sync('worker{}'.format(dst_rank), torch.add, args=(t1, t2)) # Get send function. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) # Retrieve the next functions in the graph. next_funcs = send_functions[0].next_functions self.assertEqual(2, len(next_funcs)) # We should now hit t1 and t2 in the autograd graph. self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[0][0].name()) self.assertEqual(t1, next_funcs[0][0].variable) self.assertEqual(0, next_funcs[0][1]) self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[1][0].name()) self.assertEqual(t2, next_funcs[1][0].variable) self.assertEqual(0, next_funcs[1][1]) # autograd context should be cleaned up by now. with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(context_id) # No autograd context available. with self.assertRaises(RuntimeError): ctx = dist_autograd._current_context()
def test_autograd_functions(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, )) # Get send function. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) # Retrieve the next functions in the graph. next_funcs = list(send_functions.values())[0].next_functions self.assertEqual(2, len(next_funcs)) # We should now hit t1 and t2 in the autograd graph. self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[0][0].name()) self.assertEqual(t1, next_funcs[0][0].variable) self.assertEqual(0, next_funcs[0][1]) self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[1][0].name()) self.assertEqual(t2, next_funcs[1][0].variable) self.assertEqual(0, next_funcs[1][1]) # Test recv functions. recv_functions = ctx._recv_functions() self.assertEqual(1, len(recv_functions)) self.assertEqual(ret.grad_fn, list(recv_functions.values())[0]) # We should have send/recv functions from the previous rank, get all # contexts in this node to find them. # Wait for the prev rank to be done with rpc. while not prev_rank_rpc_done: time.sleep(0.1) pass # Now verify the autograd graph. ctx = dist_autograd._retrieve_context(prev_rank_context_id) # Get the send function. send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) # Verify next function is AddBackward0 next_funcs = list(send_functions.values())[0].next_functions self.assertEqual(1, len(next_funcs)) add_backward_fn = next_funcs[0][0] self.assertEqual("AddBackward0", add_backward_fn.name()) # Verify the next two functions are the same recv backward function. next_funcs = add_backward_fn.next_functions self.assertEqual(2, len(next_funcs)) self.assertEqual("torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()) self.assertEqual("torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name()) self.assertEqual(next_funcs[0][0], next_funcs[1][0]) # autograd context should be cleaned up by now. with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(context_id) # No autograd context available. with self.assertRaises(RuntimeError): ctx = dist_autograd._current_context()
def _test_graph_for_py_nested_call_itself(self, exec_mode): dst_rank = (self.rank + 1) % self.world_size # This is for the below `dist.barrier`. # For `RpcAgent` other than `ProcessGroupAgent`, # no `_default_pg` is initialized. if not dist.is_initialized(): dist.init_process_group( backend="gloo", init_method=self.init_method, rank=self.rank, world_size=self.world_size, ) with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync("worker{}".format(dst_rank), my_py_nested_call, args=(t1, t2, (self.rank - 1 + self.world_size) % self.world_size, self.world_size, 0)) elif ExecMode.REMOTE == exec_mode: ret = rpc.remote( "worker{}".format(dst_rank), my_py_nested_call, args=(t1, t2, (self.rank - 1 + self.world_size) % self.world_size, self.world_size, 0)).to_here().wait() else: raise ValueError("Unrecognized ExecMode {}".format(exec_mode)) rpc.rpc_sync("worker{}".format((self.rank + 1) % self.world_size), _set_rpc_done, args=(context_id, 1)) # For self.rank, it has 2 graphs to verify. # One is for current context id when this rank send first rpc # call and execute the torch.add() operator. # Another one is for prev context id when this rank make # nested call. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(2, len(send_functions)) recv_functions = ctx._recv_functions() self.assertEqual(2, len(recv_functions)) self._verify_graph_for_first_rpc_call( list(send_functions.values())[0], list(recv_functions.values())[1], t1, t2, ret) self._verify_graph_for_rpc_call_exec( list(send_functions.values())[1]) # Verify two pairs of send and recv functions for nested # call self._check_rpc_done(1) ctx = dist_autograd._retrieve_context(ctx_ids[1]) self._verify_graph_for_nested_rpc_call(ctx) # this barrier is needed so one worker does not clean up their # autograd context before another worker tries to access it. dist.barrier()
def _test_graph(self, fn, exec_mode): dst_rank = (self.rank + 1) % self.world_size # This is for the below `dist.barrier`. # For `RpcAgent` other than `ProcessGroupAgent`, # no `_default_pg` is initialized. if not dist.is_initialized(): dist.init_process_group( backend="gloo", init_method=self.init_method, rank=self.rank, world_size=self.world_size, ) with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync("worker{}".format(dst_rank), fn, args=(t1, t2)) elif ExecMode.REMOTE == exec_mode: ret = rpc.remote("worker{}".format(dst_rank), fn, args=(t1, t2)).to_here().wait() else: raise ValueError("Unrecognized ExecMode {}".format(exec_mode)) rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)) # Verify graph for current context id. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) recv_functions = ctx._recv_functions() self.assertEqual(1, len(recv_functions)) self._verify_graph_for_first_rpc_call( list(send_functions.values())[0], list(recv_functions.values())[0], t1, t2, ret) # Wait for the prev rank to be done with rpc. self._check_rpc_done(1) # Verify graph for previous context id. ctx = dist_autograd._retrieve_context(ctx_ids[1]) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) self._verify_graph_for_rpc_call_exec( list(send_functions.values())[0]) # this barrier is needed so one worker does not clean up their # autograd context before another worker tries to access it. dist.barrier() # autograd context should be cleaned up by now. with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(context_id) # No autograd context available. with self.assertRaises(RuntimeError): ctx = dist_autograd._current_context()