def _specialize_transient_strides(sdfg: dace.SDFG, layout_map): repldict = replace_strides( [array for array in sdfg.arrays.values() if array.transient], layout_map, ) sdfg.replace_dict(repldict) for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, dace.nodes.NestedSDFG): for k, v in repldict.items(): if k in node.symbol_mapping: node.symbol_mapping[k] = v for k in repldict.keys(): if k in sdfg.symbols: sdfg.remove_symbol(k)
def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = None ) -> Optional[Set[str]]: """ Propagates constants throughout the SDFG. :param sdfg: The SDFG to modify. :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass results as ``{Pass subclass name: returned object from pass}``. If not run in a pipeline, an empty dictionary is expected. :param initial_symbols: If not None, sets values of initial symbols. :return: A set of propagated constants, or None if nothing was changed. """ initial_symbols = initial_symbols or {} # Early exit if no constants can be propagated if not initial_symbols and not self.should_apply(sdfg): result = {} else: # Trace all constants and symbols through states per_state_constants: Dict[SDFGState, Dict[str, Any]] = self.collect_constants( sdfg, initial_symbols) # Keep track of replaced and ambiguous symbols symbols_replaced: Dict[str, Any] = {} remaining_unknowns: Set[str] = set() # Collect symbols from symbol-dependent data descriptors # If there can be multiple values over the SDFG, the symbols are not propagated desc_symbols, multivalue_desc_symbols = self._find_desc_symbols( sdfg, per_state_constants) # Replace constants per state for state, mapping in per_state_constants.items(): remaining_unknowns.update({ k for k, v in mapping.items() if v is _UnknownValue or k in multivalue_desc_symbols }) mapping = { k: v for k, v in mapping.items() if v is not _UnknownValue and k not in multivalue_desc_symbols } # Update replaced symbols for later replacements symbols_replaced.update(mapping) # Replace in state contents state.replace_dict(mapping) # Replace in outgoing edges as well for e in sdfg.out_edges(state): e.data.replace_dict(mapping, replace_keys=False) # If symbols are never unknown any longer, remove from SDFG result = { k: v for k, v in symbols_replaced.items() if k not in remaining_unknowns } # Remove from symbol repository for sym in result: if sym in sdfg.symbols: sdfg.remove_symbol(sym) # Remove single-valued symbols from data descriptors (e.g., symbolic array size) sdfg.replace_dict( {k: v for k, v in result.items() if k in desc_symbols}, replace_in_graph=False, replace_keys=False) # Remove constant symbol assignments in interstate edges for edge in sdfg.edges(): intersection = result & edge.data.assignments.keys() for sym in intersection: del edge.data.assignments[sym] result = set(result.keys()) if self.recursive: # Change result to set of tuples sid = sdfg.sdfg_id result = set((sid, sym) for sym in result) for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, nodes.NestedSDFG): nested_id = node.sdfg.sdfg_id const_syms = { k: v for k, v in node.symbol_mapping.items() if not symbolic.issymbolic(v) } internal = self.apply_pass(node.sdfg, _, const_syms) if internal: for nid, removed in internal: result.add((nid, removed)) # Remove symbol mapping if constant was completely propagated if nid == nested_id and removed in node.symbol_mapping: del node.symbol_mapping[removed] # Return result if not result: return None return result