def inline(dsk, keys=None, inline_constants=True, dependencies=None): """Return new dask with the given keys inlined with their values. Inlines all constants if ``inline_constants`` keyword is True. Note that the constant keys will remain in the graph, to remove them follow ``inline`` with ``cull``. Examples -------- >>> d = {'x': 1, 'y': (inc, 'x'), 'z': (add, 'x', 'y')} >>> inline(d) # doctest: +SKIP {'x': 1, 'y': (inc, 1), 'z': (add, 1, 'y')} >>> inline(d, keys='y') # doctest: +SKIP {'x': 1, 'y': (inc, 1), 'z': (add, 1, (inc, 1))} >>> inline(d, keys='y', inline_constants=False) # doctest: +SKIP {'x': 1, 'y': (inc, 1), 'z': (add, 'x', (inc, 'x'))} """ if dependencies and isinstance(next(iter(dependencies.values())), list): dependencies = {k: set(v) for k, v in dependencies.items()} keys = _flat_set(keys) if dependencies is None: dependencies = {k: get_dependencies(dsk, k) for k in dsk} if inline_constants: keys.update(k for k, v in dsk.items() if (ishashable(v) and v in dsk) or ( not dependencies[k] and not istask(v))) # Keys may depend on other keys, so determine replace order with toposort. # The values stored in `keysubs` do not include other keys. replaceorder = toposort(dict((k, dsk[k]) for k in keys if k in dsk), dependencies=dependencies) keysubs = {} for key in replaceorder: val = dsk[key] for dep in keys & dependencies[key]: if dep in keysubs: replace = keysubs[dep] else: replace = dsk[dep] val = subs(val, dep, replace) keysubs[key] = val # Make new dask with substitutions dsk2 = keysubs.copy() for key, val in dsk.items(): if key not in dsk2: for item in keys & dependencies[key]: val = subs(val, item, keysubs[item]) dsk2[key] = val return dsk2
def build_composition( endpoint_protocol: 'EndpointProtocol', components: Dict[str, 'ModelComponent'], connections: List['Connection'], ) -> 'TaskComposition': r"""Build a composed graph. Notes on easy sources to introduce bugs. :: Input Data -------------------- a b c d | | | | \\ \ | / \ | || C_2 C_1 || / | | \ // / | / * RES_2 | | // \ | | // RES_1 \ | // C_2_1 | RES_3 --------------------- Output Data Because there are connections between ``C_1 -> C_2_1`` and ``C_2 -> C_2_1`` we can eliminate the ``serialize <-> deserialize`` tasks for the data transfered between these components. We need to be careful to not eliminate the ``serialize`` or ``deserialize`` tasks entirely though. In the case shown above, it is apparent ``RES_1`` & ``RES_2``. still need the ``serialize`` function, but the same also applies for ``deserialize``. Consider the example below with the same composition & connections as above: :: Input Data -------------------- a b c d | | | | \\ \ | /| \ | \\ C_2 | C_1 || / | | @\ || / | | @ \ // RES_2 | | @ * | | @ // \ \ | @ // RES_1 C_2_1 | RES_3 --------------------- Output Data Though we are using the same composition, the endpoints have been changed so that the previous result of ``C_1``-> ``C_2_1`` is now being provided by input ``c``. However, there is still a connection between ``C_1`` and ``C_2_1`` which is denoted by the ``@`` symbols... Though the first example (shown at the top of this docstring) would be able to eliminate ``C_2_1 deserailize``from ``C_2`` / ``C_1``, we see here that since endpoints define the path through the DAG, we cannot eliminate them entirely either. """ initial_task_dsk = _process_initial(endpoint_protocol, components) dsk_tgt_src_connections = {} for connection in connections: source_dsk = f"{connection.source_component}.outputs.{connection.source_key}" target_dsk = f"{connection.target_component}.inputs.{connection.target_key}" # value of target key is mapped one-to-one from value of source dsk_tgt_src_connections[target_dsk] = (identity, source_dsk) rewrite_ruleset = RuleSet() for dsk_payload_target_serial in initial_task_dsk.payload_tasks_dsk.keys(): dsk_payload_target, _serial_ident = dsk_payload_target_serial.rsplit( ".", maxsplit=1) if _serial_ident != "serial": raise RuntimeError( f"dsk_payload_target_serial={dsk_payload_target_serial}, " f"dsk_payload_target={dsk_payload_target}, _serial_ident={_serial_ident}" ) if dsk_payload_target in dsk_tgt_src_connections: # This rewrite rule ensures that exposed inputs are able to replace inputs # coming from connected components. If the payload keys are mapped in a # connection, replace the connection with the payload deserialize function. lhs = dsk_tgt_src_connections[dsk_payload_target] rhs = initial_task_dsk.merged_dsk[dsk_payload_target] rule = RewriteRule(lhs, rhs, vars=()) rewrite_ruleset.add(rule) io_subgraphs_merged = merge( initial_task_dsk.merged_dsk, dsk_tgt_src_connections, initial_task_dsk.result_tasks_dsk, initial_task_dsk.payload_tasks_dsk, ) # apply rewrite rules rewritten_dsk = valmap(rewrite_ruleset.rewrite, io_subgraphs_merged) # We perform a significant optimization here by culling any tasks which # have been made redundant by the rewrite rules, or which don't exist # on a path which is required for computation of the endpoint outputs culled_dsk, culled_deps = cull(rewritten_dsk, initial_task_dsk.output_keys) _verify_no_cycles(culled_dsk, initial_task_dsk.output_keys, endpoint_protocol.name) # as an optimization, we inline the `one_to_one` functions, into the # execution of their dependency. Since they are so cheap, there's no # need to spend time sending off a task to perform them. inlined = inline_functions( culled_dsk, initial_task_dsk.output_keys, fast_functions=[identity], inline_constants=True, dependencies=culled_deps, ) inlined_culled_dsk, inlined_culled_deps = cull( inlined, initial_task_dsk.output_keys) _verify_no_cycles(inlined_culled_dsk, initial_task_dsk.output_keys, endpoint_protocol.name) # pe-run topological sort of tasks so it doesn't have to be # recomputed upon every request. toposort_keys = toposort(inlined_culled_dsk) # construct results res = TaskComposition( dsk=inlined_culled_dsk, sortkeys=toposort_keys, get_keys=initial_task_dsk.output_keys, ep_dsk_input_keys=initial_task_dsk.payload_dsk_map, ep_dsk_output_keys=initial_task_dsk.result_dsk_map, pre_optimization_dsk=initial_task_dsk.merged_dsk, ) return res