def combine_nested_closures(self): # Remove previous nested closures if there are any # self.closure_arrays = { # k: v # for k, v in self.closure_arrays.items() if v[3] is False # } for _, child in self.nested_closures: for arrname, (_, desc, evaluator, _) in sorted(child.closure_arrays.items()): # Check if the same array is already passed as part of a # nested closure arr = evaluator() if id(arr) in self.array_mapping: continue new_name = data.find_new_name(arrname, self.closure_arrays.keys()) if not desc.transient: self.closure_arrays[new_name] = (arrname, desc, evaluator, True) self.array_mapping[id(arr)] = new_name for cbname, (_, cb, _) in sorted(child.callbacks.items()): new_name = data.find_new_name(cbname, self.callbacks.keys()) self.callbacks[new_name] = (cbname, cb, True) self.array_mapping[id(cb)] = new_name
def apply(self, outer_state: SDFGState, sdfg: SDFG): nsdfg_node = self.nested_sdfg nsdfg: SDFG = nsdfg_node.sdfg if nsdfg_node.schedule is not dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types( nsdfg, nsdfg_node.schedule) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Environments for nstate in nsdfg.nodes(): for node in nstate.nodes(): if isinstance(node, nodes.CodeNode): node.environments |= nsdfg_node.environments # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Symbols outer_symbols = {str(k): v for k, v in sdfg.symbols.items()} for ise in sdfg.edges(): outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in outer_state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in outer_state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes symbolic.safe_replace(nsdfg_node.symbol_mapping, nsdfg.replace_dict) # Access nodes that need to be reshaped # reshapes: Set(str) = set() # for aname, array in nsdfg.arrays.items(): # if array.transient: # continue # edge = None # if aname in inputs: # edge = inputs[aname] # if len(array.shape) > len(edge.data.subset): # reshapes.add(aname) # continue # if aname in outputs: # edge = outputs[aname] # if len(array.shape) > len(edge.data.subset): # reshapes.add(aname) # continue # if edge is not None and not InlineMultistateSDFG._check_strides( # array.strides, sdfg.arrays[edge.data.data].strides, # edge.data, nsdfg_node): # reshapes.add(aname) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} # All transients become transients of the parent (if data already # exists, find new name) for nstate in nsdfg.nodes(): for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: new_name = node.data if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): new_name = f'{nsdfg.label}_{node.data}' name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): if edge.data.data is not None: datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: new_name = edge.data.data if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): new_name = f'{nsdfg.label}_{edge.data.data}' name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) transients[edge.data.data] = name ####################################################### # Replace data on inlined SDFG nodes/edges # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) symbolic.safe_replace(repldict, lambda m: replace_datadesc_names(nsdfg, m), value_as_string=True) # Add views whenever reshapes are necessary # for dname in reshapes: # desc = nsdfg.arrays[dname] # # To avoid potential confusion, rename protected __return keyword # if dname.startswith('__return'): # newname = f'{nsdfg.name}_ret{dname[8:]}' # else: # newname = dname # newname, _ = sdfg.add_view(newname, # desc.shape, # desc.dtype, # storage=desc.storage, # strides=desc.strides, # offset=desc.offset, # debuginfo=desc.debuginfo, # allow_conflicts=desc.allow_conflicts, # total_size=desc.total_size, # alignment=desc.alignment, # may_alias=desc.may_alias, # find_new_name=True) # repldict[dname] = newname # Add extra access nodes for out/in view nodes # inv_reshapes = {repldict[r]: r for r in reshapes} # for nstate in nsdfg.nodes(): # for node in nstate.nodes(): # if isinstance(node, # nodes.AccessNode) and node.data in inv_reshapes: # if nstate.in_degree(node) > 0 and nstate.out_degree( # node) > 0: # # Such a node has to be in the output set # edge = outputs[inv_reshapes[node.data]] # # Redirect outgoing edges through access node # out_edges = list(nstate.out_edges(node)) # anode = nstate.add_access(edge.data.data) # vnode = nstate.add_access(node.data) # nstate.add_nedge(node, anode, edge.data) # nstate.add_nedge(anode, vnode, edge.data) # for e in out_edges: # nstate.remove_edge(e) # nstate.add_edge(vnode, e.src_conn, e.dst, # e.dst_conn, e.data) # Make unique names for states statenames = set(s.label for s in sdfg.nodes()) for nstate in nsdfg.nodes(): if nstate.label in statenames: newname = data.find_new_name(nstate.label, statenames) statenames.add(newname) nstate.set_label(newname) ####################################################### # Collect and modify interstate edges as necessary outer_assignments = set() for e in sdfg.edges(): outer_assignments |= e.data.assignments.keys() inner_assignments = set() for e in nsdfg.edges(): inner_assignments |= e.data.assignments.keys() assignments_to_replace = inner_assignments & outer_assignments sym_replacements: Dict[str, str] = {} allnames = set(outer_symbols.keys()) | set(sdfg.arrays.keys()) for assign in assignments_to_replace: newname = data.find_new_name(assign, allnames) allnames.add(newname) sym_replacements[assign] = newname nsdfg.replace_dict(sym_replacements) ####################################################### # Add nested SDFG states into top-level SDFG outer_start_state = sdfg.start_state sdfg.add_nodes_from(nsdfg.nodes()) for ise in nsdfg.edges(): sdfg.add_edge(ise.src, ise.dst, ise.data) ####################################################### # Reconnect inlined SDFG source = nsdfg.start_state sinks = nsdfg.sink_nodes() # Reconnect state machine for e in sdfg.in_edges(outer_state): sdfg.add_edge(e.src, source, e.data) for e in sdfg.out_edges(outer_state): for sink in sinks: sdfg.add_edge(sink, e.dst, e.data) # Modify start state as necessary if outer_start_state is outer_state: sdfg.start_state = sdfg.node_id(source) # TODO: Modify memlets by offsetting # If both source and sink nodes are inputs/outputs, reconnect once # edges_to_ignore = self._modify_access_to_access(new_incoming_edges, # nsdfg, nstate, state, # orig_data) # source_to_outer = {n: e.src for n, e in new_incoming_edges.items()} # sink_to_outer = {n: e.dst for n, e in new_outgoing_edges.items()} # # If a source/sink node is one of the inputs/outputs, reconnect it, # # replacing memlets in outgoing/incoming paths # modified_edges = set() # modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, # state, sink_to_outer, True, # edges_to_ignore) # modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, # state, source_to_outer, # False, edges_to_ignore) # # Reshape: add connections to viewed data # self._modify_reshape_data(reshapes, repldict, inputs, nstate, state, # True) # self._modify_reshape_data(reshapes, repldict, outputs, nstate, state, # False) # Modify all other internal edges pertaining to input/output nodes # for nstate in nsdfg.nodes(): # for node in nstate.nodes(): # if isinstance(node, nodes.AccessNode): # if node.data in input_set or node.data in output_set: # if node.data in input_set: # outer_edge = inputs[input_set[node.data]] # else: # outer_edge = outputs[output_set[node.data]] # for edge in state.all_edges(node): # if (edge not in modified_edges # and edge.data.data == node.data): # for e in state.memlet_tree(edge): # if e.data.data == node.data: # e._data = helpers.unsqueeze_memlet( # e.data, outer_edge.data) # Replace nested SDFG parents with new SDFG for nstate in nsdfg.nodes(): nstate.parent = sdfg for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node ####################################################### # Remove nested SDFG and state sdfg.remove_node(outer_state) return nsdfg.nodes()