def _candidates(nsdfg: nodes.NestedSDFG) -> Set[str]: candidates = set(nsdfg.symbol_mapping.keys()) if len(candidates) == 0: return set() for desc in nsdfg.sdfg.arrays.values(): candidates -= set(map(str, desc.free_symbols)) ignore = set() for nstate in cfg.stateorder_topological_sort(nsdfg.sdfg): state_syms = nstate.free_symbols # Try to be conservative with C++ tasklets for node in nstate.nodes(): if (isinstance(node, nodes.Tasklet) and node.language is dtypes.Language.CPP): for candidate in candidates: if re.findall(r'\b%s\b' % re.escape(candidate), node.code.as_string): state_syms.add(candidate) # Any symbol used in this state is considered used candidates -= (state_syms - ignore) if len(candidates) == 0: return set() # Any symbol that is set in all outgoing edges is ignored from # this point local_ignore = None for e in nsdfg.sdfg.out_edges(nstate): # Look for symbols in condition candidates -= (set( map(str, symbolic.symbols_in_ast( e.data.condition.code[0]))) - ignore) for assign in e.data.assignments.values(): candidates -= ( symbolic.free_symbols_and_functions(assign) - ignore) if local_ignore is None: local_ignore = set(e.data.assignments.keys()) else: local_ignore &= e.data.assignments.keys() if local_ignore is not None: ignore |= local_ignore return candidates
def can_be_applied(self, graph, candidate, expr_index, sdfg, strict=False): # Is this even a loop if not DetectLoop.can_be_applied(graph, candidate, expr_index, sdfg, strict): return False guard = graph.node(candidate[DetectLoop._loop_guard]) begin = graph.node(candidate[DetectLoop._loop_begin]) # Guard state should not contain any dataflow if len(guard.nodes()) != 0: return False # If loop cannot be detected, fail found = find_for_loop(graph, guard, begin, itervar=self.itervar) if not found: return False itervar, (start, end, step), (_, body_end) = found # We cannot handle symbols read from data containers unless they are # scalar for expr in (start, end, step): if symbolic.contains_sympy_functions(expr): return False # Find all loop-body states states = set([body_end]) to_visit = [begin] while to_visit: state = to_visit.pop(0) if state is body_end: continue for _, dst, _ in graph.out_edges(state): if dst not in states: to_visit.append(dst) states.add(state) write_set = set() for state in states: _, wset = state.read_and_write_sets() write_set |= wset # Get access nodes from other states to isolate local loop variables other_access_nodes = set() for state in sdfg.nodes(): if state in states: continue other_access_nodes |= set(n.data for n in state.data_nodes() if sdfg.arrays[n.data].transient) # Add non-transient nodes from loop state for state in states: other_access_nodes |= set(n.data for n in state.data_nodes() if not sdfg.arrays[n.data].transient) write_memlets = defaultdict(list) itersym = symbolic.pystr_to_symbolic(itervar) a = sp.Wild('a', exclude=[itersym]) b = sp.Wild('b', exclude=[itersym]) for state in states: for dn in state.data_nodes(): if dn.data not in other_access_nodes: continue # Take all writes that are not conflicted into consideration if dn.data in write_set: for e in state.in_edges(dn): if e.data.dynamic and e.data.wcr is None: # If pointers are involved, give up return False # To be sure that the value is only written at unique # indices per loop iteration, we want to match symbols # of the form "a*i+b" where a >= 1, and i is the iteration # variable. The iteration variable must be used. if e.data.wcr is None: dst_subset = e.data.get_dst_subset(e, state) if not _check_range(dst_subset, a, itersym, b, step): return False # End of check write_memlets[dn.data].append(e.data) # After looping over relevant writes, consider reads that may overlap for state in states: for dn in state.data_nodes(): if dn.data not in other_access_nodes: continue data = dn.data if data in write_memlets: # Import as necessary from dace.sdfg.propagation import propagate_subset for e in state.out_edges(dn): # If the same container is both read and written, only match if # it read and written at locations that will not create data races if e.data.dynamic and e.data.src_subset.num_elements() != 1: # If pointers are involved, give up return False src_subset = e.data.get_src_subset(e, state) if not _check_range(src_subset, a, itersym, b, step): return False pread = propagate_subset([e.data], sdfg.arrays[data], [itervar], subsets.Range([(start, end, step) ])) for candidate in write_memlets[data]: # Simple case: read and write are in the same subset if e.data.subset == candidate.subset: break # Propagated read does not overlap with propagated write pwrite = propagate_subset([candidate], sdfg.arrays[data], [itervar], subsets.Range([(start, end, step)])) if subsets.intersects(pread.subset, pwrite.subset) is False: break return False # Check that the iteration variable is not used on other edges or states # before it is reassigned prior_states = True for state in cfg.stateorder_topological_sort(sdfg): # Skip all states up to guard if prior_states: if state is begin: prior_states = False continue # We do not need to check the loop-body states if state in states: continue if itervar in state.free_symbols: return False # Don't continue in this direction, as the variable has # now been reassigned # TODO: Handle case of subset of out_edges if all(itervar in e.data.assignments for e in sdfg.out_edges(state)): break return True
def apply_pass( self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[SDFGState, Set[str]]]: """ Removes unreachable dataflow throughout SDFG states. :param sdfg: The SDFG to modify. :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass results as ``{Pass subclass name: returned object from pass}``. If not run in a pipeline, an empty dictionary is expected. :return: A dictionary mapping states to removed data descriptor names, or None if nothing changed. """ # Depends on the following analysis passes: # * State reachability # * Read/write access sets per state reachable: Dict[SDFGState, Set[SDFGState]] = pipeline_results['StateReachability'] access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]] = pipeline_results['AccessSets'] result: Dict[SDFGState, Set[str]] = defaultdict(set) # Traverse SDFG backwards for state in reversed(list(cfg.stateorder_topological_sort(sdfg))): ############################################# # Analysis ############################################# # Compute states where memory will no longer be read writes = access_sets[state][1] descendants = reachable[state] descendant_reads = set().union(*(access_sets[succ][0] for succ in descendants)) no_longer_used: Set[str] = set(data for data in writes if data not in descendant_reads) # Compute dead nodes dead_nodes: List[nodes.Node] = [] # Propagate deadness backwards within a state for node in sdutil.dfs_topological_sort(state, reverse=True): if self._is_node_dead(node, sdfg, state, dead_nodes, no_longer_used): dead_nodes.append(node) # Scope exit nodes are only dead if their corresponding entry nodes are live_nodes = set() for node in dead_nodes: if isinstance(node, nodes.ExitNode) and state.entry_node( node) not in dead_nodes: live_nodes.add(node) dead_nodes = dtypes.deduplicate( [n for n in dead_nodes if n not in live_nodes]) if not dead_nodes: continue # Remove nodes while preserving scopes scopes_to_reconnect: Set[nodes.Node] = set() for node in state.nodes(): # Look for scope exits that will be disconnected if isinstance(node, nodes.ExitNode) and node not in dead_nodes: if any(n in dead_nodes for n in state.predecessors(node)): scopes_to_reconnect.add(node) # Two types of scope disconnections may occur: # 1. Two scope exits will no longer be connected # 2. A predecessor of dead nodes is in a scope and not connected to its exit # Case (1) is taken care of by ``remove_memlet_path`` # Case (2) is handled below # Reconnect scopes if scopes_to_reconnect: schildren = state.scope_children() for exit_node in scopes_to_reconnect: entry_node = state.entry_node(exit_node) for node in schildren[entry_node]: if node is exit_node: continue if isinstance(node, nodes.EntryNode): node = state.exit_node(node) # If node will be disconnected from exit node, add an empty memlet if all(succ in dead_nodes for succ in state.successors(node)): state.add_nedge(node, exit_node, Memlet()) ############################################# # Removal ############################################# predecessor_nsdfgs: Dict[nodes.NestedSDFG, Set[str]] = defaultdict(set) for node in dead_nodes: # Remove memlet paths and connectors pertaining to dead nodes for e in state.in_edges(node): mtree = state.memlet_tree(e) for leaf in mtree.leaves(): # Keep track of predecessors of removed nodes for connector pruning if isinstance(leaf.src, nodes.NestedSDFG): predecessor_nsdfgs[leaf.src].add(leaf.src_conn) state.remove_memlet_path(leaf) # Remove the node itself as necessary state.remove_node(node) result[state].update(dead_nodes) # Remove isolated access nodes after elimination access_nodes = set(state.data_nodes()) for node in access_nodes: if state.degree(node) == 0: state.remove_node(node) result[state].add(node) # Prune now-dead connectors for node, dead_conns in predecessor_nsdfgs.items(): for conn in dead_conns: # If removed connector belonged to a nested SDFG, and no other input connector shares name, # make nested data transient (dead dataflow elimination would remove internally as necessary) if conn not in node.in_connectors: node.sdfg.arrays[conn].transient = True # Update read sets for the predecessor states to reuse access_nodes -= result[state] access_node_names = set(n.data for n in access_nodes if state.out_degree(n) > 0) access_sets[state] = (access_node_names, access_sets[state][1]) return result or None