Esempio n. 1
0
 def test_rref_timeout_pickle_script_func(self):
     # Similar to above test, but calls python rpc with script function.
     if self.rank != 0:
         return
     dst_rank = (self.rank + 1) % self.world_size
     dst_worker = "worker{}".format(dst_rank)
     rref = rpc.remote(
         dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
     )
     # Will ensure error handling callbacks are run.
     wait_until_pending_futures_and_users_flushed()
     # Call RPC with script function that takes RRef, ensure timeout during pickling
     with self.assertRaisesRegex(RuntimeError, "RRef creation"):
         rpc.rpc_sync(dst_worker, rref_to_here, args=(rref, ))
Esempio n. 2
0
 def test_rref_timeout_pickle_in_jit(self):
     if self.rank != 0:
         return
     dst_rank = (self.rank + 1) % self.world_size
     dst_worker = "worker{}".format(dst_rank)
     rref = rpc.remote(
         dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
     )
     # Will ensure error handling callbacks are run.
     wait_until_pending_futures_and_users_flushed()
     # Call RPC with RRef arg in JIT, which will go through JIT pickling and
     # ensure error is raised.
     with self.assertRaisesRegex(RuntimeError, "RRef creation"):
         rpc_async_with_rref_arg(dst_worker, (rref, ))
    def _test_remote_message_dropped_timeout(self, func, args, dst=None):
        if self.rank != 0:
            return

        # test the case where rpc.remote() message creation is completely dropped.
        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
        dst_worker = "worker{}".format(dst_rank)
        # Since we fail python_remote_call messages synchronously, the future
        # corresponding to this remote call will be marked with an error when
        # this function returns.
        rref = rpc.remote(dst_worker, func, args=args)
        # Call to ensure pending callbacks are run.
        wait_until_pending_futures_and_users_flushed()
        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
            rref.to_here()
Esempio n. 4
0
 def test_remote_timeout_to_here_in_jit(self):
     # Test that calling to_here() in JIT will raise timeout error if
     # rpc.remote failed.
     if self.rank != 0:
         return
     dst_rank = (self.rank + 1) % self.world_size
     dst_worker = "worker{}".format(dst_rank)
     rref = rpc.remote(
         dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
     )
     # Will ensure error handling callbacks are run.
     wait_until_pending_futures_and_users_flushed()
     # Call to_here() within a ScriptFunction and ensure it raises
     with self.assertRaisesRegex(RuntimeError, "RRef creation"):
         rref_to_here(rref)
 def _test_remote_message_dropped_pickle(self, dst=None):
     if self.rank != 0:
         return
     dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
     dst_worker = "worker{}".format(dst_rank)
     # Since we fail python_remote_call messages synchronously, the future
     # corresponding to this remote call will be marked with an error when
     # this function returns.
     rref = rpc.remote(dst_worker, my_sleep_func, args=(1,))
     # Call to ensure pending callbacks are run.
     wait_until_pending_futures_and_users_flushed()
     # Attempt to fork the RRef should raise an error indicating the rpc.remote timeout.
     with self.assertRaisesRegex(RuntimeError, "RRef creation"):
         rref._serialize()
     # Test that using RRef as arg over RPC (which forks) results in the same
     # error
     with self.assertRaisesRegex(RuntimeError, "RRef creation"):
         rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 1))
    def _test_remote_message_delay_timeout(self, func, args, dst=None):
        if self.rank != 0:
            return
        # Test the case where remote message is eventually processed on the owner,
        # but the future on the creator times out before the response comes back.
        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
        dst_worker = "worker{}".format(dst_rank)
        # 10 ms timeout
        rref = rpc.remote(dst_worker, func, args=args, timeout=0.001)
        # Future corresponding to the remote creation should time out.
        expected_error = self.get_timeout_error_regex()
        with self.assertRaisesRegex(RuntimeError, expected_error):
            rref._get_future().wait()

        # Call to ensure pending callbacks are run.
        wait_until_pending_futures_and_users_flushed()
        # to_here() should now pick up that rpc.remote() creation has failed.
        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
            rref.to_here()

        # Test the case where rpc.remote() times out, but to_here() has already
        # started blocking before.
        # NOTE: we only test this when not sending to self, as to_here() calls
        # calls localValue(), which does not send an RPC and thus does not have
        # a timeout. This can be supported by allowing future.wait() to
        # take in an optional timeout (https://github.com/pytorch/pytorch/issues/39280)
        if dst_rank != self.rank:
            slow_rref = rpc.remote(dst_worker, func, args=args, timeout=2)

            with self.assertRaisesRegex(RuntimeError, expected_error):
                # to_here() should raise timeout error, since it does not know about the
                # status of rpc.remote().
                slow_rref.to_here(0.001)
        # Note: If we proceed with shutdown, UserRRef will send out a RRefUserDelete
        # but this can be a noop since it may not exist on the owner yet. Later,
        # the owner can process the RRef creation and wait for the delete message,
        # thus leading to a timeout.
        # Therefore, we wait until we get notification that pending owners have
        # been confirmed before sending out RRefUserDeletes.
        if dst_rank != self.rank:
            wait_until_owners_and_forks_on_rank(2, 2, rank=dst_rank)