Exemple #1
0
def run_task_with_timeout(
    task: "Task",
    args: Sequence = (),
    kwargs: Mapping = None,
    logger: Logger = None,
) -> Any:
    """
    Helper function for implementing timeouts on task executions.

    The exact implementation varies depending on whether this function is being
    run in the main thread or a non-daemonic subprocess.  If this is run from a
    daemonic subprocess or on Windows, the task is run in a `ThreadPoolExecutor`
    and only a soft timeout is enforced, meaning a `TaskTimeoutSignal` is raised at the
    appropriate time but the task continues running in the background.

    The task is passed instead of a function so we can give better logs and messages.
    If you need to run generic functions with timeout handlers,
    `run_with_thread_timeout` or `run_with_multiprocess_timeout` can be called directly

    Args:
        - task (Task): the task to execute
            `task.timeout` specifies the number of seconds to allow `task.run` to run
            for before terminating
        - args (Sequence): arguments to pass to the function
        - kwargs (Mapping): keyword arguments to pass to the function
        - logger (Logger): an optional logger to use. If not passed, a logger for the
            `prefect.run_task_with_timeout_handler` namespace will be created.

    Returns:
        - the result of `f(*args, **kwargs)`

    Raises:
        - TaskTimeoutSignal: if function execution exceeds the allowed timeout
    """
    logger = logger or get_logger()
    name = prefect.context.get("task_full_name", task.name)
    kwargs = kwargs or {}

    # if no timeout, just run the function
    if task.timeout is None:
        return task.run(*args, **kwargs)  # type: ignore

    # if we are running the main thread, use a signal to stop execution at the
    # appropriate time; else if we are running in a non-daemonic process, spawn
    # a subprocess to kill at the appropriate time
    if not sys.platform.startswith("win"):

        if threading.current_thread() is threading.main_thread():
            # This case is typically encountered when using a non-parallel or local
            # multiprocess scheduler because then each worker is in the main
            # thread
            logger.debug(
                f"Task '{name}': Attaching thread based timeout handler...")
            return run_with_thread_timeout(
                task.run,
                args,
                kwargs,
                timeout=task.timeout,
                logger=logger,
                name=f"Task '{name}'",
            )

        elif multiprocessing.current_process().daemon is False:
            # This case is typically encountered when using a multithread distributed
            # executor
            logger.debug(
                f"Task '{name}': Attaching process based timeout handler...")
            return run_with_multiprocess_timeout(
                task.run,
                args,
                kwargs,
                timeout=task.timeout,
                logger=logger,
                name=f"Task '{name}'",
            )

        # We are in a daemonic process and cannot enforce a timeout
        # This case is typically encountered when using a multiprocess distributed
        # executor
        soft_timeout_reason = "in a daemonic subprocess"
    else:
        # We are in windows and cannot enforce a timeout
        soft_timeout_reason = "on Windows"

    msg = (
        f"This task is running {soft_timeout_reason}; "
        "consequently Prefect can only enforce a soft timeout limit, i.e., "
        "if your Task reaches its timeout limit it will enter a TimedOut state "
        "but continue running in the background.")

    logger.debug(
        f"Task '{name}': Falling back to daemonic soft limit timeout handler because "
        f"we are running {soft_timeout_reason}.")
    warnings.warn(msg, stacklevel=2)
    executor = ThreadPoolExecutor()

    def run_with_ctx(context: dict) -> Any:
        with prefect.context(context):
            return task.run(*args, **kwargs)  # type: ignore

    # Run the function in the background and then retrieve its result with a timeout
    fut = executor.submit(run_with_ctx, prefect.context.to_dict())

    try:
        return fut.result(timeout=task.timeout)
    except FutureTimeout as exc:
        raise TaskTimeoutSignal(
            f"Execution timed out but was executed {soft_timeout_reason} and will "
            "continue to run in the background.") from exc
Exemple #2
0
 def error_handler(signum, frame):  # type: ignore
     raise TaskTimeoutSignal("Execution timed out.")
Exemple #3
0
def run_with_multiprocess_timeout(
    fn: Callable,
    args: Sequence = (),
    kwargs: Mapping = None,
    timeout: int = None,
    logger: Logger = None,
    name: str = None,
) -> Any:
    """
    Helper function for implementing timeouts on function executions.
    Implemented by spawning a new multiprocess.Process() and joining with timeout.

    Args:
        - fn (callable): the function to execute
        - args (Sequence): arguments to pass to the function
        - kwargs (Mapping): keyword arguments to pass to the function
        - timeout (int): the length of time to allow for execution before raising a
            `TaskTimeoutSignal`, represented as an integer in seconds
        - logger (Logger): an optional logger to use. If not passed, a logger for the
            `prefect.` namespace will be created.
        - name (str): an optional name to attach to logs for this function run, defaults
            to the name of the given function. Provides an interface for passing task
            names for logs.

    Returns:
        - the result of `f(*args, **kwargs)`

    Raises:
        - AssertionError: if run from a daemonic process
        - TaskTimeoutSignal: if function execution exceeds the allowed timeout
    """
    logger = logger or get_logger()
    name = name or f"Function '{fn.__name__}'"
    kwargs = kwargs or {}

    if timeout is None:
        return fn(*args, **kwargs)

    spawn_mp = multiprocessing.get_context("spawn")

    # Create a queue to pass the function return value back
    queue = spawn_mp.Queue()  # type: multiprocessing.Queue

    # Set internal kwargs for the helper function
    request = {
        "fn": fn,
        "args": args,
        "kwargs": kwargs,
        "context": prefect.context.to_dict(),
        "name": name,
        "logger": logger,
    }
    payload = cloudpickle.dumps(request)

    run_process = spawn_mp.Process(
        target=multiprocessing_safe_run_and_retrieve, args=(queue, payload))
    logger.debug(f"{name}: Sending execution to a new process...")
    run_process.start()
    logger.debug(
        f"{name}: Waiting for process to return with {timeout}s timeout...")
    run_process.join(timeout)
    run_process.terminate()

    # Handle the process result, if the queue is empty the function did not finish
    # before the timeout
    logger.debug(f"{name}: Execution process closed, collecting result...")
    if not queue.empty():
        result = cloudpickle.loads(queue.get())
        if isinstance(result, Exception):
            raise result
        return result
    else:
        raise TaskTimeoutSignal(f"Execution timed out for {name}.")
Exemple #4
0
def run_with_multiprocess_timeout(
    fn: Callable,
    args: Sequence = (),
    kwargs: Mapping = None,
    timeout: int = None,
    logger: Logger = None,
    name: str = None,
) -> Any:
    """
    Helper function for implementing timeouts on function executions.

    Implemented by spawning a new multiprocess.Process() and using a queue to pass
    the result back. The result is retrieved from the queue with a timeout.

    Args:
        - fn (callable): the function to execute
        - args (Sequence): arguments to pass to the function
        - kwargs (Mapping): keyword arguments to pass to the function
        - timeout (int): the length of time to allow for execution before raising a
            `TaskTimeoutSignal`, represented as an integer in seconds
        - logger (Logger): an optional logger to use. If not passed, a logger for the
            `prefect.` namespace will be created.
        - name (str): an optional name to attach to logs for this function run, defaults
            to the name of the given function. Provides an interface for passing task
            names for logs.

    Returns:
        - the result of `f(*args, **kwargs)`

    Raises:
        - Exception: Any user errors within the subprocess will be pickled and reraised
        - AssertionError: if run from a daemonic process
        - TaskTimeoutSignal: if function execution exceeds the allowed timeout
    """
    logger = logger or get_logger()
    name = name or f"Function '{fn.__name__}'"
    kwargs = kwargs or {}

    if timeout is None:
        return fn(*args, **kwargs)

    spawn_mp = multiprocessing.get_context("spawn")

    # Create a queue to pass the function return value back
    queue = spawn_mp.Queue()  # type: multiprocessing.Queue

    # Set internal kwargs for the helper function
    request = {
        "fn": fn,
        "args": args,
        "kwargs": kwargs,
        "context": prefect.context.to_dict(),
        "name": name,
        "logger": logger,
    }
    payload = cloudpickle.dumps(request)

    run_process = spawn_mp.Process(
        target=multiprocessing_safe_run_and_retrieve,
        args=(queue, payload),
    )
    logger.debug(f"{name}: Sending execution to a new process...")
    run_process.start()
    logger.debug(
        f"{name}: Waiting for process to return with {timeout}s timeout...")

    # Pull the data from the queue. If empty, the function did not finish before
    # the timeout
    try:
        pickled_result = queue.get(block=True, timeout=timeout)
        logger.debug(f"{name}: Result received from subprocess, unpickling...")
        result = cloudpickle.loads(pickled_result)
        if isinstance(result, (Exception, PrefectSignal)):
            raise result
        return result
    except Empty:
        logger.debug(f"{name}: No result returned within the timeout period!")
        raise TaskTimeoutSignal(f"Execution timed out for {name}.")
    finally:
        # Do not let the process dangle
        run_process.join(0.1)
        run_process.terminate()