def dask_task_wrapper(func, repack, *args): """ A Ray remote function acting as a Dask task wrapper. This function will repackage the given flat `args` into its original data structures using `repack`, execute any Dask subtasks within the repackaged arguments (inlined by Dask's optimization pass), and then pass the concrete task arguments to the provide Dask task function, `func`. Args: func (callable): The Dask task function to execute. repack (callable): A function that repackages the provided args into the original (possibly nested) Python objects. *args (ObjectRef): Ray object references representing the Dask task's arguments. Returns: The output of the Dask task. In the context of Ray, a dask_task_wrapper.remote() invocation will return a Ray object reference representing the Ray task's result. """ repacked_args, repacked_deps = repack(args) # Recursively execute Dask-inlined tasks. actual_args = [_execute_task(a, repacked_deps) for a in repacked_args] # Execute the actual underlying Dask task. return func(*actual_args)
def cache_entry(cache, key, task): with cache.lock: try: return cache.cache[key] except KeyError: cache.cache[key] = value = _execute_task(task, {}) return value
def execute_task(key, task_info, dumps, loads, get_id, pack_exception): """ Compute task and handle all administration See Also -------- _execute_task : actually execute task """ try: task, data = loads(task_info) result = _execute_task(task, data) id = get_id() result = dumps((result, id)) failed = False except BaseException as e: result = pack_exception(e, dumps) failed = True return key, result, failed
def dask_task_wrapper(func, repack, key, ray_pretask_cbs, ray_posttask_cbs, *args): """ A Ray remote function acting as a Dask task wrapper. This function will repackage the given flat `args` into its original data structures using `repack`, execute any Dask subtasks within the repackaged arguments (inlined by Dask's optimization pass), and then pass the concrete task arguments to the provide Dask task function, `func`. Args: func (callable): The Dask task function to execute. repack (callable): A function that repackages the provided args into the original (possibly nested) Python objects. key (str): The Dask key for this task. ray_pretask_cbs (callable): Pre-task execution callbacks. ray_posttask_cbs (callable): Post-task execution callback. *args (ObjectRef): Ray object references representing the Dask task's arguments. Returns: The output of the Dask task. In the context of Ray, a dask_task_wrapper.remote() invocation will return a Ray object reference representing the Ray task's result. """ if ray_pretask_cbs is not None: pre_states = [ cb(key, args) if cb is not None else None for cb in ray_pretask_cbs ] repacked_args, repacked_deps = repack(args) # Recursively execute Dask-inlined tasks. actual_args = [_execute_task(a, repacked_deps) for a in repacked_args] # Execute the actual underlying Dask task. result = func(*actual_args) if ray_posttask_cbs is not None: for cb, pre_state in zip(ray_posttask_cbs, pre_states): if cb is not None: cb(key, result, pre_state) return result
def get_async( submit, num_workers, dsk, result, cache=None, get_id=default_get_id, rerun_exceptions_locally=None, pack_exception=default_pack_exception, raise_exception=reraise, callbacks=None, dumps=identity, loads=identity, chunksize=None, **kwargs, ): """Asynchronous get function This is a general version of various asynchronous schedulers for dask. It takes a ``concurrent.futures.Executor.submit`` function to form a more specific ``get`` method that walks through the dask array with parallel workers, avoiding repeat computation and minimizing memory use. Parameters ---------- submit : function A ``concurrent.futures.Executor.submit`` function num_workers : int The number of workers that task submissions can be spread over dsk : dict A dask dictionary specifying a workflow result : key or list of keys Keys corresponding to desired data cache : dict-like, optional Temporary storage of results get_id : callable, optional Function to return the worker id, takes no arguments. Examples are `threading.current_thread` and `multiprocessing.current_process`. rerun_exceptions_locally : bool, optional Whether to rerun failing tasks in local process to enable debugging (False by default) pack_exception : callable, optional Function to take an exception and ``dumps`` method, and return a serialized tuple of ``(exception, traceback)`` to send back to the scheduler. Default is to just raise the exception. raise_exception : callable, optional Function that takes an exception and a traceback, and raises an error. callbacks : tuple or list of tuples, optional Callbacks are passed in as tuples of length 5. Multiple sets of callbacks may be passed in as a list of tuples. For more information, see the dask.diagnostics documentation. dumps: callable, optional Function to serialize task data and results to communicate between worker and parent. Defaults to identity. loads: callable, optional Inverse function of `dumps`. Defaults to identity. chunksize: int, optional Size of chunks to use when dispatching work. Defaults to 1. If -1, will be computed to evenly divide ready work across workers. See Also -------- threaded.get """ chunksize = chunksize or config.get("chunksize", 1) queue = Queue() if isinstance(result, list): result_flat = set(flatten(result)) else: result_flat = {result} results = set(result_flat) dsk = dict(dsk) with local_callbacks(callbacks) as callbacks: _, _, pretask_cbs, posttask_cbs, _ = unpack_callbacks(callbacks) started_cbs = [] succeeded = False # if start_state_from_dask fails, we will have something # to pass to the final block. state = {} try: for cb in callbacks: if cb[0]: cb[0](dsk) started_cbs.append(cb) keyorder = order(dsk) state = start_state_from_dask(dsk, cache=cache, sortkey=keyorder.get) for _, start_state, _, _, _ in callbacks: if start_state: start_state(dsk, state) if rerun_exceptions_locally is None: rerun_exceptions_locally = config.get( "rerun_exceptions_locally", False) if state["waiting"] and not state["ready"]: raise ValueError("Found no accessible jobs in dask") def fire_tasks(chunksize): """Fire off a task to the thread pool""" # Determine chunksize and/or number of tasks to submit nready = len(state["ready"]) if chunksize == -1: ntasks = nready chunksize = -(ntasks // -num_workers) else: used_workers = -(len(state["running"]) // -chunksize) avail_workers = max(num_workers - used_workers, 0) ntasks = min(nready, chunksize * avail_workers) # Prep all ready tasks for submission args = [] for _ in range(ntasks): # Get the next task to compute (most recently added) key = state["ready"].pop() # Notify task is running state["running"].add(key) for f in pretask_cbs: f(key, dsk, state) # Prep args to send data = { dep: state["cache"][dep] for dep in get_dependencies(dsk, key) } args.append(( key, dumps((dsk[key], data)), dumps, loads, get_id, pack_exception, )) # Batch submit for i in range(-(len(args) // -chunksize)): each_args = args[i * chunksize:(i + 1) * chunksize] if not each_args: break fut = submit(batch_execute_tasks, each_args) fut.add_done_callback(queue.put) # Main loop, wait on tasks to finish, insert new ones while state["waiting"] or state["ready"] or state["running"]: fire_tasks(chunksize) for key, res_info, failed in queue_get(queue).result(): if failed: exc, tb = loads(res_info) if rerun_exceptions_locally: data = { dep: state["cache"][dep] for dep in get_dependencies(dsk, key) } task = dsk[key] _execute_task(task, data) # Re-execute locally else: raise_exception(exc, tb) res, worker_id = loads(res_info) state["cache"][key] = res finish_task(dsk, key, state, results, keyorder.get) for f in posttask_cbs: f(key, res, dsk, state, worker_id) succeeded = True finally: for _, _, _, _, finish in started_cbs: if finish: finish(dsk, state, not succeeded) return nested_get(result, state["cache"])
def get_async(apply_async, num_workers, dsk, result, cache=None, get_id=default_get_id, rerun_exceptions_locally=None, pack_exception=default_pack_exception, raise_exception=reraise, callbacks=None, dumps=identity, loads=identity, **kwargs): """Asynchronous get function This is a general version of various asynchronous schedulers for dask. It takes a an apply_async function as found on Pool objects to form a more specific ``get`` method that walks through the dask array with parallel workers, avoiding repeat computation and minimizing memory use. Parameters ---------- apply_async : function Asynchronous apply function as found on Pool or ThreadPool num_workers : int The number of active tasks we should have at any one time dsk : dict A dask dictionary specifying a workflow result : key or list of keys Keys corresponding to desired data cache : dict-like, optional Temporary storage of results get_id : callable, optional Function to return the worker id, takes no arguments. Examples are `threading.current_thread` and `multiprocessing.current_process`. rerun_exceptions_locally : bool, optional Whether to rerun failing tasks in local process to enable debugging (False by default) pack_exception : callable, optional Function to take an exception and ``dumps`` method, and return a serialized tuple of ``(exception, traceback)`` to send back to the scheduler. Default is to just raise the exception. raise_exception : callable, optional Function that takes an exception and a traceback, and raises an error. dumps: callable, optional Function to serialize task data and results to communicate between worker and parent. Defaults to identity. loads: callable, optional Inverse function of `dumps`. Defaults to identity. callbacks : tuple or list of tuples, optional Callbacks are passed in as tuples of length 5. Multiple sets of callbacks may be passed in as a list of tuples. For more information, see the dask.diagnostics documentation. See Also -------- threaded.get """ queue = Queue() if isinstance(result, list): result_flat = set(flatten(result)) else: result_flat = {result} results = set(result_flat) dsk = dict(dsk) with local_callbacks(callbacks) as callbacks: _, _, pretask_cbs, posttask_cbs, _ = unpack_callbacks(callbacks) started_cbs = [] succeeded = False # if start_state_from_dask fails, we will have something # to pass to the final block. state = {} try: for cb in callbacks: if cb[0]: cb[0](dsk) started_cbs.append(cb) keyorder = order(dsk) state = start_state_from_dask(dsk, cache=cache, sortkey=keyorder.get) for _, start_state, _, _, _ in callbacks: if start_state: start_state(dsk, state) if rerun_exceptions_locally is None: rerun_exceptions_locally = config.get( "rerun_exceptions_locally", False) if state["waiting"] and not state["ready"]: raise ValueError("Found no accessible jobs in dask") def fire_task(): """ Fire off a task to the thread pool """ # Choose a good task to compute key = state["ready"].pop() state["running"].add(key) for f in pretask_cbs: f(key, dsk, state) # Prep data to send data = { dep: state["cache"][dep] for dep in get_dependencies(dsk, key) } # Submit apply_async( execute_task, args=( key, dumps((dsk[key], data)), dumps, loads, get_id, pack_exception, ), callback=queue.put, ) # Seed initial tasks into the thread pool while state["ready"] and len(state["running"]) < num_workers: fire_task() # Main loop, wait on tasks to finish, insert new ones while state["waiting"] or state["ready"] or state["running"]: key, res_info, failed = queue_get(queue) if failed: exc, tb = loads(res_info) if rerun_exceptions_locally: data = { dep: state["cache"][dep] for dep in get_dependencies(dsk, key) } task = dsk[key] _execute_task(task, data) # Re-execute locally else: raise_exception(exc, tb) res, worker_id = loads(res_info) state["cache"][key] = res finish_task(dsk, key, state, results, keyorder.get) for f in posttask_cbs: f(key, res, dsk, state, worker_id) while state["ready"] and len(state["running"]) < num_workers: fire_task() succeeded = True finally: for _, _, _, _, finish in started_cbs: if finish: finish(dsk, state, not succeeded) return nested_get(result, state["cache"])
def _rayify_task( task, key, deps, ray_presubmit_cbs, ray_postsubmit_cbs, ray_pretask_cbs, ray_posttask_cbs, ): """ Rayifies the given task, submitting it as a Ray task to the Ray cluster. Args: task (tuple): A Dask graph value, being either a literal, dependency key, Dask task, or a list thereof. key (str): The Dask graph key for the given task. deps (dict): The dependencies of this task. ray_presubmit_cbs (callable): Pre-task submission callbacks. ray_postsubmit_cbs (callable): Post-task submission callbacks. ray_pretask_cbs (callable): Pre-task execution callbacks. ray_posttask_cbs (callable): Post-task execution callbacks. Returns: A literal, a Ray object reference representing a submitted task, or a list thereof. """ if isinstance(task, list): # Recursively rayify this list. This will still bottom out at the first # actual task encountered, inlining any tasks in that task's arguments. return [ _rayify_task( t, key, deps, ray_presubmit_cbs, ray_postsubmit_cbs, ray_pretask_cbs, ray_posttask_cbs, ) for t in task ] elif istask(task): # Unpacks and repacks Ray object references and submits the task to the # Ray cluster for execution. if ray_presubmit_cbs is not None: alternate_returns = [ cb(task, key, deps) for cb in ray_presubmit_cbs ] for alternate_return in alternate_returns: # We don't submit a Ray task if a presubmit callback returns # a non-`None` value, instead we return said value. # NOTE: This returns the first non-None presubmit callback # return value. if alternate_return is not None: return alternate_return func, args = task[0], task[1:] if func is multiple_return_get: return _execute_task(task, deps) # If the function's arguments contain nested object references, we must # unpack said object references into a flat set of arguments so that # Ray properly tracks the object dependencies between Ray tasks. arg_object_refs, repack = unpack_object_refs(args, deps) # Submit the task using a wrapper function. object_refs = dask_task_wrapper.options( name=f"dask:{key!s}", num_returns=(1 if not isinstance(func, MultipleReturnFunc) else func.num_returns), ).remote( func, repack, key, ray_pretask_cbs, ray_posttask_cbs, *arg_object_refs, ) if ray_postsubmit_cbs is not None: for cb in ray_postsubmit_cbs: cb(task, key, deps, object_refs) return object_refs elif not ishashable(task): return task elif task in deps: return deps[task] else: return task