def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None: ''' The RRef protocol holds forkIds of rrefs in a map until those forks are confirmed by the owner. The message confirming the fork may arrive after our tests check whether this map is empty, which leads to failures and flaky tests. to_here also does not guarantee that we have finished processind the owner's confirmation message for the RRef. This function loops until the map is empty, which means the messages have been received as processed. Call this function before asserting the map returned by _get_debug_info is empty. ''' start = time.time() while True: debug_info = _rref_context_get_debug_info() num_pending_futures = int(debug_info["num_pending_futures"]) num_pending_users = int(debug_info["num_pending_users"]) if num_pending_futures == 0 and num_pending_users == 0: break time.sleep(0.1) if time.time() - start > timeout: raise ValueError( "Timed out waiting to flush pending futures and users, had {} pending futures and {} pending users".format( num_pending_futures, num_pending_users ) )
def wait_until_pending_users_flushed(): ''' The RRef protocol holds forkIds of rrefs in a map until those forks are confirmed by the owner. The message confirming the fork may arrive after our tests check whether this map is empty, which leads to failures and flaky tests. to_here also does not guarantee that we have finished processind the owner's confirmation message for the RRef. This function loops until the map is empty, which means the messages have been received as processed. Call this function before asserting the map returned by _get_debug_info is empty. ''' num_pending_users = int(_rref_context_get_debug_info()["num_pending_users"]) while num_pending_users != 0: time.sleep(0.1) num_pending_users = int(_rref_context_get_debug_info()["num_pending_users"]) return
def get_num_owners_and_forks() -> Tuple[str, str]: """ Retrieves number of OwnerRRefs and forks on this node from _rref_context_get_debug_info. """ rref_dbg_info = _rref_context_get_debug_info() num_owners = rref_dbg_info["num_owner_rrefs"] num_forks = rref_dbg_info["num_forks"] return num_owners, num_forks
def test_rref_context_debug_info(self): if not dist.is_initialized(): dist.init_process_group( backend="gloo", init_method=self.init_method, rank=self.rank, world_size=self.world_size, ) from torch.distributed.rpc import _rref_context_get_debug_info rref1 = RRef(self.rank) info = _rref_context_get_debug_info() self.assertIn("num_owner_rrefs", info) # RRef on local value is not added to context until shared across RPC self.assertEqual("0", info["num_owner_rrefs"]) dst_rank = (self.rank + 1) % self.world_size rpc.rpc_sync("worker{}".format(dst_rank), set_global_rref, args=(rref1, )) info = _rref_context_get_debug_info() self.assertIn("num_owner_rrefs", info) self.assertEqual("1", info["num_owner_rrefs"]) rpc.rpc_sync("worker{}".format(dst_rank), clear_global_rref) rref2 = rpc.remote("worker{}".format(dst_rank), torch.add, args=(torch.ones(2, 2), 1)) rref3 = rpc.remote("worker{}".format(dst_rank), torch.add, args=(torch.ones(2, 2), 1)) rref2.to_here() rref3.to_here() # Use a barrier to make sure that OwnerRRefs are created on this worker # before checking debug info dist.barrier() info = _rref_context_get_debug_info() self.assertIn("num_owner_rrefs", info) self.assertEqual("2", info["num_owner_rrefs"]) # Use another barrier to make sure that UserRRefs are only deleted after # checking debug info dist.barrier()
def test_debug_info(self): # only test keys in this test case. Values should be covered by # individual module debug info tests from torch.distributed.rpc import (_get_debug_info, _rref_context_get_debug_info) from torch.distributed.rpc.api import _agent import torch.distributed.autograd as dist_autograd info = _get_debug_info() rref_info = _rref_context_get_debug_info() agent_info = _agent.get_debug_info() autograd_info = dist_autograd._get_debug_info() common_keys = rref_info.keys() & agent_info.keys( ) & autograd_info.keys() self.assertEqual(0, len(common_keys)) expected = {} expected.update(rref_info) expected.update(agent_info) expected.update(autograd_info) self.assertEqual(expected, info)
def test_rref_context_debug_info(self): # This test checks local states that are modified by remote workers. # This means that we would need barrier before and after every check. # The barrier before the check makes sure that all previous states are # cleared globally, the barrier after ensures that no following states # change gets into the current check. if not dist.is_initialized(): dist.init_process_group( backend="gloo", init_method=self.init_method, rank=self.rank, world_size=self.world_size, ) from torch.distributed.rpc import _rref_context_get_debug_info # Check 1: local RRef does not update owners_ map ################################################# rref1 = RRef(self.rank) # don't need a barrier here as local RRef is handled by this thread info = _rref_context_get_debug_info() self.assertIn("num_owner_rrefs", info) # RRef on local value is not added to context until shared across RPC self.assertEqual(0, int(info["num_owner_rrefs"])) # barrier after the check 1 dist.barrier() # Check 2: Sharing RRef as an arg should update owners_ map ########################################################### dst_rank = (self.rank + 1) % self.world_size rpc.rpc_sync( "worker{}".format(dst_rank), set_global_rref, args=(rref1,) ) # barrier before check 2 dist.barrier() info = _rref_context_get_debug_info() self.assertIn("num_owner_rrefs", info) self.assertEqual(1, int(info["num_owner_rrefs"])) # barrier after check 2 dist.barrier() # clear states for check 2 rpc.rpc_sync("worker{}".format(dst_rank), clear_global_rref) # Check 3: rpc.remote call should update owners_ map #################################################### rref2 = rpc.remote( "worker{}".format(dst_rank), torch.add, args=(torch.ones(2, 2), 1) ) rref3 = rpc.remote( "worker{}".format(dst_rank), torch.add, args=(torch.ones(2, 2), 1) ) rref2.to_here() rref3.to_here() # barrier before check 3 dist.barrier() info = _rref_context_get_debug_info() self.assertIn("num_owner_rrefs", info) self.assertEqual(2, int(info["num_owner_rrefs"])) # barrier after check 3 dist.barrier()