Exemple #1
0
def inline(dsk, keys=None, inline_constants=True, dependencies=None):
    """Return new dask with the given keys inlined with their values.

    Inlines all constants if ``inline_constants`` keyword is True. Note that
    the constant keys will remain in the graph, to remove them follow
    ``inline`` with ``cull``.

    Examples
    --------
    >>> d = {'x': 1, 'y': (inc, 'x'), 'z': (add, 'x', 'y')}
    >>> inline(d)  # doctest: +SKIP
    {'x': 1, 'y': (inc, 1), 'z': (add, 1, 'y')}
    >>> inline(d, keys='y')  # doctest: +SKIP
    {'x': 1, 'y': (inc, 1), 'z': (add, 1, (inc, 1))}
    >>> inline(d, keys='y', inline_constants=False)  # doctest: +SKIP
    {'x': 1, 'y': (inc, 1), 'z': (add, 'x', (inc, 'x'))}
    """
    if dependencies and isinstance(next(iter(dependencies.values())), list):
        dependencies = {k: set(v) for k, v in dependencies.items()}

    keys = _flat_set(keys)

    if dependencies is None:
        dependencies = {k: get_dependencies(dsk, k) for k in dsk}

    if inline_constants:
        keys.update(k for k, v in dsk.items()
                    if (ishashable(v) and v in dsk) or (
                        not dependencies[k] and not istask(v)))

    # Keys may depend on other keys, so determine replace order with toposort.
    # The values stored in `keysubs` do not include other keys.
    replaceorder = toposort(dict((k, dsk[k]) for k in keys if k in dsk),
                            dependencies=dependencies)
    keysubs = {}
    for key in replaceorder:
        val = dsk[key]
        for dep in keys & dependencies[key]:
            if dep in keysubs:
                replace = keysubs[dep]
            else:
                replace = dsk[dep]
            val = subs(val, dep, replace)
        keysubs[key] = val

    # Make new dask with substitutions
    dsk2 = keysubs.copy()
    for key, val in dsk.items():
        if key not in dsk2:
            for item in keys & dependencies[key]:
                val = subs(val, item, keysubs[item])
            dsk2[key] = val
    return dsk2
Exemple #2
0
def build_composition(
    endpoint_protocol: 'EndpointProtocol',
    components: Dict[str, 'ModelComponent'],
    connections: List['Connection'],
) -> 'TaskComposition':
    r"""Build a composed graph.

    Notes on easy sources to introduce bugs.

    ::

            Input Data
        --------------------
            a  b  c   d
            |  |  |   | \\
             \ | / \  |  ||
              C_2   C_1  ||
            /  |     | \ //
           /   |    /   *
        RES_2  |   |   // \
               |   |  //   RES_1
                \  | //
                C_2_1
                  |
                RES_3
        ---------------------
              Output Data

    Because there are connections between ``C_1 -> C_2_1`` and
    ``C_2 -> C_2_1`` we can eliminate the ``serialize <-> deserialize``
    tasks for the data transfered between these components. We need to be
    careful to not eliminate the ``serialize`` or ``deserialize`` tasks
    entirely though. In the case shown above, it is apparent ``RES_1`` &
    ``RES_2``. still need the ``serialize`` function, but the same also applies
    for ``deserialize``. Consider the example below with the same composition &
    connections as above:

    ::
            Input Data
        --------------------
            a  b  c   d
            |  |  |   | \\
             \ | /| \ |  \\
              C_2 |  C_1  ||
            /  |  |   @\  ||
           /   |  |   @ \ //
        RES_2  |  |  @   *
               |  | @  // \
                \ | @ //   RES_1
                 C_2_1
                  |
                RES_3
        ---------------------
              Output Data

    Though we are using the same composition, the endpoints have been changed so
    that the previous result of ``C_1``-> ``C_2_1`` is now being provided by
    input ``c``. However, there is still a connection between ``C_1`` and
    ``C_2_1`` which is denoted by the ``@`` symbols... Though the first
    example (shown at the top of this docstring) would be able to eliminate
    ``C_2_1 deserailize``from ``C_2`` / ``C_1``, we see here that since
    endpoints define the path through the DAG, we cannot eliminate them
    entirely either.
    """
    initial_task_dsk = _process_initial(endpoint_protocol, components)

    dsk_tgt_src_connections = {}
    for connection in connections:
        source_dsk = f"{connection.source_component}.outputs.{connection.source_key}"
        target_dsk = f"{connection.target_component}.inputs.{connection.target_key}"
        # value of target key is mapped one-to-one from value of source
        dsk_tgt_src_connections[target_dsk] = (identity, source_dsk)

    rewrite_ruleset = RuleSet()
    for dsk_payload_target_serial in initial_task_dsk.payload_tasks_dsk.keys():
        dsk_payload_target, _serial_ident = dsk_payload_target_serial.rsplit(
            ".", maxsplit=1)
        if _serial_ident != "serial":
            raise RuntimeError(
                f"dsk_payload_target_serial={dsk_payload_target_serial}, "
                f"dsk_payload_target={dsk_payload_target}, _serial_ident={_serial_ident}"
            )
        if dsk_payload_target in dsk_tgt_src_connections:
            # This rewrite rule ensures that exposed inputs are able to replace inputs
            # coming from connected components. If the payload keys are mapped in a
            # connection, replace the connection with the payload deserialize function.
            lhs = dsk_tgt_src_connections[dsk_payload_target]
            rhs = initial_task_dsk.merged_dsk[dsk_payload_target]
            rule = RewriteRule(lhs, rhs, vars=())
            rewrite_ruleset.add(rule)

    io_subgraphs_merged = merge(
        initial_task_dsk.merged_dsk,
        dsk_tgt_src_connections,
        initial_task_dsk.result_tasks_dsk,
        initial_task_dsk.payload_tasks_dsk,
    )

    # apply rewrite rules
    rewritten_dsk = valmap(rewrite_ruleset.rewrite, io_subgraphs_merged)

    # We perform a significant optimization here by culling any tasks which
    # have been made redundant by the rewrite rules, or which don't exist
    # on a path which is required for computation of the endpoint outputs
    culled_dsk, culled_deps = cull(rewritten_dsk, initial_task_dsk.output_keys)
    _verify_no_cycles(culled_dsk, initial_task_dsk.output_keys,
                      endpoint_protocol.name)

    # as an optimization, we inline the `one_to_one` functions, into the
    # execution of their dependency. Since they are so cheap, there's no
    # need to spend time sending off a task to perform them.
    inlined = inline_functions(
        culled_dsk,
        initial_task_dsk.output_keys,
        fast_functions=[identity],
        inline_constants=True,
        dependencies=culled_deps,
    )
    inlined_culled_dsk, inlined_culled_deps = cull(
        inlined, initial_task_dsk.output_keys)
    _verify_no_cycles(inlined_culled_dsk, initial_task_dsk.output_keys,
                      endpoint_protocol.name)

    # pe-run topological sort of tasks so it doesn't have to be
    # recomputed upon every request.
    toposort_keys = toposort(inlined_culled_dsk)

    # construct results
    res = TaskComposition(
        dsk=inlined_culled_dsk,
        sortkeys=toposort_keys,
        get_keys=initial_task_dsk.output_keys,
        ep_dsk_input_keys=initial_task_dsk.payload_dsk_map,
        ep_dsk_output_keys=initial_task_dsk.result_dsk_map,
        pre_optimization_dsk=initial_task_dsk.merged_dsk,
    )
    return res