Ejemplo n.º 1
0
    def _candidates(nsdfg: nodes.NestedSDFG) -> Set[str]:
        candidates = set(nsdfg.symbol_mapping.keys())
        if len(candidates) == 0:
            return set()

        for desc in nsdfg.sdfg.arrays.values():
            candidates -= set(map(str, desc.free_symbols))

        ignore = set()
        for nstate in cfg.stateorder_topological_sort(nsdfg.sdfg):
            state_syms = nstate.free_symbols

            # Try to be conservative with C++ tasklets
            for node in nstate.nodes():
                if (isinstance(node, nodes.Tasklet)
                        and node.language is dtypes.Language.CPP):
                    for candidate in candidates:
                        if re.findall(r'\b%s\b' % re.escape(candidate),
                                      node.code.as_string):
                            state_syms.add(candidate)

            # Any symbol used in this state is considered used
            candidates -= (state_syms - ignore)
            if len(candidates) == 0:
                return set()

            # Any symbol that is set in all outgoing edges is ignored from
            # this point
            local_ignore = None
            for e in nsdfg.sdfg.out_edges(nstate):
                # Look for symbols in condition
                candidates -= (set(
                    map(str, symbolic.symbols_in_ast(
                        e.data.condition.code[0]))) - ignore)

                for assign in e.data.assignments.values():
                    candidates -= (
                        symbolic.free_symbols_and_functions(assign) - ignore)

                if local_ignore is None:
                    local_ignore = set(e.data.assignments.keys())
                else:
                    local_ignore &= e.data.assignments.keys()
            if local_ignore is not None:
                ignore |= local_ignore

        return candidates
Ejemplo n.º 2
0
    def can_be_applied(self, graph, candidate, expr_index, sdfg, strict=False):
        # Is this even a loop
        if not DetectLoop.can_be_applied(graph, candidate, expr_index, sdfg,
                                         strict):
            return False

        guard = graph.node(candidate[DetectLoop._loop_guard])
        begin = graph.node(candidate[DetectLoop._loop_begin])

        # Guard state should not contain any dataflow
        if len(guard.nodes()) != 0:
            return False

        # If loop cannot be detected, fail
        found = find_for_loop(graph, guard, begin,
                              itervar=self.itervar)
        if not found:
            return False

        itervar, (start, end, step), (_, body_end) = found

        # We cannot handle symbols read from data containers unless they are
        # scalar
        for expr in (start, end, step):
            if symbolic.contains_sympy_functions(expr):
                return False

        # Find all loop-body states
        states = set([body_end])
        to_visit = [begin]
        while to_visit:
            state = to_visit.pop(0)
            if state is body_end:
                continue
            for _, dst, _ in graph.out_edges(state):
                if dst not in states:
                    to_visit.append(dst)
            states.add(state)

        write_set = set()
        for state in states:
            _, wset = state.read_and_write_sets()
            write_set |= wset

        # Get access nodes from other states to isolate local loop variables
        other_access_nodes = set()
        for state in sdfg.nodes():
            if state in states:
                continue
            other_access_nodes |= set(n.data for n in state.data_nodes()
                                      if sdfg.arrays[n.data].transient)
        # Add non-transient nodes from loop state
        for state in states:
            other_access_nodes |= set(n.data for n in state.data_nodes()
                                      if not sdfg.arrays[n.data].transient)

        write_memlets = defaultdict(list)

        itersym = symbolic.pystr_to_symbolic(itervar)
        a = sp.Wild('a', exclude=[itersym])
        b = sp.Wild('b', exclude=[itersym])

        for state in states:
            for dn in state.data_nodes():
                if dn.data not in other_access_nodes:
                    continue
                # Take all writes that are not conflicted into consideration
                if dn.data in write_set:
                    for e in state.in_edges(dn):
                        if e.data.dynamic and e.data.wcr is None:
                            # If pointers are involved, give up
                            return False
                        # To be sure that the value is only written at unique
                        # indices per loop iteration, we want to match symbols
                        # of the form "a*i+b" where a >= 1, and i is the iteration
                        # variable. The iteration variable must be used.
                        if e.data.wcr is None:
                            dst_subset = e.data.get_dst_subset(e, state)
                            if not _check_range(dst_subset, a, itersym, b, step):
                                return False
                        # End of check

                        write_memlets[dn.data].append(e.data)

        # After looping over relevant writes, consider reads that may overlap
        for state in states:
            for dn in state.data_nodes():
                if dn.data not in other_access_nodes:
                    continue
                data = dn.data
                if data in write_memlets:
                    # Import as necessary
                    from dace.sdfg.propagation import propagate_subset

                    for e in state.out_edges(dn):
                        # If the same container is both read and written, only match if
                        # it read and written at locations that will not create data races
                        if e.data.dynamic and e.data.src_subset.num_elements() != 1:
                            # If pointers are involved, give up
                            return False
                        src_subset = e.data.get_src_subset(e, state)
                        if not _check_range(src_subset, a, itersym, b, step):
                            return False

                        pread = propagate_subset([e.data], sdfg.arrays[data],
                                                [itervar],
                                                subsets.Range([(start, end, step)
                                                                ]))
                        for candidate in write_memlets[data]:
                            # Simple case: read and write are in the same subset
                            if e.data.subset == candidate.subset:
                                break
                            # Propagated read does not overlap with propagated write
                            pwrite = propagate_subset([candidate],
                                                    sdfg.arrays[data], [itervar],
                                                    subsets.Range([(start, end,
                                                                    step)]))
                            if subsets.intersects(pread.subset,
                                                pwrite.subset) is False:
                                break
                            return False

        # Check that the iteration variable is not used on other edges or states
        # before it is reassigned
        prior_states = True
        for state in cfg.stateorder_topological_sort(sdfg):
            # Skip all states up to guard
            if prior_states:
                if state is begin:
                    prior_states = False
                continue
            # We do not need to check the loop-body states
            if state in states:
                continue
            if itervar in state.free_symbols:
                return False
            # Don't continue in this direction, as the variable has
            # now been reassigned
            # TODO: Handle case of subset of out_edges
            if all(itervar in e.data.assignments
                   for e in sdfg.out_edges(state)):
                break

        return True
Ejemplo n.º 3
0
    def apply_pass(
        self, sdfg: SDFG,
        pipeline_results: Dict[str,
                               Any]) -> Optional[Dict[SDFGState, Set[str]]]:
        """
        Removes unreachable dataflow throughout SDFG states.
        :param sdfg: The SDFG to modify.
        :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass
                                 results as ``{Pass subclass name: returned object from pass}``. If not run in a
                                 pipeline, an empty dictionary is expected.
        :return: A dictionary mapping states to removed data descriptor names, or None if nothing changed.
        """
        # Depends on the following analysis passes:
        #  * State reachability
        #  * Read/write access sets per state
        reachable: Dict[SDFGState,
                        Set[SDFGState]] = pipeline_results['StateReachability']
        access_sets: Dict[SDFGState,
                          Tuple[Set[str],
                                Set[str]]] = pipeline_results['AccessSets']
        result: Dict[SDFGState, Set[str]] = defaultdict(set)

        # Traverse SDFG backwards
        for state in reversed(list(cfg.stateorder_topological_sort(sdfg))):
            #############################################
            # Analysis
            #############################################

            # Compute states where memory will no longer be read
            writes = access_sets[state][1]
            descendants = reachable[state]
            descendant_reads = set().union(*(access_sets[succ][0]
                                             for succ in descendants))
            no_longer_used: Set[str] = set(data for data in writes
                                           if data not in descendant_reads)

            # Compute dead nodes
            dead_nodes: List[nodes.Node] = []

            # Propagate deadness backwards within a state
            for node in sdutil.dfs_topological_sort(state, reverse=True):
                if self._is_node_dead(node, sdfg, state, dead_nodes,
                                      no_longer_used):
                    dead_nodes.append(node)

            # Scope exit nodes are only dead if their corresponding entry nodes are
            live_nodes = set()
            for node in dead_nodes:
                if isinstance(node, nodes.ExitNode) and state.entry_node(
                        node) not in dead_nodes:
                    live_nodes.add(node)
            dead_nodes = dtypes.deduplicate(
                [n for n in dead_nodes if n not in live_nodes])

            if not dead_nodes:
                continue

            # Remove nodes while preserving scopes
            scopes_to_reconnect: Set[nodes.Node] = set()
            for node in state.nodes():
                # Look for scope exits that will be disconnected
                if isinstance(node, nodes.ExitNode) and node not in dead_nodes:
                    if any(n in dead_nodes for n in state.predecessors(node)):
                        scopes_to_reconnect.add(node)

            # Two types of scope disconnections may occur:
            # 1. Two scope exits will no longer be connected
            # 2. A predecessor of dead nodes is in a scope and not connected to its exit
            # Case (1) is taken care of by ``remove_memlet_path``
            # Case (2) is handled below
            # Reconnect scopes
            if scopes_to_reconnect:
                schildren = state.scope_children()
                for exit_node in scopes_to_reconnect:
                    entry_node = state.entry_node(exit_node)
                    for node in schildren[entry_node]:
                        if node is exit_node:
                            continue
                        if isinstance(node, nodes.EntryNode):
                            node = state.exit_node(node)
                        # If node will be disconnected from exit node, add an empty memlet
                        if all(succ in dead_nodes
                               for succ in state.successors(node)):
                            state.add_nedge(node, exit_node, Memlet())

            #############################################
            # Removal
            #############################################
            predecessor_nsdfgs: Dict[nodes.NestedSDFG,
                                     Set[str]] = defaultdict(set)
            for node in dead_nodes:
                # Remove memlet paths and connectors pertaining to dead nodes
                for e in state.in_edges(node):
                    mtree = state.memlet_tree(e)
                    for leaf in mtree.leaves():
                        # Keep track of predecessors of removed nodes for connector pruning
                        if isinstance(leaf.src, nodes.NestedSDFG):
                            predecessor_nsdfgs[leaf.src].add(leaf.src_conn)
                        state.remove_memlet_path(leaf)

                # Remove the node itself as necessary
                state.remove_node(node)

            result[state].update(dead_nodes)

            # Remove isolated access nodes after elimination
            access_nodes = set(state.data_nodes())
            for node in access_nodes:
                if state.degree(node) == 0:
                    state.remove_node(node)
                    result[state].add(node)

            # Prune now-dead connectors
            for node, dead_conns in predecessor_nsdfgs.items():
                for conn in dead_conns:
                    # If removed connector belonged to a nested SDFG, and no other input connector shares name,
                    # make nested data transient (dead dataflow elimination would remove internally as necessary)
                    if conn not in node.in_connectors:
                        node.sdfg.arrays[conn].transient = True

            # Update read sets for the predecessor states to reuse
            access_nodes -= result[state]
            access_node_names = set(n.data for n in access_nodes
                                    if state.out_degree(n) > 0)
            access_sets[state] = (access_node_names, access_sets[state][1])

        return result or None