Ejemplo n.º 1
0
def ray_get_unpack(object_refs, progress_bar_actor=None):
    """
    Unpacks object references, gets the object references, and repacks.
    Traverses arbitrary data structures.

    Args:
        object_refs: A (potentially nested) Python object containing Ray object
            references.

    Returns:
        The input Python object with all contained Ray object references
        resolved with their concrete values.
    """
    def get_result(object_refs):
        if progress_bar_actor:
            render_progress_bar(progress_bar_actor, object_refs)
        return ray.get(object_refs)

    if isinstance(object_refs, tuple):
        object_refs = list(object_refs)

    if isinstance(object_refs, list) and any(not isinstance(x, ray.ObjectRef)
                                             for x in object_refs):
        # We flatten the object references before calling ray.get(), since Dask
        # loves to nest collections in nested tuples and Ray expects a flat
        # list of object references. We repack the results after ray.get()
        # completes.
        object_refs, repack = unpack_object_refs(*object_refs)
        computed_result = get_result(object_refs)
        return repack(computed_result)
    else:
        return get_result(object_refs)
Ejemplo n.º 2
0
def _rayify_task(
    task,
    key,
    deps,
    ray_presubmit_cbs,
    ray_postsubmit_cbs,
    ray_pretask_cbs,
    ray_posttask_cbs,
    ray_remote_args,
):
    """
    Rayifies the given task, submitting it as a Ray task to the Ray cluster.

    Args:
        task: A Dask graph value, being either a literal, dependency
            key, Dask task, or a list thereof.
        key: The Dask graph key for the given task.
        deps: The dependencies of this task.
        ray_presubmit_cbs: Pre-task submission callbacks.
        ray_postsubmit_cbs: Post-task submission callbacks.
        ray_pretask_cbs: Pre-task execution callbacks.
        ray_posttask_cbs: Post-task execution callbacks.
        ray_remote_args: Ray task options.

    Returns:
        A literal, a Ray object reference representing a submitted task, or a
        list thereof.
    """
    if isinstance(task, list):
        # Recursively rayify this list. This will still bottom out at the first
        # actual task encountered, inlining any tasks in that task's arguments.
        return [
            _rayify_task(
                t,
                key,
                deps,
                ray_presubmit_cbs,
                ray_postsubmit_cbs,
                ray_pretask_cbs,
                ray_posttask_cbs,
                ray_remote_args,
            ) for t in task
        ]
    elif istask(task):
        # Unpacks and repacks Ray object references and submits the task to the
        # Ray cluster for execution.
        if ray_presubmit_cbs is not None:
            alternate_returns = [
                cb(task, key, deps) for cb in ray_presubmit_cbs
            ]
            for alternate_return in alternate_returns:
                # We don't submit a Ray task if a presubmit callback returns
                # a non-`None` value, instead we return said value.
                # NOTE: This returns the first non-None presubmit callback
                # return value.
                if alternate_return is not None:
                    return alternate_return

        func, args = task[0], task[1:]
        if func is multiple_return_get:
            return _execute_task(task, deps)
        # If the function's arguments contain nested object references, we must
        # unpack said object references into a flat set of arguments so that
        # Ray properly tracks the object dependencies between Ray tasks.
        arg_object_refs, repack = unpack_object_refs(args, deps)
        # Submit the task using a wrapper function.
        object_refs = dask_task_wrapper.options(
            name=f"dask:{key!s}",
            num_returns=(1 if not isinstance(func, MultipleReturnFunc) else
                         func.num_returns),
            **ray_remote_args,
        ).remote(
            func,
            repack,
            key,
            ray_pretask_cbs,
            ray_posttask_cbs,
            *arg_object_refs,
        )

        if ray_postsubmit_cbs is not None:
            for cb in ray_postsubmit_cbs:
                cb(task, key, deps, object_refs)

        return object_refs
    elif not ishashable(task):
        return task
    elif task in deps:
        return deps[task]
    else:
        return task