Beispiel #1
0
def dask_task_wrapper(func, repack, *args):
    """
    A Ray remote function acting as a Dask task wrapper. This function will
    repackage the given flat `args` into its original data structures using
    `repack`, execute any Dask subtasks within the repackaged arguments
    (inlined by Dask's optimization pass), and then pass the concrete task
    arguments to the provide Dask task function, `func`.

    Args:
        func (callable): The Dask task function to execute.
        repack (callable): A function that repackages the provided args into
            the original (possibly nested) Python objects.
        *args (ObjectRef): Ray object references representing the Dask task's
            arguments.

    Returns:
        The output of the Dask task. In the context of Ray, a
        dask_task_wrapper.remote() invocation will return a Ray object
        reference representing the Ray task's result.
    """
    repacked_args, repacked_deps = repack(args)
    # Recursively execute Dask-inlined tasks.
    actual_args = [_execute_task(a, repacked_deps) for a in repacked_args]
    # Execute the actual underlying Dask task.
    return func(*actual_args)
Beispiel #2
0
def cache_entry(cache, key, task):
    with cache.lock:
        try:
            return cache.cache[key]
        except KeyError:
            cache.cache[key] = value = _execute_task(task, {})
            return value
Beispiel #3
0
def execute_task(key, task_info, dumps, loads, get_id, pack_exception):
    """
    Compute task and handle all administration
    See Also
    --------
    _execute_task : actually execute task
    """
    try:
        task, data = loads(task_info)
        result = _execute_task(task, data)
        id = get_id()
        result = dumps((result, id))
        failed = False
    except BaseException as e:
        result = pack_exception(e, dumps)
        failed = True
    return key, result, failed
Beispiel #4
0
def dask_task_wrapper(func, repack, key, ray_pretask_cbs, ray_posttask_cbs,
                      *args):
    """
    A Ray remote function acting as a Dask task wrapper. This function will
    repackage the given flat `args` into its original data structures using
    `repack`, execute any Dask subtasks within the repackaged arguments
    (inlined by Dask's optimization pass), and then pass the concrete task
    arguments to the provide Dask task function, `func`.

    Args:
        func (callable): The Dask task function to execute.
        repack (callable): A function that repackages the provided args into
            the original (possibly nested) Python objects.
        key (str): The Dask key for this task.
        ray_pretask_cbs (callable): Pre-task execution callbacks.
        ray_posttask_cbs (callable): Post-task execution callback.
        *args (ObjectRef): Ray object references representing the Dask task's
            arguments.

    Returns:
        The output of the Dask task. In the context of Ray, a
        dask_task_wrapper.remote() invocation will return a Ray object
        reference representing the Ray task's result.
    """
    if ray_pretask_cbs is not None:
        pre_states = [
            cb(key, args) if cb is not None else None for cb in ray_pretask_cbs
        ]
    repacked_args, repacked_deps = repack(args)
    # Recursively execute Dask-inlined tasks.
    actual_args = [_execute_task(a, repacked_deps) for a in repacked_args]
    # Execute the actual underlying Dask task.
    result = func(*actual_args)

    if ray_posttask_cbs is not None:
        for cb, pre_state in zip(ray_posttask_cbs, pre_states):
            if cb is not None:
                cb(key, result, pre_state)

    return result
Beispiel #5
0
def get_async(
    submit,
    num_workers,
    dsk,
    result,
    cache=None,
    get_id=default_get_id,
    rerun_exceptions_locally=None,
    pack_exception=default_pack_exception,
    raise_exception=reraise,
    callbacks=None,
    dumps=identity,
    loads=identity,
    chunksize=None,
    **kwargs,
):
    """Asynchronous get function

    This is a general version of various asynchronous schedulers for dask.  It
    takes a ``concurrent.futures.Executor.submit`` function to form a more
    specific ``get`` method that walks through the dask array with parallel
    workers, avoiding repeat computation and minimizing memory use.

    Parameters
    ----------
    submit : function
        A ``concurrent.futures.Executor.submit`` function
    num_workers : int
        The number of workers that task submissions can be spread over
    dsk : dict
        A dask dictionary specifying a workflow
    result : key or list of keys
        Keys corresponding to desired data
    cache : dict-like, optional
        Temporary storage of results
    get_id : callable, optional
        Function to return the worker id, takes no arguments. Examples are
        `threading.current_thread` and `multiprocessing.current_process`.
    rerun_exceptions_locally : bool, optional
        Whether to rerun failing tasks in local process to enable debugging
        (False by default)
    pack_exception : callable, optional
        Function to take an exception and ``dumps`` method, and return a
        serialized tuple of ``(exception, traceback)`` to send back to the
        scheduler. Default is to just raise the exception.
    raise_exception : callable, optional
        Function that takes an exception and a traceback, and raises an error.
    callbacks : tuple or list of tuples, optional
        Callbacks are passed in as tuples of length 5. Multiple sets of
        callbacks may be passed in as a list of tuples. For more information,
        see the dask.diagnostics documentation.
    dumps: callable, optional
        Function to serialize task data and results to communicate between
        worker and parent.  Defaults to identity.
    loads: callable, optional
        Inverse function of `dumps`.  Defaults to identity.
    chunksize: int, optional
        Size of chunks to use when dispatching work. Defaults to 1.
        If -1, will be computed to evenly divide ready work across workers.

    See Also
    --------
    threaded.get
    """
    chunksize = chunksize or config.get("chunksize", 1)

    queue = Queue()

    if isinstance(result, list):
        result_flat = set(flatten(result))
    else:
        result_flat = {result}
    results = set(result_flat)

    dsk = dict(dsk)
    with local_callbacks(callbacks) as callbacks:
        _, _, pretask_cbs, posttask_cbs, _ = unpack_callbacks(callbacks)
        started_cbs = []
        succeeded = False
        # if start_state_from_dask fails, we will have something
        # to pass to the final block.
        state = {}
        try:
            for cb in callbacks:
                if cb[0]:
                    cb[0](dsk)
                started_cbs.append(cb)

            keyorder = order(dsk)

            state = start_state_from_dask(dsk,
                                          cache=cache,
                                          sortkey=keyorder.get)

            for _, start_state, _, _, _ in callbacks:
                if start_state:
                    start_state(dsk, state)

            if rerun_exceptions_locally is None:
                rerun_exceptions_locally = config.get(
                    "rerun_exceptions_locally", False)

            if state["waiting"] and not state["ready"]:
                raise ValueError("Found no accessible jobs in dask")

            def fire_tasks(chunksize):
                """Fire off a task to the thread pool"""
                # Determine chunksize and/or number of tasks to submit
                nready = len(state["ready"])
                if chunksize == -1:
                    ntasks = nready
                    chunksize = -(ntasks // -num_workers)
                else:
                    used_workers = -(len(state["running"]) // -chunksize)
                    avail_workers = max(num_workers - used_workers, 0)
                    ntasks = min(nready, chunksize * avail_workers)

                # Prep all ready tasks for submission
                args = []
                for _ in range(ntasks):
                    # Get the next task to compute (most recently added)
                    key = state["ready"].pop()
                    # Notify task is running
                    state["running"].add(key)
                    for f in pretask_cbs:
                        f(key, dsk, state)

                    # Prep args to send
                    data = {
                        dep: state["cache"][dep]
                        for dep in get_dependencies(dsk, key)
                    }
                    args.append((
                        key,
                        dumps((dsk[key], data)),
                        dumps,
                        loads,
                        get_id,
                        pack_exception,
                    ))

                # Batch submit
                for i in range(-(len(args) // -chunksize)):
                    each_args = args[i * chunksize:(i + 1) * chunksize]
                    if not each_args:
                        break
                    fut = submit(batch_execute_tasks, each_args)
                    fut.add_done_callback(queue.put)

            # Main loop, wait on tasks to finish, insert new ones
            while state["waiting"] or state["ready"] or state["running"]:
                fire_tasks(chunksize)
                for key, res_info, failed in queue_get(queue).result():
                    if failed:
                        exc, tb = loads(res_info)
                        if rerun_exceptions_locally:
                            data = {
                                dep: state["cache"][dep]
                                for dep in get_dependencies(dsk, key)
                            }
                            task = dsk[key]
                            _execute_task(task, data)  # Re-execute locally
                        else:
                            raise_exception(exc, tb)
                    res, worker_id = loads(res_info)
                    state["cache"][key] = res
                    finish_task(dsk, key, state, results, keyorder.get)
                    for f in posttask_cbs:
                        f(key, res, dsk, state, worker_id)

            succeeded = True

        finally:
            for _, _, _, _, finish in started_cbs:
                if finish:
                    finish(dsk, state, not succeeded)

    return nested_get(result, state["cache"])
Beispiel #6
0
def get_async(apply_async,
              num_workers,
              dsk,
              result,
              cache=None,
              get_id=default_get_id,
              rerun_exceptions_locally=None,
              pack_exception=default_pack_exception,
              raise_exception=reraise,
              callbacks=None,
              dumps=identity,
              loads=identity,
              **kwargs):
    """Asynchronous get function
    This is a general version of various asynchronous schedulers for dask.  It
    takes a an apply_async function as found on Pool objects to form a more
    specific ``get`` method that walks through the dask array with parallel
    workers, avoiding repeat computation and minimizing memory use.
    Parameters
    ----------
    apply_async : function
        Asynchronous apply function as found on Pool or ThreadPool
    num_workers : int
        The number of active tasks we should have at any one time
    dsk : dict
        A dask dictionary specifying a workflow
    result : key or list of keys
        Keys corresponding to desired data
    cache : dict-like, optional
        Temporary storage of results
    get_id : callable, optional
        Function to return the worker id, takes no arguments. Examples are
        `threading.current_thread` and `multiprocessing.current_process`.
    rerun_exceptions_locally : bool, optional
        Whether to rerun failing tasks in local process to enable debugging
        (False by default)
    pack_exception : callable, optional
        Function to take an exception and ``dumps`` method, and return a
        serialized tuple of ``(exception, traceback)`` to send back to the
        scheduler. Default is to just raise the exception.
    raise_exception : callable, optional
        Function that takes an exception and a traceback, and raises an error.
    dumps: callable, optional
        Function to serialize task data and results to communicate between
        worker and parent.  Defaults to identity.
    loads: callable, optional
        Inverse function of `dumps`.  Defaults to identity.
    callbacks : tuple or list of tuples, optional
        Callbacks are passed in as tuples of length 5. Multiple sets of
        callbacks may be passed in as a list of tuples. For more information,
        see the dask.diagnostics documentation.
    See Also
    --------
    threaded.get
    """
    queue = Queue()

    if isinstance(result, list):
        result_flat = set(flatten(result))
    else:
        result_flat = {result}
    results = set(result_flat)

    dsk = dict(dsk)
    with local_callbacks(callbacks) as callbacks:
        _, _, pretask_cbs, posttask_cbs, _ = unpack_callbacks(callbacks)
        started_cbs = []
        succeeded = False
        # if start_state_from_dask fails, we will have something
        # to pass to the final block.
        state = {}
        try:
            for cb in callbacks:
                if cb[0]:
                    cb[0](dsk)
                started_cbs.append(cb)

            keyorder = order(dsk)

            state = start_state_from_dask(dsk,
                                          cache=cache,
                                          sortkey=keyorder.get)

            for _, start_state, _, _, _ in callbacks:
                if start_state:
                    start_state(dsk, state)

            if rerun_exceptions_locally is None:
                rerun_exceptions_locally = config.get(
                    "rerun_exceptions_locally", False)

            if state["waiting"] and not state["ready"]:
                raise ValueError("Found no accessible jobs in dask")

            def fire_task():
                """ Fire off a task to the thread pool """
                # Choose a good task to compute
                key = state["ready"].pop()
                state["running"].add(key)
                for f in pretask_cbs:
                    f(key, dsk, state)

                # Prep data to send
                data = {
                    dep: state["cache"][dep]
                    for dep in get_dependencies(dsk, key)
                }
                # Submit
                apply_async(
                    execute_task,
                    args=(
                        key,
                        dumps((dsk[key], data)),
                        dumps,
                        loads,
                        get_id,
                        pack_exception,
                    ),
                    callback=queue.put,
                )

            # Seed initial tasks into the thread pool
            while state["ready"] and len(state["running"]) < num_workers:
                fire_task()

            # Main loop, wait on tasks to finish, insert new ones
            while state["waiting"] or state["ready"] or state["running"]:
                key, res_info, failed = queue_get(queue)
                if failed:
                    exc, tb = loads(res_info)
                    if rerun_exceptions_locally:
                        data = {
                            dep: state["cache"][dep]
                            for dep in get_dependencies(dsk, key)
                        }
                        task = dsk[key]
                        _execute_task(task, data)  # Re-execute locally
                    else:
                        raise_exception(exc, tb)
                res, worker_id = loads(res_info)
                state["cache"][key] = res
                finish_task(dsk, key, state, results, keyorder.get)
                for f in posttask_cbs:
                    f(key, res, dsk, state, worker_id)

                while state["ready"] and len(state["running"]) < num_workers:
                    fire_task()

            succeeded = True

        finally:
            for _, _, _, _, finish in started_cbs:
                if finish:
                    finish(dsk, state, not succeeded)

    return nested_get(result, state["cache"])
Beispiel #7
0
def _rayify_task(
        task,
        key,
        deps,
        ray_presubmit_cbs,
        ray_postsubmit_cbs,
        ray_pretask_cbs,
        ray_posttask_cbs,
):
    """
    Rayifies the given task, submitting it as a Ray task to the Ray cluster.

    Args:
        task (tuple): A Dask graph value, being either a literal, dependency
            key, Dask task, or a list thereof.
        key (str): The Dask graph key for the given task.
        deps (dict): The dependencies of this task.
        ray_presubmit_cbs (callable): Pre-task submission callbacks.
        ray_postsubmit_cbs (callable): Post-task submission callbacks.
        ray_pretask_cbs (callable): Pre-task execution callbacks.
        ray_posttask_cbs (callable): Post-task execution callbacks.

    Returns:
        A literal, a Ray object reference representing a submitted task, or a
        list thereof.
    """
    if isinstance(task, list):
        # Recursively rayify this list. This will still bottom out at the first
        # actual task encountered, inlining any tasks in that task's arguments.
        return [
            _rayify_task(
                t,
                key,
                deps,
                ray_presubmit_cbs,
                ray_postsubmit_cbs,
                ray_pretask_cbs,
                ray_posttask_cbs,
            ) for t in task
        ]
    elif istask(task):
        # Unpacks and repacks Ray object references and submits the task to the
        # Ray cluster for execution.
        if ray_presubmit_cbs is not None:
            alternate_returns = [
                cb(task, key, deps) for cb in ray_presubmit_cbs
            ]
            for alternate_return in alternate_returns:
                # We don't submit a Ray task if a presubmit callback returns
                # a non-`None` value, instead we return said value.
                # NOTE: This returns the first non-None presubmit callback
                # return value.
                if alternate_return is not None:
                    return alternate_return

        func, args = task[0], task[1:]
        if func is multiple_return_get:
            return _execute_task(task, deps)
        # If the function's arguments contain nested object references, we must
        # unpack said object references into a flat set of arguments so that
        # Ray properly tracks the object dependencies between Ray tasks.
        arg_object_refs, repack = unpack_object_refs(args, deps)
        # Submit the task using a wrapper function.
        object_refs = dask_task_wrapper.options(
            name=f"dask:{key!s}",
            num_returns=(1 if not isinstance(func, MultipleReturnFunc) else
                         func.num_returns),
        ).remote(
            func,
            repack,
            key,
            ray_pretask_cbs,
            ray_posttask_cbs,
            *arg_object_refs,
        )

        if ray_postsubmit_cbs is not None:
            for cb in ray_postsubmit_cbs:
                cb(task, key, deps, object_refs)

        return object_refs
    elif not ishashable(task):
        return task
    elif task in deps:
        return deps[task]
    else:
        return task