コード例 #1
0
ファイル: dist_utils.py プロジェクト: zoltanszekely21/pytorch
def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None:
    '''
    The RRef protocol holds forkIds of rrefs in a map until those forks are
    confirmed by the owner. The message confirming the fork may arrive after
    our tests check whether this map is empty, which leads to failures and
    flaky tests. to_here also does not guarantee that we have finished
    processind the owner's confirmation message for the RRef. This function
    loops until the map is empty, which means the messages have been received
    as processed. Call this function before asserting the map returned by
    _get_debug_info is empty.
    '''
    start = time.time()
    while True:
        debug_info = _rref_context_get_debug_info()
        num_pending_futures = int(debug_info["num_pending_futures"])
        num_pending_users = int(debug_info["num_pending_users"])
        if num_pending_futures == 0 and num_pending_users == 0:
            break
        time.sleep(0.1)
        if time.time() - start > timeout:
            raise ValueError(
                "Timed out waiting to flush pending futures and users, had {} pending futures and {} pending users".format(
                    num_pending_futures, num_pending_users
                )
            )
コード例 #2
0
ファイル: dist_utils.py プロジェクト: zmmyc/pytorch
def wait_until_pending_users_flushed():
    '''
    The RRef protocol holds forkIds of rrefs in a map until those forks are
    confirmed by the owner. The message confirming the fork may arrive after
    our tests check whether this map is empty, which leads to failures and
    flaky tests. to_here also does not guarantee that we have finished
    processind the owner's confirmation message for the RRef. This function
    loops until the map is empty, which means the messages have been received
    as processed. Call this function before asserting the map returned by
    _get_debug_info is empty.
    '''
    num_pending_users = int(_rref_context_get_debug_info()["num_pending_users"])
    while num_pending_users != 0:
        time.sleep(0.1)
        num_pending_users = int(_rref_context_get_debug_info()["num_pending_users"])
    return
コード例 #3
0
ファイル: dist_utils.py プロジェクト: skidipap/pytorch
def get_num_owners_and_forks() -> Tuple[str, str]:
    """
    Retrieves number of OwnerRRefs and forks on this node from
    _rref_context_get_debug_info.
    """
    rref_dbg_info = _rref_context_get_debug_info()
    num_owners = rref_dbg_info["num_owner_rrefs"]
    num_forks = rref_dbg_info["num_forks"]
    return num_owners, num_forks
コード例 #4
0
    def test_rref_context_debug_info(self):
        if not dist.is_initialized():
            dist.init_process_group(
                backend="gloo",
                init_method=self.init_method,
                rank=self.rank,
                world_size=self.world_size,
            )

        from torch.distributed.rpc import _rref_context_get_debug_info
        rref1 = RRef(self.rank)
        info = _rref_context_get_debug_info()
        self.assertIn("num_owner_rrefs", info)
        # RRef on local value is not added to context until shared across RPC
        self.assertEqual("0", info["num_owner_rrefs"])

        dst_rank = (self.rank + 1) % self.world_size
        rpc.rpc_sync("worker{}".format(dst_rank),
                     set_global_rref,
                     args=(rref1, ))
        info = _rref_context_get_debug_info()
        self.assertIn("num_owner_rrefs", info)
        self.assertEqual("1", info["num_owner_rrefs"])
        rpc.rpc_sync("worker{}".format(dst_rank), clear_global_rref)

        rref2 = rpc.remote("worker{}".format(dst_rank),
                           torch.add,
                           args=(torch.ones(2, 2), 1))
        rref3 = rpc.remote("worker{}".format(dst_rank),
                           torch.add,
                           args=(torch.ones(2, 2), 1))
        rref2.to_here()
        rref3.to_here()

        # Use a barrier to make sure that OwnerRRefs are created on this worker
        # before checking debug info
        dist.barrier()
        info = _rref_context_get_debug_info()
        self.assertIn("num_owner_rrefs", info)
        self.assertEqual("2", info["num_owner_rrefs"])

        # Use another barrier to make sure that UserRRefs are only deleted after
        # checking debug info
        dist.barrier()
コード例 #5
0
    def test_debug_info(self):
        # only test keys in this test case. Values should be covered by
        # individual module debug info tests
        from torch.distributed.rpc import (_get_debug_info,
                                           _rref_context_get_debug_info)
        from torch.distributed.rpc.api import _agent
        import torch.distributed.autograd as dist_autograd

        info = _get_debug_info()
        rref_info = _rref_context_get_debug_info()
        agent_info = _agent.get_debug_info()
        autograd_info = dist_autograd._get_debug_info()
        common_keys = rref_info.keys() & agent_info.keys(
        ) & autograd_info.keys()
        self.assertEqual(0, len(common_keys))
        expected = {}
        expected.update(rref_info)
        expected.update(agent_info)
        expected.update(autograd_info)
        self.assertEqual(expected, info)
コード例 #6
0
ファイル: rpc_test.py プロジェクト: yangcc2019/pytorch
    def test_rref_context_debug_info(self):
        # This test checks local states that are modified by remote workers.
        # This means that we would need barrier before and after every check.
        # The barrier before the check makes sure that all previous states are
        # cleared globally, the barrier after ensures that no following states
        # change gets into the current check.
        if not dist.is_initialized():
            dist.init_process_group(
                backend="gloo",
                init_method=self.init_method,
                rank=self.rank,
                world_size=self.world_size,
            )

        from torch.distributed.rpc import _rref_context_get_debug_info
        # Check 1: local RRef does not update owners_ map
        #################################################

        rref1 = RRef(self.rank)

        # don't need a barrier here as local RRef is handled by this thread
        info = _rref_context_get_debug_info()
        self.assertIn("num_owner_rrefs", info)
        # RRef on local value is not added to context until shared across RPC
        self.assertEqual(0, int(info["num_owner_rrefs"]))

        # barrier after the check 1
        dist.barrier()

        # Check 2: Sharing RRef as an arg should update owners_ map
        ###########################################################

        dst_rank = (self.rank + 1) % self.world_size
        rpc.rpc_sync(
            "worker{}".format(dst_rank),
            set_global_rref,
            args=(rref1,)
        )

        # barrier before check 2
        dist.barrier()

        info = _rref_context_get_debug_info()
        self.assertIn("num_owner_rrefs", info)
        self.assertEqual(1, int(info["num_owner_rrefs"]))

        # barrier after check 2
        dist.barrier()

        # clear states for check 2
        rpc.rpc_sync("worker{}".format(dst_rank), clear_global_rref)

        # Check 3: rpc.remote call should update owners_ map
        ####################################################
        rref2 = rpc.remote(
            "worker{}".format(dst_rank),
            torch.add,
            args=(torch.ones(2, 2), 1)
        )
        rref3 = rpc.remote(
            "worker{}".format(dst_rank),
            torch.add,
            args=(torch.ones(2, 2), 1)
        )
        rref2.to_here()
        rref3.to_here()

        # barrier before check 3
        dist.barrier()

        info = _rref_context_get_debug_info()
        self.assertIn("num_owner_rrefs", info)
        self.assertEqual(2, int(info["num_owner_rrefs"]))

        # barrier after check 3
        dist.barrier()