예제 #1
0
def _specialize_transient_strides(sdfg: dace.SDFG, layout_map):
    repldict = replace_strides(
        [array for array in sdfg.arrays.values() if array.transient],
        layout_map,
    )
    sdfg.replace_dict(repldict)
    for state in sdfg.nodes():
        for node in state.nodes():
            if isinstance(node, dace.nodes.NestedSDFG):
                for k, v in repldict.items():
                    if k in node.symbol_mapping:
                        node.symbol_mapping[k] = v
    for k in repldict.keys():
        if k in sdfg.symbols:
            sdfg.remove_symbol(k)
예제 #2
0
    def apply_pass(self,
                   sdfg: SDFG,
                   _,
                   initial_symbols: Optional[Dict[str, Any]] = None
                   ) -> Optional[Set[str]]:
        """
        Propagates constants throughout the SDFG.
        :param sdfg: The SDFG to modify.
        :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass
                                 results as ``{Pass subclass name: returned object from pass}``. If not run in a
                                 pipeline, an empty dictionary is expected.
        :param initial_symbols: If not None, sets values of initial symbols.
        :return: A set of propagated constants, or None if nothing was changed.
        """
        initial_symbols = initial_symbols or {}

        # Early exit if no constants can be propagated
        if not initial_symbols and not self.should_apply(sdfg):
            result = {}
        else:
            # Trace all constants and symbols through states
            per_state_constants: Dict[SDFGState,
                                      Dict[str, Any]] = self.collect_constants(
                                          sdfg, initial_symbols)

            # Keep track of replaced and ambiguous symbols
            symbols_replaced: Dict[str, Any] = {}
            remaining_unknowns: Set[str] = set()

            # Collect symbols from symbol-dependent data descriptors
            # If there can be multiple values over the SDFG, the symbols are not propagated
            desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(
                sdfg, per_state_constants)

            # Replace constants per state
            for state, mapping in per_state_constants.items():
                remaining_unknowns.update({
                    k
                    for k, v in mapping.items()
                    if v is _UnknownValue or k in multivalue_desc_symbols
                })
                mapping = {
                    k: v
                    for k, v in mapping.items() if v is not _UnknownValue
                    and k not in multivalue_desc_symbols
                }

                # Update replaced symbols for later replacements
                symbols_replaced.update(mapping)

                # Replace in state contents
                state.replace_dict(mapping)
                # Replace in outgoing edges as well
                for e in sdfg.out_edges(state):
                    e.data.replace_dict(mapping, replace_keys=False)

            # If symbols are never unknown any longer, remove from SDFG
            result = {
                k: v
                for k, v in symbols_replaced.items()
                if k not in remaining_unknowns
            }
            # Remove from symbol repository
            for sym in result:
                if sym in sdfg.symbols:
                    sdfg.remove_symbol(sym)

            # Remove single-valued symbols from data descriptors (e.g., symbolic array size)
            sdfg.replace_dict(
                {k: v
                 for k, v in result.items() if k in desc_symbols},
                replace_in_graph=False,
                replace_keys=False)

            # Remove constant symbol assignments in interstate edges
            for edge in sdfg.edges():
                intersection = result & edge.data.assignments.keys()
                for sym in intersection:
                    del edge.data.assignments[sym]

        result = set(result.keys())

        if self.recursive:
            # Change result to set of tuples
            sid = sdfg.sdfg_id
            result = set((sid, sym) for sym in result)

            for state in sdfg.nodes():
                for node in state.nodes():
                    if isinstance(node, nodes.NestedSDFG):
                        nested_id = node.sdfg.sdfg_id
                        const_syms = {
                            k: v
                            for k, v in node.symbol_mapping.items()
                            if not symbolic.issymbolic(v)
                        }
                        internal = self.apply_pass(node.sdfg, _, const_syms)
                        if internal:
                            for nid, removed in internal:
                                result.add((nid, removed))
                                # Remove symbol mapping if constant was completely propagated
                                if nid == nested_id and removed in node.symbol_mapping:
                                    del node.symbol_mapping[removed]

        # Return result
        if not result:
            return None
        return result