def run_task_with_timeout( task: "Task", args: Sequence = (), kwargs: Mapping = None, logger: Logger = None, ) -> Any: """ Helper function for implementing timeouts on task executions. The exact implementation varies depending on whether this function is being run in the main thread or a non-daemonic subprocess. If this is run from a daemonic subprocess or on Windows, the task is run in a `ThreadPoolExecutor` and only a soft timeout is enforced, meaning a `TaskTimeoutSignal` is raised at the appropriate time but the task continues running in the background. The task is passed instead of a function so we can give better logs and messages. If you need to run generic functions with timeout handlers, `run_with_thread_timeout` or `run_with_multiprocess_timeout` can be called directly Args: - task (Task): the task to execute `task.timeout` specifies the number of seconds to allow `task.run` to run for before terminating - args (Sequence): arguments to pass to the function - kwargs (Mapping): keyword arguments to pass to the function - logger (Logger): an optional logger to use. If not passed, a logger for the `prefect.run_task_with_timeout_handler` namespace will be created. Returns: - the result of `f(*args, **kwargs)` Raises: - TaskTimeoutSignal: if function execution exceeds the allowed timeout """ logger = logger or get_logger() name = prefect.context.get("task_full_name", task.name) kwargs = kwargs or {} # if no timeout, just run the function if task.timeout is None: return task.run(*args, **kwargs) # type: ignore # if we are running the main thread, use a signal to stop execution at the # appropriate time; else if we are running in a non-daemonic process, spawn # a subprocess to kill at the appropriate time if not sys.platform.startswith("win"): if threading.current_thread() is threading.main_thread(): # This case is typically encountered when using a non-parallel or local # multiprocess scheduler because then each worker is in the main # thread logger.debug( f"Task '{name}': Attaching thread based timeout handler...") return run_with_thread_timeout( task.run, args, kwargs, timeout=task.timeout, logger=logger, name=f"Task '{name}'", ) elif multiprocessing.current_process().daemon is False: # This case is typically encountered when using a multithread distributed # executor logger.debug( f"Task '{name}': Attaching process based timeout handler...") return run_with_multiprocess_timeout( task.run, args, kwargs, timeout=task.timeout, logger=logger, name=f"Task '{name}'", ) # We are in a daemonic process and cannot enforce a timeout # This case is typically encountered when using a multiprocess distributed # executor soft_timeout_reason = "in a daemonic subprocess" else: # We are in windows and cannot enforce a timeout soft_timeout_reason = "on Windows" msg = ( f"This task is running {soft_timeout_reason}; " "consequently Prefect can only enforce a soft timeout limit, i.e., " "if your Task reaches its timeout limit it will enter a TimedOut state " "but continue running in the background.") logger.debug( f"Task '{name}': Falling back to daemonic soft limit timeout handler because " f"we are running {soft_timeout_reason}.") warnings.warn(msg, stacklevel=2) executor = ThreadPoolExecutor() def run_with_ctx(context: dict) -> Any: with prefect.context(context): return task.run(*args, **kwargs) # type: ignore # Run the function in the background and then retrieve its result with a timeout fut = executor.submit(run_with_ctx, prefect.context.to_dict()) try: return fut.result(timeout=task.timeout) except FutureTimeout as exc: raise TaskTimeoutSignal( f"Execution timed out but was executed {soft_timeout_reason} and will " "continue to run in the background.") from exc
def error_handler(signum, frame): # type: ignore raise TaskTimeoutSignal("Execution timed out.")
def run_with_multiprocess_timeout( fn: Callable, args: Sequence = (), kwargs: Mapping = None, timeout: int = None, logger: Logger = None, name: str = None, ) -> Any: """ Helper function for implementing timeouts on function executions. Implemented by spawning a new multiprocess.Process() and joining with timeout. Args: - fn (callable): the function to execute - args (Sequence): arguments to pass to the function - kwargs (Mapping): keyword arguments to pass to the function - timeout (int): the length of time to allow for execution before raising a `TaskTimeoutSignal`, represented as an integer in seconds - logger (Logger): an optional logger to use. If not passed, a logger for the `prefect.` namespace will be created. - name (str): an optional name to attach to logs for this function run, defaults to the name of the given function. Provides an interface for passing task names for logs. Returns: - the result of `f(*args, **kwargs)` Raises: - AssertionError: if run from a daemonic process - TaskTimeoutSignal: if function execution exceeds the allowed timeout """ logger = logger or get_logger() name = name or f"Function '{fn.__name__}'" kwargs = kwargs or {} if timeout is None: return fn(*args, **kwargs) spawn_mp = multiprocessing.get_context("spawn") # Create a queue to pass the function return value back queue = spawn_mp.Queue() # type: multiprocessing.Queue # Set internal kwargs for the helper function request = { "fn": fn, "args": args, "kwargs": kwargs, "context": prefect.context.to_dict(), "name": name, "logger": logger, } payload = cloudpickle.dumps(request) run_process = spawn_mp.Process( target=multiprocessing_safe_run_and_retrieve, args=(queue, payload)) logger.debug(f"{name}: Sending execution to a new process...") run_process.start() logger.debug( f"{name}: Waiting for process to return with {timeout}s timeout...") run_process.join(timeout) run_process.terminate() # Handle the process result, if the queue is empty the function did not finish # before the timeout logger.debug(f"{name}: Execution process closed, collecting result...") if not queue.empty(): result = cloudpickle.loads(queue.get()) if isinstance(result, Exception): raise result return result else: raise TaskTimeoutSignal(f"Execution timed out for {name}.")
def run_with_multiprocess_timeout( fn: Callable, args: Sequence = (), kwargs: Mapping = None, timeout: int = None, logger: Logger = None, name: str = None, ) -> Any: """ Helper function for implementing timeouts on function executions. Implemented by spawning a new multiprocess.Process() and using a queue to pass the result back. The result is retrieved from the queue with a timeout. Args: - fn (callable): the function to execute - args (Sequence): arguments to pass to the function - kwargs (Mapping): keyword arguments to pass to the function - timeout (int): the length of time to allow for execution before raising a `TaskTimeoutSignal`, represented as an integer in seconds - logger (Logger): an optional logger to use. If not passed, a logger for the `prefect.` namespace will be created. - name (str): an optional name to attach to logs for this function run, defaults to the name of the given function. Provides an interface for passing task names for logs. Returns: - the result of `f(*args, **kwargs)` Raises: - Exception: Any user errors within the subprocess will be pickled and reraised - AssertionError: if run from a daemonic process - TaskTimeoutSignal: if function execution exceeds the allowed timeout """ logger = logger or get_logger() name = name or f"Function '{fn.__name__}'" kwargs = kwargs or {} if timeout is None: return fn(*args, **kwargs) spawn_mp = multiprocessing.get_context("spawn") # Create a queue to pass the function return value back queue = spawn_mp.Queue() # type: multiprocessing.Queue # Set internal kwargs for the helper function request = { "fn": fn, "args": args, "kwargs": kwargs, "context": prefect.context.to_dict(), "name": name, "logger": logger, } payload = cloudpickle.dumps(request) run_process = spawn_mp.Process( target=multiprocessing_safe_run_and_retrieve, args=(queue, payload), ) logger.debug(f"{name}: Sending execution to a new process...") run_process.start() logger.debug( f"{name}: Waiting for process to return with {timeout}s timeout...") # Pull the data from the queue. If empty, the function did not finish before # the timeout try: pickled_result = queue.get(block=True, timeout=timeout) logger.debug(f"{name}: Result received from subprocess, unpickling...") result = cloudpickle.loads(pickled_result) if isinstance(result, (Exception, PrefectSignal)): raise result return result except Empty: logger.debug(f"{name}: No result returned within the timeout period!") raise TaskTimeoutSignal(f"Execution timed out for {name}.") finally: # Do not let the process dangle run_process.join(0.1) run_process.terminate()