def should_apply(self, sdfg: SDFG) -> bool: """ Fast check (O(m)) whether the pass should early-exit without traversing the SDFG. """ for edge in sdfg.edges(): # If there are no assignments, there are no constants to propagate if len(edge.data.assignments) == 0: continue # If no assignment assigns a constant to a symbol, no constants can be propagated if any(not symbolic.issymbolic(aval) for aval in edge.data.assignments.values()): return True return False
def split_condition_interstate_edges(sdfg: dace.SDFG): edges_to_split = set() for isedge in sdfg.edges(): if (not isedge.data.is_unconditional() and len(isedge.data.assignments) > 0): edges_to_split.add(isedge) for ise in edges_to_split: sdfg.remove_edge(ise) interim = sdfg.add_state() sdfg.add_edge(ise.src, interim, dace.InterstateEdge(ise.data.condition)) sdfg.add_edge(interim, ise.dst, dace.InterstateEdge(assignments=ise.data.assignments))
def apply_pass(self, sdfg: SDFG, _) -> Dict[SDFGState, Tuple[Set[str], Set[str]]]: """ :return: A dictionary mapping each state to its other reachable states. """ result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} for state in sdfg.nodes(): readset, writeset = set(), set() for anode in state.data_nodes(): if state.in_degree(anode) > 0: writeset.add(anode.data) if state.out_degree(anode) > 0: readset.add(anode.data) result[state] = (readset, writeset) # Edges that read from arrays add to both ends' access sets anames = sdfg.arrays.keys() for e in sdfg.edges(): fsyms = e.data.free_symbols & anames if fsyms: result[e.src][0].update(fsyms) result[e.dst][0].update(fsyms) return result
def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = None ) -> Optional[Set[str]]: """ Propagates constants throughout the SDFG. :param sdfg: The SDFG to modify. :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass results as ``{Pass subclass name: returned object from pass}``. If not run in a pipeline, an empty dictionary is expected. :param initial_symbols: If not None, sets values of initial symbols. :return: A set of propagated constants, or None if nothing was changed. """ initial_symbols = initial_symbols or {} # Early exit if no constants can be propagated if not initial_symbols and not self.should_apply(sdfg): result = {} else: # Trace all constants and symbols through states per_state_constants: Dict[SDFGState, Dict[str, Any]] = self.collect_constants( sdfg, initial_symbols) # Keep track of replaced and ambiguous symbols symbols_replaced: Dict[str, Any] = {} remaining_unknowns: Set[str] = set() # Collect symbols from symbol-dependent data descriptors # If there can be multiple values over the SDFG, the symbols are not propagated desc_symbols, multivalue_desc_symbols = self._find_desc_symbols( sdfg, per_state_constants) # Replace constants per state for state, mapping in per_state_constants.items(): remaining_unknowns.update({ k for k, v in mapping.items() if v is _UnknownValue or k in multivalue_desc_symbols }) mapping = { k: v for k, v in mapping.items() if v is not _UnknownValue and k not in multivalue_desc_symbols } # Update replaced symbols for later replacements symbols_replaced.update(mapping) # Replace in state contents state.replace_dict(mapping) # Replace in outgoing edges as well for e in sdfg.out_edges(state): e.data.replace_dict(mapping, replace_keys=False) # If symbols are never unknown any longer, remove from SDFG result = { k: v for k, v in symbols_replaced.items() if k not in remaining_unknowns } # Remove from symbol repository for sym in result: if sym in sdfg.symbols: sdfg.remove_symbol(sym) # Remove single-valued symbols from data descriptors (e.g., symbolic array size) sdfg.replace_dict( {k: v for k, v in result.items() if k in desc_symbols}, replace_in_graph=False, replace_keys=False) # Remove constant symbol assignments in interstate edges for edge in sdfg.edges(): intersection = result & edge.data.assignments.keys() for sym in intersection: del edge.data.assignments[sym] result = set(result.keys()) if self.recursive: # Change result to set of tuples sid = sdfg.sdfg_id result = set((sid, sym) for sym in result) for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, nodes.NestedSDFG): nested_id = node.sdfg.sdfg_id const_syms = { k: v for k, v in node.symbol_mapping.items() if not symbolic.issymbolic(v) } internal = self.apply_pass(node.sdfg, _, const_syms) if internal: for nid, removed in internal: result.add((nid, removed)) # Remove symbol mapping if constant was completely propagated if nid == nested_id and removed in node.symbol_mapping: del node.symbol_mapping[removed] # Return result if not result: return None return result