class StreamTransient(pattern_matching.Transformation): """ Implements the StreamTransient transformation, which adds a transient stream node between nested maps that lead to a stream. The transient then acts as a local buffer. """ _tasklet = nodes.Tasklet('_') _map_exit = nodes.MapExit(nodes.Map("", [], [])) _outer_map_exit = nodes.MapExit(nodes.Map("", [], [])) @staticmethod def expressions(): return [ nxutil.node_path_graph(StreamTransient._tasklet, StreamTransient._map_exit, StreamTransient._outer_map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): map_exit = graph.nodes()[candidate[StreamTransient._map_exit]] outer_map_exit = graph.nodes()[candidate[ StreamTransient._outer_map_exit]] # Check if there is a streaming output for _src, _, dest, _, memlet in graph.out_edges(map_exit): if isinstance(sdfg.arrays[memlet.data], data.Stream) and dest == outer_map_exit: return True return False @staticmethod def match_to_str(graph, candidate): tasklet = candidate[StreamTransient._tasklet] map_exit = candidate[StreamTransient._map_exit] outer_map_exit = candidate[StreamTransient._outer_map_exit] return ' -> '.join( str(node) for node in [tasklet, map_exit, outer_map_exit]) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] tasklet = graph.nodes()[self.subgraph[StreamTransient._tasklet]] map_exit = graph.nodes()[self.subgraph[StreamTransient._map_exit]] outer_map_exit = graph.nodes()[self.subgraph[ StreamTransient._outer_map_exit]] memlet = None edge = None for e in graph.out_edges(map_exit): memlet = e.data # TODO: What if there's more than one? if e.dst == outer_map_exit and isinstance(sdfg.arrays[memlet.data], data.Stream): edge = e break tasklet_memlet = None for e in graph.out_edges(tasklet): tasklet_memlet = e.data if tasklet_memlet.data == memlet.data: break bbox = map_exit.map.range.bounding_box_size() bbox_approx = [symbolic.overapproximate(dim) for dim in bbox] dataname = memlet.data # Create the new node: Temporary stream and an access node newstream = sdfg.add_stream( 'tile_' + dataname, sdfg.arrays[memlet.data].dtype, 1, bbox_approx[0], [1], transient=True, ) snode = nodes.AccessNode('tile_' + dataname) to_stream_mm = copy.deepcopy(memlet) to_stream_mm.data = snode.data tasklet_memlet.data = snode.data # Reconnect, assuming one edge to the stream graph.remove_edge(edge) graph.add_edge(map_exit, None, snode, None, to_stream_mm) graph.add_edge(snode, None, outer_map_exit, None, memlet) return def modifies_graph(self): return True
class AccumulateTransient(pattern_matching.Transformation): """ Implements the AccumulateTransient transformation, which adds transient stream and data nodes between nested maps that lead to a stream. The transient data nodes then act as a local accumulator. """ _tasklet = nodes.Tasklet('_') _map_exit = nodes.MapExit(nodes.Map("", [], [])) _outer_map_exit = nodes.MapExit(nodes.Map("", [], [])) @staticmethod def expressions(): return [ nxutil.node_path_graph(StreamTransient._tasklet, StreamTransient._map_exit, StreamTransient._outer_map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): tasklet = graph.nodes()[candidate[StreamTransient._tasklet]] map_exit = graph.nodes()[candidate[StreamTransient._map_exit]] # Check if there is a streaming output for _src, _, dest, _, memlet in graph.out_edges(tasklet): if memlet.wcr is not None and dest == map_exit: return True return False @staticmethod def match_to_str(graph, candidate): tasklet = candidate[StreamTransient._tasklet] map_exit = candidate[StreamTransient._map_exit] outer_map_exit = candidate[StreamTransient._outer_map_exit] return ' -> '.join( str(node) for node in [tasklet, map_exit, outer_map_exit]) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] tasklet = graph.nodes()[self.subgraph[StreamTransient._tasklet]] map_exit = graph.nodes()[self.subgraph[StreamTransient._map_exit]] outer_map_exit = graph.nodes()[self.subgraph[ StreamTransient._outer_map_exit]] memlet = None edge = None for e in graph.out_edges(tasklet): memlet = e.data # TODO: What if there's more than one? if e.dst == map_exit and e.data.wcr is not None: break out_memlet = None for e in graph.out_edges(map_exit): out_memlet = e.data if out_memlet.data == memlet.data: edge = e break dataname = memlet.data # Create a new node with the same size as the output newdata = sdfg.add_array('trans_' + dataname, sdfg.arrays[memlet.data].shape, sdfg.arrays[memlet.data].dtype, transient=True) dnode = nodes.AccessNode('trans_' + dataname) to_data_mm = copy.deepcopy(memlet) to_data_mm.data = dnode.data to_data_mm.num_accesses = memlet.num_elements() to_exit_mm = copy.deepcopy(out_memlet) to_exit_mm.num_accesses = out_memlet.num_elements() memlet.data = dnode.data # Reconnect, assuming one edge to the stream graph.remove_edge(edge) graph.add_edge(map_exit, edge.src_conn, dnode, None, to_data_mm) graph.add_edge(dnode, None, outer_map_exit, edge.dst_conn, to_exit_mm) return def modifies_graph(self): return True
class OutLocalStorage(pattern_matching.Transformation): """ Implements the OutLocalStorage transformation, which adds a transient data node between nested map exits. """ _inner_map_exit = nodes.MapExit(nodes.Map("", [], [])) _outer_map_exit = nodes.MapExit(nodes.Map("", [], [])) @staticmethod def annotates_memlets(): return True @staticmethod def expressions(): return [ nxutil.node_path_graph( #OutLocalStorage._tasklet, OutLocalStorage._inner_map_exit, OutLocalStorage._outer_map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): return True @staticmethod def match_to_str(graph, candidate): inner_map_exit = candidate[OutLocalStorage._inner_map_exit] outer_map_exit = candidate[OutLocalStorage._outer_map_exit] return ' -> '.join( str(node) for node in [inner_map_exit, outer_map_exit]) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] inner_map_exit = graph.nodes()[self.subgraph[ OutLocalStorage._inner_map_exit]] outer_map_exit = graph.nodes()[self.subgraph[ OutLocalStorage._outer_map_exit]] original_edge = None invariant_memlet = None array = None for edge in graph.in_edges(outer_map_exit): src = edge.src if src != inner_map_exit: continue memlet = edge.data original_edge = edge invariant_memlet = memlet array = memlet.data break new_data = sdfg.add_array( graph.label + '_trans_' + invariant_memlet.data, [ symbolic.overapproximate(r) for r in invariant_memlet.bounding_box_size() ], sdfg.arrays[invariant_memlet.data].dtype, transient=True) data_node = nodes.AccessNode(graph.label + '_trans_' + invariant_memlet.data) data_node.setzero = True from_data_mm = copy.deepcopy(invariant_memlet) to_data_mm = copy.deepcopy(invariant_memlet) to_data_mm.data = data_node.data offset = [] for ind, r in enumerate(invariant_memlet.subset): offset.append(r[0]) if isinstance(invariant_memlet.subset[ind], tuple): begin = invariant_memlet.subset[ind][0] - r[0] end = invariant_memlet.subset[ind][1] - r[0] step = invariant_memlet.subset[ind][2] to_data_mm.subset[ind] = (begin, end, step) else: to_data_mm.subset[ind] -= r[0] # Reconnect, assuming one edge to the stream graph.remove_edge(original_edge) graph.add_edge(inner_map_exit, original_edge.src_conn, data_node, None, to_data_mm) graph.add_edge(data_node, None, outer_map_exit, original_edge.dst_conn, from_data_mm) for _parent, _, _child, _, memlet in graph.bfs_edges(inner_map_exit, reverse=True): if isinstance(_child, nodes.CodeNode): break if memlet.data != array: continue for ind, r in enumerate(memlet.subset): if isinstance(memlet.subset[ind], tuple): begin = r[0] - offset[ind] end = r[1] - offset[ind] step = r[2] memlet.subset[ind] = (begin, end, step) else: memlet.subset[ind] -= offset[ind] memlet.data = graph.label + '_trans_' + invariant_memlet.data return
def __stripmine(self, sdfg, graph, candidate): # Retrieve map entry and exit nodes. map_entry = graph.nodes()[candidate[OrthogonalTiling._map_entry]] map_exit = graph.exit_nodes(map_entry)[0] # Map subgraph map_subgraph = graph.scope_subgraph(map_entry) # Retrieve transformation properties. prefix = self.prefix tile_sizes = self.tile_sizes divides_evenly = self.divides_evenly new_param = [] new_range = [] for dim_idx in range(len(map_entry.map.params)): if dim_idx >= len(tile_sizes): tile_size = tile_sizes[-1] else: tile_size = tile_sizes[dim_idx] # Retrieve parameter and range of dimension to be strip-mined. target_dim = map_entry.map.params[dim_idx] td_from, td_to, td_step = map_entry.map.range[dim_idx] new_dim = prefix + '_' + target_dim # Basic values if divides_evenly: tile_num = '(%s + 1 - %s) / %s' % (symbolic.symstr(td_to), symbolic.symstr(td_from), str(tile_size)) else: tile_num = 'int_ceil((%s + 1 - %s), %s)' % (symbolic.symstr( td_to), symbolic.symstr(td_from), str(tile_size)) # Outer map values (over all tiles) nd_from = 0 nd_to = symbolic.pystr_to_symbolic(str(tile_num) + ' - 1') nd_step = 1 # Inner map values (over one tile) td_from_new = dace.symbolic.pystr_to_symbolic(td_from) td_to_new_exact = symbolic.pystr_to_symbolic( 'min(%s + 1 - %s * %s, %s + %s) - 1' % (symbolic.symstr(td_to), str(new_dim), str(tile_size), td_from_new, str(tile_size))) td_to_new_approx = symbolic.pystr_to_symbolic( '%s + %s - 1' % (td_from_new, str(tile_size))) # Outer map (over all tiles) new_dim_range = (nd_from, nd_to, nd_step) new_param.append(new_dim) new_range.append(new_dim_range) # Inner map (over one tile) if divides_evenly: td_to_new = td_to_new_approx else: td_to_new = dace.symbolic.SymExpr(td_to_new_exact, td_to_new_approx) map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step) # Fix subgraph memlets target_dim = dace.symbolic.pystr_to_symbolic(target_dim) offset = dace.symbolic.pystr_to_symbolic( '%s * %s' % (new_dim, str(tile_size))) for _, _, _, _, memlet in map_subgraph.edges(): old_subset = memlet.subset if isinstance(old_subset, dace.subsets.Indices): new_indices = [] for idx in old_subset: new_idx = idx.subs(target_dim, target_dim + offset) new_indices.append(new_idx) memlet.subset = dace.subsets.Indices(new_indices) elif isinstance(old_subset, dace.subsets.Range): new_ranges = [] for i, old_range in enumerate(old_subset): if len(old_range) == 3: b, e, s, = old_range t = old_subset.tile_sizes[i] else: raise ValueError( 'Range %s is invalid.' % old_range) new_b = b.subs(target_dim, target_dim + offset) new_e = e.subs(target_dim, target_dim + offset) new_s = s.subs(target_dim, target_dim + offset) new_t = t.subs(target_dim, target_dim + offset) new_ranges.append((new_b, new_e, new_s, new_t)) memlet.subset = dace.subsets.Range(new_ranges) else: raise NotImplementedError new_map = nodes.Map(prefix + '_' + map_entry.map.label, new_param, subsets.Range(new_range)) new_map_entry = nodes.MapEntry(new_map) new_exit = nodes.MapExit(new_map) # Make internal map's schedule to "not parallel" map_entry.map._schedule = dtypes.ScheduleType.Default # Redirect/create edges. new_in_edges = {} for _src, conn, _dest, _, memlet in graph.out_edges(map_entry): if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar): new_subset = copy.deepcopy(memlet.subset) # new_subset = calc_set_image(map_entry.map.params, # map_entry.map.range, memlet.subset, # cont_or_strided) if memlet.data in new_in_edges: src, src_conn, dest, dest_conn, new_memlet, num = \ new_in_edges[memlet.data] new_memlet.subset = calc_set_union( new_memlet.data, sdfg.arrays[nnew_memlet.data], new_memlet.subset, new_subset) new_memlet.num_accesses = new_memlet.num_elements() new_in_edges.update({ memlet.data: (src, src_conn, dest, dest_conn, new_memlet, min(num, int(conn[4:]))) }) else: new_memlet = dcpy(memlet) new_memlet.subset = new_subset new_memlet.num_accesses = new_memlet.num_elements() new_in_edges.update({ memlet.data: (new_map_entry, None, map_entry, None, new_memlet, int(conn[4:])) }) nxutil.change_edge_dest(graph, map_entry, new_map_entry) new_out_edges = {} for _src, conn, _dest, _, memlet in graph.in_edges(map_exit): if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar): new_subset = memlet.subset # new_subset = calc_set_image(map_entry.map.params, # map_entry.map.range, # memlet.subset, cont_or_strided) if memlet.data in new_out_edges: src, src_conn, dest, dest_conn, new_memlet, num = \ new_out_edges[memlet.data] new_memlet.subset = calc_set_union( new_memlet.data, sdfg.arrays[nnew_memlet.data], new_memlet.subset, new_subset) new_memlet.num_accesses = new_memlet.num_elements() new_out_edges.update({ memlet.data: (src, src_conn, dest, dest_conn, new_memlet, min(num, conn[4:])) }) else: new_memlet = dcpy(memlet) new_memlet.subset = new_subset new_memlet.num_accesses = new_memlet.num_elements() new_out_edges.update({ memlet.data: (map_exit, None, new_exit, None, new_memlet, conn[4:]) }) nxutil.change_edge_src(graph, map_exit, new_exit) # Connector related work follows # 1. Dictionary 'old_connector_number': 'new_connector_numer' # 2. New node in/out connectors # 3. New edges in_conn_nums = [] for _, e in new_in_edges.items(): _, _, _, _, _, num = e in_conn_nums.append(num) in_conn = {} for i, num in enumerate(in_conn_nums): in_conn.update({num: i + 1}) entry_in_connectors = set() entry_out_connectors = set() for i in range(len(in_conn_nums)): entry_in_connectors.add('IN_' + str(i + 1)) entry_out_connectors.add('OUT_' + str(i + 1)) new_map_entry.in_connectors = entry_in_connectors new_map_entry.out_connectors = entry_out_connectors for _, e in new_in_edges.items(): src, _, dst, _, memlet, num = e graph.add_edge(src, 'OUT_' + str(in_conn[num]), dst, 'IN_' + str(in_conn[num]), memlet) out_conn_nums = [] for _, e in new_out_edges.items(): _, _, dst, _, _, num = e if dst is not new_exit: continue out_conn_nums.append(num) out_conn = {} for i, num in enumerate(out_conn_nums): out_conn.update({num: i + 1}) exit_in_connectors = set() exit_out_connectors = set() for i in range(len(out_conn_nums)): exit_in_connectors.add('IN_' + str(i + 1)) exit_out_connectors.add('OUT_' + str(i + 1)) new_exit.in_connectors = exit_in_connectors new_exit.out_connectors = exit_out_connectors for _, e in new_out_edges.items(): src, _, dst, _, memlet, num = e graph.add_edge(src, 'OUT_' + str(out_conn[num]), dst, 'IN_' + str(out_conn[num]), memlet) # Return strip-mined dimension. return target_dim, new_dim, new_map
class Vectorization(pattern_matching.Transformation): """ Implements the vectorization transformation. Vectorization matches when all the input and output memlets of a tasklet inside a map access the inner-most loop variable in their last dimension. The transformation changes the step of the inner-most loop to be equal to the length of the vector and vectorizes the memlets. """ vector_len = Property(desc="Vector length", dtype=int, default=4) _map_entry = nodes.MapEntry(nodes.Map("", [], [])) _tasklet = nodes.Tasklet('_') _map_exit = nodes.MapExit(nodes.Map("", [], [])) @staticmethod def expressions(): return [ nxutil.node_path_graph(Vectorization._map_entry, Vectorization._tasklet, Vectorization._map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): map_entry = graph.nodes()[candidate[Vectorization._map_entry]] tasklet = graph.nodes()[candidate[Vectorization._tasklet]] param = symbolic.pystr_to_symbolic(map_entry.map.params[-1]) found = False dtype = None # Check if all edges, adjacent to the tasklet, # use the parameter in their last dimension. for _src, _, _dest, _, memlet in graph.all_edges(tasklet): # Cases that do not matter for vectorization if isinstance(sdfg.arrays[memlet.data], data.Stream): continue if memlet.wcr is not None: continue try: subset = memlet.subset veclen = memlet.veclen except AttributeError: return False if subset is None: return False try: if veclen > symbolic.pystr_to_symbolic('1'): return False for idx, expr in enumerate(subset): if isinstance(expr, tuple): for ex in expr: ex = symbolic.pystr_to_symbolic(ex) symbols = ex.free_symbols if param in symbols: if idx == subset.dims() - 1: found = True else: return False else: expr = symbolic.pystr_to_symbolic(expr) symbols = expr.free_symbols if param in symbols: if idx == subset.dims() - 1: found = True else: return False except TypeError: # cannot determine truth value of Relational return False return found @staticmethod def match_to_str(graph, candidate): map_entry = candidate[Vectorization._map_entry] tasklet = candidate[Vectorization._tasklet] map_exit = candidate[Vectorization._map_exit] return ' -> '.join( str(node) for node in [map_entry, tasklet, map_exit]) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]] tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]] map_exit = graph.nodes()[self.subgraph[Vectorization._map_exit]] param = symbolic.pystr_to_symbolic(map_entry.map.params[-1]) # Create new vector size. vector_size = self.vector_len # Change the step of the inner-most dimension. dim_from, dim_to, _dim_step = map_entry.map.range[-1] map_entry.map.range[-1] = (dim_from, dim_to, vector_size) # Vectorize memlets adjacent to the tasklet. for _src, _, _dest, _, memlet in graph.all_edges(tasklet): subset = memlet.subset lastindex = memlet.subset[-1] if isinstance(lastindex, tuple): symbols = set() for indd in lastindex: symbols.update( symbolic.pystr_to_symbolic(indd).free_symbols) else: symbols = symbolic.pystr_to_symbolic( memlet.subset[-1]).free_symbols if param in symbols: try: memlet.veclen = vector_size except AttributeError: return # TODO: Create new map for non-vectorizable part. return def modifies_graph(self): return True
class MapWCRFusion(pm.Transformation): """ Implements the map expanded-reduce fusion transformation. Fuses a map with an immediately following reduction, where the array between the map and the reduction is not used anywhere else, and the reduction is divided to two maps with a WCR, denoting partial reduction. """ _tasklet = nodes.Tasklet('_') _tmap_exit = nodes.MapExit(nodes.Map("", [], [])) _in_array = nodes.AccessNode('_') _rmap_in_entry = nodes.MapEntry(nodes.Map("", [], [])) _rmap_in_tasklet = nodes.Tasklet('_') _rmap_in_cr = nodes.MapExit(nodes.Map("", [], [])) _rmap_out_entry = nodes.MapEntry(nodes.Map("", [], [])) _rmap_out_exit = nodes.MapExit(nodes.Map("", [], [])) _out_array = nodes.AccessNode('_') @staticmethod def expressions(): return [ # Map, then partial reduction of axes nxutil.node_path_graph( MapWCRFusion._tasklet, MapWCRFusion._tmap_exit, MapWCRFusion._in_array, MapWCRFusion._rmap_out_entry, MapWCRFusion._rmap_in_entry, MapWCRFusion._rmap_in_tasklet, MapWCRFusion._rmap_in_cr, MapWCRFusion._rmap_out_exit, MapWCRFusion._out_array) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): tmap_exit = graph.nodes()[candidate[MapWCRFusion._tmap_exit]] in_array = graph.nodes()[candidate[MapWCRFusion._in_array]] rmap_entry = graph.nodes()[candidate[MapWCRFusion._rmap_out_entry]] # Make sure that the array is only accessed by the map and the reduce if any([ src != tmap_exit for src, _, _, _, memlet in graph.in_edges(in_array) ]): return False if any([ dest != rmap_entry for _, _, dest, _, memlet in graph.out_edges(in_array) ]): return False # Make sure that there is a reduction in the second map rmap_cr = graph.nodes()[candidate[MapWCRFusion._rmap_in_cr]] reduce_edge = graph.in_edges(rmap_cr)[0] if reduce_edge.data.wcr is None: return False # (strict) Make sure that the transient is not accessed anywhere else # in this state or other states if strict and (len([ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.data == in_array.data ]) > 1 or in_array.data in sdfg.shared_transients()): return False # Verify that reduction ranges match tasklet map tout_memlet = graph.in_edges(in_array)[0].data rin_memlet = graph.out_edges(in_array)[0].data if tout_memlet.subset != rin_memlet.subset: return False return True @staticmethod def match_to_str(graph, candidate): tasklet = candidate[MapWCRFusion._tasklet] map_exit = candidate[MapWCRFusion._tmap_exit] reduce = candidate[MapWCRFusion._rmap_in_cr] return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce]) def apply(self, sdfg): graph = sdfg.node(self.state_id) # To apply, collapse the second map and then fuse the two resulting maps map_collapse = MapCollapse( self.sdfg_id, self.state_id, { MapCollapse._outer_map_entry: self.subgraph[MapWCRFusion._rmap_out_entry], MapCollapse._inner_map_entry: self.subgraph[MapWCRFusion._rmap_in_entry] }, 0) map_entry, _ = map_collapse.apply(sdfg) map_fusion = MapFusion( self.sdfg_id, self.state_id, { MapFusion._first_map_exit: self.subgraph[MapWCRFusion._tmap_exit], MapFusion._second_map_entry: graph.node_id(map_entry) }, 0) map_fusion.apply(sdfg)
def _stripmine(self, sdfg, graph, candidate): # Retrieve map entry and exit nodes. map_entry = graph.nodes()[candidate[StripMining._map_entry]] map_exit = graph.exit_nodes(map_entry)[0] # Retrieve transformation properties. dim_idx = self.dim_idx new_dim_prefix = self.new_dim_prefix tile_size = self.tile_size divides_evenly = self.divides_evenly strided = self.strided tile_stride = self.tile_stride if tile_stride is None or len(tile_stride) == 0: tile_stride = tile_size # Retrieve parameter and range of dimension to be strip-mined. target_dim = map_entry.map.params[dim_idx] td_from, td_to, td_step = map_entry.map.range[dim_idx] # Create new map. Replace by cloning??? new_dim = self._find_new_dim(sdfg, graph, map_entry, new_dim_prefix, target_dim) nd_from = 0 nd_to = symbolic.pystr_to_symbolic( 'int_ceil(%s + 1 - %s, %s) - 1' % (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_stride)) nd_step = 1 new_dim_range = (nd_from, nd_to, nd_step) new_map = nodes.Map(new_dim + '_' + map_entry.map.label, [new_dim], subsets.Range([new_dim_range])) new_map_entry = nodes.MapEntry(new_map) new_map_exit = nodes.MapExit(new_map) # Change the range of the selected dimension to iterate over a single # tile if strided: td_from_new = symbolic.pystr_to_symbolic(new_dim) td_to_new_approx = td_to td_step = symbolic.pystr_to_symbolic(tile_size) else: td_from_new = symbolic.pystr_to_symbolic( '%s + %s * %s' % (symbolic.symstr(td_from), str(new_dim), tile_stride)) td_to_new_exact = symbolic.pystr_to_symbolic( 'min(%s + 1, %s + %s * %s + %s) - 1' % (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_stride, str(new_dim), tile_size)) td_to_new_approx = symbolic.pystr_to_symbolic( '%s + %s * %s + %s - 1' % (symbolic.symstr(td_from), tile_stride, str(new_dim), tile_size)) if divides_evenly or strided: td_to_new = td_to_new_approx else: td_to_new = dace.symbolic.SymExpr(td_to_new_exact, td_to_new_approx) map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step) # Make internal map's schedule to "not parallel" new_map.schedule = map_entry.map.schedule map_entry.map.schedule = dtypes.ScheduleType.Sequential # Redirect edges new_map_entry.in_connectors = dcpy(map_entry.in_connectors) nxutil.change_edge_dest(graph, map_entry, new_map_entry) new_map_exit.out_connectors = dcpy(map_exit.out_connectors) nxutil.change_edge_src(graph, map_exit, new_map_exit) # Create new entry edges new_in_edges = dict() entry_in_conn = set() entry_out_conn = set() for _src, src_conn, _dst, _, memlet in graph.out_edges(map_entry): if (src_conn is not None and src_conn[:4] == 'OUT_' and not isinstance( sdfg.arrays[memlet.data], dace.data.Scalar)): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) conn = src_conn[4:] key = (memlet.data, 'IN_' + conn, 'OUT_' + conn) if key in new_in_edges.keys(): old_subset = new_in_edges[key].subset new_in_edges[key].subset = calc_set_union( old_subset, new_subset) else: entry_in_conn.add('IN_' + conn) entry_out_conn.add('OUT_' + conn) new_memlet = dcpy(memlet) new_memlet.subset = new_subset new_memlet.num_accesses = new_memlet.num_elements() new_in_edges[key] = new_memlet else: if src_conn is not None and src_conn[:4] == 'OUT_': conn = src_conn[4:] in_conn = 'IN_' + conn out_conn = 'OUT_' + conn else: in_conn = src_conn out_conn = src_conn if in_conn: entry_in_conn.add(in_conn) if out_conn: entry_out_conn.add(out_conn) new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet) new_map_entry.out_connectors = entry_out_conn map_entry.in_connectors = entry_in_conn for (_, in_conn, out_conn), memlet in new_in_edges.items(): graph.add_edge(new_map_entry, out_conn, map_entry, in_conn, memlet) # Create new exit edges new_out_edges = dict() exit_in_conn = set() exit_out_conn = set() for _src, _, _dst, dst_conn, memlet in graph.in_edges(map_exit): if (dst_conn is not None and dst_conn[:3] == 'IN_' and not isinstance( sdfg.arrays[memlet.data], dace.data.Scalar)): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) conn = dst_conn[3:] key = (memlet.data, 'IN_' + conn, 'OUT_' + conn) if key in new_out_edges.keys(): old_subset = new_out_edges[key].subset new_out_edges[key].subset = calc_set_union( old_subset, new_subset) else: exit_in_conn.add('IN_' + conn) exit_out_conn.add('OUT_' + conn) new_memlet = dcpy(memlet) new_memlet.subset = new_subset new_memlet.num_accesses = new_memlet.num_elements() new_out_edges[key] = new_memlet else: if dst_conn is not None and dst_conn[:3] == 'IN_': conn = dst_conn[3:] in_conn = 'IN_' + conn out_conn = 'OUT_' + conn else: in_conn = src_conn out_conn = src_conn if in_conn: exit_in_conn.add(in_conn) if out_conn: exit_out_conn.add(out_conn) new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet) new_map_exit.in_connectors = exit_in_conn map_exit.out_connectors = exit_out_conn for (_, in_conn, out_conn), memlet in new_out_edges.items(): graph.add_edge(map_exit, out_conn, new_map_exit, in_conn, memlet) # Return strip-mined dimension. return target_dim, new_dim, new_map
def _build_dataflow_graph_recurse(sdfg, state, primitives, modules, superEntry, super_exit): # Array of pairs (exit node, memlet) exit_nodes = [] if len(primitives) == 0: # Inject empty tasklets into empty states primitives = [astnodes._EmptyTaskletNode("Empty Tasklet", None)] for prim in primitives: label = prim.name # Expand node to get entry and exit points if isinstance(prim, astnodes._MapNode): if len(prim.children) == 0: raise ValueError("Map node expected to have children") mapNode = nd.Map(label, prim.params, prim.range, is_async=prim.is_async) # Add connectors for inputs that exist as array nodes entry = nd.MapEntry( mapNode, _get_input_symbols(prim.inputs, prim.range.free_symbols)) exit = nd.MapExit(mapNode) elif isinstance(prim, astnodes._ConsumeNode): if len(prim.children) == 0: raise ValueError("Consume node expected to have children") consumeNode = nd.Consume(label, (prim.params[1], prim.num_pes), prim.condition) entry = nd.ConsumeEntry(consumeNode) exit = nd.ConsumeExit(consumeNode) elif isinstance(prim, astnodes._ReduceNode): rednode = nd.Reduce(prim.ast, prim.axes, prim.identity) state.add_node(rednode) entry = rednode exit = rednode elif isinstance(prim, astnodes._TaskletNode): if isinstance(prim, astnodes._EmptyTaskletNode): tasklet = nd.EmptyTasklet(prim.name) else: # Remove memlets from tasklet AST if prim.language == types.Language.Python: clean_code = MemletRemover().visit(prim.ast) clean_code = ModuleInliner(modules).visit(clean_code) else: # Use external code from tasklet definition if prim.extcode is None: raise SyntaxError("Cannot define an intrinsic " "tasklet without an implementation") clean_code = prim.extcode tasklet = nd.Tasklet( prim.name, set(prim.inputs.keys()), set(prim.outputs.keys()), code=clean_code, language=prim.language, code_global=prim.gcode) # TODO: location=prim.location # Need to add the tasklet in case we're in an empty state, where no # edge will be drawn to it state.add_node(tasklet) entry = tasklet exit = tasklet elif isinstance(prim, astnodes._NestedSDFGNode): prim.sdfg.parent = state prim.sdfg._parent_sdfg = sdfg prim.sdfg.update_sdfg_list([]) nsdfg = nd.NestedSDFG(prim.name, prim.sdfg, set(prim.inputs.keys()), set(prim.outputs.keys())) state.add_node(nsdfg) entry = nsdfg exit = nsdfg elif isinstance(prim, astnodes._ProgramNode): return elif isinstance(prim, astnodes._ControlFlowNode): continue else: raise TypeError("Node type not implemented: " + str(prim.__class__)) # Add incoming edges for varname, memlet in prim.inputs.items(): arr = memlet.dataname if (prim.parent is not None and memlet.dataname in prim.parent.transients.keys()): node = input_node_for_array(state, memlet.dataname) # Add incoming edge into transient as well # FIXME: A bit hacked? if arr in prim.parent.inputs: astmem = prim.parent.inputs[arr] _add_astmemlet_edge(sdfg, state, superEntry, None, node, None, astmem) # Remove local name from incoming edge to parent prim.parent.inputs[arr].local_name = None elif superEntry: node = superEntry else: node = input_node_for_array(state, memlet.dataname) # Destination connector inference # Connected to a tasklet or a nested SDFG dst_conn = (memlet.local_name if isinstance(entry, nd.CodeNode) else None) # Connected to a scope as part of its range if str(varname).startswith('__DACEIN_'): dst_conn = str(varname)[9:] # Handle special case of consume input stream if (isinstance(entry, nd.ConsumeEntry) and memlet.data == prim.stream): dst_conn = 'IN_stream' # If a memlet that covers this input already exists, skip # generating this one; otherwise replace memlet with ours skip_incoming_edge = False remove_edge = None for e in state.edges_between(node, entry): if e.data.data != memlet.dataname or dst_conn != e.dst_conn: continue if e.data.subset.covers(memlet.subset): skip_incoming_edge = True break elif memlet.subset.covers(e.data.subset): remove_edge = e break else: print('WARNING: Performing bounding-box union on', memlet.subset, 'and', e.data.subset, '(in)') e.data.subset = sbs.bounding_box_union( e.data.subset, memlet.subset) e.data.num_accesses += memlet.num_accesses skip_incoming_edge = True break if remove_edge is not None: state.remove_edge(remove_edge) if skip_incoming_edge == False: _add_astmemlet_edge(sdfg, state, node, None, entry, dst_conn, memlet) # If there are no inputs, generate a dummy edge if superEntry and len(prim.inputs) == 0: state.add_edge(superEntry, None, entry, None, EmptyMemlet()) if len(prim.children) > 0: # Recurse inner_outputs = _build_dataflow_graph_recurse( sdfg, state, prim.children, modules, entry, exit) # Infer output node for each memlet for i, (out_src, mem) in enumerate(inner_outputs): # If there is no such array in this primitive's outputs, # it's an external array (e.g., a map in a map). In this case, # connect to the exit node if mem.dataname in prim.outputs: inner_outputs[i] = (out_src, prim.outputs[mem.dataname]) else: inner_outputs[i] = (out_src, mem) else: inner_outputs = [(exit, mem) for mem in prim.outputs.values()] # Add outgoing edges for out_src, astmem in inner_outputs: data = astmem.data dataname = astmem.dataname # If WCR is not none, it needs to be handled in the code. Check for # this after, as we only expect it for one distinct case wcr_was_handled = astmem.wcr is None # TODO: This is convoluted. We should find a more readable # way of connecting the outgoing edges. if super_exit is None: # Assert that we're in a top-level node if ((not isinstance(prim.parent, astnodes._ProgramNode)) and (not isinstance(prim.parent, astnodes._ControlFlowNode))): raise RuntimeError("Expected to be at the top node") # Looks hacky src_conn = (astmem.local_name if isinstance( out_src, (nd.Tasklet, nd.NestedSDFG)) else None) # Here we just need to connect memlets directly to their # respective data nodes out_tgt = output_node_for_array(state, astmem.dataname) # If a memlet that covers this outuput already exists, skip # generating this one; otherwise replace memlet with ours skip_outgoing_edge = False remove_edge = None for e in state.edges_between(out_src, out_tgt): if e.data.data != astmem.dataname or src_conn != e.src_conn: continue if e.data.subset.covers(astmem.subset): skip_outgoing_edge = True break elif astmem.subset.covers(e.data.subset): remove_edge = e break else: print('WARNING: Performing bounding-box union on', astmem.subset, 'and', e.data.subset, '(out)') e.data.subset = sbs.bounding_box_union( e.data.subset, astmem.subset) e.data.num_accesses += astmem.num_accesses skip_outgoing_edge = True break if skip_outgoing_edge == True: continue if remove_edge is not None: state.remove_edge(remove_edge) _add_astmemlet_edge(sdfg, state, out_src, src_conn, out_tgt, None, astmem, wcr=astmem.wcr, wcr_identity=astmem.wcr_identity) wcr_was_handled = (True if astmem.wcr is not None else wcr_was_handled) # If the program defines another output, connect it too. # This refers to the case where we have streams, which # must define an input and output, and sometimes this output # is defined in pdp.outputs if (isinstance(out_tgt, nd.AccessNode) and isinstance(out_tgt.desc(sdfg), dt.Stream)): try: stream_memlet = next( v for k, v in prim.parent.outputs.items() if k == out_tgt.data) stream_output = output_node_for_array( state, stream_memlet.dataname) _add_astmemlet_edge(sdfg, state, out_tgt, None, stream_output, None, stream_memlet) except StopIteration: # Stream output not found, skip pass else: # We're in a nest if isinstance(prim, astnodes._ScopeNode): # We're a map or a consume node, that needs to connect our # exit to either an array or to the super_exit if data.transient and dataname in prim.parent.transients: # Connect the exit directly out_tgt = output_node_for_array(state, data.dataname) _add_astmemlet_edge(sdfg, state, out_src, None, out_tgt, None, astmem) else: # This is either a transient defined in an outer scope, # or an I/O array, so redirect thruogh the exit node _add_astmemlet_edge(sdfg, state, out_src, None, super_exit, None, astmem) # Instruct outer recursion layer to continue the route exit_nodes.append((super_exit, astmem)) elif isinstance( prim, (astnodes._TaskletNode, astnodes._NestedSDFGNode)): # We're a tasklet, and need to connect either to the exit # if the array is I/O or is defined in a scope further out, # or directly to the transient if it's defined locally if dataname in prim.parent.transients: # This is a local transient variable, so connect to it # directly out_tgt = output_node_for_array(state, data.dataname) _add_astmemlet_edge(sdfg, state, out_src, astmem.local_name, out_tgt, None, astmem) else: # This is an I/O array, or an outer level transient, so # redirect through the exit node _add_astmemlet_edge(sdfg, state, out_src, astmem.local_name, super_exit, None, astmem, wcr=astmem.wcr, wcr_identity=astmem.wcr_identity) exit_nodes.append((super_exit, astmem)) if astmem.wcr is not None: wcr_was_handled = True # Sanity check else: raise TypeError("Unexpected node type: {}".format( type(out_src).__name__)) if not wcr_was_handled and not isinstance(prim, astnodes._ScopeNode): raise RuntimeError("Detected unhandled WCR for primitive '{}' " "of type {}. WCR is only expected for " "tasklets in a map/consume scope.".format( prim.name, type(prim).__name__)) return exit_nodes
class MapReduceFusion(pm.Transformation): """ Implements the map-reduce-fusion transformation. Fuses a map with an immediately following reduction, where the array between the map and the reduction is not used anywhere else. """ _tasklet = nodes.Tasklet('_') _tmap_exit = nodes.MapExit(nodes.Map("", [], [])) _in_array = nodes.AccessNode('_') _reduce = nodes.Reduce('lambda: None', None) _out_array = nodes.AccessNode('_') @staticmethod def expressions(): return [ nxutil.node_path_graph(MapReduceFusion._tasklet, MapReduceFusion._tmap_exit, MapReduceFusion._in_array, MapReduceFusion._reduce, MapReduceFusion._out_array) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): tmap_exit = graph.nodes()[candidate[MapReduceFusion._tmap_exit]] in_array = graph.nodes()[candidate[MapReduceFusion._in_array]] reduce_node = graph.nodes()[candidate[MapReduceFusion._reduce]] tasklet = graph.nodes()[candidate[MapReduceFusion._tasklet]] # Make sure that the array is only accessed by the map and the reduce if any([ src != tmap_exit for src, _, _, _, memlet in graph.in_edges(in_array) ]): return False if any([ dest != reduce_node for _, _, dest, _, memlet in graph.out_edges(in_array) ]): return False tmem = next(e for e in graph.edges_between(tasklet, tmap_exit) if e.data.data == in_array.data).data # (strict) Make sure that the transient is not accessed anywhere else # in this state or other states if strict and (len([ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.data == in_array.data ]) > 1 or in_array.data in sdfg.shared_transients()): return False # If memlet already has WCR and it is different from reduce node, # do not match if tmem.wcr is not None and tmem.wcr != reduce_node.wcr: return False # Verify that reduction ranges match tasklet map tout_memlet = graph.in_edges(in_array)[0].data rin_memlet = graph.out_edges(in_array)[0].data if tout_memlet.subset != rin_memlet.subset: return False return True @staticmethod def match_to_str(graph, candidate): tasklet = candidate[MapReduceFusion._tasklet] map_exit = candidate[MapReduceFusion._tmap_exit] reduce = candidate[MapReduceFusion._reduce] return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce]) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]] in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]] reduce_node = graph.nodes()[self.subgraph[MapReduceFusion._reduce]] out_array = graph.nodes()[self.subgraph[MapReduceFusion._out_array]] # Set nodes to remove according to the expression index nodes_to_remove = [in_array] nodes_to_remove.append(reduce_node) memlet_edge = None for edge in graph.in_edges(tmap_exit): if edge.data.data == in_array.data: memlet_edge = edge break if memlet_edge is None: raise RuntimeError('Reduction memlet cannot be None') # Find which indices should be removed from new memlet input_edge = graph.in_edges(reduce_node)[0] axes = reduce_node.axes or list(range(input_edge.data.subset)) array_edge = graph.out_edges(reduce_node)[0] # Delete relevant edges and nodes graph.remove_nodes_from(nodes_to_remove) # Filter out reduced dimensions from subset filtered_subset = [ dim for i, dim in enumerate(memlet_edge.data.subset) if i not in axes ] if len(filtered_subset) == 0: # Output is a scalar filtered_subset = [0] # Modify edge from tasklet to map exit memlet_edge.data.data = out_array.data memlet_edge.data.wcr = reduce_node.wcr memlet_edge.data.wcr_identity = reduce_node.identity memlet_edge.data.subset = type( memlet_edge.data.subset)(filtered_subset) # Add edge from map exit to output array graph.add_edge( memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:], array_edge.dst, array_edge.dst_conn, Memlet(array_edge.data.data, array_edge.data.num_accesses, array_edge.data.subset, array_edge.data.veclen, reduce_node.wcr, reduce_node.identity))
class Vectorization(pattern_matching.Transformation): """ Implements the vectorization transformation. Vectorization matches when all the input and output memlets of a tasklet inside a map access the inner-most loop variable in their last dimension. The transformation changes the step of the inner-most loop to be equal to the length of the vector and vectorizes the memlets. """ vector_len = Property(desc="Vector length", dtype=int, default=4) propagate_parent = Property(desc="Propagate vector length through " "parent SDFGs", dtype=bool, default=False) _map_entry = nodes.MapEntry(nodes.Map("", [], [])) _tasklet = nodes.Tasklet('_') _map_exit = nodes.MapExit(nodes.Map("", [], [])) @staticmethod def expressions(): return [ nxutil.node_path_graph(Vectorization._map_entry, Vectorization._tasklet, Vectorization._map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): map_entry = graph.nodes()[candidate[Vectorization._map_entry]] tasklet = graph.nodes()[candidate[Vectorization._tasklet]] param = symbolic.pystr_to_symbolic(map_entry.map.params[-1]) found = False # Check if all edges, adjacent to the tasklet, # use the parameter in their last dimension. for _src, _, _dest, _, memlet in graph.all_edges(tasklet): # Cases that do not matter for vectorization if memlet.data is None: # Empty memlets continue if isinstance(sdfg.arrays[memlet.data], data.Stream): # Streams continue # Vectorization can not be applied in WCR if memlet.wcr is not None: return False try: subset = memlet.subset veclen = memlet.veclen except AttributeError: return False if subset is None: return False try: if veclen > symbolic.pystr_to_symbolic('1'): return False for idx, expr in enumerate(subset): if isinstance(expr, tuple): for ex in expr: ex = symbolic.pystr_to_symbolic(ex) symbols = ex.free_symbols if param in symbols: if idx == subset.dims() - 1: found = True else: return False else: expr = symbolic.pystr_to_symbolic(expr) symbols = expr.free_symbols if param in symbols: if idx == subset.dims() - 1: found = True else: return False except TypeError: # cannot determine truth value of Relational return False return found @staticmethod def match_to_str(graph, candidate): map_entry = candidate[Vectorization._map_entry] tasklet = candidate[Vectorization._tasklet] map_exit = candidate[Vectorization._map_exit] return ' -> '.join( str(node) for node in [map_entry, tasklet, map_exit]) def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]] tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]] map_exit = graph.nodes()[self.subgraph[Vectorization._map_exit]] param = symbolic.pystr_to_symbolic(map_entry.map.params[-1]) # Create new vector size. vector_size = self.vector_len # Change the step of the inner-most dimension. dim_from, dim_to, _dim_step = map_entry.map.range[-1] map_entry.map.range[-1] = (dim_from, dim_to, vector_size) # Vectorize memlets adjacent to the tasklet. for edge in graph.all_edges(tasklet): _src, _, _dest, _, memlet = edge if memlet.data is None: # Empty memlets continue lastindex = memlet.subset[-1] if isinstance(lastindex, tuple): symbols = set() for indd in lastindex: symbols.update( symbolic.pystr_to_symbolic(indd).free_symbols) else: symbols = symbolic.pystr_to_symbolic( memlet.subset[-1]).free_symbols if param in symbols: try: # propagate vector length inside this SDFG for e in graph.memlet_path(edge): e.data.veclen = vector_size source_edge = graph.memlet_path(edge)[0] sink_edge = graph.memlet_path(edge)[-1] # propagate to the parent (TODO: handle multiple level of nestings) if self.propagate_parent and sdfg.parent is not None: # Find parent Nested SDFG node parent_node = next(n for n in sdfg.parent.nodes() if isinstance(n, nodes.NestedSDFG) and n.sdfg.name == sdfg.name) # continue in propagating the vector length following the path that arrives to source_edge or # starts from sink_edge for pe in sdfg.parent.all_edges(parent_node): if str(pe.dst_conn) == str(source_edge.src) or str( pe.src_conn) == str(sink_edge.dst): for ppe in sdfg.parent.memlet_path(pe): ppe.data.veclen = vector_size except AttributeError: raise return
class MapReduceFusion(pm.Transformation): """ Implements the map-reduce-fusion transformation. Fuses a map with an immediately following reduction, where the array between the map and the reduction is not used anywhere else. """ _tasklet = nodes.Tasklet('_') _tmap_exit = nodes.MapExit(nodes.Map("", [], [])) _in_array = nodes.AccessNode('_') _rmap_in_entry = nodes.MapEntry(nodes.Map("", [], [])) _rmap_in_tasklet = nodes.Tasklet('_') _rmap_in_cr = nodes.MapExit(nodes.Map("", [], [])) _rmap_out_entry = nodes.MapEntry(nodes.Map("", [], [])) _rmap_out_exit = nodes.MapExit(nodes.Map("", [], [])) _out_array = nodes.AccessNode('_') _reduce = nodes.Reduce('lambda: None', None) @staticmethod def expressions(): return [ # Map, then reduce of all axes nxutil.node_path_graph( MapReduceFusion._tasklet, MapReduceFusion._tmap_exit, MapReduceFusion._in_array, MapReduceFusion._rmap_in_entry, MapReduceFusion._rmap_in_tasklet, MapReduceFusion._rmap_in_cr, MapReduceFusion._out_array), # Map, then partial reduction of axes nxutil.node_path_graph( MapReduceFusion._tasklet, MapReduceFusion._tmap_exit, MapReduceFusion._in_array, MapReduceFusion._rmap_out_entry, MapReduceFusion._rmap_in_entry, MapReduceFusion._rmap_in_tasklet, MapReduceFusion._rmap_in_cr, MapReduceFusion._rmap_out_exit, MapReduceFusion._out_array), # Map, then reduce node nxutil.node_path_graph( MapReduceFusion._tasklet, MapReduceFusion._tmap_exit, MapReduceFusion._in_array, MapReduceFusion._reduce, MapReduceFusion._out_array) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): tmap_exit = graph.nodes()[candidate[MapReduceFusion._tmap_exit]] in_array = graph.nodes()[candidate[MapReduceFusion._in_array]] if expr_index == 0: # Reduce without outer map rmap_entry = graph.nodes()[candidate[ MapReduceFusion._rmap_in_entry]] # rmap_in_entry = rmap_entry elif expr_index == 1: # Reduce with outer map rmap_entry = graph.nodes()[candidate[ MapReduceFusion._rmap_out_entry]] # rmap_in_entry = graph.nodes()[candidate[ # MapReduceFusion._rmap_in_entry]] else: # Reduce node rmap_entry = graph.nodes()[candidate[MapReduceFusion._reduce]] # Make sure that the array is only accessed by the map and the reduce if any([ src != tmap_exit for src, _, _, _, memlet in graph.in_edges(in_array) ]): return False if any([ dest != rmap_entry for _, _, dest, _, memlet in graph.out_edges(in_array) ]): return False # Make sure that there is a reduction in the second map if expr_index < 2: rmap_cr = graph.nodes()[candidate[MapReduceFusion._rmap_in_cr]] reduce_edge = graph.in_edges(rmap_cr)[0] if reduce_edge.data.wcr is None: return False # Make sure that the transient is not accessed by other states # if garr.get_unique_name() in cgen_state.sdfg.shared_transients(): # return False # reduce_inarr = reduce.in_array # reduce_outarr = reduce.out_array # reduce_inslice = reduce.inslice # reduce_outslice = reduce.outslice # insize = cgen_state.var_sizes[reduce_inarr] # outsize = cgen_state.var_sizes[reduce_outarr] # Currently only supports full-range arrays # TODO(later): Support fusion of partial reductions and refactor slice/subarray handling #if not nxutil.fullrange(reduce_inslice, insize) or \ # not nxutil.fullrange(reduce_outslice, outsize): # return False # Verify acceses from tasklet through MapExit #already_found = False #for _src, _, _dest, _, memlet in graph.in_edges(map_exit): # if isinstance(memlet.subset, subsets.Indices): # # Make sure that only one value is reduced at a time # if memlet.data == in_array.desc: # if already_found: # return False # already_found = True ## Find axes after reduction #indims = len(reduce.inslice) #axis_after_reduce = [None] * indims #ctr = 0 #for i in range(indims): # if reduce.axes is not None and i in reduce.axes: # axis_after_reduce[i] = None # else: # axis_after_reduce[i] = ctr # ctr += 1 ## Match map ranges with reduce ranges #curaxis = 0 #for dim, var in enumerate(memlet.subset): # # Make sure that indices are direct symbols # #if not isinstance(symbolic.pystr_to_symbolic(var), sympy.Symbol): # # return False # perm = None # for i, mapvar in enumerate(map_exit.map.params): # if symbolic.pystr_to_symbolic(mapvar) == var: # perm = i # break # if perm is None: # If symbol is not found in map range # return False # # Make sure that map ranges match output slice after reduction # map_range = map_exit.map.range[perm] # if map_range[0] != 0: # return False # Disallow start from middle # if map_range[2] is not None and map_range[2] != 1: # return False # Disallow skip # if reduce.axes is not None and dim not in reduce.axes: # if map_range[1] != symbolic.pystr_to_symbolic( # reduce.outslice[axis_after_reduce[dim]][1]): # return False # Range check (output axis) # else: # if map_range[1] != symbolic.pystr_to_symbolic(reduce.inslice[dim][1]): # return False # Range check (reduction axis) # Verify that reduction ranges match tasklet map tout_memlet = graph.in_edges(in_array)[0].data rin_memlet = graph.out_edges(in_array)[0].data if tout_memlet.subset != rin_memlet.subset: return False return True @staticmethod def match_to_str(graph, candidate): tasklet = candidate[MapReduceFusion._tasklet] map_exit = candidate[MapReduceFusion._tmap_exit] if len(candidate) == 5: # Expression 2 reduce = candidate[MapReduceFusion._reduce] else: reduce = candidate[MapReduceFusion._rmap_in_cr] return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce]) @staticmethod def find_memlet_map_permutation(memlet: Memlet, map: nodes.Map): perm = [None] * len(memlet.subset) indices = set() for i, dim in enumerate(memlet.subset): for j, mapdim in enumerate(map.params): if symbolic.pystr_to_symbolic( mapdim) == dim and j not in indices: perm[i] = j indices.add(j) break return perm @staticmethod def find_permutation(tasklet_map: nodes.Map, red_outer_map: nodes.Map, red_inner_map: nodes.Map, tmem: Memlet): """ Find permutation between tasklet-exit memlet and tasklet map. """ result = [], [] assert len(tasklet_map.range) == len(red_inner_map.range) + len( red_outer_map.range) # Match map ranges with reduce ranges unavailable_ranges_out = set() unavailable_ranges_in = set() for i, tmap_rng in enumerate(tasklet_map.range): found = False for j, rng in enumerate(red_outer_map.range): if tmap_rng == rng and j not in unavailable_ranges_out: result[0].append(i) unavailable_ranges_out.add(j) found = True break if found: continue for j, rng in enumerate(red_inner_map.range): if tmap_rng == rng and j not in unavailable_ranges_in: result[1].append(i) unavailable_ranges_in.add(j) found = True break if not found: break # Ensure all map variables matched with reduce variables assert len(result[0]) + len(result[1]) == len(tasklet_map.range) # Returns ([outer map indices], [inner (CR) map indices]) return result @staticmethod def find_permutation_reduce(tasklet_map: nodes.Map, reduce_node: nodes.Reduce, graph: SDFGState, tmem: Memlet): in_memlet = graph.in_edges(reduce_node)[0].data out_memlet = graph.out_edges(reduce_node)[0].data assert len(tasklet_map.range) == in_memlet.subset.dims() # Find permutation between tasklet-exit memlet and tasklet map tmem_perm = MapReduceFusion.find_memlet_map_permutation( tmem, tasklet_map) mapred_perm = [] # Match map ranges with reduce ranges unavailable_ranges = set() for i, tmap_rng in enumerate(tasklet_map.range): found = False for j, in_rng in enumerate(in_memlet.subset): if tmap_rng == in_rng and j not in unavailable_ranges: mapred_perm.append(i) unavailable_ranges.add(j) found = True break if not found: break # Ensure all map variables matched with reduce variables assert len(tmem_perm) == len(tmem.subset) assert len(mapred_perm) == len(in_memlet.subset) # Prepare result from the two permutations and the reduction axes result = [] for i in range(len(mapred_perm)): if reduce_node.axes is None or i in reduce_node.axes: continue result.append(mapred_perm[tmem_perm[i]]) return result def apply(self, sdfg): def gnode(nname): return graph.nodes()[self.subgraph[nname]] expr_index = self.expr_index graph = sdfg.nodes()[self.state_id] tasklet = gnode(MapReduceFusion._tasklet) tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]] in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]] if expr_index == 0: # Reduce without outer map rmap_entry = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_in_entry]] elif expr_index == 1: # Reduce with outer map rmap_out_entry = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_out_entry]] rmap_out_exit = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_out_exit]] rmap_in_entry = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_in_entry]] rmap_tasklet = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_in_tasklet]] if expr_index == 2: rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._reduce]] else: rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._rmap_in_cr]] out_array = gnode(MapReduceFusion._out_array) # Set nodes to remove according to the expression index nodes_to_remove = [in_array] if expr_index == 0: nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry)) elif expr_index == 1: nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_entry)) nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry)) nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_exit)) else: nodes_to_remove.append(gnode(MapReduceFusion._reduce)) # If no other edges lead to mapexit, remove it. Otherwise, keep # it and remove reduction incoming/outgoing edges if expr_index != 2 and len(graph.in_edges(tmap_exit)) == 1: nodes_to_remove.append(tmap_exit) memlet_edge = None for edge in graph.in_edges(tmap_exit): if edge.data.data == in_array.data: memlet_edge = edge break if memlet_edge is None: raise RuntimeError('Reduction memlet cannot be None') if expr_index == 0: # Reduce without outer map # Index order does not matter, merge as-is pass elif expr_index == 1: # Reduce with outer map tmap = tmap_exit.map perm_outer, perm_inner = MapReduceFusion.find_permutation( tmap, rmap_out_entry.map, rmap_in_entry.map, memlet_edge.data) # Split tasklet map into tmap_out -> tmap_in (according to # reduction) omap = nodes.Map( tmap.label + '_nonreduce', [p for i, p in enumerate(tmap.params) if i in perm_outer], [r for i, r in enumerate(tmap.range) if i in perm_outer], tmap.schedule, tmap.unroll, tmap.is_async) tmap.params = [ p for i, p in enumerate(tmap.params) if i in perm_inner ] tmap.range = [ r for i, r in enumerate(tmap.range) if i in perm_inner ] omap_entry = nodes.MapEntry(omap) omap_exit = rmap_out_exit rmap_out_exit.map = omap # Reconnect graph to new map tmap_entry = graph.entry_node(tmap_exit) tmap_in_edges = list(graph.in_edges(tmap_entry)) for e in tmap_in_edges: nxutil.change_edge_dest(graph, tmap_entry, omap_entry) for e in tmap_in_edges: graph.add_edge(omap_entry, e.src_conn, tmap_entry, e.dst_conn, copy.copy(e.data)) elif expr_index == 2: # Reduce node # Find correspondence between map indices and array outputs tmap = tmap_exit.map perm = MapReduceFusion.find_permutation_reduce( tmap, rmap_cr, graph, memlet_edge.data) output_subset = [tmap.params[d] for d in perm] if len(output_subset) == 0: # Output is a scalar output_subset = [0] array_edge = graph.out_edges(rmap_cr)[0] # Delete relevant edges and nodes graph.remove_edge(memlet_edge) graph.remove_nodes_from(nodes_to_remove) # Add new edges and nodes # From tasklet to map exit graph.add_edge( memlet_edge.src, memlet_edge.src_conn, memlet_edge.dst, memlet_edge.dst_conn, Memlet(out_array.data, memlet_edge.data.num_accesses, subsets.Indices(output_subset), memlet_edge.data.veclen, rmap_cr.wcr, rmap_cr.identity)) # From map exit to output array graph.add_edge( memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:], array_edge.dst, array_edge.dst_conn, Memlet(array_edge.data.data, array_edge.data.num_accesses, array_edge.data.subset, array_edge.data.veclen, rmap_cr.wcr, rmap_cr.identity)) return # Remove tmp array node prior to the others, so that a new one # can be created in its stead (see below) graph.remove_node(nodes_to_remove[0]) nodes_to_remove = nodes_to_remove[1:] # Create tasklet -> tmp -> tasklet connection tmp = graph.add_array( 'tmp', memlet_edge.data.subset.bounding_box_size(), sdfg.arrays[memlet_edge.data.data].dtype, transient=True) tasklet_tmp_memlet = copy.deepcopy(memlet_edge.data) tasklet_tmp_memlet.data = tmp.data tasklet_tmp_memlet.subset = ShapeProperty.to_string(tmp.shape) # Modify memlet to point to output array memlet_edge.data.data = out_array.data # Recover reduction axes from CR reduce subset reduce_cr_subset = graph.in_edges(rmap_tasklet)[0].data.subset reduce_axes = [] for ind, crvar in enumerate(reduce_cr_subset.indices): if '__i' in str(crvar): reduce_axes.append(ind) # Modify memlet access index by filtering out reduction axes if True: # expr_index == 0: newindices = [] for ind, ovar in enumerate(memlet_edge.data.subset.indices): if ind not in reduce_axes: newindices.append(ovar) if len(newindices) == 0: newindices = [0] memlet_edge.data.subset = subsets.Indices(newindices) graph.remove_edge(memlet_edge) graph.add_edge(memlet_edge.src, memlet_edge.src_conn, tmp, memlet_edge.dst_conn, tasklet_tmp_memlet) red_edges = list(graph.in_edges(rmap_tasklet)) if len(red_edges) != 1: raise RuntimeError('CR edge must be unique') tmp_tasklet_memlet = copy.deepcopy(tasklet_tmp_memlet) graph.add_edge(tmp, None, rmap_tasklet, red_edges[0].dst_conn, tmp_tasklet_memlet) for e in graph.edges_between(rmap_tasklet, rmap_cr): e.data.subset = memlet_edge.data.subset # Move output edges to point directly to CR node if expr_index == 1: # Set output memlet between CR node and outer reduction map to # contain the same subset as the one pointing to the CR node for e in graph.out_edges(rmap_cr): e.data.subset = memlet_edge.data.subset rmap_out = gnode(MapReduceFusion._rmap_out_exit) nxutil.change_edge_src(graph, rmap_out, omap_exit) # Remove nodes graph.remove_nodes_from(nodes_to_remove) # For unrelated outputs, connect original output to rmap_out if expr_index == 1 and tmap_exit not in nodes_to_remove: other_out_edges = list(graph.out_edges(tmap_exit)) for e in other_out_edges: graph.remove_edge(e) graph.add_edge(e.src, e.src_conn, omap_exit, None, e.data) graph.add_edge(omap_exit, None, e.dst, e.dst_conn, copy.copy(e.data)) def modifies_graph(self): return True
def apply(self, sdfg: dace.SDFG): # Extract the map and its entry and exit nodes. graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[MapExpansion._map_entry]] map_exit = graph.exit_nodes(map_entry)[0] current_map = map_entry.map # Create new maps maps = [ nodes.Map(current_map.label + '_' + str(param), [param], subsets.Range([param_range]), schedule=dtypes.ScheduleType.Sequential) for param, param_range in zip(current_map.params, current_map.range) ] maps[0]._schedule = dtypes.ScheduleType.Default # Create new map entries entries = [nodes.MapEntry(new_map) for new_map in maps] entries[0].in_connectors = map_entry.in_connectors entries[0].out_connectors = map_entry.out_connectors num_entry_out_edges = len(graph.out_edges(map_entry)) for i in range(1, len(entries)): entries[i].in_connectors = set('IN_' + str(i + 1) for i in range(num_entry_out_edges)) entries[i].out_connectors = set( 'OUT_' + str(i + 1) for i in range(num_entry_out_edges)) # Create new map exits exits = [nodes.MapExit(new_map) for new_map in maps] exits.reverse() exits[-1].in_connectors = map_exit.in_connectors exits[-1].out_connectors = map_exit.out_connectors num_entry_out_edges = len(graph.out_edges(map_exit)) for i in range(0, len(exits) - 1): exits[i].in_connectors = set('IN_' + str(i + 1) for i in range(num_entry_out_edges)) exits[i].out_connectors = set('OUT_' + str(i + 1) for i in range(num_entry_out_edges)) # Add new nodes to state graph.add_nodes_from(entries) graph.add_nodes_from(exits) # Redirect edges to new nodes dace.graph.nxutil.change_edge_dest(graph, map_entry, entries[0]) dace.graph.nxutil.change_edge_src(graph, map_exit, exits[-1]) for i, e in enumerate(graph.out_edges(map_entry)): graph.remove_edge(e) graph.add_edge(entries[0], e.src_conn, entries[1], 'IN_' + str(i + 1), copy.deepcopy(e.data)) graph.add_edge(entries[-1], 'OUT_' + str(i + 1), e.dst, e.dst_conn, copy.deepcopy(e.data)) for j in range(1, len(entries) - 1): graph.add_edge(entries[j], 'OUT_' + str(i + 1), entries[j + 1], 'IN_' + str(i + 1), copy.deepcopy(e.data)) for i, e in enumerate(graph.in_edges(map_exit)): graph.remove_edge(e) graph.add_edge(e.src, e.src_conn, exits[0], 'IN_' + str(i + 1), copy.deepcopy(e.data)) graph.add_edge(exits[-2], 'OUT_' + str(i + 1), exits[-1], e.dst_conn, copy.deepcopy(e.data)) for j in range(0, len(exits) - 2): graph.add_edge(exits[j], 'OUT_' + str(i + 1), exits[j + 1], 'IN_' + str(i + 1), copy.deepcopy(e.data)) # Remove old nodes graph.remove_node(map_entry) graph.remove_node(map_exit)
class AccumulateTransient(pattern_matching.Transformation): """ Implements the AccumulateTransient transformation, which adds transient stream and data nodes between nested maps that lead to a stream. The transient data nodes then act as a local accumulator. """ _tasklet = nodes.Tasklet('_') _map_exit = nodes.MapExit(nodes.Map("", [], [])) _outer_map_exit = nodes.MapExit(nodes.Map("", [], [])) array = Property( dtype=str, desc="Array to create local storage for (if empty, first available)", default=None, allow_none=True) @staticmethod def expressions(): return [ nxutil.node_path_graph(AccumulateTransient._tasklet, AccumulateTransient._map_exit, AccumulateTransient._outer_map_exit) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): tasklet = graph.nodes()[candidate[AccumulateTransient._tasklet]] map_exit = graph.nodes()[candidate[AccumulateTransient._map_exit]] # Check if there is an accumulation output for _src, _, dest, _, memlet in graph.out_edges(tasklet): if memlet.wcr is not None and dest == map_exit: return True return False @staticmethod def match_to_str(graph, candidate): tasklet = candidate[AccumulateTransient._tasklet] map_exit = candidate[AccumulateTransient._map_exit] outer_map_exit = candidate[AccumulateTransient._outer_map_exit] return ' -> '.join( str(node) for node in [tasklet, map_exit, outer_map_exit]) def apply(self, sdfg): graph = sdfg.node(self.state_id) # Avoid import loop from dace.transformation.dataflow.local_storage import LocalStorage local_storage_subgraph = { LocalStorage._node_a: self.subgraph[AccumulateTransient._map_exit], LocalStorage._node_b: self.subgraph[AccumulateTransient._outer_map_exit] } sdfg_id = sdfg.sdfg_list.index(sdfg) in_local_storage = LocalStorage( sdfg_id, self.state_id, local_storage_subgraph, self.expr_index) in_local_storage.array = self.array in_local_storage.apply(sdfg) # Initialize transient to zero in case of summation # TODO: Initialize transient in other WCR types memlet = graph.in_edges(in_local_storage._data_node)[0].data if detect_reduction_type(memlet.wcr) == dtypes.ReductionType.Sum: in_local_storage._data_node.setzero = True else: warnings.warn('AccumulateTransient did not properly initialize' 'newly-created transient!')
def apply(self, sdfg: dace.SDFG): # Extract the map and its entry and exit nodes. graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[MapExpansion._map_entry]] map_exit = graph.exit_nodes(map_entry)[0] current_map = map_entry.map # Create new maps new_maps = [ nodes.Map(current_map.label + '_' + str(param), [param], subsets.Range([param_range]), schedule=dtypes.ScheduleType.Sequential) for param, param_range in zip(current_map.params[1:], current_map.range[1:]) ] current_map.params = [current_map.params[0]] current_map.range = subsets.Range([current_map.range[0]]) # Create new map entries and exits entries = [nodes.MapEntry(new_map) for new_map in new_maps] exits = [nodes.MapExit(new_map) for new_map in new_maps] # Create edges, abiding by the following rules: # 1. If there are no edges coming from the outside, use empty memlets # 2. Edges with IN_* connectors replicate along the maps # 3. Edges for dynamic map ranges replicate until reaching range(s) for edge in graph.out_edges(map_entry): graph.remove_edge(edge) graph.add_memlet_path(map_entry, *entries, edge.dst, src_conn=edge.src_conn, memlet=edge.data, dst_conn=edge.dst_conn) # Modify dynamic map ranges dynamic_edges = dace.sdfg.dynamic_map_inputs(graph, map_entry) for edge in dynamic_edges: # Remove old edge and connector graph.remove_edge(edge) edge.dst._in_connectors.remove(edge.dst_conn) # Propagate to each range it belongs to path = [] for mapnode in [map_entry] + entries: path.append(mapnode) if any(edge.dst_conn in map(str, symbolic.symlist(r)) for r in mapnode.map.range): graph.add_memlet_path(edge.src, *path, memlet=edge.data, src_conn=edge.src_conn, dst_conn=edge.dst_conn) # Create new map exits for edge in graph.in_edges(map_exit): graph.remove_edge(edge) graph.add_memlet_path(edge.src, *exits[::-1], map_exit, memlet=edge.data, src_conn=edge.src_conn, dst_conn=edge.dst_conn)
def __stripmine(self, sdfg, graph, candidate): # Retrieve map entry and exit nodes. map_entry = graph.nodes()[candidate[StripMining._map_entry]] map_exits = graph.exit_nodes(map_entry) # Retrieve transformation properties. dim_idx = self.dim_idx new_dim_prefix = self.new_dim_prefix tile_size = self.tile_size divides_evenly = self.divides_evenly strided = self.strided # Retrieve parameter and range of dimension to be strip-mined. target_dim = map_entry.map.params[dim_idx] td_from, td_to, td_step = map_entry.map.range[dim_idx] # Create new map. Replace by cloning??? new_dim = new_dim_prefix + '_' + target_dim nd_from = 0 nd_to = symbolic.pystr_to_symbolic( 'int_ceil(%s + 1 - %s, %s) - 1' % (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_size)) nd_step = 1 new_dim_range = (nd_from, nd_to, nd_step) new_map = nodes.Map(new_dim + '_' + map_entry.map.label, [new_dim], subsets.Range([new_dim_range])) new_map_entry = nodes.MapEntry(new_map) # Change the range of the selected dimension to iterate over a single # tile if strided: td_from_new = symbolic.pystr_to_symbolic(new_dim) td_to_new_approx = td_to td_step = symbolic.pystr_to_symbolic(tile_size) else: td_from_new = symbolic.pystr_to_symbolic( '%s + %s * %s' % (symbolic.symstr(td_from), str(new_dim), tile_size)) td_to_new_exact = symbolic.pystr_to_symbolic( 'min(%s + 1, %s + %s * %s + %s) - 1' % (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_size, str(new_dim), tile_size)) td_to_new_approx = symbolic.pystr_to_symbolic( '%s + %s * %s + %s - 1' % (symbolic.symstr(td_from), tile_size, str(new_dim), tile_size)) if divides_evenly or strided: td_to_new = td_to_new_approx else: td_to_new = dace.symbolic.SymExpr(td_to_new_exact, td_to_new_approx) map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step) # Make internal map's schedule to "not parallel" map_entry.map._schedule = dtypes.ScheduleType.Default # Redirect/create edges. new_in_edges = {} for _src, conn, _dest, _, memlet in graph.out_edges(map_entry): if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) if memlet.data in new_in_edges: src, src_conn, dest, dest_conn, new_memlet, num = \ new_in_edges[memlet.data] new_memlet.subset = calc_set_union(new_memlet.subset, new_subset) new_memlet.num_accesses = new_memlet.num_elements() new_in_edges.update({ memlet.data: (src, src_conn, dest, dest_conn, new_memlet, min(num, int(conn[4:]))) }) else: new_memlet = dcpy(memlet) new_memlet.subset = new_subset new_memlet.num_accesses = new_memlet.num_elements() new_in_edges.update({ memlet.data: (new_map_entry, None, map_entry, None, new_memlet, int(conn[4:])) }) nxutil.change_edge_dest(graph, map_entry, new_map_entry) new_out_edges = {} new_exits = [] for map_exit in map_exits: if isinstance(map_exit, nodes.MapExit): new_exit = nodes.MapExit(new_map) new_exits.append(new_exit) for _src, conn, _dest, _, memlet in graph.in_edges(map_exit): if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) if memlet.data in new_out_edges: src, src_conn, dest, dest_conn, new_memlet, num = \ new_out_edges[memlet.data] new_memlet.subset = calc_set_union( new_memlet.subset, new_subset) new_memlet.num_accesses = new_memlet.num_elements() new_out_edges.update({ memlet.data: (src, src_conn, dest, dest_conn, new_memlet, min(num, conn[4:])) }) else: new_memlet = dcpy(memlet) new_memlet.subset = new_subset new_memlet.num_accesses = new_memlet.num_elements() new_out_edges.update({ memlet.data: (map_exit, None, new_exit, None, new_memlet, conn[4:]) }) nxutil.change_edge_src(graph, map_exit, new_exit) in_conn_nums = [] for _, e in new_in_edges.items(): _, _, _, _, _, num = e in_conn_nums.append(num) in_conn = {} for i, num in enumerate(in_conn_nums): in_conn.update({num: i + 1}) entry_in_connectors = set() entry_out_connectors = set() for i in range(len(in_conn_nums)): entry_in_connectors.add('IN_' + str(i + 1)) entry_out_connectors.add('OUT_' + str(i + 1)) new_map_entry.in_connectors = entry_in_connectors new_map_entry.out_connectors = entry_out_connectors for _, e in new_in_edges.items(): src, _, dst, _, memlet, num = e graph.add_edge(src, 'OUT_' + str(in_conn[num]), dst, 'IN_' + str(in_conn[num]), memlet) for new_exit in new_exits: out_conn_nums = [] for _, e in new_out_edges.items(): _, _, dst, _, _, num = e if dst is not new_exit: continue out_conn_nums.append(num) out_conn = {} for i, num in enumerate(out_conn_nums): out_conn.update({num: i + 1}) exit_in_connectors = set() exit_out_connectors = set() for i in range(len(out_conn_nums)): exit_in_connectors.add('IN_' + str(i + 1)) exit_out_connectors.add('OUT_' + str(i + 1)) new_exit.in_connectors = exit_in_connectors new_exit.out_connectors = exit_out_connectors for _, e in new_out_edges.items(): src, _, dst, _, memlet, num = e graph.add_edge(src, 'OUT_' + str(out_conn[num]), dst, 'IN_' + str(out_conn[num]), memlet) # Return strip-mined dimension. return target_dim, new_dim, new_map
def my_add_mapped_tasklet( state, inpdict, outdict, name: str, map_ranges: Dict[str, dace.subsets.Subset], inputs: Dict[str, dace.memlet.Memlet], code: str, outputs: Dict[str, dace.memlet.Memlet], schedule=dace.dtypes.ScheduleType.Default, unroll_map=False, code_global="", code_init="", code_exit="", location="-1", language=dace.dtypes.Language.Python, debuginfo=None, external_edges=True, ) -> Tuple[dace.graph.nodes.Node]: """ Convenience function that adds a map entry, tasklet, map exit, and the respective edges to external arrays. :param name: Tasklet (and wrapping map) name :param map_ranges: Mapping between variable names and their subsets :param inputs: Mapping between input local variable names and their memlets :param code: Code (written in `language`) :param outputs: Mapping between output local variable names and their memlets :param schedule: Map schedule :param unroll_map: True if map should be unrolled in code generation :param code_global: (optional) Global code (outside functions) :param language: Programming language in which the code is written :param debuginfo: Debugging information (mostly for DIODE) :param external_edges: Create external access nodes and connect them with memlets automatically :return: tuple of (tasklet, map_entry, map_exit) """ import dace.graph.nodes as nd from dace.sdfg import getdebuginfo from dace.graph.labeling import propagate_memlet map_name = name + "_map" debuginfo = getdebuginfo(debuginfo) tasklet = nd.Tasklet( name, set(inputs.keys()), set(outputs.keys()), code, language=language, code_global=code_global, code_init=code_init, code_exit=code_exit, location=location, debuginfo=debuginfo, ) map = state._map_from_ndrange(map_name, schedule, unroll_map, map_ranges, debuginfo=debuginfo) map_entry = nd.MapEntry(map) map_exit = nd.MapExit(map) state.add_nodes_from([map_entry, tasklet, map_exit]) tomemlet = {} for name, memlet in inputs.items(): memlet.name = name state.add_edge(map_entry, None, tasklet, name, memlet) tomemlet[memlet.data] = memlet if len(inputs) == 0: state.add_edge(map_entry, None, tasklet, None, dace.memlet.EmptyMemlet()) if external_edges: for inp, inpnode in inpdict.items(): outer_memlet = propagate_memlet(state, tomemlet[inp], map_entry, True) state.add_edge(inpnode, None, map_entry, "IN_" + inp, outer_memlet) for e in state.out_edges(map_entry): if e.data.data == inp: e._src_conn = "OUT_" + inp map_entry.add_in_connector("IN_" + inp) map_entry.add_out_connector("OUT_" + inp) tomemlet = {} for name, memlet in outputs.items(): memlet.name = name state.add_edge(tasklet, name, map_exit, None, memlet) tomemlet[memlet.data] = memlet if len(outputs) == 0: state.add_edge(tasklet, None, map_exit, None, mm.EmptyMemlet()) if external_edges: for out, outnode in outdict.items(): outer_memlet = propagate_memlet(state, tomemlet[out], map_exit, True) state.add_edge(map_exit, "OUT_" + out, outnode, None, outer_memlet) for e in state.in_edges(map_exit): if e.data.data == out: e._dst_conn = "IN_" + out map_exit.add_in_connector("IN_" + out) map_exit.add_out_connector("OUT_" + out) return tasklet, map_entry, map_exit