def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = None ) -> Optional[Set[str]]: """ Propagates constants throughout the SDFG. :param sdfg: The SDFG to modify. :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass results as ``{Pass subclass name: returned object from pass}``. If not run in a pipeline, an empty dictionary is expected. :param initial_symbols: If not None, sets values of initial symbols. :return: A set of propagated constants, or None if nothing was changed. """ initial_symbols = initial_symbols or {} # Early exit if no constants can be propagated if not initial_symbols and not self.should_apply(sdfg): result = {} else: # Trace all constants and symbols through states per_state_constants: Dict[SDFGState, Dict[str, Any]] = self.collect_constants( sdfg, initial_symbols) # Keep track of replaced and ambiguous symbols symbols_replaced: Dict[str, Any] = {} remaining_unknowns: Set[str] = set() # Collect symbols from symbol-dependent data descriptors # If there can be multiple values over the SDFG, the symbols are not propagated desc_symbols, multivalue_desc_symbols = self._find_desc_symbols( sdfg, per_state_constants) # Replace constants per state for state, mapping in per_state_constants.items(): remaining_unknowns.update({ k for k, v in mapping.items() if v is _UnknownValue or k in multivalue_desc_symbols }) mapping = { k: v for k, v in mapping.items() if v is not _UnknownValue and k not in multivalue_desc_symbols } # Update replaced symbols for later replacements symbols_replaced.update(mapping) # Replace in state contents state.replace_dict(mapping) # Replace in outgoing edges as well for e in sdfg.out_edges(state): e.data.replace_dict(mapping, replace_keys=False) # If symbols are never unknown any longer, remove from SDFG result = { k: v for k, v in symbols_replaced.items() if k not in remaining_unknowns } # Remove from symbol repository for sym in result: if sym in sdfg.symbols: sdfg.remove_symbol(sym) # Remove single-valued symbols from data descriptors (e.g., symbolic array size) sdfg.replace_dict( {k: v for k, v in result.items() if k in desc_symbols}, replace_in_graph=False, replace_keys=False) # Remove constant symbol assignments in interstate edges for edge in sdfg.edges(): intersection = result & edge.data.assignments.keys() for sym in intersection: del edge.data.assignments[sym] result = set(result.keys()) if self.recursive: # Change result to set of tuples sid = sdfg.sdfg_id result = set((sid, sym) for sym in result) for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, nodes.NestedSDFG): nested_id = node.sdfg.sdfg_id const_syms = { k: v for k, v in node.symbol_mapping.items() if not symbolic.issymbolic(v) } internal = self.apply_pass(node.sdfg, _, const_syms) if internal: for nid, removed in internal: result.add((nid, removed)) # Remove symbol mapping if constant was completely propagated if nid == nested_id and removed in node.symbol_mapping: del node.symbol_mapping[removed] # Return result if not result: return None return result
def find_dead_states( self, sdfg: SDFG, set_unconditional_edges: bool = True) -> Set[SDFGState]: ''' Finds "dead" (unreachable) states in an SDFG. A state is deemed unreachable if it is: * Unreachable from the starting state * Conditions leading to it will always evaluate to False * There is another unconditional (always True) inter-state edge that leads to another state :param sdfg: The SDFG to traverse. :param set_unconditional_edges: If True, conditions of edges evaluated as unconditional are removed. :return: A set of unreachable states. ''' visited: Set[SDFGState] = set() # Run a modified BFS where definitely False edges are not traversed, or if there is an # unconditional edge the rest are not. The inverse of the visited states is the dead set. queue = collections.deque([sdfg.start_state]) while len(queue) > 0: node = queue.popleft() if node in visited: continue visited.add(node) # First, check for unconditional edges unconditional = None for e in sdfg.out_edges(node): # If an unconditional edge is found, ignore all other outgoing edges if self.is_definitely_taken(e.data): # If more than one unconditional outgoing edge exist, fail with Invalid SDFG if unconditional is not None: raise InvalidSDFGInterstateEdgeError( 'Multiple unconditional edges leave the same state', sdfg, sdfg.edge_id(e)) unconditional = e if set_unconditional_edges and not e.data.is_unconditional( ): # Annotate edge as unconditional e.data.condition = CodeBlock('1') # Continue traversal through edge if e.dst not in visited: queue.append(e.dst) continue if unconditional is not None: # Unconditional edge exists, skip traversal continue # End of unconditional check # Check outgoing edges normally for e in sdfg.out_edges(node): next_node = e.dst # Test for edges that definitely evaluate to False if self.is_definitely_not_taken(e.data): continue # Continue traversal through edge if next_node not in visited: queue.append(next_node) # Dead states are states that are not live (i.e., visited) return set(sdfg.nodes()) - visited