def test_inline_cull_dependencies():
    d = {
        "a": 1,
        "b": "a",
        "c": "b",
        "d": ["a", "b", "c"],
        "e": (add, (len, "d"), "a")
    }

    d2, dependencies = cull(d, ["d", "e"])
    inline(d2, {"b"}, dependencies=dependencies)
def test_cull():
    # 'out' depends on 'x' and 'y', but not 'z'
    d = {"x": 1, "y": (inc, "x"), "z": (inc, "x"), "out": (add, "y", 10)}
    culled, dependencies = cull(d, "out")
    assert culled == {"x": 1, "y": (inc, "x"), "out": (add, "y", 10)}
    assert dependencies == {"x": [], "y": ["x"], "out": ["y"]}

    assert cull(d, "out") == cull(d, ["out"])
    assert cull(d, ["out", "z"])[0] == d
    assert cull(d, [["out"], ["z"]]) == cull(d, ["out", "z"])
    pytest.raises(KeyError, lambda: cull(d, "badkey"))
def test_SubgraphCallable():
    non_hashable = [1, 2, 3]

    dsk = {
        "a": (apply, add, ["in1", 2]),
        "b": (
            apply,
            partial_by_order,
            ["in2"],
            {
                "function": func_with_kwargs,
                "other": [(1, 20)],
                "c": 4
            },
        ),
        "c": (
            apply,
            partial_by_order,
            ["in2", "in1"],
            {
                "function": func_with_kwargs,
                "other": [(1, 20)]
            },
        ),
        "d": (inc, "a"),
        "e": (add, "c", "d"),
        "f": ["a", 2, "b", (add, "b", (sum, non_hashable))],
        "h": (add, (sum, "f"), (sum, ["a", "b"])),
    }

    f = SubgraphCallable(dsk, "h", ["in1", "in2"], name="test")
    assert f.name == "test"
    assert repr(f) == "test"

    f2 = SubgraphCallable(dsk, "h", ["in1", "in2"], name="test")
    assert f == f2

    f3 = SubgraphCallable(dsk, "g", ["in1", "in2"], name="test")
    assert f != f3

    assert dict(f=None)
    assert hash(SubgraphCallable(None, None, [None]))
    assert hash(f3) != hash(f2)

    dsk2 = dsk.copy()
    dsk2.update({"in1": 1, "in2": 2})
    assert f(1, 2) == get(cull(dsk2, ["h"])[0], ["h"])[0]
    assert f(1, 2) == f(1, 2)

    f2 = pickle.loads(pickle.dumps(f))
    assert f2(1, 2) == f(1, 2)
예제 #4
0
def build_composition(
    endpoint_protocol: 'EndpointProtocol',
    components: Dict[str, 'ModelComponent'],
    connections: List['Connection'],
) -> 'TaskComposition':
    r"""Build a composed graph.

    Notes on easy sources to introduce bugs.

    ::

            Input Data
        --------------------
            a  b  c   d
            |  |  |   | \\
             \ | / \  |  ||
              C_2   C_1  ||
            /  |     | \ //
           /   |    /   *
        RES_2  |   |   // \
               |   |  //   RES_1
                \  | //
                C_2_1
                  |
                RES_3
        ---------------------
              Output Data

    Because there are connections between ``C_1 -> C_2_1`` and
    ``C_2 -> C_2_1`` we can eliminate the ``serialize <-> deserialize``
    tasks for the data transfered between these components. We need to be
    careful to not eliminate the ``serialize`` or ``deserialize`` tasks
    entirely though. In the case shown above, it is apparent ``RES_1`` &
    ``RES_2``. still need the ``serialize`` function, but the same also applies
    for ``deserialize``. Consider the example below with the same composition &
    connections as above:

    ::
            Input Data
        --------------------
            a  b  c   d
            |  |  |   | \\
             \ | /| \ |  \\
              C_2 |  C_1  ||
            /  |  |   @\  ||
           /   |  |   @ \ //
        RES_2  |  |  @   *
               |  | @  // \
                \ | @ //   RES_1
                 C_2_1
                  |
                RES_3
        ---------------------
              Output Data

    Though we are using the same composition, the endpoints have been changed so
    that the previous result of ``C_1``-> ``C_2_1`` is now being provided by
    input ``c``. However, there is still a connection between ``C_1`` and
    ``C_2_1`` which is denoted by the ``@`` symbols... Though the first
    example (shown at the top of this docstring) would be able to eliminate
    ``C_2_1 deserailize``from ``C_2`` / ``C_1``, we see here that since
    endpoints define the path through the DAG, we cannot eliminate them
    entirely either.
    """
    initial_task_dsk = _process_initial(endpoint_protocol, components)

    dsk_tgt_src_connections = {}
    for connection in connections:
        source_dsk = f"{connection.source_component}.outputs.{connection.source_key}"
        target_dsk = f"{connection.target_component}.inputs.{connection.target_key}"
        # value of target key is mapped one-to-one from value of source
        dsk_tgt_src_connections[target_dsk] = (identity, source_dsk)

    rewrite_ruleset = RuleSet()
    for dsk_payload_target_serial in initial_task_dsk.payload_tasks_dsk.keys():
        dsk_payload_target, _serial_ident = dsk_payload_target_serial.rsplit(
            ".", maxsplit=1)
        if _serial_ident != "serial":
            raise RuntimeError(
                f"dsk_payload_target_serial={dsk_payload_target_serial}, "
                f"dsk_payload_target={dsk_payload_target}, _serial_ident={_serial_ident}"
            )
        if dsk_payload_target in dsk_tgt_src_connections:
            # This rewrite rule ensures that exposed inputs are able to replace inputs
            # coming from connected components. If the payload keys are mapped in a
            # connection, replace the connection with the payload deserialize function.
            lhs = dsk_tgt_src_connections[dsk_payload_target]
            rhs = initial_task_dsk.merged_dsk[dsk_payload_target]
            rule = RewriteRule(lhs, rhs, vars=())
            rewrite_ruleset.add(rule)

    io_subgraphs_merged = merge(
        initial_task_dsk.merged_dsk,
        dsk_tgt_src_connections,
        initial_task_dsk.result_tasks_dsk,
        initial_task_dsk.payload_tasks_dsk,
    )

    # apply rewrite rules
    rewritten_dsk = valmap(rewrite_ruleset.rewrite, io_subgraphs_merged)

    # We perform a significant optimization here by culling any tasks which
    # have been made redundant by the rewrite rules, or which don't exist
    # on a path which is required for computation of the endpoint outputs
    culled_dsk, culled_deps = cull(rewritten_dsk, initial_task_dsk.output_keys)
    _verify_no_cycles(culled_dsk, initial_task_dsk.output_keys,
                      endpoint_protocol.name)

    # as an optimization, we inline the `one_to_one` functions, into the
    # execution of their dependency. Since they are so cheap, there's no
    # need to spend time sending off a task to perform them.
    inlined = inline_functions(
        culled_dsk,
        initial_task_dsk.output_keys,
        fast_functions=[identity],
        inline_constants=True,
        dependencies=culled_deps,
    )
    inlined_culled_dsk, inlined_culled_deps = cull(
        inlined, initial_task_dsk.output_keys)
    _verify_no_cycles(inlined_culled_dsk, initial_task_dsk.output_keys,
                      endpoint_protocol.name)

    # pe-run topological sort of tasks so it doesn't have to be
    # recomputed upon every request.
    toposort_keys = toposort(inlined_culled_dsk)

    # construct results
    res = TaskComposition(
        dsk=inlined_culled_dsk,
        sortkeys=toposort_keys,
        get_keys=initial_task_dsk.output_keys,
        ep_dsk_input_keys=initial_task_dsk.payload_dsk_map,
        ep_dsk_output_keys=initial_task_dsk.result_dsk_map,
        pre_optimization_dsk=initial_task_dsk.merged_dsk,
    )
    return res