Ejemplo n.º 1
0
def is_parallel(state: SDFGState, node: Optional[nd.Node] = None) -> bool:
    """
    Returns True if a node or state are contained within a parallel
    section.
    :param state: The state to test.
    :param node: An optional node in the state to test. If None, only checks
                 state.
    :return: True if the state or node are located within a map scope that
             is scheduled to run in parallel, False otherwise.
    """
    if node is not None:
        sdict = state.scope_dict()
        curnode = node
        while curnode is not None:
            curnode = sdict[curnode]
            if curnode.schedule != dtypes.ScheduleType.Sequential:
                return True
    if state.parent.parent is not None:
        # Find nested SDFG node and continue recursion
        nsdfg_node = next(
            n for n in state.parent.parent
            if isinstance(n, nd.NestedSDFG) and n.sdfg == state.parent)
        return is_parallel(state.parent.parent, nsdfg_node)

    return False
Ejemplo n.º 2
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        node_a = self.node_a
        node_b = self.node_b
        prefix = self.prefix

        # Determine direction of new memlet
        scope_dict = graph.scope_dict()
        propagate_forward = sd.scope_contains_scope(scope_dict, node_a, node_b)

        array = self.array
        if array is None or len(array) == 0:
            array = next(e.data.data
                         for e in graph.edges_between(node_a, node_b)
                         if e.data.data is not None and e.data.wcr is None)

        original_edge = None
        invariant_memlet = None
        for edge in graph.edges_between(node_a, node_b):
            if array == edge.data.data:
                original_edge = edge
                invariant_memlet = edge.data
                break
        if invariant_memlet is None:
            for edge in graph.edges_between(node_a, node_b):
                original_edge = edge
                invariant_memlet = edge.data
                warnings.warn('Array %s not found! Using array %s instead.' %
                              (array, invariant_memlet.data))
                array = invariant_memlet.data
                break
        if invariant_memlet is None:
            raise NameError('Array %s not found!' % array)
        if self.create_array:
            # Add transient array
            new_data, _ = sdfg.add_transient(
                name=prefix + invariant_memlet.data,
                shape=[
                    symbolic.overapproximate(r).simplify()
                    for r in invariant_memlet.bounding_box_size()
                ],
                dtype=sdfg.arrays[invariant_memlet.data].dtype,
                find_new_name=True)

        else:
            new_data = prefix + invariant_memlet.data
        data_node = nodes.AccessNode(new_data)
        # Store as fields so that other transformations can use them
        self._local_name = new_data
        self._data_node = data_node

        to_data_mm = copy.deepcopy(invariant_memlet)
        from_data_mm = copy.deepcopy(invariant_memlet)
        offset = subsets.Indices([r[0] for r in invariant_memlet.subset])

        # Reconnect, assuming one edge to the access node
        graph.remove_edge(original_edge)
        if propagate_forward:
            graph.add_edge(node_a, original_edge.src_conn, data_node, None,
                           to_data_mm)
            new_edge = graph.add_edge(data_node, None, node_b,
                                      original_edge.dst_conn, from_data_mm)
        else:
            new_edge = graph.add_edge(node_a, original_edge.src_conn,
                                      data_node, None, to_data_mm)
            graph.add_edge(data_node, None, node_b, original_edge.dst_conn,
                           from_data_mm)

        # Offset all edges in the memlet tree (including the new edge)
        for edge in graph.memlet_tree(new_edge):
            edge.data.subset.offset(offset, True)
            edge.data.data = new_data

        return data_node