コード例 #1
0
def test_inline_functions_protects_output_keys():
    dsk = {"x": (inc, 1), "y": (double, "x")}
    assert inline_functions(dsk, [], [inc]) == {"y": (double, (inc, 1))}
    assert inline_functions(dsk, ["x"], [inc]) == {
        "y": (double, "x"),
        "x": (inc, 1)
    }
コード例 #2
0
def test_inline_functions():
    x, y, i, d = "xyid"
    dsk = {"out": (add, i, d), i: (inc, x), d: (double, y), x: 1, y: 1}

    result = inline_functions(dsk, [], fast_functions=set([inc]))
    expected = {"out": (add, (inc, x), d), d: (double, y), x: 1, y: 1}
    assert result == expected
コード例 #3
0
def test_inline_functions_non_hashable():
    class NonHashableCallable:
        def __call__(self, a):
            return a + 1

        def __hash__(self):
            raise TypeError("Not hashable")

    nohash = NonHashableCallable()

    dsk = {"a": 1, "b": (inc, "a"), "c": (nohash, "b"), "d": (inc, "c")}

    result = inline_functions(dsk, [], fast_functions={inc})
    assert result["c"] == (nohash, dsk["b"])
    assert "b" not in result
コード例 #4
0
def test_inline_traverses_lists():
    x, y, i, d = "xyid"
    dsk = {"out": (sum, [i, d]), i: (inc, x), d: (double, y), x: 1, y: 1}
    expected = {"out": (sum, [(inc, x), d]), d: (double, y), x: 1, y: 1}
    result = inline_functions(dsk, [], fast_functions=set([inc]))
    assert result == expected
コード例 #5
0
def test_inline_doesnt_shrink_fast_functions_at_top():
    dsk = {"x": (inc, "y"), "y": 1}
    result = inline_functions(dsk, [], fast_functions=set([inc]))
    assert result == dsk
コード例 #6
0
def test_inline_ignores_curries_and_partials():
    dsk = {"x": 1, "y": 2, "a": (partial(add, 1), "x"), "b": (inc, "a")}

    result = inline_functions(dsk, [], fast_functions=set([add]))
    assert result["b"] == (inc, dsk["a"])
    assert "a" not in result
コード例 #7
0
def build_composition(
    endpoint_protocol: 'EndpointProtocol',
    components: Dict[str, 'ModelComponent'],
    connections: List['Connection'],
) -> 'TaskComposition':
    r"""Build a composed graph.

    Notes on easy sources to introduce bugs.

    ::

            Input Data
        --------------------
            a  b  c   d
            |  |  |   | \\
             \ | / \  |  ||
              C_2   C_1  ||
            /  |     | \ //
           /   |    /   *
        RES_2  |   |   // \
               |   |  //   RES_1
                \  | //
                C_2_1
                  |
                RES_3
        ---------------------
              Output Data

    Because there are connections between ``C_1 -> C_2_1`` and
    ``C_2 -> C_2_1`` we can eliminate the ``serialize <-> deserialize``
    tasks for the data transfered between these components. We need to be
    careful to not eliminate the ``serialize`` or ``deserialize`` tasks
    entirely though. In the case shown above, it is apparent ``RES_1`` &
    ``RES_2``. still need the ``serialize`` function, but the same also applies
    for ``deserialize``. Consider the example below with the same composition &
    connections as above:

    ::
            Input Data
        --------------------
            a  b  c   d
            |  |  |   | \\
             \ | /| \ |  \\
              C_2 |  C_1  ||
            /  |  |   @\  ||
           /   |  |   @ \ //
        RES_2  |  |  @   *
               |  | @  // \
                \ | @ //   RES_1
                 C_2_1
                  |
                RES_3
        ---------------------
              Output Data

    Though we are using the same composition, the endpoints have been changed so
    that the previous result of ``C_1``-> ``C_2_1`` is now being provided by
    input ``c``. However, there is still a connection between ``C_1`` and
    ``C_2_1`` which is denoted by the ``@`` symbols... Though the first
    example (shown at the top of this docstring) would be able to eliminate
    ``C_2_1 deserailize``from ``C_2`` / ``C_1``, we see here that since
    endpoints define the path through the DAG, we cannot eliminate them
    entirely either.
    """
    initial_task_dsk = _process_initial(endpoint_protocol, components)

    dsk_tgt_src_connections = {}
    for connection in connections:
        source_dsk = f"{connection.source_component}.outputs.{connection.source_key}"
        target_dsk = f"{connection.target_component}.inputs.{connection.target_key}"
        # value of target key is mapped one-to-one from value of source
        dsk_tgt_src_connections[target_dsk] = (identity, source_dsk)

    rewrite_ruleset = RuleSet()
    for dsk_payload_target_serial in initial_task_dsk.payload_tasks_dsk.keys():
        dsk_payload_target, _serial_ident = dsk_payload_target_serial.rsplit(
            ".", maxsplit=1)
        if _serial_ident != "serial":
            raise RuntimeError(
                f"dsk_payload_target_serial={dsk_payload_target_serial}, "
                f"dsk_payload_target={dsk_payload_target}, _serial_ident={_serial_ident}"
            )
        if dsk_payload_target in dsk_tgt_src_connections:
            # This rewrite rule ensures that exposed inputs are able to replace inputs
            # coming from connected components. If the payload keys are mapped in a
            # connection, replace the connection with the payload deserialize function.
            lhs = dsk_tgt_src_connections[dsk_payload_target]
            rhs = initial_task_dsk.merged_dsk[dsk_payload_target]
            rule = RewriteRule(lhs, rhs, vars=())
            rewrite_ruleset.add(rule)

    io_subgraphs_merged = merge(
        initial_task_dsk.merged_dsk,
        dsk_tgt_src_connections,
        initial_task_dsk.result_tasks_dsk,
        initial_task_dsk.payload_tasks_dsk,
    )

    # apply rewrite rules
    rewritten_dsk = valmap(rewrite_ruleset.rewrite, io_subgraphs_merged)

    # We perform a significant optimization here by culling any tasks which
    # have been made redundant by the rewrite rules, or which don't exist
    # on a path which is required for computation of the endpoint outputs
    culled_dsk, culled_deps = cull(rewritten_dsk, initial_task_dsk.output_keys)
    _verify_no_cycles(culled_dsk, initial_task_dsk.output_keys,
                      endpoint_protocol.name)

    # as an optimization, we inline the `one_to_one` functions, into the
    # execution of their dependency. Since they are so cheap, there's no
    # need to spend time sending off a task to perform them.
    inlined = inline_functions(
        culled_dsk,
        initial_task_dsk.output_keys,
        fast_functions=[identity],
        inline_constants=True,
        dependencies=culled_deps,
    )
    inlined_culled_dsk, inlined_culled_deps = cull(
        inlined, initial_task_dsk.output_keys)
    _verify_no_cycles(inlined_culled_dsk, initial_task_dsk.output_keys,
                      endpoint_protocol.name)

    # pe-run topological sort of tasks so it doesn't have to be
    # recomputed upon every request.
    toposort_keys = toposort(inlined_culled_dsk)

    # construct results
    res = TaskComposition(
        dsk=inlined_culled_dsk,
        sortkeys=toposort_keys,
        get_keys=initial_task_dsk.output_keys,
        ep_dsk_input_keys=initial_task_dsk.payload_dsk_map,
        ep_dsk_output_keys=initial_task_dsk.result_dsk_map,
        pre_optimization_dsk=initial_task_dsk.merged_dsk,
    )
    return res