class TrivialTaskletElimination(transformation.SingleStateTransformation): """ Implements the Trivial-Tasklet Elimination pattern. Trivial-Tasklet Elimination removes tasklets that just copy the input to the output without WCR. """ read = transformation.PatternNode(nodes.AccessNode) tasklet = transformation.PatternNode(nodes.Tasklet) write = transformation.PatternNode(nodes.AccessNode) @classmethod def expressions(cls): return [sdutil.node_path_graph(cls.read, cls.tasklet, cls.write)] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): read = self.read tasklet = self.tasklet write = self.write # Do not apply on Streams if isinstance(sdfg.arrays[read.data], data.Stream): return False if isinstance(sdfg.arrays[write.data], data.Stream): return False if len(graph.in_edges(tasklet)) != 1: return False if len(graph.out_edges(tasklet)) != 1: return False if graph.edges_between(tasklet, write)[0].data.wcr: return False if len(tasklet.in_connectors) != 1: return False if len(tasklet.out_connectors) != 1: return False in_conn = list(tasklet.in_connectors.keys())[0] out_conn = list(tasklet.out_connectors.keys())[0] if tasklet.code.as_string != f'{out_conn} = {in_conn}': return False return True def apply(self, graph, sdfg): read = self.read tasklet = self.tasklet write = self.write in_edge = graph.edges_between(read, tasklet)[0] out_edge = graph.edges_between(tasklet, write)[0] graph.remove_edge(in_edge) graph.remove_edge(out_edge) out_edge.data.other_subset = in_edge.data.subset graph.add_nedge(read, write, out_edge.data) graph.remove_node(tasklet)
class FalseConditionElimination(transformation.MultiStateTransformation): """ If a state transition condition is always false, removes edge. """ state_a = transformation.PatternNode(sdfg.SDFGState) state_b = transformation.PatternNode(sdfg.SDFGState) @classmethod def expressions(cls): return [sdutil.node_path_graph(cls.state_a, cls.state_b)] def can_be_applied(self, graph: SDFG, expr_index, sdfg: SDFG, permissive=False): a: SDFGState = self.state_a b: SDFGState = self.state_b in_edges = graph.in_edges(b) # Only apply in cases where DeadStateElimination wouldn't if len(in_edges) <= 1: return False # Directed graph has only one edge between two nodes edge = graph.edges_between(a, b)[0] if edge.data.assignments: return False if edge.data.is_unconditional(): return False # Evaluate condition scond = edge.data.condition_sympy() if scond == False: return True return False def apply(self, _, sdfg: SDFG): a: SDFGState = self.state_a b: SDFGState = self.state_b edge = sdfg.edges_between(a, b)[0] sdfg.remove_edge(edge)
class StartStateElimination(transformation.MultiStateTransformation): """ Start-state elimination removes a redundant state that has one outgoing edge and no contents. This transformation applies only to nested SDFGs. """ start_state = transformation.PatternNode(SDFGState) @classmethod def expressions(cls): return [sdutil.node_path_graph(cls.start_state)] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): state = self.start_state # The transformation applies only to nested SDFGs if not graph.parent: return False # Only empty states can be eliminated if state.number_of_nodes() > 0: return False out_edges = graph.out_edges(state) in_edges = graph.in_edges(state) # If this is a start state, there are no incoming edges if len(in_edges) != 0: return False # We only match start states with one sink and no conditions if len(out_edges) != 1: return False edge = out_edges[0] if not edge.data.is_unconditional(): return False # Assignments that make descriptors into symbols cannot be eliminated for assign in edge.data.assignments.values(): if graph.arrays.keys() & symbolic.free_symbols_and_functions( assign): return False return True def apply(self, _, sdfg): state = self.start_state # Move assignments to the nested SDFG node's symbol mappings node = sdfg.parent_nsdfg_node edge = sdfg.out_edges(state)[0] for k, v in edge.data.assignments.items(): node.symbol_mapping[k] = v sdfg.remove_node(state)
class RedundantArrayCopying3(pm.SingleStateTransformation): """ Implements the redundant array removal transformation. Removes multiples of array B in pattern MapEntry -> B. """ map_entry = pm.PatternNode(nodes.MapEntry) out_array = pm.PatternNode(nodes.AccessNode) @classmethod def expressions(cls): return [sdutil.node_path_graph(cls.map_entry, cls.out_array)] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): map_entry = self.map_entry out_array = self.out_array # Ensure out degree is one (only one target, which is out_array) found = 0 for _, _, dst, _, _ in graph.out_edges(map_entry): if (isinstance(dst, nodes.AccessNode) and dst != out_array and dst.data == out_array.data): found += 1 return found > 0 def apply(self, graph, sdfg): map_entry = self.map_entry out_array = self.out_array for e1 in graph.out_edges(map_entry): dst = e1.dst if (isinstance(dst, nodes.AccessNode) and dst != out_array and dst.data == out_array.data): for e2 in graph.out_edges(dst): graph.add_edge(out_array, None, e2.dst, e2.dst_conn, e2.data) graph.remove_edge(e2) graph.remove_edge(e1) graph.remove_node(dst)
class TrueConditionElimination(transformation.MultiStateTransformation, transformation.SimplifyPass): """ If a state transition condition is always true, removes condition from edge. """ state_a = transformation.PatternNode(sdfg.SDFGState) state_b = transformation.PatternNode(sdfg.SDFGState) @classmethod def expressions(cls): return [sdutil.node_path_graph(cls.state_a, cls.state_b)] def can_be_applied(self, graph: SDFG, expr_index, sdfg: SDFG, permissive=False): a: SDFGState = self.state_a b: SDFGState = self.state_b # Directed graph has only one edge between two nodes edge = graph.edges_between(a, b)[0] if edge.data.is_unconditional(): return False # Evaluate condition scond = edge.data.condition_sympy() if scond == True: return True return False def apply(self, _, sdfg: SDFG): a: SDFGState = self.state_a b: SDFGState = self.state_b edge = sdfg.edges_between(a, b)[0] edge.data.condition = CodeBlock("1")
class Reduction1Operation(pm.Transformation): """ Detects reduction1 operations. """ map_entry = pm.PatternNode(nodes.MapEntry) @staticmethod def expressions(): return [sdutil.node_path_graph(Reduction1Operation.map_entry)] @staticmethod def can_be_applied(graph: dace.SDFGState, candidate: Dict[pm.PatternNode, int], expr_index: int, sdfg: dace.SDFG, permissive: bool = False): map_entry = graph.node(candidate[Reduction1Operation.map_entry]) map_exit = graph.exit_node(map_entry) params = [dace.symbol(p) for p in map_entry.map.params] outputs = dict() for _, _, _, _, m in graph.out_edges(map_exit): if not m.wcr: return False desc = sdfg.arrays[m.data] if desc not in outputs.keys(): outputs[desc] = [] outputs[desc].append(m.subset) for desc, accesses in outputs.items(): if isinstance(desc, dace.data.Scalar): continue elif isinstance(desc, (dace.data.Array, dace.data.View)): for a in accesses: if a.num_elements() != 1: return False else: return False return True @staticmethod def match_to_str(graph: dace.SDFGState, candidate: Dict[pm.PatternNode, int]) -> str: map_entry = graph.node(candidate[Reduction1Operation.map_entry]) return map_entry.map.label + ': ' + str(map_entry.map.params) def apply(self, sdfg: dace.SDFG): pass
class EndStateElimination(transformation.MultiStateTransformation, transformation.SimplifyPass): """ End-state elimination removes a redundant state that has one incoming edge and no contents. """ end_state = transformation.PatternNode(SDFGState) @classmethod def expressions(cls): return [sdutil.node_path_graph(cls.end_state)] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): state = self.end_state out_edges = graph.out_edges(state) in_edges = graph.in_edges(state) # If this is an end state, there are no outgoing edges if len(out_edges) != 0: return False # We only match end states with one source and no conditions if len(in_edges) != 1: return False edge = in_edges[0] if not edge.data.is_unconditional(): return False # Only empty states can be eliminated if state.number_of_nodes() > 0: return False return True def apply(self, _, sdfg): state = self.end_state # Handle orphan symbols (due to the deletion the incoming edge) edge = sdfg.in_edges(state)[0] sym_assign = edge.data.assignments.keys() sdfg.remove_node(state) # Remove orphan symbols for sym in sym_assign: if sym in sdfg.free_symbols: sdfg.remove_symbol(sym)
class DeadStateElimination(transformation.MultiStateTransformation): """ Dead state elimination removes an unreachable state and all of its dominated states. """ end_state = transformation.PatternNode(sdfg.SDFGState) @classmethod def expressions(cls): return [sdutil.node_path_graph(cls.end_state)] def can_be_applied(self, graph: SDFG, expr_index, sdfg: SDFG, permissive=False): state: SDFGState = self.end_state in_edges = graph.in_edges(state) # We only match end states with one source and at least one assignment if len(in_edges) != 1: return False edge = in_edges[0] if edge.data.assignments: return False if edge.data.is_unconditional(): return False # Evaluate condition scond = edge.data.condition_sympy() if scond == False: return True return False def apply(self, _, sdfg: SDFG): # Remove state and all dominated states state = self.end_state domset = cfg.all_dominators(sdfg) states_to_remove = {k for k, v in domset.items() if state in v} states_to_remove.add(state) sdfg.remove_nodes_from(states_to_remove)
class MapDimShuffle(transformation.Transformation): """ Implements the map-dim shuffle transformation. MapDimShuffle takes a map and a list of params. It reorders the dimensions in the map such that it matches the list. """ _map_entry = transformation.PatternNode(nodes.MapEntry) # Properties parameters = ShapeProperty(dtype=list, default=None, desc="Desired order of map parameters") @staticmethod def expressions(): return [sdutil.node_path_graph(MapDimShuffle._map_entry)] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False): return True @staticmethod def match_to_str(graph, candidate): map_entry = graph.nodes()[candidate[MapDimShuffle._map_entry]] return map_entry.map.label + ': ' + str(map_entry.map.params) def apply(self, sdfg: SDFG): graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[self._map_entry]] if set(self.parameters) != set(map_entry.map.params): return map_entry.range.ranges = [ r for list_param in self.parameters for map_param, r in zip( map_entry.map.params, map_entry.range.ranges) if list_param == map_param ] map_entry.map.params = self.parameters
class TrivialMapRangeElimination(transformation.SingleStateTransformation): """ Implements the Trivial Map Range Elimination pattern. Trivial Map Range Elimination takes a multi-dimensional map with a range containing one element and removes the corresponding dimension. Example: Map[i=0:I,j=0] -> Map[i=0:I] """ map_entry = transformation.PatternNode(nodes.MapEntry) @classmethod def expressions(cls): return [sdutil.node_path_graph(cls.map_entry)] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): map_entry = self.map_entry if len(map_entry.map.range) <= 1: return False # only acts on multi-dimensional maps return any(frm == to for frm, to, _ in map_entry.map.range) def apply(self, graph, sdfg): map_entry = self.map_entry remaining_ranges = [] remaining_params = [] for map_param, ranges in zip(map_entry.map.params, map_entry.map.range.ranges): map_from, map_to, _ = ranges if map_from == map_to: # Replace the map index variable with the value it obtained scope = graph.scope_subgraph(map_entry) scope.replace(map_param, map_from) else: remaining_ranges.append(ranges) remaining_params.append(map_param) map_entry.map.range.ranges = remaining_ranges map_entry.map.params = remaining_params
class OTFMapFusion(transformation.SingleStateTransformation): """ Performs fusion of two maps by replicating the contents of the first into the second map until all the input dependencies (memlets) of the second one are met. """ first_map_exit = transformation.PatternNode(nds.ExitNode) array = transformation.PatternNode(nds.AccessNode) second_map_entry = transformation.PatternNode(nds.EntryNode) @staticmethod def annotates_memlets(): return False @classmethod def expressions(cls): return [ sdutil.node_path_graph(cls.first_map_exit, cls.array, cls.second_map_entry) ] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # WCR: not supported on first map for _in_e in graph.in_edges(self.first_map_exit): if _in_e.data.wcr is not None: return False # Check intermediate nodes between both maps. for _, _, node, _, _ in graph.out_edges(self.first_map_exit): # Only map -> array -> map if not isinstance(node, nds.AccessNode): return False # Non-transient blocks removal of first map if not sdfg.arrays[node.data].transient: return False # Check that array is not co-produced by other parent map. producers = set(map(lambda edge: edge.src, graph.in_edges(node))) for prod in producers: if prod != self.first_map_exit: return False # Check that array is not co-consumed by other child mao consumers = set(map(lambda edge: edge.dst, graph.out_edges(node))) for cons in consumers: if cons != self.second_map_entry: return False # Success return True def apply(self, graph: SDFGState, sdfg: SDFG): first_map_entry = graph.entry_node(self.first_map_exit) intermediate_dnodes = set() for _, _, node, _, _ in graph.out_edges(self.first_map_exit): if not isinstance(node, nds.AccessNode): continue intermediate_dnodes.add(node) self._update_in_connectors(graph, intermediate_dnodes) self._replicate_first_map(sdfg, graph, first_map_entry, intermediate_dnodes) graph.remove_nodes_from( graph.all_nodes_between(first_map_entry, self.first_map_exit) | {first_map_entry, self.first_map_exit}) for node in graph.nodes(): if not isinstance(node, nds.AccessNode): continue if graph.in_degree(node) == 0 and graph.out_degree(node) == 0: graph.remove_node(node) def _update_in_connectors(self, graph, intermediate_dnodes): first_map_entry = graph.entry_node(self.first_map_exit) for dnode in intermediate_dnodes: for edge in graph.edges_between(dnode, self.second_map_entry): graph.remove_edge_and_connectors(edge) for edge in graph.in_edges(first_map_entry): if self.second_map_entry.add_in_connector(edge.dst_conn + "_"): graph.add_edge(edge.src, edge.src_conn, self.second_map_entry, edge.dst_conn + "_", edge.data) else: raise ValueError("Failed to connect") def _replicate_first_map(self, sdfg, graph, first_map_entry, intermediate_dnodes): for dnode in intermediate_dnodes: array_name = dnode.data array = sdfg.arrays[array_name] read_offsets = self._read_offsets(graph, array_name) # Replicate first map tasklets once for each read offset access and # connect them to other tasklets accordingly for offset, edges in read_offsets.items(): new_nodes = self._copy_first_map_contents( sdfg, graph, first_map_entry) tmp_name = "__otf" tmp_name, _ = sdfg.add_scalar(tmp_name, array.dtype, transient=True, find_new_name=True) tmp_access = graph.add_access(tmp_name) for node in new_nodes: for edge in graph.edges_between(node, self.first_map_exit): graph.add_edge(edge.src, edge.src_conn, tmp_access, None, Memlet(tmp_name)) graph.remove_edge(edge) for edge in graph.edges_between(first_map_entry, node): memlet = dcpy(edge.data) memlet.subset.offset(list(offset), negative=False) self.second_map_entry.add_out_connector(edge.src_conn + "_") graph.add_edge(self.second_map_entry, edge.src_conn + "_", node, edge.dst_conn, memlet) graph.remove_edge(edge) for edge in edges: graph.add_edge(tmp_access, None, edge.dst, edge.dst_conn, Memlet(tmp_name)) def _read_offsets(self, state, array_name): """Compute offsets of read accesses in second map.""" # Get output memlet of first tasklet output_edges = state.in_edges(self.first_map_exit) assert len(output_edges) == 1 write_memlet = output_edges[0].data # Find read offsets by looping over second map entry connectors offsets = defaultdict(list) for edge in state.out_edges(self.second_map_entry): if edge.data.data == array_name: self.second_map_entry.remove_out_connector(edge.src_conn) state.remove_edge(edge) offset = OTFMapFusion._memlet_offsets(write_memlet, edge.data) offsets[offset].append(edge) return offsets def _copy_first_map_contents(self, sdfg, graph, first_map_entry): inter_nodes = list( graph.all_nodes_between(first_map_entry, self.first_map_exit) - {first_map_entry}) new_inter_nodes = [dcpy(node) for node in inter_nodes] tmp_map = dict() for node in new_inter_nodes: if isinstance(node, nds.AccessNode): data = sdfg.arrays[node.data] if isinstance(data, dt.Scalar) and data.transient: tmp_name = sdfg.temp_data_name() sdfg.add_scalar(tmp_name, data.dtype, transient=True) tmp_map[node.data] = tmp_name node.data = tmp_name graph.add_node(node) id_map = { graph.node_id(old): graph.node_id(new) for old, new in zip(inter_nodes, new_inter_nodes) } def map_node(node): return graph.node(id_map[graph.node_id(node)]) def map_memlet(memlet): memlet = dcpy(memlet) memlet.data = tmp_map.get(memlet.data, memlet.data) return memlet for edge in graph.edges(): if edge.src in inter_nodes or edge.dst in inter_nodes: src = map_node( edge.src) if edge.src in inter_nodes else edge.src dst = map_node( edge.dst) if edge.dst in inter_nodes else edge.dst edge_data = map_memlet(edge.data) graph.add_edge(src, edge.src_conn, dst, edge.dst_conn, edge_data) return new_inter_nodes @staticmethod def _memlet_offsets(base_memlet, offset_memlet): """Compute subset offset of `offset_memlet` relative to `base_memlet`.""" def offset(base_range, offset_range): b0, e0, s0 = base_range b1, e1, s1 = offset_range assert e1 - e0 == b1 - b0 and s0 == s1 return int(e1 - e0) return tuple( offset(b, o) for b, o in zip(base_memlet.subset.ranges, offset_memlet.subset.ranges))
class MapFission(transformation.SingleStateTransformation): """ Implements the MapFission transformation. Map fission refers to subsuming a map scope into its internal subgraph, essentially replicating the map into maps in all of its internal components. This also extends the dimensions of "border" transient arrays (i.e., those between the maps), in order to retain program semantics after fission. There are two cases that match map fission: 1. A map with an arbitrary subgraph with more than one computational (i.e., non-access) node. The use of arrays connecting the computational nodes must be limited to the subgraph, and non transient arrays may not be used as "border" arrays. 2. A map with one internal node that is a nested SDFG, in which each state matches the conditions of case (1). If a map has nested SDFGs in its subgraph, they are not considered in the case (1) above, and MapFission must be invoked again on the maps with the nested SDFGs in question. """ map_entry = transformation.PatternNode(nodes.EntryNode) nested_sdfg = transformation.PatternNode(nodes.NestedSDFG) @staticmethod def annotates_memlets(): return False @classmethod def expressions(cls): return [ sdutil.node_path_graph(cls.map_entry), sdutil.node_path_graph(cls.map_entry, cls.nested_sdfg), ] @staticmethod def _components( subgraph: gr.SubgraphView) -> List[Tuple[nodes.Node, nodes.Node]]: """ Returns the list of tuples non-array components in this subgraph. Each element in the list is a 2 tuple of (input node, output node) of the component. """ graph = (subgraph if isinstance(subgraph, sd.SDFGState) else subgraph.graph) schildren = subgraph.scope_children() ns = [(n, graph.exit_node(n)) if isinstance(n, nodes.EntryNode) else (n, n) for n in schildren[None] if isinstance(n, (nodes.CodeNode, nodes.EntryNode))] return ns @staticmethod def _border_arrays(sdfg, parent, subgraph): """ Returns a set of array names that are local to the fission subgraph. """ nested = isinstance(parent, sd.SDFGState) schildren = subgraph.scope_children() subset = gr.SubgraphView(parent, schildren[None]) if nested: return set(node.data for node in subset.nodes() if isinstance(node, nodes.AccessNode) and sdfg.arrays[node.data].transient) else: return set(node.data for node in subset.nodes() if isinstance(node, nodes.AccessNode)) @staticmethod def _internal_border_arrays(total_components, subgraphs): """ Returns the set of border arrays that appear between computational components (i.e., without sources and sinks). """ inputs = set() outputs = set() for components, subgraph in zip(total_components, subgraphs): for component_in, component_out in components: for e in subgraph.in_edges(component_in): if isinstance(e.src, nodes.AccessNode): inputs.add(e.src.data) for e in subgraph.out_edges(component_out): if isinstance(e.dst, nodes.AccessNode): outputs.add(e.dst.data) return inputs & outputs @staticmethod def _outside_map(node, scope_dict, entry_nodes): """ Returns True iff node is not in any of the scopes spanned by entry_nodes. """ while scope_dict[node] is not None: if scope_dict[node] in entry_nodes: return False node = scope_dict[node] return True def can_be_applied(self, graph, expr_index, sdfg, permissive=False): map_node = self.map_entry nsdfg_node = None # If the map is dynamic-ranged, the resulting border arrays would be # dynamically sized if sd.has_dynamic_map_inputs(graph, map_node): return False if expr_index == 0: # Map with subgraph subgraphs = [ graph.scope_subgraph(map_node, include_entry=False, include_exit=False) ] else: # Map with nested SDFG nsdfg_node = self.nested_sdfg # Make sure there are no other internal nodes in the map if len(set(e.dst for e in graph.out_edges(map_node))) > 1: return False subgraphs = list(nsdfg_node.sdfg.nodes()) # Test subgraphs border_arrays = set() total_components = [] for sg in subgraphs: components = self._components(sg) snodes = sg.nodes() # Test that the subgraphs have more than one computational component if expr_index == 0 and len(snodes) > 0 and len(components) <= 1: return False # Test that the components are connected by transients that are not # used anywhere else border_arrays |= self._border_arrays( nsdfg_node.sdfg if expr_index == 1 else sdfg, sg if expr_index == 1 else graph, sg) total_components.append(components) # In nested SDFGs and subgraphs, ensure none of the border # values are non-transients for array in border_arrays: if expr_index == 0: ndesc = sdfg.arrays[array] else: ndesc = nsdfg_node.sdfg.arrays[array] if ndesc.transient is False: return False # In subgraphs, make sure transients are not used/allocated # in other scopes or states if expr_index == 0: # Find all nodes not in subgraph not_subgraph = set( n.data for n in graph.nodes() if n not in snodes and isinstance(n, nodes.AccessNode)) not_subgraph.update( set(n.data for s in sdfg.nodes() if s != graph for n in s.nodes() if isinstance(n, nodes.AccessNode))) for _, component_out in components: for e in sg.out_edges(component_out): if isinstance(e.dst, nodes.AccessNode): if e.dst.data in not_subgraph: return False # Fail if there are arrays inside the map that are not a direct # output of a computational component # TODO(later): Support this case? Ambiguous array sizes and memlets external_arrays = ( border_arrays - self._internal_border_arrays(total_components, subgraphs)) if len(external_arrays) > 0: return False return True def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): map_entry = self.map_entry map_exit = graph.exit_node(map_entry) nsdfg_node: Optional[nodes.NestedSDFG] = None # Obtain subgraph to perform fission to if self.expr_index == 0: # Map with subgraph subgraphs = [(graph, graph.scope_subgraph(map_entry, include_entry=False, include_exit=False))] parent = sdfg else: # Map with nested SDFG nsdfg_node = self.nested_sdfg subgraphs = [(state, state) for state in nsdfg_node.sdfg.nodes()] parent = nsdfg_node.sdfg modified_arrays = set() # Get map information outer_map: nodes.Map = map_entry.map mapsize = outer_map.range.size() # Add new symbols from outer map to nested SDFG if self.expr_index == 1: map_syms = outer_map.range.free_symbols for edge in graph.out_edges(map_entry): if edge.data.data: map_syms.update(edge.data.subset.free_symbols) for edge in graph.in_edges(map_exit): if edge.data.data: map_syms.update(edge.data.subset.free_symbols) for sym in map_syms: symname = str(sym) if symname in outer_map.params: continue if symname not in nsdfg_node.symbol_mapping.keys(): nsdfg_node.symbol_mapping[symname] = sym nsdfg_node.sdfg.symbols[ symname] = graph.symbols_defined_at( nsdfg_node)[symname] # Remove map symbols from nested mapping for name in outer_map.params: if str(name) in nsdfg_node.symbol_mapping: del nsdfg_node.symbol_mapping[str(name)] if str(name) in nsdfg_node.sdfg.symbols: del nsdfg_node.sdfg.symbols[str(name)] for state, subgraph in subgraphs: components = MapFission._components(subgraph) sources = subgraph.source_nodes() sinks = subgraph.sink_nodes() # Collect external edges if self.expr_index == 0: external_edges_entry = list(state.out_edges(map_entry)) external_edges_exit = list(state.in_edges(map_exit)) else: external_edges_entry = [ e for e in subgraph.edges() if (isinstance(e.src, nodes.AccessNode) and not nsdfg_node.sdfg.arrays[e.src.data].transient) ] external_edges_exit = [ e for e in subgraph.edges() if (isinstance(e.dst, nodes.AccessNode) and not nsdfg_node.sdfg.arrays[e.dst.data].transient) ] # Map external edges to outer memlets edge_to_outer = {} for edge in external_edges_entry: if self.expr_index == 0: # Subgraphs use the corresponding outer map edges path = state.memlet_path(edge) eindex = path.index(edge) edge_to_outer[edge] = path[eindex - 1] else: # Nested SDFGs use the internal map edges of the node outer_edge = next(e for e in graph.in_edges(nsdfg_node) if e.dst_conn == edge.src.data) edge_to_outer[edge] = outer_edge for edge in external_edges_exit: if self.expr_index == 0: path = state.memlet_path(edge) eindex = path.index(edge) edge_to_outer[edge] = path[eindex + 1] else: # Nested SDFGs use the internal map edges of the node outer_edge = next(e for e in graph.out_edges(nsdfg_node) if e.src_conn == edge.dst.data) edge_to_outer[edge] = outer_edge # Collect all border arrays and code->code edges arrays = MapFission._border_arrays( nsdfg_node.sdfg if self.expr_index == 1 else sdfg, state, subgraph) scalars = defaultdict(list) for _, component_out in components: for e in subgraph.out_edges(component_out): if isinstance(e.dst, nodes.CodeNode): scalars[e.data.data].append(e) # Create new arrays for scalars for scalar, edges in scalars.items(): desc = parent.arrays[scalar] del parent.arrays[scalar] name, newdesc = parent.add_transient( scalar, mapsize, desc.dtype, desc.storage, lifetime=desc.lifetime, debuginfo=desc.debuginfo, allow_conflicts=desc.allow_conflicts, find_new_name=True) # Add extra nodes in component boundaries for edge in edges: anode = state.add_access(name) sbs = subsets.Range.from_string(','.join(outer_map.params)) # Offset memlet by map range begin (to fit the transient) sbs.offset([r[0] for r in outer_map.range], True) state.add_edge( edge.src, edge.src_conn, anode, None, mm.Memlet.simple( name, sbs, num_accesses=outer_map.range.num_elements())) state.add_edge( anode, None, edge.dst, edge.dst_conn, mm.Memlet.simple( name, sbs, num_accesses=outer_map.range.num_elements())) state.remove_edge(edge) # Add extra maps around components new_map_entries = [] for component_in, component_out in components: me, mx = state.add_map(outer_map.label + '_fission', [(p, '0:1') for p in outer_map.params], outer_map.schedule, unroll=outer_map.unroll, debuginfo=outer_map.debuginfo) # Add dynamic input connectors for conn in map_entry.in_connectors: if not conn.startswith('IN_'): me.add_in_connector(conn) me.map.range = dcpy(outer_map.range) new_map_entries.append(me) # Reconnect edges through new map for e in state.in_edges(component_in): state.add_edge(me, None, e.dst, e.dst_conn, dcpy(e.data)) # Reconnect inner edges at source directly to external nodes if self.expr_index == 0 and e in external_edges_entry: state.add_edge(edge_to_outer[e].src, edge_to_outer[e].src_conn, me, None, dcpy(edge_to_outer[e].data)) else: state.add_edge(e.src, e.src_conn, me, None, dcpy(e.data)) state.remove_edge(e) # Empty memlet edge in nested SDFGs if state.in_degree(component_in) == 0: state.add_edge(me, None, component_in, None, mm.Memlet()) for e in state.out_edges(component_out): state.add_edge(e.src, e.src_conn, mx, None, dcpy(e.data)) # Reconnect inner edges at sink directly to external nodes if self.expr_index == 0 and e in external_edges_exit: state.add_edge(mx, None, edge_to_outer[e].dst, edge_to_outer[e].dst_conn, dcpy(edge_to_outer[e].data)) else: state.add_edge(mx, None, e.dst, e.dst_conn, dcpy(e.data)) state.remove_edge(e) # Empty memlet edge in nested SDFGs if state.out_degree(component_out) == 0: state.add_edge(component_out, None, mx, None, mm.Memlet()) # Connect other sources/sinks not in components (access nodes) # directly to external nodes if self.expr_index == 0: for node in sources: if isinstance(node, nodes.AccessNode): for edge in state.in_edges(node): outer_edge = edge_to_outer[edge] memlet = dcpy(edge.data) memlet.subset = subsets.Range( outer_map.range.ranges + memlet.subset.ranges) state.add_edge(outer_edge.src, outer_edge.src_conn, edge.dst, edge.dst_conn, memlet) for node in sinks: if isinstance(node, nodes.AccessNode): for edge in state.out_edges(node): outer_edge = edge_to_outer[edge] state.add_edge(edge.src, edge.src_conn, outer_edge.dst, outer_edge.dst_conn, dcpy(outer_edge.data)) # Augment arrays by prepending map dimensions for array in arrays: if array in modified_arrays: continue desc = parent.arrays[array] if isinstance( desc, dt.Scalar): # Scalar needs to be augmented to an array desc = dt.Array(desc.dtype, desc.shape, desc.transient, desc.allow_conflicts, desc.storage, desc.location, desc.strides, desc.offset, False, desc.lifetime, 0, desc.debuginfo, desc.total_size, desc.start_offset) parent.arrays[array] = desc for sz in reversed(mapsize): desc.strides = [desc.total_size] + list(desc.strides) desc.total_size = desc.total_size * sz desc.shape = mapsize + list(desc.shape) desc.offset = [0] * len(mapsize) + list(desc.offset) modified_arrays.add(array) # Fill scope connectors so that memlets can be tracked below state.fill_scope_connectors() # Correct connectors and memlets in nested SDFGs to account for # missing outside map if self.expr_index == 1: to_correct = ([(e, e.src) for e in external_edges_entry] + [(e, e.dst) for e in external_edges_exit]) corrected_nodes = set() for edge, node in to_correct: if isinstance(node, nodes.AccessNode): if node in corrected_nodes: continue corrected_nodes.add(node) outer_edge = edge_to_outer[edge] desc = parent.arrays[node.data] # Modify shape of internal array to match outer one outer_desc = sdfg.arrays[outer_edge.data.data] if not isinstance(desc, dt.Scalar): desc.shape = outer_desc.shape if isinstance(desc, dt.Array): desc.strides = outer_desc.strides desc.total_size = outer_desc.total_size # Inside the nested SDFG, offset all memlets to include # the offsets from within the map. # NOTE: Relies on propagation to fix outer memlets for internal_edge in state.all_edges(node): for e in state.memlet_tree(internal_edge): e.data.subset.offset(desc.offset, False) e.data.subset = helpers.unsqueeze_memlet( e.data, outer_edge.data).subset # Only after offsetting memlets we can modify the # overall offset if isinstance(desc, dt.Array): desc.offset = outer_desc.offset # Fill in memlet trees for border transients # NOTE: Memlet propagation should run to correct the outer edges for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode) and node.data in arrays: for edge in state.all_edges(node): for e in state.memlet_tree(edge): # Prepend map dimensions to memlet e.data.subset = subsets.Range( [(pystr_to_symbolic(d) - r[0], pystr_to_symbolic(d) - r[0], 1) for d, r in zip(outer_map.params, outer_map.range)] + e.data.subset.ranges) # If nested SDFG, reconnect nodes around map and modify memlets if self.expr_index == 1: for edge in graph.in_edges(map_entry): if not edge.dst_conn or not edge.dst_conn.startswith('IN_'): continue # Modify edge coming into nested SDFG to include entire array desc = sdfg.arrays[edge.data.data] edge.data.subset = subsets.Range.from_array(desc) edge.data.num_accesses = edge.data.subset.num_elements() # Find matching edge inside map inner_edge = next( e for e in graph.out_edges(map_entry) if e.src_conn and e.src_conn[4:] == edge.dst_conn[3:]) graph.add_edge(edge.src, edge.src_conn, nsdfg_node, inner_edge.dst_conn, dcpy(edge.data)) for edge in graph.out_edges(map_exit): # Modify edge coming out of nested SDFG to include entire array desc = sdfg.arrays[edge.data.data] edge.data.subset = subsets.Range.from_array(desc) # Find matching edge inside map inner_edge = next(e for e in graph.in_edges(map_exit) if e.dst_conn[3:] == edge.src_conn[4:]) graph.add_edge(nsdfg_node, inner_edge.src_conn, edge.dst, edge.dst_conn, dcpy(edge.data)) # Remove outer map graph.remove_nodes_from([map_entry, map_exit])
class AugAssignToWCR(transformation.Transformation): """ Converts an augmented assignment ("a += b", "a = a + b") into a tasklet with a write-conflict resolution. """ input = transformation.PatternNode(nodes.AccessNode) tasklet = transformation.PatternNode(nodes.Tasklet) output = transformation.PatternNode(nodes.AccessNode) map_entry = transformation.PatternNode(nodes.MapEntry) map_exit = transformation.PatternNode(nodes.MapExit) _EXPRESSIONS = ['+', '-', '*', '^', '%'] #, '/'] _EXPR_MAP = { '-': ('+', '-({expr})'), '/': ('*', '((decltype({expr}))1)/({expr})') } @staticmethod def expressions(): return [ sdutil.node_path_graph(AugAssignToWCR.input, AugAssignToWCR.tasklet, AugAssignToWCR.output), ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): inarr = graph.node(candidate[AugAssignToWCR.input]) tasklet: nodes.Tasklet = graph.node(candidate[AugAssignToWCR.tasklet]) outarr = graph.node(candidate[AugAssignToWCR.output]) if inarr.data != outarr.data: return False # Free tasklet if expr_index == 0: # Only free tasklets supported for now if graph.entry_node(tasklet) is not None: return False inedges = graph.edges_between(inarr, tasklet) if len(graph.edges_between(tasklet, outarr)) > 1: return False # Make sure augmented assignment can be fissioned as necessary if any(not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(tasklet)): return False if graph.in_degree(inarr) > 0 and graph.out_degree(outarr) > 0: return False outedge = graph.edges_between(tasklet, outarr)[0] else: # Free map me: nodes.MapEntry = graph.node(candidate[AugAssignToWCR.map_entry]) mx = graph.node(candidate[AugAssignToWCR.map_exit]) # Only free maps supported for now if graph.entry_node(me) is not None: return False inedges = graph.edges_between(me, tasklet) if len(graph.edges_between(tasklet, mx)) > 1: return False # Currently no fission is supported if any(e.src is not me and not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(me) + graph.in_edges(tasklet)): return False if graph.in_degree(inarr) > 0: return False outedge = graph.edges_between(tasklet, mx)[0] # Get relevant output connector outconn = outedge.src_conn ops = '[%s]' % ''.join( re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) if tasklet.language is dtypes.Language.Python: # Expect ast.Assign(ast.Expr()) return False elif tasklet.language is dtypes.Language.CPP: cstr = tasklet.code.as_string.strip() for edge in inedges: # Try to match a single C assignment that can be converted to WCR inconn = edge.dst_conn lhs = r'^\s*%s\s*=\s*%s\s*%s.*;$' % (re.escape(outconn), re.escape(inconn), ops) rhs = r'^\s*%s\s*=\s*.*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)) if re.match(lhs, cstr) is None: continue # Same memlet if edge.data.subset != outedge.data.subset: continue # If in map, only match if the subset is independent of any # map indices (otherwise no conflict) if (expr_index == 1 and len(outedge.data.subset.free_symbols & set(me.map.params)) == len(me.map.params)): continue return True else: # Only Python/C++ tasklets supported return False return False def apply(self, sdfg: SDFG): input: nodes.AccessNode = self.input(sdfg) tasklet: nodes.Tasklet = self.tasklet(sdfg) output: nodes.AccessNode = self.output(sdfg) state: SDFGState = sdfg.node(self.state_id) # If state fission is necessary to keep semantics, do it first if (self.expr_index == 0 and state.in_degree(input) > 0 and state.out_degree(output) == 0): newstate = sdfg.add_state_after(state) newstate.add_node(tasklet) new_input, new_output = None, None # Keep old edges for after we remove tasklet from the original state in_edges = list(state.in_edges(tasklet)) out_edges = list(state.out_edges(tasklet)) for e in in_edges: r = newstate.add_read(e.src.data) newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data) if e.src is input: new_input = r for e in out_edges: w = newstate.add_write(e.dst.data) newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data) if e.dst is output: new_output = w # Remove tasklet and resulting isolated nodes state.remove_node(tasklet) for e in in_edges: if state.degree(e.src) == 0: state.remove_node(e.src) for e in out_edges: if state.degree(e.dst) == 0: state.remove_node(e.dst) # Reset state and nodes for rest of transformation input = new_input output = new_output state = newstate # End of state fission if self.expr_index == 0: inedges = state.edges_between(input, tasklet) outedge = state.edges_between(tasklet, output)[0] else: me = self.map_entry(sdfg) mx = self.map_exit(sdfg) inedges = state.edges_between(me, tasklet) outedge = state.edges_between(tasklet, mx)[0] # Get relevant output connector outconn = outedge.src_conn ops = '[%s]' % ''.join( re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) # Change tasklet code if tasklet.language is dtypes.Language.Python: raise NotImplementedError elif tasklet.language is dtypes.Language.CPP: cstr = tasklet.code.as_string.strip() for edge in inedges: inconn = edge.dst_conn match = re.match( r'^\s*%s\s*=\s*%s\s*(%s)(.*);$' % (re.escape(outconn), re.escape(inconn), ops), cstr) if match is None: # match = re.match( # r'^\s*%s\s*=\s*(.*)\s*(%s)\s*%s;$' % # (re.escape(outconn), ops, re.escape(inconn)), cstr) # if match is None: continue # op = match.group(2) # expr = match.group(1) else: op = match.group(1) expr = match.group(2) if edge.data.subset != outedge.data.subset: continue # Map asymmetric WCRs to symmetric ones if possible if op in AugAssignToWCR._EXPR_MAP: op, newexpr = AugAssignToWCR._EXPR_MAP[op] expr = newexpr.format(expr=expr) tasklet.code.code = '%s = %s;' % (outconn, expr) inedge = edge break else: raise NotImplementedError # Change output edge outedge.data.wcr = f'lambda a,b: a {op} b' if self.expr_index == 0: # Remove input node and connector state.remove_edge_and_connectors(inedge) if state.degree(input) == 0: state.remove_node(input) else: # Remove input edge and dst connector, but not necessarily src state.remove_memlet_path(inedge) # If outedge leads to non-transient, and this is a nested SDFG, # propagate outwards sd = sdfg while (not sd.arrays[outedge.data.data].transient and sd.parent_nsdfg_node is not None): nsdfg = sd.parent_nsdfg_node nstate = sd.parent sd = sd.parent_sdfg outedge = next( iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data))) for outedge in nstate.memlet_path(outedge): outedge.data.wcr = f'lambda a,b: a {op} b'
class PruneConnectors(pm.Transformation): """ Removes unused connectors from nested SDFGs, as well as their memlets in the outer scope, replacing them with empty memlets if necessary. """ nsdfg = pm.PatternNode(nodes.NestedSDFG) @staticmethod def expressions(): return [utils.node_path_graph(PruneConnectors.nsdfg)] @staticmethod def can_be_applied(graph: Union[SDFG, SDFGState], candidate: Dict[pm.PatternNode, int], expr_index: int, sdfg: SDFG, strict: bool = False) -> bool: nsdfg = graph.node(candidate[PruneConnectors.nsdfg]) read_set, write_set = nsdfg.sdfg.read_and_write_sets() prune_in = nsdfg.in_connectors.keys() - read_set prune_out = nsdfg.out_connectors.keys() - write_set # Add WCR outputs to "do not prune" input list for e in graph.out_edges(nsdfg): if e.data.wcr is not None and e.src_conn in prune_in: if (graph.in_degree( next( iter(graph.in_edges_by_connector( nsdfg, e.src_conn))).src) > 0): prune_in.remove(e.src_conn) has_before = any( graph.in_degree(graph.memlet_path(e)[0].src) > 0 for e in graph.in_edges(nsdfg) if e.dst_conn in prune_in) has_after = any( graph.out_degree(graph.memlet_path(e)[-1].dst) > 0 for e in graph.out_edges(nsdfg) if e.src_conn in prune_out) if has_before or has_after: return False if len(prune_in) > 0 or len(prune_out) > 0: return True return False def apply(self, sdfg: SDFG) -> Union[Any, None]: state = sdfg.node(self.state_id) nsdfg = self.nsdfg(sdfg) read_set, write_set = nsdfg.sdfg.read_and_write_sets() prune_in = nsdfg.in_connectors.keys() - read_set prune_out = nsdfg.out_connectors.keys() - write_set # Detect which nodes are used, so we can delete unused nodes after the # connectors have been pruned all_data_used = read_set | write_set # Add WCR outputs to "do not prune" input list for e in state.out_edges(nsdfg): if e.data.wcr is not None and e.src_conn in prune_in: if (state.in_degree( next( iter(state.in_edges_by_connector( nsdfg, e.src_conn))).src) > 0): prune_in.remove(e.src_conn) for conn in prune_in: for e in state.in_edges_by_connector(nsdfg, conn): state.remove_memlet_path(e, remove_orphans=True) if conn in nsdfg.sdfg.arrays and conn not in all_data_used: # If the data is now unused, we can purge it from the SDFG nsdfg.sdfg.remove_data(conn) for conn in prune_out: for e in state.out_edges_by_connector(nsdfg, conn): state.remove_memlet_path(e, remove_orphans=True) if conn in nsdfg.sdfg.arrays and conn not in all_data_used: # If the data is now unused, we can purge it from the SDFG nsdfg.sdfg.remove_data(conn)
class MapFusion(transformation.Transformation): """ Implements the MapFusion transformation. It wil check for all patterns MapExit -> AccessNode -> MapEntry, and based on the following rules, fuse them and remove the transient in between. There are several possibilities of what it does to this transient in between. Essentially, if there is some other place in the sdfg where it is required, or if it is not a transient, then it will not be removed. In such a case, it will be linked to the MapExit node of the new fused map. Rules for fusing maps: 0. The map range of the second map should be a permutation of the first map range. 1. Each of the access nodes that are adjacent to the first map exit should have an edge to the second map entry. If it doesn't, then the second map entry should not be reachable from this access node. 2. Any node that has a wcr from the first map exit should not be adjacent to the second map entry. 3. Access pattern for the access nodes in the second map should be the same permutation of the map parameters as the map ranges of the two maps. Alternatively, this access node should not be adjacent to the first map entry. """ first_map_exit = transformation.PatternNode(nodes.ExitNode) array = transformation.PatternNode(nodes.AccessNode) second_map_entry = transformation.PatternNode(nodes.EntryNode) @staticmethod def annotates_memlets(): return False @staticmethod def expressions(): return [ sdutil.node_path_graph( MapFusion.first_map_exit, MapFusion.array, MapFusion.second_map_entry, ) ] @staticmethod def find_permutation(first_map: nodes.Map, second_map: nodes.Map) -> Union[List[int], None]: """ Find permutation between two map ranges. :param first_map: First map. :param second_map: Second map. :return: None if no such permutation exists, otherwise a list of indices L such that L[x]'th parameter of second map has the same range as x'th parameter of the first map. """ result = [] if len(first_map.range) != len(second_map.range): return None # Match map ranges with reduce ranges for i, tmap_rng in enumerate(first_map.range): found = False for j, rng in enumerate(second_map.range): if tmap_rng == rng and j not in result: result.append(j) found = True break if not found: break # Ensure all map ranges matched if len(result) != len(first_map.range): return None return result @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_map_exit = graph.nodes()[candidate[MapFusion.first_map_exit]] first_map_entry = graph.entry_node(first_map_exit) second_map_entry = graph.nodes()[candidate[MapFusion.second_map_entry]] second_map_exit = graph.exit_node(second_map_entry) for _in_e in graph.in_edges(first_map_exit): if _in_e.data.wcr is not None: for _out_e in graph.out_edges(second_map_entry): if _out_e.data.data == _in_e.data.data: # wcr is on a node that is used in the second map, quit return False # Check whether there is a pattern map -> access -> map. intermediate_nodes = set() intermediate_data = set() for _, _, dst, _, _ in graph.out_edges(first_map_exit): if isinstance(dst, nodes.AccessNode): intermediate_nodes.add(dst) intermediate_data.add(dst.data) # If array is used anywhere else in this state. num_occurrences = len([ n for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n.data == dst.data ]) if num_occurrences > 1: return False else: return False # Check map ranges perm = MapFusion.find_permutation(first_map_entry.map, second_map_entry.map) if perm is None: return False # Check if any intermediate transient is also going to another location second_inodes = set(e.src for e in graph.in_edges(second_map_entry) if isinstance(e.src, nodes.AccessNode)) transients_to_remove = intermediate_nodes & second_inodes # if any(e.dst != second_map_entry for n in transients_to_remove # for e in graph.out_edges(n)): if any(graph.out_degree(n) > 1 for n in transients_to_remove): return False # Create a dict that maps parameters of the first map to those of the # second map. params_dict = {} for _index, _param in enumerate(second_map_entry.map.params): params_dict[_param] = first_map_entry.map.params[perm[_index]] # Create intermediate dicts to avoid conflicts, such as {i:j, j:i} repldict = { symbolic.pystr_to_symbolic(k): symbolic.pystr_to_symbolic('__dacesym_' + str(v)) for k, v in params_dict.items() } repldict_inv = { symbolic.pystr_to_symbolic('__dacesym_' + str(v)): symbolic.pystr_to_symbolic(v) for v in params_dict.values() } out_memlets = [e.data for e in graph.in_edges(first_map_exit)] # Check that input set of second map is provided by the output set # of the first map, or other unrelated maps for second_edge in graph.out_edges(second_map_entry): # Memlets that do not come from one of the intermediate arrays if second_edge.data.data not in intermediate_data: # however, if intermediate_data eventually leads to # second_memlet.data, need to fail. for _n in intermediate_nodes: source_node = _n destination_node = graph.memlet_path(second_edge)[0].src # NOTE: Assumes graph has networkx version if destination_node in nx.descendants( graph._nx, source_node): return False continue provided = False # Compute second subset with respect to first subset's symbols sbs_permuted = dcpy(second_edge.data.subset) if sbs_permuted: sbs_permuted.replace(repldict) sbs_permuted.replace(repldict_inv) for first_memlet in out_memlets: if first_memlet.data != second_edge.data.data: continue # If there is a covered subset, it is provided if first_memlet.subset.covers(sbs_permuted): provided = True break # If none of the output memlets of the first map provide the info, # fail. if provided is False: return False # Checking for stencil pattern and common input/output data # (after fusing the maps) first_map_inputnodes = { e.src: e.src.data for e in graph.in_edges(first_map_entry) if isinstance(e.src, nodes.AccessNode) } input_views = set() viewed_inputnodes = dict() for n in first_map_inputnodes.keys(): if isinstance(n.desc(sdfg), data.View): input_views.add(n) for v in input_views: del first_map_inputnodes[v] e = sdutil.get_view_edge(graph, v) if e: first_map_inputnodes[e.src] = e.src.data viewed_inputnodes[e.src.data] = v second_map_outputnodes = { e.dst: e.dst.data for e in graph.out_edges(second_map_exit) if isinstance(e.dst, nodes.AccessNode) } output_views = set() viewed_outputnodes = dict() for n in second_map_outputnodes: if isinstance(n.desc(sdfg), data.View): output_views.add(n) for v in output_views: del second_map_outputnodes[v] e = sdutil.get_view_edge(graph, v) if e: second_map_outputnodes[e.dst] = e.dst.data viewed_outputnodes[e.dst.data] = v common_data = set(first_map_inputnodes.values()).intersection( set(second_map_outputnodes.values())) if common_data: input_data = [ viewed_inputnodes[d].data if d in viewed_inputnodes.keys() else d for d in common_data ] input_accesses = [ graph.memlet_path(e)[-1].data.src_subset for e in graph.out_edges(first_map_entry) if e.data.data in input_data ] if len(input_accesses) > 1: for i, a in enumerate(input_accesses[:-1]): for b in input_accesses[i + 1:]: if isinstance(a, subsets.Indices): c = subsets.Range.from_indices(a) c.offset(b, negative=True) else: c = a.offset_new(b, negative=True) for r in c: if r != (0, 0, 1): return False output_data = [ viewed_outputnodes[d].data if d in viewed_outputnodes.keys() else d for d in common_data ] output_accesses = [ graph.memlet_path(e)[0].data.dst_subset for e in graph.in_edges(second_map_exit) if e.data.data in output_data ] # Compute output accesses with respect to first map's symbols oacc_permuted = [dcpy(a) for a in output_accesses] for a in oacc_permuted: a.replace(repldict) a.replace(repldict_inv) a = input_accesses[0] for b in oacc_permuted: if isinstance(a, subsets.Indices): c = subsets.Range.from_indices(a) c.offset(b, negative=True) else: c = a.offset_new(b, negative=True) for r in c: if r != (0, 0, 1): return False # Success return True @staticmethod def match_to_str(graph, candidate): first_exit = graph.nodes()[candidate[MapFusion.first_map_exit]] second_entry = graph.nodes()[candidate[MapFusion.second_map_entry]] return " -> ".join(entry.map.label + ": " + str(entry.map.params) for entry in [first_exit, second_entry]) def apply(self, sdfg): """ This method applies the mapfusion transformation. Other than the removal of the second map entry node (SME), and the first map exit (FME) node, it has the following side effects: 1. Any transient adjacent to both FME and SME with degree = 2 will be removed. The tasklets that use/produce it shall be connected directly with a scalar/new transient (if the dataflow is more than a single scalar) 2. If this transient is adjacent to FME and SME and has other uses, it will be adjacent to the new map exit post fusion. Tasklet-> Tasklet edges will ALSO be added as mentioned above. 3. If an access node is adjacent to FME but not SME, it will be adjacent to new map exit post fusion. 4. If an access node is adjacent to SME but not FME, it will be adjacent to the new map entry node post fusion. """ graph: SDFGState = sdfg.nodes()[self.state_id] first_exit = graph.nodes()[self.subgraph[MapFusion.first_map_exit]] first_entry = graph.entry_node(first_exit) second_entry = graph.nodes()[self.subgraph[MapFusion.second_map_entry]] second_exit = graph.exit_node(second_entry) intermediate_nodes = set() for _, _, dst, _, _ in graph.out_edges(first_exit): intermediate_nodes.add(dst) assert isinstance(dst, nodes.AccessNode) # Check if an access node refers to non transient memory, or transient # is used at another location (cannot erase) do_not_erase = set() for node in intermediate_nodes: if sdfg.arrays[node.data].transient is False: do_not_erase.add(node) else: for edge in graph.in_edges(node): if edge.src != first_exit: do_not_erase.add(node) break else: for edge in graph.out_edges(node): if edge.dst != second_entry: do_not_erase.add(node) break # Find permutation between first and second scopes perm = MapFusion.find_permutation(first_entry.map, second_entry.map) params_dict = {} for index, param in enumerate(first_entry.map.params): params_dict[param] = second_entry.map.params[perm[index]] # Replaces (in memlets and tasklet) the second scope map # indices with the permuted first map indices. # This works in two passes to avoid problems when e.g., exchanging two # parameters (instead of replacing (j,i) and (i,j) to (j,j) and then # i,i). second_scope = graph.scope_subgraph(second_entry) for firstp, secondp in params_dict.items(): if firstp != secondp: replace(second_scope, secondp, '__' + secondp + '_fused') for firstp, secondp in params_dict.items(): if firstp != secondp: replace(second_scope, '__' + secondp + '_fused', firstp) # Isolate First exit node ############################ edges_to_remove = set() nodes_to_remove = set() for edge in graph.in_edges(first_exit): tree = graph.memlet_tree(edge) access_node = tree.root().edge.dst if access_node not in do_not_erase: out_edges = [ e for e in graph.out_edges(access_node) if e.dst == second_entry ] # In this transformation, there can only be one edge to the # second map assert len(out_edges) == 1 # Get source connector to the second map connector = out_edges[0].dst_conn[3:] new_dsts = [] # Look at the second map entry out-edges to get the new # destinations for e in graph.out_edges(second_entry): if e.src_conn[4:] == connector: new_dsts.append(e) if not new_dsts: # Access node is not used in the second map nodes_to_remove.add(access_node) continue # If the source is an access node, modify the memlet to point # to it if (isinstance(edge.src, nodes.AccessNode) and edge.data.data != edge.src.data): edge.data.data = edge.src.data edge.data.subset = ("0" if edge.data.other_subset is None else edge.data.other_subset) edge.data.other_subset = None else: # Add a transient scalar/array self.fuse_nodes(sdfg, graph, edge, new_dsts[0].dst, new_dsts[0].dst_conn, new_dsts[1:]) edges_to_remove.add(edge) # Remove transient node between the two maps nodes_to_remove.add(access_node) else: # The case where intermediate array node cannot be removed # Node will become an output of the second map exit out_e = tree.parent.edge conn = second_exit.next_connector() graph.add_edge( second_exit, 'OUT_' + conn, out_e.dst, out_e.dst_conn, dcpy(out_e.data), ) second_exit.add_out_connector('OUT_' + conn) graph.add_edge(edge.src, edge.src_conn, second_exit, 'IN_' + conn, dcpy(edge.data)) second_exit.add_in_connector('IN_' + conn) edges_to_remove.add(out_e) edges_to_remove.add(edge) # If the second map needs this node, link the connector # that generated this to the place where it is needed, with a # temp transient/scalar for memlet to be generated for out_e in graph.out_edges(second_entry): second_memlet_path = graph.memlet_path(out_e) source_node = second_memlet_path[0].src if source_node == access_node: self.fuse_nodes(sdfg, graph, edge, out_e.dst, out_e.dst_conn) ### # First scope exit is isolated and can now be safely removed for e in edges_to_remove: graph.remove_edge(e) graph.remove_nodes_from(nodes_to_remove) graph.remove_node(first_exit) # Isolate second_entry node ########################### for edge in graph.in_edges(second_entry): tree = graph.memlet_tree(edge) access_node = tree.root().edge.src if access_node in intermediate_nodes: # Already handled above, can be safely removed graph.remove_edge(edge) continue # This is an external input to the second map which will now go # through the first map. conn = first_entry.next_connector() graph.add_edge(edge.src, edge.src_conn, first_entry, 'IN_' + conn, dcpy(edge.data)) first_entry.add_in_connector('IN_' + conn) graph.remove_edge(edge) for out_enode in tree.children: out_e = out_enode.edge graph.add_edge( first_entry, 'OUT_' + conn, out_e.dst, out_e.dst_conn, dcpy(out_e.data), ) graph.remove_edge(out_e) first_entry.add_out_connector('OUT_' + conn) ### # Second node is isolated and can now be safely removed graph.remove_node(second_entry) # Fix scope exit to point to the right map second_exit.map = first_entry.map def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None): """ Fuses two nodes via memlets and possibly transient arrays. """ other_edges = other_edges or [] memlet_path = graph.memlet_path(edge) access_node = memlet_path[-1].dst local_name = "__s%d_n%d%s_n%d%s" % ( self.state_id, graph.node_id(edge.src), edge.src_conn, graph.node_id(edge.dst), edge.dst_conn, ) # Add intermediate memory between subgraphs. If a scalar, # uses direct connection. If an array, adds a transient node if edge.data.subset.num_elements() == 1: sdfg.add_scalar( local_name, dtype=access_node.desc(graph).dtype, transient=True, storage=dtypes.StorageType.Register, ) edge.data.data = local_name edge.data.subset = "0" # If source of edge leads to multiple destinations, # redirect all through an access node out_edges = list( graph.out_edges_by_connector(edge.src, edge.src_conn)) if len(out_edges) > 1: local_node = graph.add_access(local_name) src_connector = None # Add edge that leads to transient node graph.add_edge(edge.src, edge.src_conn, local_node, None, dcpy(edge.data)) for other_edge in out_edges: if other_edge is not edge: graph.remove_edge(other_edge) graph.add_edge(local_node, src_connector, other_edge.dst, other_edge.dst_conn, other_edge.data) else: local_node = edge.src src_connector = edge.src_conn # Add edge that leads to the second node graph.add_edge(local_node, src_connector, new_dst, new_dst_conn, dcpy(edge.data)) for e in other_edges: graph.add_edge(local_node, src_connector, e.dst, e.dst_conn, dcpy(edge.data)) else: sdfg.add_transient(local_name, edge.data.subset.size(), dtype=access_node.desc(graph).dtype) old_edge = dcpy(edge) local_node = graph.add_access(local_name) src_connector = None edge.data.data = local_name edge.data.subset = ",".join( ["0:" + str(s) for s in edge.data.subset.size()]) # Add edge that leads to transient node graph.add_edge( edge.src, edge.src_conn, local_node, None, dcpy(edge.data), ) # Add edge that leads to the second node graph.add_edge(local_node, src_connector, new_dst, new_dst_conn, dcpy(edge.data)) for e in other_edges: graph.add_edge(local_node, src_connector, e.dst, e.dst_conn, dcpy(edge.data)) # Modify data and memlets on all surrounding edges to match array for neighbor in graph.all_edges(local_node): for e in graph.memlet_tree(neighbor): e.data.data = local_name e.data.subset.offset(old_edge.data.subset, negative=True)
class PruneSymbols(pm.Transformation): """ Removes unused symbol mappings from nested SDFGs, as well as internal symbols if necessary. """ nsdfg = pm.PatternNode(nodes.NestedSDFG) @staticmethod def expressions(): return [utils.node_path_graph(PruneSymbols.nsdfg)] @staticmethod def _candidates(nsdfg: nodes.NestedSDFG) -> Set[str]: candidates = set(nsdfg.symbol_mapping.keys()) if len(candidates) == 0: return set() for desc in nsdfg.sdfg.arrays.values(): candidates -= set(map(str, desc.free_symbols)) ignore = set() for nstate in cfg.stateorder_topological_sort(nsdfg.sdfg): state_syms = nstate.free_symbols # Try to be conservative with C++ tasklets for node in nstate.nodes(): if (isinstance(node, nodes.Tasklet) and node.language is dtypes.Language.CPP): for candidate in candidates: if re.findall(r'\b%s\b' % re.escape(candidate), node.code.as_string): state_syms.add(candidate) # Any symbol used in this state is considered used candidates -= (state_syms - ignore) if len(candidates) == 0: return set() # Any symbol that is set in all outgoing edges is ignored from # this point local_ignore = None for e in nsdfg.sdfg.out_edges(nstate): # Look for symbols in condition candidates -= (set( map(str, symbolic.symbols_in_ast( e.data.condition.code[0]))) - ignore) for assign in e.data.assignments.values(): candidates -= ( symbolic.free_symbols_and_functions(assign) - ignore) if local_ignore is None: local_ignore = set(e.data.assignments.keys()) else: local_ignore &= e.data.assignments.keys() if local_ignore is not None: ignore |= local_ignore return candidates @staticmethod def can_be_applied(graph: Union[SDFG, SDFGState], candidate: Dict[pm.PatternNode, int], expr_index: int, sdfg: SDFG, strict: bool = False) -> bool: nsdfg: nodes.NestedSDFG = graph.node(candidate[PruneSymbols.nsdfg]) if len(PruneSymbols._candidates(nsdfg)) > 0: return True return False def apply(self, sdfg: SDFG) -> Union[Any, None]: nsdfg = self.nsdfg(sdfg) candidates = PruneSymbols._candidates(nsdfg) for candidate in candidates: del nsdfg.symbol_mapping[candidate] # If not used in SDFG, remove from symbols as well if helpers.is_symbol_unused(nsdfg.sdfg, candidate): nsdfg.sdfg.remove_symbol(candidate)
class BufferTiling(transformation.SingleStateTransformation): """ Implements the buffer tiling transformation. BufferTiling tiles a buffer that is in between two maps, where the preceding map writes to the buffer and the succeeding map reads from it. It introduces additional computations in exchange for reduced memory footprint. Commonly used to make use of shared memory on GPUs. """ map1_exit = transformation.PatternNode(nodes.MapExit) array = transformation.PatternNode(nodes.AccessNode) map2_entry = transformation.PatternNode(nodes.MapEntry) tile_sizes = ShapeProperty(dtype=tuple, default=(128, 128, 128), desc="Tile size per dimension") # Returns a list of graphs that represent the pattern @classmethod def expressions(cls): return [ sdutil.node_path_graph(cls.map1_exit, cls.array, cls.map2_entry) ] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): map1_exit = self.map1_exit map2_entry = self.map2_entry for buf in graph.all_nodes_between(map1_exit, map2_entry): # Check that buffers are AccessNodes. if not isinstance(buf, nodes.AccessNode): return False # Check that buffers are transient. if not sdfg.arrays[buf.data].transient: return False # Check that buffers have exactly 1 input and 1 output edge. if graph.in_degree(buf) != 1: return False if graph.out_degree(buf) != 1: return False # Check that buffers are next to the maps. if graph.in_edges(buf)[0].src != map1_exit: return False if graph.out_edges(buf)[0].dst != map2_entry: return False # Check that the data consumed is provided. provided = graph.in_edges(buf)[0].data.subset consumed = graph.out_edges(buf)[0].data.subset if not provided.covers(consumed): return False # Check that buffers occur only once in this state. num_occurrences = len([ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.data == buf ]) if num_occurrences > 1: return False return True def apply(self, graph, sdfg): map1_exit = self.map1_exit map1_entry = graph.entry_node(map1_exit) map2_entry = self.map2_entry buffers = graph.all_nodes_between(map1_exit, map2_entry) # Situation: # -> map1_entry -> ... -> map1_exit -> buffers -> map2_entry -> ... lower_extents = tuple(b - a for a, b in zip( map1_entry.range.min_element(), map2_entry.range.min_element())) upper_extents = tuple(a - b for a, b in zip( map1_entry.range.max_element(), map2_entry.range.max_element())) # Tile the first map with overlap MapTilingWithOverlap.apply_to(sdfg, map_entry=map1_entry, options={ 'tile_sizes': self.tile_sizes, 'lower_overlap': lower_extents, 'upper_overlap': upper_extents }) tile_map1_exit = graph.out_edges(map1_exit)[0].dst tile_map1_entry = graph.entry_node(tile_map1_exit) tile_map1_entry.label = 'BufferTiling' # Tile the second map MapTiling.apply_to(sdfg, map_entry=map2_entry, options={ 'tile_sizes': self.tile_sizes, 'tile_trivial': True }) tile_map2_entry = graph.in_edges(map2_entry)[0].src # Fuse maps some_buffer = next( iter(buffers)) # some dummy to pass to MapFusion.apply_to() MapFusion.apply_to(sdfg, first_map_exit=tile_map1_exit, array=some_buffer, second_map_entry=tile_map2_entry) # Optimize the simple cases map1_entry.range.ranges = [ (r[0], r[0], r[2]) if l_ext == 0 and u_ext == 0 and ts == 1 else r for r, l_ext, u_ext, ts in zip(map1_entry.range.ranges, lower_extents, upper_extents, self.tile_sizes) ] map2_entry.range.ranges = [ (r[0], r[0], r[2]) if ts == 1 else r for r, ts in zip(map2_entry.range.ranges, self.tile_sizes) ] if any(ts == 1 for ts in self.tile_sizes): if any(r[0] == r[1] for r in map1_entry.map.range): TrivialMapElimination.apply_to(sdfg, map_entry=map1_entry) if any(r[0] == r[1] for r in map2_entry.map.range): TrivialMapElimination.apply_to(sdfg, map_entry=map2_entry)
class InlineTransients(transformation.Transformation): """ Inlines all transient arrays that are not used anywhere else into a nested SDFG. """ nsdfg = transformation.PatternNode(nodes.NestedSDFG) @staticmethod def annotates_memlets(): return True @staticmethod def expressions(): return [sdutil.node_path_graph(InlineTransients.nsdfg)] @staticmethod def _candidates(sdfg: SDFG, graph: SDFGState, nsdfg: nodes.NestedSDFG) -> Dict[str, str]: candidates = {} for e in graph.all_edges(nsdfg): if e.data.is_empty(): continue conn = (e.src_conn if e.src is nsdfg else e.dst_conn) desc = sdfg.arrays[e.data.data] # Needs to be transient if not desc.transient: continue # Needs to be allocated in "Scope" lifetime if desc.lifetime is not dtypes.AllocationLifetime.Scope: continue # If same transient is connected with multiple connectors, bail # for now if e.data.data in candidates and candidates[e.data.data] != conn: del candidates[e.data.data] continue # (for now) needs to use entire data descriptor (skipped due to # above check for multiple connectors) # if desc.shape != e.data.subset.size(): # continue candidates[e.data.data] = conn if not candidates: return candidates # Check for uses in other states for state in sdfg.nodes(): if state is graph: continue for node in state.data_nodes(): if node.data in candidates: del candidates[node.data] if not candidates: return candidates # Check for uses in state access_nodes = set() for e in graph.in_edges(nsdfg): src = graph.memlet_path(e)[0].src if isinstance(src, nodes.AccessNode) and graph.in_degree(src) == 0: access_nodes.add(src) for e in graph.out_edges(nsdfg): dst = graph.memlet_path(e)[-1].dst if isinstance(dst, nodes.AccessNode) and graph.out_degree(dst) == 0: access_nodes.add(dst) for node in graph.data_nodes(): if node.data in candidates and node not in access_nodes: del candidates[node.data] return candidates @staticmethod def can_be_applied(graph: SDFGState, candidate: Dict[transformation.PatternNode, int], expr_index: int, sdfg: SDFG, strict: bool = False): nsdfg = graph.node(candidate[InlineTransients.nsdfg]) # Not every schedule is supported if strict: if nsdfg.schedule not in (dtypes.ScheduleType.Default, dtypes.ScheduleType.Sequential, dtypes.ScheduleType.CPU_Multicore, dtypes.ScheduleType.GPU_Device): return False candidates = InlineTransients._candidates(sdfg, graph, nsdfg) return len(candidates) > 0 @staticmethod def match_to_str(graph, candidate): return graph.label def apply(self, sdfg): state: SDFGState = sdfg.nodes()[self.state_id] nsdfg_node: nodes.NestedSDFG = self.nsdfg(sdfg) nsdfg: SDFG = nsdfg_node.sdfg toremove = InlineTransients._candidates(sdfg, state, nsdfg_node) for dname, cname in toremove.items(): # Make nested SDFG data descriptors transient nsdfg.arrays[cname].transient = True # Remove connectors from node nsdfg_node.remove_in_connector(cname) nsdfg_node.remove_out_connector(cname) # Remove data descriptor from outer SDFG del sdfg.arrays[dname] # Remove edges from outer SDFG for e in state.in_edges(nsdfg_node): if e.data.data not in toremove: continue tree = state.memlet_tree(e) for te in tree: state.remove_edge_and_connectors(te) # Remove newly isolated node state.remove_node(tree.root().edge.src) for e in state.out_edges(nsdfg_node): if e.data.data not in toremove: continue tree = state.memlet_tree(e) for te in tree: state.remove_edge_and_connectors(te) # Remove newly isolated node state.remove_node(tree.root().edge.dst)
class MapWCRFusion(pm.SingleStateTransformation): """ Implements the map expanded-reduce fusion transformation. Fuses a map with an immediately following reduction, where the array between the map and the reduction is not used anywhere else, and the reduction is divided to two maps with a WCR, denoting partial reduction. """ tasklet = pm.PatternNode(nodes.Tasklet) tmap_exit = pm.PatternNode(nodes.MapExit) in_array = pm.PatternNode(nodes.AccessNode) rmap_in_entry = pm.PatternNode(nodes.MapEntry) rmap_in_tasklet = pm.PatternNode(nodes.Tasklet) rmap_in_cr = pm.PatternNode(nodes.MapExit) rmap_out_entry = pm.PatternNode(nodes.MapEntry) rmap_out_exit = pm.PatternNode(nodes.MapExit) out_array = pm.PatternNode(nodes.AccessNode) @classmethod def expressions(cls): return [ # Map, then partial reduction of axes sdutil.node_path_graph(cls.tasklet, cls.tmap_exit, cls.in_array, cls.rmap_out_entry, cls.rmap_in_entry, cls.rmap_in_tasklet, cls.rmap_in_cr, cls.rmap_out_exit, cls.out_array) ] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): tmap_exit = self.tmap_exit in_array = self.in_array rmap_entry = self.rmap_out_entry # Make sure that the array is only accessed by the map and the reduce if any([ src != tmap_exit for src, _, _, _, memlet in graph.in_edges(in_array) ]): return False if any([ dest != rmap_entry for _, _, dest, _, memlet in graph.out_edges(in_array) ]): return False # Make sure that there is a reduction in the second map rmap_cr = self.rmap_in_cr reduce_edge = graph.in_edges(rmap_cr)[0] if reduce_edge.data.wcr is None: return False # Make sure that the transient is not accessed anywhere else # in this state or other states if not permissive and (len([ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.data == in_array.data ]) > 1 or in_array.data in sdfg.shared_transients()): return False # Verify that reduction ranges match tasklet map tout_memlet = graph.in_edges(in_array)[0].data rin_memlet = graph.out_edges(in_array)[0].data if tout_memlet.subset != rin_memlet.subset: return False return True def match_to_str(self, graph): return ' -> '.join( str(node) for node in [self.tasklet, self.tmap_exit, self.rmap_in_cr]) def apply(self, graph: SDFGState, sdfg: SDFG): # To apply, collapse the second map and then fuse the two resulting maps map_collapse = MapCollapse() map_collapse.setup_match( sdfg, self.sdfg_id, self.state_id, { MapCollapse.outer_map_entry: graph.node_id( self.rmap_out_entry), MapCollapse.inner_map_entry: graph.node_id(self.rmap_in_entry), }, 0) map_entry, _ = map_collapse.apply(graph, sdfg) map_fusion = MapFusion() map_fusion.setup_match( sdfg, self.sdfg_id, self.state_id, { MapFusion.first_map_exit: graph.node_id(self.tmap_exit), MapFusion.second_map_entry: graph.node_id(map_entry), }, 0) map_fusion.apply(graph, sdfg)
class StreamingMemory(xf.Transformation): """ Converts a read or a write to streaming memory access, where data is read/written to/from a stream in a separate connected component than the computation. """ access = xf.PatternNode(nodes.AccessNode) entry = xf.PatternNode(nodes.EntryNode) exit = xf.PatternNode(nodes.ExitNode) buffer_size = properties.Property( dtype=int, default=1, desc='Set buffer size for the newly-created stream') storage = properties.EnumProperty( dtype=dtypes.StorageType, desc='Set storage type for the newly-created stream', default=dtypes.StorageType.Default) @staticmethod def expressions() -> List[gr.SubgraphView]: return [ sdutil.node_path_graph(StreamingMemory.access, StreamingMemory.entry), sdutil.node_path_graph(StreamingMemory.exit, StreamingMemory.access), ] @staticmethod def can_be_applied(graph: SDFGState, candidate: Dict[xf.PatternNode, int], expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: access = graph.node(candidate[StreamingMemory.access]) # Make sure the access node is only accessed once (read or write), # and not at the same time if graph.out_degree(access) > 0 and graph.in_degree(access) > 0: return False # If already a stream, skip if isinstance(sdfg.arrays[access.data], data.Stream): return False # If does not exist on off-chip memory, skip if sdfg.arrays[access.data].storage not in [ dtypes.StorageType.CPU_Heap, dtypes.StorageType.CPU_Pinned, dtypes.StorageType.GPU_Global, dtypes.StorageType.FPGA_Global ]: return False # Only free nodes are allowed (search up the SDFG tree) curstate = graph node = access while curstate is not None: if curstate.entry_node(node) is not None: return False if curstate.parent.parent_nsdfg_node is None: break node = curstate.parent.parent_nsdfg_node curstate = curstate.parent.parent # Only one memlet path is allowed per outgoing/incoming edge edges = (graph.out_edges(access) if expr_index == 0 else graph.in_edges(access)) for edge in edges: mpath = graph.memlet_path(edge) if len(mpath) != len(list(graph.memlet_tree(edge))): return False # The innermost end of the path must have a clearly defined memory # access pattern innermost_edge = mpath[-1] if expr_index == 0 else mpath[0] if (innermost_edge.data.subset.num_elements() != 1 or innermost_edge.data.dynamic or innermost_edge.data.volume != 1): return False # Check if any of the maps has a dynamic range # These cases can potentially work but some nodes (and perhaps # tasklets) need to be replicated, which are difficult to track. for pe in mpath: node = pe.dst if expr_index == 0 else graph.entry_node(pe.src) if isinstance( node, nodes.MapEntry) and sdutil.has_dynamic_map_inputs( graph, node): return False # If already applied on this memlet and this is the I/O component, skip if expr_index == 0: other_node = graph.node(candidate[StreamingMemory.entry]) else: other_node = graph.node(candidate[StreamingMemory.exit]) other_node = graph.entry_node(other_node) if other_node.label.startswith('__s'): return False return True def apply(self, sdfg: SDFG) -> nodes.AccessNode: state = sdfg.node(self.state_id) dnode: nodes.AccessNode = self.access(sdfg) if self.expr_index == 0: edges = state.out_edges(dnode) else: edges = state.in_edges(dnode) # To understand how many components we need to create, all map ranges # throughout memlet paths must match exactly. We thus create a # dictionary of unique ranges mapping: Dict[Tuple[subsets.Range], List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict( list) ranges = {} for edge in edges: mpath = state.memlet_path(edge) ranges[edge] = _collect_map_ranges(state, mpath) mapping[tuple(r[1] for r in ranges[edge])].append(edge) # Collect all edges with the same memory access pattern components_to_create: Dict[ Tuple[symbolic.SymbolicType], List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict(list) for edges_with_same_range in mapping.values(): for edge in edges_with_same_range: # Get memlet path and innermost edge mpath = state.memlet_path(edge) innermost_edge = copy.deepcopy(mpath[-1] if self.expr_index == 0 else mpath[0]) # Store memlets of the same access in the same component expr = _canonicalize_memlet(innermost_edge.data, ranges[edge]) components_to_create[expr].append((innermost_edge, edge)) components = list(components_to_create.values()) # Split out components that have dependencies between them to avoid # deadlocks if self.expr_index == 0: ccs_to_add = [] for i, component in enumerate(components): edges_to_remove = set() for cedge in component: if any( nx.has_path(state.nx, o[1].dst, cedge[1].dst) for o in component if o is not cedge): ccs_to_add.append([cedge]) edges_to_remove.add(cedge) if edges_to_remove: components[i] = [ c for c in component if c not in edges_to_remove ] components.extend(ccs_to_add) # End of split desc = sdfg.arrays[dnode.data] # Create new streams of shape 1 streams = {} mpaths = {} for edge in edges: name, newdesc = sdfg.add_stream(dnode.data, desc.dtype, buffer_size=self.buffer_size, storage=self.storage, transient=True, find_new_name=True) streams[edge] = name mpath = state.memlet_path(edge) mpaths[edge] = mpath # Replace memlets in path with stream access for e in mpath: e.data = mm.Memlet(data=name, subset='0', other_subset=e.data.other_subset) if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc) # Replace access node and memlet tree with one access if self.expr_index == 0: replacement = state.add_read(name) state.remove_edge(edge) state.add_edge(replacement, edge.src_conn, edge.dst, edge.dst_conn, edge.data) else: replacement = state.add_write(name) state.remove_edge(edge) state.add_edge(edge.src, edge.src_conn, replacement, edge.dst_conn, edge.data) # Make read/write components ionodes = [] for component in components: # Pick the first edge as the edge to make the component from innermost_edge, outermost_edge = component[0] mpath = mpaths[outermost_edge] mapname = streams[outermost_edge] innermost_edge.data.other_subset = None # Get edge data and streams if self.expr_index == 0: opname = 'read' path = [e.dst for e in mpath[:-1]] rmemlets = [(dnode, '__inp', innermost_edge.data)] wmemlets = [] for i, (_, edge) in enumerate(component): name = streams[edge] ionode = state.add_write(name) ionodes.append(ionode) wmemlets.append( (ionode, '__out%d' % i, mm.Memlet(data=name, subset='0'))) code = '\n'.join('__out%d = __inp' % i for i in range(len(component))) else: # More than one input stream might mean a data race, so we only # address the first one in the tasklet code if len(component) > 1: warnings.warn( f'More than one input found for the same index for {dnode.data}' ) opname = 'write' path = [state.entry_node(e.src) for e in reversed(mpath[1:])] wmemlets = [(dnode, '__out', innermost_edge.data)] rmemlets = [] for i, (_, edge) in enumerate(component): name = streams[edge] ionode = state.add_read(name) ionodes.append(ionode) rmemlets.append( (ionode, '__inp%d' % i, mm.Memlet(data=name, subset='0'))) code = '__out = __inp0' # Create map structure for read/write component maps = [] for entry in path: map: nodes.Map = entry.map maps.append( state.add_map(f'__s{opname}_{mapname}', [(p, r) for p, r in zip(map.params, map.range)], map.schedule)) tasklet = state.add_tasklet( f'{opname}_{mapname}', {m[1] for m in rmemlets}, {m[1] for m in wmemlets}, code, ) for node, cname, memlet in rmemlets: state.add_memlet_path(node, *(me for me, _ in maps), tasklet, dst_conn=cname, memlet=memlet) for node, cname, memlet in wmemlets: state.add_memlet_path(tasklet, *(mx for _, mx in reversed(maps)), node, src_conn=cname, memlet=memlet) return ionodes
class OuterProductOperation(pm.Transformation): """ Detects outer-product operations. """ map_entry = pm.PatternNode(nodes.MapEntry) @staticmethod def expressions(): return [sdutil.node_path_graph(OuterProductOperation.map_entry)] @staticmethod def can_be_applied(graph: dace.SDFGState, candidate: Dict[pm.PatternNode, int], expr_index: int, sdfg: dace.SDFG, permissive: bool = False): map_entry = graph.node(candidate[OuterProductOperation.map_entry]) map_exit = graph.exit_node(map_entry) params = [dace.symbol(p) for p in map_entry.map.params] inputs = dict() for _, _, _, _, m in graph.out_edges(map_entry): if not m.data: continue desc = sdfg.arrays[m.data] if desc not in inputs.keys(): inputs[desc] = [] inputs[desc].append(m.subset) outer_product_found = False for desc, accesses in inputs.items(): if isinstance(desc, dace.data.Scalar): continue elif isinstance(desc, (dace.data.Array, dace.data.View)): if list(desc.shape) == [1]: continue for a in accesses: indices = a.min_element() unmatched_indices = set(params) for idx in indices: if not isinstance(idx, sympy.Symbol): return False if idx in unmatched_indices: unmatched_indices.remove(idx) if len(unmatched_indices) == 0: return False outer_product_found = True else: return False outputs = dict() for _, _, _, _, m in graph.in_edges(map_exit): if m.wcr: return False desc = sdfg.arrays[m.data] if desc not in outputs.keys(): outputs[desc] = [] outputs[desc].append(m.subset) for desc, accesses in outputs.items(): if isinstance(desc, (dace.data.Array, dace.data.View)): for a in accesses: if a.num_elements() != 1: return False indices = a.min_element() unmatched_indices = set(params) for idx in indices: if idx in unmatched_indices: unmatched_indices.remove(idx) if len(unmatched_indices) > 0: return False else: return False return outer_product_found @staticmethod def match_to_str(graph: dace.SDFGState, candidate: Dict[pm.PatternNode, int]) -> str: map_entry = graph.node(candidate[OuterProductOperation.map_entry]) return map_entry.map.label + ': ' + str(map_entry.map.params) def apply(self, sdfg: dace.SDFG): pass
class RedundantComm2D(pm.Transformation): """ Implements the redundant communication removal transformation, applied when data are scattered and immediately gathered, but never used anywhere else. """ in_array = pm.PatternNode(nodes.AccessNode) gather = pm.PatternNode(nodes.Tasklet) mid_array = pm.PatternNode(nodes.AccessNode) scatter = pm.PatternNode(nodes.Tasklet) out_array = pm.PatternNode(nodes.AccessNode) @staticmethod def expressions(): return [ sdutil.node_path_graph(RedundantComm2D.in_array, RedundantComm2D.gather, RedundantComm2D.mid_array, RedundantComm2D.scatter, RedundantComm2D.out_array) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False): gather = graph.nodes()[candidate[RedundantComm2D.gather]] if '_block_sizes' not in gather.in_connectors: return False scatter = graph.nodes()[candidate[RedundantComm2D.scatter]] if '_gdescriptor' not in scatter.out_connectors: return False in_array = graph.nodes()[candidate[RedundantComm2D.in_array]] out_array = graph.nodes()[candidate[RedundantComm2D.out_array]] in_desc = in_array.desc(sdfg) out_desc = out_array.desc(sdfg) if len(in_desc.shape) != 2: return False if in_desc.shape == out_desc.shape: return True return False @staticmethod def match_to_str(graph, candidate): in_array = graph.nodes()[candidate[RedundantComm2D.in_array]] return "Remove " + str(in_array) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] in_array = self.in_array(sdfg) gather = self.gather(sdfg) mid_array = self.mid_array(sdfg) scatter = self.scatter(sdfg) out_array = self.out_array(sdfg) in_desc = sdfg.arrays[in_array.data] out_desc = sdfg.arrays[out_array.data] for e in graph.in_edges(gather): if e.src != in_array: if graph.in_degree(e.src) == 0 and graph.out_degree( e.src) == 1: graph.remove_edge(e) graph.remove_node(e.src) for e in graph.out_edges(scatter): if e.dst != out_array: if graph.in_degree(e.dst) == 1 and graph.out_degree( e.dst) == 0: graph.remove_edge(e) graph.remove_node(e.dst) for e in graph.out_edges(out_array): path = graph.memlet_tree(e) for e2 in path: if e2.data.data == out_array.data: e2.data.data = in_array.data graph.remove_edge(e) graph.add_edge(in_array, None, e.dst, e.dst_conn, dace.Memlet.from_array(in_array, in_desc)) graph.remove_node(gather) graph.remove_node(mid_array) graph.remove_node(scatter) graph.remove_node(out_array)
class ElementWiseArrayOperation2D(pm.Transformation): """ Distributes element-wise array operations. """ _map_entry = pm.PatternNode(nodes.MapEntry) @staticmethod def expressions(): return [sdutil.node_path_graph(ElementWiseArrayOperation2D._map_entry)] @staticmethod def can_be_applied(graph: dace.SDFGState, candidate: Dict[pm.PatternNode, int], expr_index: int, sdfg: dace.SDFG, permissive: bool = False): map_entry = graph.node( candidate[ElementWiseArrayOperation2D._map_entry]) map_exit = graph.exit_node(map_entry) params = [dace.symbol(p) for p in map_entry.map.params] if len(params) != 2: return False if "commsize" in map_entry.map.range.free_symbols: return False if "Px" in map_entry.map.range.free_symbols: return False if "Py" in map_entry.map.range.free_symbols: return False inputs = dict() for _, _, _, _, m in graph.out_edges(map_entry): if not m.data: continue desc = sdfg.arrays[m.data] if desc not in inputs.keys(): inputs[desc] = [] inputs[desc].append(m.subset) for desc, accesses in inputs.items(): if isinstance(desc, dace.data.Scalar): continue elif isinstance(desc, (dace.data.Array, dace.data.View)): if list(desc.shape) == [1]: continue if len(desc.shape) != 2: return False for a in accesses: if a.num_elements() != 1: return False indices = a.min_element() unmatched_indices = set(params) for idx in indices: if idx in unmatched_indices: unmatched_indices.remove(idx) if len(unmatched_indices) > 0: return False else: return False outputs = dict() for _, _, _, _, m in graph.in_edges(map_exit): if m.wcr: return False desc = sdfg.arrays[m.data] if desc not in outputs.keys(): outputs[desc] = [] outputs[desc].append(m.subset) for desc, accesses in outputs.items(): if isinstance(desc, (dace.data.Array, dace.data.View)): if len(desc.shape) != 2: return False for a in accesses: if a.num_elements() != 1: return False indices = a.min_element() unmatched_indices = set(params) for idx in indices: if idx in unmatched_indices: unmatched_indices.remove(idx) if len(unmatched_indices) > 0: return False else: return False return True @staticmethod def match_to_str(graph: dace.SDFGState, candidate: Dict[pm.PatternNode, int]) -> str: map_entry = graph.node( candidate[ElementWiseArrayOperation2D._map_entry]) return map_entry.map.label + ': ' + str(map_entry.map.params) def apply(self, sdfg: dace.SDFG): graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[self._map_entry]] map_exit = graph.exit_node(map_entry) sz = dace.symbol('commsize', dtype=dace.int32, integer=True, positive=True) Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True) Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True) def _prod(sequence): return reduce(lambda a, b: a * b, sequence, 1) # NOTE: Maps with step in their ranges are currently not supported if len(map_entry.map.params) == 2: params = map_entry.map.params ranges = [None] * 2 b, e, _ = map_entry.map.range[0] ranges[0] = (0, (e - b + 1) / Px - 1, 1) b, e, _ = map_entry.map.range[1] ranges[1] = (0, (e - b + 1) / Py - 1, 1) strides = [1] else: params = ['__iflat'] sizes = map_entry.map.range.size_exact() total_size = _prod(sizes) ranges = [(0, (total_size) / sz - 1, 1)] strides = [_prod(sizes[i + 1:]) for i in range(len(sizes))] root_name = sdfg.temp_data_name() sdfg.add_scalar(root_name, dace.int32, transient=True) root_node = graph.add_access(root_name) root_tasklet = graph.add_tasklet('_set_root_', {}, {'__out'}, '__out = 0') graph.add_edge(root_tasklet, '__out', root_node, None, dace.Memlet.simple(root_name, '0')) from dace.libraries.mpi import Bcast from dace.libraries.pblas import BlockCyclicScatter, BlockCyclicGather inputs = set() for src, _, _, _, m in graph.in_edges(map_entry): if not isinstance(src, nodes.AccessNode): raise NotImplementedError desc = src.desc(sdfg) if not isinstance(desc, (data.Scalar, data.Array)): raise NotImplementedError if list(desc.shape) != m.src_subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.src_subset.size_exact()): raise NotImplementedError inputs.add(src) for inp in inputs: desc = inp.desc(sdfg) if isinstance(desc, data.Scalar): local_access = graph.add_access(inp.data) bcast_node = Bcast('_Bcast_') graph.add_edge(inp, None, bcast_node, '_inbuffer', dace.Memlet.from_array(inp.data, desc)) graph.add_edge(root_node, None, bcast_node, '_root', dace.Memlet.simple(root_name, '0')) graph.add_edge(bcast_node, '_outbuffer', local_access, None, dace.Memlet.from_array(inp.data, desc)) for e in graph.edges_between(inp, map_entry): graph.add_edge(local_access, None, map_entry, e.dst_conn, dace.Memlet.from_array(inp.data, desc)) graph.remove_edge(e) elif isinstance(desc, data.Array): local_name, local_arr = sdfg.add_temp_transient( [(desc.shape[0]) // Px, (desc.shape[1]) // Py], dtype=desc.dtype, storage=desc.storage) local_access = graph.add_access(local_name) bsizes_name, bsizes_arr = sdfg.add_temp_transient( (2, ), dtype=dace.int32) bsizes_access = graph.add_access(bsizes_name) bsizes_tasklet = nodes.Tasklet( '_set_bsizes_', {}, {'__out'}, "__out[0] = {x}; __out[1] = {y}".format( x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py)) graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None, dace.Memlet.from_array(bsizes_name, bsizes_arr)) gdesc_name, gdesc_arr = sdfg.add_temp_transient( (9, ), dtype=dace.int32) gdesc_access = graph.add_access(gdesc_name) ldesc_name, ldesc_arr = sdfg.add_temp_transient( (9, ), dtype=dace.int32) ldesc_access = graph.add_access(ldesc_name) scatter_node = BlockCyclicScatter('_Scatter_') graph.add_edge(inp, None, scatter_node, '_inbuffer', dace.Memlet.from_array(inp.data, desc)) graph.add_edge(bsizes_access, None, scatter_node, '_block_sizes', dace.Memlet.from_array(bsizes_name, bsizes_arr)) graph.add_edge(scatter_node, '_outbuffer', local_access, None, dace.Memlet.from_array(local_name, local_arr)) graph.add_edge(scatter_node, '_gdescriptor', gdesc_access, None, dace.Memlet.from_array(gdesc_name, gdesc_arr)) graph.add_edge(scatter_node, '_ldescriptor', ldesc_access, None, dace.Memlet.from_array(ldesc_name, ldesc_arr)) for e in graph.edges_between(inp, map_entry): graph.add_edge( local_access, None, map_entry, e.dst_conn, dace.Memlet.from_array(local_name, local_arr)) graph.remove_edge(e) for e in graph.out_edges(map_entry): if e.data.data == inp.data: e.data.data = local_name else: raise NotImplementedError outputs = set() for _, _, dst, _, m in graph.out_edges(map_exit): if not isinstance(dst, nodes.AccessNode): raise NotImplementedError desc = dst.desc(sdfg) if not isinstance(desc, data.Array): raise NotImplementedError try: if list(desc.shape) != m.dst_subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.dst_subset.size_exact()): raise NotImplementedError except AttributeError: if list(desc.shape) != m.subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.subset.size_exact()): raise NotImplementedError outputs.add(dst) for out in outputs: desc = out.desc(sdfg) if isinstance(desc, data.Scalar): raise NotImplementedError elif isinstance(desc, data.Array): local_name, local_arr = sdfg.add_temp_transient( [(desc.shape[0]) // Px, (desc.shape[1]) // Py], dtype=desc.dtype, storage=desc.storage) local_access = graph.add_access(local_name) bsizes_name, bsizes_arr = sdfg.add_temp_transient( (2, ), dtype=dace.int32) bsizes_access = graph.add_access(bsizes_name) bsizes_tasklet = nodes.Tasklet( '_set_bsizes_', {}, {'__out'}, "__out[0] = {x}; __out[1] = {y}".format( x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py)) graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None, dace.Memlet.from_array(bsizes_name, bsizes_arr)) scatter_node = BlockCyclicGather('_Gather_') graph.add_edge(local_access, None, scatter_node, '_inbuffer', dace.Memlet.from_array(local_name, local_arr)) graph.add_edge(bsizes_access, None, scatter_node, '_block_sizes', dace.Memlet.from_array(bsizes_name, bsizes_arr)) graph.add_edge(scatter_node, '_outbuffer', out, None, dace.Memlet.from_array(out.data, desc)) for e in graph.edges_between(map_exit, out): graph.add_edge( map_exit, e.src_conn, local_access, None, dace.Memlet.from_array(local_name, local_arr)) graph.remove_edge(e) for e in graph.in_edges(map_exit): if e.data.data == out.data: e.data.data = local_name else: raise NotImplementedError map_entry.map.params = params map_entry.map.range = subsets.Range(ranges)
class FPGATransformState(transformation.MultiStateTransformation): """ Implements the FPGATransformState transformation. """ state = transformation.PatternNode(sd.SDFGState) @classmethod def expressions(cls): return [sdutil.node_path_graph(cls.state)] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): state = self.state for node, graph in state.all_nodes_recursive(): # Consume scopes are currently unsupported if isinstance(node, (nodes.ConsumeEntry, nodes.ConsumeExit)): return False # Streams have strict conditions due to code generator limitations if (isinstance(node, nodes.AccessNode) and isinstance( graph.parent.arrays[node.data], data.Stream)): nodedesc = graph.parent.arrays[node.data] sdict = graph.scope_dict() if nodedesc.storage in [ dtypes.StorageType.CPU_Heap, dtypes.StorageType.CPU_Pinned, dtypes.StorageType.CPU_ThreadLocal ]: return False # Cannot allocate FIFO from CPU code if sdict[node] is None: return False # Arrays of streams cannot have symbolic size on FPGA if dace.symbolic.issymbolic(nodedesc.total_size, graph.parent.constants): return False # Streams cannot be unbounded on FPGA if nodedesc.buffer_size < 1: return False for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).storage not in (dtypes.StorageType.Default, dtypes.StorageType.Register)): return False if not isinstance(node, nodes.MapEntry): continue map_entry = node candidate_map = map_entry.map # Map schedules that are disallowed to transform to FPGAs if (candidate_map.schedule == dtypes.ScheduleType.MPI or candidate_map.schedule == dtypes.ScheduleType.GPU_Device or candidate_map.schedule == dtypes.ScheduleType.FPGA_Device or candidate_map.schedule == dtypes.ScheduleType.GPU_ThreadBlock): return False # Recursively check parent for FPGA schedules sdict = state.scope_dict() current_node = map_entry while current_node is not None: if (current_node.map.schedule == dtypes.ScheduleType.GPU_Device or current_node.map.schedule == dtypes.ScheduleType.FPGA_Device or current_node.map.schedule == dtypes.ScheduleType.GPU_ThreadBlock): return False current_node = sdict[current_node] return True def apply(self, _, sdfg): state = self.state # Find source/sink (data) nodes that are relevant outside this FPGA # kernel shared_transients = set(sdfg.shared_transients()) input_nodes = [ n for n in sdutil.find_source_nodes(state) if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) ] output_nodes = [ n for n in sdutil.find_sink_nodes(state) if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) ] fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.sdfg.nodes.AccessNode): for e in graph.in_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, memlet_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: _, outer_node = node_trace if outer_node is not None: break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_node) wcr_input_nodes.add(outer_node) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_array[1].location = copy.copy(desc.location) desc.location.clear() fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) mem = memlet.Memlet(data=node.data, subset=subsets.Range.from_array(desc)) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) sdutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) sdutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, sd.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_array[1].location = copy.copy(desc.location) desc.location.clear() fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) mem = memlet.Memlet(f"fpga_{node.data}", None, subsets.Range.from_array(desc)) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) sdutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) sdutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, sd.InterstateEdge()) # propagate memlet info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data fpga_update(sdfg, state, 0)
class RefineNestedAccess(transformation.Transformation): """ Reduces memlet shape when a memlet is connected to a nested SDFG, but not using all of the contents. Makes the outer memlet smaller in shape and ensures that the offsets in the nested SDFG start with zero. This helps with subsequent transformations on the outer SDFGs. For example, in the following program:: @dace.program def func_a(y): return y[1:5] + 1 @dace.program def main(x: dace.float32[N]): return func_a(x) The memlet pointing to ``func_a`` will contain all of ``x`` (``x[0:N]``), and it is offset to ``y[1:5]`` in the function, with ``y``'s size being ``N``. After the transformation, the memlet connected to the nested SDFG of ``func_a`` would contain ``x[1:5]`` directly and the internal ``y`` array would have a size of 4, accessed as ``y[0:4]``. """ nsdfg = transformation.PatternNode(nodes.NestedSDFG) @staticmethod def annotates_memlets(): return True @staticmethod def expressions(): return [sdutil.node_path_graph(RefineNestedAccess.nsdfg)] @staticmethod def _candidates( state: SDFGState, nsdfg: nodes.NestedSDFG ) -> Tuple[Dict[str, Tuple[Memlet, Set[int]]], Dict[str, Tuple[Memlet, Set[int]]]]: in_candidates: Dict[str, Tuple[Memlet, SDFGState, Set[int]]] = {} out_candidates: Dict[str, Tuple[Memlet, SDFGState, Set[int]]] = {} ignore = set() for nstate in nsdfg.sdfg.nodes(): for dnode in nstate.data_nodes(): if nsdfg.sdfg.arrays[dnode.data].transient: continue # For now we only detect one element for e in nstate.in_edges(dnode): # If more than one unique element detected, remove from # candidates if e.data.data in out_candidates: memlet, ns, indices = out_candidates[e.data.data] # Try to find dimensions in which there is a mismatch # and remove them from list for i, (s1, s2) in enumerate( zip(e.data.subset, memlet.subset)): if s1 != s2 and i in indices: indices.remove(i) if len(indices) == 0: ignore.add(e.data.data) out_candidates[e.data.data] = (memlet, ns, indices) continue out_candidates[e.data.data] = (e.data, nstate, set( range(len( e.data.subset)))) for e in nstate.out_edges(dnode): # If more than one unique element detected, remove from # candidates if e.data.data in in_candidates: memlet, ns, indices = in_candidates[e.data.data] # Try to find dimensions in which there is a mismatch # and remove them from list for i, (s1, s2) in enumerate( zip(e.data.subset, memlet.subset)): if s1 != s2 and i in indices: indices.remove(i) if len(indices) == 0: ignore.add(e.data.data) in_candidates[e.data.data] = (memlet, ns, indices) continue in_candidates[e.data.data] = (e.data, nstate, set(range(len( e.data.subset)))) # TODO: Check in_candidates in interstate edges as well # Check in/out candidates for cand in in_candidates.keys() & out_candidates.keys(): s1, nstate1, ind1 = in_candidates[cand] s2, nstate2, ind2 = out_candidates[cand] indices = ind1 & ind2 if any(s1.subset[ind] != s2.subset[ind] for ind in indices): ignore.add(cand) in_candidates[cand] = (s1, nstate1, indices) out_candidates[cand] = (s2, nstate2, indices) # Ensure minimum elements of candidates do not begin with zero def _check_cand(candidates, outer_edges): for cname, (cand, nstate, indices) in candidates.items(): if all(me == 0 for i, me in enumerate(cand.subset.min_element()) if i in indices): ignore.add(cname) continue # Ensure outer memlets begin with 0 outer_edge = next(iter(outer_edges(nsdfg, cname))) if any(me != 0 for i, me in enumerate( outer_edge.data.subset.min_element()) if i in indices): ignore.add(cname) continue # Check w.r.t. loops if len(nstate.ranges) > 0: # Re-annotate loop ranges, in case someone changed them # TODO: Move out of here! nstate.ranges = {} from dace.sdfg.propagation import _annotate_loop_ranges _annotate_loop_ranges(nsdfg.sdfg, []) memlet = propagation.propagate_subset( [cand], nsdfg.sdfg.arrays[cname], sorted(nstate.ranges.keys()), subsets.Range([ v.ndrange()[0] for _, v in sorted(nstate.ranges.items()) ])) if all(me == 0 for i, me in enumerate(memlet.subset.min_element()) if i in indices): ignore.add(cname) continue # Modify memlet to propagated one candidates[cname] = (memlet, nstate, indices) else: memlet = cand # If there are any symbols here that are not defined # in "defined_symbols" missing_symbols = (memlet.free_symbols - set(nsdfg.symbol_mapping.keys())) if missing_symbols: ignore.add(cname) continue _check_cand(in_candidates, state.in_edges_by_connector) _check_cand(out_candidates, state.out_edges_by_connector) # Return result, filtering out the states return ({ k: (dc(v), ind) for k, (v, _, ind) in in_candidates.items() if k not in ignore }, { k: (dc(v), ind) for k, (v, _, ind) in out_candidates.items() if k not in ignore }) @staticmethod def can_be_applied(graph: SDFGState, candidate: Dict[transformation.PatternNode, int], expr_index: int, sdfg: SDFG, strict: bool = False): nsdfg = graph.node(candidate[RefineNestedAccess.nsdfg]) ic, oc = RefineNestedAccess._candidates(graph, nsdfg) return (len(ic) + len(oc)) > 0 @staticmethod def match_to_str(graph, candidate): return graph.label def apply(self, sdfg): state: SDFGState = sdfg.nodes()[self.state_id] nsdfg_node: nodes.NestedSDFG = self.nsdfg(sdfg) nsdfg: SDFG = nsdfg_node.sdfg torefine_in, torefine_out = RefineNestedAccess._candidates( state, nsdfg_node) refined = set() def _offset_refine( torefine: Dict[str, Tuple[Memlet, Set[int]]], outer_edges: Callable[[nodes.NestedSDFG, str], Iterable[MultiConnectorEdge[Memlet]]]): # Offset memlets inside negatively by "refine", modify outer # memlets to be "refine" for aname, (refine, indices) in torefine.items(): outer_edge = next(iter(outer_edges(nsdfg_node, aname))) new_memlet = helpers.unsqueeze_memlet(refine, outer_edge.data) outer_edge.data.subset = subsets.Range([ ns if i in indices else os for i, (os, ns) in enumerate( zip(outer_edge.data.subset, new_memlet.subset)) ]) if aname in refined: continue # Refine internal memlets for nstate in nsdfg.nodes(): for e in nstate.edges(): if e.data.data == aname: e.data.subset.offset(refine.subset, True, indices) # Refine accesses in interstate edges refiner = ASTRefiner(aname, refine.subset, nsdfg, indices) for isedge in nsdfg.edges(): for k, v in isedge.data.assignments.items(): vast = ast.parse(v) refiner.visit(vast) isedge.data.assignments[k] = astutils.unparse(vast) if isedge.data.condition.language is dtypes.Language.Python: for i, stmt in enumerate(isedge.data.condition.code): isedge.data.condition.code[i] = refiner.visit(stmt) else: raise NotImplementedError refined.add(aname) # Proceed symmetrically on incoming and outgoing edges _offset_refine(torefine_in, state.in_edges_by_connector) _offset_refine(torefine_out, state.out_edges_by_connector)
class MapTiling(transformation.Transformation): """ Implements the orthogonal tiling transformation. Orthogonal tiling is a type of nested map fission that creates tiles in every dimension of the matched Map. """ map_entry = transformation.PatternNode(nodes.MapEntry) # Properties prefix = Property(dtype=str, default="tile", desc="Prefix for new range symbols") tile_sizes = ShapeProperty(dtype=tuple, default=(128, 128, 128), desc="Tile size per dimension") strides = ShapeProperty( dtype=tuple, default=tuple(), desc="Tile stride (enables overlapping tiles). If empty, matches tile") tile_offset = ShapeProperty(dtype=tuple, default=None, desc="Negative Stride offset per dimension", allow_none=True) divides_evenly = Property(dtype=bool, default=False, desc="Tile size divides dimension length evenly") tile_trivial = Property(dtype=bool, default=False, desc="Tiles even if tile_size is 1") @staticmethod def annotates_memlets(): return True @staticmethod def expressions(): return [sdutil.node_path_graph(MapTiling.map_entry)] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): return True @staticmethod def match_to_str(graph, candidate): map_entry = graph.nodes()[candidate[MapTiling.map_entry]] return map_entry.map.label + ': ' + str(map_entry.map.params) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] tile_strides = self.tile_sizes if self.strides is not None and len(self.strides) == len(tile_strides): tile_strides = self.strides # Retrieve map entry and exit nodes. map_entry = graph.nodes()[self.subgraph[MapTiling.map_entry]] from dace.transformation.dataflow.map_collapse import MapCollapse from dace.transformation.dataflow.strip_mining import StripMining stripmine_subgraph = { StripMining._map_entry: self.subgraph[MapTiling.map_entry] } sdfg_id = sdfg.sdfg_id last_map_entry = None removed_maps = 0 original_schedule = map_entry.schedule for dim_idx in range(len(map_entry.map.params)): if dim_idx >= len(self.tile_sizes): tile_size = symbolic.pystr_to_symbolic(self.tile_sizes[-1]) tile_stride = symbolic.pystr_to_symbolic(tile_strides[-1]) else: tile_size = symbolic.pystr_to_symbolic( self.tile_sizes[dim_idx]) tile_stride = symbolic.pystr_to_symbolic(tile_strides[dim_idx]) # handle offsets if self.tile_offset and dim_idx >= len(self.tile_offset): offset = self.tile_offset[-1] elif self.tile_offset: offset = self.tile_offset[dim_idx] else: offset = 0 dim_idx -= removed_maps # If tile size is trivial, skip strip-mining map dimension if tile_size == map_entry.map.range.size()[dim_idx]: continue stripmine = StripMining(sdfg_id, self.state_id, stripmine_subgraph, self.expr_index) # Special case: Tile size of 1 should be omitted from inner map if tile_size == 1 and tile_stride == 1 and self.tile_trivial == False: stripmine.dim_idx = dim_idx stripmine.new_dim_prefix = '' stripmine.tile_size = str(tile_size) stripmine.tile_stride = str(tile_stride) stripmine.divides_evenly = True stripmine.tile_offset = str(offset) stripmine.apply(sdfg) removed_maps += 1 else: stripmine.dim_idx = dim_idx stripmine.new_dim_prefix = self.prefix stripmine.tile_size = str(tile_size) stripmine.tile_stride = str(tile_stride) stripmine.divides_evenly = self.divides_evenly stripmine.tile_offset = str(offset) stripmine.apply(sdfg) # apply to the new map the schedule of the original one map_entry.schedule = original_schedule if last_map_entry: new_map_entry = graph.in_edges(map_entry)[0].src mapcollapse_subgraph = { MapCollapse._outer_map_entry: graph.node_id(last_map_entry), MapCollapse._inner_map_entry: graph.node_id(new_map_entry) } mapcollapse = MapCollapse(sdfg_id, self.state_id, mapcollapse_subgraph, 0) mapcollapse.apply(sdfg) last_map_entry = graph.in_edges(map_entry)[0].src return last_map_entry
class MapReduceFusion(pm.SingleStateTransformation): """ Implements the map-reduce-fusion transformation. Fuses a map with an immediately following reduction, where the array between the map and the reduction is not used anywhere else. """ no_init = Property( dtype=bool, default=False, desc='If enabled, does not create initialization states ' 'for reduce nodes with identity') tasklet = pm.PatternNode(nodes.Tasklet) tmap_exit = pm.PatternNode(nodes.MapExit) in_array = pm.PatternNode(nodes.AccessNode) import dace.libraries.standard as stdlib # Avoid import loop reduce = pm.PatternNode(stdlib.Reduce) out_array = pm.PatternNode(nodes.AccessNode) @classmethod def expressions(cls): return [ sdutil.node_path_graph(cls.tasklet, cls.tmap_exit, cls.in_array, cls.reduce, cls.out_array) ] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): tmap_exit = self.tmap_exit in_array = self.in_array reduce_node = self.reduce tasklet = self.tasklet # Make sure that the array is only accessed by the map and the reduce if any([ src != tmap_exit for src, _, _, _, memlet in graph.in_edges(in_array) ]): return False if any([ dest != reduce_node for _, _, dest, _, memlet in graph.out_edges(in_array) ]): return False tmem = next(e for e in graph.edges_between(tasklet, tmap_exit) if e.data.data == in_array.data).data # Make sure that the transient is not accessed anywhere else # in this state or other states if not permissive and (len([ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.data == in_array.data ]) > 1 or in_array.data in sdfg.shared_transients()): return False # If memlet already has WCR and it is different from reduce node, # do not match if tmem.wcr is not None and tmem.wcr != reduce_node.wcr: return False # Verify that reduction ranges match tasklet map tout_memlet = graph.in_edges(in_array)[0].data rin_memlet = graph.out_edges(in_array)[0].data if tout_memlet.subset != rin_memlet.subset: return False return True def match_to_str(self, graph): return ' -> '.join( str(node) for node in [self.tasklet, self.tmap_exit, self.reduce]) def apply(self, graph: SDFGState, sdfg: SDFG): tmap_exit = self.tmap_exit in_array = self.in_array reduce_node = self.reduce out_array = self.out_array # Set nodes to remove according to the expression index nodes_to_remove = [in_array] nodes_to_remove.append(reduce_node) memlet_edge = None for edge in graph.in_edges(tmap_exit): if edge.data.data == in_array.data: memlet_edge = edge break if memlet_edge is None: raise RuntimeError('Reduction memlet cannot be None') # Find which indices should be removed from new memlet input_edge = graph.in_edges(reduce_node)[0] axes = reduce_node.axes or list(range(len(input_edge.data.subset))) array_edge = graph.out_edges(reduce_node)[0] # Delete relevant edges and nodes graph.remove_nodes_from(nodes_to_remove) # Delete relevant data descriptors for node in set(nodes_to_remove): if isinstance(node, nodes.AccessNode): # try to delete it try: sdfg.remove_data(node.data) # will raise ValueError if the datadesc is used somewhere else except ValueError: pass # Filter out reduced dimensions from subset filtered_subset = [ dim for i, dim in enumerate(memlet_edge.data.subset) if i not in axes ] if len(filtered_subset) == 0: # Output is a scalar filtered_subset = [(0, 0, 1)] # Modify edge from tasklet to map exit memlet_edge.data.data = out_array.data memlet_edge.data.wcr = reduce_node.wcr memlet_edge.data.subset = type( memlet_edge.data.subset)(filtered_subset) # Add edge from map exit to output array graph.add_edge( memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:], array_edge.dst, array_edge.dst_conn, Memlet.simple(array_edge.data.data, array_edge.data.subset, num_accesses=array_edge.data.num_accesses, wcr_str=reduce_node.wcr)) # Add initialization state as necessary if not self.no_init and reduce_node.identity is not None: init_state = sdfg.add_state_before(graph) init_state.add_mapped_tasklet( 'freduce_init', [('o%d' % i, '%s:%s:%s' % (r[0], r[1] + 1, r[2])) for i, r in enumerate(array_edge.data.subset)], {}, '__out = %s' % reduce_node.identity, { '__out': Memlet.simple( array_edge.data.data, ','.join([ 'o%d' % i for i in range(len(array_edge.data.subset)) ])) }, external_edges=True)
class DoubleBuffering(transformation.SingleStateTransformation): """ Implements the double buffering pattern, which pipelines reading and processing data by creating a second copy of the memory. In particular, the transformation takes a 1D map and all internal (directly connected) transients, adds an additional dimension of size 2, and turns the map into a for loop that processes and reads the data in a double-buffered manner. Other memlets will not be transformed. """ map_entry = transformation.PatternNode(nodes.MapEntry) transient = transformation.PatternNode(nodes.AccessNode) @classmethod def expressions(cls): return [sdutil.node_path_graph(cls.map_entry, cls.transient)] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): map_entry = self.map_entry transient = self.transient # Only one dimensional maps are allowed if len(map_entry.map.params) != 1: return False # Verify the map can be transformed to a for-loop m2for = MapToForLoop() m2for.setup_match( sdfg, sdfg.sdfg_id, self.state_id, {MapToForLoop.map_entry: self.subgraph[DoubleBuffering.map_entry]}, expr_index) if not m2for.can_be_applied(graph, expr_index, sdfg, permissive): return False # Verify that all directly-connected internal access nodes point to # transient arrays first = True for edge in graph.out_edges(map_entry): if isinstance(edge.dst, nodes.AccessNode): desc = sdfg.arrays[edge.dst.data] if not isinstance(desc, data.Array) or not desc.transient: return False else: # To avoid duplicate matches, only match the first transient if first and edge.dst != transient: return False first = False return True def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): map_entry = self.map_entry map_param = map_entry.map.params[0] # Assuming one dimensional ############################## # Change condition of loop to one fewer iteration (so that the # final one reads from the last buffer) map_rstart, map_rend, map_rstride = map_entry.map.range[0] map_rend = symbolic.pystr_to_symbolic('(%s) - (%s)' % (map_rend, map_rstride)) map_entry.map.range = subsets.Range([(map_rstart, map_rend, map_rstride)]) ############################## # Gather transients to modify transients_to_modify = set(edge.dst.data for edge in graph.out_edges(map_entry) if isinstance(edge.dst, nodes.AccessNode)) # Add dimension to transients and modify memlets for transient in transients_to_modify: desc: data.Array = sdfg.arrays[transient] # Using non-python syntax to ensure properties change desc.strides = [desc.total_size] + list(desc.strides) desc.shape = [2] + list(desc.shape) desc.offset = [0] + list(desc.offset) desc.total_size = desc.total_size * 2 ############################## # Modify memlets to use map parameter as buffer index modified_subsets = [] # Store modified memlets for final state for edge in graph.scope_subgraph(map_entry).edges(): if edge.data.data in transients_to_modify: edge.data.subset = self._modify_memlet(sdfg, edge.data.subset, edge.data.data) modified_subsets.append(edge.data.subset) else: # Could be other_subset path = graph.memlet_path(edge) src_node = path[0].src dst_node = path[-1].dst # other_subset could be None. In that case, recreate from array dataname = None if (isinstance(src_node, nodes.AccessNode) and src_node.data in transients_to_modify): dataname = src_node.data elif (isinstance(dst_node, nodes.AccessNode) and dst_node.data in transients_to_modify): dataname = dst_node.data if dataname is not None: subset = (edge.data.other_subset or subsets.Range.from_array(sdfg.arrays[dataname])) edge.data.other_subset = self._modify_memlet( sdfg, subset, dataname) modified_subsets.append(edge.data.other_subset) ############################## # Turn map into for loop map_to_for = MapToForLoop() map_to_for.setup_match( sdfg, self.sdfg_id, self.state_id, {MapToForLoop.map_entry: graph.node_id(self.map_entry)}, self.expr_index) nsdfg_node, nstate = map_to_for.apply(graph, sdfg) ############################## # Gather node copies and remove memlets edges_to_replace = [] for node in nstate.source_nodes(): for edge in nstate.out_edges(node): if (isinstance(edge.dst, nodes.AccessNode) and edge.dst.data in transients_to_modify): edges_to_replace.append(edge) nstate.remove_edge(edge) if nstate.out_degree(node) == 0: nstate.remove_node(node) ############################## # Add initial reads to initial nested state initial_state: sd.SDFGState = nsdfg_node.sdfg.start_state initial_state.set_label('%s_init' % map_entry.map.label) for edge in edges_to_replace: initial_state.add_node(edge.src) rnode = edge.src wnode = initial_state.add_write(edge.dst.data) initial_state.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn, copy.deepcopy(edge.data)) # All instances of the map parameter in this state become the loop start sd.replace(initial_state, map_param, map_rstart) # Initial writes go to the appropriate buffer init_expr = symbolic.pystr_to_symbolic('(%s / %s) %% 2' % (map_rstart, map_rstride)) sd.replace(initial_state, '__dace_db_param', init_expr) ############################## # Modify main state's memlets # Divide by loop stride new_expr = symbolic.pystr_to_symbolic('(%s / %s) %% 2' % (map_param, map_rstride)) sd.replace(nstate, '__dace_db_param', new_expr) ############################## # Add the main state's contents to the last state, modifying # memlets appropriately. final_state: sd.SDFGState = nsdfg_node.sdfg.sink_nodes()[0] final_state.set_label('%s_final_computation' % map_entry.map.label) dup_nstate = copy.deepcopy(nstate) final_state.add_nodes_from(dup_nstate.nodes()) for e in dup_nstate.edges(): final_state.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data) # If there is a WCR output with transient, only output in last state nstate: sd.SDFGState for node in nstate.sink_nodes(): for e in list(nstate.in_edges(node)): if e.data.wcr is not None: path = nstate.memlet_path(e) if isinstance(path[0].src, nodes.AccessNode): nstate.remove_memlet_path(e) ############################## # Add reads into next buffers to main state for edge in edges_to_replace: rnode = copy.deepcopy(edge.src) nstate.add_node(rnode) wnode = nstate.add_write(edge.dst.data) new_memlet = copy.deepcopy(edge.data) if new_memlet.data in transients_to_modify: new_memlet.other_subset = self._replace_in_subset( new_memlet.other_subset, map_param, '(%s + %s)' % (map_param, map_rstride)) else: new_memlet.subset = self._replace_in_subset( new_memlet.subset, map_param, '(%s + %s)' % (map_param, map_rstride)) nstate.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn, new_memlet) nstate.set_label('%s_double_buffered' % map_entry.map.label) # Divide by loop stride new_expr = symbolic.pystr_to_symbolic('((%s / %s) + 1) %% 2' % (map_param, map_rstride)) sd.replace(nstate, '__dace_db_param', new_expr) # Remove symbol once done del nsdfg_node.sdfg.symbols['__dace_db_param'] del nsdfg_node.symbol_mapping['__dace_db_param'] return nsdfg_node @staticmethod def _modify_memlet(sdfg, subset, data_name): desc = sdfg.arrays[data_name] if len(subset) == len(desc.shape): # Already in the right shape, modify new dimension subset = list(subset)[1:] new_subset = subsets.Range([('__dace_db_param', '__dace_db_param', 1)] + list(subset)) return new_subset @staticmethod def _replace_in_subset(subset, string_or_symbol, new_string_or_symbol): new_subset = copy.deepcopy(subset) repldict = { symbolic.pystr_to_symbolic(string_or_symbol): symbolic.pystr_to_symbolic(new_string_or_symbol) } for i, dim in enumerate(new_subset): try: new_subset[i] = tuple(d.subs(repldict) for d in dim) except TypeError: new_subset[i] = (dim.subs(repldict) if symbolic.issymbolic(dim) else dim) return new_subset
class StreamingComposition(xf.Transformation): """ Converts two connected computations (nodes, map scopes) into two separate processing elements, with a stream connecting the results. Only applies if the memory access patterns of the two computations match. """ first = xf.PatternNode(nodes.Node) access = xf.PatternNode(nodes.AccessNode) second = xf.PatternNode(nodes.Node) buffer_size = properties.Property( dtype=int, default=1, desc='Set buffer size for the newly-created stream') storage = properties.EnumProperty( dtype=dtypes.StorageType, desc='Set storage type for the newly-created stream', default=dtypes.StorageType.Default) @staticmethod def expressions() -> List[gr.SubgraphView]: return [ sdutil.node_path_graph(StreamingComposition.first, StreamingComposition.access, StreamingComposition.second) ] @staticmethod def can_be_applied(graph: SDFGState, candidate: Dict[xf.PatternNode, int], expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: access = graph.node(candidate[StreamingComposition.access]) # Make sure the access node is only accessed once (read or write), # and not at the same time if graph.in_degree(access) > 1 or graph.out_degree(access) > 1: return False # If already a stream, skip if isinstance(sdfg.arrays[access.data], data.Stream): return False # Only free nodes are allowed (search up the SDFG tree) curstate = graph node = access while curstate is not None: if curstate.entry_node(node) is not None: return False if curstate.parent.parent_nsdfg_node is None: break node = curstate.parent.parent_nsdfg_node curstate = curstate.parent.parent # Array must not be used anywhere else in the state if any(n is not access and n.data == access.data for n in graph.data_nodes()): return False # Only one memlet path on each direction is allowed # TODO: Relax so that repeated application of # transformation would yield additional streams first_edge = graph.in_edges(access)[0] second_edge = graph.out_edges(access)[0] first_mpath = graph.memlet_path(first_edge) second_mpath = graph.memlet_path(second_edge) if len(first_mpath) != len(list(graph.memlet_tree(first_edge))): return False if len(second_mpath) != len(list(graph.memlet_tree(second_edge))): return False # The innermost ends of the paths must have a clearly defined memory # access pattern and no WCR first_iedge = first_mpath[0] second_iedge = second_mpath[-1] if first_iedge.data.subset.num_elements() != 1: return False if first_iedge.data.volume != 1: return False if first_iedge.data.wcr is not None: return False if second_iedge.data.subset.num_elements() != 1: return False if second_iedge.data.volume != 1: return False ################################################################## # The memory access pattern must be exactly the same # Collect all maps and ranges ranges_first = _collect_map_ranges(graph, first_mpath) ranges_second = _collect_map_ranges(graph, second_mpath) # Check map ranges for (_, frng), (_, srng) in zip(ranges_first, ranges_second): if frng != srng: return False # Check memlets for equivalence if len(first_iedge.data.subset) != len(second_iedge.data.subset): return False if not _do_memlets_correspond(first_iedge.data, second_iedge.data, ranges_first, ranges_second): return False return True def apply(self, sdfg: SDFG) -> nodes.AccessNode: state = sdfg.node(self.state_id) access: nodes.AccessNode = self.access(sdfg) # Get memlet paths first_edge = state.in_edges(access)[0] second_edge = state.out_edges(access)[0] first_mpath = state.memlet_path(first_edge) second_mpath = state.memlet_path(second_edge) # Create new stream of shape 1 desc = sdfg.arrays[access.data] name, newdesc = sdfg.add_stream(access.data, desc.dtype, buffer_size=self.buffer_size, storage=self.storage, transient=True, find_new_name=True) # Remove transient array if possible for ostate in sdfg.nodes(): if ostate is state: continue if any(n.data == access.data for n in ostate.data_nodes()): break else: del sdfg.arrays[access.data] # Replace memlets in path with stream access for e in first_mpath: e.data = mm.Memlet(data=name, subset='0') if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc) for e in second_mpath: e.data = mm.Memlet(data=name, subset='0') if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc) # Replace array access node with two stream access nodes wnode = state.add_write(name) rnode = state.add_read(name) state.remove_edge(first_edge) state.add_edge(first_edge.src, first_edge.src_conn, wnode, first_edge.dst_conn, first_edge.data) state.remove_edge(second_edge) state.add_edge(rnode, second_edge.src_conn, second_edge.dst, second_edge.dst_conn, second_edge.data) # Remove original access node state.remove_node(access) return wnode, rnode
class MapExpansion(pm.Transformation): """ Implements the map-expansion pattern. Map-expansion takes an N-dimensional map and expands it to N unidimensional maps. New edges abide by the following rules: 1. If there are no edges coming from the outside, use empty memlets 2. Edges with IN_* connectors replicate along the maps 3. Edges for dynamic map ranges replicate until reaching range(s) """ map_entry = pm.PatternNode(nodes.MapEntry) @staticmethod def expressions(): return [sdutil.node_path_graph(MapExpansion.map_entry)] @staticmethod def can_be_applied(graph: dace.SDFGState, candidate: Dict[pm.PatternNode, int], expr_index: int, sdfg: dace.SDFG, strict: bool = False): # A candidate subgraph matches the map-expansion pattern when it # includes an N-dimensional map, with N greater than one. map_entry = graph.node(candidate[MapExpansion.map_entry]) return map_entry.map.get_param_num() > 1 @staticmethod def match_to_str(graph: dace.SDFGState, candidate: Dict[pm.PatternNode, int]) -> str: map_entry = graph.node(candidate[MapExpansion.map_entry]) return map_entry.map.label + ': ' + str(map_entry.map.params) def apply(self, sdfg: dace.SDFG): # Extract the map and its entry and exit nodes. graph = sdfg.node(self.state_id) map_entry = self.map_entry(sdfg) map_exit = graph.exit_node(map_entry) current_map = map_entry.map # Create new maps new_maps = [ nodes.Map(current_map.label + '_' + str(param), [param], subsets.Range([param_range]), schedule=dtypes.ScheduleType.Sequential) for param, param_range in zip(current_map.params[1:], current_map.range[1:]) ] current_map.params = [current_map.params[0]] current_map.range = subsets.Range([current_map.range[0]]) # Create new map entries and exits entries = [nodes.MapEntry(new_map) for new_map in new_maps] exits = [nodes.MapExit(new_map) for new_map in new_maps] # Create edges, abiding by the following rules: # 1. If there are no edges coming from the outside, use empty memlets # 2. Edges with IN_* connectors replicate along the maps # 3. Edges for dynamic map ranges replicate until reaching range(s) for edge in graph.out_edges(map_entry): graph.remove_edge(edge) graph.add_memlet_path(map_entry, *entries, edge.dst, src_conn=edge.src_conn, memlet=edge.data, dst_conn=edge.dst_conn) # Modify dynamic map ranges dynamic_edges = dace.sdfg.dynamic_map_inputs(graph, map_entry) for edge in dynamic_edges: # Remove old edge and connector graph.remove_edge(edge) edge.dst.remove_in_connector(edge.dst_conn) # Propagate to each range it belongs to path = [] for mapnode in [map_entry] + entries: path.append(mapnode) if any(edge.dst_conn in map(str, symbolic.symlist(r)) for r in mapnode.map.range): graph.add_memlet_path(edge.src, *path, memlet=edge.data, src_conn=edge.src_conn, dst_conn=edge.dst_conn) # Create new map exits for edge in graph.in_edges(map_exit): graph.remove_edge(edge) graph.add_memlet_path(edge.src, *exits[::-1], map_exit, memlet=edge.data, src_conn=edge.src_conn, dst_conn=edge.dst_conn) from dace.sdfg.scope import ScopeTree scope = None queue: List[ScopeTree] = graph.scope_leaves() while len(queue) > 0: tnode = queue.pop() if tnode.entry == entries[-1]: scope = tnode break elif tnode.parent is not None: queue.append(tnode.parent) else: raise ValueError('Cannot find scope in state') consolidate_edges(sdfg, scope) return [map_entry] + entries