コード例 #1
0
ファイル: scheduler.py プロジェクト: MetaMind/ray-internal
def ray_dask_get_sync(dsk, keys, **kwargs):
    """
    A synchronous Dask-Ray scheduler. This scheduler will send top-level
    (non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will
    wait for the tasks to finish executing, fetch the results, and repackage
    them into the appropriate Dask collections. This particular scheduler
    submits Ray tasks synchronously, which can be useful for debugging.

    This can be passed directly to `dask.compute()`, as the scheduler:

    >>> dask.compute(obj, scheduler=ray_dask_get_sync)

    Args:
        dsk (Dict): Dask graph, represented as a task DAG dictionary.
        keys (List[str]): List of Dask graph keys whose values we wish to
            compute and return.

    Returns:
        Computed values corresponding to the provided keys.
    """
    # NOTE: We hijack Dask's `get_async` function, injecting a different task
    # executor.
    object_refs = get_async(
        _apply_async_wrapper(apply_sync, _rayify_task_wrapper),
        1,
        dsk,
        keys,
        **kwargs,
    )
    # NOTE: We explicitly delete the Dask graph here so object references
    # are garbage-collected before this function returns, i.e. before all Ray
    # tasks are done. Otherwise, no intermediate objects will be cleaned up
    # until all Ray tasks are done.
    del dsk
    return ray_get_unpack(object_refs)
コード例 #2
0
def get(dsk, keys, **kwargs):
    num_workers = kwargs.pop('num_workers', 100)
    lambda_name = kwargs.pop('lambda_name', 'sprite')
    invoker = kwargs.pop('invoker', sprite2.aws.remote)
    invoker = partial(invoker, lambda_name=lambda_name)
    if num_workers and num_workers > 1:
        executor = concurrent.futures.ThreadPoolExecutor(num_workers)
        apply = apply_async_lambda_factory(
            executor=executor,
            invoker=invoker,
        )
    else:
        apply = apply_sync_lambda_factory(invoker=invoker, )
    return get_async(apply, num_workers, dsk, keys, **kwargs)
コード例 #3
0
ファイル: scheduler.py プロジェクト: MetaMind/ray-internal
def ray_dask_get(dsk, keys, **kwargs):
    """
    A Dask-Ray scheduler. This scheduler will send top-level (non-inlined) Dask
    tasks to a Ray cluster for execution. The scheduler will wait for the
    tasks to finish executing, fetch the results, and repackage them into the
    appropriate Dask collections. This particular scheduler uses a threadpool
    to submit Ray tasks.

    This can be passed directly to `dask.compute()`, as the scheduler:

    >>> dask.compute(obj, scheduler=ray_dask_get)

    You can override the number of threads to use when submitting the
    Ray tasks, or the threadpool used to submit Ray tasks:

    >>> dask.compute(
            obj,
            scheduler=ray_dask_get,
            num_workers=8,
            pool=some_cool_pool,
        )

    Args:
        dsk (Dict): Dask graph, represented as a task DAG dictionary.
        keys (List[str]): List of Dask graph keys whose values we wish to
            compute and return.
        num_workers (Optional[int]): The number of worker threads to use in
            the Ray task submission traversal of the Dask graph.
        pool (Optional[ThreadPool]): A multiprocessing threadpool to use to
            submit Ray tasks.

    Returns:
        Computed values corresponding to the provided keys.
    """
    num_workers = kwargs.pop("num_workers", None)
    pool = kwargs.pop("pool", None)
    # We attempt to reuse any other thread pools that have been created within
    # this thread and with the given number of workers. We reuse a global
    # thread pool if num_workers is not given and we're in the main thread.
    global default_pool
    thread = threading.current_thread()
    if pool is None:
        with pools_lock:
            if num_workers is None and thread is main_thread:
                if default_pool is None:
                    default_pool = ThreadPool(CPU_COUNT)
                    atexit.register(default_pool.close)
                pool = default_pool
            elif thread in pools and num_workers in pools[thread]:
                pool = pools[thread][num_workers]
            else:
                pool = ThreadPool(num_workers)
                atexit.register(pool.close)
                pools[thread][num_workers] = pool

    # NOTE: We hijack Dask's `get_async` function, injecting a different task
    # executor.
    object_refs = get_async(
        _apply_async_wrapper(pool.apply_async, _rayify_task_wrapper),
        len(pool._pool),
        dsk,
        keys,
        get_id=_thread_get_id,
        pack_exception=pack_exception,
        **kwargs,
    )
    # NOTE: We explicitly delete the Dask graph here so object references
    # are garbage-collected before this function returns, i.e. before all Ray
    # tasks are done. Otherwise, no intermediate objects will be cleaned up
    # until all Ray tasks are done.
    del dsk
    result = ray_get_unpack(object_refs)

    # cleanup pools associated with dead threads.
    with pools_lock:
        active_threads = set(threading.enumerate())
        if thread is not main_thread:
            for t in list(pools):
                if t not in active_threads:
                    for p in pools.pop(t).values():
                        p.close()
    return result
コード例 #4
0
ファイル: threaded.py プロジェクト: jakirkham/dask
def get(dsk, result, cache=None, num_workers=None, pool=None, **kwargs):
    """Threaded cached implementation of dask.get

    Parameters
    ----------

    dsk: dict
        A dask dictionary specifying a workflow
    result: key or list of keys
        Keys corresponding to desired data
    num_workers: integer of thread count
        The number of threads to use in the ThreadPool that will actually execute tasks
    cache: dict-like (optional)
        Temporary storage of results

    Examples
    --------
    >>> inc = lambda x: x + 1
    >>> add = lambda x, y: x + y
    >>> dsk = {'x': 1, 'y': 2, 'z': (inc, 'x'), 'w': (add, 'z', 'y')}
    >>> get(dsk, 'w')
    4
    >>> get(dsk, ['w', 'y'])
    (4, 2)
    """
    global default_pool
    pool = pool or config.get("pool", None)
    num_workers = num_workers or config.get("num_workers", None)
    thread = current_thread()

    with pools_lock:
        if pool is None:
            if num_workers is None and thread is main_thread:
                if default_pool is None:
                    default_pool = ThreadPoolExecutor(CPU_COUNT)
                    atexit.register(default_pool.shutdown)
                pool = default_pool
            elif thread in pools and num_workers in pools[thread]:
                pool = pools[thread][num_workers]
            else:
                pool = ThreadPoolExecutor(num_workers)
                atexit.register(pool.shutdown)
                pools[thread][num_workers] = pool
        elif isinstance(pool, multiprocessing.pool.Pool):
            pool = MultiprocessingPoolExecutor(pool)

    results = get_async(
        pool.submit,
        pool._max_workers,
        dsk,
        result,
        cache=cache,
        get_id=_thread_get_id,
        pack_exception=pack_exception,
        **kwargs,
    )

    # Cleanup pools associated to dead threads
    with pools_lock:
        active_threads = set(threading.enumerate())
        if thread is not main_thread:
            for t in list(pools):
                if t not in active_threads:
                    for p in pools.pop(t).values():
                        p.shutdown()

    return results
コード例 #5
0
ファイル: multiprocessing.py プロジェクト: m-rossi/dask
def get(
    dsk: Mapping,
    keys: Sequence[Hashable] | Hashable,
    num_workers=None,
    func_loads=None,
    func_dumps=None,
    optimize_graph=True,
    pool=None,
    initializer=None,
    chunksize=None,
    **kwargs,
):
    """Multiprocessed get function appropriate for Bags

    Parameters
    ----------
    dsk : dict
        dask graph
    keys : object or list
        Desired results from graph
    num_workers : int
        Number of worker processes (defaults to number of cores)
    func_dumps : function
        Function to use for function serialization (defaults to cloudpickle.dumps)
    func_loads : function
        Function to use for function deserialization (defaults to cloudpickle.loads)
    optimize_graph : bool
        If True [default], `fuse` is applied to the graph before computation.
    pool : Executor or Pool
        Some sort of `Executor` or `Pool` to use
    initializer: function
        Ignored if ``pool`` has been set.
        Function to initialize a worker process before running any tasks in it.
    chunksize: int, optional
        Size of chunks to use when dispatching work.
        Defaults to 5 as some batching is helpful.
        If -1, will be computed to evenly divide ready work across workers.
    """
    chunksize = chunksize or config.get("chunksize", 6)
    pool = pool or config.get("pool", None)
    initializer = initializer or config.get("multiprocessing.initializer",
                                            None)
    num_workers = num_workers or config.get("num_workers", None) or CPU_COUNT
    if pool is None:
        # In order to get consistent hashing in subprocesses, we need to set a
        # consistent seed for the Python hash algorithm. Unfortunately, there
        # is no way to specify environment variables only for the Pool
        # processes, so we have to rely on environment variables being
        # inherited.
        if os.environ.get("PYTHONHASHSEED") in (None, "0"):
            # This number is arbitrary; it was chosen to commemorate
            # https://github.com/dask/dask/issues/6640.
            os.environ["PYTHONHASHSEED"] = "6640"
        context = get_context()
        initializer = partial(initialize_worker_process,
                              user_initializer=initializer)
        pool = ProcessPoolExecutor(num_workers,
                                   mp_context=context,
                                   initializer=initializer)
        cleanup = True
    else:
        if initializer is not None:
            warn(
                "The ``initializer`` argument is ignored when ``pool`` is provided. "
                "The user should configure ``pool`` with the needed ``initializer`` "
                "on creation.")
        if isinstance(pool, multiprocessing.pool.Pool):
            pool = MultiprocessingPoolExecutor(pool)
        cleanup = False

    # Optimize Dask
    dsk = ensure_dict(dsk)
    dsk2, dependencies = cull(dsk, keys)
    if optimize_graph:
        dsk3, dependencies = fuse(dsk2, keys, dependencies)
    else:
        dsk3 = dsk2

    # We specify marshalling functions in order to catch serialization
    # errors and report them to the user.
    loads = func_loads or config.get("func_loads", None) or _loads
    dumps = func_dumps or config.get("func_dumps", None) or _dumps

    # Note former versions used a multiprocessing Manager to share
    # a Queue between parent and workers, but this is fragile on Windows
    # (issue #1652).
    try:
        # Run
        result = get_async(
            pool.submit,
            pool._max_workers,
            dsk3,
            keys,
            get_id=_process_get_id,
            dumps=dumps,
            loads=loads,
            pack_exception=pack_exception,
            raise_exception=reraise,
            chunksize=chunksize,
            **kwargs,
        )
    finally:
        if cleanup:
            pool.shutdown()
    return result
コード例 #6
0
def ray_dask_get_sync(dsk, keys, **kwargs):
    """
    A synchronous Dask-Ray scheduler. This scheduler will send top-level
    (non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will
    wait for the tasks to finish executing, fetch the results, and repackage
    them into the appropriate Dask collections. This particular scheduler
    submits Ray tasks synchronously, which can be useful for debugging.

    This can be passed directly to `dask.compute()`, as the scheduler:

    >>> dask.compute(obj, scheduler=ray_dask_get_sync)

    You can override the currently active global Dask-Ray callbacks (e.g.
    supplied via a context manager):

    >>> dask.compute(
            obj,
            scheduler=ray_dask_get_sync,
            ray_callbacks=some_ray_dask_callbacks,
        )

    Args:
        dsk (Dict): Dask graph, represented as a task DAG dictionary.
        keys (List[str]): List of Dask graph keys whose values we wish to
            compute and return.

    Returns:
        Computed values corresponding to the provided keys.
    """

    ray_callbacks = kwargs.pop("ray_callbacks", None)

    with local_ray_callbacks(ray_callbacks) as ray_callbacks:
        # Unpack the Ray-specific callbacks.
        (
            ray_presubmit_cbs,
            ray_postsubmit_cbs,
            ray_pretask_cbs,
            ray_posttask_cbs,
            ray_postsubmit_all_cbs,
            ray_finish_cbs,
        ) = unpack_ray_callbacks(ray_callbacks)
        # NOTE: We hijack Dask's `get_async` function, injecting a different
        # task executor.
        object_refs = get_async(
            _apply_async_wrapper(
                apply_sync,
                _rayify_task_wrapper,
                ray_presubmit_cbs,
                ray_postsubmit_cbs,
                ray_pretask_cbs,
                ray_posttask_cbs,
            ),
            1,
            dsk,
            keys,
            **kwargs,
        )
        if ray_postsubmit_all_cbs is not None:
            for cb in ray_postsubmit_all_cbs:
                cb(object_refs, dsk)
        # NOTE: We explicitly delete the Dask graph here so object references
        # are garbage-collected before this function returns, i.e. before all
        # Ray tasks are done. Otherwise, no intermediate objects will be
        # cleaned up until all Ray tasks are done.
        del dsk
        result = ray_get_unpack(object_refs)
        if ray_finish_cbs is not None:
            for cb in ray_finish_cbs:
                cb(result)

        return result