def apply(self, graph, sdfg): map1_exit = self.map1_exit map1_entry = graph.entry_node(map1_exit) map2_entry = self.map2_entry buffers = graph.all_nodes_between(map1_exit, map2_entry) # Situation: # -> map1_entry -> ... -> map1_exit -> buffers -> map2_entry -> ... lower_extents = tuple(b - a for a, b in zip( map1_entry.range.min_element(), map2_entry.range.min_element())) upper_extents = tuple(a - b for a, b in zip( map1_entry.range.max_element(), map2_entry.range.max_element())) # Tile the first map with overlap MapTilingWithOverlap.apply_to(sdfg, map_entry=map1_entry, options={ 'tile_sizes': self.tile_sizes, 'lower_overlap': lower_extents, 'upper_overlap': upper_extents }) tile_map1_exit = graph.out_edges(map1_exit)[0].dst tile_map1_entry = graph.entry_node(tile_map1_exit) tile_map1_entry.label = 'BufferTiling' # Tile the second map MapTiling.apply_to(sdfg, map_entry=map2_entry, options={ 'tile_sizes': self.tile_sizes, 'tile_trivial': True }) tile_map2_entry = graph.in_edges(map2_entry)[0].src # Fuse maps some_buffer = next( iter(buffers)) # some dummy to pass to MapFusion.apply_to() MapFusion.apply_to(sdfg, first_map_exit=tile_map1_exit, array=some_buffer, second_map_entry=tile_map2_entry) # Optimize the simple cases map1_entry.range.ranges = [ (r[0], r[0], r[2]) if l_ext == 0 and u_ext == 0 and ts == 1 else r for r, l_ext, u_ext, ts in zip(map1_entry.range.ranges, lower_extents, upper_extents, self.tile_sizes) ] map2_entry.range.ranges = [ (r[0], r[0], r[2]) if ts == 1 else r for r, ts in zip(map2_entry.range.ranges, self.tile_sizes) ] if any(ts == 1 for ts in self.tile_sizes): if any(r[0] == r[1] for r in map1_entry.map.range): TrivialMapElimination.apply_to(sdfg, map_entry=map1_entry) if any(r[0] == r[1] for r in map2_entry.map.range): TrivialMapElimination.apply_to(sdfg, map_entry=map2_entry)
def test_applyto_pattern(): sdfg = dbladd.to_sdfg() sdfg.simplify() # Since there is only one state (thanks to StateFusion), we can use the # first one in the SDFG state = sdfg.node(0) # The multiplication map is called "_Mult__map" (see above graph), we can # query it mult_exit = next( n for n in state.nodes() if isinstance(n, dace.nodes.MapExit) and n.label == '_Mult__map') # Same goes for the addition entry add_entry = next( n for n in state.nodes() if isinstance(n, dace.nodes.MapEntry) and n.label == '_Add__map') # Since all redundant arrays have been removed by simplification pass, # we can get the only transient array that remains in the graph transient = next(aname for aname, desc in sdfg.arrays.items() if desc.transient) access_node = next( n for n in state.nodes() if isinstance(n, dace.nodes.AccessNode) and n.data == transient) MapFusion.apply_to(sdfg, first_map_exit=mult_exit, array=access_node, second_map_entry=add_entry)
def test_applyto_enumerate(): sdfg = dbladd.to_sdfg() sdfg.simplify() # Construct subgraph pattern pattern = sdutil.node_path_graph(dace.nodes.MapExit, dace.nodes.AccessNode, dace.nodes.MapEntry) for subgraph in enumerate_matches(sdfg, pattern): MapFusion.apply_to(sdfg, first_map_exit=subgraph.source_nodes()[0], array=next(n for n in subgraph.nodes() if isinstance(n, dace.nodes.AccessNode)), second_map_entry=subgraph.sink_nodes()[0])