Example #1
0
    def apply_pass(self,
                   sdfg: SDFG,
                   _,
                   initial_symbols: Optional[Dict[str, Any]] = None
                   ) -> Optional[Set[str]]:
        """
        Propagates constants throughout the SDFG.
        :param sdfg: The SDFG to modify.
        :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass
                                 results as ``{Pass subclass name: returned object from pass}``. If not run in a
                                 pipeline, an empty dictionary is expected.
        :param initial_symbols: If not None, sets values of initial symbols.
        :return: A set of propagated constants, or None if nothing was changed.
        """
        initial_symbols = initial_symbols or {}

        # Early exit if no constants can be propagated
        if not initial_symbols and not self.should_apply(sdfg):
            result = {}
        else:
            # Trace all constants and symbols through states
            per_state_constants: Dict[SDFGState,
                                      Dict[str, Any]] = self.collect_constants(
                                          sdfg, initial_symbols)

            # Keep track of replaced and ambiguous symbols
            symbols_replaced: Dict[str, Any] = {}
            remaining_unknowns: Set[str] = set()

            # Collect symbols from symbol-dependent data descriptors
            # If there can be multiple values over the SDFG, the symbols are not propagated
            desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(
                sdfg, per_state_constants)

            # Replace constants per state
            for state, mapping in per_state_constants.items():
                remaining_unknowns.update({
                    k
                    for k, v in mapping.items()
                    if v is _UnknownValue or k in multivalue_desc_symbols
                })
                mapping = {
                    k: v
                    for k, v in mapping.items() if v is not _UnknownValue
                    and k not in multivalue_desc_symbols
                }

                # Update replaced symbols for later replacements
                symbols_replaced.update(mapping)

                # Replace in state contents
                state.replace_dict(mapping)
                # Replace in outgoing edges as well
                for e in sdfg.out_edges(state):
                    e.data.replace_dict(mapping, replace_keys=False)

            # If symbols are never unknown any longer, remove from SDFG
            result = {
                k: v
                for k, v in symbols_replaced.items()
                if k not in remaining_unknowns
            }
            # Remove from symbol repository
            for sym in result:
                if sym in sdfg.symbols:
                    sdfg.remove_symbol(sym)

            # Remove single-valued symbols from data descriptors (e.g., symbolic array size)
            sdfg.replace_dict(
                {k: v
                 for k, v in result.items() if k in desc_symbols},
                replace_in_graph=False,
                replace_keys=False)

            # Remove constant symbol assignments in interstate edges
            for edge in sdfg.edges():
                intersection = result & edge.data.assignments.keys()
                for sym in intersection:
                    del edge.data.assignments[sym]

        result = set(result.keys())

        if self.recursive:
            # Change result to set of tuples
            sid = sdfg.sdfg_id
            result = set((sid, sym) for sym in result)

            for state in sdfg.nodes():
                for node in state.nodes():
                    if isinstance(node, nodes.NestedSDFG):
                        nested_id = node.sdfg.sdfg_id
                        const_syms = {
                            k: v
                            for k, v in node.symbol_mapping.items()
                            if not symbolic.issymbolic(v)
                        }
                        internal = self.apply_pass(node.sdfg, _, const_syms)
                        if internal:
                            for nid, removed in internal:
                                result.add((nid, removed))
                                # Remove symbol mapping if constant was completely propagated
                                if nid == nested_id and removed in node.symbol_mapping:
                                    del node.symbol_mapping[removed]

        # Return result
        if not result:
            return None
        return result
Example #2
0
    def find_dead_states(
            self,
            sdfg: SDFG,
            set_unconditional_edges: bool = True) -> Set[SDFGState]:
        '''
        Finds "dead" (unreachable) states in an SDFG. A state is deemed unreachable if it is:
            * Unreachable from the starting state
            * Conditions leading to it will always evaluate to False
            * There is another unconditional (always True) inter-state edge that leads to another state

        :param sdfg: The SDFG to traverse.
        :param set_unconditional_edges: If True, conditions of edges evaluated as unconditional are removed.
        :return: A set of unreachable states.
        '''
        visited: Set[SDFGState] = set()

        # Run a modified BFS where definitely False edges are not traversed, or if there is an
        # unconditional edge the rest are not. The inverse of the visited states is the dead set.
        queue = collections.deque([sdfg.start_state])
        while len(queue) > 0:
            node = queue.popleft()
            if node in visited:
                continue
            visited.add(node)

            # First, check for unconditional edges
            unconditional = None
            for e in sdfg.out_edges(node):
                # If an unconditional edge is found, ignore all other outgoing edges
                if self.is_definitely_taken(e.data):
                    # If more than one unconditional outgoing edge exist, fail with Invalid SDFG
                    if unconditional is not None:
                        raise InvalidSDFGInterstateEdgeError(
                            'Multiple unconditional edges leave the same state',
                            sdfg, sdfg.edge_id(e))
                    unconditional = e
                    if set_unconditional_edges and not e.data.is_unconditional(
                    ):
                        # Annotate edge as unconditional
                        e.data.condition = CodeBlock('1')

                    # Continue traversal through edge
                    if e.dst not in visited:
                        queue.append(e.dst)
                        continue
            if unconditional is not None:  # Unconditional edge exists, skip traversal
                continue
            # End of unconditional check

            # Check outgoing edges normally
            for e in sdfg.out_edges(node):
                next_node = e.dst

                # Test for edges that definitely evaluate to False
                if self.is_definitely_not_taken(e.data):
                    continue

                # Continue traversal through edge
                if next_node not in visited:
                    queue.append(next_node)

        # Dead states are states that are not live (i.e., visited)
        return set(sdfg.nodes()) - visited