예제 #1
0
class TrivialTaskletElimination(transformation.SingleStateTransformation):
    """ Implements the Trivial-Tasklet Elimination pattern.

        Trivial-Tasklet Elimination removes tasklets that just copy the input
        to the output without WCR.
    """

    read = transformation.PatternNode(nodes.AccessNode)
    tasklet = transformation.PatternNode(nodes.Tasklet)
    write = transformation.PatternNode(nodes.AccessNode)

    @classmethod
    def expressions(cls):
        return [sdutil.node_path_graph(cls.read, cls.tasklet, cls.write)]

    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        read = self.read
        tasklet = self.tasklet
        write = self.write
        # Do not apply on Streams
        if isinstance(sdfg.arrays[read.data], data.Stream):
            return False
        if isinstance(sdfg.arrays[write.data], data.Stream):
            return False
        if len(graph.in_edges(tasklet)) != 1:
            return False
        if len(graph.out_edges(tasklet)) != 1:
            return False
        if graph.edges_between(tasklet, write)[0].data.wcr:
            return False
        if len(tasklet.in_connectors) != 1:
            return False
        if len(tasklet.out_connectors) != 1:
            return False
        in_conn = list(tasklet.in_connectors.keys())[0]
        out_conn = list(tasklet.out_connectors.keys())[0]
        if tasklet.code.as_string != f'{out_conn} = {in_conn}':
            return False

        return True

    def apply(self, graph, sdfg):
        read = self.read
        tasklet = self.tasklet
        write = self.write

        in_edge = graph.edges_between(read, tasklet)[0]
        out_edge = graph.edges_between(tasklet, write)[0]
        graph.remove_edge(in_edge)
        graph.remove_edge(out_edge)
        out_edge.data.other_subset = in_edge.data.subset
        graph.add_nedge(read, write, out_edge.data)
        graph.remove_node(tasklet)
예제 #2
0
class FalseConditionElimination(transformation.MultiStateTransformation):
    """
    If a state transition condition is always false, removes edge.
    """

    state_a = transformation.PatternNode(sdfg.SDFGState)
    state_b = transformation.PatternNode(sdfg.SDFGState)

    @classmethod
    def expressions(cls):
        return [sdutil.node_path_graph(cls.state_a, cls.state_b)]

    def can_be_applied(self,
                       graph: SDFG,
                       expr_index,
                       sdfg: SDFG,
                       permissive=False):
        a: SDFGState = self.state_a
        b: SDFGState = self.state_b

        in_edges = graph.in_edges(b)

        # Only apply in cases where DeadStateElimination wouldn't
        if len(in_edges) <= 1:
            return False

        # Directed graph has only one edge between two nodes
        edge = graph.edges_between(a, b)[0]

        if edge.data.assignments:
            return False
        if edge.data.is_unconditional():
            return False

        # Evaluate condition
        scond = edge.data.condition_sympy()
        if scond == False:
            return True

        return False

    def apply(self, _, sdfg: SDFG):
        a: SDFGState = self.state_a
        b: SDFGState = self.state_b
        edge = sdfg.edges_between(a, b)[0]
        sdfg.remove_edge(edge)
예제 #3
0
class StartStateElimination(transformation.MultiStateTransformation):
    """
    Start-state elimination removes a redundant state that has one outgoing edge
    and no contents. This transformation applies only to nested SDFGs.
    """

    start_state = transformation.PatternNode(SDFGState)

    @classmethod
    def expressions(cls):
        return [sdutil.node_path_graph(cls.start_state)]

    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        state = self.start_state

        # The transformation applies only to nested SDFGs
        if not graph.parent:
            return False

        # Only empty states can be eliminated
        if state.number_of_nodes() > 0:
            return False

        out_edges = graph.out_edges(state)
        in_edges = graph.in_edges(state)

        # If this is a start state, there are no incoming edges
        if len(in_edges) != 0:
            return False

        # We only match start states with one sink and no conditions
        if len(out_edges) != 1:
            return False
        edge = out_edges[0]
        if not edge.data.is_unconditional():
            return False
        # Assignments that make descriptors into symbols cannot be eliminated
        for assign in edge.data.assignments.values():
            if graph.arrays.keys() & symbolic.free_symbols_and_functions(
                    assign):
                return False

        return True

    def apply(self, _, sdfg):
        state = self.start_state
        # Move assignments to the nested SDFG node's symbol mappings
        node = sdfg.parent_nsdfg_node
        edge = sdfg.out_edges(state)[0]
        for k, v in edge.data.assignments.items():
            node.symbol_mapping[k] = v
        sdfg.remove_node(state)
예제 #4
0
class RedundantArrayCopying3(pm.SingleStateTransformation):
    """ Implements the redundant array removal transformation. Removes multiples
        of array B in pattern MapEntry -> B.
    """

    map_entry = pm.PatternNode(nodes.MapEntry)
    out_array = pm.PatternNode(nodes.AccessNode)

    @classmethod
    def expressions(cls):
        return [sdutil.node_path_graph(cls.map_entry, cls.out_array)]

    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        map_entry = self.map_entry
        out_array = self.out_array

        # Ensure out degree is one (only one target, which is out_array)
        found = 0
        for _, _, dst, _, _ in graph.out_edges(map_entry):
            if (isinstance(dst, nodes.AccessNode) and dst != out_array
                    and dst.data == out_array.data):
                found += 1

        return found > 0

    def apply(self, graph, sdfg):
        map_entry = self.map_entry
        out_array = self.out_array

        for e1 in graph.out_edges(map_entry):
            dst = e1.dst
            if (isinstance(dst, nodes.AccessNode) and dst != out_array
                    and dst.data == out_array.data):
                for e2 in graph.out_edges(dst):
                    graph.add_edge(out_array, None, e2.dst, e2.dst_conn,
                                   e2.data)
                    graph.remove_edge(e2)
                graph.remove_edge(e1)
                graph.remove_node(dst)
예제 #5
0
class TrueConditionElimination(transformation.MultiStateTransformation,
                               transformation.SimplifyPass):
    """
    If a state transition condition is always true, removes condition from edge.
    """

    state_a = transformation.PatternNode(sdfg.SDFGState)
    state_b = transformation.PatternNode(sdfg.SDFGState)

    @classmethod
    def expressions(cls):
        return [sdutil.node_path_graph(cls.state_a, cls.state_b)]

    def can_be_applied(self,
                       graph: SDFG,
                       expr_index,
                       sdfg: SDFG,
                       permissive=False):
        a: SDFGState = self.state_a
        b: SDFGState = self.state_b
        # Directed graph has only one edge between two nodes
        edge = graph.edges_between(a, b)[0]

        if edge.data.is_unconditional():
            return False

        # Evaluate condition
        scond = edge.data.condition_sympy()
        if scond == True:
            return True

        return False

    def apply(self, _, sdfg: SDFG):
        a: SDFGState = self.state_a
        b: SDFGState = self.state_b
        edge = sdfg.edges_between(a, b)[0]
        edge.data.condition = CodeBlock("1")
예제 #6
0
class Reduction1Operation(pm.Transformation):
    """ Detects reduction1 operations.
    """

    map_entry = pm.PatternNode(nodes.MapEntry)

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(Reduction1Operation.map_entry)]

    @staticmethod
    def can_be_applied(graph: dace.SDFGState,
                       candidate: Dict[pm.PatternNode, int],
                       expr_index: int,
                       sdfg: dace.SDFG,
                       permissive: bool = False):

        map_entry = graph.node(candidate[Reduction1Operation.map_entry])
        map_exit = graph.exit_node(map_entry)
        params = [dace.symbol(p) for p in map_entry.map.params]

        outputs = dict()
        for _, _, _, _, m in graph.out_edges(map_exit):
            if not m.wcr:
                return False
            desc = sdfg.arrays[m.data]
            if desc not in outputs.keys():
                outputs[desc] = []
            outputs[desc].append(m.subset)

        for desc, accesses in outputs.items():
            if isinstance(desc, dace.data.Scalar):
                continue
            elif isinstance(desc, (dace.data.Array, dace.data.View)):
                for a in accesses:
                    if a.num_elements() != 1:
                        return False
            else:
                return False

        return True

    @staticmethod
    def match_to_str(graph: dace.SDFGState, candidate: Dict[pm.PatternNode,
                                                            int]) -> str:
        map_entry = graph.node(candidate[Reduction1Operation.map_entry])
        return map_entry.map.label + ': ' + str(map_entry.map.params)

    def apply(self, sdfg: dace.SDFG):
        pass
예제 #7
0
class EndStateElimination(transformation.MultiStateTransformation,
                          transformation.SimplifyPass):
    """
    End-state elimination removes a redundant state that has one incoming edge
    and no contents.
    """

    end_state = transformation.PatternNode(SDFGState)

    @classmethod
    def expressions(cls):
        return [sdutil.node_path_graph(cls.end_state)]

    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        state = self.end_state

        out_edges = graph.out_edges(state)
        in_edges = graph.in_edges(state)

        # If this is an end state, there are no outgoing edges
        if len(out_edges) != 0:
            return False

        # We only match end states with one source and no conditions
        if len(in_edges) != 1:
            return False
        edge = in_edges[0]
        if not edge.data.is_unconditional():
            return False

        # Only empty states can be eliminated
        if state.number_of_nodes() > 0:
            return False

        return True

    def apply(self, _, sdfg):
        state = self.end_state
        # Handle orphan symbols (due to the deletion the incoming edge)
        edge = sdfg.in_edges(state)[0]
        sym_assign = edge.data.assignments.keys()
        sdfg.remove_node(state)
        # Remove orphan symbols
        for sym in sym_assign:
            if sym in sdfg.free_symbols:
                sdfg.remove_symbol(sym)
예제 #8
0
class DeadStateElimination(transformation.MultiStateTransformation):
    """
    Dead state elimination removes an unreachable state and all of its dominated
    states.
    """

    end_state = transformation.PatternNode(sdfg.SDFGState)

    @classmethod
    def expressions(cls):
        return [sdutil.node_path_graph(cls.end_state)]

    def can_be_applied(self,
                       graph: SDFG,
                       expr_index,
                       sdfg: SDFG,
                       permissive=False):
        state: SDFGState = self.end_state
        in_edges = graph.in_edges(state)

        # We only match end states with one source and at least one assignment
        if len(in_edges) != 1:
            return False
        edge = in_edges[0]

        if edge.data.assignments:
            return False
        if edge.data.is_unconditional():
            return False

        # Evaluate condition
        scond = edge.data.condition_sympy()
        if scond == False:
            return True

        return False

    def apply(self, _, sdfg: SDFG):
        # Remove state and all dominated states
        state = self.end_state

        domset = cfg.all_dominators(sdfg)
        states_to_remove = {k for k, v in domset.items() if state in v}
        states_to_remove.add(state)
        sdfg.remove_nodes_from(states_to_remove)
예제 #9
0
class MapDimShuffle(transformation.Transformation):
    """ Implements the map-dim shuffle transformation.
    
        MapDimShuffle takes a map and a list of params.
        It reorders the dimensions in the map such that it matches the list.
    """

    _map_entry = transformation.PatternNode(nodes.MapEntry)

    # Properties
    parameters = ShapeProperty(dtype=list,
                               default=None,
                               desc="Desired order of map parameters")

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(MapDimShuffle._map_entry)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False):
        return True

    @staticmethod
    def match_to_str(graph, candidate):
        map_entry = graph.nodes()[candidate[MapDimShuffle._map_entry]]
        return map_entry.map.label + ': ' + str(map_entry.map.params)

    def apply(self, sdfg: SDFG):
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[self._map_entry]]

        if set(self.parameters) != set(map_entry.map.params):
            return

        map_entry.range.ranges = [
            r for list_param in self.parameters for map_param, r in zip(
                map_entry.map.params, map_entry.range.ranges)
            if list_param == map_param
        ]
        map_entry.map.params = self.parameters
예제 #10
0
class TrivialMapRangeElimination(transformation.SingleStateTransformation):
    """ Implements the Trivial Map Range Elimination pattern.

        Trivial Map Range Elimination takes a multi-dimensional map with 
        a range containing one element and removes the corresponding dimension.
        Example: Map[i=0:I,j=0] -> Map[i=0:I]
    """

    map_entry = transformation.PatternNode(nodes.MapEntry)

    @classmethod
    def expressions(cls):
        return [sdutil.node_path_graph(cls.map_entry)]

    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        map_entry = self.map_entry
        if len(map_entry.map.range) <= 1:
            return False  # only acts on multi-dimensional maps
        return any(frm == to for frm, to, _ in map_entry.map.range)

    def apply(self, graph, sdfg):
        map_entry = self.map_entry

        remaining_ranges = []
        remaining_params = []
        for map_param, ranges in zip(map_entry.map.params,
                                     map_entry.map.range.ranges):
            map_from, map_to, _ = ranges
            if map_from == map_to:
                # Replace the map index variable with the value it obtained
                scope = graph.scope_subgraph(map_entry)
                scope.replace(map_param, map_from)
            else:
                remaining_ranges.append(ranges)
                remaining_params.append(map_param)

        map_entry.map.range.ranges = remaining_ranges
        map_entry.map.params = remaining_params
예제 #11
0
class OTFMapFusion(transformation.SingleStateTransformation):
    """ Performs fusion of two maps by replicating the contents of the first into the second map
        until all the input dependencies (memlets) of the second one are met.
    """
    first_map_exit = transformation.PatternNode(nds.ExitNode)
    array = transformation.PatternNode(nds.AccessNode)
    second_map_entry = transformation.PatternNode(nds.EntryNode)

    @staticmethod
    def annotates_memlets():
        return False

    @classmethod
    def expressions(cls):
        return [
            sdutil.node_path_graph(cls.first_map_exit, cls.array,
                                   cls.second_map_entry)
        ]

    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        # WCR: not supported on first map
        for _in_e in graph.in_edges(self.first_map_exit):
            if _in_e.data.wcr is not None:
                return False

        # Check intermediate nodes between both maps.
        for _, _, node, _, _ in graph.out_edges(self.first_map_exit):
            # Only map -> array -> map
            if not isinstance(node, nds.AccessNode):
                return False

            # Non-transient blocks removal of first map
            if not sdfg.arrays[node.data].transient:
                return False

            # Check that array is not co-produced by other parent map.
            producers = set(map(lambda edge: edge.src, graph.in_edges(node)))
            for prod in producers:
                if prod != self.first_map_exit:
                    return False

            # Check that array is not co-consumed by other child mao
            consumers = set(map(lambda edge: edge.dst, graph.out_edges(node)))
            for cons in consumers:
                if cons != self.second_map_entry:
                    return False

        # Success
        return True

    def apply(self, graph: SDFGState, sdfg: SDFG):
        first_map_entry = graph.entry_node(self.first_map_exit)

        intermediate_dnodes = set()
        for _, _, node, _, _ in graph.out_edges(self.first_map_exit):
            if not isinstance(node, nds.AccessNode):
                continue

            intermediate_dnodes.add(node)

        self._update_in_connectors(graph, intermediate_dnodes)
        self._replicate_first_map(sdfg, graph, first_map_entry,
                                  intermediate_dnodes)

        graph.remove_nodes_from(
            graph.all_nodes_between(first_map_entry, self.first_map_exit)
            | {first_map_entry, self.first_map_exit})

        for node in graph.nodes():
            if not isinstance(node, nds.AccessNode):
                continue

            if graph.in_degree(node) == 0 and graph.out_degree(node) == 0:
                graph.remove_node(node)

    def _update_in_connectors(self, graph, intermediate_dnodes):
        first_map_entry = graph.entry_node(self.first_map_exit)
        for dnode in intermediate_dnodes:
            for edge in graph.edges_between(dnode, self.second_map_entry):
                graph.remove_edge_and_connectors(edge)

        for edge in graph.in_edges(first_map_entry):
            if self.second_map_entry.add_in_connector(edge.dst_conn + "_"):
                graph.add_edge(edge.src, edge.src_conn, self.second_map_entry,
                               edge.dst_conn + "_", edge.data)
            else:
                raise ValueError("Failed to connect")

    def _replicate_first_map(self, sdfg, graph, first_map_entry,
                             intermediate_dnodes):
        for dnode in intermediate_dnodes:
            array_name = dnode.data
            array = sdfg.arrays[array_name]

            read_offsets = self._read_offsets(graph, array_name)

            # Replicate first map tasklets once for each read offset access and
            # connect them to other tasklets accordingly
            for offset, edges in read_offsets.items():
                new_nodes = self._copy_first_map_contents(
                    sdfg, graph, first_map_entry)
                tmp_name = "__otf"
                tmp_name, _ = sdfg.add_scalar(tmp_name,
                                              array.dtype,
                                              transient=True,
                                              find_new_name=True)
                tmp_access = graph.add_access(tmp_name)

                for node in new_nodes:
                    for edge in graph.edges_between(node, self.first_map_exit):
                        graph.add_edge(edge.src, edge.src_conn, tmp_access,
                                       None, Memlet(tmp_name))
                        graph.remove_edge(edge)

                    for edge in graph.edges_between(first_map_entry, node):
                        memlet = dcpy(edge.data)
                        memlet.subset.offset(list(offset), negative=False)
                        self.second_map_entry.add_out_connector(edge.src_conn +
                                                                "_")
                        graph.add_edge(self.second_map_entry,
                                       edge.src_conn + "_", node,
                                       edge.dst_conn, memlet)
                        graph.remove_edge(edge)

                for edge in edges:
                    graph.add_edge(tmp_access, None, edge.dst, edge.dst_conn,
                                   Memlet(tmp_name))

    def _read_offsets(self, state, array_name):
        """Compute offsets of read accesses in second map."""
        # Get output memlet of first tasklet
        output_edges = state.in_edges(self.first_map_exit)
        assert len(output_edges) == 1
        write_memlet = output_edges[0].data

        # Find read offsets by looping over second map entry connectors
        offsets = defaultdict(list)
        for edge in state.out_edges(self.second_map_entry):
            if edge.data.data == array_name:
                self.second_map_entry.remove_out_connector(edge.src_conn)
                state.remove_edge(edge)
                offset = OTFMapFusion._memlet_offsets(write_memlet, edge.data)
                offsets[offset].append(edge)

        return offsets

    def _copy_first_map_contents(self, sdfg, graph, first_map_entry):
        inter_nodes = list(
            graph.all_nodes_between(first_map_entry, self.first_map_exit) -
            {first_map_entry})
        new_inter_nodes = [dcpy(node) for node in inter_nodes]
        tmp_map = dict()
        for node in new_inter_nodes:
            if isinstance(node, nds.AccessNode):
                data = sdfg.arrays[node.data]
                if isinstance(data, dt.Scalar) and data.transient:
                    tmp_name = sdfg.temp_data_name()
                    sdfg.add_scalar(tmp_name, data.dtype, transient=True)
                    tmp_map[node.data] = tmp_name
                    node.data = tmp_name
            graph.add_node(node)
        id_map = {
            graph.node_id(old): graph.node_id(new)
            for old, new in zip(inter_nodes, new_inter_nodes)
        }

        def map_node(node):
            return graph.node(id_map[graph.node_id(node)])

        def map_memlet(memlet):
            memlet = dcpy(memlet)
            memlet.data = tmp_map.get(memlet.data, memlet.data)
            return memlet

        for edge in graph.edges():
            if edge.src in inter_nodes or edge.dst in inter_nodes:
                src = map_node(
                    edge.src) if edge.src in inter_nodes else edge.src
                dst = map_node(
                    edge.dst) if edge.dst in inter_nodes else edge.dst
                edge_data = map_memlet(edge.data)
                graph.add_edge(src, edge.src_conn, dst, edge.dst_conn,
                               edge_data)

        return new_inter_nodes

    @staticmethod
    def _memlet_offsets(base_memlet, offset_memlet):
        """Compute subset offset of `offset_memlet` relative to `base_memlet`."""
        def offset(base_range, offset_range):
            b0, e0, s0 = base_range
            b1, e1, s1 = offset_range
            assert e1 - e0 == b1 - b0 and s0 == s1
            return int(e1 - e0)

        return tuple(
            offset(b, o) for b, o in zip(base_memlet.subset.ranges,
                                         offset_memlet.subset.ranges))
예제 #12
0
class MapFission(transformation.SingleStateTransformation):
    """ Implements the MapFission transformation.
        Map fission refers to subsuming a map scope into its internal subgraph,
        essentially replicating the map into maps in all of its internal
        components. This also extends the dimensions of "border" transient
        arrays (i.e., those between the maps), in order to retain program
        semantics after fission.

        There are two cases that match map fission:
        1. A map with an arbitrary subgraph with more than one computational
           (i.e., non-access) node. The use of arrays connecting the
           computational nodes must be limited to the subgraph, and non
           transient arrays may not be used as "border" arrays.
        2. A map with one internal node that is a nested SDFG, in which
           each state matches the conditions of case (1).

        If a map has nested SDFGs in its subgraph, they are not considered in
        the case (1) above, and MapFission must be invoked again on the maps
        with the nested SDFGs in question.
    """
    map_entry = transformation.PatternNode(nodes.EntryNode)
    nested_sdfg = transformation.PatternNode(nodes.NestedSDFG)

    @staticmethod
    def annotates_memlets():
        return False

    @classmethod
    def expressions(cls):
        return [
            sdutil.node_path_graph(cls.map_entry),
            sdutil.node_path_graph(cls.map_entry, cls.nested_sdfg),
        ]

    @staticmethod
    def _components(
            subgraph: gr.SubgraphView) -> List[Tuple[nodes.Node, nodes.Node]]:
        """
        Returns the list of tuples non-array components in this subgraph.
        Each element in the list is a 2 tuple of (input node, output node) of
        the component.
        """
        graph = (subgraph
                 if isinstance(subgraph, sd.SDFGState) else subgraph.graph)
        schildren = subgraph.scope_children()
        ns = [(n, graph.exit_node(n)) if isinstance(n, nodes.EntryNode) else
              (n, n) for n in schildren[None]
              if isinstance(n, (nodes.CodeNode, nodes.EntryNode))]

        return ns

    @staticmethod
    def _border_arrays(sdfg, parent, subgraph):
        """ Returns a set of array names that are local to the fission
            subgraph. """
        nested = isinstance(parent, sd.SDFGState)
        schildren = subgraph.scope_children()
        subset = gr.SubgraphView(parent, schildren[None])
        if nested:
            return set(node.data for node in subset.nodes()
                       if isinstance(node, nodes.AccessNode)
                       and sdfg.arrays[node.data].transient)
        else:
            return set(node.data for node in subset.nodes()
                       if isinstance(node, nodes.AccessNode))

    @staticmethod
    def _internal_border_arrays(total_components, subgraphs):
        """ Returns the set of border arrays that appear between computational
            components (i.e., without sources and sinks). """
        inputs = set()
        outputs = set()

        for components, subgraph in zip(total_components, subgraphs):
            for component_in, component_out in components:
                for e in subgraph.in_edges(component_in):
                    if isinstance(e.src, nodes.AccessNode):
                        inputs.add(e.src.data)
                for e in subgraph.out_edges(component_out):
                    if isinstance(e.dst, nodes.AccessNode):
                        outputs.add(e.dst.data)

        return inputs & outputs

    @staticmethod
    def _outside_map(node, scope_dict, entry_nodes):
        """ Returns True iff node is not in any of the scopes spanned by
            entry_nodes. """
        while scope_dict[node] is not None:
            if scope_dict[node] in entry_nodes:
                return False
            node = scope_dict[node]
        return True

    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        map_node = self.map_entry
        nsdfg_node = None

        # If the map is dynamic-ranged, the resulting border arrays would be
        # dynamically sized
        if sd.has_dynamic_map_inputs(graph, map_node):
            return False

        if expr_index == 0:  # Map with subgraph
            subgraphs = [
                graph.scope_subgraph(map_node,
                                     include_entry=False,
                                     include_exit=False)
            ]
        else:  # Map with nested SDFG
            nsdfg_node = self.nested_sdfg
            # Make sure there are no other internal nodes in the map
            if len(set(e.dst for e in graph.out_edges(map_node))) > 1:
                return False
            subgraphs = list(nsdfg_node.sdfg.nodes())

        # Test subgraphs
        border_arrays = set()
        total_components = []
        for sg in subgraphs:
            components = self._components(sg)
            snodes = sg.nodes()
            # Test that the subgraphs have more than one computational component
            if expr_index == 0 and len(snodes) > 0 and len(components) <= 1:
                return False

            # Test that the components are connected by transients that are not
            # used anywhere else
            border_arrays |= self._border_arrays(
                nsdfg_node.sdfg if expr_index == 1 else sdfg,
                sg if expr_index == 1 else graph, sg)
            total_components.append(components)

            # In nested SDFGs and subgraphs, ensure none of the border
            # values are non-transients
            for array in border_arrays:
                if expr_index == 0:
                    ndesc = sdfg.arrays[array]
                else:
                    ndesc = nsdfg_node.sdfg.arrays[array]

                if ndesc.transient is False:
                    return False

            # In subgraphs, make sure transients are not used/allocated
            # in other scopes or states
            if expr_index == 0:
                # Find all nodes not in subgraph
                not_subgraph = set(
                    n.data for n in graph.nodes()
                    if n not in snodes and isinstance(n, nodes.AccessNode))
                not_subgraph.update(
                    set(n.data for s in sdfg.nodes() if s != graph
                        for n in s.nodes() if isinstance(n, nodes.AccessNode)))

                for _, component_out in components:
                    for e in sg.out_edges(component_out):
                        if isinstance(e.dst, nodes.AccessNode):
                            if e.dst.data in not_subgraph:
                                return False

        # Fail if there are arrays inside the map that are not a direct
        # output of a computational component
        # TODO(later): Support this case? Ambiguous array sizes and memlets
        external_arrays = (
            border_arrays -
            self._internal_border_arrays(total_components, subgraphs))
        if len(external_arrays) > 0:
            return False

        return True

    def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG):
        map_entry = self.map_entry
        map_exit = graph.exit_node(map_entry)
        nsdfg_node: Optional[nodes.NestedSDFG] = None

        # Obtain subgraph to perform fission to
        if self.expr_index == 0:  # Map with subgraph
            subgraphs = [(graph,
                          graph.scope_subgraph(map_entry,
                                               include_entry=False,
                                               include_exit=False))]
            parent = sdfg
        else:  # Map with nested SDFG
            nsdfg_node = self.nested_sdfg
            subgraphs = [(state, state) for state in nsdfg_node.sdfg.nodes()]
            parent = nsdfg_node.sdfg
        modified_arrays = set()

        # Get map information
        outer_map: nodes.Map = map_entry.map
        mapsize = outer_map.range.size()

        # Add new symbols from outer map to nested SDFG
        if self.expr_index == 1:
            map_syms = outer_map.range.free_symbols
            for edge in graph.out_edges(map_entry):
                if edge.data.data:
                    map_syms.update(edge.data.subset.free_symbols)
            for edge in graph.in_edges(map_exit):
                if edge.data.data:
                    map_syms.update(edge.data.subset.free_symbols)
            for sym in map_syms:
                symname = str(sym)
                if symname in outer_map.params:
                    continue
                if symname not in nsdfg_node.symbol_mapping.keys():
                    nsdfg_node.symbol_mapping[symname] = sym
                    nsdfg_node.sdfg.symbols[
                        symname] = graph.symbols_defined_at(
                            nsdfg_node)[symname]

            # Remove map symbols from nested mapping
            for name in outer_map.params:
                if str(name) in nsdfg_node.symbol_mapping:
                    del nsdfg_node.symbol_mapping[str(name)]
                if str(name) in nsdfg_node.sdfg.symbols:
                    del nsdfg_node.sdfg.symbols[str(name)]

        for state, subgraph in subgraphs:
            components = MapFission._components(subgraph)
            sources = subgraph.source_nodes()
            sinks = subgraph.sink_nodes()

            # Collect external edges
            if self.expr_index == 0:
                external_edges_entry = list(state.out_edges(map_entry))
                external_edges_exit = list(state.in_edges(map_exit))
            else:
                external_edges_entry = [
                    e for e in subgraph.edges()
                    if (isinstance(e.src, nodes.AccessNode)
                        and not nsdfg_node.sdfg.arrays[e.src.data].transient)
                ]
                external_edges_exit = [
                    e for e in subgraph.edges()
                    if (isinstance(e.dst, nodes.AccessNode)
                        and not nsdfg_node.sdfg.arrays[e.dst.data].transient)
                ]

            # Map external edges to outer memlets
            edge_to_outer = {}
            for edge in external_edges_entry:
                if self.expr_index == 0:
                    # Subgraphs use the corresponding outer map edges
                    path = state.memlet_path(edge)
                    eindex = path.index(edge)
                    edge_to_outer[edge] = path[eindex - 1]
                else:
                    # Nested SDFGs use the internal map edges of the node
                    outer_edge = next(e for e in graph.in_edges(nsdfg_node)
                                      if e.dst_conn == edge.src.data)
                    edge_to_outer[edge] = outer_edge

            for edge in external_edges_exit:
                if self.expr_index == 0:
                    path = state.memlet_path(edge)
                    eindex = path.index(edge)
                    edge_to_outer[edge] = path[eindex + 1]
                else:
                    # Nested SDFGs use the internal map edges of the node
                    outer_edge = next(e for e in graph.out_edges(nsdfg_node)
                                      if e.src_conn == edge.dst.data)
                    edge_to_outer[edge] = outer_edge

            # Collect all border arrays and code->code edges
            arrays = MapFission._border_arrays(
                nsdfg_node.sdfg if self.expr_index == 1 else sdfg, state,
                subgraph)
            scalars = defaultdict(list)
            for _, component_out in components:
                for e in subgraph.out_edges(component_out):
                    if isinstance(e.dst, nodes.CodeNode):
                        scalars[e.data.data].append(e)

            # Create new arrays for scalars
            for scalar, edges in scalars.items():
                desc = parent.arrays[scalar]
                del parent.arrays[scalar]
                name, newdesc = parent.add_transient(
                    scalar,
                    mapsize,
                    desc.dtype,
                    desc.storage,
                    lifetime=desc.lifetime,
                    debuginfo=desc.debuginfo,
                    allow_conflicts=desc.allow_conflicts,
                    find_new_name=True)

                # Add extra nodes in component boundaries
                for edge in edges:
                    anode = state.add_access(name)
                    sbs = subsets.Range.from_string(','.join(outer_map.params))
                    # Offset memlet by map range begin (to fit the transient)
                    sbs.offset([r[0] for r in outer_map.range], True)
                    state.add_edge(
                        edge.src, edge.src_conn, anode, None,
                        mm.Memlet.simple(
                            name,
                            sbs,
                            num_accesses=outer_map.range.num_elements()))
                    state.add_edge(
                        anode, None, edge.dst, edge.dst_conn,
                        mm.Memlet.simple(
                            name,
                            sbs,
                            num_accesses=outer_map.range.num_elements()))
                    state.remove_edge(edge)

            # Add extra maps around components
            new_map_entries = []
            for component_in, component_out in components:
                me, mx = state.add_map(outer_map.label + '_fission',
                                       [(p, '0:1') for p in outer_map.params],
                                       outer_map.schedule,
                                       unroll=outer_map.unroll,
                                       debuginfo=outer_map.debuginfo)

                # Add dynamic input connectors
                for conn in map_entry.in_connectors:
                    if not conn.startswith('IN_'):
                        me.add_in_connector(conn)

                me.map.range = dcpy(outer_map.range)
                new_map_entries.append(me)

                # Reconnect edges through new map
                for e in state.in_edges(component_in):
                    state.add_edge(me, None, e.dst, e.dst_conn, dcpy(e.data))
                    # Reconnect inner edges at source directly to external nodes
                    if self.expr_index == 0 and e in external_edges_entry:
                        state.add_edge(edge_to_outer[e].src,
                                       edge_to_outer[e].src_conn, me, None,
                                       dcpy(edge_to_outer[e].data))
                    else:
                        state.add_edge(e.src, e.src_conn, me, None,
                                       dcpy(e.data))
                    state.remove_edge(e)
                # Empty memlet edge in nested SDFGs
                if state.in_degree(component_in) == 0:
                    state.add_edge(me, None, component_in, None, mm.Memlet())

                for e in state.out_edges(component_out):
                    state.add_edge(e.src, e.src_conn, mx, None, dcpy(e.data))
                    # Reconnect inner edges at sink directly to external nodes
                    if self.expr_index == 0 and e in external_edges_exit:
                        state.add_edge(mx, None, edge_to_outer[e].dst,
                                       edge_to_outer[e].dst_conn,
                                       dcpy(edge_to_outer[e].data))
                    else:
                        state.add_edge(mx, None, e.dst, e.dst_conn,
                                       dcpy(e.data))
                    state.remove_edge(e)
                # Empty memlet edge in nested SDFGs
                if state.out_degree(component_out) == 0:
                    state.add_edge(component_out, None, mx, None, mm.Memlet())
            # Connect other sources/sinks not in components (access nodes)
            # directly to external nodes
            if self.expr_index == 0:
                for node in sources:
                    if isinstance(node, nodes.AccessNode):
                        for edge in state.in_edges(node):
                            outer_edge = edge_to_outer[edge]
                            memlet = dcpy(edge.data)
                            memlet.subset = subsets.Range(
                                outer_map.range.ranges + memlet.subset.ranges)
                            state.add_edge(outer_edge.src, outer_edge.src_conn,
                                           edge.dst, edge.dst_conn, memlet)

                for node in sinks:
                    if isinstance(node, nodes.AccessNode):
                        for edge in state.out_edges(node):
                            outer_edge = edge_to_outer[edge]
                            state.add_edge(edge.src, edge.src_conn,
                                           outer_edge.dst, outer_edge.dst_conn,
                                           dcpy(outer_edge.data))

            # Augment arrays by prepending map dimensions
            for array in arrays:
                if array in modified_arrays:
                    continue
                desc = parent.arrays[array]
                if isinstance(
                        desc,
                        dt.Scalar):  # Scalar needs to be augmented to an array
                    desc = dt.Array(desc.dtype, desc.shape, desc.transient,
                                    desc.allow_conflicts, desc.storage,
                                    desc.location, desc.strides, desc.offset,
                                    False, desc.lifetime, 0, desc.debuginfo,
                                    desc.total_size, desc.start_offset)
                    parent.arrays[array] = desc
                for sz in reversed(mapsize):
                    desc.strides = [desc.total_size] + list(desc.strides)
                    desc.total_size = desc.total_size * sz

                desc.shape = mapsize + list(desc.shape)
                desc.offset = [0] * len(mapsize) + list(desc.offset)
                modified_arrays.add(array)

            # Fill scope connectors so that memlets can be tracked below
            state.fill_scope_connectors()

            # Correct connectors and memlets in nested SDFGs to account for
            # missing outside map
            if self.expr_index == 1:
                to_correct = ([(e, e.src) for e in external_edges_entry] +
                              [(e, e.dst) for e in external_edges_exit])
                corrected_nodes = set()
                for edge, node in to_correct:
                    if isinstance(node, nodes.AccessNode):
                        if node in corrected_nodes:
                            continue
                        corrected_nodes.add(node)

                        outer_edge = edge_to_outer[edge]
                        desc = parent.arrays[node.data]

                        # Modify shape of internal array to match outer one
                        outer_desc = sdfg.arrays[outer_edge.data.data]
                        if not isinstance(desc, dt.Scalar):
                            desc.shape = outer_desc.shape
                        if isinstance(desc, dt.Array):
                            desc.strides = outer_desc.strides
                            desc.total_size = outer_desc.total_size

                        # Inside the nested SDFG, offset all memlets to include
                        # the offsets from within the map.
                        # NOTE: Relies on propagation to fix outer memlets
                        for internal_edge in state.all_edges(node):
                            for e in state.memlet_tree(internal_edge):
                                e.data.subset.offset(desc.offset, False)
                                e.data.subset = helpers.unsqueeze_memlet(
                                    e.data, outer_edge.data).subset

                        # Only after offsetting memlets we can modify the
                        # overall offset
                        if isinstance(desc, dt.Array):
                            desc.offset = outer_desc.offset

            # Fill in memlet trees for border transients
            # NOTE: Memlet propagation should run to correct the outer edges
            for node in subgraph.nodes():
                if isinstance(node, nodes.AccessNode) and node.data in arrays:
                    for edge in state.all_edges(node):
                        for e in state.memlet_tree(edge):
                            # Prepend map dimensions to memlet
                            e.data.subset = subsets.Range(
                                [(pystr_to_symbolic(d) - r[0],
                                  pystr_to_symbolic(d) - r[0], 1) for d, r in
                                 zip(outer_map.params, outer_map.range)] +
                                e.data.subset.ranges)

        # If nested SDFG, reconnect nodes around map and modify memlets
        if self.expr_index == 1:
            for edge in graph.in_edges(map_entry):
                if not edge.dst_conn or not edge.dst_conn.startswith('IN_'):
                    continue

                # Modify edge coming into nested SDFG to include entire array
                desc = sdfg.arrays[edge.data.data]
                edge.data.subset = subsets.Range.from_array(desc)
                edge.data.num_accesses = edge.data.subset.num_elements()

                # Find matching edge inside map
                inner_edge = next(
                    e for e in graph.out_edges(map_entry)
                    if e.src_conn and e.src_conn[4:] == edge.dst_conn[3:])
                graph.add_edge(edge.src, edge.src_conn, nsdfg_node,
                               inner_edge.dst_conn, dcpy(edge.data))

            for edge in graph.out_edges(map_exit):
                # Modify edge coming out of nested SDFG to include entire array
                desc = sdfg.arrays[edge.data.data]
                edge.data.subset = subsets.Range.from_array(desc)

                # Find matching edge inside map
                inner_edge = next(e for e in graph.in_edges(map_exit)
                                  if e.dst_conn[3:] == edge.src_conn[4:])
                graph.add_edge(nsdfg_node, inner_edge.src_conn, edge.dst,
                               edge.dst_conn, dcpy(edge.data))

        # Remove outer map
        graph.remove_nodes_from([map_entry, map_exit])
예제 #13
0
class AugAssignToWCR(transformation.Transformation):
    """
    Converts an augmented assignment ("a += b", "a = a + b") into a tasklet
    with a write-conflict resolution.
    """
    input = transformation.PatternNode(nodes.AccessNode)
    tasklet = transformation.PatternNode(nodes.Tasklet)
    output = transformation.PatternNode(nodes.AccessNode)
    map_entry = transformation.PatternNode(nodes.MapEntry)
    map_exit = transformation.PatternNode(nodes.MapExit)

    _EXPRESSIONS = ['+', '-', '*', '^', '%']  #, '/']
    _EXPR_MAP = {
        '-': ('+', '-({expr})'),
        '/': ('*', '((decltype({expr}))1)/({expr})')
    }

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(AugAssignToWCR.input, AugAssignToWCR.tasklet,
                                   AugAssignToWCR.output),
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        inarr = graph.node(candidate[AugAssignToWCR.input])
        tasklet: nodes.Tasklet = graph.node(candidate[AugAssignToWCR.tasklet])
        outarr = graph.node(candidate[AugAssignToWCR.output])
        if inarr.data != outarr.data:
            return False

        # Free tasklet
        if expr_index == 0:
            # Only free tasklets supported for now
            if graph.entry_node(tasklet) is not None:
                return False

            inedges = graph.edges_between(inarr, tasklet)
            if len(graph.edges_between(tasklet, outarr)) > 1:
                return False

            # Make sure augmented assignment can be fissioned as necessary
            if any(not isinstance(e.src, nodes.AccessNode)
                   for e in graph.in_edges(tasklet)):
                return False
            if graph.in_degree(inarr) > 0 and graph.out_degree(outarr) > 0:
                return False

            outedge = graph.edges_between(tasklet, outarr)[0]
        else:  # Free map
            me: nodes.MapEntry = graph.node(candidate[AugAssignToWCR.map_entry])
            mx = graph.node(candidate[AugAssignToWCR.map_exit])

            # Only free maps supported for now
            if graph.entry_node(me) is not None:
                return False

            inedges = graph.edges_between(me, tasklet)
            if len(graph.edges_between(tasklet, mx)) > 1:
                return False

            # Currently no fission is supported
            if any(e.src is not me and not isinstance(e.src, nodes.AccessNode)
                   for e in graph.in_edges(me) + graph.in_edges(tasklet)):
                return False
            if graph.in_degree(inarr) > 0:
                return False

            outedge = graph.edges_between(tasklet, mx)[0]

        # Get relevant output connector
        outconn = outedge.src_conn

        ops = '[%s]' % ''.join(
            re.escape(o) for o in AugAssignToWCR._EXPRESSIONS)

        if tasklet.language is dtypes.Language.Python:
            # Expect ast.Assign(ast.Expr())
            return False
        elif tasklet.language is dtypes.Language.CPP:
            cstr = tasklet.code.as_string.strip()
            for edge in inedges:
                # Try to match a single C assignment that can be converted to WCR
                inconn = edge.dst_conn
                lhs = r'^\s*%s\s*=\s*%s\s*%s.*;$' % (re.escape(outconn),
                                                     re.escape(inconn), ops)
                rhs = r'^\s*%s\s*=\s*.*%s\s*%s;$' % (re.escape(outconn), ops,
                                                     re.escape(inconn))
                if re.match(lhs, cstr) is None:
                    continue
                # Same memlet
                if edge.data.subset != outedge.data.subset:
                    continue

                # If in map, only match if the subset is independent of any
                # map indices (otherwise no conflict)
                if (expr_index == 1
                        and len(outedge.data.subset.free_symbols
                                & set(me.map.params)) == len(me.map.params)):
                    continue

                return True
        else:
            # Only Python/C++ tasklets supported
            return False

        return False

    def apply(self, sdfg: SDFG):
        input: nodes.AccessNode = self.input(sdfg)
        tasklet: nodes.Tasklet = self.tasklet(sdfg)
        output: nodes.AccessNode = self.output(sdfg)
        state: SDFGState = sdfg.node(self.state_id)

        # If state fission is necessary to keep semantics, do it first
        if (self.expr_index == 0 and state.in_degree(input) > 0
                and state.out_degree(output) == 0):
            newstate = sdfg.add_state_after(state)
            newstate.add_node(tasklet)
            new_input, new_output = None, None

            # Keep old edges for after we remove tasklet from the original state
            in_edges = list(state.in_edges(tasklet))
            out_edges = list(state.out_edges(tasklet))

            for e in in_edges:
                r = newstate.add_read(e.src.data)
                newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data)
                if e.src is input:
                    new_input = r
            for e in out_edges:
                w = newstate.add_write(e.dst.data)
                newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data)
                if e.dst is output:
                    new_output = w

            # Remove tasklet and resulting isolated nodes
            state.remove_node(tasklet)
            for e in in_edges:
                if state.degree(e.src) == 0:
                    state.remove_node(e.src)
            for e in out_edges:
                if state.degree(e.dst) == 0:
                    state.remove_node(e.dst)

            # Reset state and nodes for rest of transformation
            input = new_input
            output = new_output
            state = newstate
        # End of state fission

        if self.expr_index == 0:
            inedges = state.edges_between(input, tasklet)
            outedge = state.edges_between(tasklet, output)[0]
        else:
            me = self.map_entry(sdfg)
            mx = self.map_exit(sdfg)

            inedges = state.edges_between(me, tasklet)
            outedge = state.edges_between(tasklet, mx)[0]

        # Get relevant output connector
        outconn = outedge.src_conn

        ops = '[%s]' % ''.join(
            re.escape(o) for o in AugAssignToWCR._EXPRESSIONS)

        # Change tasklet code
        if tasklet.language is dtypes.Language.Python:
            raise NotImplementedError
        elif tasklet.language is dtypes.Language.CPP:
            cstr = tasklet.code.as_string.strip()
            for edge in inedges:
                inconn = edge.dst_conn
                match = re.match(
                    r'^\s*%s\s*=\s*%s\s*(%s)(.*);$' %
                    (re.escape(outconn), re.escape(inconn), ops), cstr)
                if match is None:
                    # match = re.match(
                    #     r'^\s*%s\s*=\s*(.*)\s*(%s)\s*%s;$' %
                    #     (re.escape(outconn), ops, re.escape(inconn)), cstr)
                    # if match is None:
                    continue
                    # op = match.group(2)
                    # expr = match.group(1)
                else:
                    op = match.group(1)
                    expr = match.group(2)

                if edge.data.subset != outedge.data.subset:
                    continue


                # Map asymmetric WCRs to symmetric ones if possible
                if op in AugAssignToWCR._EXPR_MAP:
                    op, newexpr = AugAssignToWCR._EXPR_MAP[op]
                    expr = newexpr.format(expr=expr)

                tasklet.code.code = '%s = %s;' % (outconn, expr)
                inedge = edge
                break
        else:
            raise NotImplementedError

        # Change output edge
        outedge.data.wcr = f'lambda a,b: a {op} b'

        if self.expr_index == 0:
            # Remove input node and connector
            state.remove_edge_and_connectors(inedge)
            if state.degree(input) == 0:
                state.remove_node(input)
        else:
            # Remove input edge and dst connector, but not necessarily src
            state.remove_memlet_path(inedge)

        # If outedge leads to non-transient, and this is a nested SDFG,
        # propagate outwards
        sd = sdfg
        while (not sd.arrays[outedge.data.data].transient
               and sd.parent_nsdfg_node is not None):
            nsdfg = sd.parent_nsdfg_node
            nstate = sd.parent
            sd = sd.parent_sdfg
            outedge = next(
                iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data)))
            for outedge in nstate.memlet_path(outedge):
                outedge.data.wcr = f'lambda a,b: a {op} b'
예제 #14
0
class PruneConnectors(pm.Transformation):
    """ Removes unused connectors from nested SDFGs, as well as their memlets
        in the outer scope, replacing them with empty memlets if necessary.
    """

    nsdfg = pm.PatternNode(nodes.NestedSDFG)

    @staticmethod
    def expressions():
        return [utils.node_path_graph(PruneConnectors.nsdfg)]

    @staticmethod
    def can_be_applied(graph: Union[SDFG, SDFGState],
                       candidate: Dict[pm.PatternNode, int],
                       expr_index: int,
                       sdfg: SDFG,
                       strict: bool = False) -> bool:

        nsdfg = graph.node(candidate[PruneConnectors.nsdfg])

        read_set, write_set = nsdfg.sdfg.read_and_write_sets()
        prune_in = nsdfg.in_connectors.keys() - read_set
        prune_out = nsdfg.out_connectors.keys() - write_set

        # Add WCR outputs to "do not prune" input list
        for e in graph.out_edges(nsdfg):
            if e.data.wcr is not None and e.src_conn in prune_in:
                if (graph.in_degree(
                        next(
                            iter(graph.in_edges_by_connector(
                                nsdfg, e.src_conn))).src) > 0):
                    prune_in.remove(e.src_conn)
        has_before = any(
            graph.in_degree(graph.memlet_path(e)[0].src) > 0
            for e in graph.in_edges(nsdfg) if e.dst_conn in prune_in)
        has_after = any(
            graph.out_degree(graph.memlet_path(e)[-1].dst) > 0
            for e in graph.out_edges(nsdfg) if e.src_conn in prune_out)
        if has_before or has_after:
            return False
        if len(prune_in) > 0 or len(prune_out) > 0:
            return True

        return False

    def apply(self, sdfg: SDFG) -> Union[Any, None]:

        state = sdfg.node(self.state_id)
        nsdfg = self.nsdfg(sdfg)

        read_set, write_set = nsdfg.sdfg.read_and_write_sets()
        prune_in = nsdfg.in_connectors.keys() - read_set
        prune_out = nsdfg.out_connectors.keys() - write_set

        # Detect which nodes are used, so we can delete unused nodes after the
        # connectors have been pruned
        all_data_used = read_set | write_set

        # Add WCR outputs to "do not prune" input list
        for e in state.out_edges(nsdfg):
            if e.data.wcr is not None and e.src_conn in prune_in:
                if (state.in_degree(
                        next(
                            iter(state.in_edges_by_connector(
                                nsdfg, e.src_conn))).src) > 0):
                    prune_in.remove(e.src_conn)

        for conn in prune_in:
            for e in state.in_edges_by_connector(nsdfg, conn):
                state.remove_memlet_path(e, remove_orphans=True)
                if conn in nsdfg.sdfg.arrays and conn not in all_data_used:
                    # If the data is now unused, we can purge it from the SDFG
                    nsdfg.sdfg.remove_data(conn)

        for conn in prune_out:
            for e in state.out_edges_by_connector(nsdfg, conn):
                state.remove_memlet_path(e, remove_orphans=True)
                if conn in nsdfg.sdfg.arrays and conn not in all_data_used:
                    # If the data is now unused, we can purge it from the SDFG
                    nsdfg.sdfg.remove_data(conn)
예제 #15
0
class MapFusion(transformation.Transformation):
    """ Implements the MapFusion transformation.
        It wil check for all patterns MapExit -> AccessNode -> MapEntry, and
        based on the following rules, fuse them and remove the transient in
        between. There are several possibilities of what it does to this
        transient in between.

        Essentially, if there is some other place in the
        sdfg where it is required, or if it is not a transient, then it will
        not be removed. In such a case, it will be linked to the MapExit node
        of the new fused map.

        Rules for fusing maps:
          0. The map range of the second map should be a permutation of the
             first map range.
          1. Each of the access nodes that are adjacent to the first map exit
             should have an edge to the second map entry. If it doesn't, then the
             second map entry should not be reachable from this access node.
          2. Any node that has a wcr from the first map exit should not be
             adjacent to the second map entry.
          3. Access pattern for the access nodes in the second map should be
             the same permutation of the map parameters as the map ranges of the
             two maps. Alternatively, this access node should not be adjacent to
             the first map entry.
    """
    first_map_exit = transformation.PatternNode(nodes.ExitNode)
    array = transformation.PatternNode(nodes.AccessNode)
    second_map_entry = transformation.PatternNode(nodes.EntryNode)

    @staticmethod
    def annotates_memlets():
        return False

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(
                MapFusion.first_map_exit,
                MapFusion.array,
                MapFusion.second_map_entry,
            )
        ]

    @staticmethod
    def find_permutation(first_map: nodes.Map,
                         second_map: nodes.Map) -> Union[List[int], None]:
        """ Find permutation between two map ranges.
            :param first_map: First map.
            :param second_map: Second map.
            :return: None if no such permutation exists, otherwise a list of
                     indices L such that L[x]'th parameter of second map has the same range as x'th
                     parameter of the first map.
            """
        result = []

        if len(first_map.range) != len(second_map.range):
            return None

        # Match map ranges with reduce ranges
        for i, tmap_rng in enumerate(first_map.range):
            found = False
            for j, rng in enumerate(second_map.range):
                if tmap_rng == rng and j not in result:
                    result.append(j)
                    found = True
                    break
            if not found:
                break

        # Ensure all map ranges matched
        if len(result) != len(first_map.range):
            return None

        return result

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        first_map_exit = graph.nodes()[candidate[MapFusion.first_map_exit]]
        first_map_entry = graph.entry_node(first_map_exit)
        second_map_entry = graph.nodes()[candidate[MapFusion.second_map_entry]]
        second_map_exit = graph.exit_node(second_map_entry)

        for _in_e in graph.in_edges(first_map_exit):
            if _in_e.data.wcr is not None:
                for _out_e in graph.out_edges(second_map_entry):
                    if _out_e.data.data == _in_e.data.data:
                        # wcr is on a node that is used in the second map, quit
                        return False
        # Check whether there is a pattern map -> access -> map.
        intermediate_nodes = set()
        intermediate_data = set()
        for _, _, dst, _, _ in graph.out_edges(first_map_exit):
            if isinstance(dst, nodes.AccessNode):
                intermediate_nodes.add(dst)
                intermediate_data.add(dst.data)

                # If array is used anywhere else in this state.
                num_occurrences = len([
                    n for s in sdfg.nodes() for n in s.nodes()
                    if isinstance(n, nodes.AccessNode) and n.data == dst.data
                ])
                if num_occurrences > 1:
                    return False
            else:
                return False
        # Check map ranges
        perm = MapFusion.find_permutation(first_map_entry.map,
                                          second_map_entry.map)
        if perm is None:
            return False

        # Check if any intermediate transient is also going to another location
        second_inodes = set(e.src for e in graph.in_edges(second_map_entry)
                            if isinstance(e.src, nodes.AccessNode))
        transients_to_remove = intermediate_nodes & second_inodes
        # if any(e.dst != second_map_entry for n in transients_to_remove
        #        for e in graph.out_edges(n)):
        if any(graph.out_degree(n) > 1 for n in transients_to_remove):
            return False

        # Create a dict that maps parameters of the first map to those of the
        # second map.
        params_dict = {}
        for _index, _param in enumerate(second_map_entry.map.params):
            params_dict[_param] = first_map_entry.map.params[perm[_index]]
        # Create intermediate dicts to avoid conflicts, such as {i:j, j:i}
        repldict = {
            symbolic.pystr_to_symbolic(k):
            symbolic.pystr_to_symbolic('__dacesym_' + str(v))
            for k, v in params_dict.items()
        }
        repldict_inv = {
            symbolic.pystr_to_symbolic('__dacesym_' + str(v)):
            symbolic.pystr_to_symbolic(v)
            for v in params_dict.values()
        }

        out_memlets = [e.data for e in graph.in_edges(first_map_exit)]

        # Check that input set of second map is provided by the output set
        # of the first map, or other unrelated maps
        for second_edge in graph.out_edges(second_map_entry):
            # Memlets that do not come from one of the intermediate arrays
            if second_edge.data.data not in intermediate_data:
                # however, if intermediate_data eventually leads to
                # second_memlet.data, need to fail.
                for _n in intermediate_nodes:
                    source_node = _n
                    destination_node = graph.memlet_path(second_edge)[0].src
                    # NOTE: Assumes graph has networkx version
                    if destination_node in nx.descendants(
                            graph._nx, source_node):
                        return False
                continue

            provided = False

            # Compute second subset with respect to first subset's symbols
            sbs_permuted = dcpy(second_edge.data.subset)
            if sbs_permuted:
                sbs_permuted.replace(repldict)
                sbs_permuted.replace(repldict_inv)

            for first_memlet in out_memlets:
                if first_memlet.data != second_edge.data.data:
                    continue

                # If there is a covered subset, it is provided
                if first_memlet.subset.covers(sbs_permuted):
                    provided = True
                    break

            # If none of the output memlets of the first map provide the info,
            # fail.
            if provided is False:
                return False

        # Checking for stencil pattern and common input/output data
        # (after fusing the maps)
        first_map_inputnodes = {
            e.src: e.src.data
            for e in graph.in_edges(first_map_entry)
            if isinstance(e.src, nodes.AccessNode)
        }
        input_views = set()
        viewed_inputnodes = dict()
        for n in first_map_inputnodes.keys():
            if isinstance(n.desc(sdfg), data.View):
                input_views.add(n)
        for v in input_views:
            del first_map_inputnodes[v]
            e = sdutil.get_view_edge(graph, v)
            if e:
                first_map_inputnodes[e.src] = e.src.data
                viewed_inputnodes[e.src.data] = v
        second_map_outputnodes = {
            e.dst: e.dst.data
            for e in graph.out_edges(second_map_exit)
            if isinstance(e.dst, nodes.AccessNode)
        }
        output_views = set()
        viewed_outputnodes = dict()
        for n in second_map_outputnodes:
            if isinstance(n.desc(sdfg), data.View):
                output_views.add(n)
        for v in output_views:
            del second_map_outputnodes[v]
            e = sdutil.get_view_edge(graph, v)
            if e:
                second_map_outputnodes[e.dst] = e.dst.data
                viewed_outputnodes[e.dst.data] = v
        common_data = set(first_map_inputnodes.values()).intersection(
            set(second_map_outputnodes.values()))
        if common_data:
            input_data = [
                viewed_inputnodes[d].data
                if d in viewed_inputnodes.keys() else d for d in common_data
            ]
            input_accesses = [
                graph.memlet_path(e)[-1].data.src_subset
                for e in graph.out_edges(first_map_entry)
                if e.data.data in input_data
            ]
            if len(input_accesses) > 1:
                for i, a in enumerate(input_accesses[:-1]):
                    for b in input_accesses[i + 1:]:
                        if isinstance(a, subsets.Indices):
                            c = subsets.Range.from_indices(a)
                            c.offset(b, negative=True)
                        else:
                            c = a.offset_new(b, negative=True)
                        for r in c:
                            if r != (0, 0, 1):
                                return False

            output_data = [
                viewed_outputnodes[d].data
                if d in viewed_outputnodes.keys() else d for d in common_data
            ]
            output_accesses = [
                graph.memlet_path(e)[0].data.dst_subset
                for e in graph.in_edges(second_map_exit)
                if e.data.data in output_data
            ]

            # Compute output accesses with respect to first map's symbols
            oacc_permuted = [dcpy(a) for a in output_accesses]
            for a in oacc_permuted:
                a.replace(repldict)
                a.replace(repldict_inv)

            a = input_accesses[0]
            for b in oacc_permuted:
                if isinstance(a, subsets.Indices):
                    c = subsets.Range.from_indices(a)
                    c.offset(b, negative=True)
                else:
                    c = a.offset_new(b, negative=True)
                for r in c:
                    if r != (0, 0, 1):
                        return False

        # Success
        return True

    @staticmethod
    def match_to_str(graph, candidate):
        first_exit = graph.nodes()[candidate[MapFusion.first_map_exit]]
        second_entry = graph.nodes()[candidate[MapFusion.second_map_entry]]

        return " -> ".join(entry.map.label + ": " + str(entry.map.params)
                           for entry in [first_exit, second_entry])

    def apply(self, sdfg):
        """
            This method applies the mapfusion transformation.
            Other than the removal of the second map entry node (SME), and the first
            map exit (FME) node, it has the following side effects:

            1.  Any transient adjacent to both FME and SME with degree = 2 will be removed.
                The tasklets that use/produce it shall be connected directly with a
                scalar/new transient (if the dataflow is more than a single scalar)

            2.  If this transient is adjacent to FME and SME and has other
                uses, it will be adjacent to the new map exit post fusion.
                Tasklet-> Tasklet edges will ALSO be added as mentioned above.

            3.  If an access node is adjacent to FME but not SME, it will be
                adjacent to new map exit post fusion.

            4.  If an access node is adjacent to SME but not FME, it will be
                adjacent to the new map entry node post fusion.

        """
        graph: SDFGState = sdfg.nodes()[self.state_id]
        first_exit = graph.nodes()[self.subgraph[MapFusion.first_map_exit]]
        first_entry = graph.entry_node(first_exit)
        second_entry = graph.nodes()[self.subgraph[MapFusion.second_map_entry]]
        second_exit = graph.exit_node(second_entry)

        intermediate_nodes = set()
        for _, _, dst, _, _ in graph.out_edges(first_exit):
            intermediate_nodes.add(dst)
            assert isinstance(dst, nodes.AccessNode)

        # Check if an access node refers to non transient memory, or transient
        # is used at another location (cannot erase)
        do_not_erase = set()
        for node in intermediate_nodes:
            if sdfg.arrays[node.data].transient is False:
                do_not_erase.add(node)
            else:
                for edge in graph.in_edges(node):
                    if edge.src != first_exit:
                        do_not_erase.add(node)
                        break
                else:
                    for edge in graph.out_edges(node):
                        if edge.dst != second_entry:
                            do_not_erase.add(node)
                            break

        # Find permutation between first and second scopes
        perm = MapFusion.find_permutation(first_entry.map, second_entry.map)
        params_dict = {}
        for index, param in enumerate(first_entry.map.params):
            params_dict[param] = second_entry.map.params[perm[index]]

        # Replaces (in memlets and tasklet) the second scope map
        # indices with the permuted first map indices.
        # This works in two passes to avoid problems when e.g., exchanging two
        # parameters (instead of replacing (j,i) and (i,j) to (j,j) and then
        # i,i).
        second_scope = graph.scope_subgraph(second_entry)
        for firstp, secondp in params_dict.items():
            if firstp != secondp:
                replace(second_scope, secondp, '__' + secondp + '_fused')
        for firstp, secondp in params_dict.items():
            if firstp != secondp:
                replace(second_scope, '__' + secondp + '_fused', firstp)

        # Isolate First exit node
        ############################
        edges_to_remove = set()
        nodes_to_remove = set()
        for edge in graph.in_edges(first_exit):
            tree = graph.memlet_tree(edge)
            access_node = tree.root().edge.dst
            if access_node not in do_not_erase:
                out_edges = [
                    e for e in graph.out_edges(access_node)
                    if e.dst == second_entry
                ]
                # In this transformation, there can only be one edge to the
                # second map
                assert len(out_edges) == 1

                # Get source connector to the second map
                connector = out_edges[0].dst_conn[3:]

                new_dsts = []
                # Look at the second map entry out-edges to get the new
                # destinations
                for e in graph.out_edges(second_entry):
                    if e.src_conn[4:] == connector:
                        new_dsts.append(e)
                if not new_dsts:  # Access node is not used in the second map
                    nodes_to_remove.add(access_node)
                    continue

                # If the source is an access node, modify the memlet to point
                # to it
                if (isinstance(edge.src, nodes.AccessNode)
                        and edge.data.data != edge.src.data):
                    edge.data.data = edge.src.data
                    edge.data.subset = ("0" if edge.data.other_subset is None
                                        else edge.data.other_subset)
                    edge.data.other_subset = None

                else:
                    # Add a transient scalar/array
                    self.fuse_nodes(sdfg, graph, edge, new_dsts[0].dst,
                                    new_dsts[0].dst_conn, new_dsts[1:])

                edges_to_remove.add(edge)

                # Remove transient node between the two maps
                nodes_to_remove.add(access_node)
            else:  # The case where intermediate array node cannot be removed
                # Node will become an output of the second map exit
                out_e = tree.parent.edge
                conn = second_exit.next_connector()
                graph.add_edge(
                    second_exit,
                    'OUT_' + conn,
                    out_e.dst,
                    out_e.dst_conn,
                    dcpy(out_e.data),
                )
                second_exit.add_out_connector('OUT_' + conn)

                graph.add_edge(edge.src, edge.src_conn, second_exit,
                               'IN_' + conn, dcpy(edge.data))
                second_exit.add_in_connector('IN_' + conn)

                edges_to_remove.add(out_e)
                edges_to_remove.add(edge)

                # If the second map needs this node, link the connector
                # that generated this to the place where it is needed, with a
                # temp transient/scalar for memlet to be generated
                for out_e in graph.out_edges(second_entry):
                    second_memlet_path = graph.memlet_path(out_e)
                    source_node = second_memlet_path[0].src
                    if source_node == access_node:
                        self.fuse_nodes(sdfg, graph, edge, out_e.dst,
                                        out_e.dst_conn)

        ###
        # First scope exit is isolated and can now be safely removed
        for e in edges_to_remove:
            graph.remove_edge(e)
        graph.remove_nodes_from(nodes_to_remove)
        graph.remove_node(first_exit)

        # Isolate second_entry node
        ###########################
        for edge in graph.in_edges(second_entry):
            tree = graph.memlet_tree(edge)
            access_node = tree.root().edge.src
            if access_node in intermediate_nodes:
                # Already handled above, can be safely removed
                graph.remove_edge(edge)
                continue

            # This is an external input to the second map which will now go
            # through the first map.
            conn = first_entry.next_connector()
            graph.add_edge(edge.src, edge.src_conn, first_entry, 'IN_' + conn,
                           dcpy(edge.data))
            first_entry.add_in_connector('IN_' + conn)
            graph.remove_edge(edge)
            for out_enode in tree.children:
                out_e = out_enode.edge
                graph.add_edge(
                    first_entry,
                    'OUT_' + conn,
                    out_e.dst,
                    out_e.dst_conn,
                    dcpy(out_e.data),
                )
                graph.remove_edge(out_e)
            first_entry.add_out_connector('OUT_' + conn)

        ###
        # Second node is isolated and can now be safely removed
        graph.remove_node(second_entry)

        # Fix scope exit to point to the right map
        second_exit.map = first_entry.map

    def fuse_nodes(self,
                   sdfg,
                   graph,
                   edge,
                   new_dst,
                   new_dst_conn,
                   other_edges=None):
        """ Fuses two nodes via memlets and possibly transient arrays. """
        other_edges = other_edges or []
        memlet_path = graph.memlet_path(edge)
        access_node = memlet_path[-1].dst

        local_name = "__s%d_n%d%s_n%d%s" % (
            self.state_id,
            graph.node_id(edge.src),
            edge.src_conn,
            graph.node_id(edge.dst),
            edge.dst_conn,
        )
        # Add intermediate memory between subgraphs. If a scalar,
        # uses direct connection. If an array, adds a transient node
        if edge.data.subset.num_elements() == 1:
            sdfg.add_scalar(
                local_name,
                dtype=access_node.desc(graph).dtype,
                transient=True,
                storage=dtypes.StorageType.Register,
            )
            edge.data.data = local_name
            edge.data.subset = "0"

            # If source of edge leads to multiple destinations,
            # redirect all through an access node
            out_edges = list(
                graph.out_edges_by_connector(edge.src, edge.src_conn))
            if len(out_edges) > 1:
                local_node = graph.add_access(local_name)
                src_connector = None

                # Add edge that leads to transient node
                graph.add_edge(edge.src, edge.src_conn, local_node, None,
                               dcpy(edge.data))

                for other_edge in out_edges:
                    if other_edge is not edge:
                        graph.remove_edge(other_edge)
                        graph.add_edge(local_node, src_connector,
                                       other_edge.dst, other_edge.dst_conn,
                                       other_edge.data)
            else:
                local_node = edge.src
                src_connector = edge.src_conn

            # Add edge that leads to the second node
            graph.add_edge(local_node, src_connector, new_dst, new_dst_conn,
                           dcpy(edge.data))

            for e in other_edges:
                graph.add_edge(local_node, src_connector, e.dst, e.dst_conn,
                               dcpy(edge.data))
        else:
            sdfg.add_transient(local_name,
                               edge.data.subset.size(),
                               dtype=access_node.desc(graph).dtype)
            old_edge = dcpy(edge)
            local_node = graph.add_access(local_name)
            src_connector = None
            edge.data.data = local_name
            edge.data.subset = ",".join(
                ["0:" + str(s) for s in edge.data.subset.size()])
            # Add edge that leads to transient node
            graph.add_edge(
                edge.src,
                edge.src_conn,
                local_node,
                None,
                dcpy(edge.data),
            )

            # Add edge that leads to the second node
            graph.add_edge(local_node, src_connector, new_dst, new_dst_conn,
                           dcpy(edge.data))

            for e in other_edges:
                graph.add_edge(local_node, src_connector, e.dst, e.dst_conn,
                               dcpy(edge.data))

            # Modify data and memlets on all surrounding edges to match array
            for neighbor in graph.all_edges(local_node):
                for e in graph.memlet_tree(neighbor):
                    e.data.data = local_name
                    e.data.subset.offset(old_edge.data.subset, negative=True)
예제 #16
0
class PruneSymbols(pm.Transformation):
    """ 
    Removes unused symbol mappings from nested SDFGs, as well as internal
    symbols if necessary.
    """

    nsdfg = pm.PatternNode(nodes.NestedSDFG)

    @staticmethod
    def expressions():
        return [utils.node_path_graph(PruneSymbols.nsdfg)]

    @staticmethod
    def _candidates(nsdfg: nodes.NestedSDFG) -> Set[str]:
        candidates = set(nsdfg.symbol_mapping.keys())
        if len(candidates) == 0:
            return set()

        for desc in nsdfg.sdfg.arrays.values():
            candidates -= set(map(str, desc.free_symbols))

        ignore = set()
        for nstate in cfg.stateorder_topological_sort(nsdfg.sdfg):
            state_syms = nstate.free_symbols

            # Try to be conservative with C++ tasklets
            for node in nstate.nodes():
                if (isinstance(node, nodes.Tasklet)
                        and node.language is dtypes.Language.CPP):
                    for candidate in candidates:
                        if re.findall(r'\b%s\b' % re.escape(candidate),
                                      node.code.as_string):
                            state_syms.add(candidate)

            # Any symbol used in this state is considered used
            candidates -= (state_syms - ignore)
            if len(candidates) == 0:
                return set()

            # Any symbol that is set in all outgoing edges is ignored from
            # this point
            local_ignore = None
            for e in nsdfg.sdfg.out_edges(nstate):
                # Look for symbols in condition
                candidates -= (set(
                    map(str, symbolic.symbols_in_ast(
                        e.data.condition.code[0]))) - ignore)

                for assign in e.data.assignments.values():
                    candidates -= (
                        symbolic.free_symbols_and_functions(assign) - ignore)

                if local_ignore is None:
                    local_ignore = set(e.data.assignments.keys())
                else:
                    local_ignore &= e.data.assignments.keys()
            if local_ignore is not None:
                ignore |= local_ignore

        return candidates

    @staticmethod
    def can_be_applied(graph: Union[SDFG, SDFGState],
                       candidate: Dict[pm.PatternNode, int],
                       expr_index: int,
                       sdfg: SDFG,
                       strict: bool = False) -> bool:

        nsdfg: nodes.NestedSDFG = graph.node(candidate[PruneSymbols.nsdfg])

        if len(PruneSymbols._candidates(nsdfg)) > 0:
            return True

        return False

    def apply(self, sdfg: SDFG) -> Union[Any, None]:
        nsdfg = self.nsdfg(sdfg)

        candidates = PruneSymbols._candidates(nsdfg)
        for candidate in candidates:
            del nsdfg.symbol_mapping[candidate]

            # If not used in SDFG, remove from symbols as well
            if helpers.is_symbol_unused(nsdfg.sdfg, candidate):
                nsdfg.sdfg.remove_symbol(candidate)
예제 #17
0
class BufferTiling(transformation.SingleStateTransformation):
    """ Implements the buffer tiling transformation.

        BufferTiling tiles a buffer that is in between two maps, where the preceding map
        writes to the buffer and the succeeding map reads from it.
        It introduces additional computations in exchange for reduced memory footprint.
        Commonly used to make use of shared memory on GPUs.
    """

    map1_exit = transformation.PatternNode(nodes.MapExit)
    array = transformation.PatternNode(nodes.AccessNode)
    map2_entry = transformation.PatternNode(nodes.MapEntry)

    tile_sizes = ShapeProperty(dtype=tuple,
                               default=(128, 128, 128),
                               desc="Tile size per dimension")

    # Returns a list of graphs that represent the pattern
    @classmethod
    def expressions(cls):
        return [
            sdutil.node_path_graph(cls.map1_exit, cls.array, cls.map2_entry)
        ]

    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        map1_exit = self.map1_exit
        map2_entry = self.map2_entry

        for buf in graph.all_nodes_between(map1_exit, map2_entry):
            # Check that buffers are AccessNodes.
            if not isinstance(buf, nodes.AccessNode):
                return False

            # Check that buffers are transient.
            if not sdfg.arrays[buf.data].transient:
                return False

            # Check that buffers have exactly 1 input and 1 output edge.
            if graph.in_degree(buf) != 1:
                return False
            if graph.out_degree(buf) != 1:
                return False

            # Check that buffers are next to the maps.
            if graph.in_edges(buf)[0].src != map1_exit:
                return False
            if graph.out_edges(buf)[0].dst != map2_entry:
                return False

            # Check that the data consumed is provided.
            provided = graph.in_edges(buf)[0].data.subset
            consumed = graph.out_edges(buf)[0].data.subset
            if not provided.covers(consumed):
                return False

            # Check that buffers occur only once in this state.
            num_occurrences = len([
                n for n in graph.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == buf
            ])
            if num_occurrences > 1:
                return False
        return True

    def apply(self, graph, sdfg):
        map1_exit = self.map1_exit
        map1_entry = graph.entry_node(map1_exit)
        map2_entry = self.map2_entry
        buffers = graph.all_nodes_between(map1_exit, map2_entry)
        # Situation:
        # -> map1_entry -> ... -> map1_exit -> buffers -> map2_entry -> ...

        lower_extents = tuple(b - a for a, b in zip(
            map1_entry.range.min_element(), map2_entry.range.min_element()))
        upper_extents = tuple(a - b for a, b in zip(
            map1_entry.range.max_element(), map2_entry.range.max_element()))

        # Tile the first map with overlap
        MapTilingWithOverlap.apply_to(sdfg,
                                      map_entry=map1_entry,
                                      options={
                                          'tile_sizes': self.tile_sizes,
                                          'lower_overlap': lower_extents,
                                          'upper_overlap': upper_extents
                                      })
        tile_map1_exit = graph.out_edges(map1_exit)[0].dst
        tile_map1_entry = graph.entry_node(tile_map1_exit)
        tile_map1_entry.label = 'BufferTiling'

        # Tile the second map
        MapTiling.apply_to(sdfg,
                           map_entry=map2_entry,
                           options={
                               'tile_sizes': self.tile_sizes,
                               'tile_trivial': True
                           })
        tile_map2_entry = graph.in_edges(map2_entry)[0].src

        # Fuse maps
        some_buffer = next(
            iter(buffers))  # some dummy to pass to MapFusion.apply_to()
        MapFusion.apply_to(sdfg,
                           first_map_exit=tile_map1_exit,
                           array=some_buffer,
                           second_map_entry=tile_map2_entry)

        # Optimize the simple cases
        map1_entry.range.ranges = [
            (r[0], r[0], r[2]) if l_ext == 0 and u_ext == 0 and ts == 1 else r
            for r, l_ext, u_ext, ts in zip(map1_entry.range.ranges,
                                           lower_extents, upper_extents,
                                           self.tile_sizes)
        ]

        map2_entry.range.ranges = [
            (r[0], r[0], r[2]) if ts == 1 else r
            for r, ts in zip(map2_entry.range.ranges, self.tile_sizes)
        ]

        if any(ts == 1 for ts in self.tile_sizes):
            if any(r[0] == r[1] for r in map1_entry.map.range):
                TrivialMapElimination.apply_to(sdfg, map_entry=map1_entry)
            if any(r[0] == r[1] for r in map2_entry.map.range):
                TrivialMapElimination.apply_to(sdfg, map_entry=map2_entry)
예제 #18
0
class InlineTransients(transformation.Transformation):
    """ 
    Inlines all transient arrays that are not used anywhere else into a 
    nested SDFG.
    """

    nsdfg = transformation.PatternNode(nodes.NestedSDFG)

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(InlineTransients.nsdfg)]

    @staticmethod
    def _candidates(sdfg: SDFG, graph: SDFGState,
                    nsdfg: nodes.NestedSDFG) -> Dict[str, str]:
        candidates = {}
        for e in graph.all_edges(nsdfg):
            if e.data.is_empty():
                continue
            conn = (e.src_conn if e.src is nsdfg else e.dst_conn)
            desc = sdfg.arrays[e.data.data]
            # Needs to be transient
            if not desc.transient:
                continue
            # Needs to be allocated in "Scope" lifetime
            if desc.lifetime is not dtypes.AllocationLifetime.Scope:
                continue
            # If same transient is connected with multiple connectors, bail
            # for now
            if e.data.data in candidates and candidates[e.data.data] != conn:
                del candidates[e.data.data]
                continue
            # (for now) needs to use entire data descriptor (skipped due to
            # above check for multiple connectors)
            # if desc.shape != e.data.subset.size():
            #     continue
            candidates[e.data.data] = conn

        if not candidates:
            return candidates

        # Check for uses in other states
        for state in sdfg.nodes():
            if state is graph:
                continue
            for node in state.data_nodes():
                if node.data in candidates:
                    del candidates[node.data]

        if not candidates:
            return candidates

        # Check for uses in state
        access_nodes = set()
        for e in graph.in_edges(nsdfg):
            src = graph.memlet_path(e)[0].src
            if isinstance(src, nodes.AccessNode) and graph.in_degree(src) == 0:
                access_nodes.add(src)
        for e in graph.out_edges(nsdfg):
            dst = graph.memlet_path(e)[-1].dst
            if isinstance(dst,
                          nodes.AccessNode) and graph.out_degree(dst) == 0:
                access_nodes.add(dst)
        for node in graph.data_nodes():
            if node.data in candidates and node not in access_nodes:
                del candidates[node.data]

        return candidates

    @staticmethod
    def can_be_applied(graph: SDFGState,
                       candidate: Dict[transformation.PatternNode, int],
                       expr_index: int,
                       sdfg: SDFG,
                       strict: bool = False):
        nsdfg = graph.node(candidate[InlineTransients.nsdfg])

        # Not every schedule is supported
        if strict:
            if nsdfg.schedule not in (dtypes.ScheduleType.Default,
                                      dtypes.ScheduleType.Sequential,
                                      dtypes.ScheduleType.CPU_Multicore,
                                      dtypes.ScheduleType.GPU_Device):
                return False

        candidates = InlineTransients._candidates(sdfg, graph, nsdfg)
        return len(candidates) > 0

    @staticmethod
    def match_to_str(graph, candidate):
        return graph.label

    def apply(self, sdfg):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node: nodes.NestedSDFG = self.nsdfg(sdfg)
        nsdfg: SDFG = nsdfg_node.sdfg
        toremove = InlineTransients._candidates(sdfg, state, nsdfg_node)

        for dname, cname in toremove.items():
            # Make nested SDFG data descriptors transient
            nsdfg.arrays[cname].transient = True

            # Remove connectors from node
            nsdfg_node.remove_in_connector(cname)
            nsdfg_node.remove_out_connector(cname)

            # Remove data descriptor from outer SDFG
            del sdfg.arrays[dname]

        # Remove edges from outer SDFG
        for e in state.in_edges(nsdfg_node):
            if e.data.data not in toremove:
                continue
            tree = state.memlet_tree(e)
            for te in tree:
                state.remove_edge_and_connectors(te)
            # Remove newly isolated node
            state.remove_node(tree.root().edge.src)

        for e in state.out_edges(nsdfg_node):
            if e.data.data not in toremove:
                continue
            tree = state.memlet_tree(e)
            for te in tree:
                state.remove_edge_and_connectors(te)
            # Remove newly isolated node
            state.remove_node(tree.root().edge.dst)
예제 #19
0
파일: mapreduce.py 프로젝트: am-ivanov/dace
class MapWCRFusion(pm.SingleStateTransformation):
    """ Implements the map expanded-reduce fusion transformation.
        Fuses a map with an immediately following reduction, where the array
        between the map and the reduction is not used anywhere else, and the
        reduction is divided to two maps with a WCR, denoting partial reduction.
    """

    tasklet = pm.PatternNode(nodes.Tasklet)
    tmap_exit = pm.PatternNode(nodes.MapExit)
    in_array = pm.PatternNode(nodes.AccessNode)
    rmap_in_entry = pm.PatternNode(nodes.MapEntry)
    rmap_in_tasklet = pm.PatternNode(nodes.Tasklet)
    rmap_in_cr = pm.PatternNode(nodes.MapExit)
    rmap_out_entry = pm.PatternNode(nodes.MapEntry)
    rmap_out_exit = pm.PatternNode(nodes.MapExit)
    out_array = pm.PatternNode(nodes.AccessNode)

    @classmethod
    def expressions(cls):
        return [
            # Map, then partial reduction of axes
            sdutil.node_path_graph(cls.tasklet, cls.tmap_exit, cls.in_array,
                                   cls.rmap_out_entry, cls.rmap_in_entry,
                                   cls.rmap_in_tasklet, cls.rmap_in_cr,
                                   cls.rmap_out_exit, cls.out_array)
        ]

    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        tmap_exit = self.tmap_exit
        in_array = self.in_array
        rmap_entry = self.rmap_out_entry

        # Make sure that the array is only accessed by the map and the reduce
        if any([
                src != tmap_exit
                for src, _, _, _, memlet in graph.in_edges(in_array)
        ]):
            return False
        if any([
                dest != rmap_entry
                for _, _, dest, _, memlet in graph.out_edges(in_array)
        ]):
            return False

        # Make sure that there is a reduction in the second map
        rmap_cr = self.rmap_in_cr
        reduce_edge = graph.in_edges(rmap_cr)[0]
        if reduce_edge.data.wcr is None:
            return False

        # Make sure that the transient is not accessed anywhere else
        # in this state or other states
        if not permissive and (len([
                n for n in graph.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == in_array.data
        ]) > 1 or in_array.data in sdfg.shared_transients()):
            return False

        # Verify that reduction ranges match tasklet map
        tout_memlet = graph.in_edges(in_array)[0].data
        rin_memlet = graph.out_edges(in_array)[0].data
        if tout_memlet.subset != rin_memlet.subset:
            return False

        return True

    def match_to_str(self, graph):
        return ' -> '.join(
            str(node)
            for node in [self.tasklet, self.tmap_exit, self.rmap_in_cr])

    def apply(self, graph: SDFGState, sdfg: SDFG):
        # To apply, collapse the second map and then fuse the two resulting maps
        map_collapse = MapCollapse()
        map_collapse.setup_match(
            sdfg, self.sdfg_id, self.state_id, {
                MapCollapse.outer_map_entry: graph.node_id(
                    self.rmap_out_entry),
                MapCollapse.inner_map_entry: graph.node_id(self.rmap_in_entry),
            }, 0)
        map_entry, _ = map_collapse.apply(graph, sdfg)

        map_fusion = MapFusion()
        map_fusion.setup_match(
            sdfg, self.sdfg_id, self.state_id, {
                MapFusion.first_map_exit: graph.node_id(self.tmap_exit),
                MapFusion.second_map_entry: graph.node_id(map_entry),
            }, 0)
        map_fusion.apply(graph, sdfg)
예제 #20
0
class StreamingMemory(xf.Transformation):
    """ 
    Converts a read or a write to streaming memory access, where data is
    read/written to/from a stream in a separate connected component than the
    computation.
    """
    access = xf.PatternNode(nodes.AccessNode)
    entry = xf.PatternNode(nodes.EntryNode)
    exit = xf.PatternNode(nodes.ExitNode)

    buffer_size = properties.Property(
        dtype=int,
        default=1,
        desc='Set buffer size for the newly-created stream')

    storage = properties.EnumProperty(
        dtype=dtypes.StorageType,
        desc='Set storage type for the newly-created stream',
        default=dtypes.StorageType.Default)

    @staticmethod
    def expressions() -> List[gr.SubgraphView]:
        return [
            sdutil.node_path_graph(StreamingMemory.access,
                                   StreamingMemory.entry),
            sdutil.node_path_graph(StreamingMemory.exit,
                                   StreamingMemory.access),
        ]

    @staticmethod
    def can_be_applied(graph: SDFGState,
                       candidate: Dict[xf.PatternNode, int],
                       expr_index: int,
                       sdfg: SDFG,
                       permissive: bool = False) -> bool:
        access = graph.node(candidate[StreamingMemory.access])
        # Make sure the access node is only accessed once (read or write),
        # and not at the same time
        if graph.out_degree(access) > 0 and graph.in_degree(access) > 0:
            return False

        # If already a stream, skip
        if isinstance(sdfg.arrays[access.data], data.Stream):
            return False
        # If does not exist on off-chip memory, skip
        if sdfg.arrays[access.data].storage not in [
                dtypes.StorageType.CPU_Heap, dtypes.StorageType.CPU_Pinned,
                dtypes.StorageType.GPU_Global, dtypes.StorageType.FPGA_Global
        ]:
            return False

        # Only free nodes are allowed (search up the SDFG tree)
        curstate = graph
        node = access
        while curstate is not None:
            if curstate.entry_node(node) is not None:
                return False
            if curstate.parent.parent_nsdfg_node is None:
                break
            node = curstate.parent.parent_nsdfg_node
            curstate = curstate.parent.parent

        # Only one memlet path is allowed per outgoing/incoming edge
        edges = (graph.out_edges(access)
                 if expr_index == 0 else graph.in_edges(access))
        for edge in edges:
            mpath = graph.memlet_path(edge)
            if len(mpath) != len(list(graph.memlet_tree(edge))):
                return False

            # The innermost end of the path must have a clearly defined memory
            # access pattern
            innermost_edge = mpath[-1] if expr_index == 0 else mpath[0]
            if (innermost_edge.data.subset.num_elements() != 1
                    or innermost_edge.data.dynamic
                    or innermost_edge.data.volume != 1):
                return False

            # Check if any of the maps has a dynamic range
            # These cases can potentially work but some nodes (and perhaps
            # tasklets) need to be replicated, which are difficult to track.
            for pe in mpath:
                node = pe.dst if expr_index == 0 else graph.entry_node(pe.src)
                if isinstance(
                        node,
                        nodes.MapEntry) and sdutil.has_dynamic_map_inputs(
                            graph, node):
                    return False

        # If already applied on this memlet and this is the I/O component, skip
        if expr_index == 0:
            other_node = graph.node(candidate[StreamingMemory.entry])
        else:
            other_node = graph.node(candidate[StreamingMemory.exit])
            other_node = graph.entry_node(other_node)
        if other_node.label.startswith('__s'):
            return False

        return True

    def apply(self, sdfg: SDFG) -> nodes.AccessNode:
        state = sdfg.node(self.state_id)
        dnode: nodes.AccessNode = self.access(sdfg)
        if self.expr_index == 0:
            edges = state.out_edges(dnode)
        else:
            edges = state.in_edges(dnode)

        # To understand how many components we need to create, all map ranges
        # throughout memlet paths must match exactly. We thus create a
        # dictionary of unique ranges
        mapping: Dict[Tuple[subsets.Range],
                      List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict(
                          list)
        ranges = {}
        for edge in edges:
            mpath = state.memlet_path(edge)
            ranges[edge] = _collect_map_ranges(state, mpath)
            mapping[tuple(r[1] for r in ranges[edge])].append(edge)

        # Collect all edges with the same memory access pattern
        components_to_create: Dict[
            Tuple[symbolic.SymbolicType],
            List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict(list)
        for edges_with_same_range in mapping.values():
            for edge in edges_with_same_range:
                # Get memlet path and innermost edge
                mpath = state.memlet_path(edge)
                innermost_edge = copy.deepcopy(mpath[-1] if self.expr_index ==
                                               0 else mpath[0])

                # Store memlets of the same access in the same component
                expr = _canonicalize_memlet(innermost_edge.data, ranges[edge])
                components_to_create[expr].append((innermost_edge, edge))
        components = list(components_to_create.values())

        # Split out components that have dependencies between them to avoid
        # deadlocks
        if self.expr_index == 0:
            ccs_to_add = []
            for i, component in enumerate(components):
                edges_to_remove = set()
                for cedge in component:
                    if any(
                            nx.has_path(state.nx, o[1].dst, cedge[1].dst)
                            for o in component if o is not cedge):
                        ccs_to_add.append([cedge])
                        edges_to_remove.add(cedge)
                if edges_to_remove:
                    components[i] = [
                        c for c in component if c not in edges_to_remove
                    ]
            components.extend(ccs_to_add)
        # End of split

        desc = sdfg.arrays[dnode.data]

        # Create new streams of shape 1
        streams = {}
        mpaths = {}
        for edge in edges:
            name, newdesc = sdfg.add_stream(dnode.data,
                                            desc.dtype,
                                            buffer_size=self.buffer_size,
                                            storage=self.storage,
                                            transient=True,
                                            find_new_name=True)
            streams[edge] = name
            mpath = state.memlet_path(edge)
            mpaths[edge] = mpath

            # Replace memlets in path with stream access
            for e in mpath:
                e.data = mm.Memlet(data=name,
                                   subset='0',
                                   other_subset=e.data.other_subset)
                if isinstance(e.src, nodes.NestedSDFG):
                    e.data.dynamic = True
                    _streamify_recursive(e.src, e.src_conn, newdesc)
                if isinstance(e.dst, nodes.NestedSDFG):
                    e.data.dynamic = True
                    _streamify_recursive(e.dst, e.dst_conn, newdesc)

            # Replace access node and memlet tree with one access
            if self.expr_index == 0:
                replacement = state.add_read(name)
                state.remove_edge(edge)
                state.add_edge(replacement, edge.src_conn, edge.dst,
                               edge.dst_conn, edge.data)
            else:
                replacement = state.add_write(name)
                state.remove_edge(edge)
                state.add_edge(edge.src, edge.src_conn, replacement,
                               edge.dst_conn, edge.data)

        # Make read/write components
        ionodes = []
        for component in components:

            # Pick the first edge as the edge to make the component from
            innermost_edge, outermost_edge = component[0]
            mpath = mpaths[outermost_edge]
            mapname = streams[outermost_edge]
            innermost_edge.data.other_subset = None

            # Get edge data and streams
            if self.expr_index == 0:
                opname = 'read'
                path = [e.dst for e in mpath[:-1]]
                rmemlets = [(dnode, '__inp', innermost_edge.data)]
                wmemlets = []
                for i, (_, edge) in enumerate(component):
                    name = streams[edge]
                    ionode = state.add_write(name)
                    ionodes.append(ionode)
                    wmemlets.append(
                        (ionode, '__out%d' % i, mm.Memlet(data=name,
                                                          subset='0')))
                code = '\n'.join('__out%d = __inp' % i
                                 for i in range(len(component)))
            else:
                # More than one input stream might mean a data race, so we only
                # address the first one in the tasklet code
                if len(component) > 1:
                    warnings.warn(
                        f'More than one input found for the same index for {dnode.data}'
                    )
                opname = 'write'
                path = [state.entry_node(e.src) for e in reversed(mpath[1:])]
                wmemlets = [(dnode, '__out', innermost_edge.data)]
                rmemlets = []
                for i, (_, edge) in enumerate(component):
                    name = streams[edge]
                    ionode = state.add_read(name)
                    ionodes.append(ionode)
                    rmemlets.append(
                        (ionode, '__inp%d' % i, mm.Memlet(data=name,
                                                          subset='0')))
                code = '__out = __inp0'

            # Create map structure for read/write component
            maps = []
            for entry in path:
                map: nodes.Map = entry.map
                maps.append(
                    state.add_map(f'__s{opname}_{mapname}',
                                  [(p, r)
                                   for p, r in zip(map.params, map.range)],
                                  map.schedule))
            tasklet = state.add_tasklet(
                f'{opname}_{mapname}',
                {m[1]
                 for m in rmemlets},
                {m[1]
                 for m in wmemlets},
                code,
            )
            for node, cname, memlet in rmemlets:
                state.add_memlet_path(node,
                                      *(me for me, _ in maps),
                                      tasklet,
                                      dst_conn=cname,
                                      memlet=memlet)
            for node, cname, memlet in wmemlets:
                state.add_memlet_path(tasklet,
                                      *(mx for _, mx in reversed(maps)),
                                      node,
                                      src_conn=cname,
                                      memlet=memlet)

        return ionodes
예제 #21
0
class OuterProductOperation(pm.Transformation):
    """ Detects outer-product operations.
    """

    map_entry = pm.PatternNode(nodes.MapEntry)

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(OuterProductOperation.map_entry)]

    @staticmethod
    def can_be_applied(graph: dace.SDFGState,
                       candidate: Dict[pm.PatternNode, int],
                       expr_index: int,
                       sdfg: dace.SDFG,
                       permissive: bool = False):

        map_entry = graph.node(candidate[OuterProductOperation.map_entry])
        map_exit = graph.exit_node(map_entry)
        params = [dace.symbol(p) for p in map_entry.map.params]

        inputs = dict()
        for _, _, _, _, m in graph.out_edges(map_entry):
            if not m.data:
                continue
            desc = sdfg.arrays[m.data]
            if desc not in inputs.keys():
                inputs[desc] = []
            inputs[desc].append(m.subset)

        outer_product_found = False
        for desc, accesses in inputs.items():
            if isinstance(desc, dace.data.Scalar):
                continue
            elif isinstance(desc, (dace.data.Array, dace.data.View)):
                if list(desc.shape) == [1]:
                    continue
                for a in accesses:
                    indices = a.min_element()
                    unmatched_indices = set(params)
                    for idx in indices:
                        if not isinstance(idx, sympy.Symbol):
                            return False
                        if idx in unmatched_indices:
                            unmatched_indices.remove(idx)
                    if len(unmatched_indices) == 0:
                        return False
                    outer_product_found = True
            else:
                return False

        outputs = dict()
        for _, _, _, _, m in graph.in_edges(map_exit):
            if m.wcr:
                return False
            desc = sdfg.arrays[m.data]
            if desc not in outputs.keys():
                outputs[desc] = []
            outputs[desc].append(m.subset)

        for desc, accesses in outputs.items():
            if isinstance(desc, (dace.data.Array, dace.data.View)):
                for a in accesses:
                    if a.num_elements() != 1:
                        return False
                    indices = a.min_element()
                    unmatched_indices = set(params)
                    for idx in indices:
                        if idx in unmatched_indices:
                            unmatched_indices.remove(idx)
                    if len(unmatched_indices) > 0:
                        return False
            else:
                return False

        return outer_product_found

    @staticmethod
    def match_to_str(graph: dace.SDFGState, candidate: Dict[pm.PatternNode,
                                                            int]) -> str:
        map_entry = graph.node(candidate[OuterProductOperation.map_entry])
        return map_entry.map.label + ': ' + str(map_entry.map.params)

    def apply(self, sdfg: dace.SDFG):
        pass
예제 #22
0
class RedundantComm2D(pm.Transformation):
    """ Implements the redundant communication removal transformation,
        applied when data are scattered and immediately gathered,
        but never used anywhere else. """

    in_array = pm.PatternNode(nodes.AccessNode)
    gather = pm.PatternNode(nodes.Tasklet)
    mid_array = pm.PatternNode(nodes.AccessNode)
    scatter = pm.PatternNode(nodes.Tasklet)
    out_array = pm.PatternNode(nodes.AccessNode)

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(RedundantComm2D.in_array,
                                   RedundantComm2D.gather,
                                   RedundantComm2D.mid_array,
                                   RedundantComm2D.scatter,
                                   RedundantComm2D.out_array)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False):
        gather = graph.nodes()[candidate[RedundantComm2D.gather]]
        if '_block_sizes' not in gather.in_connectors:
            return False
        scatter = graph.nodes()[candidate[RedundantComm2D.scatter]]
        if '_gdescriptor' not in scatter.out_connectors:
            return False
        in_array = graph.nodes()[candidate[RedundantComm2D.in_array]]
        out_array = graph.nodes()[candidate[RedundantComm2D.out_array]]
        in_desc = in_array.desc(sdfg)
        out_desc = out_array.desc(sdfg)
        if len(in_desc.shape) != 2:
            return False
        if in_desc.shape == out_desc.shape:
            return True
        return False

    @staticmethod
    def match_to_str(graph, candidate):
        in_array = graph.nodes()[candidate[RedundantComm2D.in_array]]

        return "Remove " + str(in_array)

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        in_array = self.in_array(sdfg)
        gather = self.gather(sdfg)
        mid_array = self.mid_array(sdfg)
        scatter = self.scatter(sdfg)
        out_array = self.out_array(sdfg)

        in_desc = sdfg.arrays[in_array.data]
        out_desc = sdfg.arrays[out_array.data]

        for e in graph.in_edges(gather):
            if e.src != in_array:
                if graph.in_degree(e.src) == 0 and graph.out_degree(
                        e.src) == 1:
                    graph.remove_edge(e)
                    graph.remove_node(e.src)
        for e in graph.out_edges(scatter):
            if e.dst != out_array:
                if graph.in_degree(e.dst) == 1 and graph.out_degree(
                        e.dst) == 0:
                    graph.remove_edge(e)
                    graph.remove_node(e.dst)

        for e in graph.out_edges(out_array):
            path = graph.memlet_tree(e)
            for e2 in path:
                if e2.data.data == out_array.data:
                    e2.data.data = in_array.data
            graph.remove_edge(e)
            graph.add_edge(in_array, None, e.dst, e.dst_conn,
                           dace.Memlet.from_array(in_array, in_desc))

        graph.remove_node(gather)
        graph.remove_node(mid_array)
        graph.remove_node(scatter)
        graph.remove_node(out_array)
예제 #23
0
class ElementWiseArrayOperation2D(pm.Transformation):
    """ Distributes element-wise array operations.
    """

    _map_entry = pm.PatternNode(nodes.MapEntry)

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(ElementWiseArrayOperation2D._map_entry)]

    @staticmethod
    def can_be_applied(graph: dace.SDFGState,
                       candidate: Dict[pm.PatternNode, int],
                       expr_index: int,
                       sdfg: dace.SDFG,
                       permissive: bool = False):

        map_entry = graph.node(
            candidate[ElementWiseArrayOperation2D._map_entry])
        map_exit = graph.exit_node(map_entry)
        params = [dace.symbol(p) for p in map_entry.map.params]
        if len(params) != 2:
            return False

        if "commsize" in map_entry.map.range.free_symbols:
            return False
        if "Px" in map_entry.map.range.free_symbols:
            return False
        if "Py" in map_entry.map.range.free_symbols:
            return False

        inputs = dict()
        for _, _, _, _, m in graph.out_edges(map_entry):
            if not m.data:
                continue
            desc = sdfg.arrays[m.data]
            if desc not in inputs.keys():
                inputs[desc] = []
            inputs[desc].append(m.subset)

        for desc, accesses in inputs.items():
            if isinstance(desc, dace.data.Scalar):
                continue
            elif isinstance(desc, (dace.data.Array, dace.data.View)):
                if list(desc.shape) == [1]:
                    continue
                if len(desc.shape) != 2:
                    return False
                for a in accesses:
                    if a.num_elements() != 1:
                        return False
                    indices = a.min_element()
                    unmatched_indices = set(params)
                    for idx in indices:
                        if idx in unmatched_indices:
                            unmatched_indices.remove(idx)
                    if len(unmatched_indices) > 0:
                        return False
            else:
                return False

        outputs = dict()
        for _, _, _, _, m in graph.in_edges(map_exit):
            if m.wcr:
                return False
            desc = sdfg.arrays[m.data]
            if desc not in outputs.keys():
                outputs[desc] = []
            outputs[desc].append(m.subset)

        for desc, accesses in outputs.items():
            if isinstance(desc, (dace.data.Array, dace.data.View)):
                if len(desc.shape) != 2:
                    return False
                for a in accesses:
                    if a.num_elements() != 1:
                        return False
                    indices = a.min_element()
                    unmatched_indices = set(params)
                    for idx in indices:
                        if idx in unmatched_indices:
                            unmatched_indices.remove(idx)
                    if len(unmatched_indices) > 0:
                        return False
            else:
                return False

        return True

    @staticmethod
    def match_to_str(graph: dace.SDFGState, candidate: Dict[pm.PatternNode,
                                                            int]) -> str:
        map_entry = graph.node(
            candidate[ElementWiseArrayOperation2D._map_entry])
        return map_entry.map.label + ': ' + str(map_entry.map.params)

    def apply(self, sdfg: dace.SDFG):
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[self._map_entry]]
        map_exit = graph.exit_node(map_entry)

        sz = dace.symbol('commsize',
                         dtype=dace.int32,
                         integer=True,
                         positive=True)
        Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True)
        Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True)

        def _prod(sequence):
            return reduce(lambda a, b: a * b, sequence, 1)

        # NOTE: Maps with step in their ranges are currently not supported
        if len(map_entry.map.params) == 2:
            params = map_entry.map.params
            ranges = [None] * 2
            b, e, _ = map_entry.map.range[0]
            ranges[0] = (0, (e - b + 1) / Px - 1, 1)
            b, e, _ = map_entry.map.range[1]
            ranges[1] = (0, (e - b + 1) / Py - 1, 1)
            strides = [1]
        else:
            params = ['__iflat']
            sizes = map_entry.map.range.size_exact()
            total_size = _prod(sizes)
            ranges = [(0, (total_size) / sz - 1, 1)]
            strides = [_prod(sizes[i + 1:]) for i in range(len(sizes))]

        root_name = sdfg.temp_data_name()
        sdfg.add_scalar(root_name, dace.int32, transient=True)
        root_node = graph.add_access(root_name)
        root_tasklet = graph.add_tasklet('_set_root_', {}, {'__out'},
                                         '__out = 0')
        graph.add_edge(root_tasklet, '__out', root_node, None,
                       dace.Memlet.simple(root_name, '0'))

        from dace.libraries.mpi import Bcast
        from dace.libraries.pblas import BlockCyclicScatter, BlockCyclicGather

        inputs = set()
        for src, _, _, _, m in graph.in_edges(map_entry):
            if not isinstance(src, nodes.AccessNode):
                raise NotImplementedError
            desc = src.desc(sdfg)
            if not isinstance(desc, (data.Scalar, data.Array)):
                raise NotImplementedError
            if list(desc.shape) != m.src_subset.size_exact():
                # Second attempt
                # TODO: We need a solution for symbols not matching
                if str(list(desc.shape)) != str(m.src_subset.size_exact()):
                    raise NotImplementedError
            inputs.add(src)

        for inp in inputs:
            desc = inp.desc(sdfg)

            if isinstance(desc, data.Scalar):
                local_access = graph.add_access(inp.data)
                bcast_node = Bcast('_Bcast_')
                graph.add_edge(inp, None, bcast_node, '_inbuffer',
                               dace.Memlet.from_array(inp.data, desc))
                graph.add_edge(root_node, None, bcast_node, '_root',
                               dace.Memlet.simple(root_name, '0'))
                graph.add_edge(bcast_node, '_outbuffer', local_access, None,
                               dace.Memlet.from_array(inp.data, desc))
                for e in graph.edges_between(inp, map_entry):
                    graph.add_edge(local_access, None, map_entry, e.dst_conn,
                                   dace.Memlet.from_array(inp.data, desc))
                    graph.remove_edge(e)

            elif isinstance(desc, data.Array):

                local_name, local_arr = sdfg.add_temp_transient(
                    [(desc.shape[0]) // Px, (desc.shape[1]) // Py],
                    dtype=desc.dtype,
                    storage=desc.storage)
                local_access = graph.add_access(local_name)
                bsizes_name, bsizes_arr = sdfg.add_temp_transient(
                    (2, ), dtype=dace.int32)
                bsizes_access = graph.add_access(bsizes_name)
                bsizes_tasklet = nodes.Tasklet(
                    '_set_bsizes_', {}, {'__out'},
                    "__out[0] = {x}; __out[1] = {y}".format(
                        x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py))
                graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None,
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                gdesc_name, gdesc_arr = sdfg.add_temp_transient(
                    (9, ), dtype=dace.int32)
                gdesc_access = graph.add_access(gdesc_name)
                ldesc_name, ldesc_arr = sdfg.add_temp_transient(
                    (9, ), dtype=dace.int32)
                ldesc_access = graph.add_access(ldesc_name)
                scatter_node = BlockCyclicScatter('_Scatter_')
                graph.add_edge(inp, None, scatter_node, '_inbuffer',
                               dace.Memlet.from_array(inp.data, desc))
                graph.add_edge(bsizes_access, None, scatter_node,
                               '_block_sizes',
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                graph.add_edge(scatter_node, '_outbuffer', local_access, None,
                               dace.Memlet.from_array(local_name, local_arr))
                graph.add_edge(scatter_node, '_gdescriptor', gdesc_access,
                               None,
                               dace.Memlet.from_array(gdesc_name, gdesc_arr))
                graph.add_edge(scatter_node, '_ldescriptor', ldesc_access,
                               None,
                               dace.Memlet.from_array(ldesc_name, ldesc_arr))
                for e in graph.edges_between(inp, map_entry):
                    graph.add_edge(
                        local_access, None, map_entry, e.dst_conn,
                        dace.Memlet.from_array(local_name, local_arr))
                    graph.remove_edge(e)
                for e in graph.out_edges(map_entry):
                    if e.data.data == inp.data:
                        e.data.data = local_name

            else:
                raise NotImplementedError

        outputs = set()
        for _, _, dst, _, m in graph.out_edges(map_exit):
            if not isinstance(dst, nodes.AccessNode):
                raise NotImplementedError
            desc = dst.desc(sdfg)
            if not isinstance(desc, data.Array):
                raise NotImplementedError
            try:
                if list(desc.shape) != m.dst_subset.size_exact():
                    # Second attempt
                    # TODO: We need a solution for symbols not matching
                    if str(list(desc.shape)) != str(m.dst_subset.size_exact()):
                        raise NotImplementedError
            except AttributeError:
                if list(desc.shape) != m.subset.size_exact():
                    # Second attempt
                    # TODO: We need a solution for symbols not matching
                    if str(list(desc.shape)) != str(m.subset.size_exact()):
                        raise NotImplementedError
            outputs.add(dst)

        for out in outputs:
            desc = out.desc(sdfg)
            if isinstance(desc, data.Scalar):
                raise NotImplementedError
            elif isinstance(desc, data.Array):
                local_name, local_arr = sdfg.add_temp_transient(
                    [(desc.shape[0]) // Px, (desc.shape[1]) // Py],
                    dtype=desc.dtype,
                    storage=desc.storage)
                local_access = graph.add_access(local_name)
                bsizes_name, bsizes_arr = sdfg.add_temp_transient(
                    (2, ), dtype=dace.int32)
                bsizes_access = graph.add_access(bsizes_name)
                bsizes_tasklet = nodes.Tasklet(
                    '_set_bsizes_', {}, {'__out'},
                    "__out[0] = {x}; __out[1] = {y}".format(
                        x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py))
                graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None,
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                scatter_node = BlockCyclicGather('_Gather_')
                graph.add_edge(local_access, None, scatter_node, '_inbuffer',
                               dace.Memlet.from_array(local_name, local_arr))
                graph.add_edge(bsizes_access, None, scatter_node,
                               '_block_sizes',
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                graph.add_edge(scatter_node, '_outbuffer', out, None,
                               dace.Memlet.from_array(out.data, desc))

                for e in graph.edges_between(map_exit, out):
                    graph.add_edge(
                        map_exit, e.src_conn, local_access, None,
                        dace.Memlet.from_array(local_name, local_arr))
                    graph.remove_edge(e)
                for e in graph.in_edges(map_exit):
                    if e.data.data == out.data:
                        e.data.data = local_name
            else:
                raise NotImplementedError

        map_entry.map.params = params
        map_entry.map.range = subsets.Range(ranges)
예제 #24
0
class FPGATransformState(transformation.MultiStateTransformation):
    """ Implements the FPGATransformState transformation. """

    state = transformation.PatternNode(sd.SDFGState)

    @classmethod
    def expressions(cls):
        return [sdutil.node_path_graph(cls.state)]

    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        state = self.state

        for node, graph in state.all_nodes_recursive():
            # Consume scopes are currently unsupported
            if isinstance(node, (nodes.ConsumeEntry, nodes.ConsumeExit)):
                return False

            # Streams have strict conditions due to code generator limitations
            if (isinstance(node, nodes.AccessNode) and isinstance(
                    graph.parent.arrays[node.data], data.Stream)):
                nodedesc = graph.parent.arrays[node.data]
                sdict = graph.scope_dict()
                if nodedesc.storage in [
                        dtypes.StorageType.CPU_Heap,
                        dtypes.StorageType.CPU_Pinned,
                        dtypes.StorageType.CPU_ThreadLocal
                ]:
                    return False

                # Cannot allocate FIFO from CPU code
                if sdict[node] is None:
                    return False

                # Arrays of streams cannot have symbolic size on FPGA
                if dace.symbolic.issymbolic(nodedesc.total_size,
                                            graph.parent.constants):
                    return False

                # Streams cannot be unbounded on FPGA
                if nodedesc.buffer_size < 1:
                    return False

        for node in state.nodes():

            if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).storage
                    not in (dtypes.StorageType.Default,
                            dtypes.StorageType.Register)):
                return False

            if not isinstance(node, nodes.MapEntry):
                continue

            map_entry = node
            candidate_map = map_entry.map

            # Map schedules that are disallowed to transform to FPGAs
            if (candidate_map.schedule == dtypes.ScheduleType.MPI
                    or candidate_map.schedule == dtypes.ScheduleType.GPU_Device
                    or candidate_map.schedule
                    == dtypes.ScheduleType.FPGA_Device
                    or candidate_map.schedule
                    == dtypes.ScheduleType.GPU_ThreadBlock):
                return False

            # Recursively check parent for FPGA schedules
            sdict = state.scope_dict()
            current_node = map_entry
            while current_node is not None:
                if (current_node.map.schedule == dtypes.ScheduleType.GPU_Device
                        or current_node.map.schedule
                        == dtypes.ScheduleType.FPGA_Device
                        or current_node.map.schedule
                        == dtypes.ScheduleType.GPU_ThreadBlock):
                    return False
                current_node = sdict[current_node]

        return True

    def apply(self, _, sdfg):
        state = self.state

        # Find source/sink (data) nodes that are relevant outside this FPGA
        # kernel
        shared_transients = set(sdfg.shared_transients())
        input_nodes = [
            n for n in sdutil.find_source_nodes(state)
            if isinstance(n, nodes.AccessNode) and
            (not sdfg.arrays[n.data].transient or n.data in shared_transients)
        ]
        output_nodes = [
            n for n in sdutil.find_sink_nodes(state)
            if isinstance(n, nodes.AccessNode) and
            (not sdfg.arrays[n.data].transient or n.data in shared_transients)
        ]

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        parent_sdfg = {state: sdfg}  # Map states to their parent SDFG
        for node, graph in state.all_nodes_recursive():
            if isinstance(graph, dace.SDFG):
                parent_sdfg[node] = graph
            if isinstance(node, dace.sdfg.nodes.AccessNode):
                for e in graph.in_edges(node):
                    if e.data.wcr is not None:
                        trace = dace.sdfg.trace_nested_access(
                            node, graph, parent_sdfg[graph])
                        for node_trace, memlet_trace, state_trace, sdfg_trace in trace:
                            # Find the name of the accessed node in our scope
                            if state_trace == state and sdfg_trace == sdfg:
                                _, outer_node = node_trace
                                if outer_node is not None:
                                    break
                        else:
                            # This does not trace back to the current state, so
                            # we don't care
                            continue
                        input_nodes.append(outer_node)
                        wcr_input_nodes.add(outer_node)
        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_array[1].location = copy.copy(desc.location)
                    desc.location.clear()
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                mem = memlet.Memlet(data=node.data,
                                    subset=subsets.Range.from_array(desc))
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    sdutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            sdutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, sd.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_array[1].location = copy.copy(desc.location)
                    desc.location.clear()
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                mem = memlet.Memlet(f"fpga_{node.data}", None,
                                    subsets.Range.from_array(desc))
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                sdutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            sdutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, sd.InterstateEdge())

        # propagate memlet info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
        fpga_update(sdfg, state, 0)
예제 #25
0
class RefineNestedAccess(transformation.Transformation):
    """ 
    Reduces memlet shape when a memlet is connected to a nested SDFG, but not
    using all of the contents. Makes the outer memlet smaller in shape and 
    ensures that the offsets in the nested SDFG start with zero.
    This helps with subsequent transformations on the outer SDFGs.

    For example, in the following program::

        @dace.program
        def func_a(y):
            return y[1:5] + 1
        
        @dace.program
        def main(x: dace.float32[N]):
            return func_a(x)

    The memlet pointing to ``func_a`` will contain all of ``x`` (``x[0:N]``), 
    and it is offset to ``y[1:5]`` in the function, with ``y``'s size being
    ``N``. After the transformation, the memlet connected to the nested SDFG of
    ``func_a`` would contain ``x[1:5]`` directly and the internal ``y`` array
    would have a size of 4, accessed as ``y[0:4]``.
    """

    nsdfg = transformation.PatternNode(nodes.NestedSDFG)

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(RefineNestedAccess.nsdfg)]

    @staticmethod
    def _candidates(
        state: SDFGState, nsdfg: nodes.NestedSDFG
    ) -> Tuple[Dict[str, Tuple[Memlet, Set[int]]], Dict[str, Tuple[Memlet,
                                                                   Set[int]]]]:
        in_candidates: Dict[str, Tuple[Memlet, SDFGState, Set[int]]] = {}
        out_candidates: Dict[str, Tuple[Memlet, SDFGState, Set[int]]] = {}
        ignore = set()
        for nstate in nsdfg.sdfg.nodes():
            for dnode in nstate.data_nodes():
                if nsdfg.sdfg.arrays[dnode.data].transient:
                    continue

                # For now we only detect one element
                for e in nstate.in_edges(dnode):
                    # If more than one unique element detected, remove from
                    # candidates
                    if e.data.data in out_candidates:
                        memlet, ns, indices = out_candidates[e.data.data]
                        # Try to find dimensions in which there is a mismatch
                        # and remove them from list
                        for i, (s1, s2) in enumerate(
                                zip(e.data.subset, memlet.subset)):
                            if s1 != s2 and i in indices:
                                indices.remove(i)
                        if len(indices) == 0:
                            ignore.add(e.data.data)
                        out_candidates[e.data.data] = (memlet, ns, indices)
                        continue
                    out_candidates[e.data.data] = (e.data, nstate,
                                                   set(
                                                       range(len(
                                                           e.data.subset))))
                for e in nstate.out_edges(dnode):
                    # If more than one unique element detected, remove from
                    # candidates
                    if e.data.data in in_candidates:
                        memlet, ns, indices = in_candidates[e.data.data]
                        # Try to find dimensions in which there is a mismatch
                        # and remove them from list
                        for i, (s1, s2) in enumerate(
                                zip(e.data.subset, memlet.subset)):
                            if s1 != s2 and i in indices:
                                indices.remove(i)
                        if len(indices) == 0:
                            ignore.add(e.data.data)
                        in_candidates[e.data.data] = (memlet, ns, indices)
                        continue
                    in_candidates[e.data.data] = (e.data, nstate,
                                                  set(range(len(
                                                      e.data.subset))))

        # TODO: Check in_candidates in interstate edges as well

        # Check in/out candidates
        for cand in in_candidates.keys() & out_candidates.keys():
            s1, nstate1, ind1 = in_candidates[cand]
            s2, nstate2, ind2 = out_candidates[cand]
            indices = ind1 & ind2
            if any(s1.subset[ind] != s2.subset[ind] for ind in indices):
                ignore.add(cand)
            in_candidates[cand] = (s1, nstate1, indices)
            out_candidates[cand] = (s2, nstate2, indices)

        # Ensure minimum elements of candidates do not begin with zero
        def _check_cand(candidates, outer_edges):
            for cname, (cand, nstate, indices) in candidates.items():
                if all(me == 0
                       for i, me in enumerate(cand.subset.min_element())
                       if i in indices):
                    ignore.add(cname)
                    continue

                # Ensure outer memlets begin with 0
                outer_edge = next(iter(outer_edges(nsdfg, cname)))
                if any(me != 0 for i, me in enumerate(
                        outer_edge.data.subset.min_element()) if i in indices):
                    ignore.add(cname)
                    continue

                # Check w.r.t. loops
                if len(nstate.ranges) > 0:
                    # Re-annotate loop ranges, in case someone changed them
                    # TODO: Move out of here!
                    nstate.ranges = {}
                    from dace.sdfg.propagation import _annotate_loop_ranges
                    _annotate_loop_ranges(nsdfg.sdfg, [])

                    memlet = propagation.propagate_subset(
                        [cand], nsdfg.sdfg.arrays[cname],
                        sorted(nstate.ranges.keys()),
                        subsets.Range([
                            v.ndrange()[0]
                            for _, v in sorted(nstate.ranges.items())
                        ]))
                    if all(me == 0
                           for i, me in enumerate(memlet.subset.min_element())
                           if i in indices):
                        ignore.add(cname)
                        continue

                    # Modify memlet to propagated one
                    candidates[cname] = (memlet, nstate, indices)
                else:
                    memlet = cand

                # If there are any symbols here that are not defined
                # in "defined_symbols"
                missing_symbols = (memlet.free_symbols -
                                   set(nsdfg.symbol_mapping.keys()))
                if missing_symbols:
                    ignore.add(cname)
                    continue

        _check_cand(in_candidates, state.in_edges_by_connector)
        _check_cand(out_candidates, state.out_edges_by_connector)

        # Return result, filtering out the states
        return ({
            k: (dc(v), ind)
            for k, (v, _, ind) in in_candidates.items() if k not in ignore
        }, {
            k: (dc(v), ind)
            for k, (v, _, ind) in out_candidates.items() if k not in ignore
        })

    @staticmethod
    def can_be_applied(graph: SDFGState,
                       candidate: Dict[transformation.PatternNode, int],
                       expr_index: int,
                       sdfg: SDFG,
                       strict: bool = False):
        nsdfg = graph.node(candidate[RefineNestedAccess.nsdfg])
        ic, oc = RefineNestedAccess._candidates(graph, nsdfg)
        return (len(ic) + len(oc)) > 0

    @staticmethod
    def match_to_str(graph, candidate):
        return graph.label

    def apply(self, sdfg):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node: nodes.NestedSDFG = self.nsdfg(sdfg)
        nsdfg: SDFG = nsdfg_node.sdfg
        torefine_in, torefine_out = RefineNestedAccess._candidates(
            state, nsdfg_node)

        refined = set()

        def _offset_refine(
            torefine: Dict[str, Tuple[Memlet, Set[int]]],
            outer_edges: Callable[[nodes.NestedSDFG, str],
                                  Iterable[MultiConnectorEdge[Memlet]]]):
            # Offset memlets inside negatively by "refine", modify outer
            # memlets to be "refine"
            for aname, (refine, indices) in torefine.items():
                outer_edge = next(iter(outer_edges(nsdfg_node, aname)))
                new_memlet = helpers.unsqueeze_memlet(refine, outer_edge.data)
                outer_edge.data.subset = subsets.Range([
                    ns if i in indices else os for i, (os, ns) in enumerate(
                        zip(outer_edge.data.subset, new_memlet.subset))
                ])
                if aname in refined:
                    continue
                # Refine internal memlets
                for nstate in nsdfg.nodes():
                    for e in nstate.edges():
                        if e.data.data == aname:
                            e.data.subset.offset(refine.subset, True, indices)
                # Refine accesses in interstate edges
                refiner = ASTRefiner(aname, refine.subset, nsdfg, indices)
                for isedge in nsdfg.edges():
                    for k, v in isedge.data.assignments.items():
                        vast = ast.parse(v)
                        refiner.visit(vast)
                        isedge.data.assignments[k] = astutils.unparse(vast)
                    if isedge.data.condition.language is dtypes.Language.Python:
                        for i, stmt in enumerate(isedge.data.condition.code):
                            isedge.data.condition.code[i] = refiner.visit(stmt)
                    else:
                        raise NotImplementedError
                refined.add(aname)

        # Proceed symmetrically on incoming and outgoing edges
        _offset_refine(torefine_in, state.in_edges_by_connector)
        _offset_refine(torefine_out, state.out_edges_by_connector)
예제 #26
0
class MapTiling(transformation.Transformation):
    """ Implements the orthogonal tiling transformation.

        Orthogonal tiling is a type of nested map fission that creates tiles
        in every dimension of the matched Map.
    """

    map_entry = transformation.PatternNode(nodes.MapEntry)

    # Properties
    prefix = Property(dtype=str,
                      default="tile",
                      desc="Prefix for new range symbols")
    tile_sizes = ShapeProperty(dtype=tuple,
                               default=(128, 128, 128),
                               desc="Tile size per dimension")

    strides = ShapeProperty(
        dtype=tuple,
        default=tuple(),
        desc="Tile stride (enables overlapping tiles). If empty, matches tile")

    tile_offset = ShapeProperty(dtype=tuple,
                                default=None,
                                desc="Negative Stride offset per dimension",
                                allow_none=True)

    divides_evenly = Property(dtype=bool,
                              default=False,
                              desc="Tile size divides dimension length evenly")
    tile_trivial = Property(dtype=bool,
                            default=False,
                            desc="Tiles even if tile_size is 1")

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(MapTiling.map_entry)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        return True

    @staticmethod
    def match_to_str(graph, candidate):
        map_entry = graph.nodes()[candidate[MapTiling.map_entry]]
        return map_entry.map.label + ': ' + str(map_entry.map.params)

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]

        tile_strides = self.tile_sizes
        if self.strides is not None and len(self.strides) == len(tile_strides):
            tile_strides = self.strides

        # Retrieve map entry and exit nodes.
        map_entry = graph.nodes()[self.subgraph[MapTiling.map_entry]]
        from dace.transformation.dataflow.map_collapse import MapCollapse
        from dace.transformation.dataflow.strip_mining import StripMining
        stripmine_subgraph = {
            StripMining._map_entry: self.subgraph[MapTiling.map_entry]
        }
        sdfg_id = sdfg.sdfg_id
        last_map_entry = None
        removed_maps = 0

        original_schedule = map_entry.schedule

        for dim_idx in range(len(map_entry.map.params)):
            if dim_idx >= len(self.tile_sizes):
                tile_size = symbolic.pystr_to_symbolic(self.tile_sizes[-1])
                tile_stride = symbolic.pystr_to_symbolic(tile_strides[-1])
            else:
                tile_size = symbolic.pystr_to_symbolic(
                    self.tile_sizes[dim_idx])
                tile_stride = symbolic.pystr_to_symbolic(tile_strides[dim_idx])

            # handle offsets
            if self.tile_offset and dim_idx >= len(self.tile_offset):
                offset = self.tile_offset[-1]
            elif self.tile_offset:
                offset = self.tile_offset[dim_idx]
            else:
                offset = 0

            dim_idx -= removed_maps
            # If tile size is trivial, skip strip-mining map dimension
            if tile_size == map_entry.map.range.size()[dim_idx]:
                continue

            stripmine = StripMining(sdfg_id, self.state_id, stripmine_subgraph,
                                    self.expr_index)

            # Special case: Tile size of 1 should be omitted from inner map
            if tile_size == 1 and tile_stride == 1 and self.tile_trivial == False:
                stripmine.dim_idx = dim_idx
                stripmine.new_dim_prefix = ''
                stripmine.tile_size = str(tile_size)
                stripmine.tile_stride = str(tile_stride)
                stripmine.divides_evenly = True
                stripmine.tile_offset = str(offset)
                stripmine.apply(sdfg)
                removed_maps += 1
            else:
                stripmine.dim_idx = dim_idx
                stripmine.new_dim_prefix = self.prefix
                stripmine.tile_size = str(tile_size)
                stripmine.tile_stride = str(tile_stride)
                stripmine.divides_evenly = self.divides_evenly
                stripmine.tile_offset = str(offset)
                stripmine.apply(sdfg)

            # apply to the new map the schedule of the original one
            map_entry.schedule = original_schedule

            if last_map_entry:
                new_map_entry = graph.in_edges(map_entry)[0].src
                mapcollapse_subgraph = {
                    MapCollapse._outer_map_entry:
                    graph.node_id(last_map_entry),
                    MapCollapse._inner_map_entry: graph.node_id(new_map_entry)
                }
                mapcollapse = MapCollapse(sdfg_id, self.state_id,
                                          mapcollapse_subgraph, 0)
                mapcollapse.apply(sdfg)
            last_map_entry = graph.in_edges(map_entry)[0].src
        return last_map_entry
예제 #27
0
파일: mapreduce.py 프로젝트: am-ivanov/dace
class MapReduceFusion(pm.SingleStateTransformation):
    """ Implements the map-reduce-fusion transformation.
        Fuses a map with an immediately following reduction, where the array
        between the map and the reduction is not used anywhere else.
    """

    no_init = Property(
        dtype=bool,
        default=False,
        desc='If enabled, does not create initialization states '
        'for reduce nodes with identity')

    tasklet = pm.PatternNode(nodes.Tasklet)
    tmap_exit = pm.PatternNode(nodes.MapExit)
    in_array = pm.PatternNode(nodes.AccessNode)

    import dace.libraries.standard as stdlib  # Avoid import loop
    reduce = pm.PatternNode(stdlib.Reduce)

    out_array = pm.PatternNode(nodes.AccessNode)

    @classmethod
    def expressions(cls):
        return [
            sdutil.node_path_graph(cls.tasklet, cls.tmap_exit, cls.in_array,
                                   cls.reduce, cls.out_array)
        ]

    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        tmap_exit = self.tmap_exit
        in_array = self.in_array
        reduce_node = self.reduce
        tasklet = self.tasklet

        # Make sure that the array is only accessed by the map and the reduce
        if any([
                src != tmap_exit
                for src, _, _, _, memlet in graph.in_edges(in_array)
        ]):
            return False
        if any([
                dest != reduce_node
                for _, _, dest, _, memlet in graph.out_edges(in_array)
        ]):
            return False

        tmem = next(e for e in graph.edges_between(tasklet, tmap_exit)
                    if e.data.data == in_array.data).data

        # Make sure that the transient is not accessed anywhere else
        # in this state or other states
        if not permissive and (len([
                n for n in graph.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == in_array.data
        ]) > 1 or in_array.data in sdfg.shared_transients()):
            return False

        # If memlet already has WCR and it is different from reduce node,
        # do not match
        if tmem.wcr is not None and tmem.wcr != reduce_node.wcr:
            return False

        # Verify that reduction ranges match tasklet map
        tout_memlet = graph.in_edges(in_array)[0].data
        rin_memlet = graph.out_edges(in_array)[0].data
        if tout_memlet.subset != rin_memlet.subset:
            return False

        return True

    def match_to_str(self, graph):
        return ' -> '.join(
            str(node) for node in [self.tasklet, self.tmap_exit, self.reduce])

    def apply(self, graph: SDFGState, sdfg: SDFG):
        tmap_exit = self.tmap_exit
        in_array = self.in_array
        reduce_node = self.reduce
        out_array = self.out_array

        # Set nodes to remove according to the expression index
        nodes_to_remove = [in_array]
        nodes_to_remove.append(reduce_node)

        memlet_edge = None
        for edge in graph.in_edges(tmap_exit):
            if edge.data.data == in_array.data:
                memlet_edge = edge
                break
        if memlet_edge is None:
            raise RuntimeError('Reduction memlet cannot be None')

        # Find which indices should be removed from new memlet
        input_edge = graph.in_edges(reduce_node)[0]
        axes = reduce_node.axes or list(range(len(input_edge.data.subset)))
        array_edge = graph.out_edges(reduce_node)[0]

        # Delete relevant edges and nodes
        graph.remove_nodes_from(nodes_to_remove)

        # Delete relevant data descriptors
        for node in set(nodes_to_remove):
            if isinstance(node, nodes.AccessNode):
                # try to delete it
                try:
                    sdfg.remove_data(node.data)
                # will raise ValueError if the datadesc is used somewhere else
                except ValueError:
                    pass

        # Filter out reduced dimensions from subset
        filtered_subset = [
            dim for i, dim in enumerate(memlet_edge.data.subset)
            if i not in axes
        ]
        if len(filtered_subset) == 0:  # Output is a scalar
            filtered_subset = [(0, 0, 1)]

        # Modify edge from tasklet to map exit
        memlet_edge.data.data = out_array.data
        memlet_edge.data.wcr = reduce_node.wcr
        memlet_edge.data.subset = type(
            memlet_edge.data.subset)(filtered_subset)

        # Add edge from map exit to output array
        graph.add_edge(
            memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:], array_edge.dst,
            array_edge.dst_conn,
            Memlet.simple(array_edge.data.data,
                          array_edge.data.subset,
                          num_accesses=array_edge.data.num_accesses,
                          wcr_str=reduce_node.wcr))

        # Add initialization state as necessary
        if not self.no_init and reduce_node.identity is not None:
            init_state = sdfg.add_state_before(graph)
            init_state.add_mapped_tasklet(
                'freduce_init',
                [('o%d' % i, '%s:%s:%s' % (r[0], r[1] + 1, r[2]))
                 for i, r in enumerate(array_edge.data.subset)], {},
                '__out = %s' % reduce_node.identity, {
                    '__out':
                    Memlet.simple(
                        array_edge.data.data, ','.join([
                            'o%d' % i
                            for i in range(len(array_edge.data.subset))
                        ]))
                },
                external_edges=True)
예제 #28
0
class DoubleBuffering(transformation.SingleStateTransformation):
    """ Implements the double buffering pattern, which pipelines reading
        and processing data by creating a second copy of the memory.
        In particular, the transformation takes a 1D map and all internal
        (directly connected) transients, adds an additional dimension of size 2,
        and turns the map into a for loop that processes and reads the data in a
        double-buffered manner. Other memlets will not be transformed.
    """

    map_entry = transformation.PatternNode(nodes.MapEntry)
    transient = transformation.PatternNode(nodes.AccessNode)

    @classmethod
    def expressions(cls):
        return [sdutil.node_path_graph(cls.map_entry, cls.transient)]

    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        map_entry = self.map_entry
        transient = self.transient

        # Only one dimensional maps are allowed
        if len(map_entry.map.params) != 1:
            return False

        # Verify the map can be transformed to a for-loop
        m2for = MapToForLoop()
        m2for.setup_match(
            sdfg, sdfg.sdfg_id, self.state_id,
            {MapToForLoop.map_entry: self.subgraph[DoubleBuffering.map_entry]},
            expr_index)
        if not m2for.can_be_applied(graph, expr_index, sdfg, permissive):
            return False

        # Verify that all directly-connected internal access nodes point to
        # transient arrays
        first = True
        for edge in graph.out_edges(map_entry):
            if isinstance(edge.dst, nodes.AccessNode):
                desc = sdfg.arrays[edge.dst.data]
                if not isinstance(desc, data.Array) or not desc.transient:
                    return False
                else:
                    # To avoid duplicate matches, only match the first transient
                    if first and edge.dst != transient:
                        return False
                    first = False

        return True

    def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG):
        map_entry = self.map_entry

        map_param = map_entry.map.params[0]  # Assuming one dimensional

        ##############################
        # Change condition of loop to one fewer iteration (so that the
        # final one reads from the last buffer)
        map_rstart, map_rend, map_rstride = map_entry.map.range[0]
        map_rend = symbolic.pystr_to_symbolic('(%s) - (%s)' %
                                              (map_rend, map_rstride))
        map_entry.map.range = subsets.Range([(map_rstart, map_rend,
                                              map_rstride)])

        ##############################
        # Gather transients to modify
        transients_to_modify = set(edge.dst.data
                                   for edge in graph.out_edges(map_entry)
                                   if isinstance(edge.dst, nodes.AccessNode))

        # Add dimension to transients and modify memlets
        for transient in transients_to_modify:
            desc: data.Array = sdfg.arrays[transient]
            # Using non-python syntax to ensure properties change
            desc.strides = [desc.total_size] + list(desc.strides)
            desc.shape = [2] + list(desc.shape)
            desc.offset = [0] + list(desc.offset)
            desc.total_size = desc.total_size * 2

        ##############################
        # Modify memlets to use map parameter as buffer index
        modified_subsets = []  # Store modified memlets for final state
        for edge in graph.scope_subgraph(map_entry).edges():
            if edge.data.data in transients_to_modify:
                edge.data.subset = self._modify_memlet(sdfg, edge.data.subset,
                                                       edge.data.data)
                modified_subsets.append(edge.data.subset)
            else:  # Could be other_subset
                path = graph.memlet_path(edge)
                src_node = path[0].src
                dst_node = path[-1].dst

                # other_subset could be None. In that case, recreate from array
                dataname = None
                if (isinstance(src_node, nodes.AccessNode)
                        and src_node.data in transients_to_modify):
                    dataname = src_node.data
                elif (isinstance(dst_node, nodes.AccessNode)
                      and dst_node.data in transients_to_modify):
                    dataname = dst_node.data
                if dataname is not None:
                    subset = (edge.data.other_subset or
                              subsets.Range.from_array(sdfg.arrays[dataname]))
                    edge.data.other_subset = self._modify_memlet(
                        sdfg, subset, dataname)
                    modified_subsets.append(edge.data.other_subset)

        ##############################
        # Turn map into for loop
        map_to_for = MapToForLoop()
        map_to_for.setup_match(
            sdfg, self.sdfg_id, self.state_id,
            {MapToForLoop.map_entry: graph.node_id(self.map_entry)},
            self.expr_index)
        nsdfg_node, nstate = map_to_for.apply(graph, sdfg)

        ##############################
        # Gather node copies and remove memlets
        edges_to_replace = []
        for node in nstate.source_nodes():
            for edge in nstate.out_edges(node):
                if (isinstance(edge.dst, nodes.AccessNode)
                        and edge.dst.data in transients_to_modify):
                    edges_to_replace.append(edge)
                    nstate.remove_edge(edge)
            if nstate.out_degree(node) == 0:
                nstate.remove_node(node)

        ##############################
        # Add initial reads to initial nested state
        initial_state: sd.SDFGState = nsdfg_node.sdfg.start_state
        initial_state.set_label('%s_init' % map_entry.map.label)
        for edge in edges_to_replace:
            initial_state.add_node(edge.src)
            rnode = edge.src
            wnode = initial_state.add_write(edge.dst.data)
            initial_state.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn,
                                   copy.deepcopy(edge.data))

        # All instances of the map parameter in this state become the loop start
        sd.replace(initial_state, map_param, map_rstart)
        # Initial writes go to the appropriate buffer
        init_expr = symbolic.pystr_to_symbolic('(%s / %s) %% 2' %
                                               (map_rstart, map_rstride))
        sd.replace(initial_state, '__dace_db_param', init_expr)

        ##############################
        # Modify main state's memlets

        # Divide by loop stride
        new_expr = symbolic.pystr_to_symbolic('(%s / %s) %% 2' %
                                              (map_param, map_rstride))
        sd.replace(nstate, '__dace_db_param', new_expr)

        ##############################
        # Add the main state's contents to the last state, modifying
        # memlets appropriately.
        final_state: sd.SDFGState = nsdfg_node.sdfg.sink_nodes()[0]
        final_state.set_label('%s_final_computation' % map_entry.map.label)
        dup_nstate = copy.deepcopy(nstate)
        final_state.add_nodes_from(dup_nstate.nodes())
        for e in dup_nstate.edges():
            final_state.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data)

        # If there is a WCR output with transient, only output in last state
        nstate: sd.SDFGState
        for node in nstate.sink_nodes():
            for e in list(nstate.in_edges(node)):
                if e.data.wcr is not None:
                    path = nstate.memlet_path(e)
                    if isinstance(path[0].src, nodes.AccessNode):
                        nstate.remove_memlet_path(e)

        ##############################
        # Add reads into next buffers to main state
        for edge in edges_to_replace:
            rnode = copy.deepcopy(edge.src)
            nstate.add_node(rnode)
            wnode = nstate.add_write(edge.dst.data)
            new_memlet = copy.deepcopy(edge.data)
            if new_memlet.data in transients_to_modify:
                new_memlet.other_subset = self._replace_in_subset(
                    new_memlet.other_subset, map_param,
                    '(%s + %s)' % (map_param, map_rstride))
            else:
                new_memlet.subset = self._replace_in_subset(
                    new_memlet.subset, map_param,
                    '(%s + %s)' % (map_param, map_rstride))

            nstate.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn,
                            new_memlet)

        nstate.set_label('%s_double_buffered' % map_entry.map.label)
        # Divide by loop stride
        new_expr = symbolic.pystr_to_symbolic('((%s / %s) + 1) %% 2' %
                                              (map_param, map_rstride))
        sd.replace(nstate, '__dace_db_param', new_expr)

        # Remove symbol once done
        del nsdfg_node.sdfg.symbols['__dace_db_param']
        del nsdfg_node.symbol_mapping['__dace_db_param']

        return nsdfg_node

    @staticmethod
    def _modify_memlet(sdfg, subset, data_name):
        desc = sdfg.arrays[data_name]
        if len(subset) == len(desc.shape):
            # Already in the right shape, modify new dimension
            subset = list(subset)[1:]

        new_subset = subsets.Range([('__dace_db_param', '__dace_db_param',
                                     1)] + list(subset))
        return new_subset

    @staticmethod
    def _replace_in_subset(subset, string_or_symbol, new_string_or_symbol):
        new_subset = copy.deepcopy(subset)

        repldict = {
            symbolic.pystr_to_symbolic(string_or_symbol):
            symbolic.pystr_to_symbolic(new_string_or_symbol)
        }

        for i, dim in enumerate(new_subset):
            try:
                new_subset[i] = tuple(d.subs(repldict) for d in dim)
            except TypeError:
                new_subset[i] = (dim.subs(repldict)
                                 if symbolic.issymbolic(dim) else dim)

        return new_subset
예제 #29
0
class StreamingComposition(xf.Transformation):
    """ 
    Converts two connected computations (nodes, map scopes) into two separate
    processing elements, with a stream connecting the results. Only applies
    if the memory access patterns of the two computations match.
    """
    first = xf.PatternNode(nodes.Node)
    access = xf.PatternNode(nodes.AccessNode)
    second = xf.PatternNode(nodes.Node)

    buffer_size = properties.Property(
        dtype=int,
        default=1,
        desc='Set buffer size for the newly-created stream')

    storage = properties.EnumProperty(
        dtype=dtypes.StorageType,
        desc='Set storage type for the newly-created stream',
        default=dtypes.StorageType.Default)

    @staticmethod
    def expressions() -> List[gr.SubgraphView]:
        return [
            sdutil.node_path_graph(StreamingComposition.first,
                                   StreamingComposition.access,
                                   StreamingComposition.second)
        ]

    @staticmethod
    def can_be_applied(graph: SDFGState,
                       candidate: Dict[xf.PatternNode, int],
                       expr_index: int,
                       sdfg: SDFG,
                       permissive: bool = False) -> bool:
        access = graph.node(candidate[StreamingComposition.access])
        # Make sure the access node is only accessed once (read or write),
        # and not at the same time
        if graph.in_degree(access) > 1 or graph.out_degree(access) > 1:
            return False

        # If already a stream, skip
        if isinstance(sdfg.arrays[access.data], data.Stream):
            return False

        # Only free nodes are allowed (search up the SDFG tree)
        curstate = graph
        node = access
        while curstate is not None:
            if curstate.entry_node(node) is not None:
                return False
            if curstate.parent.parent_nsdfg_node is None:
                break
            node = curstate.parent.parent_nsdfg_node
            curstate = curstate.parent.parent

        # Array must not be used anywhere else in the state
        if any(n is not access and n.data == access.data
               for n in graph.data_nodes()):
            return False

        # Only one memlet path on each direction is allowed
        # TODO: Relax so that repeated application of
        # transformation would yield additional streams
        first_edge = graph.in_edges(access)[0]
        second_edge = graph.out_edges(access)[0]
        first_mpath = graph.memlet_path(first_edge)
        second_mpath = graph.memlet_path(second_edge)
        if len(first_mpath) != len(list(graph.memlet_tree(first_edge))):
            return False
        if len(second_mpath) != len(list(graph.memlet_tree(second_edge))):
            return False

        # The innermost ends of the paths must have a clearly defined memory
        # access pattern and no WCR
        first_iedge = first_mpath[0]
        second_iedge = second_mpath[-1]
        if first_iedge.data.subset.num_elements() != 1:
            return False
        if first_iedge.data.volume != 1:
            return False
        if first_iedge.data.wcr is not None:
            return False
        if second_iedge.data.subset.num_elements() != 1:
            return False
        if second_iedge.data.volume != 1:
            return False

        ##################################################################
        # The memory access pattern must be exactly the same

        # Collect all maps and ranges
        ranges_first = _collect_map_ranges(graph, first_mpath)
        ranges_second = _collect_map_ranges(graph, second_mpath)

        # Check map ranges
        for (_, frng), (_, srng) in zip(ranges_first, ranges_second):
            if frng != srng:
                return False

        # Check memlets for equivalence
        if len(first_iedge.data.subset) != len(second_iedge.data.subset):
            return False
        if not _do_memlets_correspond(first_iedge.data, second_iedge.data,
                                      ranges_first, ranges_second):
            return False

        return True

    def apply(self, sdfg: SDFG) -> nodes.AccessNode:
        state = sdfg.node(self.state_id)
        access: nodes.AccessNode = self.access(sdfg)

        # Get memlet paths
        first_edge = state.in_edges(access)[0]
        second_edge = state.out_edges(access)[0]
        first_mpath = state.memlet_path(first_edge)
        second_mpath = state.memlet_path(second_edge)

        # Create new stream of shape 1
        desc = sdfg.arrays[access.data]
        name, newdesc = sdfg.add_stream(access.data,
                                        desc.dtype,
                                        buffer_size=self.buffer_size,
                                        storage=self.storage,
                                        transient=True,
                                        find_new_name=True)

        # Remove transient array if possible
        for ostate in sdfg.nodes():
            if ostate is state:
                continue
            if any(n.data == access.data for n in ostate.data_nodes()):
                break
        else:
            del sdfg.arrays[access.data]

        # Replace memlets in path with stream access
        for e in first_mpath:
            e.data = mm.Memlet(data=name, subset='0')
            if isinstance(e.src, nodes.NestedSDFG):
                e.data.dynamic = True
                _streamify_recursive(e.src, e.src_conn, newdesc)
            if isinstance(e.dst, nodes.NestedSDFG):
                e.data.dynamic = True
                _streamify_recursive(e.dst, e.dst_conn, newdesc)
        for e in second_mpath:
            e.data = mm.Memlet(data=name, subset='0')
            if isinstance(e.src, nodes.NestedSDFG):
                e.data.dynamic = True
                _streamify_recursive(e.src, e.src_conn, newdesc)
            if isinstance(e.dst, nodes.NestedSDFG):
                e.data.dynamic = True
                _streamify_recursive(e.dst, e.dst_conn, newdesc)

        # Replace array access node with two stream access nodes
        wnode = state.add_write(name)
        rnode = state.add_read(name)
        state.remove_edge(first_edge)
        state.add_edge(first_edge.src, first_edge.src_conn, wnode,
                       first_edge.dst_conn, first_edge.data)
        state.remove_edge(second_edge)
        state.add_edge(rnode, second_edge.src_conn, second_edge.dst,
                       second_edge.dst_conn, second_edge.data)

        # Remove original access node
        state.remove_node(access)

        return wnode, rnode
예제 #30
0
class MapExpansion(pm.Transformation):
    """ Implements the map-expansion pattern.

        Map-expansion takes an N-dimensional map and expands it to N 
        unidimensional maps.

        New edges abide by the following rules:
          1. If there are no edges coming from the outside, use empty memlets
          2. Edges with IN_* connectors replicate along the maps
          3. Edges for dynamic map ranges replicate until reaching range(s)
    """

    map_entry = pm.PatternNode(nodes.MapEntry)

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(MapExpansion.map_entry)]

    @staticmethod
    def can_be_applied(graph: dace.SDFGState,
                       candidate: Dict[pm.PatternNode, int],
                       expr_index: int,
                       sdfg: dace.SDFG,
                       strict: bool = False):
        # A candidate subgraph matches the map-expansion pattern when it
        # includes an N-dimensional map, with N greater than one.
        map_entry = graph.node(candidate[MapExpansion.map_entry])
        return map_entry.map.get_param_num() > 1

    @staticmethod
    def match_to_str(graph: dace.SDFGState, candidate: Dict[pm.PatternNode,
                                                            int]) -> str:
        map_entry = graph.node(candidate[MapExpansion.map_entry])
        return map_entry.map.label + ': ' + str(map_entry.map.params)

    def apply(self, sdfg: dace.SDFG):
        # Extract the map and its entry and exit nodes.
        graph = sdfg.node(self.state_id)
        map_entry = self.map_entry(sdfg)
        map_exit = graph.exit_node(map_entry)
        current_map = map_entry.map

        # Create new maps
        new_maps = [
            nodes.Map(current_map.label + '_' + str(param), [param],
                      subsets.Range([param_range]),
                      schedule=dtypes.ScheduleType.Sequential) for param,
            param_range in zip(current_map.params[1:], current_map.range[1:])
        ]
        current_map.params = [current_map.params[0]]
        current_map.range = subsets.Range([current_map.range[0]])

        # Create new map entries and exits
        entries = [nodes.MapEntry(new_map) for new_map in new_maps]
        exits = [nodes.MapExit(new_map) for new_map in new_maps]

        # Create edges, abiding by the following rules:
        # 1. If there are no edges coming from the outside, use empty memlets
        # 2. Edges with IN_* connectors replicate along the maps
        # 3. Edges for dynamic map ranges replicate until reaching range(s)
        for edge in graph.out_edges(map_entry):
            graph.remove_edge(edge)
            graph.add_memlet_path(map_entry,
                                  *entries,
                                  edge.dst,
                                  src_conn=edge.src_conn,
                                  memlet=edge.data,
                                  dst_conn=edge.dst_conn)

        # Modify dynamic map ranges
        dynamic_edges = dace.sdfg.dynamic_map_inputs(graph, map_entry)
        for edge in dynamic_edges:
            # Remove old edge and connector
            graph.remove_edge(edge)
            edge.dst.remove_in_connector(edge.dst_conn)

            # Propagate to each range it belongs to
            path = []
            for mapnode in [map_entry] + entries:
                path.append(mapnode)
                if any(edge.dst_conn in map(str, symbolic.symlist(r))
                       for r in mapnode.map.range):
                    graph.add_memlet_path(edge.src,
                                          *path,
                                          memlet=edge.data,
                                          src_conn=edge.src_conn,
                                          dst_conn=edge.dst_conn)

        # Create new map exits
        for edge in graph.in_edges(map_exit):
            graph.remove_edge(edge)
            graph.add_memlet_path(edge.src,
                                  *exits[::-1],
                                  map_exit,
                                  memlet=edge.data,
                                  src_conn=edge.src_conn,
                                  dst_conn=edge.dst_conn)

        from dace.sdfg.scope import ScopeTree
        scope = None
        queue: List[ScopeTree] = graph.scope_leaves()
        while len(queue) > 0:
            tnode = queue.pop()
            if tnode.entry == entries[-1]:
                scope = tnode
                break
            elif tnode.parent is not None:
                queue.append(tnode.parent)
        else:
            raise ValueError('Cannot find scope in state')

        consolidate_edges(sdfg, scope)

        return [map_entry] + entries