Esempio n. 1
0
def ray_dask_get(dsk, keys, **kwargs):
    """
    A Dask-Ray scheduler. This scheduler will send top-level (non-inlined) Dask
    tasks to a Ray cluster for execution. The scheduler will wait for the
    tasks to finish executing, fetch the results, and repackage them into the
    appropriate Dask collections. This particular scheduler uses a threadpool
    to submit Ray tasks.

    This can be passed directly to `dask.compute()`, as the scheduler:

    >>> dask.compute(obj, scheduler=ray_dask_get)

    You can override the currently active global Dask-Ray callbacks (e.g.
    supplied via a context manager), the number of threads to use when
    submitting the Ray tasks, or the threadpool used to submit Ray tasks:

    >>> dask.compute(
            obj,
            scheduler=ray_dask_get,
            ray_callbacks=some_ray_dask_callbacks,
            num_workers=8,
            pool=some_cool_pool,
        )

    Args:
        dsk: Dask graph, represented as a task DAG dictionary.
        keys (List[str]): List of Dask graph keys whose values we wish to
            compute and return.
        ray_callbacks (Optional[list[callable]]): Dask-Ray callbacks.
        num_workers (Optional[int]): The number of worker threads to use in
            the Ray task submission traversal of the Dask graph.
        pool (Optional[ThreadPool]): A multiprocessing threadpool to use to
            submit Ray tasks.

    Returns:
        Computed values corresponding to the provided keys.
    """
    num_workers = kwargs.pop("num_workers", None)
    pool = kwargs.pop("pool", None)
    # We attempt to reuse any other thread pools that have been created within
    # this thread and with the given number of workers. We reuse a global
    # thread pool if num_workers is not given and we're in the main thread.
    global default_pool
    thread = threading.current_thread()
    if pool is None:
        with pools_lock:
            if num_workers is None and thread is main_thread:
                if default_pool is None:
                    default_pool = ThreadPool(CPU_COUNT)
                    atexit.register(default_pool.close)
                pool = default_pool
            elif thread in pools and num_workers in pools[thread]:
                pool = pools[thread][num_workers]
            else:
                pool = ThreadPool(num_workers)
                atexit.register(pool.close)
                pools[thread][num_workers] = pool

    ray_callbacks = kwargs.pop("ray_callbacks", None)
    persist = kwargs.pop("ray_persist", False)
    enable_progress_bar = kwargs.pop("_ray_enable_progress_bar", None)

    # Handle Ray remote args and resource annotations.
    if "resources" in kwargs:
        raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG)
    ray_remote_args = kwargs.pop("ray_remote_args", {})
    try:
        annotations = dask.config.get("annotations")
    except KeyError:
        annotations = {}
    if "resources" in annotations:
        raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG)

    scoped_ray_remote_args = _build_key_scoped_ray_remote_args(
        dsk, annotations, ray_remote_args)

    with local_ray_callbacks(ray_callbacks) as ray_callbacks:
        # Unpack the Ray-specific callbacks.
        (
            ray_presubmit_cbs,
            ray_postsubmit_cbs,
            ray_pretask_cbs,
            ray_posttask_cbs,
            ray_postsubmit_all_cbs,
            ray_finish_cbs,
        ) = unpack_ray_callbacks(ray_callbacks)
        # NOTE: We hijack Dask's `get_async` function, injecting a different
        # task executor.
        object_refs = get_async(
            _apply_async_wrapper(
                pool.apply_async,
                _rayify_task_wrapper,
                ray_presubmit_cbs,
                ray_postsubmit_cbs,
                ray_pretask_cbs,
                ray_posttask_cbs,
                scoped_ray_remote_args,
            ),
            len(pool._pool),
            dsk,
            keys,
            get_id=_thread_get_id,
            pack_exception=pack_exception,
            **kwargs,
        )
        if ray_postsubmit_all_cbs is not None:
            for cb in ray_postsubmit_all_cbs:
                cb(object_refs, dsk)
        # NOTE: We explicitly delete the Dask graph here so object references
        # are garbage-collected before this function returns, i.e. before all
        # Ray tasks are done. Otherwise, no intermediate objects will be
        # cleaned up until all Ray tasks are done.
        del dsk
        if persist:
            result = object_refs
        else:
            pb_actor = None
            if enable_progress_bar:
                pb_actor = ray.get_actor("_dask_on_ray_pb")
            result = ray_get_unpack(object_refs, progress_bar_actor=pb_actor)
        if ray_finish_cbs is not None:
            for cb in ray_finish_cbs:
                cb(result)

    # cleanup pools associated with dead threads.
    with pools_lock:
        active_threads = set(threading.enumerate())
        if thread is not main_thread:
            for t in list(pools):
                if t not in active_threads:
                    for p in pools.pop(t).values():
                        p.close()
    return result
Esempio n. 2
0
def ray_dask_get_sync(dsk, keys, **kwargs):
    """
    A synchronous Dask-Ray scheduler. This scheduler will send top-level
    (non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will
    wait for the tasks to finish executing, fetch the results, and repackage
    them into the appropriate Dask collections. This particular scheduler
    submits Ray tasks synchronously, which can be useful for debugging.

    This can be passed directly to `dask.compute()`, as the scheduler:

    >>> dask.compute(obj, scheduler=ray_dask_get_sync)

    You can override the currently active global Dask-Ray callbacks (e.g.
    supplied via a context manager):

    >>> dask.compute(
            obj,
            scheduler=ray_dask_get_sync,
            ray_callbacks=some_ray_dask_callbacks,
        )

    Args:
        dsk: Dask graph, represented as a task DAG dictionary.
        keys (List[str]): List of Dask graph keys whose values we wish to
            compute and return.

    Returns:
        Computed values corresponding to the provided keys.
    """

    ray_callbacks = kwargs.pop("ray_callbacks", None)
    persist = kwargs.pop("ray_persist", False)

    with local_ray_callbacks(ray_callbacks) as ray_callbacks:
        # Unpack the Ray-specific callbacks.
        (
            ray_presubmit_cbs,
            ray_postsubmit_cbs,
            ray_pretask_cbs,
            ray_posttask_cbs,
            ray_postsubmit_all_cbs,
            ray_finish_cbs,
        ) = unpack_ray_callbacks(ray_callbacks)
        # NOTE: We hijack Dask's `get_async` function, injecting a different
        # task executor.
        object_refs = get_async(
            _apply_async_wrapper(
                apply_sync,
                _rayify_task_wrapper,
                ray_presubmit_cbs,
                ray_postsubmit_cbs,
                ray_pretask_cbs,
                ray_posttask_cbs,
            ),
            1,
            dsk,
            keys,
            **kwargs,
        )
        if ray_postsubmit_all_cbs is not None:
            for cb in ray_postsubmit_all_cbs:
                cb(object_refs, dsk)
        # NOTE: We explicitly delete the Dask graph here so object references
        # are garbage-collected before this function returns, i.e. before all
        # Ray tasks are done. Otherwise, no intermediate objects will be
        # cleaned up until all Ray tasks are done.
        del dsk
        if persist:
            result = object_refs
        else:
            result = ray_get_unpack(object_refs)
        if ray_finish_cbs is not None:
            for cb in ray_finish_cbs:
                cb(result)

        return result