class MapFusion(pattern_matching.Transformation): """ Implements the MapFusion transformation. It wil check for all patterns MapExit -> AccessNode -> MapEntry, and based on the following rules, fuse them and remove the transient in between. There are several possibilities of what it does to this transient in between. Essentially, if there is some other place in the sdfg where it is required, or if it is not a transient, then it will not be removed. In such a case, it will be linked to the MapExit node of the new fused map. Rules for fusing maps: 0. The map range of the second map should be a permutation of the first map range. 1. Each of the access nodes that are adjacent to the first map exit should have an edge to the second map entry. If it doesn't, then the second map entry should not be reachable from this access node. 2. Any node that has a wcr from the first map exit should not be adjacent to the second map entry. 3. Access pattern for the access nodes in the second map should be the same permutation of the map parameters as the map ranges of the two maps. Alternatively, this access node should not be adjacent to the first map entry. """ _first_map_exit = nodes.ExitNode() _some_array = nodes.AccessNode("_") _second_map_entry = nodes.EntryNode() @staticmethod def annotates_memlets(): return False @staticmethod def expressions(): return [ nxutil.node_path_graph( MapFusion._first_map_exit, MapFusion._some_array, MapFusion._second_map_entry, ) ] @staticmethod def find_permutation(first_map: nodes.Map, second_map: nodes.Map) -> Union[List[int], None]: """ Find permutation between two map ranges. :param first_map: First map. :param second_map: Second map. :return: None if no such permutation exists, otherwise a list of indices L such that L[x]'th parameter of second map has the same range as x'th parameter of the first map. """ result = [] if len(first_map.range) != len(second_map.range): return None # Match map ranges with reduce ranges for i, tmap_rng in enumerate(first_map.range): found = False for j, rng in enumerate(second_map.range): if tmap_rng == rng and j not in result: result.append(j) found = True break if not found: break # Ensure all map ranges matched if len(result) != len(first_map.range): return None return result @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_map_exit = graph.nodes()[candidate[MapFusion._first_map_exit]] first_map_entry = graph.entry_node(first_map_exit) second_map_entry = graph.nodes()[candidate[ MapFusion._second_map_entry]] for _in_e in graph.in_edges(first_map_exit): if _in_e.data.wcr is not None: for _out_e in graph.out_edges(second_map_entry): if _out_e.data.data == _in_e.data.data: # wcr is on a node that is used in the second map, quit return False # Check whether there is a pattern map -> access -> map. intermediate_nodes = set() intermediate_data = set() for _, _, dst, _, _ in graph.out_edges(first_map_exit): if isinstance(dst, nodes.AccessNode): intermediate_nodes.add(dst) intermediate_data.add(dst.data) # If array is used anywhere else in this state. num_occurrences = len([ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.data == dst.data ]) if num_occurrences > 1: return False else: return False # Check map ranges perm = MapFusion.find_permutation(first_map_entry.map, second_map_entry.map) if perm is None: return False # Create a dict that maps parameters of the first map to those of the # second map. params_dict = {} for _index, _param in enumerate(first_map_entry.map.params): params_dict[_param] = second_map_entry.map.params[perm[_index]] out_memlets = [e.data for e in graph.in_edges(first_map_exit)] # Check that input set of second map is provided by the output set # of the first map, or other unrelated maps for _, _, _, _, second_memlet in graph.out_edges(second_map_entry): # Memlets that do not come from one of the intermediate arrays if second_memlet.data not in intermediate_data: # however, if intermediate_data eventually leads to # second_memlet.data, need to fail. for _n in intermediate_nodes: source_node = _n # graph.find_node(_n.data) destination_node = graph.find_node(second_memlet.data) # NOTE: Assumes graph has networkx version if destination_node in nx.descendants( graph._nx, source_node): return False continue provided = False for first_memlet in out_memlets: if first_memlet.data != second_memlet.data: continue # If there is an equivalent subset, it is provided expected_second_subset = [] for _tup in first_memlet.subset: new_tuple = [] if isinstance(_tup, symbolic.symbol): new_tuple = symbolic.symbol(params_dict[str(_tup)]) elif isinstance(_tup, (list, tuple)): for _sym in _tup: if (isinstance(_sym, symbolic.symbol) and str(_sym) in params_dict): new_tuple.append( symbolic.symbol(params_dict[str(_sym)])) else: new_tuple.append(_sym) new_tuple = tuple(new_tuple) else: new_tuple = _tup expected_second_subset.append(new_tuple) if expected_second_subset == list(second_memlet.subset): provided = True break # If none of the output memlets of the first map provide the info, # fail. if provided is False: return False # Success return True @staticmethod def match_to_str(graph, candidate): first_exit = graph.nodes()[candidate[MapFusion._first_map_exit]] second_entry = graph.nodes()[candidate[MapFusion._second_map_entry]] return " -> ".join(entry.map.label + ": " + str(entry.map.params) for entry in [first_exit, second_entry]) def apply(self, sdfg): """ This method applies the mapfusion transformation. Other than the removal of the second map entry node (SME), and the first map exit (FME) node, it has the following side effects: 1. Any transient adjacent to both FME and SME with degree = 2 will be removed. The tasklets that use/produce it shall be connected directly with a scalar/new transient (if the dataflow is more than a single scalar) 2. If this transient is adjacent to FME and SME and has other uses, it will be adjacent to the new map exit post fusion. Tasklet-> Tasklet edges will ALSO be added as mentioned above. 3. If an access node is adjacent to FME but not SME, it will be adjacent to new map exit post fusion. 4. If an access node is adjacent to SME but not FME, it will be adjacent to the new map entry node post fusion. """ graph = sdfg.nodes()[self.state_id] first_exit = graph.nodes()[self.subgraph[MapFusion._first_map_exit]] first_entry = graph.entry_node(first_exit) second_entry = graph.nodes()[self.subgraph[ MapFusion._second_map_entry]] second_exit = graph.exit_nodes(second_entry)[0] intermediate_nodes = set() for _, _, dst, _, _ in graph.out_edges(first_exit): intermediate_nodes.add(dst) assert isinstance(dst, nodes.AccessNode) # Check if an access node refers to non transient memory, or transient # is used at another location (cannot erase) do_not_erase = set() for node in intermediate_nodes: if sdfg.arrays[node.data].transient is False: do_not_erase.add(node) else: for edge in graph.in_edges(node): if edge.src != first_exit: do_not_erase.add(node) break else: for edge in graph.out_edges(node): if edge.dst != second_entry: do_not_erase.add(node) break # Find permutation between first and second scopes perm = MapFusion.find_permutation(first_entry.map, second_entry.map) params_dict = {} for index, param in enumerate(first_entry.map.params): params_dict[param] = second_entry.map.params[perm[index]] # Replaces (in memlets and tasklet) the second scope map # indices with the permuted first map indices. # This works in two passes to avoid problems when e.g., exchanging two # parameters (instead of replacing (j,i) and (i,j) to (j,j) and then # i,i). second_scope = graph.scope_subgraph(second_entry) for firstp, secondp in params_dict.items(): if firstp != secondp: replace(second_scope, secondp, '__' + secondp + '_fused') for firstp, secondp in params_dict.items(): if firstp != secondp: replace(second_scope, '__' + secondp + '_fused', firstp) # Isolate First exit node ############################ edges_to_remove = set() nodes_to_remove = set() for edge in graph.in_edges(first_exit): memlet_path = graph.memlet_path(edge) edge_index = next(i for i, e in enumerate(memlet_path) if e == edge) access_node = memlet_path[-1].dst if access_node not in do_not_erase: out_edges = [ e for e in graph.out_edges(access_node) if e.dst == second_entry ] # In this transformation, there can only be one edge to the # second map assert len(out_edges) == 1 # Get source connector to the second map connector = out_edges[0].dst_conn[3:] new_dst = None new_dst_conn = None # Look at the second map entry out-edges to get the new # destination for _e in graph.out_edges(second_entry): if _e.src_conn[4:] == connector: new_dst = _e.dst new_dst_conn = _e.dst_conn break if new_dst is None: # Access node is not used in the second map nodes_to_remove.add(access_node) continue # If the source is an access node, modify the memlet to point # to it if (isinstance(edge.src, nodes.AccessNode) and edge.data.data != edge.src.data): edge.data.data = edge.src.data edge.data.subset = ("0" if edge.data.other_subset is None else edge.data.other_subset) edge.data.other_subset = None else: # Add a transient scalar/array self.fuse_nodes(sdfg, graph, edge, new_dst, new_dst_conn) edges_to_remove.add(edge) # Remove transient node between the two maps nodes_to_remove.add(access_node) else: # The case where intermediate array node cannot be removed # Node will become an output of the second map exit out_e = memlet_path[edge_index + 1] conn = second_exit.next_connector() graph.add_edge( second_exit, 'OUT_' + conn, out_e.dst, out_e.dst_conn, dcpy(out_e.data), ) second_exit.add_out_connector('OUT_' + conn) graph.add_edge(edge.src, edge.src_conn, second_exit, 'IN_' + conn, dcpy(edge.data)) second_exit.add_in_connector('IN_' + conn) edges_to_remove.add(out_e) # If the second map needs this node, link the connector # that generated this to the place where it is needed, with a # temp transient/scalar for memlet to be generated for out_e in graph.out_edges(second_entry): second_memlet_path = graph.memlet_path(out_e) source_node = second_memlet_path[0].src if source_node == access_node: self.fuse_nodes(sdfg, graph, edge, out_e.dst, out_e.dst_conn) edges_to_remove.add(edge) ### # First scope exit is isolated and can now be safely removed for e in edges_to_remove: graph.remove_edge(e) graph.remove_nodes_from(nodes_to_remove) graph.remove_node(first_exit) # Isolate second_entry node ########################### for edge in graph.in_edges(second_entry): memlet_path = graph.memlet_path(edge) edge_index = next(i for i, e in enumerate(memlet_path) if e == edge) access_node = memlet_path[0].src if access_node in intermediate_nodes: # Already handled above, can be safely removed graph.remove_edge(edge) continue # This is an external input to the second map which will now go # through the first map. conn = first_entry.next_connector() graph.add_edge(edge.src, edge.src_conn, first_entry, 'IN_' + conn, dcpy(edge.data)) first_entry.add_in_connector('IN_' + conn) graph.remove_edge(edge) out_e = memlet_path[edge_index + 1] graph.add_edge( first_entry, 'OUT_' + conn, out_e.dst, out_e.dst_conn, dcpy(out_e.data), ) first_entry.add_out_connector('OUT_' + conn) graph.remove_edge(out_e) ### # Second node is isolated and can now be safely removed graph.remove_node(second_entry) # Fix scope exit to point to the right map second_exit.map = first_entry.map def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn): """ Fuses two nodes via memlets and possibly transient arrays. """ memlet_path = graph.memlet_path(edge) access_node = memlet_path[-1].dst local_name = "__s%d_n%d%s_n%d%s" % ( self.state_id, graph.node_id(edge.src), edge.src_conn, graph.node_id(edge.dst), edge.dst_conn, ) # Add intermediate memory between subgraphs. If a scalar, # uses direct connection. If an array, adds a transient node if edge.data.subset.num_elements() == 1: sdfg.add_scalar( local_name, dtype=access_node.desc(graph).dtype, transient=True, storage=dtypes.StorageType.Register, ) edge.data.data = local_name edge.data.subset = "0" local_node = edge.src src_connector = edge.src_conn else: sdfg.add_transient(local_name, edge.data.subset.size(), dtype=access_node.desc(graph).dtype) local_node = graph.add_access(local_name) src_connector = None edge.data.data = local_name edge.data.subset = ",".join( ["0:" + str(s) for s in edge.data.subset.size()]) # Add edge that leads to transient node graph.add_edge( edge.src, edge.src_conn, local_node, None, dcpy(edge.data), ) ######## # Add edge that leads to the second node graph.add_edge(local_node, src_connector, new_dst, new_dst_conn, dcpy(edge.data))
def __init__(self, *args, **kwargs): self.entry = nodes.EntryNode() self.tasklet = nodes.Tasklet('_') self.exit = nodes.ExitNode() self.pairs = None super().__init__(*args, **kwargs)
def __init__(self, *args, **kwargs): self._entry = nodes.EntryNode() self._tasklet = nodes.Tasklet('_') self._exit = nodes.ExitNode() super().__init__(*args, **kwargs)
class MapFission(pattern_matching.Transformation): """ Implements the MapFission transformation. Map fission refers to subsuming a map scope into its internal subgraph, essentially replicating the map into maps in all of its internal components. This also extends the dimensions of "border" transient arrays (i.e., those between the maps), in order to retain program semantics after fission. There are two cases that match map fission: 1. A map with an arbitrary subgraph with more than one computational (i.e., non-access) node. The use of arrays connecting the computational nodes must be limited to the subgraph, and non transient arrays may not be used as "border" arrays. 2. A map with one internal node that is a nested SDFG, in which each state matches the conditions of case (1). If a map has nested SDFGs in its subgraph, they are not considered in the case (1) above, and MapFission must be invoked again on the maps with the nested SDFGs in question. """ _map_entry = nodes.EntryNode() _nested_sdfg = nodes.NestedSDFG("", OrderedDiGraph(), set(), set()) @staticmethod def annotates_memlets(): return False @staticmethod def expressions(): return [ nxutil.node_path_graph(MapFission._map_entry, ), nxutil.node_path_graph( MapFission._map_entry, MapFission._nested_sdfg, ) ] @staticmethod def _components( subgraph: sd.SubgraphView) -> List[Tuple[nodes.Node, nodes.Node]]: """ Returns the list of tuples non-array components in this subgraph. Each element in the list is a 2 tuple of (input node, output node) of the component. """ graph = (subgraph if isinstance(subgraph, sd.SDFGState) else subgraph.graph) sdict = subgraph.scope_dict(node_to_children=True) ns = [(n, graph.exit_nodes(n)[0]) if isinstance(n, nodes.EntryNode) else (n, n) for n in sdict[None] if isinstance(n, (nodes.CodeNode, nodes.EntryNode))] return ns @staticmethod def _border_arrays(sdfg, parent, subgraph): """ Returns a set of array names that are local to the fission subgraph. """ nested = isinstance(parent, sd.SDFGState) sdict = subgraph.scope_dict(node_to_children=True) subset = sd.SubgraphView(parent, sdict[None]) if nested: return set(node.data for node in subset.nodes() if isinstance(node, nodes.AccessNode) and sdfg.arrays[node.data].transient) else: return set(node.data for node in subset.nodes() if isinstance(node, nodes.AccessNode)) @staticmethod def _internal_border_arrays(total_components, subgraphs): """ Returns the set of border arrays that appear between computational components (i.e., without sources and sinks). """ inputs = set() outputs = set() for components, subgraph in zip(total_components, subgraphs): for component_in, component_out in components: for e in subgraph.in_edges(component_in): if isinstance(e.src, nodes.AccessNode): inputs.add(e.src.data) for e in subgraph.out_edges(component_out): if isinstance(e.dst, nodes.AccessNode): outputs.add(e.dst.data) return inputs & outputs @staticmethod def _outside_map(node, scope_dict, entry_nodes): """ Returns True iff node is not in any of the scopes spanned by entry_nodes. """ while scope_dict[node] is not None: if scope_dict[node] in entry_nodes: return False node = scope_dict[node] return True @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): map_node = graph.node(candidate[MapFission._map_entry]) nsdfg_node = None # If the map is dynamic-ranged, the resulting border arrays would be # dynamically sized if sd.has_dynamic_map_inputs(graph, map_node): return False if expr_index == 0: # Map with subgraph subgraphs = [ graph.scope_subgraph(map_node, include_entry=False, include_exit=False) ] else: # Map with nested SDFG nsdfg_node = graph.node(candidate[MapFission._nested_sdfg]) # Make sure there are no other internal nodes in the map if len(set(e.dst for e in graph.out_edges(map_node))) > 1: return False subgraphs = list(nsdfg_node.sdfg.nodes()) # Test subgraphs border_arrays = set() total_components = [] for sg in subgraphs: components = MapFission._components(sg) snodes = sg.nodes() # Test that the subgraphs have more than one computational component if expr_index == 0 and len(snodes) > 0 and len(components) <= 1: return False # Test that the components are connected by transients that are not # used anywhere else border_arrays |= MapFission._border_arrays( nsdfg_node.sdfg if expr_index == 1 else sdfg, sg if expr_index == 1 else graph, sg) total_components.append(components) # In nested SDFGs and subgraphs, ensure none of the border # values are non-transients for array in border_arrays: if expr_index == 0: ndesc = sdfg.arrays[array] else: ndesc = nsdfg_node.sdfg.arrays[array] if ndesc.transient is False: return False # In subgraphs, make sure transients are not used/allocated # in other scopes or states if expr_index == 0: # Find all nodes not in subgraph not_subgraph = set( n.data for n in graph.nodes() if n not in snodes and isinstance(n, nodes.AccessNode)) not_subgraph.update( set(n.data for s in sdfg.nodes() if s != graph for n in s.nodes() if isinstance(n, nodes.AccessNode))) for _, component_out in components: for e in sg.out_edges(component_out): if isinstance(e.dst, nodes.AccessNode): if e.dst.data in not_subgraph: return False # Fail if there are arrays inside the map that are not a direct # output of a computational component # TODO(later): Support this case? Ambiguous array sizes and memlets external_arrays = ( border_arrays - MapFission._internal_border_arrays(total_components, subgraphs)) if len(external_arrays) > 0: return False return True @staticmethod def match_to_str(graph, candidate): map_entry = graph.node(candidate[MapFission._map_entry]) return map_entry.map.label def apply(self, sdfg: sd.SDFG): graph: sd.SDFGState = sdfg.nodes()[self.state_id] map_entry = graph.node(self.subgraph[MapFission._map_entry]) map_exit = graph.exit_nodes(map_entry)[0] nsdfg_node: Optional[nodes.NestedSDFG] = None # Obtain subgraph to perform fission to if self.expr_index == 0: # Map with subgraph subgraphs = [(graph, graph.scope_subgraph(map_entry, include_entry=False, include_exit=False))] parent = sdfg else: # Map with nested SDFG nsdfg_node = graph.node(self.subgraph[MapFission._nested_sdfg]) subgraphs = [(state, state) for state in nsdfg_node.sdfg.nodes()] parent = nsdfg_node.sdfg modified_arrays = set() # Get map information outer_map: nodes.Map = map_entry.map mapsize = outer_map.range.size() # Add new symbols from outer map to nested SDFG if self.expr_index == 1: map_syms = outer_map.range.free_symbols for edge in graph.out_edges(map_entry): if edge.data.data: map_syms.update(edge.data.subset.free_symbols) for edge in graph.in_edges(map_exit): if edge.data.data: map_syms.update(edge.data.subset.free_symbols) for symname, sym in map_syms.items(): if symname in outer_map.params: continue if symname not in nsdfg_node.symbol_mapping.keys(): nsdfg_node.symbol_mapping[symname] = sym for state, subgraph in subgraphs: components = MapFission._components(subgraph) sources = subgraph.source_nodes() sinks = subgraph.sink_nodes() # Collect external edges if self.expr_index == 0: external_edges_entry = list(state.out_edges(map_entry)) external_edges_exit = list(state.in_edges(map_exit)) else: external_edges_entry = [ e for e in subgraph.edges() if (isinstance(e.src, nodes.AccessNode) and not nsdfg_node.sdfg.arrays[e.src.data].transient) ] external_edges_exit = [ e for e in subgraph.edges() if (isinstance(e.dst, nodes.AccessNode) and not nsdfg_node.sdfg.arrays[e.dst.data].transient) ] # Map external edges to outer memlets edge_to_outer = {} for edge in external_edges_entry: if self.expr_index == 0: # Subgraphs use the corresponding outer map edges path = state.memlet_path(edge) eindex = path.index(edge) edge_to_outer[edge] = path[eindex - 1] else: # Nested SDFGs use the internal map edges of the node outer_edge = next(e for e in graph.in_edges(nsdfg_node) if e.dst_conn == edge.src.data) edge_to_outer[edge] = outer_edge for edge in external_edges_exit: if self.expr_index == 0: path = state.memlet_path(edge) eindex = path.index(edge) edge_to_outer[edge] = path[eindex + 1] else: # Nested SDFGs use the internal map edges of the node outer_edge = next(e for e in graph.out_edges(nsdfg_node) if e.src_conn == edge.dst.data) edge_to_outer[edge] = outer_edge # Collect all border arrays and code->code edges arrays = MapFission._border_arrays( nsdfg_node.sdfg if self.expr_index == 1 else sdfg, state, subgraph) scalars = defaultdict(list) for _, component_out in components: for e in subgraph.out_edges(component_out): if isinstance(e.dst, nodes.CodeNode): scalars[e.data.data].append(e) # Create new arrays for scalars for scalar, edges in scalars.items(): desc = parent.arrays[scalar] name, newdesc = parent.add_temp_transient( mapsize, desc.dtype, desc.storage, toplevel=desc.toplevel, debuginfo=desc.debuginfo, allow_conflicts=desc.allow_conflicts) # Add extra nodes in component boundaries for edge in edges: anode = state.add_access(name) state.add_edge( edge.src, edge.src_conn, anode, None, mm.Memlet( name, outer_map.range.num_elements(), subsets.Range.from_string(','.join( outer_map.params)), 1)) state.add_edge( anode, None, edge.dst, edge.dst_conn, mm.Memlet( name, outer_map.range.num_elements(), subsets.Range.from_string(','.join( outer_map.params)), 1)) state.remove_edge(edge) # Add extra maps around components new_map_entries = [] for component_in, component_out in components: me, mx = state.add_map(outer_map.label + '_fission', [(p, '0:1') for p in outer_map.params], outer_map.schedule, unroll=outer_map.unroll, debuginfo=outer_map.debuginfo) # Add dynamic input connectors for conn in map_entry.in_connectors: if not conn.startswith('IN_'): me.add_in_connector(conn) me.map.range = dcpy(outer_map.range) new_map_entries.append(me) # Reconnect edges through new map for e in state.in_edges(component_in): state.add_edge(me, None, e.dst, e.dst_conn, dcpy(e.data)) # Reconnect inner edges at source directly to external nodes if self.expr_index == 0 and e in external_edges_entry: state.add_edge(edge_to_outer[e].src, edge_to_outer[e].src_conn, me, None, dcpy(edge_to_outer[e].data)) else: state.add_edge(e.src, e.src_conn, me, None, dcpy(e.data)) state.remove_edge(e) # Empty memlet edge in nested SDFGs if state.in_degree(component_in) == 0: state.add_edge(me, None, component_in, None, mm.EmptyMemlet()) for e in state.out_edges(component_out): state.add_edge(e.src, e.src_conn, mx, None, dcpy(e.data)) # Reconnect inner edges at sink directly to external nodes if self.expr_index == 0 and e in external_edges_exit: state.add_edge(mx, None, edge_to_outer[e].dst, edge_to_outer[e].dst_conn, dcpy(edge_to_outer[e].data)) else: state.add_edge(mx, None, e.dst, e.dst_conn, dcpy(e.data)) state.remove_edge(e) # Empty memlet edge in nested SDFGs if state.out_degree(component_out) == 0: state.add_edge(component_out, None, mx, None, mm.EmptyMemlet()) # Connect other sources/sinks not in components (access nodes) # directly to external nodes if self.expr_index == 0: for node in sources: if isinstance(node, nodes.AccessNode): for edge in state.in_edges(node): outer_edge = edge_to_outer[edge] memlet = dcpy(edge.data) memlet.subset = subsets.Range( outer_map.range.ranges + memlet.subset.ranges) state.add_edge(outer_edge.src, outer_edge.src_conn, edge.dst, edge.dst_conn, memlet) for node in sinks: if isinstance(node, nodes.AccessNode): for edge in state.out_edges(node): outer_edge = edge_to_outer[edge] state.add_edge(edge.src, edge.src_conn, outer_edge.dst, outer_edge.dst_conn, dcpy(outer_edge.data)) # Augment arrays by prepending map dimensions for array in arrays: if array in modified_arrays: continue desc = parent.arrays[array] for sz in reversed(mapsize): desc.strides = [desc.total_size] + list(desc.strides) desc.total_size = desc.total_size * sz desc.shape = mapsize + list(desc.shape) desc.offset = [0] * len(mapsize) + list(desc.offset) modified_arrays.add(array) # Fill scope connectors so that memlets can be tracked below state.fill_scope_connectors() # Correct connectors and memlets in nested SDFGs to account for # missing outside map if self.expr_index == 1: to_correct = ([(e, e.src) for e in external_edges_entry] + [(e, e.dst) for e in external_edges_exit]) corrected_nodes = set() for edge, node in to_correct: if isinstance(node, nodes.AccessNode): if node in corrected_nodes: continue corrected_nodes.add(node) outer_edge = edge_to_outer[edge] desc = parent.arrays[node.data] # Modify shape of internal array to match outer one outer_desc = sdfg.arrays[outer_edge.data.data] if not isinstance(desc, dt.Scalar): desc.shape = outer_desc.shape if isinstance(desc, dt.Array): desc.strides = outer_desc.strides desc.total_size = outer_desc.total_size # Inside the nested SDFG, offset all memlets to include # the offsets from within the map. # NOTE: Relies on propagation to fix outer memlets for internal_edge in state.all_edges(node): for e in state.memlet_tree(internal_edge): e.data.subset.offset(desc.offset, False) e.data.subset = helpers.unsqueeze_memlet( e.data, outer_edge.data).subset # Only after offsetting memlets we can modify the # overall offset if isinstance(desc, dt.Array): desc.offset = outer_desc.offset # Fill in memlet trees for border transients # NOTE: Memlet propagation should run to correct the outer edges for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode) and node.data in arrays: for edge in state.all_edges(node): for e in state.memlet_tree(edge): # Prepend map dimensions to memlet e.data.subset = subsets.Range( [(d, d, 1) for d in outer_map.params] + e.data.subset.ranges) # If nested SDFG, reconnect nodes around map and modify memlets if self.expr_index == 1: for edge in graph.in_edges(map_entry): if not edge.dst_conn or not edge.dst_conn.startswith('IN_'): continue # Modify edge coming into nested SDFG to include entire array desc = sdfg.arrays[edge.data.data] edge.data.subset = subsets.Range.from_array(desc) edge.data.num_accesses = edge.data.subset.num_elements() # Find matching edge inside map inner_edge = next( e for e in graph.out_edges(map_entry) if e.src_conn and e.src_conn[4:] == edge.dst_conn[3:]) graph.add_edge(edge.src, edge.src_conn, nsdfg_node, inner_edge.dst_conn, dcpy(edge.data)) for edge in graph.out_edges(map_exit): # Modify edge coming out of nested SDFG to include entire array desc = sdfg.arrays[edge.data.data] edge.data.subset = subsets.Range.from_array(desc) # Find matching edge inside map inner_edge = next(e for e in graph.in_edges(map_exit) if e.dst_conn[3:] == edge.src_conn[4:]) graph.add_edge(nsdfg_node, inner_edge.src_conn, edge.dst, edge.dst_conn, dcpy(edge.data)) # Remove outer map graph.remove_nodes_from([map_entry, map_exit])
class MapFusion(pattern_matching.Transformation): """ Implements the MapFusion transformation. It wil check for all patterns MapExit -> AccessNode -> MapEntry, and based on the following rules, fuse them and remove the transient in between. There are several possibilities of what it does to this transient in between. Essentially, if there is some other place in the sdfg where it is required, or if it is not a transient, then it will not be removed. In such a case, it will be linked to the MapExit node of the new fused map. Rules for fusing maps: 0. The map range of the second map should be a permutation of the first map range. 1. Each of the access nodes that are adjacent to the first map exit should have an edge to the second map entry. If it doesn't, then the second map entry should not be reachable from this access node. 2. Any node that has a wcr from the first map exit should not be adjacent to the second map entry. 3. Access pattern for the access nodes in the second map should be the same permutation of the map parameters as the map ranges of the two maps. Alternatively, this access node should not be adjacent to the first map entry. """ _first_map_exit = nodes.ExitNode() _some_array = nodes.AccessNode("_") _second_map_entry = nodes.EntryNode() @staticmethod def annotates_memlets(): return False @staticmethod def expressions(): return [ nxutil.node_path_graph( MapFusion._first_map_exit, MapFusion._some_array, MapFusion._second_map_entry, ) ] @staticmethod def find_permutation(first_map: nodes.Map, second_map: nodes.Map) -> Union[List[int], None]: """ Find permutation between two map ranges. @param first_map: First map. @param second_map: Second map. @return: None if no such permutation exists, otherwise a list of indices L such that L[x]'th parameter of second map has the same range as x'th parameter of the first map. """ result = [] if len(first_map.range) != len(second_map.range): return None # Match map ranges with reduce ranges for i, tmap_rng in enumerate(first_map.range): found = False for j, rng in enumerate(second_map.range): if tmap_rng == rng and j not in result: result.append(j) found = True break if not found: break # Ensure all map ranges matched if len(result) != len(first_map.range): return None return result @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_map_exit = graph.nodes()[candidate[MapFusion._first_map_exit]] first_map_entry = graph.entry_node(first_map_exit) second_map_entry = graph.nodes()[candidate[ MapFusion._second_map_entry]] for _in_e in graph.in_edges(first_map_exit): if _in_e.data.wcr is not None: for _out_e in graph.out_edges(second_map_entry): if _out_e.data.data == _in_e.data.data: # wcr is on a node that is used in the second map, quit return False # Check whether there is a pattern map -> access -> map. intermediate_nodes = set() intermediate_data = set() for _, _, dst, _, _ in graph.out_edges(first_map_exit): if isinstance(dst, nodes.AccessNode): intermediate_nodes.add(dst) intermediate_data.add(dst.data) else: return False # Check map ranges perm = MapFusion.find_permutation(first_map_entry.map, second_map_entry.map) if perm is None: return False # Create a dict that maps parameters of the first map to those of the # second map. params_dict = {} for _index, _param in enumerate(first_map_entry.map.params): params_dict[_param] = second_map_entry.map.params[perm[_index]] out_memlets = [e.data for e in graph.in_edges(first_map_exit)] # Check that input set of second map is provided by the output set # of the first map, or other unrelated maps for _, _, _, _, second_memlet in graph.out_edges(second_map_entry): # Memlets that do not come from one of the intermediate arrays if second_memlet.data not in intermediate_data: # however, if intermediate_data eventually leads to # second_memlet.data, need to fail. for _n in intermediate_nodes: source_node = _n # graph.find_node(_n.data) destination_node = graph.find_node(second_memlet.data) # NOTE: Assumes graph has networkx version if destination_node in nx.descendants( graph._nx, source_node): return False continue provided = False for first_memlet in out_memlets: if first_memlet.data != second_memlet.data: continue # If there is an equivalent subset, it is provided expected_second_subset = [] for _tup in first_memlet.subset: new_tuple = [] if isinstance(_tup, symbolic.symbol): new_tuple = symbolic.symbol(params_dict[str(_tup)]) elif isinstance(_tup, (list, tuple)): for _sym in _tup: if isinstance(_sym, symbolic.symbol): new_tuple.append( symbolic.symbol(params_dict[str(_sym)])) else: new_tuple.append(_sym) new_tuple = tuple(new_tuple) else: new_tuple = _tup expected_second_subset.append(new_tuple) if expected_second_subset == list(second_memlet.subset): provided = True break # If none of the output memlets of the first map provide the info, # fail. if provided is False: return False # Success return True @staticmethod def match_to_str(graph, candidate): first_exit = graph.nodes()[candidate[MapFusion._first_map_exit]] second_entry = graph.nodes()[candidate[MapFusion._second_map_entry]] return " -> ".join(entry.map.label + ": " + str(entry.map.params) for entry in [first_exit, second_entry]) def apply(self, sdfg): """ This method applies the mapfusion transformation. Other than the removal of the second map entry node (SME), and the first map exit (FME) node, it has the following side effects: 1. Any transient adjacent to both FME and SME with degree = 2 will be removed. The tasklets that use/produce it shall be connected directly with a scalar/new transient (if the dataflow is more than a single scalar) 2. If this transient is adjacent to FME and SME and has other uses, it will be adjacent to the new map exit post fusion. Tasklet-> Tasklet edges will ALSO be added as mentioned above. 3. If an access node is adjacent to FME but not SME, it will be adjacent to new map exit post fusion. 4. If an access node is adjacent to SME but not FME, it will be adjacent to the new map entry node post fusion. """ graph = sdfg.nodes()[self.state_id] first_exit = graph.nodes()[self.subgraph[MapFusion._first_map_exit]] first_entry = graph.entry_node(first_exit) second_entry = graph.nodes()[self.subgraph[ MapFusion._second_map_entry]] second_exit = graph.exit_nodes(second_entry)[0] intermediate_nodes = set() for _, _, dst, _, _ in graph.out_edges(first_exit): intermediate_nodes.add(dst) assert isinstance(dst, nodes.AccessNode) # Check if an access node refers to non transient memory, or transient # is used at another location (cannot erase) do_not_erase = set() for node in intermediate_nodes: if sdfg.arrays[node.data].transient is False: do_not_erase.add(node) else: # If array is used anywhere else in this state. num_occurrences = len([ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.data == node.data ]) if num_occurrences > 1: return False for edge in graph.in_edges(node): if edge.src != first_exit: do_not_erase.add(node) break else: for edge in graph.out_edges(node): if edge.dst != second_entry: do_not_erase.add(node) break # Find permutation between first and second scopes if first_entry.map.params != second_entry.map.params: perm = MapFusion.find_permutation(first_entry.map, second_entry.map) params_dict = {} for _index, _param in enumerate(first_entry.map.params): params_dict[_param] = second_entry.map.params[perm[_index]] # Hopefully replaces (in memlets and tasklet) the second scope map # indices with the permuted first map indices second_scope = graph.scope_subgraph(second_entry) for _firstp, _secondp in params_dict.items(): replace(second_scope, _secondp, _firstp) ########Isolate First MapExit node########### for _edge in graph.in_edges(first_exit): __some_str = _edge.data.data _access_node = graph.find_node(__some_str) # all outputs of first_exit are in intermediate_nodes set, so all inputs to # first_exit should also be! if _access_node not in do_not_erase: _new_dst = None _new_dst_conn = None # look at the second map entry out-edges to get the new destination for _e in graph.out_edges(second_entry): if _e.data.data == _access_node.data: _new_dst = _e.dst _new_dst_conn = _e.dst_conn break if _new_dst is None: # Access node is not even used in the second map graph.remove_node(_access_node) continue if _edge.data.data == _access_node.data and isinstance( _edge._src, nodes.AccessNode): _edge.data.data = _edge._src.data _edge.data.subset = "0" graph.add_edge( _edge._src, _edge.src_conn, _new_dst, _new_dst_conn, dcpy(_edge.data), ) else: if _edge.data.subset.num_elements() == 1: # We will add a scalar local_name = "__s%d_n%d%s_n%d%s" % ( self.state_id, graph.node_id(_edge._src), _edge.src_conn, graph.node_id(_edge._dst), _edge.dst_conn, ) local_node = sdfg.add_scalar( local_name, dtype=_access_node.desc(graph).dtype, toplevel=False, transient=True, storage=dtypes.StorageType.Register, ) _edge.data.data = ( local_name) # graph.add_access(local_name).data _edge.data.subset = "0" graph.add_edge( _edge._src, _edge.src_conn, _new_dst, _new_dst_conn, dcpy(_edge.data), ) else: # We will add a transient of size = memlet subset # size local_name = "__s%d_n%d%s_n%d%s" % ( self.state_id, graph.node_id(_edge._src), _edge.src_conn, graph.node_id(_edge._dst), _edge.dst_conn, ) local_node = graph.add_transient( local_name, _edge.data.subset.size(), dtype=_access_node.desc(graph).dtype, toplevel=False, ) _edge.data.data = ( local_name) # graph.add_access(local_name).data _edge.data.subset = ",".join([ "0:" + str(_s) for _s in _edge.data.subset.size() ]) graph.add_edge( _edge._src, _edge.src_conn, local_node, None, dcpy(_edge.data), ) graph.add_edge(local_node, None, _new_dst, _new_dst_conn, dcpy(_edge.data)) graph.remove_edge(_edge) ####Isolate this node##### for _in_e in graph.in_edges(_access_node): graph.remove_edge(_in_e) for _out_e in graph.out_edges(_access_node): graph.remove_edge(_out_e) graph.remove_node(_access_node) else: # _access_node will become an output of the second map exit for _out_e in graph.out_edges(first_exit): if _out_e.data.data == _access_node.data: graph.add_edge( second_exit, None, _out_e._dst, _out_e.dst_conn, dcpy(_out_e.data), ) graph.remove_edge(_out_e) break else: raise AssertionError( "No out-edge was found that leads to {}".format( _access_node)) graph.add_edge(_edge._src, _edge.src_conn, second_exit, None, dcpy(_edge.data)) ### If the second map needs this node then link the connector # that generated this to the place where it is needed, with a # temp transient/scalar for memlet to be generated for _out_e in graph.out_edges(second_entry): if _out_e.data.data == _access_node.data: if _edge.data.subset.num_elements() == 1: # We will add a scalar local_name = "__s%d_n%d%s_n%d%s" % ( self.state_id, graph.node_id(_edge._src), _edge.src_conn, graph.node_id(_edge._dst), _edge.dst_conn, ) local_node = sdfg.add_scalar( local_name, dtype=_access_node.desc(graph).dtype, storage=dtypes.StorageType.Register, toplevel=False, transient=True, ) _edge.data.data = ( local_name ) # graph.add_access(local_name).data _edge.data.subset = "0" graph.add_edge( _edge._src, _edge.src_conn, _out_e._dst, _out_e.dst_conn, dcpy(_edge.data), ) else: # We will add a transient of size = memlet subset # size local_name = "__s%d_n%d%s_n%d%s" % ( self.state_id, graph.node_id(_edge._src), _edge.src_conn, graph.node_id(_edge._dst), _edge.dst_conn, ) local_node = sdfg.add_transient( local_name, _edge.data.subset.size(), dtype=_access_node.desc(graph).dtype, toplevel=False, ) _edge.data.data = ( local_name ) # graph.add_access(local_name).data _edge.data.subset = ",".join([ "0:" + str(_s) for _s in _edge.data.subset.size() ]) graph.add_edge( _edge._src, _edge.src_conn, local_node, None, dcpy(_edge.data), ) graph.add_edge( local_node, None, _out_e._dst, _out_e.dst_conn, dcpy(_edge.data), ) break graph.remove_edge(_edge) graph.remove_node(first_exit) # Take a leap of faith #############Isolate second_entry node################ for _edge in graph.in_edges(second_entry): _access_node = graph.find_node(_edge.data.data) if _access_node in intermediate_nodes: # Already handled above, just remove this graph.remove_edge(_edge) continue else: # This is an external input to the second map which will now go through the first # map. graph.add_edge(_edge._src, _edge.src_conn, first_entry, None, dcpy(_edge.data)) graph.remove_edge(_edge) for _out_e in graph.out_edges(second_entry): if _out_e.data.data == _access_node.data: graph.add_edge( first_entry, None, _out_e._dst, _out_e.dst_conn, dcpy(_out_e.data), ) graph.remove_edge(_out_e) break else: raise AssertionError( "No out-edge was found that leads to {}".format( _access_node)) graph.remove_node(second_entry) # Fix scope exit second_exit.map = first_entry.map graph.fill_scope_connectors()
class MergeArrays(pattern_matching.Transformation): """ Merge duplicate arrays connected to the same scope entry. """ _array1 = nodes.AccessNode("_") _array2 = nodes.AccessNode("_") _map_entry = nodes.EntryNode() @staticmethod def expressions(): # Matching # o o # | | # /======\ g = SDFGState() g.add_node(MergeArrays._array1) g.add_node(MergeArrays._array2) g.add_node(MergeArrays._map_entry) g.add_edge(MergeArrays._array1, None, MergeArrays._map_entry, None, memlet.EmptyMemlet()) g.add_edge(MergeArrays._array2, None, MergeArrays._map_entry, None, memlet.EmptyMemlet()) return [g] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): arr1_id = candidate[MergeArrays._array1] arr2_id = candidate[MergeArrays._array2] # Ensure both arrays contain the same data arr1 = graph.node(arr1_id) arr2 = graph.node(arr2_id) if arr1.data != arr2.data: return False # Ensure only arr1's node ID contains incoming edges if graph.in_degree(arr1) == 0 and graph.in_degree(arr2) > 0: return False # Ensure arr1 and arr2's node IDs are ordered (avoid duplicates) if (graph.in_degree(arr1) == 0 and graph.in_degree(arr2) == 0 and arr1_id >= arr2_id): return False map = graph.node(candidate[MergeArrays._map_entry]) # If arr1's connector leads directly to map, skip it if all( e.dst_conn and not e.dst_conn.startswith('IN_') for e in graph.edges_between(arr1, map)): return False if (any(e.dst != map for e in graph.out_edges(arr1)) or any(e.dst != map for e in graph.out_edges(arr2))): return False # Ensure arr1 and arr2 are the first two incoming nodes (avoid further # duplicates) all_source_nodes = set( graph.node_id(e.src) for e in graph.in_edges(map) if e.src != arr1 and e.src != arr2 and e.dst_conn and e.dst_conn.startswith('IN_') and graph.in_degree(e.src) == 0) if any(nid < arr1_id or nid < arr2_id for nid in all_source_nodes): return False return True @staticmethod def match_to_str(graph, candidate): arr = graph.node(candidate[MergeArrays._array1]) map = graph.node(candidate[MergeArrays._map_entry]) return '%s (%d, %d) -> %s' % (arr.data, candidate[MergeArrays._array1], candidate[MergeArrays._array2], map.label) def apply(self, sdfg): graph = sdfg.node(self.state_id) array = graph.node(self.subgraph[MergeArrays._array1]) map = graph.node(self.subgraph[MergeArrays._map_entry]) map_edge = next(e for e in graph.out_edges(array) if e.dst == map) result_connector = map_edge.dst_conn[3:] # Find all other incoming access nodes without incoming edges source_edges = [ e for e in graph.in_edges(map) if isinstance(e.src, nodes.AccessNode) and e.src.data == array.data and e.src != array and e.dst_conn and e.dst_conn.startswith('IN_') and graph.in_degree(e.src) == 0 ] # Modify connectors to point to first array connectors_to_remove = set() for e in source_edges: connector = e.dst_conn[3:] connectors_to_remove.add(connector) for inner_edge in graph.out_edges(map): if inner_edge.src_conn[4:] == connector: inner_edge._src_conn = 'OUT_' + result_connector # Remove other nodes from state graph.remove_nodes_from(set(e.src for e in source_edges)) # Remove connectors from scope entry map.in_connectors -= set('IN_' + c for c in connectors_to_remove) map.out_connectors -= set('OUT_' + c for c in connectors_to_remove)