def test_inline_cull_dependencies(): d = { "a": 1, "b": "a", "c": "b", "d": ["a", "b", "c"], "e": (add, (len, "d"), "a") } d2, dependencies = cull(d, ["d", "e"]) inline(d2, {"b"}, dependencies=dependencies)
def test_cull(): # 'out' depends on 'x' and 'y', but not 'z' d = {"x": 1, "y": (inc, "x"), "z": (inc, "x"), "out": (add, "y", 10)} culled, dependencies = cull(d, "out") assert culled == {"x": 1, "y": (inc, "x"), "out": (add, "y", 10)} assert dependencies == {"x": [], "y": ["x"], "out": ["y"]} assert cull(d, "out") == cull(d, ["out"]) assert cull(d, ["out", "z"])[0] == d assert cull(d, [["out"], ["z"]]) == cull(d, ["out", "z"]) pytest.raises(KeyError, lambda: cull(d, "badkey"))
def test_SubgraphCallable(): non_hashable = [1, 2, 3] dsk = { "a": (apply, add, ["in1", 2]), "b": ( apply, partial_by_order, ["in2"], { "function": func_with_kwargs, "other": [(1, 20)], "c": 4 }, ), "c": ( apply, partial_by_order, ["in2", "in1"], { "function": func_with_kwargs, "other": [(1, 20)] }, ), "d": (inc, "a"), "e": (add, "c", "d"), "f": ["a", 2, "b", (add, "b", (sum, non_hashable))], "h": (add, (sum, "f"), (sum, ["a", "b"])), } f = SubgraphCallable(dsk, "h", ["in1", "in2"], name="test") assert f.name == "test" assert repr(f) == "test" f2 = SubgraphCallable(dsk, "h", ["in1", "in2"], name="test") assert f == f2 f3 = SubgraphCallable(dsk, "g", ["in1", "in2"], name="test") assert f != f3 assert dict(f=None) assert hash(SubgraphCallable(None, None, [None])) assert hash(f3) != hash(f2) dsk2 = dsk.copy() dsk2.update({"in1": 1, "in2": 2}) assert f(1, 2) == get(cull(dsk2, ["h"])[0], ["h"])[0] assert f(1, 2) == f(1, 2) f2 = pickle.loads(pickle.dumps(f)) assert f2(1, 2) == f(1, 2)
def build_composition( endpoint_protocol: 'EndpointProtocol', components: Dict[str, 'ModelComponent'], connections: List['Connection'], ) -> 'TaskComposition': r"""Build a composed graph. Notes on easy sources to introduce bugs. :: Input Data -------------------- a b c d | | | | \\ \ | / \ | || C_2 C_1 || / | | \ // / | / * RES_2 | | // \ | | // RES_1 \ | // C_2_1 | RES_3 --------------------- Output Data Because there are connections between ``C_1 -> C_2_1`` and ``C_2 -> C_2_1`` we can eliminate the ``serialize <-> deserialize`` tasks for the data transfered between these components. We need to be careful to not eliminate the ``serialize`` or ``deserialize`` tasks entirely though. In the case shown above, it is apparent ``RES_1`` & ``RES_2``. still need the ``serialize`` function, but the same also applies for ``deserialize``. Consider the example below with the same composition & connections as above: :: Input Data -------------------- a b c d | | | | \\ \ | /| \ | \\ C_2 | C_1 || / | | @\ || / | | @ \ // RES_2 | | @ * | | @ // \ \ | @ // RES_1 C_2_1 | RES_3 --------------------- Output Data Though we are using the same composition, the endpoints have been changed so that the previous result of ``C_1``-> ``C_2_1`` is now being provided by input ``c``. However, there is still a connection between ``C_1`` and ``C_2_1`` which is denoted by the ``@`` symbols... Though the first example (shown at the top of this docstring) would be able to eliminate ``C_2_1 deserailize``from ``C_2`` / ``C_1``, we see here that since endpoints define the path through the DAG, we cannot eliminate them entirely either. """ initial_task_dsk = _process_initial(endpoint_protocol, components) dsk_tgt_src_connections = {} for connection in connections: source_dsk = f"{connection.source_component}.outputs.{connection.source_key}" target_dsk = f"{connection.target_component}.inputs.{connection.target_key}" # value of target key is mapped one-to-one from value of source dsk_tgt_src_connections[target_dsk] = (identity, source_dsk) rewrite_ruleset = RuleSet() for dsk_payload_target_serial in initial_task_dsk.payload_tasks_dsk.keys(): dsk_payload_target, _serial_ident = dsk_payload_target_serial.rsplit( ".", maxsplit=1) if _serial_ident != "serial": raise RuntimeError( f"dsk_payload_target_serial={dsk_payload_target_serial}, " f"dsk_payload_target={dsk_payload_target}, _serial_ident={_serial_ident}" ) if dsk_payload_target in dsk_tgt_src_connections: # This rewrite rule ensures that exposed inputs are able to replace inputs # coming from connected components. If the payload keys are mapped in a # connection, replace the connection with the payload deserialize function. lhs = dsk_tgt_src_connections[dsk_payload_target] rhs = initial_task_dsk.merged_dsk[dsk_payload_target] rule = RewriteRule(lhs, rhs, vars=()) rewrite_ruleset.add(rule) io_subgraphs_merged = merge( initial_task_dsk.merged_dsk, dsk_tgt_src_connections, initial_task_dsk.result_tasks_dsk, initial_task_dsk.payload_tasks_dsk, ) # apply rewrite rules rewritten_dsk = valmap(rewrite_ruleset.rewrite, io_subgraphs_merged) # We perform a significant optimization here by culling any tasks which # have been made redundant by the rewrite rules, or which don't exist # on a path which is required for computation of the endpoint outputs culled_dsk, culled_deps = cull(rewritten_dsk, initial_task_dsk.output_keys) _verify_no_cycles(culled_dsk, initial_task_dsk.output_keys, endpoint_protocol.name) # as an optimization, we inline the `one_to_one` functions, into the # execution of their dependency. Since they are so cheap, there's no # need to spend time sending off a task to perform them. inlined = inline_functions( culled_dsk, initial_task_dsk.output_keys, fast_functions=[identity], inline_constants=True, dependencies=culled_deps, ) inlined_culled_dsk, inlined_culled_deps = cull( inlined, initial_task_dsk.output_keys) _verify_no_cycles(inlined_culled_dsk, initial_task_dsk.output_keys, endpoint_protocol.name) # pe-run topological sort of tasks so it doesn't have to be # recomputed upon every request. toposort_keys = toposort(inlined_culled_dsk) # construct results res = TaskComposition( dsk=inlined_culled_dsk, sortkeys=toposort_keys, get_keys=initial_task_dsk.output_keys, ep_dsk_input_keys=initial_task_dsk.payload_dsk_map, ep_dsk_output_keys=initial_task_dsk.result_dsk_map, pre_optimization_dsk=initial_task_dsk.merged_dsk, ) return res