def apply(self, sdfg: sd.SDFG): # Obtain loop information guard: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_guard]) body: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_begin]) after: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._exit_state]) # Obtain iteration variable, range, and stride itervar, (start, end, step), (_, body_end) = find_for_loop( sdfg, guard, body, itervar=self.itervar) # Find all loop-body states states = set([body_end]) to_visit = [body] while to_visit: state = to_visit.pop(0) if state is body_end: continue for _, dst, _ in sdfg.out_edges(state): if dst not in states: to_visit.append(dst) states.add(state) # Nest loop-body states if len(states) > 1: # Find read/write sets read_set, write_set = set(), set() for state in states: rset, wset = state.read_and_write_sets() read_set |= rset write_set |= wset # Add data from edges for src in states: for dst in states: for edge in sdfg.edges_between(src, dst): for s in edge.data.free_symbols: if s in sdfg.arrays: read_set.add(s) # Find NestedSDFG's unique data rw_set = read_set | write_set unique_set = set() for name in rw_set: if not sdfg.arrays[name].transient: continue found = False for state in sdfg.states(): if state in states: continue for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data == name): found = True break if not found: unique_set.add(name) # Find NestedSDFG's connectors read_set = {n for n in read_set if n not in unique_set or not sdfg.arrays[n].transient} write_set = {n for n in write_set if n not in unique_set or not sdfg.arrays[n].transient} # Create NestedSDFG and add all loop-body states and edges # Also, find defined symbols in NestedSDFG fsymbols = set(sdfg.free_symbols) new_body = sdfg.add_state('single_state_body') nsdfg = SDFG("loop_body", constants=sdfg.constants, parent=new_body) nsdfg.add_node(body, is_start_state=True) body.parent = nsdfg exit_state = nsdfg.add_state('exit') nsymbols = dict() for state in states: if state is body: continue nsdfg.add_node(state) state.parent = nsdfg for state in states: if state is body: continue for src, dst, data in sdfg.in_edges(state): nsymbols.update({s: sdfg.symbols[s] for s in data.assignments.keys() if s in sdfg.symbols}) nsdfg.add_edge(src, dst, data) nsdfg.add_edge(body_end, exit_state, InterstateEdge()) # Move guard -> body edge to guard -> new_body for src, dst, data, in sdfg.edges_between(guard, body): sdfg.add_edge(src, new_body, data) # Move body_end -> guard edge to new_body -> guard for src, dst, data in sdfg.edges_between(body_end, guard): sdfg.add_edge(new_body, dst, data) # Delete loop-body states and edges from parent SDFG for state in states: for e in sdfg.all_edges(state): sdfg.remove_edge(e) sdfg.remove_node(state) # Add NestedSDFG arrays for name in read_set | write_set: nsdfg.arrays[name] = copy.deepcopy(sdfg.arrays[name]) nsdfg.arrays[name].transient = False for name in unique_set: nsdfg.arrays[name] = sdfg.arrays[name] del sdfg.arrays[name] # Add NestedSDFG node cnode = new_body.add_nested_sdfg(nsdfg, None, read_set, write_set) if sdfg.parent: for s, m in sdfg.parent_nsdfg_node.symbol_mapping.items(): if s not in cnode.symbol_mapping: cnode.symbol_mapping[s] = m nsdfg.add_symbol(s, sdfg.symbols[s]) for name in read_set: r = new_body.add_read(name) new_body.add_edge( r, None, cnode, name, memlet.Memlet.from_array(name, sdfg.arrays[name])) for name in write_set: w = new_body.add_write(name) new_body.add_edge( cnode, name, w, None, memlet.Memlet.from_array(name, sdfg.arrays[name])) # Fix SDFG symbols for sym in sdfg.free_symbols - fsymbols: del sdfg.symbols[sym] for sym, dtype in nsymbols.items(): nsdfg.symbols[sym] = dtype # Change body state reference body = new_body if (step < 0) == True: # If step is negative, we have to flip start and end to produce a # correct map with a positive increment start, end, step = end, start, -step # If necessary, make a nested SDFG with assignments isedge = sdfg.edges_between(guard, body)[0] symbols_to_remove = set() if len(isedge.data.assignments) > 0: nsdfg = helpers.nest_state_subgraph( sdfg, body, gr.SubgraphView(body, body.nodes())) for sym in isedge.data.free_symbols: if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors: continue if sym in sdfg.symbols: nsdfg.symbol_mapping[sym] = symbolic.pystr_to_symbolic(sym) nsdfg.sdfg.add_symbol(sym, sdfg.symbols[sym]) elif sym in sdfg.arrays: if sym in nsdfg.sdfg.arrays: raise NotImplementedError rnode = body.add_read(sym) nsdfg.add_in_connector(sym) desc = copy.deepcopy(sdfg.arrays[sym]) desc.transient = False nsdfg.sdfg.add_datadesc(sym, desc) body.add_edge(rnode, None, nsdfg, sym, memlet.Memlet(sym)) nstate = nsdfg.sdfg.node(0) init_state = nsdfg.sdfg.add_state_before(nstate) nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0] nisedge.data.assignments = isedge.data.assignments symbols_to_remove = set(nisedge.data.assignments.keys()) for k in nisedge.data.assignments.keys(): if k in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[k] isedge.data.assignments = {} source_nodes = body.source_nodes() sink_nodes = body.sink_nodes() map = nodes.Map(body.label + "_map", [itervar], [(start, end, step)]) entry = nodes.MapEntry(map) exit = nodes.MapExit(map) body.add_node(entry) body.add_node(exit) # If the map uses symbols from data containers, instantiate reads containers_to_read = entry.free_symbols & sdfg.arrays.keys() for rd in containers_to_read: # We are guaranteed that this is always a scalar, because # can_be_applied makes sure there are no sympy functions in each of # the loop expresions access_node = body.add_read(rd) body.add_memlet_path(access_node, entry, dst_conn=rd, memlet=memlet.Memlet(rd)) # Reroute all memlets through the entry and exit nodes for n in source_nodes: if isinstance(n, nodes.AccessNode): for e in body.out_edges(n): body.remove_edge(e) body.add_edge_pair(entry, e.dst, n, e.data, internal_connector=e.dst_conn) else: body.add_nedge(entry, n, memlet.Memlet()) for n in sink_nodes: if isinstance(n, nodes.AccessNode): for e in body.in_edges(n): body.remove_edge(e) body.add_edge_pair(exit, e.src, n, e.data, internal_connector=e.src_conn) else: body.add_nedge(n, exit, memlet.Memlet()) # Get rid of the loop exit condition edge after_edge = sdfg.edges_between(guard, after)[0] sdfg.remove_edge(after_edge) # Remove the assignment on the edge to the guard for e in sdfg.in_edges(guard): if itervar in e.data.assignments: del e.data.assignments[itervar] # Remove the condition on the entry edge condition_edge = sdfg.edges_between(guard, body)[0] condition_edge.data.condition = CodeBlock("1") # Get rid of backedge to guard sdfg.remove_edge(sdfg.edges_between(body, guard)[0]) # Route body directly to after state, maintaining any other assignments # it might have had sdfg.add_edge( body, after, sd.InterstateEdge(assignments=after_edge.data.assignments)) # If this had made the iteration variable a free symbol, we can remove # it from the SDFG symbols if itervar in sdfg.free_symbols: sdfg.remove_symbol(itervar) for sym in symbols_to_remove: if helpers.is_symbol_unused(sdfg, sym): sdfg.remove_symbol(sym)
def apply(self, sdfg: sd.SDFG): # Obtain loop information guard: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_guard]) body: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_begin]) after: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._exit_state]) # Obtain iteration variable, range, and stride itervar, (start, end, step), _ = find_for_loop(sdfg, guard, body) if (step < 0) == True: # If step is negative, we have to flip start and end to produce a # correct map with a positive increment start, end, step = end, start, -step # If necessary, make a nested SDFG with assignments isedge = sdfg.edges_between(guard, body)[0] symbols_to_remove = set() if len(isedge.data.assignments) > 0: nsdfg = helpers.nest_state_subgraph( sdfg, body, gr.SubgraphView(body, body.nodes())) for sym in isedge.data.free_symbols: if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors: continue if sym in sdfg.symbols: nsdfg.symbol_mapping[sym] = symbolic.pystr_to_symbolic(sym) nsdfg.sdfg.add_symbol(sym, sdfg.symbols[sym]) elif sym in sdfg.arrays: if sym in nsdfg.sdfg.arrays: raise NotImplementedError rnode = body.add_read(sym) nsdfg.add_in_connector(sym) desc = copy.deepcopy(sdfg.arrays[sym]) desc.transient = False nsdfg.sdfg.add_datadesc(sym, desc) body.add_edge(rnode, None, nsdfg, sym, memlet.Memlet(sym)) nstate = nsdfg.sdfg.node(0) init_state = nsdfg.sdfg.add_state_before(nstate) nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0] nisedge.data.assignments = isedge.data.assignments symbols_to_remove = set(nisedge.data.assignments.keys()) for k in nisedge.data.assignments.keys(): if k in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[k] isedge.data.assignments = {} source_nodes = body.source_nodes() sink_nodes = body.sink_nodes() map = nodes.Map(body.label + "_map", [itervar], [(start, end, step)]) entry = nodes.MapEntry(map) exit = nodes.MapExit(map) body.add_node(entry) body.add_node(exit) # If the map uses symbols from data containers, instantiate reads containers_to_read = entry.free_symbols & sdfg.arrays.keys() for rd in containers_to_read: # We are guaranteed that this is always a scalar, because # can_be_applied makes sure there are no sympy functions in each of # the loop expresions access_node = body.add_read(rd) body.add_memlet_path(access_node, entry, dst_conn=rd, memlet=memlet.Memlet(rd)) # Reroute all memlets through the entry and exit nodes for n in source_nodes: if isinstance(n, nodes.AccessNode): for e in body.out_edges(n): body.remove_edge(e) body.add_edge_pair(entry, e.dst, n, e.data, internal_connector=e.dst_conn) else: body.add_nedge(entry, n, memlet.Memlet()) for n in sink_nodes: if isinstance(n, nodes.AccessNode): for e in body.in_edges(n): body.remove_edge(e) body.add_edge_pair(exit, e.src, n, e.data, internal_connector=e.src_conn) else: body.add_nedge(n, exit, memlet.Memlet()) # Get rid of the loop exit condition edge after_edge = sdfg.edges_between(guard, after)[0] sdfg.remove_edge(after_edge) # Remove the assignment on the edge to the guard for e in sdfg.in_edges(guard): if itervar in e.data.assignments: del e.data.assignments[itervar] # Remove the condition on the entry edge condition_edge = sdfg.edges_between(guard, body)[0] condition_edge.data.condition = CodeBlock("1") # Get rid of backedge to guard sdfg.remove_edge(sdfg.edges_between(body, guard)[0]) # Route body directly to after state, maintaining any other assignments # it might have had sdfg.add_edge( body, after, sd.InterstateEdge(assignments=after_edge.data.assignments)) # If this had made the iteration variable a free symbol, we can remove # it from the SDFG symbols if itervar in sdfg.free_symbols: sdfg.remove_symbol(itervar) for sym in symbols_to_remove: if helpers.is_symbol_unused(sdfg, sym): sdfg.remove_symbol(sym)