Ejemplo n.º 1
0
    def should_apply(self, sdfg: SDFG) -> bool:
        """
        Fast check (O(m)) whether the pass should early-exit without traversing the SDFG.
        """
        for edge in sdfg.edges():
            # If there are no assignments, there are no constants to propagate
            if len(edge.data.assignments) == 0:
                continue
            # If no assignment assigns a constant to a symbol, no constants can be propagated
            if any(not symbolic.issymbolic(aval)
                   for aval in edge.data.assignments.values()):
                return True

        return False
Ejemplo n.º 2
0
def split_condition_interstate_edges(sdfg: dace.SDFG):
    edges_to_split = set()
    for isedge in sdfg.edges():
        if (not isedge.data.is_unconditional()
                and len(isedge.data.assignments) > 0):
            edges_to_split.add(isedge)

    for ise in edges_to_split:
        sdfg.remove_edge(ise)
        interim = sdfg.add_state()
        sdfg.add_edge(ise.src, interim,
                      dace.InterstateEdge(ise.data.condition))
        sdfg.add_edge(interim, ise.dst,
                      dace.InterstateEdge(assignments=ise.data.assignments))
Ejemplo n.º 3
0
    def apply_pass(self, sdfg: SDFG,
                   _) -> Dict[SDFGState, Tuple[Set[str], Set[str]]]:
        """
        :return: A dictionary mapping each state to its other reachable states.
        """
        result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {}
        for state in sdfg.nodes():
            readset, writeset = set(), set()
            for anode in state.data_nodes():
                if state.in_degree(anode) > 0:
                    writeset.add(anode.data)
                if state.out_degree(anode) > 0:
                    readset.add(anode.data)

            result[state] = (readset, writeset)

        # Edges that read from arrays add to both ends' access sets
        anames = sdfg.arrays.keys()
        for e in sdfg.edges():
            fsyms = e.data.free_symbols & anames
            if fsyms:
                result[e.src][0].update(fsyms)
                result[e.dst][0].update(fsyms)
        return result
Ejemplo n.º 4
0
    def apply_pass(self,
                   sdfg: SDFG,
                   _,
                   initial_symbols: Optional[Dict[str, Any]] = None
                   ) -> Optional[Set[str]]:
        """
        Propagates constants throughout the SDFG.
        :param sdfg: The SDFG to modify.
        :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass
                                 results as ``{Pass subclass name: returned object from pass}``. If not run in a
                                 pipeline, an empty dictionary is expected.
        :param initial_symbols: If not None, sets values of initial symbols.
        :return: A set of propagated constants, or None if nothing was changed.
        """
        initial_symbols = initial_symbols or {}

        # Early exit if no constants can be propagated
        if not initial_symbols and not self.should_apply(sdfg):
            result = {}
        else:
            # Trace all constants and symbols through states
            per_state_constants: Dict[SDFGState,
                                      Dict[str, Any]] = self.collect_constants(
                                          sdfg, initial_symbols)

            # Keep track of replaced and ambiguous symbols
            symbols_replaced: Dict[str, Any] = {}
            remaining_unknowns: Set[str] = set()

            # Collect symbols from symbol-dependent data descriptors
            # If there can be multiple values over the SDFG, the symbols are not propagated
            desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(
                sdfg, per_state_constants)

            # Replace constants per state
            for state, mapping in per_state_constants.items():
                remaining_unknowns.update({
                    k
                    for k, v in mapping.items()
                    if v is _UnknownValue or k in multivalue_desc_symbols
                })
                mapping = {
                    k: v
                    for k, v in mapping.items() if v is not _UnknownValue
                    and k not in multivalue_desc_symbols
                }

                # Update replaced symbols for later replacements
                symbols_replaced.update(mapping)

                # Replace in state contents
                state.replace_dict(mapping)
                # Replace in outgoing edges as well
                for e in sdfg.out_edges(state):
                    e.data.replace_dict(mapping, replace_keys=False)

            # If symbols are never unknown any longer, remove from SDFG
            result = {
                k: v
                for k, v in symbols_replaced.items()
                if k not in remaining_unknowns
            }
            # Remove from symbol repository
            for sym in result:
                if sym in sdfg.symbols:
                    sdfg.remove_symbol(sym)

            # Remove single-valued symbols from data descriptors (e.g., symbolic array size)
            sdfg.replace_dict(
                {k: v
                 for k, v in result.items() if k in desc_symbols},
                replace_in_graph=False,
                replace_keys=False)

            # Remove constant symbol assignments in interstate edges
            for edge in sdfg.edges():
                intersection = result & edge.data.assignments.keys()
                for sym in intersection:
                    del edge.data.assignments[sym]

        result = set(result.keys())

        if self.recursive:
            # Change result to set of tuples
            sid = sdfg.sdfg_id
            result = set((sid, sym) for sym in result)

            for state in sdfg.nodes():
                for node in state.nodes():
                    if isinstance(node, nodes.NestedSDFG):
                        nested_id = node.sdfg.sdfg_id
                        const_syms = {
                            k: v
                            for k, v in node.symbol_mapping.items()
                            if not symbolic.issymbolic(v)
                        }
                        internal = self.apply_pass(node.sdfg, _, const_syms)
                        if internal:
                            for nid, removed in internal:
                                result.add((nid, removed))
                                # Remove symbol mapping if constant was completely propagated
                                if nid == nested_id and removed in node.symbol_mapping:
                                    del node.symbol_mapping[removed]

        # Return result
        if not result:
            return None
        return result