def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[StateAssignElimination._end_state]] edge = sdfg.in_edges(state)[0] # Since inter-state assignments that use an assigned value leads to # undefined behavior (e.g., {m: n, n: m}), we can replace each # assignment separately. keys_to_remove = set() assignments_to_consider = _assignments_to_consider(sdfg, edge) for varname, assignment in assignments_to_consider.items(): state.replace(varname, assignment) keys_to_remove.add(varname) repl_dict = {} for varname in keys_to_remove: # Remove assignments from edge del edge.data.assignments[varname] for e in sdfg.edges(): if varname in e.data.free_symbols: break else: # If removed assignment does not appear in any other edge, # replace and remove symbol if assignments_to_consider[varname] in sdfg.symbols: repl_dict[varname] = assignments_to_consider[varname] if varname in sdfg.symbols: sdfg.remove_symbol(varname) def _str_repl(s, d): for k, v in d.items(): s.replace(str(k), str(v)) if repl_dict: symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m))
def _check_strides(inner_strides: List[symbolic.SymbolicType], outer_strides: List[symbolic.SymbolicType], memlet: Memlet, nested_sdfg: nodes.NestedSDFG) -> bool: """ Returns True if the strides of the inner array can be matched to the strides of the outer array upon inlining. Takes into consideration memlet (un)squeeze and nested SDFG symbol mapping. :param inner_strides: The strides of the array inside the nested SDFG. :param outer_strides: The strides of the array in the external SDFG. :param nested_sdfg: Nested SDFG node with symbol mapping. :return: True if all strides match, False otherwise. """ # Replace all inner symbols based on symbol mapping istrides = list(inner_strides) def replfunc(mapping): for i, s in enumerate(istrides): if symbolic.issymbolic(s): istrides[i] = s.subs(mapping) symbolic.safe_replace(nested_sdfg.symbol_mapping, replfunc) if istrides == list(outer_strides): return True # Take unsqueezing into account dims_to_ignore = [ i for i, s in enumerate(memlet.subset.size()) if s == 1 ] ostrides = [ os for i, os in enumerate(outer_strides) if i not in dims_to_ignore ] if len(ostrides) == 0: ostrides = [1] if len(ostrides) != len(istrides): return False return all(istr == ostr for istr, ostr in zip(istrides, ostrides))
def apply(self, _, sdfg: SDFG): state = self.end_state edge = sdfg.in_edges(state)[0] # Since inter-state assignments that use an assigned value leads to # undefined behavior (e.g., {m: n, n: m}), we can replace each # assignment separately. assignments_to_consider = _assignments_to_consider(sdfg, edge, True) def _str_repl(s, d, **kwargs): for k, v in d.items(): s.replace(str(k), str(v), **kwargs) # Replace in state, and all successors symbolic.safe_replace(assignments_to_consider, lambda m: _str_repl(state, m)) visited = {edge} for isedge in sdfg.bfs_edges(state): if isedge not in visited: symbolic.safe_replace( assignments_to_consider, lambda m: _str_repl(isedge.data, m, replace_keys=False)) visited.add(isedge) if isedge.dst not in visited: symbolic.safe_replace(assignments_to_consider, lambda m: _str_repl(isedge.dst, m)) visited.add(isedge.dst) repl_dict = {} for varname in assignments_to_consider.keys(): # Remove assignments from edge del edge.data.assignments[varname] for e in sdfg.edges(): if varname in e.data.free_symbols: break else: # If removed assignment does not appear in any other edge, # replace and remove symbol if varname in sdfg.symbols: sdfg.remove_symbol(varname) # if assignments_to_consider[varname] in sdfg.symbols: if varname in sdfg.free_symbols: repl_dict[varname] = assignments_to_consider[varname] if repl_dict: symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m))
def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False): first_map_exit = graph.nodes()[candidate[MapFusion.first_map_exit]] first_map_entry = graph.entry_node(first_map_exit) second_map_entry = graph.nodes()[candidate[MapFusion.second_map_entry]] second_map_exit = graph.exit_node(second_map_entry) for _in_e in graph.in_edges(first_map_exit): if _in_e.data.wcr is not None: for _out_e in graph.out_edges(second_map_entry): if _out_e.data.data == _in_e.data.data: # wcr is on a node that is used in the second map, quit return False # Check whether there is a pattern map -> access -> map. intermediate_nodes = set() intermediate_data = set() for _, _, dst, _, _ in graph.out_edges(first_map_exit): if isinstance(dst, nodes.AccessNode): intermediate_nodes.add(dst) intermediate_data.add(dst.data) # If array is used anywhere else in this state. num_occurrences = len([ n for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n.data == dst.data ]) if num_occurrences > 1: return False else: return False # Check map ranges perm = MapFusion.find_permutation(first_map_entry.map, second_map_entry.map) if perm is None: return False # Check if any intermediate transient is also going to another location second_inodes = set(e.src for e in graph.in_edges(second_map_entry) if isinstance(e.src, nodes.AccessNode)) transients_to_remove = intermediate_nodes & second_inodes # if any(e.dst != second_map_entry for n in transients_to_remove # for e in graph.out_edges(n)): if any(graph.out_degree(n) > 1 for n in transients_to_remove): return False # Create a dict that maps parameters of the first map to those of the # second map. params_dict = {} for _index, _param in enumerate(second_map_entry.map.params): params_dict[_param] = first_map_entry.map.params[perm[_index]] out_memlets = [e.data for e in graph.in_edges(first_map_exit)] # Check that input set of second map is provided by the output set # of the first map, or other unrelated maps for second_edge in graph.out_edges(second_map_entry): # Memlets that do not come from one of the intermediate arrays if second_edge.data.data not in intermediate_data: # however, if intermediate_data eventually leads to # second_memlet.data, need to fail. for _n in intermediate_nodes: source_node = _n destination_node = graph.memlet_path(second_edge)[0].src # NOTE: Assumes graph has networkx version if destination_node in nx.descendants( graph._nx, source_node): return False continue provided = False # Compute second subset with respect to first subset's symbols sbs_permuted = dcpy(second_edge.data.subset) if sbs_permuted: # Create intermediate dicts to avoid conflicts, such as {i:j, j:i} symbolic.safe_replace(params_dict, lambda m: sbs_permuted.replace(m)) for first_memlet in out_memlets: if first_memlet.data != second_edge.data.data: continue # If there is a covered subset, it is provided if first_memlet.subset.covers(sbs_permuted): provided = True break # If none of the output memlets of the first map provide the info, # fail. if provided is False: return False # Checking for stencil pattern and common input/output data # (after fusing the maps) first_map_inputnodes = { e.src: e.src.data for e in graph.in_edges(first_map_entry) if isinstance(e.src, nodes.AccessNode) } input_views = set() viewed_inputnodes = dict() for n in first_map_inputnodes.keys(): if isinstance(n.desc(sdfg), data.View): input_views.add(n) for v in input_views: del first_map_inputnodes[v] e = sdutil.get_view_edge(graph, v) if e: first_map_inputnodes[e.src] = e.src.data viewed_inputnodes[e.src.data] = v second_map_outputnodes = { e.dst: e.dst.data for e in graph.out_edges(second_map_exit) if isinstance(e.dst, nodes.AccessNode) } output_views = set() viewed_outputnodes = dict() for n in second_map_outputnodes: if isinstance(n.desc(sdfg), data.View): output_views.add(n) for v in output_views: del second_map_outputnodes[v] e = sdutil.get_view_edge(graph, v) if e: second_map_outputnodes[e.dst] = e.dst.data viewed_outputnodes[e.dst.data] = v common_data = set(first_map_inputnodes.values()).intersection( set(second_map_outputnodes.values())) if common_data: input_data = [ viewed_inputnodes[d].data if d in viewed_inputnodes.keys() else d for d in common_data ] input_accesses = [ graph.memlet_path(e)[-1].data.src_subset for e in graph.out_edges(first_map_entry) if e.data.data in input_data ] if len(input_accesses) > 1: for i, a in enumerate(input_accesses[:-1]): for b in input_accesses[i + 1:]: if isinstance(a, subsets.Indices): c = subsets.Range.from_indices(a) c.offset(b, negative=True) else: c = a.offset_new(b, negative=True) for r in c: if r != (0, 0, 1): return False output_data = [ viewed_outputnodes[d].data if d in viewed_outputnodes.keys() else d for d in common_data ] output_accesses = [ graph.memlet_path(e)[0].data.dst_subset for e in graph.in_edges(second_map_exit) if e.data.data in output_data ] # Compute output accesses with respect to first map's symbols oacc_permuted = [dcpy(a) for a in output_accesses] for a in oacc_permuted: # Create intermediate dicts to avoid conflicts, such as {i:j, j:i} symbolic.safe_replace(params_dict, lambda m: a.replace(m)) a = input_accesses[0] for b in oacc_permuted: if isinstance(a, subsets.Indices): c = subsets.Range.from_indices(a) c.offset(b, negative=True) else: c = a.offset_new(b, negative=True) for r in c: if r != (0, 0, 1): return False # Success return True
def apply(self, sdfg: SDFG): nsdfg: nodes.NestedSDFG = self.nsdfg(sdfg) state = sdfg.node(self.state_id) new_state = sdfg.add_state_before(state) isedge = sdfg.edges_between(new_state, state)[0] # Find relevant symbol and data descriptor mapping mapping: Dict[str, str] = {} mapping.update({k: str(v) for k, v in nsdfg.symbol_mapping.items()}) mapping.update({ k: next(iter(state.in_edges_by_connector(nsdfg, k))).data.data for k in nsdfg.in_connectors }) mapping.update({ k: next(iter(state.out_edges_by_connector(nsdfg, k))).data.data for k in nsdfg.out_connectors }) # Get internal state and interstate edge source_state = nsdfg.sdfg.start_state nisedge = nsdfg.sdfg.out_edges(source_state)[0] # Add state contents (nodes) new_state.add_nodes_from(source_state.nodes()) # Replace data descriptors and symbols on state graph for node in source_state.nodes(): if isinstance(node, nodes.AccessNode) and node.data in mapping: node.data = mapping[node.data] for edge in source_state.edges(): edge.data.replace(mapping) if edge.data.data in mapping: edge.data.data = mapping[edge.data.data] # Add state contents (edges) for edge in source_state.edges(): new_state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) # Safe replacement of edge contents def replfunc(m): for k, v in mapping.items(): nisedge.data.replace(k, v, replace_keys=False) symbolic.safe_replace(mapping, replfunc) # Add interstate edge for akey, aval in nisedge.data.assignments.items(): # Map assignment to outer edge if akey not in sdfg.symbols and akey not in sdfg.arrays: newname = akey else: newname = nsdfg.label + '_' + akey isedge.data.assignments[newname] = aval # Add symbol to outer SDFG sdfg.add_symbol(newname, nsdfg.sdfg.symbols[akey]) # Add symbol mapping to nested SDFG nsdfg.symbol_mapping[akey] = newname isedge.data.condition = nisedge.data.condition # Clean nested SDFG nsdfg.sdfg.remove_node(source_state) # Set new starting state nsdfg.sdfg.start_state = nsdfg.sdfg.node_id(nisedge.dst)
def apply(self, outer_state: SDFGState, sdfg: SDFG): nsdfg_node = self.nested_sdfg nsdfg: SDFG = nsdfg_node.sdfg if nsdfg_node.schedule is not dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types( nsdfg, nsdfg_node.schedule) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Environments for nstate in nsdfg.nodes(): for node in nstate.nodes(): if isinstance(node, nodes.CodeNode): node.environments |= nsdfg_node.environments # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Symbols outer_symbols = {str(k): v for k, v in sdfg.symbols.items()} for ise in sdfg.edges(): outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in outer_state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in outer_state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes symbolic.safe_replace(nsdfg_node.symbol_mapping, nsdfg.replace_dict) # Access nodes that need to be reshaped # reshapes: Set(str) = set() # for aname, array in nsdfg.arrays.items(): # if array.transient: # continue # edge = None # if aname in inputs: # edge = inputs[aname] # if len(array.shape) > len(edge.data.subset): # reshapes.add(aname) # continue # if aname in outputs: # edge = outputs[aname] # if len(array.shape) > len(edge.data.subset): # reshapes.add(aname) # continue # if edge is not None and not InlineMultistateSDFG._check_strides( # array.strides, sdfg.arrays[edge.data.data].strides, # edge.data, nsdfg_node): # reshapes.add(aname) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} # All transients become transients of the parent (if data already # exists, find new name) for nstate in nsdfg.nodes(): for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: new_name = node.data if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): new_name = f'{nsdfg.label}_{node.data}' name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): if edge.data.data is not None: datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: new_name = edge.data.data if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): new_name = f'{nsdfg.label}_{edge.data.data}' name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) transients[edge.data.data] = name ####################################################### # Replace data on inlined SDFG nodes/edges # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) symbolic.safe_replace(repldict, lambda m: replace_datadesc_names(nsdfg, m), value_as_string=True) # Add views whenever reshapes are necessary # for dname in reshapes: # desc = nsdfg.arrays[dname] # # To avoid potential confusion, rename protected __return keyword # if dname.startswith('__return'): # newname = f'{nsdfg.name}_ret{dname[8:]}' # else: # newname = dname # newname, _ = sdfg.add_view(newname, # desc.shape, # desc.dtype, # storage=desc.storage, # strides=desc.strides, # offset=desc.offset, # debuginfo=desc.debuginfo, # allow_conflicts=desc.allow_conflicts, # total_size=desc.total_size, # alignment=desc.alignment, # may_alias=desc.may_alias, # find_new_name=True) # repldict[dname] = newname # Add extra access nodes for out/in view nodes # inv_reshapes = {repldict[r]: r for r in reshapes} # for nstate in nsdfg.nodes(): # for node in nstate.nodes(): # if isinstance(node, # nodes.AccessNode) and node.data in inv_reshapes: # if nstate.in_degree(node) > 0 and nstate.out_degree( # node) > 0: # # Such a node has to be in the output set # edge = outputs[inv_reshapes[node.data]] # # Redirect outgoing edges through access node # out_edges = list(nstate.out_edges(node)) # anode = nstate.add_access(edge.data.data) # vnode = nstate.add_access(node.data) # nstate.add_nedge(node, anode, edge.data) # nstate.add_nedge(anode, vnode, edge.data) # for e in out_edges: # nstate.remove_edge(e) # nstate.add_edge(vnode, e.src_conn, e.dst, # e.dst_conn, e.data) # Make unique names for states statenames = set(s.label for s in sdfg.nodes()) for nstate in nsdfg.nodes(): if nstate.label in statenames: newname = data.find_new_name(nstate.label, statenames) statenames.add(newname) nstate.set_label(newname) ####################################################### # Collect and modify interstate edges as necessary outer_assignments = set() for e in sdfg.edges(): outer_assignments |= e.data.assignments.keys() inner_assignments = set() for e in nsdfg.edges(): inner_assignments |= e.data.assignments.keys() assignments_to_replace = inner_assignments & outer_assignments sym_replacements: Dict[str, str] = {} allnames = set(outer_symbols.keys()) | set(sdfg.arrays.keys()) for assign in assignments_to_replace: newname = data.find_new_name(assign, allnames) allnames.add(newname) sym_replacements[assign] = newname nsdfg.replace_dict(sym_replacements) ####################################################### # Add nested SDFG states into top-level SDFG outer_start_state = sdfg.start_state sdfg.add_nodes_from(nsdfg.nodes()) for ise in nsdfg.edges(): sdfg.add_edge(ise.src, ise.dst, ise.data) ####################################################### # Reconnect inlined SDFG source = nsdfg.start_state sinks = nsdfg.sink_nodes() # Reconnect state machine for e in sdfg.in_edges(outer_state): sdfg.add_edge(e.src, source, e.data) for e in sdfg.out_edges(outer_state): for sink in sinks: sdfg.add_edge(sink, e.dst, e.data) # Modify start state as necessary if outer_start_state is outer_state: sdfg.start_state = sdfg.node_id(source) # TODO: Modify memlets by offsetting # If both source and sink nodes are inputs/outputs, reconnect once # edges_to_ignore = self._modify_access_to_access(new_incoming_edges, # nsdfg, nstate, state, # orig_data) # source_to_outer = {n: e.src for n, e in new_incoming_edges.items()} # sink_to_outer = {n: e.dst for n, e in new_outgoing_edges.items()} # # If a source/sink node is one of the inputs/outputs, reconnect it, # # replacing memlets in outgoing/incoming paths # modified_edges = set() # modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, # state, sink_to_outer, True, # edges_to_ignore) # modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, # state, source_to_outer, # False, edges_to_ignore) # # Reshape: add connections to viewed data # self._modify_reshape_data(reshapes, repldict, inputs, nstate, state, # True) # self._modify_reshape_data(reshapes, repldict, outputs, nstate, state, # False) # Modify all other internal edges pertaining to input/output nodes # for nstate in nsdfg.nodes(): # for node in nstate.nodes(): # if isinstance(node, nodes.AccessNode): # if node.data in input_set or node.data in output_set: # if node.data in input_set: # outer_edge = inputs[input_set[node.data]] # else: # outer_edge = outputs[output_set[node.data]] # for edge in state.all_edges(node): # if (edge not in modified_edges # and edge.data.data == node.data): # for e in state.memlet_tree(edge): # if e.data.data == node.data: # e._data = helpers.unsqueeze_memlet( # e.data, outer_edge.data) # Replace nested SDFG parents with new SDFG for nstate in nsdfg.nodes(): nstate.parent = sdfg for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node ####################################################### # Remove nested SDFG and state sdfg.remove_node(outer_state) return nsdfg.nodes()