def test_rref_timeout_pickle_script_func(self): # Similar to above test, but calls python rpc with script function. if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) rref = rpc.remote( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) # Will ensure error handling callbacks are run. wait_until_pending_futures_and_users_flushed() # Call RPC with script function that takes RRef, ensure timeout during pickling with self.assertRaisesRegex(RuntimeError, "RRef creation"): rpc.rpc_sync(dst_worker, rref_to_here, args=(rref, ))
def test_rref_timeout_pickle_in_jit(self): if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) rref = rpc.remote( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) # Will ensure error handling callbacks are run. wait_until_pending_futures_and_users_flushed() # Call RPC with RRef arg in JIT, which will go through JIT pickling and # ensure error is raised. with self.assertRaisesRegex(RuntimeError, "RRef creation"): rpc_async_with_rref_arg(dst_worker, (rref, ))
def _test_remote_message_dropped_timeout(self, func, args, dst=None): if self.rank != 0: return # test the case where rpc.remote() message creation is completely dropped. dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) # Since we fail python_remote_call messages synchronously, the future # corresponding to this remote call will be marked with an error when # this function returns. rref = rpc.remote(dst_worker, func, args=args) # Call to ensure pending callbacks are run. wait_until_pending_futures_and_users_flushed() with self.assertRaisesRegex(RuntimeError, "RRef creation"): rref.to_here()
def test_remote_timeout_to_here_in_jit(self): # Test that calling to_here() in JIT will raise timeout error if # rpc.remote failed. if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) rref = rpc.remote( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) # Will ensure error handling callbacks are run. wait_until_pending_futures_and_users_flushed() # Call to_here() within a ScriptFunction and ensure it raises with self.assertRaisesRegex(RuntimeError, "RRef creation"): rref_to_here(rref)
def _test_remote_message_dropped_pickle(self, dst=None): if self.rank != 0: return dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) # Since we fail python_remote_call messages synchronously, the future # corresponding to this remote call will be marked with an error when # this function returns. rref = rpc.remote(dst_worker, my_sleep_func, args=(1,)) # Call to ensure pending callbacks are run. wait_until_pending_futures_and_users_flushed() # Attempt to fork the RRef should raise an error indicating the rpc.remote timeout. with self.assertRaisesRegex(RuntimeError, "RRef creation"): rref._serialize() # Test that using RRef as arg over RPC (which forks) results in the same # error with self.assertRaisesRegex(RuntimeError, "RRef creation"): rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 1))
def _test_remote_message_delay_timeout(self, func, args, dst=None): if self.rank != 0: return # Test the case where remote message is eventually processed on the owner, # but the future on the creator times out before the response comes back. dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) # 10 ms timeout rref = rpc.remote(dst_worker, func, args=args, timeout=0.001) # Future corresponding to the remote creation should time out. expected_error = self.get_timeout_error_regex() with self.assertRaisesRegex(RuntimeError, expected_error): rref._get_future().wait() # Call to ensure pending callbacks are run. wait_until_pending_futures_and_users_flushed() # to_here() should now pick up that rpc.remote() creation has failed. with self.assertRaisesRegex(RuntimeError, "RRef creation"): rref.to_here() # Test the case where rpc.remote() times out, but to_here() has already # started blocking before. # NOTE: we only test this when not sending to self, as to_here() calls # calls localValue(), which does not send an RPC and thus does not have # a timeout. This can be supported by allowing future.wait() to # take in an optional timeout (https://github.com/pytorch/pytorch/issues/39280) if dst_rank != self.rank: slow_rref = rpc.remote(dst_worker, func, args=args, timeout=2) with self.assertRaisesRegex(RuntimeError, expected_error): # to_here() should raise timeout error, since it does not know about the # status of rpc.remote(). slow_rref.to_here(0.001) # Note: If we proceed with shutdown, UserRRef will send out a RRefUserDelete # but this can be a noop since it may not exist on the owner yet. Later, # the owner can process the RRef creation and wait for the delete message, # thus leading to a timeout. # Therefore, we wait until we get notification that pending owners have # been confirmed before sending out RRefUserDeletes. if dst_rank != self.rank: wait_until_owners_and_forks_on_rank(2, 2, rank=dst_rank)