def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Is this even a loop if not super().can_be_applied(graph, expr_index, sdfg, permissive): return False # Obtain loop information guard: sd.SDFGState = self.loop_guard body: sd.SDFGState = self.loop_begin # Obtain iteration variable, range, and stride loop_info = find_for_loop(sdfg, guard, body) if not loop_info: return False _, (start, end, step), _ = loop_info try: if step > 0 and start + step < end + 1: return False if step < 0 and start + step > end - 1: return False except: # if the relation can't be determined it's not a trivial loop return False return True
def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if not super().can_be_applied(graph, expr_index, sdfg, permissive): return False guard = self.loop_guard begin = self.loop_begin # If loop cannot be detected, fail found = find_for_loop(sdfg, guard, begin) if found is None: return False return True
def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): if not DetectLoop.can_be_applied(graph, candidate, expr_index, sdfg, strict): return False guard = graph.node(candidate[DetectLoop._loop_guard]) begin = graph.node(candidate[DetectLoop._loop_begin]) # If loop cannot be detected, fail found = find_for_loop(sdfg, guard, begin) if found is None: return False return True
def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Is this even a loop if not super().can_be_applied(graph, expr_index, sdfg, permissive): return False guard = self.loop_guard begin = self.loop_begin found = find_for_loop(graph, guard, begin) # If loop cannot be detected, fail if not found: return False _, rng, _ = found # If loop stride is not specialized or constant-sized, fail if symbolic.issymbolic(rng[2], sdfg.constants): return False # If loop range diff is not constant-sized, fail if symbolic.issymbolic(rng[1] - rng[0], sdfg.constants): return False return True
def apply(self, _, sdfg: sd.SDFG): # Obtain loop information guard: sd.SDFGState = self.loop_guard body: sd.SDFGState = self.loop_begin # Obtain iteration variable, range and stride itervar, (start, end, step), (_, body_end) = find_for_loop(sdfg, guard, body) # Find all loop-body states states = set() to_visit = [body] while to_visit: state = to_visit.pop(0) for _, dst, _ in sdfg.out_edges(state): if dst not in states and dst is not guard: to_visit.append(dst) states.add(state) for state in states: state.replace(itervar, start) # remove loop for body_inedge in sdfg.in_edges(body): sdfg.remove_edge(body_inedge) for body_outedge in sdfg.out_edges(body_end): sdfg.remove_edge(body_outedge) for guard_inedge in sdfg.in_edges(guard): guard_inedge.data.assignments = {} sdfg.add_edge(guard_inedge.src, body, guard_inedge.data) sdfg.remove_edge(guard_inedge) for guard_outedge in sdfg.out_edges(guard): guard_outedge.data.condition = CodeBlock("1") sdfg.add_edge(body_end, guard_outedge.dst, guard_outedge.data) sdfg.remove_edge(guard_outedge) sdfg.remove_node(guard) if itervar in sdfg.symbols and helpers.is_symbol_unused(sdfg, itervar): sdfg.remove_symbol(itervar)
def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): # Is this even a loop if not DetectLoop.can_be_applied(graph, candidate, expr_index, sdfg, strict): return False guard = graph.node(candidate[DetectLoop._loop_guard]) begin = graph.node(candidate[DetectLoop._loop_begin]) found = find_for_loop(graph, guard, begin) # If loop cannot be detected, fail if not found: return False _, rng, _ = found # If loop stride is not specialized or constant-sized, fail if symbolic.issymbolic(rng[2], sdfg.constants): return False # If loop range diff is not constant-sized, fail if symbolic.issymbolic(rng[1] - rng[0], sdfg.constants): return False return True
def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Is this even a loop if not super().can_be_applied(graph, expr_index, sdfg, permissive): return False # Obtain loop information guard: sd.SDFGState = self.loop_guard body: sd.SDFGState = self.loop_begin after: sd.SDFGState = self.exit_state # Obtain iteration variable, range, and stride loop_info = find_for_loop(sdfg, guard, body) if not loop_info: return False itervar, (start, end, step), (_, body_end) = loop_info if step not in [-1, 1]: return False # Body must contain a single state if body != body_end: return False # Check if body contains exactly one map maps = [node for node in body.nodes() if isinstance(node, nodes.MapEntry)] if len(maps) != 1: return False # Check that everything else is independent of the loop's itervar subgraph = body.scope_subgraph(maps[0]) map_exit = body.exit_node(maps[0]) descs: Set[str] = set() for e in body.edges(): if not e.data.is_empty(): descs.add(e.data.data) if e.src in subgraph.nodes() or e.dst in subgraph.nodes(): continue if e.dst is maps[0] and isinstance(e.src, nodes.AccessNode): continue if e.src is map_exit and isinstance(e.dst, nodes.AccessNode): continue if str(itervar) in e.data.free_symbols: return False for n in body.nodes(): if n in subgraph.nodes(): continue if str(itervar) in n.free_symbols: return False # Check for iteration variable in map and data descriptors if str(itervar) in maps[0].free_symbols: return False for arr in descs: if str(itervar) in set(map(str, sdfg.arrays[arr].free_symbols)): return False def test_subset_dependency(subset: sbs.Subset, mparams: Set[int]) -> Tuple[bool, List[int]]: dims = [] for i, r in enumerate(subset): if not isinstance(r, (list, tuple)): r = [r] fsymbols = set() for token in r: if symbolic.issymbolic(token): fsymbols = fsymbols.union({str(s) for s in token.free_symbols}) if itervar in fsymbols: if fsymbols.intersection(mparams): return (False, []) else: dims.append(i) return (True, dims) # Check that Map memlets depend on itervar in a consistent manner # a. A container must either not depend at all on itervar, or depend on it always in the same dimensions. # b. Abort when a dimension depends on both the itervar and a Map parameter. mparams = set(maps[0].map.params) data_dependency = dict() for e in body.edges(): if e.src in subgraph.nodes() and e.dst in subgraph.nodes(): if itervar in e.data.free_symbols: for i, subset in enumerate((e.data.src_subset, e.data.dst_subset)): if subset: if i == 0: access = body.memlet_path(e)[0].src else: access = body.memlet_path(e)[-1].dst passed, dims = test_subset_dependency(subset, mparams) if not passed: return False if dims: if access.data in data_dependency: if data_dependency[access.data] != dims: return False else: data_dependency[access.data] = dims for node in body.nodes(): if isinstance(node, nodes.AccessNode): if body.in_edges(node).count(True) > 1: return False if body.out_edges(node).count(True) > 1: return False return True
def apply(self, _, sdfg: sd.SDFG): # Obtain loop information guard: sd.SDFGState = self.loop_guard body: sd.SDFGState = self.loop_begin # Obtain iteration variable, range, and stride itervar, (start, end, step), _ = find_for_loop(sdfg, guard, body) forward_loop = step > 0 for node in body.nodes(): if isinstance(node, nodes.MapEntry): map_entry = node if isinstance(node, nodes.MapExit): map_exit = node # nest map's content in sdfg map_subgraph = body.scope_subgraph(map_entry, include_entry=False, include_exit=False) nsdfg = helpers.nest_state_subgraph(sdfg, body, map_subgraph, full_data=True) # replicate loop in nested sdfg new_before, new_guard, new_after = nsdfg.sdfg.add_loop( before_state=None, loop_state=nsdfg.sdfg.nodes()[0], loop_end_state=None, after_state=None, loop_var=itervar, initialize_expr=f'{start}', condition_expr=f'{itervar} <= {end}' if forward_loop else f'{itervar} >= {end}', increment_expr=f'{itervar} + {step}' if forward_loop else f'{itervar} - {abs(step)}') # remove outer loop before_guard_edge = nsdfg.sdfg.edges_between(new_before, new_guard)[0] for e in nsdfg.sdfg.out_edges(new_guard): if e.dst is new_after: guard_after_edge = e else: guard_body_edge = e for body_inedge in sdfg.in_edges(body): if body_inedge.src is guard: guard_body_edge.data.assignments.update(body_inedge.data.assignments) sdfg.remove_edge(body_inedge) for body_outedge in sdfg.out_edges(body): sdfg.remove_edge(body_outedge) for guard_inedge in sdfg.in_edges(guard): before_guard_edge.data.assignments.update(guard_inedge.data.assignments) guard_inedge.data.assignments = {} sdfg.add_edge(guard_inedge.src, body, guard_inedge.data) sdfg.remove_edge(guard_inedge) for guard_outedge in sdfg.out_edges(guard): if guard_outedge.dst is body: guard_body_edge.data.assignments.update(guard_outedge.data.assignments) else: guard_after_edge.data.assignments.update(guard_outedge.data.assignments) guard_outedge.data.condition = CodeBlock("1") sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data) sdfg.remove_edge(guard_outedge) sdfg.remove_node(guard) if itervar in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[itervar] if itervar in sdfg.symbols: del sdfg.symbols[itervar] # Add missing data/symbols for s in nsdfg.sdfg.free_symbols: if s in nsdfg.symbol_mapping: continue if s in sdfg.symbols: nsdfg.symbol_mapping[s] = s elif s in sdfg.arrays: desc = sdfg.arrays[s] access = body.add_access(s) conn = nsdfg.sdfg.add_datadesc(s, copy.deepcopy(desc)) nsdfg.sdfg.arrays[s].transient = False nsdfg.add_in_connector(conn) body.add_memlet_path(access, map_entry, nsdfg, memlet=Memlet.from_array(s, desc), dst_conn=conn) else: raise NotImplementedError(f"Free symbol {s} is neither a symbol nor data.") to_delete = set() for s in nsdfg.symbol_mapping: if s not in nsdfg.sdfg.free_symbols: to_delete.add(s) for s in to_delete: del nsdfg.symbol_mapping[s] # propagate scope for correct volumes scope_tree = ScopeTree(map_entry, map_exit) scope_tree.parent = ScopeTree(None, None) # The first execution helps remove apperances of symbols # that are now defined only in the nested SDFG in memlets. propagation.propagate_memlets_scope(sdfg, body, scope_tree) for s in to_delete: if helpers.is_symbol_unused(sdfg, s): sdfg.remove_symbol(s) from dace.transformation.interstate import RefineNestedAccess transformation = RefineNestedAccess() transformation.setup_match(sdfg, 0, sdfg.node_id(body), {RefineNestedAccess.nsdfg: body.node_id(nsdfg)}, 0) transformation.apply(body, sdfg) # Second propagation for refined accesses. propagation.propagate_memlets_scope(sdfg, body, scope_tree)
def can_be_applied(self, graph, candidate, expr_index, sdfg, strict=False): # Is this even a loop if not DetectLoop.can_be_applied(graph, candidate, expr_index, sdfg, strict): return False guard = graph.node(candidate[DetectLoop._loop_guard]) begin = graph.node(candidate[DetectLoop._loop_begin]) # Guard state should not contain any dataflow if len(guard.nodes()) != 0: return False # If loop cannot be detected, fail found = find_for_loop(graph, guard, begin, itervar=self.itervar) if not found: return False itervar, (start, end, step), (_, body_end) = found # We cannot handle symbols read from data containers unless they are # scalar for expr in (start, end, step): if symbolic.contains_sympy_functions(expr): return False # Find all loop-body states states = set([body_end]) to_visit = [begin] while to_visit: state = to_visit.pop(0) if state is body_end: continue for _, dst, _ in graph.out_edges(state): if dst not in states: to_visit.append(dst) states.add(state) write_set = set() for state in states: _, wset = state.read_and_write_sets() write_set |= wset # Get access nodes from other states to isolate local loop variables other_access_nodes = set() for state in sdfg.nodes(): if state in states: continue other_access_nodes |= set(n.data for n in state.data_nodes() if sdfg.arrays[n.data].transient) # Add non-transient nodes from loop state for state in states: other_access_nodes |= set(n.data for n in state.data_nodes() if not sdfg.arrays[n.data].transient) write_memlets = defaultdict(list) itersym = symbolic.pystr_to_symbolic(itervar) a = sp.Wild('a', exclude=[itersym]) b = sp.Wild('b', exclude=[itersym]) for state in states: for dn in state.data_nodes(): if dn.data not in other_access_nodes: continue # Take all writes that are not conflicted into consideration if dn.data in write_set: for e in state.in_edges(dn): if e.data.dynamic and e.data.wcr is None: # If pointers are involved, give up return False # To be sure that the value is only written at unique # indices per loop iteration, we want to match symbols # of the form "a*i+b" where a >= 1, and i is the iteration # variable. The iteration variable must be used. if e.data.wcr is None: dst_subset = e.data.get_dst_subset(e, state) if not _check_range(dst_subset, a, itersym, b, step): return False # End of check write_memlets[dn.data].append(e.data) # After looping over relevant writes, consider reads that may overlap for state in states: for dn in state.data_nodes(): if dn.data not in other_access_nodes: continue data = dn.data if data in write_memlets: # Import as necessary from dace.sdfg.propagation import propagate_subset for e in state.out_edges(dn): # If the same container is both read and written, only match if # it read and written at locations that will not create data races if e.data.dynamic and e.data.src_subset.num_elements() != 1: # If pointers are involved, give up return False src_subset = e.data.get_src_subset(e, state) if not _check_range(src_subset, a, itersym, b, step): return False pread = propagate_subset([e.data], sdfg.arrays[data], [itervar], subsets.Range([(start, end, step) ])) for candidate in write_memlets[data]: # Simple case: read and write are in the same subset if e.data.subset == candidate.subset: break # Propagated read does not overlap with propagated write pwrite = propagate_subset([candidate], sdfg.arrays[data], [itervar], subsets.Range([(start, end, step)])) if subsets.intersects(pread.subset, pwrite.subset) is False: break return False # Check that the iteration variable is not used on other edges or states # before it is reassigned prior_states = True for state in cfg.stateorder_topological_sort(sdfg): # Skip all states up to guard if prior_states: if state is begin: prior_states = False continue # We do not need to check the loop-body states if state in states: continue if itervar in state.free_symbols: return False # Don't continue in this direction, as the variable has # now been reassigned # TODO: Handle case of subset of out_edges if all(itervar in e.data.assignments for e in sdfg.out_edges(state)): break return True
def apply(self, sdfg: sd.SDFG): # Obtain loop information guard: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_guard]) body: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_begin]) after: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._exit_state]) # Obtain iteration variable, range, and stride itervar, (start, end, step), (_, body_end) = find_for_loop( sdfg, guard, body, itervar=self.itervar) # Find all loop-body states states = set([body_end]) to_visit = [body] while to_visit: state = to_visit.pop(0) if state is body_end: continue for _, dst, _ in sdfg.out_edges(state): if dst not in states: to_visit.append(dst) states.add(state) # Nest loop-body states if len(states) > 1: # Find read/write sets read_set, write_set = set(), set() for state in states: rset, wset = state.read_and_write_sets() read_set |= rset write_set |= wset # Add data from edges for src in states: for dst in states: for edge in sdfg.edges_between(src, dst): for s in edge.data.free_symbols: if s in sdfg.arrays: read_set.add(s) # Find NestedSDFG's unique data rw_set = read_set | write_set unique_set = set() for name in rw_set: if not sdfg.arrays[name].transient: continue found = False for state in sdfg.states(): if state in states: continue for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data == name): found = True break if not found: unique_set.add(name) # Find NestedSDFG's connectors read_set = {n for n in read_set if n not in unique_set or not sdfg.arrays[n].transient} write_set = {n for n in write_set if n not in unique_set or not sdfg.arrays[n].transient} # Create NestedSDFG and add all loop-body states and edges # Also, find defined symbols in NestedSDFG fsymbols = set(sdfg.free_symbols) new_body = sdfg.add_state('single_state_body') nsdfg = SDFG("loop_body", constants=sdfg.constants, parent=new_body) nsdfg.add_node(body, is_start_state=True) body.parent = nsdfg exit_state = nsdfg.add_state('exit') nsymbols = dict() for state in states: if state is body: continue nsdfg.add_node(state) state.parent = nsdfg for state in states: if state is body: continue for src, dst, data in sdfg.in_edges(state): nsymbols.update({s: sdfg.symbols[s] for s in data.assignments.keys() if s in sdfg.symbols}) nsdfg.add_edge(src, dst, data) nsdfg.add_edge(body_end, exit_state, InterstateEdge()) # Move guard -> body edge to guard -> new_body for src, dst, data, in sdfg.edges_between(guard, body): sdfg.add_edge(src, new_body, data) # Move body_end -> guard edge to new_body -> guard for src, dst, data in sdfg.edges_between(body_end, guard): sdfg.add_edge(new_body, dst, data) # Delete loop-body states and edges from parent SDFG for state in states: for e in sdfg.all_edges(state): sdfg.remove_edge(e) sdfg.remove_node(state) # Add NestedSDFG arrays for name in read_set | write_set: nsdfg.arrays[name] = copy.deepcopy(sdfg.arrays[name]) nsdfg.arrays[name].transient = False for name in unique_set: nsdfg.arrays[name] = sdfg.arrays[name] del sdfg.arrays[name] # Add NestedSDFG node cnode = new_body.add_nested_sdfg(nsdfg, None, read_set, write_set) if sdfg.parent: for s, m in sdfg.parent_nsdfg_node.symbol_mapping.items(): if s not in cnode.symbol_mapping: cnode.symbol_mapping[s] = m nsdfg.add_symbol(s, sdfg.symbols[s]) for name in read_set: r = new_body.add_read(name) new_body.add_edge( r, None, cnode, name, memlet.Memlet.from_array(name, sdfg.arrays[name])) for name in write_set: w = new_body.add_write(name) new_body.add_edge( cnode, name, w, None, memlet.Memlet.from_array(name, sdfg.arrays[name])) # Fix SDFG symbols for sym in sdfg.free_symbols - fsymbols: del sdfg.symbols[sym] for sym, dtype in nsymbols.items(): nsdfg.symbols[sym] = dtype # Change body state reference body = new_body if (step < 0) == True: # If step is negative, we have to flip start and end to produce a # correct map with a positive increment start, end, step = end, start, -step # If necessary, make a nested SDFG with assignments isedge = sdfg.edges_between(guard, body)[0] symbols_to_remove = set() if len(isedge.data.assignments) > 0: nsdfg = helpers.nest_state_subgraph( sdfg, body, gr.SubgraphView(body, body.nodes())) for sym in isedge.data.free_symbols: if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors: continue if sym in sdfg.symbols: nsdfg.symbol_mapping[sym] = symbolic.pystr_to_symbolic(sym) nsdfg.sdfg.add_symbol(sym, sdfg.symbols[sym]) elif sym in sdfg.arrays: if sym in nsdfg.sdfg.arrays: raise NotImplementedError rnode = body.add_read(sym) nsdfg.add_in_connector(sym) desc = copy.deepcopy(sdfg.arrays[sym]) desc.transient = False nsdfg.sdfg.add_datadesc(sym, desc) body.add_edge(rnode, None, nsdfg, sym, memlet.Memlet(sym)) nstate = nsdfg.sdfg.node(0) init_state = nsdfg.sdfg.add_state_before(nstate) nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0] nisedge.data.assignments = isedge.data.assignments symbols_to_remove = set(nisedge.data.assignments.keys()) for k in nisedge.data.assignments.keys(): if k in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[k] isedge.data.assignments = {} source_nodes = body.source_nodes() sink_nodes = body.sink_nodes() map = nodes.Map(body.label + "_map", [itervar], [(start, end, step)]) entry = nodes.MapEntry(map) exit = nodes.MapExit(map) body.add_node(entry) body.add_node(exit) # If the map uses symbols from data containers, instantiate reads containers_to_read = entry.free_symbols & sdfg.arrays.keys() for rd in containers_to_read: # We are guaranteed that this is always a scalar, because # can_be_applied makes sure there are no sympy functions in each of # the loop expresions access_node = body.add_read(rd) body.add_memlet_path(access_node, entry, dst_conn=rd, memlet=memlet.Memlet(rd)) # Reroute all memlets through the entry and exit nodes for n in source_nodes: if isinstance(n, nodes.AccessNode): for e in body.out_edges(n): body.remove_edge(e) body.add_edge_pair(entry, e.dst, n, e.data, internal_connector=e.dst_conn) else: body.add_nedge(entry, n, memlet.Memlet()) for n in sink_nodes: if isinstance(n, nodes.AccessNode): for e in body.in_edges(n): body.remove_edge(e) body.add_edge_pair(exit, e.src, n, e.data, internal_connector=e.src_conn) else: body.add_nedge(n, exit, memlet.Memlet()) # Get rid of the loop exit condition edge after_edge = sdfg.edges_between(guard, after)[0] sdfg.remove_edge(after_edge) # Remove the assignment on the edge to the guard for e in sdfg.in_edges(guard): if itervar in e.data.assignments: del e.data.assignments[itervar] # Remove the condition on the entry edge condition_edge = sdfg.edges_between(guard, body)[0] condition_edge.data.condition = CodeBlock("1") # Get rid of backedge to guard sdfg.remove_edge(sdfg.edges_between(body, guard)[0]) # Route body directly to after state, maintaining any other assignments # it might have had sdfg.add_edge( body, after, sd.InterstateEdge(assignments=after_edge.data.assignments)) # If this had made the iteration variable a free symbol, we can remove # it from the SDFG symbols if itervar in sdfg.free_symbols: sdfg.remove_symbol(itervar) for sym in symbols_to_remove: if helpers.is_symbol_unused(sdfg, sym): sdfg.remove_symbol(sym)
def apply(self, sdfg: sd.SDFG): # Obtain loop information guard: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_guard]) body: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_begin]) after: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._exit_state]) # Obtain iteration variable, range, and stride itervar, (start, end, step), _ = find_for_loop(sdfg, guard, body) if (step < 0) == True: # If step is negative, we have to flip start and end to produce a # correct map with a positive increment start, end, step = end, start, -step # If necessary, make a nested SDFG with assignments isedge = sdfg.edges_between(guard, body)[0] symbols_to_remove = set() if len(isedge.data.assignments) > 0: nsdfg = helpers.nest_state_subgraph( sdfg, body, gr.SubgraphView(body, body.nodes())) for sym in isedge.data.free_symbols: if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors: continue if sym in sdfg.symbols: nsdfg.symbol_mapping[sym] = symbolic.pystr_to_symbolic(sym) nsdfg.sdfg.add_symbol(sym, sdfg.symbols[sym]) elif sym in sdfg.arrays: if sym in nsdfg.sdfg.arrays: raise NotImplementedError rnode = body.add_read(sym) nsdfg.add_in_connector(sym) desc = copy.deepcopy(sdfg.arrays[sym]) desc.transient = False nsdfg.sdfg.add_datadesc(sym, desc) body.add_edge(rnode, None, nsdfg, sym, memlet.Memlet(sym)) nstate = nsdfg.sdfg.node(0) init_state = nsdfg.sdfg.add_state_before(nstate) nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0] nisedge.data.assignments = isedge.data.assignments symbols_to_remove = set(nisedge.data.assignments.keys()) for k in nisedge.data.assignments.keys(): if k in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[k] isedge.data.assignments = {} source_nodes = body.source_nodes() sink_nodes = body.sink_nodes() map = nodes.Map(body.label + "_map", [itervar], [(start, end, step)]) entry = nodes.MapEntry(map) exit = nodes.MapExit(map) body.add_node(entry) body.add_node(exit) # If the map uses symbols from data containers, instantiate reads containers_to_read = entry.free_symbols & sdfg.arrays.keys() for rd in containers_to_read: # We are guaranteed that this is always a scalar, because # can_be_applied makes sure there are no sympy functions in each of # the loop expresions access_node = body.add_read(rd) body.add_memlet_path(access_node, entry, dst_conn=rd, memlet=memlet.Memlet(rd)) # Reroute all memlets through the entry and exit nodes for n in source_nodes: if isinstance(n, nodes.AccessNode): for e in body.out_edges(n): body.remove_edge(e) body.add_edge_pair(entry, e.dst, n, e.data, internal_connector=e.dst_conn) else: body.add_nedge(entry, n, memlet.Memlet()) for n in sink_nodes: if isinstance(n, nodes.AccessNode): for e in body.in_edges(n): body.remove_edge(e) body.add_edge_pair(exit, e.src, n, e.data, internal_connector=e.src_conn) else: body.add_nedge(n, exit, memlet.Memlet()) # Get rid of the loop exit condition edge after_edge = sdfg.edges_between(guard, after)[0] sdfg.remove_edge(after_edge) # Remove the assignment on the edge to the guard for e in sdfg.in_edges(guard): if itervar in e.data.assignments: del e.data.assignments[itervar] # Remove the condition on the entry edge condition_edge = sdfg.edges_between(guard, body)[0] condition_edge.data.condition = CodeBlock("1") # Get rid of backedge to guard sdfg.remove_edge(sdfg.edges_between(body, guard)[0]) # Route body directly to after state, maintaining any other assignments # it might have had sdfg.add_edge( body, after, sd.InterstateEdge(assignments=after_edge.data.assignments)) # If this had made the iteration variable a free symbol, we can remove # it from the SDFG symbols if itervar in sdfg.free_symbols: sdfg.remove_symbol(itervar) for sym in symbols_to_remove: if helpers.is_symbol_unused(sdfg, sym): sdfg.remove_symbol(sym)
def apply(self, _, sdfg: sd.SDFG): #################################################################### # Obtain loop information guard: sd.SDFGState = self.loop_guard begin: sd.SDFGState = self.loop_begin after_state: sd.SDFGState = self.exit_state # Obtain iteration variable, range, and stride condition_edge = sdfg.edges_between(guard, begin)[0] not_condition_edge = sdfg.edges_between(guard, after_state)[0] itervar, rng, loop_struct = find_for_loop(sdfg, guard, begin) # Get loop states loop_states = list( sdutil.dfs_conditional(sdfg, sources=[begin], condition=lambda _, child: child != guard)) first_id = loop_states.index(begin) last_state = loop_struct[1] last_id = loop_states.index(last_state) loop_subgraph = gr.SubgraphView(sdfg, loop_states) #################################################################### # Transform if self.begin: # If begin, change initialization assignment and prepend states before # guard init_edges = [] before_states = loop_struct[0] for before_state in before_states: init_edge = sdfg.edges_between(before_state, guard)[0] init_edge.data.assignments[itervar] = str(rng[0] + self.count * rng[2]) init_edges.append(init_edge) append_states = before_states # Add `count` states, each with instantiated iteration variable for i in range(self.count): # Instantiate loop states with iterate value state_name: str = 'start_' + itervar + str(i * rng[2]) state_name = state_name.replace('-', 'm').replace( '+', 'p').replace('*', 'M').replace('/', 'D') new_states = self.instantiate_loop( sdfg, loop_states, loop_subgraph, itervar, rng[0] + i * rng[2], state_name, ) # Connect states to before the loop with unconditional edges for append_state in append_states: sdfg.add_edge(append_state, new_states[first_id], sd.InterstateEdge()) append_states = [new_states[last_id]] # Reconnect edge to guard state from last peeled iteration for append_state in append_states: if append_state not in before_states: for init_edge in init_edges: sdfg.remove_edge(init_edge) sdfg.add_edge(append_state, guard, init_edges[0].data) else: # If begin, change initialization assignment and prepend states before # guard itervar_sym = pystr_to_symbolic(itervar) condition_edge.data.condition = CodeBlock( self._modify_cond(condition_edge.data.condition, itervar, rng[2])) not_condition_edge.data.condition = CodeBlock( self._modify_cond(not_condition_edge.data.condition, itervar, rng[2])) prepend_state = after_state # Add `count` states, each with instantiated iteration variable for i in reversed(range(self.count)): # Instantiate loop states with iterate value state_name: str = 'end_' + itervar + str(-i * rng[2]) state_name = state_name.replace('-', 'm').replace( '+', 'p').replace('*', 'M').replace('/', 'D') new_states = self.instantiate_loop( sdfg, loop_states, loop_subgraph, itervar, itervar_sym + i * rng[2], state_name, ) # Connect states to before the loop with unconditional edges sdfg.add_edge(new_states[last_id], prepend_state, sd.InterstateEdge()) prepend_state = new_states[first_id] # Reconnect edge to guard state from last peeled iteration if prepend_state != after_state: sdfg.remove_edge(not_condition_edge) sdfg.add_edge(guard, prepend_state, not_condition_edge.data)
def apply(self, sdfg): # Obtain loop information guard: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_guard]) begin: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_begin]) after_state: sd.SDFGState = sdfg.node( self.subgraph[DetectLoop._exit_state]) # Obtain iteration variable, range, and stride, together with the last # state(s) before the loop and the last loop state. itervar, rng, loop_struct = find_for_loop(sdfg, guard, begin) # Loop must be fully unrollable for now. if self.count != 0: raise NotImplementedError # TODO(later) # Get loop states loop_states = list( sdutil.dfs_conditional(sdfg, sources=[begin], condition=lambda _, child: child != guard)) first_id = loop_states.index(begin) last_state = loop_struct[1] last_id = loop_states.index(last_state) loop_subgraph = gr.SubgraphView(sdfg, loop_states) try: start, end, stride = (r for r in rng) stride = symbolic.evaluate(stride, sdfg.constants) loop_diff = int(symbolic.evaluate(end - start + 1, sdfg.constants)) is_symbolic = any([symbolic.issymbolic(r) for r in rng[:2]]) except TypeError: raise TypeError('Loop difference and strides cannot be symbolic.') # Create states for loop subgraph unrolled_states = [] for i in range(0, loop_diff, stride): current_index = start + i # Instantiate loop states with iterate value new_states = self.instantiate_loop(sdfg, loop_states, loop_subgraph, itervar, current_index, str(i) if is_symbolic else None) # Connect iterations with unconditional edges if len(unrolled_states) > 0: sdfg.add_edge(unrolled_states[-1][1], new_states[first_id], sd.InterstateEdge()) unrolled_states.append((new_states[first_id], new_states[last_id])) # Get any assignments that might be on the edge to the after state after_assignments = (sdfg.edges_between( guard, after_state)[0].data.assignments) # Connect new states to before and after states without conditions if unrolled_states: before_states = loop_struct[0] for before_state in before_states: sdfg.add_edge(before_state, unrolled_states[0][0], sd.InterstateEdge()) sdfg.add_edge(unrolled_states[-1][1], after_state, sd.InterstateEdge(assignments=after_assignments)) # Remove old states from SDFG sdfg.remove_nodes_from([guard] + loop_states)