def ray_dask_get_sync(dsk, keys, **kwargs): """ A synchronous Dask-Ray scheduler. This scheduler will send top-level (non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will wait for the tasks to finish executing, fetch the results, and repackage them into the appropriate Dask collections. This particular scheduler submits Ray tasks synchronously, which can be useful for debugging. This can be passed directly to `dask.compute()`, as the scheduler: >>> dask.compute(obj, scheduler=ray_dask_get_sync) Args: dsk (Dict): Dask graph, represented as a task DAG dictionary. keys (List[str]): List of Dask graph keys whose values we wish to compute and return. Returns: Computed values corresponding to the provided keys. """ # NOTE: We hijack Dask's `get_async` function, injecting a different task # executor. object_refs = get_async( _apply_async_wrapper(apply_sync, _rayify_task_wrapper), 1, dsk, keys, **kwargs, ) # NOTE: We explicitly delete the Dask graph here so object references # are garbage-collected before this function returns, i.e. before all Ray # tasks are done. Otherwise, no intermediate objects will be cleaned up # until all Ray tasks are done. del dsk return ray_get_unpack(object_refs)
def get(dsk, keys, **kwargs): num_workers = kwargs.pop('num_workers', 100) lambda_name = kwargs.pop('lambda_name', 'sprite') invoker = kwargs.pop('invoker', sprite2.aws.remote) invoker = partial(invoker, lambda_name=lambda_name) if num_workers and num_workers > 1: executor = concurrent.futures.ThreadPoolExecutor(num_workers) apply = apply_async_lambda_factory( executor=executor, invoker=invoker, ) else: apply = apply_sync_lambda_factory(invoker=invoker, ) return get_async(apply, num_workers, dsk, keys, **kwargs)
def ray_dask_get(dsk, keys, **kwargs): """ A Dask-Ray scheduler. This scheduler will send top-level (non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will wait for the tasks to finish executing, fetch the results, and repackage them into the appropriate Dask collections. This particular scheduler uses a threadpool to submit Ray tasks. This can be passed directly to `dask.compute()`, as the scheduler: >>> dask.compute(obj, scheduler=ray_dask_get) You can override the number of threads to use when submitting the Ray tasks, or the threadpool used to submit Ray tasks: >>> dask.compute( obj, scheduler=ray_dask_get, num_workers=8, pool=some_cool_pool, ) Args: dsk (Dict): Dask graph, represented as a task DAG dictionary. keys (List[str]): List of Dask graph keys whose values we wish to compute and return. num_workers (Optional[int]): The number of worker threads to use in the Ray task submission traversal of the Dask graph. pool (Optional[ThreadPool]): A multiprocessing threadpool to use to submit Ray tasks. Returns: Computed values corresponding to the provided keys. """ num_workers = kwargs.pop("num_workers", None) pool = kwargs.pop("pool", None) # We attempt to reuse any other thread pools that have been created within # this thread and with the given number of workers. We reuse a global # thread pool if num_workers is not given and we're in the main thread. global default_pool thread = threading.current_thread() if pool is None: with pools_lock: if num_workers is None and thread is main_thread: if default_pool is None: default_pool = ThreadPool(CPU_COUNT) atexit.register(default_pool.close) pool = default_pool elif thread in pools and num_workers in pools[thread]: pool = pools[thread][num_workers] else: pool = ThreadPool(num_workers) atexit.register(pool.close) pools[thread][num_workers] = pool # NOTE: We hijack Dask's `get_async` function, injecting a different task # executor. object_refs = get_async( _apply_async_wrapper(pool.apply_async, _rayify_task_wrapper), len(pool._pool), dsk, keys, get_id=_thread_get_id, pack_exception=pack_exception, **kwargs, ) # NOTE: We explicitly delete the Dask graph here so object references # are garbage-collected before this function returns, i.e. before all Ray # tasks are done. Otherwise, no intermediate objects will be cleaned up # until all Ray tasks are done. del dsk result = ray_get_unpack(object_refs) # cleanup pools associated with dead threads. with pools_lock: active_threads = set(threading.enumerate()) if thread is not main_thread: for t in list(pools): if t not in active_threads: for p in pools.pop(t).values(): p.close() return result
def get(dsk, result, cache=None, num_workers=None, pool=None, **kwargs): """Threaded cached implementation of dask.get Parameters ---------- dsk: dict A dask dictionary specifying a workflow result: key or list of keys Keys corresponding to desired data num_workers: integer of thread count The number of threads to use in the ThreadPool that will actually execute tasks cache: dict-like (optional) Temporary storage of results Examples -------- >>> inc = lambda x: x + 1 >>> add = lambda x, y: x + y >>> dsk = {'x': 1, 'y': 2, 'z': (inc, 'x'), 'w': (add, 'z', 'y')} >>> get(dsk, 'w') 4 >>> get(dsk, ['w', 'y']) (4, 2) """ global default_pool pool = pool or config.get("pool", None) num_workers = num_workers or config.get("num_workers", None) thread = current_thread() with pools_lock: if pool is None: if num_workers is None and thread is main_thread: if default_pool is None: default_pool = ThreadPoolExecutor(CPU_COUNT) atexit.register(default_pool.shutdown) pool = default_pool elif thread in pools and num_workers in pools[thread]: pool = pools[thread][num_workers] else: pool = ThreadPoolExecutor(num_workers) atexit.register(pool.shutdown) pools[thread][num_workers] = pool elif isinstance(pool, multiprocessing.pool.Pool): pool = MultiprocessingPoolExecutor(pool) results = get_async( pool.submit, pool._max_workers, dsk, result, cache=cache, get_id=_thread_get_id, pack_exception=pack_exception, **kwargs, ) # Cleanup pools associated to dead threads with pools_lock: active_threads = set(threading.enumerate()) if thread is not main_thread: for t in list(pools): if t not in active_threads: for p in pools.pop(t).values(): p.shutdown() return results
def get( dsk: Mapping, keys: Sequence[Hashable] | Hashable, num_workers=None, func_loads=None, func_dumps=None, optimize_graph=True, pool=None, initializer=None, chunksize=None, **kwargs, ): """Multiprocessed get function appropriate for Bags Parameters ---------- dsk : dict dask graph keys : object or list Desired results from graph num_workers : int Number of worker processes (defaults to number of cores) func_dumps : function Function to use for function serialization (defaults to cloudpickle.dumps) func_loads : function Function to use for function deserialization (defaults to cloudpickle.loads) optimize_graph : bool If True [default], `fuse` is applied to the graph before computation. pool : Executor or Pool Some sort of `Executor` or `Pool` to use initializer: function Ignored if ``pool`` has been set. Function to initialize a worker process before running any tasks in it. chunksize: int, optional Size of chunks to use when dispatching work. Defaults to 5 as some batching is helpful. If -1, will be computed to evenly divide ready work across workers. """ chunksize = chunksize or config.get("chunksize", 6) pool = pool or config.get("pool", None) initializer = initializer or config.get("multiprocessing.initializer", None) num_workers = num_workers or config.get("num_workers", None) or CPU_COUNT if pool is None: # In order to get consistent hashing in subprocesses, we need to set a # consistent seed for the Python hash algorithm. Unfortunately, there # is no way to specify environment variables only for the Pool # processes, so we have to rely on environment variables being # inherited. if os.environ.get("PYTHONHASHSEED") in (None, "0"): # This number is arbitrary; it was chosen to commemorate # https://github.com/dask/dask/issues/6640. os.environ["PYTHONHASHSEED"] = "6640" context = get_context() initializer = partial(initialize_worker_process, user_initializer=initializer) pool = ProcessPoolExecutor(num_workers, mp_context=context, initializer=initializer) cleanup = True else: if initializer is not None: warn( "The ``initializer`` argument is ignored when ``pool`` is provided. " "The user should configure ``pool`` with the needed ``initializer`` " "on creation.") if isinstance(pool, multiprocessing.pool.Pool): pool = MultiprocessingPoolExecutor(pool) cleanup = False # Optimize Dask dsk = ensure_dict(dsk) dsk2, dependencies = cull(dsk, keys) if optimize_graph: dsk3, dependencies = fuse(dsk2, keys, dependencies) else: dsk3 = dsk2 # We specify marshalling functions in order to catch serialization # errors and report them to the user. loads = func_loads or config.get("func_loads", None) or _loads dumps = func_dumps or config.get("func_dumps", None) or _dumps # Note former versions used a multiprocessing Manager to share # a Queue between parent and workers, but this is fragile on Windows # (issue #1652). try: # Run result = get_async( pool.submit, pool._max_workers, dsk3, keys, get_id=_process_get_id, dumps=dumps, loads=loads, pack_exception=pack_exception, raise_exception=reraise, chunksize=chunksize, **kwargs, ) finally: if cleanup: pool.shutdown() return result
def ray_dask_get_sync(dsk, keys, **kwargs): """ A synchronous Dask-Ray scheduler. This scheduler will send top-level (non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will wait for the tasks to finish executing, fetch the results, and repackage them into the appropriate Dask collections. This particular scheduler submits Ray tasks synchronously, which can be useful for debugging. This can be passed directly to `dask.compute()`, as the scheduler: >>> dask.compute(obj, scheduler=ray_dask_get_sync) You can override the currently active global Dask-Ray callbacks (e.g. supplied via a context manager): >>> dask.compute( obj, scheduler=ray_dask_get_sync, ray_callbacks=some_ray_dask_callbacks, ) Args: dsk (Dict): Dask graph, represented as a task DAG dictionary. keys (List[str]): List of Dask graph keys whose values we wish to compute and return. Returns: Computed values corresponding to the provided keys. """ ray_callbacks = kwargs.pop("ray_callbacks", None) with local_ray_callbacks(ray_callbacks) as ray_callbacks: # Unpack the Ray-specific callbacks. ( ray_presubmit_cbs, ray_postsubmit_cbs, ray_pretask_cbs, ray_posttask_cbs, ray_postsubmit_all_cbs, ray_finish_cbs, ) = unpack_ray_callbacks(ray_callbacks) # NOTE: We hijack Dask's `get_async` function, injecting a different # task executor. object_refs = get_async( _apply_async_wrapper( apply_sync, _rayify_task_wrapper, ray_presubmit_cbs, ray_postsubmit_cbs, ray_pretask_cbs, ray_posttask_cbs, ), 1, dsk, keys, **kwargs, ) if ray_postsubmit_all_cbs is not None: for cb in ray_postsubmit_all_cbs: cb(object_refs, dsk) # NOTE: We explicitly delete the Dask graph here so object references # are garbage-collected before this function returns, i.e. before all # Ray tasks are done. Otherwise, no intermediate objects will be # cleaned up until all Ray tasks are done. del dsk result = ray_get_unpack(object_refs) if ray_finish_cbs is not None: for cb in ray_finish_cbs: cb(result) return result