def apply(self, sdfg: SDFG) -> Union[Any, None]: state = sdfg.node(self.state_id) nsdfg = self.nsdfg(sdfg) read_set, write_set = nsdfg.sdfg.read_and_write_sets() prune_in = nsdfg.in_connectors.keys() - read_set prune_out = nsdfg.out_connectors.keys() - write_set # Detect which nodes are used, so we can delete unused nodes after the # connectors have been pruned all_data_used = read_set | write_set # Add WCR outputs to "do not prune" input list for e in state.out_edges(nsdfg): if e.data.wcr is not None and e.src_conn in prune_in: if (state.in_degree( next( iter(state.in_edges_by_connector( nsdfg, e.src_conn))).src) > 0): prune_in.remove(e.src_conn) for conn in prune_in: for e in state.in_edges_by_connector(nsdfg, conn): state.remove_memlet_path(e, remove_orphans=True) if conn in nsdfg.sdfg.arrays and conn not in all_data_used: # If the data is now unused, we can purge it from the SDFG nsdfg.sdfg.remove_data(conn) for conn in prune_out: for e in state.out_edges_by_connector(nsdfg, conn): state.remove_memlet_path(e, remove_orphans=True) if conn in nsdfg.sdfg.arrays and conn not in all_data_used: # If the data is now unused, we can purge it from the SDFG nsdfg.sdfg.remove_data(conn)
def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: StateGraphView): scope_dict = sdfg.node(state_id).scope_dict() out_costs = 0 for edge in dfg.out_edges(node): _, uconn, v, _, memlet = edge dst_node = dfg.memlet_path(edge)[-1].dst if (isinstance(node, nodes.CodeNode) and isinstance(dst_node, nodes.AccessNode)): # If the memlet is pointing into an array in an inner scope, # it will be handled by the inner scope. if (scope_dict[node] != scope_dict[dst_node] and scope_contains_scope(scope_dict, node, dst_node)): continue if not uconn: # This would normally raise a syntax error return 0 if memlet.subset.data_dims() == 0: if memlet.wcr is not None: # write_and_resolve # We have to assume that every reduction costs 3 # accesses of the same size (read old, read new, write) out_costs += 3 * PAPIUtils.get_memlet_byte_size( sdfg, memlet) else: # This standard operation is already counted out_costs += PAPIUtils.get_memlet_byte_size( sdfg, memlet) return out_costs
def apply(self, sdfg: dace.SDFG) -> None: state = sdfg.node(self.state_id) left = self.left(sdfg) right = self.right(sdfg) # Merge source locations dinfo = self._merge_source_locations(left, right) # merge oir nodes res = HorizontalExecutionLibraryNode( oir_node=oir.HorizontalExecution( body=left.as_oir().body + right.as_oir().body, declarations=left.as_oir().declarations + right.as_oir().declarations, ), iteration_space=left.iteration_space, debuginfo=dinfo, ) state.add_node(res) intermediate_accesses = set( n for path in nx.all_simple_paths(state.nx, left, right) for n in path[1:-1]) # rewire edges and connectors to left and delete right for edge in state.edges_between(left, right): state.remove_edge_and_connectors(edge) for acc in intermediate_accesses: for edge in state.in_edges(acc): if edge.src is not left: rewire_edge(state, edge, dst=res) else: state.remove_edge_and_connectors(edge) for edge in state.out_edges(acc): if edge.dst is not right: rewire_edge(state, edge, src=res) else: state.remove_edge_and_connectors(edge) for edge in state.in_edges(left): rewire_edge(state, edge, dst=res) for edge in state.out_edges(right): rewire_edge(state, edge, src=res) for edge in state.out_edges(left): rewire_edge(state, edge, src=res) for edge in state.in_edges(right): rewire_edge(state, edge, dst=res) state.remove_node(left) state.remove_node(right) for acc in intermediate_accesses: if not state.in_edges(acc): if not state.out_edges(acc): state.remove_node(acc) else: assert (len(state.edges_between(acc, res)) == 1 and len(state.out_edges(acc)) == 1), "Previously written array now read-only." state.remove_node(acc) res.remove_in_connector("IN_" + acc.label) elif not state.out_edges: acc.access = dace.AccessType.WriteOnly
def apply(self, sdfg: SDFG) -> Union[Any, None]: state = sdfg.node(self.state_id) nsdfg = self.nsdfg(sdfg) read_set, write_set = nsdfg.sdfg.read_and_write_sets() prune_in = nsdfg.in_connectors.keys() - read_set prune_out = nsdfg.out_connectors.keys() - write_set # Detect which nodes are used, so we can delete unused nodes after the # connectors have been pruned all_data_used = read_set | write_set # Add WCR outputs to "do not prune" input list for e in state.out_edges(nsdfg): if e.data.wcr is not None and e.src_conn in prune_in: if (state.in_degree( next( iter(state.in_edges_by_connector( nsdfg, e.src_conn))).src) > 0): prune_in.remove(e.src_conn) do_not_prune = set() for conn in prune_in: if any( state.in_degree(state.memlet_path(e)[0].src) > 0 for e in state.in_edges(nsdfg) if e.dst_conn == conn): do_not_prune.add(conn) continue for e in state.in_edges_by_connector(nsdfg, conn): state.remove_memlet_path(e, remove_orphans=True) for conn in prune_out: if any( state.out_degree(state.memlet_path(e)[-1].dst) > 0 for e in state.out_edges(nsdfg) if e.src_conn == conn): do_not_prune.add(conn) continue for e in state.out_edges_by_connector(nsdfg, conn): state.remove_memlet_path(e, remove_orphans=True) for conn in prune_in: if conn in nsdfg.sdfg.arrays and conn not in all_data_used and conn not in do_not_prune: # If the data is now unused, we can purge it from the SDFG nsdfg.sdfg.remove_data(conn) for conn in prune_out: if conn in nsdfg.sdfg.arrays and conn not in all_data_used and conn not in do_not_prune: # If the data is now unused, we can purge it from the SDFG nsdfg.sdfg.remove_data(conn) if self.remove_unused_containers: # Remove unused containers from parent SDFGs containers = list(sdfg.arrays.keys()) for name in containers: s = nsdfg.sdfg while s.parent_sdfg: s = s.parent_sdfg try: s.remove_data(name) except ValueError: break
def apply(self, sdfg: dace.SDFG): before_state = sdfg.node(self.subgraph[self._before_state]) loop_state = sdfg.node(self.subgraph[self._loop_state]) guard_state = sdfg.node(self.subgraph[self._guard_state]) loop_var = next(iter(sdfg.in_edges(guard_state)[0].data.assignments)) loop_axis = self._get_loop_axis(loop_state, loop_var) buffer_size = self._get_buffer_size(loop_state, loop_var, loop_axis) self._replace_indices(sdfg.states(), loop_var, loop_axis, buffer_size) array = sdfg.arrays[self.array] # TODO: generalize if array.shape[loop_axis] == array.total_size: array.shape = tuple(buffer_size if i == loop_axis else s for i, s in enumerate(array.shape)) array.total_size = buffer_size
def apply(self, sdfg: dace.SDFG): state = sdfg.node(self.state_id) first_map_entry = state.node(self.subgraph[self._first_map_entry]) first_tasklet = state.node(self.subgraph[self._first_tasklet]) first_map_exit = state.node(self.subgraph[self._first_map_exit]) array_access = state.node(self.subgraph[self._array_access]) second_map_entry = state.node(self.subgraph[self._second_map_entry]) self._update_map_connectors(state, array_access, first_map_entry, second_map_entry) self._replicate_first_map(sdfg, array_access, first_map_entry, first_map_exit, second_map_entry) state.remove_nodes_from( state.all_nodes_between(first_map_entry, first_map_exit) | {first_map_exit})
def apply(self, sdfg: SDFG): input: nodes.AccessNode = self.input(sdfg) tasklet: nodes.Tasklet = self.tasklet(sdfg) output: nodes.AccessNode = self.output(sdfg) state: SDFGState = sdfg.node(self.state_id) # If state fission is necessary to keep semantics, do it first if (self.expr_index == 0 and state.in_degree(input) > 0 and state.out_degree(output) == 0): newstate = sdfg.add_state_after(state) newstate.add_node(tasklet) new_input, new_output = None, None # Keep old edges for after we remove tasklet from the original state in_edges = list(state.in_edges(tasklet)) out_edges = list(state.out_edges(tasklet)) for e in in_edges: r = newstate.add_read(e.src.data) newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data) if e.src is input: new_input = r for e in out_edges: w = newstate.add_write(e.dst.data) newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data) if e.dst is output: new_output = w # Remove tasklet and resulting isolated nodes state.remove_node(tasklet) for e in in_edges: if state.degree(e.src) == 0: state.remove_node(e.src) for e in out_edges: if state.degree(e.dst) == 0: state.remove_node(e.dst) # Reset state and nodes for rest of transformation input = new_input output = new_output state = newstate # End of state fission if self.expr_index == 0: inedges = state.edges_between(input, tasklet) outedge = state.edges_between(tasklet, output)[0] else: me = self.map_entry(sdfg) mx = self.map_exit(sdfg) inedges = state.edges_between(me, tasklet) outedge = state.edges_between(tasklet, mx)[0] # Get relevant output connector outconn = outedge.src_conn ops = '[%s]' % ''.join( re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) # Change tasklet code if tasklet.language is dtypes.Language.Python: raise NotImplementedError elif tasklet.language is dtypes.Language.CPP: cstr = tasklet.code.as_string.strip() for edge in inedges: inconn = edge.dst_conn match = re.match( r'^\s*%s\s*=\s*%s\s*(%s)(.*);$' % (re.escape(outconn), re.escape(inconn), ops), cstr) if match is None: # match = re.match( # r'^\s*%s\s*=\s*(.*)\s*(%s)\s*%s;$' % # (re.escape(outconn), ops, re.escape(inconn)), cstr) # if match is None: continue # op = match.group(2) # expr = match.group(1) else: op = match.group(1) expr = match.group(2) if edge.data.subset != outedge.data.subset: continue # Map asymmetric WCRs to symmetric ones if possible if op in AugAssignToWCR._EXPR_MAP: op, newexpr = AugAssignToWCR._EXPR_MAP[op] expr = newexpr.format(expr=expr) tasklet.code.code = '%s = %s;' % (outconn, expr) inedge = edge break else: raise NotImplementedError # Change output edge outedge.data.wcr = f'lambda a,b: a {op} b' if self.expr_index == 0: # Remove input node and connector state.remove_edge_and_connectors(inedge) if state.degree(input) == 0: state.remove_node(input) else: # Remove input edge and dst connector, but not necessarily src state.remove_memlet_path(inedge) # If outedge leads to non-transient, and this is a nested SDFG, # propagate outwards sd = sdfg while (not sd.arrays[outedge.data.data].transient and sd.parent_nsdfg_node is not None): nsdfg = sd.parent_nsdfg_node nstate = sd.parent sd = sd.parent_sdfg outedge = next( iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data))) for outedge in nstate.memlet_path(outedge): outedge.data.wcr = f'lambda a,b: a {op} b'
def apply(self, sdfg: dace.SDFG): # Extract the map and its entry and exit nodes. graph = sdfg.node(self.state_id) map_entry = self.map_entry(sdfg) map_exit = graph.exit_node(map_entry) current_map = map_entry.map # Create new maps new_maps = [ nodes.Map(current_map.label + '_' + str(param), [param], subsets.Range([param_range]), schedule=dtypes.ScheduleType.Sequential) for param, param_range in zip(current_map.params[1:], current_map.range[1:]) ] current_map.params = [current_map.params[0]] current_map.range = subsets.Range([current_map.range[0]]) # Create new map entries and exits entries = [nodes.MapEntry(new_map) for new_map in new_maps] exits = [nodes.MapExit(new_map) for new_map in new_maps] # Create edges, abiding by the following rules: # 1. If there are no edges coming from the outside, use empty memlets # 2. Edges with IN_* connectors replicate along the maps # 3. Edges for dynamic map ranges replicate until reaching range(s) for edge in graph.out_edges(map_entry): graph.remove_edge(edge) graph.add_memlet_path(map_entry, *entries, edge.dst, src_conn=edge.src_conn, memlet=edge.data, dst_conn=edge.dst_conn) # Modify dynamic map ranges dynamic_edges = dace.sdfg.dynamic_map_inputs(graph, map_entry) for edge in dynamic_edges: # Remove old edge and connector graph.remove_edge(edge) edge.dst.remove_in_connector(edge.dst_conn) # Propagate to each range it belongs to path = [] for mapnode in [map_entry] + entries: path.append(mapnode) if any(edge.dst_conn in map(str, symbolic.symlist(r)) for r in mapnode.map.range): graph.add_memlet_path(edge.src, *path, memlet=edge.data, src_conn=edge.src_conn, dst_conn=edge.dst_conn) # Create new map exits for edge in graph.in_edges(map_exit): graph.remove_edge(edge) graph.add_memlet_path(edge.src, *exits[::-1], map_exit, memlet=edge.data, src_conn=edge.src_conn, dst_conn=edge.dst_conn) from dace.sdfg.scope import ScopeTree scope = None queue: List[ScopeTree] = graph.scope_leaves() while len(queue) > 0: tnode = queue.pop() if tnode.entry == entries[-1]: scope = tnode break elif tnode.parent is not None: queue.append(tnode.parent) else: raise ValueError('Cannot find scope in state') consolidate_edges(sdfg, scope) return [map_entry] + entries
def apply(self, sdfg: dace.SDFG): guard = sdfg.node(self.subgraph[ld.DetectLoop._loop_guard]) edge = sdfg.in_edges(guard)[0] loopindex = next(iter(edge.data.assignments.keys())) guard._LOOPINDEX = loopindex
def generate_scope(self, sdfg: dace.SDFG, scope: ScopeSubgraphView, state_id: int, function_stream: CodeIOStream, callsite_stream: CodeIOStream): entry_node = scope.source_nodes()[0] loop_type = list(set([sdfg.arrays[a].dtype for a in sdfg.arrays]))[0] ltype_size = loop_type.bytes long_type = copy.copy(dace.int64) long_type.ctype = 'int64_t' self.counter_type = { 1: dace.int8, 2: dace.int16, 4: dace.int32, 8: long_type }[ltype_size] callsite_stream.write('{') # Define all input connectors of the map entry state_dfg = sdfg.node(state_id) for e in dace.sdfg.dynamic_map_inputs(state_dfg, entry_node): if e.data.data != e.dst_conn: callsite_stream.write( self.cpu_codegen.memlet_definition( sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), sdfg, state_id, entry_node) # We only create an SVE do-while in the innermost loop for param, rng in zip(entry_node.map.params, entry_node.map.range): begin, end, stride = (sym2cpp(r) for r in rng) self.dispatcher.defined_vars.enter_scope(sdfg) # Check whether we are in the innermost loop if param != entry_node.map.params[-1]: # Default C++ for-loop callsite_stream.write( f'for(auto {param} = {begin}; {param} <= {end}; {param} += {stride}) {{' ) else: # Generate the SVE loop header # The name of our loop predicate is always __pg_{param} self.dispatcher.defined_vars.add('__pg_' + param, DefinedType.Scalar, 'svbool_t') # Declare our counting variable (e.g. i) and precompute the loop predicate for our range callsite_stream.write( f'''{self.counter_type} {param} = {begin}; svbool_t __pg_{param} = svwhilele_b{ltype_size * 8}({param}, ({self.counter_type}) {end}); do {{''', sdfg, state_id, entry_node) # Dispatch the subgraph generation self.dispatcher.dispatch_subgraph(sdfg, scope, state_id, function_stream, callsite_stream, skip_entry_node=True, skip_exit_node=True) # Close the loops from above (in reverse) for param, rng in zip(reversed(entry_node.map.params), reversed(entry_node.map.range)): # The innermost loop is SVE and needs a special while-footer, otherwise we just add the closing bracket if param != entry_node.map.params[-1]: # Close the default C++ for-loop callsite_stream.write('}') else: # Generate the SVE loop footer _, end, stride = (sym2cpp(r) for r in rng) # Increase the counting variable (according to the number of processed elements) # Then recompute the loop predicate and test for it callsite_stream.write( f'''{param} += svcntp_b{ltype_size * 8}(__pg_{param}, __pg_{param}) * {stride}; __pg_{param} = svwhilele_b{ltype_size * 8}({param}, ({self.counter_type}) {end}); }} while(svptest_any(svptrue_b{ltype_size * 8}(), __pg_{param}));''', sdfg, state_id, entry_node) self.dispatcher.defined_vars.exit_scope(sdfg) callsite_stream.write('}')
def apply(self, sdfg: dace.SDFG): graph: dace.SDFGState = sdfg.node(self.state_id) stencil_a: Stencil = graph.node( self.subgraph[StencilFusion._stencil_a]) stencil_b: Stencil = graph.node( self.subgraph[StencilFusion._stencil_b]) array: nodes.AccessNode = graph.node( self.subgraph[StencilFusion._tmp_array]) intermediate_name = graph.in_edges(array)[0].src_conn intermediate_name_b = graph.out_edges(array)[0].dst_conn # Replace outputs of first stencil with outputs of second stencil # In node and in connectors, reconnect stencil_a.output_fields = stencil_b.output_fields stencil_a.boundary_conditions = stencil_b.boundary_conditions for edge in list(graph.out_edges(stencil_a)): if edge.src_conn == intermediate_name: graph.remove_edge(edge) del stencil_a._out_connectors[intermediate_name] for edge in graph.out_edges(stencil_b): stencil_a.add_out_connector(edge.src_conn) graph.add_edge(stencil_a, edge.src_conn, edge.dst, edge.dst_conn, edge.data) # Add other stencil inputs of the second stencil to the first # In node and in connectors, reconnect for edge in graph.in_edges(stencil_b): # Skip edge to remove if edge.dst_conn == intermediate_name_b: continue if edge.dst_conn not in stencil_a.accesses: stencil_a.accesses[edge.dst_conn] = stencil_b.accesses[ edge.dst_conn] stencil_a.add_in_connector(edge.dst_conn) graph.add_edge(edge.src, edge.src_conn, stencil_a, edge.dst_conn, edge.data) else: # If same input is accessed in both stencils, only append the # inputs that are new to stencil_a for access in stencil_b.accesses[edge.dst_conn][1]: if access not in stencil_a.accesses[edge.dst_conn][1]: stencil_a.accesses[edge.dst_conn][1].append(access) # Add second stencil's statements to first stencil, replacing the input # to the second stencil with the name of the output access if stencil_a.code.language == dace.Language.Python: # Replace first stencil's output with connector name for i, stmt in enumerate(stencil_a.code.code): stencil_a.code.code[i] = ReplaceSubscript({ intermediate_name: intermediate_name_b }).visit(stmt) # Append second stencil's contents, using connector name instead of # accessing the intermediate transient # TODO: Use offsetted stencil for i, stmt in enumerate(stencil_b.code.code): stencil_a.code.code.append( ReplaceSubscript({ intermediate_name_b: intermediate_name_b }).visit(stmt)) elif stencil_a.code.language == dace.Language.CPP: raise NotImplementedError else: raise ValueError('Unrecognized language: %s' % stencil_a.code.language) # Remove array from graph graph.remove_node(array) del sdfg.arrays[array.data] # Remove 2nd stencil graph.remove_node(stencil_b)
def apply(self, sdfg: SDFG) -> nodes.MapEntry: me: nodes.MapEntry = self.mapentry(sdfg) graph = sdfg.node(self.state_id) # Add new map within map mx = graph.exit_node(me) new_me, new_mx = graph.add_map('warp_tile', dict(__tid=f'0:{self.warp_size}'), dtypes.ScheduleType.GPU_ThreadBlock) __tid = symbolic.pystr_to_symbolic('__tid') for e in graph.out_edges(me): xfh.reconnect_edge_through_map(graph, e, new_me, True) for e in graph.in_edges(mx): xfh.reconnect_edge_through_map(graph, e, new_mx, False) # Stride and offset all internal maps maps_to_stride = xfh.get_internal_scopes(graph, new_me, immediate=True) for nstate, nmap in maps_to_stride: nsdfg = nstate.parent nsdfg_node = nsdfg.parent_nsdfg_node # Map cannot be partitioned across a warp if (nmap.range.size()[-1] < self.warp_size) == True: continue if nsdfg is not sdfg and nsdfg_node is not None: nsdfg_node.symbol_mapping['__tid'] = __tid if '__tid' not in nsdfg.symbols: nsdfg.add_symbol('__tid', dtypes.int32) nmap.range[-1] = (nmap.range[-1][0], nmap.range[-1][1] - __tid, nmap.range[-1][2] * self.warp_size) subgraph = nstate.scope_subgraph(nmap) subgraph.replace(nmap.params[-1], f'{nmap.params[-1]} + __tid') inner_map_exit = nstate.exit_node(nmap) # If requested, replicate maps with multiple dependent maps if self.replicate_maps: destinations = [ nstate.memlet_path(edge)[-1].dst for edge in nstate.out_edges(inner_map_exit) ] for dst in destinations: # Transformation will not replicate map with more than one # output if len(destinations) != 1: break if not isinstance(dst, nodes.AccessNode): continue # Not leading to access node if not xfh.contained_in(nstate, dst, new_me): continue # Memlet path goes out of map if not nsdfg.arrays[dst.data].transient: continue # Cannot modify non-transients for edge in nstate.out_edges(dst)[1:]: rep_subgraph = xfh.replicate_scope( nsdfg, nstate, subgraph) rep_edge = nstate.out_edges( rep_subgraph.sink_nodes()[0])[0] # Add copy of data newdesc = copy.deepcopy(sdfg.arrays[dst.data]) newname = nsdfg.add_datadesc(dst.data, newdesc, find_new_name=True) newaccess = nstate.add_access(newname) # Redirect edges xfh.redirect_edge(nstate, rep_edge, new_dst=newaccess, new_data=newname) xfh.redirect_edge(nstate, edge, new_src=newaccess, new_data=newname) # If has WCR, add warp-collaborative reduction on outputs for out_edge in nstate.out_edges(inner_map_exit): if out_edge.data.wcr is not None: ctype = nsdfg.arrays[out_edge.data.data].dtype.ctype redtype = detect_reduction_type(out_edge.data.wcr) if redtype == dtypes.ReductionType.Custom: raise NotImplementedError credtype = ('dace::ReductionType::' + str(redtype)[str(redtype).find('.') + 1:]) # Add local access between thread-local and warp reduction name = nsdfg._find_new_name(out_edge.data.data) nsdfg.add_scalar(name, nsdfg.arrays[out_edge.data.data].dtype, transient=True) # Initialize thread-local to global value read = nstate.add_read(out_edge.data.data) write = nstate.add_write(name) edge = nstate.add_nedge(read, write, copy.deepcopy(out_edge.data)) edge.data.wcr = None xfh.state_fission(nsdfg, SubgraphView(nstate, [read, write])) newnode = nstate.add_access(name) nstate.remove_edge(out_edge) edge = nstate.add_edge(out_edge.src, out_edge.src_conn, newnode, None, copy.deepcopy(out_edge.data)) for e in nstate.memlet_path(edge): e.data.data = name e.data.subset = subsets.Range([(0, 0, 1)]) if out_edge.data.subset.num_elements( ) == 1: # One element: tasklet wrt = nstate.add_tasklet( 'warpreduce', {'__a'}, {'__out'}, f'__out = dace::warpReduce<{credtype}, {ctype}>::reduce(__a);', dtypes.Language.CPP) nstate.add_edge(newnode, None, wrt, '__a', Memlet(name)) out_edge.data.wcr = None nstate.add_edge(wrt, '__out', out_edge.dst, None, out_edge.data) else: # More than one element: mapped tasklet raise NotImplementedError # End of WCR to warp reduction # Make nested SDFG out of new scope xfh.nest_state_subgraph(sdfg, graph, graph.scope_subgraph(new_me, False, False)) return new_me
def apply(self, sdfg: dace.SDFG): graph: dace.SDFGState = sdfg.node(self.state_id) map_entry: nodes.MapEntry = graph.node(self.subgraph[NestK._map_entry]) stencil: Stencil = graph.node(self.subgraph[NestK._stencil]) # Find dimension index and name pname = map_entry.map.params[0] dim_index = None for edge in graph.all_edges(stencil): if edge.data.data is None: # Empty memlet continue if len(edge.data.subset) == 3: for i, rng in enumerate(edge.data.subset.ndrange()): for r in rng: if (pname in map(str, r.free_symbols)): dim_index = i break if dim_index is not None: break if dim_index is not None: break ### map_exit = graph.exit_node(map_entry) # Reconnect external edges directly to stencil node for edge in graph.in_edges(map_entry): # Find matching internal edges tree = graph.memlet_tree(edge) for child in tree.children: memlet = propagation.propagate_memlet(graph, child.edge.data, map_entry, False) graph.add_edge(edge.src, edge.src_conn, stencil, child.edge.dst_conn, memlet) for edge in graph.out_edges(map_exit): # Find matching internal edges tree = graph.memlet_tree(edge) for child in tree.children: memlet = propagation.propagate_memlet(graph, child.edge.data, map_entry, False) graph.add_edge(stencil, child.edge.src_conn, edge.dst, edge.dst_conn, memlet) # Remove map graph.remove_nodes_from([map_entry, map_exit]) # Reshape stencil node computation based on nested map range stencil.shape[dim_index] = map_entry.map.range.num_elements() # Add dimensions to access and output fields add_dims = set() for edge in graph.in_edges(stencil): if edge.data.data and len(edge.data.subset) == 3: if stencil.accesses[edge.dst_conn][0][dim_index] is False: add_dims.add(edge.dst_conn) stencil.accesses[edge.dst_conn][0][dim_index] = True for edge in graph.out_edges(stencil): if edge.data.data and len(edge.data.subset) == 3: if stencil.output_fields[edge.src_conn][0][dim_index] is False: add_dims.add(edge.src_conn) stencil.output_fields[edge.src_conn][0][dim_index] = True # Change all instances in the code as well if stencil.code.language != dace.Language.Python: raise ValueError( 'For NestK to work, Stencil code language must be Python') for i, stmt in enumerate(stencil.code.code): stencil.code.code[i] = DimensionAdder(add_dims, dim_index).visit(stmt)