예제 #1
0
def inline_functions(dsk,
                     output,
                     fast_functions=None,
                     inline_constants=False,
                     dependencies=None):
    """Inline cheap functions into larger operations

    Examples
    --------
    >>> double = lambda x: x*2  # doctest: +SKIP
    >>> dsk = {'out': (add, 'i', 'd'),  # doctest: +SKIP
    ...        'i': (inc, 'x'),
    ...        'd': (double, 'y'),
    ...        'x': 1, 'y': 1}
    >>> inline_functions(dsk, [], [inc])  # doctest: +SKIP
    {'out': (add, (inc, 'x'), 'd'),
     'd': (double, 'y'),
     'x': 1, 'y': 1}

    Protect output keys.  In the example below ``i`` is not inlined because it
    is marked as an output key.

    >>> inline_functions(dsk, ['i', 'out'], [inc, double])  # doctest: +SKIP
    {'out': (add, 'i', (double, 'y')),
     'i': (inc, 'x'),
     'x': 1, 'y': 1}
    """
    if not fast_functions:
        return dsk

    output = set(output)

    fast_functions = set(fast_functions)

    if dependencies is None:
        dependencies = {k: get_dependencies(dsk, k) for k in dsk}
    dependents = reverse_dict(dependencies)

    def inlinable(v):
        try:
            return functions_of(v).issubset(fast_functions)
        except TypeError:
            return False

    keys = [
        k for k, v in dsk.items()
        if istask(v) and dependents[k] and k not in output and inlinable(v)
    ]

    if keys:
        dsk = inline(dsk,
                     keys,
                     inline_constants=inline_constants,
                     dependencies=dependencies)
        for k in keys:
            del dsk[k]
    return dsk
예제 #2
0
def test_get_dependencies_many():
    dsk = {
        "a": [1, 2, 3],
        "b": "a",
        "c": [1, (inc, 1)],
        "d": [(sum, "c")],
        "e": ["a", "b", "zzz"],
        "f": [["a", "b"], 2, 3],
    }

    tasks = [dsk[k] for k in ("d", "f")]
    s = get_dependencies(dsk, task=tasks)
    assert s == {"a", "b", "c"}
    s = get_dependencies(dsk, task=tasks, as_list=True)
    assert sorted(s) == ["a", "b", "c"]

    s = get_dependencies(dsk, task=[])
    assert s == set()
    s = get_dependencies(dsk, task=[], as_list=True)
    assert s == []
예제 #3
0
def inline(dsk, keys=None, inline_constants=True, dependencies=None):
    """Return new dask with the given keys inlined with their values.

    Inlines all constants if ``inline_constants`` keyword is True. Note that
    the constant keys will remain in the graph, to remove them follow
    ``inline`` with ``cull``.

    Examples
    --------
    >>> d = {'x': 1, 'y': (inc, 'x'), 'z': (add, 'x', 'y')}
    >>> inline(d)  # doctest: +SKIP
    {'x': 1, 'y': (inc, 1), 'z': (add, 1, 'y')}
    >>> inline(d, keys='y')  # doctest: +SKIP
    {'x': 1, 'y': (inc, 1), 'z': (add, 1, (inc, 1))}
    >>> inline(d, keys='y', inline_constants=False)  # doctest: +SKIP
    {'x': 1, 'y': (inc, 1), 'z': (add, 'x', (inc, 'x'))}
    """
    if dependencies and isinstance(next(iter(dependencies.values())), list):
        dependencies = {k: set(v) for k, v in dependencies.items()}

    keys = _flat_set(keys)

    if dependencies is None:
        dependencies = {k: get_dependencies(dsk, k) for k in dsk}

    if inline_constants:
        keys.update(k for k, v in dsk.items()
                    if (ishashable(v) and v in dsk) or (
                        not dependencies[k] and not istask(v)))

    # Keys may depend on other keys, so determine replace order with toposort.
    # The values stored in `keysubs` do not include other keys.
    replaceorder = toposort(dict((k, dsk[k]) for k in keys if k in dsk),
                            dependencies=dependencies)
    keysubs = {}
    for key in replaceorder:
        val = dsk[key]
        for dep in keys & dependencies[key]:
            if dep in keysubs:
                replace = keysubs[dep]
            else:
                replace = dsk[dep]
            val = subs(val, dep, replace)
        keysubs[key] = val

    # Make new dask with substitutions
    dsk2 = keysubs.copy()
    for key, val in dsk.items():
        if key not in dsk2:
            for item in keys & dependencies[key]:
                val = subs(val, item, keysubs[item])
            dsk2[key] = val
    return dsk2
예제 #4
0
def cull(dsk, keys):
    """Return new task graph with only the tasks required to calculate keys.

    In other words, remove unnecessary tasks from task graph.
    ``keys`` may be a single key or list of keys.

    Examples
    --------
    >>> d = {'x': 1, 'y': (inc, 'x'), 'out': (add, 'x', 10)}
    >>> dsk, dependencies = cull(d, 'out')  # doctest: +SKIP
    >>> dsk  # doctest: +SKIP
    {'x': 1, 'out': (add, 'x', 10)}
    >>> dependencies  # doctest: +SKIP
    {'x': set(), 'out': set(['x'])}

    Returns
    -------
    dsk: culled graph
    dependencies: Dict mapping {key: [deps]}.  Useful side effect to accelerate
        other optimizations, notably fuse.
    """
    if not isinstance(keys, (list, set)):
        keys = [keys]

    seen = set()
    dependencies = dict()
    out = {}
    work = list(set(flatten(keys)))

    while work:
        new_work = []
        for k in work:
            dependencies_k = get_dependencies(dsk, k,
                                              as_list=True)  # fuse needs lists
            out[k] = dsk[k]
            dependencies[k] = dependencies_k
            for d in dependencies_k:
                if d not in seen:
                    seen.add(d)
                    new_work.append(d)

        work = new_work

    return out, dependencies
예제 #5
0
def test_get_dependencies_task_none():
    # Regression test for https://github.com/dask/distributed/issues/2756
    dsk = {"foo": None}
    assert get_dependencies(dsk, task=dsk["foo"]) == set()
예제 #6
0
def test_get_dependencies_nothing():
    with pytest.raises(ValueError):
        get_dependencies({})
예제 #7
0
def test_get_dependencies_task():
    dsk = {"x": 1, "y": 2, "z": ["x", [(inc, "y")]]}
    assert get_dependencies(dsk, task=(inc, "x")) == set(["x"])
    assert get_dependencies(dsk, task=(inc, "x"), as_list=True) == ["x"]
예제 #8
0
def test_get_dependencies_list():
    dsk = {"x": 1, "y": 2, "z": ["x", [(inc, "y")]]}
    assert get_dependencies(dsk, "z") == set(["x", "y"])
    assert sorted(get_dependencies(dsk, "z", as_list=True)) == ["x", "y"]
예제 #9
0
def test_get_dependencies_empty():
    dsk = {"x": (inc, )}
    assert get_dependencies(dsk, "x") == set()
    assert get_dependencies(dsk, "x", as_list=True) == []
예제 #10
0
def test_get_dependencies_nested():
    dsk = {"x": 1, "y": 2, "z": (add, (inc, [["x"]]), "y")}

    assert get_dependencies(dsk, "z") == set(["x", "y"])
    assert sorted(get_dependencies(dsk, "z", as_list=True)) == ["x", "y"]
예제 #11
0
def order(dsk, dependencies=None):
    """Order nodes in the task graph

    This produces an ordering over our tasks that we use to break ties when
    executing.  We do this ahead of time to reduce a bit of stress on the
    scheduler and also to assist in static analysis.

    This currently traverses the graph as a single-threaded scheduler would
    traverse it.  It breaks ties in the following ways:

    1.  Begin at a leaf node that is a dependency of a root node that has the
        largest subgraph (start hard things first)
    2.  Prefer tall branches with few dependents (start hard things first and
        try to avoid memory usage)
    3.  Prefer dependents that are dependencies of root nodes that have
        the smallest subgraph (do small goals that can terminate quickly)

    Examples
    --------
    >>> dsk = {'a': 1, 'b': 2, 'c': (inc, 'a'), 'd': (add, 'b', 'c')}
    >>> order(dsk)
    {'a': 0, 'c': 1, 'b': 2, 'd': 3}
    """
    if not dsk:
        return {}

    if dependencies is None:
        dependencies = {k: get_dependencies(dsk, k) for k in dsk}

    dependents = reverse_dict(dependencies)
    num_needed, total_dependencies = ndependencies(dependencies, dependents)
    metrics = graph_metrics(dependencies, dependents, total_dependencies)
    if len(metrics) != len(dsk):
        cycle = getcycle(dsk, None)
        raise RuntimeError(
            "Cycle detected between the following keys:\n  -> %s" % "\n  -> ".join(str(x) for x in cycle)
        )

    # Leaf nodes.  We choose one--the initial node--for each weakly connected subgraph.
    # Let's calculate the `initial_stack_key` as we determine `init_stack` set.
    init_stack = {
        # First prioritize large, tall groups, then prioritize the same as ``dependents_key``.
        key: (
            # at a high-level, work towards a large goal (and prefer tall and narrow)
            -max_dependencies,
            num_dependents - max_heights,
            # tactically, finish small connected jobs first
            min_dependencies,
            num_dependents - min_heights,  # prefer tall and narrow
            -total_dependents,  # take a big step
            # try to be memory efficient
            num_dependents,
            # tie-breaker
            StrComparable(key),
        )
        for key, num_dependents, (
            total_dependents,
            min_dependencies,
            max_dependencies,
            min_heights,
            max_heights,
        ) in ((key, len(dependents[key]), metrics[key]) for key, val in dependencies.items() if not val)
    }
    # `initial_stack_key` chooses which task to run at the very beginning.
    # This value is static, so we pre-compute as the value of this dict.
    initial_stack_key = init_stack.__getitem__

    def dependents_key(x):
        """Choose a path from our starting task to our tactical goal

        This path is connected to a large goal, but focuses on completing
        a small goal and being memory efficient.
        """
        return (
            # Focus on being memory-efficient
            len(dependents[x]) - len(dependencies[x]) + num_needed[x],
            -metrics[x][3],  # min_heights
            # tie-breaker
            StrComparable(x),
        )

    def dependencies_key(x):
        """Choose which dependency to run as part of a reverse DFS

        This is very similar to both ``initial_stack_key``.
        """
        num_dependents = len(dependents[x])
        (
            total_dependents,
            min_dependencies,
            max_dependencies,
            min_heights,
            max_heights,
        ) = metrics[x]
        # Prefer short and narrow instead of tall in narrow, because we're going in
        # reverse along dependencies.
        return (
            # at a high-level, work towards a large goal (and prefer short and narrow)
            -max_dependencies,
            num_dependents + max_heights,
            # tactically, finish small connected jobs first
            min_dependencies,
            num_dependents + min_heights,  # prefer short and narrow
            -total_dependencies[x],  # go where the work is
            # try to be memory efficient
            num_dependents - len(dependencies[x]) + num_needed[x],
            num_dependents,
            total_dependents,  # already found work, so don't add more
            # tie-breaker
            StrComparable(x),
        )

    def finish_now_key(x):
        """ Determine the order of dependents that are ready to run and be released"""
        return (-len(dependencies[x]), StrComparable(x))

    # Computing this for all keys can sometimes be relatively expensive :(
    partition_keys = {
        key: ((min_dependencies - total_dependencies[key] + 1) * (total_dependents - min_heights))
        for key, (
            total_dependents,
            min_dependencies,
            _,
            min_heights,
            _,
        ) in metrics.items()
    }

    result = {}
    i = 0

    # `inner_stask` is used to perform a DFS along dependencies.  Once emptied
    # (when traversing dependencies), this continue down a path along dependents
    # until a root node is reached.
    #
    # Sometimes, a better path along a dependent is discovered (i.e., something
    # that is easier to compute and doesn't requiring holding too much in memory).
    # In this case, the current `inner_stack` is appended to `inner_stacks` and
    # we begin a new DFS from the better node.
    #
    # A "better path" is determined by comparing `partition_keys`.
    inner_stacks = [[min(init_stack, key=initial_stack_key)]]
    inner_stacks_append = inner_stacks.append
    inner_stacks_extend = inner_stacks.extend
    inner_stacks_pop = inner_stacks.pop

    # Okay, now we get to the data structures used for fancy behavior.
    #
    # As we traverse nodes in the DFS along dependencies, we partition the dependents
    # via `partition_key`.  A dependent goes to:
    #    1) `inner_stack` if it's better than our current target,
    #    2) `next_nodes` if the partition key is lower than it's parent,
    #    3) `later_nodes` otherwise.
    # When the inner stacks are depleted, we process `next_nodes`.  If `next_nodes` is
    # empty (and `outer_stacks` is empty`), then we process `later_nodes` the same way.
    # These dicts use `partition_keys` as keys.  We process them by placing the values
    # in `outer_stack` so that the smallest keys will be processed first.
    next_nodes = defaultdict(list)
    later_nodes = defaultdict(list)

    # `outer_stack` is used to populate `inner_stacks`.  From the time we partition the
    # dependents of a node, we group them: one list per partition key per parent node.
    # This likely results in many small lists.  We do this to avoid sorting many larger
    # lists (i.e., to avoid n*log(n) behavior).  So, we have many small lists that we
    # partitioned, and we keep them in the order that we saw them (we will process them
    # in a FIFO manner).  By delaying sorting for as long as we can, we can first filter
    # out nodes that have already been computed.  All this complexity is worth it!
    outer_stack = []
    outer_stack_extend = outer_stack.extend
    outer_stack_pop = outer_stack.pop

    # Keep track of nodes that are in `inner_stack` or `inner_stacks` so we don't
    # process them again.
    seen = set()  # seen in an inner_stack (and has dependencies)
    seen_update = seen.update
    seen_add = seen.add

    # alias for speed
    set_difference = set.difference

    is_init_sorted = False
    while True:
        while inner_stacks:
            inner_stack = inner_stacks_pop()
            inner_stack_pop = inner_stack.pop
            while inner_stack:
                # Perform a DFS along dependencies until we complete our tactical goal
                item = inner_stack_pop()
                if item in result:
                    continue
                if num_needed[item]:
                    inner_stack.append(item)
                    deps = set_difference(dependencies[item], result)
                    if 1 < len(deps) < 1000:
                        inner_stack.extend(sorted(deps, key=dependencies_key, reverse=True))
                    else:
                        inner_stack.extend(deps)
                    seen_update(deps)
                    continue

                result[item] = i
                i += 1
                deps = dependents[item]

                # If inner_stack is empty, then we typically add the best dependent to it.
                # However, we don't add to it if we complete a node early via "finish_now" below
                # or if a dependent is already on an inner_stack.  In this case, we add the
                # dependents (not in an inner_stack) to next_nodes or later_nodes to handle later.
                # This serves three purposes:
                #   1. shrink `deps` so that it can be processed faster,
                #   2. make sure we don't process the same dependency repeatedly, and
                #   3. make sure we don't accidentally continue down an expensive-to-compute path.
                add_to_inner_stack = True
                if metrics[item][3] == 1:  # min_height
                    # Don't leave any dangling single nodes!  Finish all dependents that are
                    # ready and are also root nodes.
                    finish_now = {dep for dep in deps if not dependents[dep] and num_needed[dep] == 1}
                    if finish_now:
                        deps -= finish_now  # Safe to mutate
                        if len(finish_now) > 1:
                            finish_now = sorted(finish_now, key=finish_now_key)
                        for dep in finish_now:
                            result[dep] = i
                            i += 1
                        add_to_inner_stack = False

                if deps:
                    for dep in deps:
                        num_needed[dep] -= 1

                    already_seen = deps & seen
                    if already_seen:
                        if len(deps) == len(already_seen):
                            continue
                        add_to_inner_stack = False
                        deps -= already_seen

                    if len(deps) == 1:
                        # Fast path!  We trim down `deps` above hoping to reach here.
                        (dep, ) = deps
                        if not inner_stack:
                            if add_to_inner_stack:
                                inner_stack = [dep]
                                inner_stack_pop = inner_stack.pop
                                seen_add(dep)
                                continue
                            key = partition_keys[dep]
                        else:
                            key = partition_keys[dep]
                            if key < partition_keys[inner_stack[0]]:
                                # Run before `inner_stack` (change tactical goal!)
                                inner_stacks_append(inner_stack)
                                inner_stack = [dep]
                                inner_stack_pop = inner_stack.pop
                                seen_add(dep)
                                continue
                        if key < partition_keys[item]:
                            next_nodes[key].append(deps)
                        else:
                            later_nodes[key].append(deps)
                    else:
                        # Slow path :(.  This requires grouping by partition_key.
                        dep_pools = defaultdict(list)
                        for dep in deps:
                            dep_pools[partition_keys[dep]].append(dep)
                        item_key = partition_keys[item]
                        if inner_stack:
                            # If we have an inner_stack, we need to look for a "better" path
                            prev_key = partition_keys[inner_stack[0]]
                            now_keys = []  # < inner_stack[0]
                            for key, vals in dep_pools.items():
                                if key < prev_key:
                                    now_keys.append(key)
                                elif key < item_key:
                                    next_nodes[key].append(vals)
                                else:
                                    later_nodes[key].append(vals)
                            if now_keys:
                                # Run before `inner_stack` (change tactical goal!)
                                inner_stacks_append(inner_stack)
                                if 1 < len(now_keys):
                                    now_keys.sort(reverse=True)
                                for key in now_keys:
                                    pool = dep_pools[key]
                                    if 1 < len(pool) < 100:
                                        pool.sort(key=dependents_key, reverse=True)
                                    inner_stacks_extend([dep] for dep in pool)
                                    seen_update(pool)
                                inner_stack = inner_stacks_pop()
                                inner_stack_pop = inner_stack.pop
                        else:
                            # If we don't have an inner_stack, then we don't need to look
                            # for a "better" path, but we do need traverse along dependents.
                            if add_to_inner_stack:
                                min_key = min(dep_pools)
                                min_pool = dep_pools.pop(min_key)
                                if len(min_pool) == 1:
                                    inner_stack = min_pool
                                    seen_update(inner_stack)
                                elif (10 * item_key > 11 * len(min_pool) * len(min_pool) * min_key):
                                    # Put all items in min_pool onto inner_stacks.
                                    # I know this is a weird comparison.  Hear me out.
                                    # Although it is often beneficial to put all of the items in `min_pool`
                                    # onto `inner_stacks` to process next, it is very easy to be overzealous.
                                    # Sometimes it is actually better to defer until `next_nodes` is handled.
                                    # We should only put items onto `inner_stacks` that we're reasonably
                                    # confident about.  The above formula is a best effort heuristic given
                                    # what we have easily available.  It is obviously very specific to our
                                    # choice of partition_key.  Dask tests take this route about 40%.
                                    if len(min_pool) < 100:
                                        min_pool.sort(key=dependents_key, reverse=True)
                                    inner_stacks_extend([dep] for dep in min_pool)
                                    inner_stack = inner_stacks_pop()
                                    seen_update(min_pool)
                                else:
                                    # Put one item in min_pool onto inner_stack and the rest into next_nodes.
                                    if len(min_pool) < 100:
                                        inner_stack = [min(min_pool, key=dependents_key)]
                                    else:
                                        inner_stack = [min_pool.pop()]
                                    next_nodes[min_key].append(min_pool)
                                    seen_update(inner_stack)

                                inner_stack_pop = inner_stack.pop
                            for key, vals in dep_pools.items():
                                if key < item_key:
                                    next_nodes[key].append(vals)
                                else:
                                    later_nodes[key].append(vals)

        if len(dependencies) == len(result):
            break  # all done!

        if next_nodes:
            for key in sorted(next_nodes, reverse=True):
                # `outer_stacks` may not be empty here--it has data from previous `next_nodes`.
                # Since we pop things off of it (onto `inner_nodes`), this means we handle
                # multiple `next_nodes` in a LIFO manner.
                outer_stack_extend(reversed(next_nodes[key]))
            next_nodes = defaultdict(list)

        while outer_stack:
            # Try to add a few items to `inner_stacks`
            deps = [x for x in outer_stack_pop() if x not in result]
            if deps:
                if 1 < len(deps) < 100:
                    deps.sort(key=dependents_key, reverse=True)
                inner_stacks_extend([dep] for dep in deps)
                seen_update(deps)
                break

        if inner_stacks:
            continue

        if later_nodes:
            # You know all those dependents with large keys we've been hanging onto to run "later"?
            # Well, "later" has finally come.
            next_nodes, later_nodes = later_nodes, next_nodes
            continue

        # We just finished computing a connected group.
        # Let's choose the first `item` in the next group to compute.
        # If we have few large groups left, then it's best to find `item` by taking a minimum.
        # If we have many small groups left, then it's best to sort.
        # If we have many tiny groups left, then it's best to simply iterate.
        if not is_init_sorted:
            prev_len = len(init_stack)
            if type(init_stack) is dict:
                init_stack = set(init_stack)
            init_stack = set_difference(init_stack, result)
            N = len(init_stack)
            m = prev_len - N
            # is `min` likely better than `sort`?
            if m >= N or N + (N - m) * log(N - m) < N * log(N):
                item = min(init_stack, key=initial_stack_key)
                continue

            if len(init_stack) < 10000:
                init_stack = sorted(init_stack, key=initial_stack_key, reverse=True)
            else:
                init_stack = list(init_stack)
            init_stack_pop = init_stack.pop
            is_init_sorted = True

        item = init_stack_pop()
        while item in result:
            item = init_stack_pop()
        inner_stacks_append([item])

    return result
def with_deps(dsk):
    return dsk, {k: get_dependencies(dsk, k) for k in dsk}
예제 #13
0
def fuse_linear(dsk, keys=None, dependencies=None, rename_keys=True):
    """Return new dask graph with linear sequence of tasks fused together.

    If specified, the keys in ``keys`` keyword argument are *not* fused.
    Supply ``dependencies`` from output of ``cull`` if available to avoid
    recomputing dependencies.

    **This function is mostly superseded by ``fuse``**

    Parameters
    ----------
    dsk: dict
    keys: list
    dependencies: dict, optional
        {key: [list-of-keys]}.  Must be a list to provide count of each key
        This optional input often comes from ``cull``
    rename_keys: bool or func, optional
        Whether to rename fused keys with ``default_fused_linear_keys_renamer``
        or not.  Renaming fused keys can keep the graph more understandable
        and comprehensive, but it comes at the cost of additional processing.
        If False, then the top-most key will be used.  For advanced usage, a
        func is also accepted, ``new_key = rename_keys(fused_key_list)``.

    Examples
    --------
    >>> d = {'a': 1, 'b': (inc, 'a'), 'c': (inc, 'b')}
    >>> dsk, dependencies = fuse(d)
    >>> dsk # doctest: +SKIP
    {'a-b-c': (inc, (inc, 1)), 'c': 'a-b-c'}
    >>> dsk, dependencies = fuse(d, rename_keys=False)
    >>> dsk # doctest: +SKIP
    {'c': (inc, (inc, 1))}
    >>> dsk, dependencies = fuse(d, keys=['b'], rename_keys=False)
    >>> dsk  # doctest: +SKIP
    {'b': (inc, 1), 'c': (inc, 'b')}

    Returns
    -------
    dsk: output graph with keys fused
    dependencies: dict mapping dependencies after fusion.  Useful side effect
        to accelerate other downstream optimizations.
    """
    if keys is not None and not isinstance(keys, set):
        if not isinstance(keys, list):
            keys = [keys]
        keys = set(flatten(keys))

    if dependencies is None:
        dependencies = {k: get_dependencies(dsk, k, as_list=True) for k in dsk}

    # locate all members of linear chains
    child2parent = {}
    unfusible = set()
    for parent in dsk:
        deps = dependencies[parent]
        has_many_children = len(deps) > 1
        for child in deps:
            if keys is not None and child in keys:
                unfusible.add(child)
            elif child in child2parent:
                del child2parent[child]
                unfusible.add(child)
            elif has_many_children:
                unfusible.add(child)
            elif child not in unfusible:
                child2parent[child] = parent

    # construct the chains from ancestor to descendant
    chains = []
    parent2child = dict(map(reversed, child2parent.items()))
    while child2parent:
        child, parent = child2parent.popitem()
        chain = [child, parent]
        while parent in child2parent:
            parent = child2parent.pop(parent)
            del parent2child[parent]
            chain.append(parent)
        chain.reverse()
        while child in parent2child:
            child = parent2child.pop(child)
            del child2parent[child]
            chain.append(child)
        chains.append(chain)

    dependencies = {k: set(v) for k, v in dependencies.items()}

    if rename_keys is True:
        key_renamer = default_fused_linear_keys_renamer
    elif rename_keys is False:
        key_renamer = None
    else:
        key_renamer = rename_keys

    # create a new dask with fused chains
    rv = {}
    fused = set()
    aliases = set()
    is_renamed = False
    for chain in chains:
        if key_renamer is not None:
            new_key = key_renamer(chain)
            is_renamed = new_key is not None and new_key not in dsk and new_key not in rv
        child = chain.pop()
        val = dsk[child]
        while chain:
            parent = chain.pop()
            dependencies[parent].update(dependencies.pop(child))
            dependencies[parent].remove(child)
            val = subs(dsk[parent], child, val)
            fused.add(child)
            child = parent
        fused.add(child)
        if is_renamed:
            rv[new_key] = val
            rv[child] = new_key
            dependencies[new_key] = dependencies[child]
            dependencies[child] = {new_key}
            aliases.add(child)
        else:
            rv[child] = val
    for key, val in dsk.items():
        if key not in fused:
            rv[key] = val
    if aliases:
        for key, deps in dependencies.items():
            for old_key in deps & aliases:
                new_key = rv[old_key]
                deps.remove(old_key)
                deps.add(new_key)
                rv[key] = subs(rv[key], old_key, new_key)
        if keys is not None:
            for key in aliases - keys:
                del rv[key]
                del dependencies[key]
    return rv, dependencies
예제 #14
0
def fuse(
    dsk,
    keys=None,
    dependencies=None,
    ave_width=None,
    max_width=None,
    max_height=None,
    max_depth_new_edges=None,
    rename_keys=True,
    fuse_subgraphs=False,
):
    """Fuse tasks that form reductions; more advanced than ``fuse_linear``

    This trades parallelism opportunities for faster scheduling by making tasks
    less granular.  It can replace ``fuse_linear`` in optimization passes.

    This optimization applies to all reductions--tasks that have at most one
    dependent--so it may be viewed as fusing "multiple input, single output"
    groups of tasks into a single task.  There are many parameters to fine
    tune the behavior, which are described below.  ``ave_width`` is the
    natural parameter with which to compare parallelism to granularity, so
    it should always be specified.  Reasonable values for other parameters
    will be determined using ``ave_width`` if necessary.

    Parameters
    ----------
    dsk: dict
        dask graph
    keys: list or set, optional
        Keys that must remain in the returned dask graph
    dependencies: dict, optional
        {key: [list-of-keys]}.  Must be a list to provide count of each key
        This optional input often comes from ``cull``
    ave_width: float (default 1)
        Upper limit for ``width = num_nodes / height``, a good measure of
        parallelizability.
    max_width: int (default infinite)
        Don't fuse if total width is greater than this. Set to ``None``
        to dynamically adjust to  ``1.5 + ave_width * log(ave_width + 1)``
    max_height: int or None (default None)
        Don't fuse more than this many levels. Set to None to dynamically
        adjust to ``1.5 + ave_width * log(ave_width + 1)``.
    max_depth_new_edges: int or None (default None)
        Don't fuse if new dependencies are added after this many levels.
        Set to None to dynamically adjust to ``ave_width * 1.5``
    rename_keys: bool or func, optional (default True)
        Whether to rename the fused keys with ``default_fused_keys_renamer``
        or not.  Renaming fused keys can keep the graph more understandable
        and comprehensive, but it comes at the cost of additional processing.
        If False, then the top-most key will be used.  For advanced usage, a
        function to create the new name is also accepted.
    fuse_subgraphs : bool, optional (default False)
        Whether to fuse multiple tasks into ``SubgraphCallable`` objects.
        Set to None to let the default optimizer of individual dask collections decide.
        If no collection-specific default exists, defaults to False.

    Returns
    -------
    dsk
        output graph with keys fused
    dependencies
        dict mapping dependencies after fusion.  Useful side effect to accelerate other
        downstream optimizations.
    """

    if keys is not None and not isinstance(keys, set):
        if not isinstance(keys, list):
            keys = [keys]
        keys = set(flatten(keys))

    if ave_width is None:
        ave_width = 1
    if max_height is None:
        max_height = 1.5 + (ave_width * math.log(ave_width + 1))
    if max_depth_new_edges is None:
        max_depth_new_edges = ave_width * 1.5
    if max_width is None:
        max_width = 1.5 + ave_width * math.log(ave_width + 1)

    if not ave_width or not max_height:
        return dsk, dependencies

    if rename_keys is True:
        key_renamer = default_fused_keys_renamer
    elif rename_keys is False:
        key_renamer = None
    elif not callable(rename_keys):
        raise TypeError("rename_keys must be a boolean or callable")
    else:
        key_renamer = rename_keys
    rename_keys = key_renamer is not None

    if dependencies is None:
        deps = {k: get_dependencies(dsk, k, as_list=True) for k in dsk}
    else:
        deps = dict(dependencies)

    rdeps = {}
    for k, vals in deps.items():
        for v in vals:
            if v not in rdeps:
                rdeps[v] = [k]
            else:
                rdeps[v].append(k)
        deps[k] = set(vals)

    reducible = {k for k, vals in rdeps.items() if len(vals) == 1}
    if keys:
        reducible -= keys

    for k, v in dsk.items():
        if type(v) is not tuple and not isinstance(v, (numbers.Number, str)):
            reducible.discard(k)

    if not reducible and (not fuse_subgraphs
                          or all(len(set(v)) != 1 for v in rdeps.values())):
        # Quick return if there's nothing to do. Only progress if there's tasks
        # fusible by the main `fuse`, or by `fuse_subgraphs` if enabled.
        return dsk, deps

    rv = dsk.copy()
    fused_trees = {}
    # These are the stacks we use to store data as we traverse the graph
    info_stack = []
    children_stack = []
    # For speed
    deps_pop = deps.pop
    reducible_add = reducible.add
    reducible_pop = reducible.pop
    reducible_remove = reducible.remove
    fused_trees_pop = fused_trees.pop
    info_stack_append = info_stack.append
    info_stack_pop = info_stack.pop
    children_stack_append = children_stack.append
    children_stack_extend = children_stack.extend
    children_stack_pop = children_stack.pop
    while reducible:
        parent = reducible_pop()
        reducible_add(parent)
        while parent in reducible:
            # Go to the top
            parent = rdeps[parent][0]
        children_stack_append(parent)
        children_stack_extend(reducible & deps[parent])
        while True:
            child = children_stack[-1]
            if child != parent:
                children = reducible & deps[child]
                while children:
                    # Depth-first search
                    children_stack_extend(children)
                    parent = child
                    child = children_stack[-1]
                    children = reducible & deps[child]
                children_stack_pop()
                # This is a leaf node in the reduction region
                # key, task, fused_keys, height, width, number of nodes, fudge, set of edges
                info_stack_append((
                    child,
                    rv[child],
                    [child] if rename_keys else None,
                    1,
                    1,
                    1,
                    0,
                    deps[child] - reducible,
                ))
            else:
                children_stack_pop()
                # Calculate metrics and fuse as appropriate
                deps_parent = deps[parent]
                edges = deps_parent - reducible
                children = deps_parent - edges
                num_children = len(children)

                if num_children == 1:
                    (
                        child_key,
                        child_task,
                        child_keys,
                        height,
                        width,
                        num_nodes,
                        fudge,
                        children_edges,
                    ) = info_stack_pop()
                    num_children_edges = len(children_edges)

                    if fudge > num_children_edges - 1 >= 0:
                        fudge = num_children_edges - 1
                    edges |= children_edges
                    no_new_edges = len(edges) == num_children_edges
                    if not no_new_edges:
                        fudge += 1

                    # Sanity check; don't go too deep if new levels introduce new edge dependencies
                    if ((num_nodes + fudge) / height <= ave_width and
                        (no_new_edges or height < max_depth_new_edges)):
                        # Perform substitutions as we go
                        val = subs(dsk[parent], child_key, child_task)
                        deps_parent.remove(child_key)
                        deps_parent |= deps_pop(child_key)
                        del rv[child_key]
                        reducible_remove(child_key)
                        if rename_keys:
                            child_keys.append(parent)
                            fused_trees[parent] = child_keys
                            fused_trees_pop(child_key, None)

                        if children_stack:
                            if no_new_edges:
                                # Linear fuse
                                info_stack_append((
                                    parent,
                                    val,
                                    child_keys,
                                    height,
                                    width,
                                    num_nodes,
                                    fudge,
                                    edges,
                                ))
                            else:
                                info_stack_append((
                                    parent,
                                    val,
                                    child_keys,
                                    height + 1,
                                    width,
                                    num_nodes + 1,
                                    fudge,
                                    edges,
                                ))
                        else:
                            rv[parent] = val
                            break
                    else:
                        rv[child_key] = child_task
                        reducible_remove(child_key)
                        if children_stack:
                            # Allow the parent to be fused, but only under strict circumstances.
                            # Ensure that linear chains may still be fused.
                            if fudge > int(ave_width - 1):
                                fudge = int(ave_width - 1)
                            # This task *implicitly* depends on `edges`
                            info_stack_append((
                                parent,
                                rv[parent],
                                [parent] if rename_keys else None,
                                1,
                                width,
                                1,
                                fudge,
                                edges,
                            ))
                        else:
                            break
                else:
                    child_keys = []
                    height = 1
                    width = 0
                    num_single_nodes = 0
                    num_nodes = 0
                    fudge = 0
                    children_edges = set()
                    max_num_edges = 0
                    children_info = info_stack[-num_children:]
                    del info_stack[-num_children:]
                    for (
                            cur_key,
                            cur_task,
                            cur_keys,
                            cur_height,
                            cur_width,
                            cur_num_nodes,
                            cur_fudge,
                            cur_edges,
                    ) in children_info:
                        if cur_height == 1:
                            num_single_nodes += 1
                        elif cur_height > height:
                            height = cur_height
                        width += cur_width
                        num_nodes += cur_num_nodes
                        fudge += cur_fudge
                        if len(cur_edges) > max_num_edges:
                            max_num_edges = len(cur_edges)
                        children_edges |= cur_edges
                    # Fudge factor to account for possible parallelism with the boundaries
                    num_children_edges = len(children_edges)
                    fudge += min(num_children - 1,
                                 max(0, num_children_edges - max_num_edges))

                    if fudge > num_children_edges - 1 >= 0:
                        fudge = num_children_edges - 1
                    edges |= children_edges
                    no_new_edges = len(edges) == num_children_edges
                    if not no_new_edges:
                        fudge += 1
                    # Sanity check; don't go too deep if new levels introduce new edge dependencies
                    if ((num_nodes + fudge) / height <= ave_width
                            and num_single_nodes <= ave_width
                            and width <= max_width
                            and height <= max_height  # noqa E129
                            and
                        (no_new_edges
                         or height < max_depth_new_edges)):  # noqa E129
                        # Perform substitutions as we go
                        val = dsk[parent]
                        children_deps = set()
                        for child_info in children_info:
                            cur_child = child_info[0]
                            val = subs(val, cur_child, child_info[1])
                            del rv[cur_child]
                            children_deps |= deps_pop(cur_child)
                            reducible_remove(cur_child)
                            if rename_keys:
                                fused_trees_pop(cur_child, None)
                                child_keys.extend(child_info[2])
                        deps_parent -= children
                        deps_parent |= children_deps

                        if rename_keys:
                            child_keys.append(parent)
                            fused_trees[parent] = child_keys

                        if children_stack:
                            info_stack_append((
                                parent,
                                val,
                                child_keys,
                                height + 1,
                                width,
                                num_nodes + 1,
                                fudge,
                                edges,
                            ))
                        else:
                            rv[parent] = val
                            break
                    else:
                        for child_info in children_info:
                            rv[child_info[0]] = child_info[1]
                            reducible_remove(child_info[0])
                        if children_stack:
                            # Allow the parent to be fused, but only under strict circumstances.
                            # Ensure that linear chains may still be fused.
                            if width > max_width:
                                width = max_width
                            if fudge > int(ave_width - 1):
                                fudge = int(ave_width - 1)
                            # key, task, height, width, number of nodes, fudge, set of edges
                            # This task *implicitly* depends on `edges`
                            info_stack_append((
                                parent,
                                rv[parent],
                                [parent] if rename_keys else None,
                                1,
                                width,
                                1,
                                fudge,
                                edges,
                            ))
                        else:
                            break
                # Traverse upwards
                parent = rdeps[parent][0]

    if fuse_subgraphs:
        _inplace_fuse_subgraphs(rv, keys, deps, fused_trees, rename_keys)

    if key_renamer:
        for root_key, fused_keys in fused_trees.items():
            alias = key_renamer(fused_keys)
            if alias is not None and alias not in rv:
                rv[alias] = rv[root_key]
                rv[root_key] = alias
                deps[alias] = deps[root_key]
                deps[root_key] = {alias}

    return rv, deps