Beispiel #1
0
class MapDimInterchange(pm.Transformation):
    """ Implements the map-dimension-interchange pattern.

        Map-dimension-interchange re-orders the dimensions of a map.
    """

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))

    order = ShapeProperty()

    @staticmethod
    def expressions():
        return [nxutil.node_path_graph(MapDimInterchange._map_entry)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        """ A candidate subgraph matches the map-dimension-interchange 
            transformation when a map has at least two dimensions.
        """
        map_entry = graph.nodes()[candidate[MapDimInterchange._map_entry]]
        return map_entry.map.get_param_num() > 1

    @staticmethod
    def match_to_str(graph, candidate):
        map_entry = candidate[MapDimInterchange._map_entry]

        return str(map_entry)

    def apply(self, sdfg):
        """ Reorders the dimensions of the map by reordering the
            parameters and the range of the map as specified through the 
            properties.
        """

        # Extract the map and its entry node.
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[MapDimInterchange._map_entry]]
        current_map = map_entry.map

        order = self.order
        if len(self.order) != current_map.get_param_num():
            # 'order' must be of the same length as the number of map
            # dimensions.
            return

        # Re-order the map dimensions
        current_map.params = [current_map.params[idx] for idx in order]
        current_map.range.reorder(order)

        return

    def __init__(self, *args, **kwargs):
        self.entry = nodes.EntryNode()
        self.tasklet = nodes.Tasklet('_')
        self.exit = nodes.ExitNode()
        self.pairs = None
        super().__init__(*args, **kwargs)

    def modifies_graph(self):
        return True
Beispiel #2
0
class RedundantArrayCopying3(pm.Transformation):
    """ Implements the redundant array removal transformation. Removes multiples
        of array B in pattern MapEntry -> B.
    """

    _arrays_removed = 0
    _map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _out_array = nodes.AccessNode("_")

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(RedundantArrayCopying3._map_entry,
                                   RedundantArrayCopying3._out_array)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_entry = graph.nodes()[candidate[RedundantArrayCopying3._map_entry]]
        out_array = graph.nodes()[candidate[RedundantArrayCopying3._out_array]]

        # Ensure out degree is one (only one target, which is out_array)
        found = 0
        for _, _, dst, _, _ in graph.out_edges(map_entry):
            if (isinstance(dst, nodes.AccessNode) and dst != out_array
                    and dst.data == out_array.data):
                found += 1

        return found > 0

    @staticmethod
    def match_to_str(graph, candidate):
        out_array = graph.nodes()[candidate[RedundantArrayCopying3._out_array]]

        return "Remove " + str(out_array)

    def apply(self, sdfg):
        def gnode(nname):
            return graph.nodes()[self.subgraph[nname]]

        graph = sdfg.nodes()[self.state_id]
        map_entry = gnode(RedundantArrayCopying3._map_entry)
        out_array = gnode(RedundantArrayCopying3._out_array)

        for e1 in graph.out_edges(map_entry):
            dst = e1.dst
            if (isinstance(dst, nodes.AccessNode) and dst != out_array
                    and dst.data == out_array.data):
                for e2 in graph.out_edges(dst):
                    graph.add_edge(out_array, None, e2.dst, e2.dst_conn,
                                   e2.data)
                    graph.remove_edge(e2)
                graph.remove_edge(e1)
                graph.remove_node(dst)
                if Config.get_bool("debugprint"):
                    RedundantArrayCopying3._arrays_removed += 1
Beispiel #3
0
class GPUTransformMap(pattern_matching.Transformation):
    """ Implements the GPUTransformMap transformation.

        Converts a single map to a GPU-scheduled map and creates GPU arrays
        outside it, generating CPU<->GPU memory copies automatically.
    """

    fullcopy = Property(
        desc="Copy whole arrays rather than used subset",
        dtype=bool,
        default=False)

    toplevel_trans = Property(
        desc="Make all GPU transients top-level", dtype=bool, default=False)

    register_trans = Property(
        desc="Make all transients inside GPU maps registers",
        dtype=bool,
        default=False)

    sequential_innermaps = Property(
        desc="Make all internal maps Sequential", dtype=bool, default=False)

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _reduce = nodes.Reduce('lambda: None', None)

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(GPUTransformMap._map_entry),
            nxutil.node_path_graph(GPUTransformMap._reduce)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        if expr_index == 0:
            map_entry = graph.nodes()[candidate[GPUTransformMap._map_entry]]
            candidate_map = map_entry.map

            # Map schedules that are disallowed to transform to GPUs
            if (candidate_map.schedule in [dtypes.ScheduleType.MPI] +
                    dtypes.GPU_SCHEDULES):
                return False
            if sd.is_devicelevel(sdfg, graph, map_entry):
                return False

            # Dynamic map ranges cannot become kernels
            if sd.has_dynamic_map_inputs(graph, map_entry):
                return False

            # Ensure that map does not include internal arrays that are
            # allocated on non-default space
            subgraph = graph.scope_subgraph(map_entry)
            for node in subgraph.nodes():
                if (isinstance(node, nodes.AccessNode) and
                        node.desc(sdfg).storage != dtypes.StorageType.Default
                        and node.desc(sdfg).storage !=
                        dtypes.StorageType.Register):
                    return False

            # If one of the outputs is a stream, do not match
            map_exit = graph.exit_nodes(map_entry)[0]
            for edge in graph.out_edges(map_exit):
                dst = graph.memlet_path(edge)[-1].dst
                if (isinstance(dst, nodes.AccessNode)
                        and isinstance(sdfg.arrays[dst.data], data.Stream)):
                    return False

            return True
        elif expr_index == 1:
            reduce = graph.nodes()[candidate[GPUTransformMap._reduce]]

            # Map schedules that are disallowed to transform to GPUs
            if (reduce.schedule in [dtypes.ScheduleType.MPI] +
                    dtypes.GPU_SCHEDULES):
                return False
            if sd.is_devicelevel(sdfg, graph, reduce):
                return False

            return True

    @staticmethod
    def match_to_str(graph, candidate):
        if GPUTransformMap._reduce in candidate:
            return str(graph.nodes()[candidate[GPUTransformMap._reduce]])
        else:
            return str(graph.nodes()[candidate[GPUTransformMap._map_entry]])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        if self.expr_index == 0:
            map_entry = graph.nodes()[self.subgraph[
                GPUTransformMap._map_entry]]
            nsdfg_node = helpers.nest_state_subgraph(
                sdfg,
                graph,
                graph.scope_subgraph(map_entry),
                full_data=self.fullcopy)
        else:
            cnode = graph.nodes()[self.subgraph[GPUTransformMap._reduce]]
            nsdfg_node = helpers.nest_state_subgraph(
                sdfg,
                graph,
                SubgraphView(graph, [cnode]),
                full_data=self.fullcopy)

        # Avoiding import loops
        from dace.transformation.interstate import GPUTransformSDFG
        transformation = GPUTransformSDFG(0, 0, {}, 0)
        transformation.register_trans = self.register_trans
        transformation.sequential_innermaps = self.sequential_innermaps
        transformation.toplevel_trans = self.toplevel_trans

        transformation.apply(nsdfg_node.sdfg)

        # Inline back as necessary
        sdfg.apply_strict_transformations()
Beispiel #4
0
class InLocalStorage(pattern_matching.Transformation):
    """ Implements the InLocalStorage transformation, which adds a transient
        data node between nested map entry nodes.
    """

    _outer_map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _inner_map_entry = nodes.MapEntry(nodes.Map("", [], []))

    array = Property(
        dtype=str,
        desc="Array to create local storage for (if empty, first available)",
        default=None,
        allow_none=True)

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(InLocalStorage._outer_map_entry,
                                   InLocalStorage._inner_map_entry)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        return True

    @staticmethod
    def match_to_str(graph, candidate):
        outer_map_entry = candidate[InLocalStorage._outer_map_entry]
        inner_map_entry = candidate[InLocalStorage._inner_map_entry]

        return ' -> '.join(
            str(node) for node in [outer_map_entry, inner_map_entry])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        outer_map_entry = graph.nodes()[self.subgraph[
            InLocalStorage._outer_map_entry]]
        inner_map_entry = graph.nodes()[self.subgraph[
            InLocalStorage._inner_map_entry]]

        array = self.array
        if array is None:
            array = graph.edges_between(outer_map_entry,
                                        inner_map_entry)[0].data.data

        original_edge = None
        invariant_memlet = None
        for edge in graph.in_edges(inner_map_entry):
            src = edge.src
            if src != outer_map_entry:
                continue
            memlet = edge.data
            if array == memlet.data:
                original_edge = edge
                invariant_memlet = memlet
                break
        if invariant_memlet is None:
            for edge in graph.in_edges(inner_map_entry):
                src = edge.src
                if src != outer_map_entry:
                    continue
                original_edge = edge
                invariant_memlet = edge.data
                print('WARNING: Array %s not found! Using array %s instead.' %
                      (array, invariant_memlet.data))
                array = invariant_memlet.data
                break
        if invariant_memlet is None:
            raise KeyError('Array %s not found!' % array)

        new_data = sdfg.add_array('trans_' + invariant_memlet.data, [
            symbolic.overapproximate(r)
            for r in invariant_memlet.bounding_box_size()
        ],
                                  sdfg.arrays[invariant_memlet.data].dtype,
                                  transient=True)
        data_node = nodes.AccessNode('trans_' + invariant_memlet.data)

        to_data_mm = copy.deepcopy(invariant_memlet)
        from_data_mm = copy.deepcopy(invariant_memlet)
        from_data_mm.data = data_node.data
        offset = []
        for ind, r in enumerate(invariant_memlet.subset):
            offset.append(r[0])
            if isinstance(invariant_memlet.subset[ind], tuple):
                begin = invariant_memlet.subset[ind][0] - r[0]
                end = invariant_memlet.subset[ind][1] - r[0]
                step = invariant_memlet.subset[ind][2]
                from_data_mm.subset[ind] = (begin, end, step)
            else:
                from_data_mm.subset[ind] -= r[0]
        to_data_mm.other_subset = copy.deepcopy(from_data_mm.subset)

        # Reconnect, assuming one edge to the stream
        graph.remove_edge(original_edge)
        graph.add_edge(outer_map_entry, original_edge.src_conn, data_node,
                       None, to_data_mm)
        graph.add_edge(data_node, None, inner_map_entry,
                       original_edge.dst_conn, from_data_mm)

        for _parent, _, _child, _, memlet in graph.bfs_edges(inner_map_entry,
                                                             reverse=False):
            if memlet.data != array:
                continue
            for ind, r in enumerate(memlet.subset):
                if isinstance(memlet.subset[ind], tuple):
                    begin = r[0] - offset[ind]
                    end = r[1] - offset[ind]
                    step = r[2]
                    memlet.subset[ind] = (begin, end, step)
                else:
                    memlet.subset[ind] -= offset[ind]
            memlet.data = 'trans_' + invariant_memlet.data

        return
Beispiel #5
0
class AccumulateTransient(pattern_matching.Transformation):
    """ Implements the AccumulateTransient transformation, which adds
        transient stream and data nodes between nested maps that lead to a 
        stream. The transient data nodes then act as a local accumulator.
    """

    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))
    _outer_map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(StreamTransient._tasklet,
                                   StreamTransient._map_exit,
                                   StreamTransient._outer_map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        tasklet = graph.nodes()[candidate[StreamTransient._tasklet]]
        map_exit = graph.nodes()[candidate[StreamTransient._map_exit]]

        # Check if there is a streaming output
        for _src, _, dest, _, memlet in graph.out_edges(tasklet):
            if memlet.wcr is not None and dest == map_exit:
                return True

        return False

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[StreamTransient._tasklet]
        map_exit = candidate[StreamTransient._map_exit]
        outer_map_exit = candidate[StreamTransient._outer_map_exit]

        return ' -> '.join(
            str(node) for node in [tasklet, map_exit, outer_map_exit])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        tasklet = graph.nodes()[self.subgraph[StreamTransient._tasklet]]
        map_exit = graph.nodes()[self.subgraph[StreamTransient._map_exit]]
        outer_map_exit = graph.nodes()[self.subgraph[
            StreamTransient._outer_map_exit]]
        memlet = None
        edge = None
        for e in graph.out_edges(tasklet):
            memlet = e.data
            # TODO: What if there's more than one?
            if e.dst == map_exit and e.data.wcr is not None:
                break
        out_memlet = None
        for e in graph.out_edges(map_exit):
            out_memlet = e.data
            if out_memlet.data == memlet.data:
                edge = e
                break
        dataname = memlet.data

        # Create a new node with the same size as the output
        newdata = sdfg.add_array('trans_' + dataname,
                                 sdfg.arrays[memlet.data].shape,
                                 sdfg.arrays[memlet.data].dtype,
                                 transient=True)
        dnode = nodes.AccessNode('trans_' + dataname)

        to_data_mm = copy.deepcopy(memlet)
        to_data_mm.data = dnode.data
        to_data_mm.num_accesses = memlet.num_elements()

        to_exit_mm = copy.deepcopy(out_memlet)
        to_exit_mm.num_accesses = out_memlet.num_elements()
        memlet.data = dnode.data

        # Reconnect, assuming one edge to the stream
        graph.remove_edge(edge)
        graph.add_edge(map_exit, edge.src_conn, dnode, None, to_data_mm)
        graph.add_edge(dnode, None, outer_map_exit, edge.dst_conn, to_exit_mm)

        return

    def modifies_graph(self):
        return True
Beispiel #6
0
class MapTiling(pattern_matching.Transformation):
    """ Implements the orthogonal tiling transformation.

        Orthogonal tiling is a type of nested map fission that creates tiles
        in every dimension of the matched Map.
    """

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))

    # Properties
    prefix = Property(dtype=str,
                      default="tile",
                      desc="Prefix for new range symbols")
    tile_sizes = ShapeProperty(dtype=tuple,
                               default=(128, 128, 128),
                               desc="Tile size per dimension")
    strides = ShapeProperty(
        dtype=tuple,
        default=tuple(),
        desc="Tile stride (enables overlapping tiles). If empty, matches tile")
    divides_evenly = Property(dtype=bool,
                              default=False,
                              desc="Tile size divides dimension length evenly")

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [nxutil.node_path_graph(MapTiling._map_entry)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        return True

    @staticmethod
    def match_to_str(graph, candidate):
        map_entry = graph.nodes()[candidate[MapTiling._map_entry]]
        return map_entry.map.label + ': ' + str(map_entry.map.params)

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]

        tile_strides = self.tile_sizes
        if self.strides is not None and len(self.strides) == len(tile_strides):
            tile_strides = self.strides

        # Retrieve map entry and exit nodes.
        map_entry = graph.nodes()[self.subgraph[MapTiling._map_entry]]
        from dace.transformation.dataflow.map_collapse import MapCollapse
        from dace.transformation.dataflow.strip_mining import StripMining
        stripmine_subgraph = {
            StripMining._map_entry: self.subgraph[MapTiling._map_entry]
        }
        sdfg_id = sdfg.sdfg_list.index(sdfg)
        last_map_entry = None
        removed_maps = 0

        original_schedule = map_entry.schedule

        for dim_idx in range(len(map_entry.map.params)):
            if dim_idx >= len(self.tile_sizes):
                tile_size = symbolic.pystr_to_symbolic(self.tile_sizes[-1])
                tile_stride = symbolic.pystr_to_symbolic(tile_strides[-1])
            else:
                tile_size = symbolic.pystr_to_symbolic(
                    self.tile_sizes[dim_idx])
                tile_stride = symbolic.pystr_to_symbolic(tile_strides[dim_idx])

            dim_idx -= removed_maps
            # If tile size is trivial, skip strip-mining map dimension
            if tile_size == map_entry.map.range.size()[dim_idx]:
                continue

            stripmine = StripMining(sdfg_id, self.state_id, stripmine_subgraph,
                                    self.expr_index)

            # Special case: Tile size of 1 should be omitted from inner map
            if tile_size == 1 and tile_stride == 1:
                stripmine.dim_idx = dim_idx
                stripmine.new_dim_prefix = ''
                stripmine.tile_size = str(tile_size)
                stripmine.tile_stride = str(tile_stride)
                stripmine.divides_evenly = True
                stripmine.apply(sdfg)
                removed_maps += 1
            else:
                stripmine.dim_idx = dim_idx
                stripmine.new_dim_prefix = self.prefix
                stripmine.tile_size = str(tile_size)
                stripmine.tile_stride = str(tile_stride)
                stripmine.divides_evenly = self.divides_evenly
                stripmine.apply(sdfg)

            # apply to the new map the schedule of the original one
            map_entry.schedule = original_schedule

            if last_map_entry:
                new_map_entry = graph.in_edges(map_entry)[0].src
                mapcollapse_subgraph = {
                    MapCollapse._outer_map_entry:
                    graph.node_id(last_map_entry),
                    MapCollapse._inner_map_entry: graph.node_id(new_map_entry)
                }
                mapcollapse = MapCollapse(sdfg_id, self.state_id,
                                          mapcollapse_subgraph, 0)
                mapcollapse.apply(sdfg)
            last_map_entry = graph.in_edges(map_entry)[0].src
Beispiel #7
0
class OrthogonalTiling(pattern_matching.Transformation):
    """ Implements the orthogonal tiling transformation.

        Orthogonal tiling is a type of nested map fission that creates tiles
        in every dimension of the matched Map.
    """

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))

    # Properties
    prefix = Property(
        dtype=str, default="tile", desc="Prefix for new iterators")
    tile_sizes = ShapeProperty(
        dtype=tuple, default=(128, 128, 128), desc="Tile size per dimension")
    divides_evenly = Property(
        dtype=bool,
        default=False,
        desc="Tile size divides dimension length evenly")

    @staticmethod
    def annotates_memlets():
        return False

    @staticmethod
    def expressions():
        return [nxutil.node_path_graph(OrthogonalTiling._map_entry)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        return True

    @staticmethod
    def match_to_str(graph, candidate):
        map_entry = graph.nodes()[candidate[OrthogonalTiling._map_entry]]
        return map_entry.map.label + ': ' + str(map_entry.map.params)

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        # Tile map.
        target_dim, new_dim, new_map = self.__stripmine(
            sdfg, graph, self.subgraph)
        return new_map

    def __stripmine(self, sdfg, graph, candidate):
        # Retrieve map entry and exit nodes.
        map_entry = graph.nodes()[candidate[OrthogonalTiling._map_entry]]
        map_exit = graph.exit_nodes(map_entry)[0]

        # Map subgraph
        map_subgraph = graph.scope_subgraph(map_entry)

        # Retrieve transformation properties.
        prefix = self.prefix
        tile_sizes = self.tile_sizes
        divides_evenly = self.divides_evenly

        new_param = []
        new_range = []

        for dim_idx in range(len(map_entry.map.params)):

            if dim_idx >= len(tile_sizes):
                tile_size = tile_sizes[-1]
            else:
                tile_size = tile_sizes[dim_idx]

            # Retrieve parameter and range of dimension to be strip-mined.
            target_dim = map_entry.map.params[dim_idx]
            td_from, td_to, td_step = map_entry.map.range[dim_idx]

            new_dim = prefix + '_' + target_dim

            # Basic values
            if divides_evenly:
                tile_num = '(%s + 1 - %s) / %s' % (symbolic.symstr(td_to),
                                                   symbolic.symstr(td_from),
                                                   str(tile_size))
            else:
                tile_num = 'int_ceil((%s + 1 - %s), %s)' % (symbolic.symstr(
                    td_to), symbolic.symstr(td_from), str(tile_size))

            # Outer map values (over all tiles)
            nd_from = 0
            nd_to = symbolic.pystr_to_symbolic(str(tile_num) + ' - 1')
            nd_step = 1

            # Inner map values (over one tile)
            td_from_new = dace.symbolic.pystr_to_symbolic(td_from)
            td_to_new_exact = symbolic.pystr_to_symbolic(
                'min(%s + 1 - %s * %s, %s + %s) - 1' %
                (symbolic.symstr(td_to), str(new_dim), str(tile_size),
                 td_from_new, str(tile_size)))
            td_to_new_approx = symbolic.pystr_to_symbolic(
                '%s + %s - 1' % (td_from_new, str(tile_size)))

            # Outer map (over all tiles)
            new_dim_range = (nd_from, nd_to, nd_step)
            new_param.append(new_dim)
            new_range.append(new_dim_range)

            # Inner map (over one tile)
            if divides_evenly:
                td_to_new = td_to_new_approx
            else:
                td_to_new = dace.symbolic.SymExpr(td_to_new_exact,
                                                  td_to_new_approx)
            map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step)

            # Fix subgraph memlets
            target_dim = dace.symbolic.pystr_to_symbolic(target_dim)
            offset = dace.symbolic.pystr_to_symbolic(
                '%s * %s' % (new_dim, str(tile_size)))
            for _, _, _, _, memlet in map_subgraph.edges():
                old_subset = memlet.subset
                if isinstance(old_subset, dace.subsets.Indices):
                    new_indices = []
                    for idx in old_subset:
                        new_idx = idx.subs(target_dim, target_dim + offset)
                        new_indices.append(new_idx)
                    memlet.subset = dace.subsets.Indices(new_indices)
                elif isinstance(old_subset, dace.subsets.Range):
                    new_ranges = []
                    for i, old_range in enumerate(old_subset):
                        if len(old_range) == 3:
                            b, e, s, = old_range
                            t = old_subset.tile_sizes[i]
                        else:
                            raise ValueError(
                                'Range %s is invalid.' % old_range)
                        new_b = b.subs(target_dim, target_dim + offset)
                        new_e = e.subs(target_dim, target_dim + offset)
                        new_s = s.subs(target_dim, target_dim + offset)
                        new_t = t.subs(target_dim, target_dim + offset)
                        new_ranges.append((new_b, new_e, new_s, new_t))
                    memlet.subset = dace.subsets.Range(new_ranges)
                else:
                    raise NotImplementedError

        new_map = nodes.Map(prefix + '_' + map_entry.map.label, new_param,
                            subsets.Range(new_range))
        new_map_entry = nodes.MapEntry(new_map)
        new_exit = nodes.MapExit(new_map)

        # Make internal map's schedule to "not parallel"
        map_entry.map._schedule = dtypes.ScheduleType.Default

        # Redirect/create edges.
        new_in_edges = {}
        for _src, conn, _dest, _, memlet in graph.out_edges(map_entry):
            if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar):
                new_subset = copy.deepcopy(memlet.subset)
                # new_subset = calc_set_image(map_entry.map.params,
                #                             map_entry.map.range, memlet.subset,
                #                             cont_or_strided)
                if memlet.data in new_in_edges:
                    src, src_conn, dest, dest_conn, new_memlet, num = \
                        new_in_edges[memlet.data]
                    new_memlet.subset = calc_set_union(
                        new_memlet.data, sdfg.arrays[nnew_memlet.data],
                        new_memlet.subset, new_subset)
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_in_edges.update({
                        memlet.data: (src, src_conn, dest, dest_conn,
                                      new_memlet, min(num, int(conn[4:])))
                    })
                else:
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_in_edges.update({
                        memlet.data: (new_map_entry, None, map_entry, None,
                                      new_memlet, int(conn[4:]))
                    })
        nxutil.change_edge_dest(graph, map_entry, new_map_entry)

        new_out_edges = {}
        for _src, conn, _dest, _, memlet in graph.in_edges(map_exit):
            if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar):
                new_subset = memlet.subset
                # new_subset = calc_set_image(map_entry.map.params,
                #                             map_entry.map.range,
                #                             memlet.subset, cont_or_strided)
                if memlet.data in new_out_edges:
                    src, src_conn, dest, dest_conn, new_memlet, num = \
                        new_out_edges[memlet.data]
                    new_memlet.subset = calc_set_union(
                        new_memlet.data, sdfg.arrays[nnew_memlet.data],
                        new_memlet.subset, new_subset)
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_out_edges.update({
                        memlet.data: (src, src_conn, dest, dest_conn,
                                      new_memlet, min(num, conn[4:]))
                    })
                else:
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_out_edges.update({
                        memlet.data: (map_exit, None, new_exit, None,
                                      new_memlet, conn[4:])
                    })
        nxutil.change_edge_src(graph, map_exit, new_exit)

        # Connector related work follows
        # 1. Dictionary 'old_connector_number': 'new_connector_numer'
        # 2. New node in/out connectors
        # 3. New edges

        in_conn_nums = []
        for _, e in new_in_edges.items():
            _, _, _, _, _, num = e
            in_conn_nums.append(num)
        in_conn = {}
        for i, num in enumerate(in_conn_nums):
            in_conn.update({num: i + 1})

        entry_in_connectors = set()
        entry_out_connectors = set()
        for i in range(len(in_conn_nums)):
            entry_in_connectors.add('IN_' + str(i + 1))
            entry_out_connectors.add('OUT_' + str(i + 1))
        new_map_entry.in_connectors = entry_in_connectors
        new_map_entry.out_connectors = entry_out_connectors

        for _, e in new_in_edges.items():
            src, _, dst, _, memlet, num = e
            graph.add_edge(src, 'OUT_' + str(in_conn[num]), dst,
                           'IN_' + str(in_conn[num]), memlet)

        out_conn_nums = []
        for _, e in new_out_edges.items():
            _, _, dst, _, _, num = e
            if dst is not new_exit:
                continue
            out_conn_nums.append(num)
        out_conn = {}
        for i, num in enumerate(out_conn_nums):
            out_conn.update({num: i + 1})

        exit_in_connectors = set()
        exit_out_connectors = set()
        for i in range(len(out_conn_nums)):
            exit_in_connectors.add('IN_' + str(i + 1))
            exit_out_connectors.add('OUT_' + str(i + 1))
        new_exit.in_connectors = exit_in_connectors
        new_exit.out_connectors = exit_out_connectors

        for _, e in new_out_edges.items():
            src, _, dst, _, memlet, num = e
            graph.add_edge(src, 'OUT_' + str(out_conn[num]), dst,
                           'IN_' + str(out_conn[num]), memlet)

        # Return strip-mined dimension.
        return target_dim, new_dim, new_map

    @staticmethod
    def __modify_edges(sdfg, graph, candidate, target_dim, new_dim):
        map_entry = graph.nodes()[candidate[OrthogonalTiling._map_entry]]

        processed = []
        for src, _dest, memlet, _scope in nxutil.traverse_sdfg_scope(
                graph, map_entry, True):
            if memlet in processed:
                continue
            processed.append(memlet)

            # Corner cases
            if isinstance(sdfg.arrays[memlet.data], dace.data.Stream):
                continue
            if memlet.wcr is not None:
                memlet.num_accesses = 1
                continue

            for i, dim in enumerate(memlet.subset):
                if isinstance(dim, tuple):
                    dim = tuple(
                        symbolic.pystr_to_symbolic(d).subs(
                            symbolic.pystr_to_symbolic(target_dim),
                            symbolic.pystr_to_symbolic(
                                '%s + %s' % (str(new_dim), str(target_dim))))
                        for d in dim)
                else:
                    dim = symbolic.pystr_to_symbolic(dim).subs(
                        symbolic.pystr_to_symbolic(target_dim),
                        symbolic.pystr_to_symbolic(
                            '%s + %s' % (str(new_dim), str(target_dim))))

                memlet.subset[i] = dim
        return
Beispiel #8
0
class StripMining(pattern_matching.Transformation):
    """ Implements the strip-mining transformation.

        Strip-mining takes as input a map dimension and splits it into
        two dimensions. The new dimension iterates over the range of
        the original one with a parameterizable step, called the tile
        size. The original dimension is changed to iterates over the
        range of the tile size, with the same step as before.
    """

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))

    # Properties
    dim_idx = Property(dtype=int,
                       default=-1,
                       desc="Index of dimension to be strip-mined")
    new_dim_prefix = Property(dtype=str,
                              default="tile",
                              desc="Prefix for new dimension name")
    tile_size = Property(dtype=str,
                         default="64",
                         desc="Tile size of strip-mined dimension")
    tile_stride = Property(dtype=str,
                           default="",
                           desc="Stride between two tiles of the "
                           "strip-mined dimension")
    divides_evenly = Property(dtype=bool,
                              default=False,
                              desc="Tile size divides dimension range evenly?")
    strided = Property(
        dtype=bool,
        default=False,
        desc="Continuous (false) or strided (true) elements in tile")

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(StripMining._map_entry)
            # kStripMining._tasklet, StripMining._map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        return True

    @staticmethod
    def match_to_str(graph, candidate):
        map_entry = graph.nodes()[candidate[StripMining._map_entry]]
        return map_entry.map.label + ': ' + str(map_entry.map.params)

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        # Strip-mine selected dimension.
        _, _, new_map = self._stripmine(sdfg, graph, self.subgraph)
        return new_map

    # def __init__(self, tag=True):
    def __init__(self, *args, **kwargs):
        self._entry = nodes.EntryNode()
        self._tasklet = nodes.Tasklet('_')
        self._exit = nodes.ExitNode()
        super().__init__(*args, **kwargs)
        # self.tag = tag

    @property
    def entry(self):
        return self._entry

    @property
    def exit(self):
        return self._exit

    @property
    def tasklet(self):
        return self._tasklet

    def print_match_pattern(self, candidate):
        gentry = candidate[self.entry]
        return str(gentry.map.params[-1])

    def modifies_graph(self):
        return True

    def _find_new_dim(self, sdfg: SDFG, state: SDFGState,
                      entry: nodes.MapEntry, prefix: str, target_dim: str):
        """ Finds a variable that is not already defined in scope. """
        stree = state.scope_tree()
        candidate = '%s_%s' % (prefix, target_dim)
        index = 1
        while candidate in map(str, stree[entry].defined_vars):
            candidate = '%s%d_%s' % (prefix, index, target_dim)
            index += 1
        return candidate

    def _stripmine(self, sdfg, graph, candidate):

        # Retrieve map entry and exit nodes.
        map_entry = graph.nodes()[candidate[StripMining._map_entry]]
        map_exit = graph.exit_nodes(map_entry)[0]

        # Retrieve transformation properties.
        dim_idx = self.dim_idx
        new_dim_prefix = self.new_dim_prefix
        tile_size = self.tile_size
        divides_evenly = self.divides_evenly
        strided = self.strided

        tile_stride = self.tile_stride
        if tile_stride is None or len(tile_stride) == 0:
            tile_stride = tile_size

        # Retrieve parameter and range of dimension to be strip-mined.
        target_dim = map_entry.map.params[dim_idx]
        td_from, td_to, td_step = map_entry.map.range[dim_idx]

        # Create new map. Replace by cloning???
        new_dim = self._find_new_dim(sdfg, graph, map_entry, new_dim_prefix,
                                     target_dim)
        nd_from = 0
        nd_to = symbolic.pystr_to_symbolic(
            'int_ceil(%s + 1 - %s, %s) - 1' %
            (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_stride))
        nd_step = 1
        new_dim_range = (nd_from, nd_to, nd_step)
        new_map = nodes.Map(new_dim + '_' + map_entry.map.label, [new_dim],
                            subsets.Range([new_dim_range]))
        new_map_entry = nodes.MapEntry(new_map)
        new_map_exit = nodes.MapExit(new_map)

        # Change the range of the selected dimension to iterate over a single
        # tile
        if strided:
            td_from_new = symbolic.pystr_to_symbolic(new_dim)
            td_to_new_approx = td_to
            td_step = symbolic.pystr_to_symbolic(tile_size)
        else:
            td_from_new = symbolic.pystr_to_symbolic(
                '%s + %s * %s' %
                (symbolic.symstr(td_from), str(new_dim), tile_stride))
            td_to_new_exact = symbolic.pystr_to_symbolic(
                'min(%s + 1, %s + %s * %s + %s) - 1' %
                (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_stride,
                 str(new_dim), tile_size))
            td_to_new_approx = symbolic.pystr_to_symbolic(
                '%s + %s * %s + %s - 1' %
                (symbolic.symstr(td_from), tile_stride, str(new_dim),
                 tile_size))
        if divides_evenly or strided:
            td_to_new = td_to_new_approx
        else:
            td_to_new = dace.symbolic.SymExpr(td_to_new_exact,
                                              td_to_new_approx)
        map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step)

        # Make internal map's schedule to "not parallel"
        new_map.schedule = map_entry.map.schedule
        map_entry.map.schedule = dtypes.ScheduleType.Sequential

        # Redirect edges
        new_map_entry.in_connectors = dcpy(map_entry.in_connectors)
        nxutil.change_edge_dest(graph, map_entry, new_map_entry)
        new_map_exit.out_connectors = dcpy(map_exit.out_connectors)
        nxutil.change_edge_src(graph, map_exit, new_map_exit)

        # Create new entry edges
        new_in_edges = dict()
        entry_in_conn = set()
        entry_out_conn = set()
        for _src, src_conn, _dst, _, memlet in graph.out_edges(map_entry):
            if (src_conn is not None
                    and src_conn[:4] == 'OUT_' and not isinstance(
                        sdfg.arrays[memlet.data], dace.data.Scalar)):
                new_subset = calc_set_image(
                    map_entry.map.params,
                    map_entry.map.range,
                    memlet.subset,
                )
                conn = src_conn[4:]
                key = (memlet.data, 'IN_' + conn, 'OUT_' + conn)
                if key in new_in_edges.keys():
                    old_subset = new_in_edges[key].subset
                    new_in_edges[key].subset = calc_set_union(
                        old_subset, new_subset)
                else:
                    entry_in_conn.add('IN_' + conn)
                    entry_out_conn.add('OUT_' + conn)
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_in_edges[key] = new_memlet
            else:
                if src_conn is not None and src_conn[:4] == 'OUT_':
                    conn = src_conn[4:]
                    in_conn = 'IN_' + conn
                    out_conn = 'OUT_' + conn
                else:
                    in_conn = src_conn
                    out_conn = src_conn
                if in_conn:
                    entry_in_conn.add(in_conn)
                if out_conn:
                    entry_out_conn.add(out_conn)
                new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet)
        new_map_entry.out_connectors = entry_out_conn
        map_entry.in_connectors = entry_in_conn
        for (_, in_conn, out_conn), memlet in new_in_edges.items():
            graph.add_edge(new_map_entry, out_conn, map_entry, in_conn, memlet)

        # Create new exit edges
        new_out_edges = dict()
        exit_in_conn = set()
        exit_out_conn = set()
        for _src, _, _dst, dst_conn, memlet in graph.in_edges(map_exit):
            if (dst_conn is not None
                    and dst_conn[:3] == 'IN_' and not isinstance(
                        sdfg.arrays[memlet.data], dace.data.Scalar)):
                new_subset = calc_set_image(
                    map_entry.map.params,
                    map_entry.map.range,
                    memlet.subset,
                )
                conn = dst_conn[3:]
                key = (memlet.data, 'IN_' + conn, 'OUT_' + conn)
                if key in new_out_edges.keys():
                    old_subset = new_out_edges[key].subset
                    new_out_edges[key].subset = calc_set_union(
                        old_subset, new_subset)
                else:
                    exit_in_conn.add('IN_' + conn)
                    exit_out_conn.add('OUT_' + conn)
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_out_edges[key] = new_memlet
            else:
                if dst_conn is not None and dst_conn[:3] == 'IN_':
                    conn = dst_conn[3:]
                    in_conn = 'IN_' + conn
                    out_conn = 'OUT_' + conn
                else:
                    in_conn = src_conn
                    out_conn = src_conn
                if in_conn:
                    exit_in_conn.add(in_conn)
                if out_conn:
                    exit_out_conn.add(out_conn)
                new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet)
        new_map_exit.in_connectors = exit_in_conn
        map_exit.out_connectors = exit_out_conn
        for (_, in_conn, out_conn), memlet in new_out_edges.items():
            graph.add_edge(map_exit, out_conn, new_map_exit, in_conn, memlet)

        # Return strip-mined dimension.
        return target_dim, new_dim, new_map
Beispiel #9
0
class MapToForLoop(pattern_matching.Transformation):
    """ Implements the Map to for-loop transformation.

        Takes a map and enforces a sequential schedule by transforming it into
        a state-machine of a for-loop. Creates a nested SDFG, if necessary.
    """

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [nxutil.node_path_graph(MapToForLoop._map_entry)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        # Only uni-dimensional maps are accepted.
        map_entry = graph.nodes()[candidate[MapToForLoop._map_entry]]
        if len(map_entry.map.params) > 1:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        map_entry = graph.nodes()[candidate[MapToForLoop._map_entry]]
        return map_entry.map.label + ': ' + str(map_entry.map.params)

    def apply(self, sdfg):
        # Retrieve map entry and exit nodes.
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[MapToForLoop._map_entry]]
        map_exits = graph.exit_nodes(map_entry)
        loop_idx = map_entry.map.params[0]
        loop_from, loop_to, loop_step = map_entry.map.range[0]

        nested_sdfg = dace.SDFG(graph.label + '_' + map_entry.map.label)

        # Construct nested SDFG
        begin = nested_sdfg.add_state('begin')
        guard = nested_sdfg.add_state('guard')
        body = nested_sdfg.add_state('body')
        end = nested_sdfg.add_state('end')

        nested_sdfg.add_edge(
            begin, guard,
            edges.InterstateEdge(assignments={str(loop_idx): str(loop_from)}))
        nested_sdfg.add_edge(
            guard,
            body,
            edges.InterstateEdge(condition = str(loop_idx) + ' <= ' + \
                                             str(loop_to))
        )
        nested_sdfg.add_edge(
            guard,
            end,
            edges.InterstateEdge(condition = str(loop_idx) + ' > ' + \
                                             str(loop_to))
        )
        nested_sdfg.add_edge(
            body,
            guard,
            edges.InterstateEdge(assignments = {str(loop_idx): str(loop_idx) + \
                                                ' + ' +str(loop_step)})
        )

        # Add map contents
        map_subgraph = graph.scope_subgraph(map_entry)
        for node in map_subgraph.nodes():
            if node is not map_entry and node not in map_exits:
                body.add_node(node)
        for src, src_conn, dst, dst_conn, memlet in map_subgraph.edges():
            if src is not map_entry and dst not in map_exits:
                body.add_edge(src, src_conn, dst, dst_conn, memlet)

        # Reconnect inputs
        nested_in_data_nodes = {}
        nested_in_connectors = {}
        nested_in_memlets = {}
        for i, edge in enumerate(graph.in_edges(map_entry)):
            src, src_conn, dst, dst_conn, memlet = edge
            data_label = '_in_' + memlet.data
            memdata = sdfg.arrays[memlet.data]
            if isinstance(memdata, data.Array):
                data_array = sdfg.add_array(data_label, memdata.dtype, [
                    symbolic.overapproximate(r)
                    for r in memlet.bounding_box_size()
                ])
            elif isinstance(memdata, data.Scalar):
                data_array = sdfg.add_scalar(data_label, memdata.dtype)
            else:
                raise NotImplementedError()
            data_node = nodes.AccessNode(data_label)
            body.add_node(data_node)
            nested_in_data_nodes.update({i: data_node})
            nested_in_connectors.update({i: data_label})
            nested_in_memlets.update({i: memlet})
            for _, _, _, _, old_memlet in body.edges():
                if old_memlet.data == memlet.data:
                    old_memlet.data = data_label
            #body.add_edge(data_node, None, dst, dst_conn, memlet)

        # Reconnect outputs
        nested_out_data_nodes = {}
        nested_out_connectors = {}
        nested_out_memlets = {}
        for map_exit in map_exits:
            for i, edge in enumerate(graph.out_edges(map_exit)):
                src, src_conn, dst, dst_conn, memlet = edge
                data_label = '_out_' + memlet.data
                memdata = sdfg.arrays[memlet.data]
                if isinstance(memdata, data.Array):
                    data_array = sdfg.add_array(data_label, memdata.dtype, [
                        symbolic.overapproximate(r)
                        for r in memlet.bounding_box_size()
                    ])
                elif isinstance(memdata, data.Scalar):
                    data_array = sdfg.add_scalar(data_label, memdata.dtype)
                else:
                    raise NotImplementedError()
                data_node = nodes.AccessNode(data_label)
                body.add_node(data_node)
                nested_out_data_nodes.update({i: data_node})
                nested_out_connectors.update({i: data_label})
                nested_out_memlets.update({i: memlet})
                for _, _, _, _, old_memlet in body.edges():
                    if old_memlet.data == memlet.data:
                        old_memlet.data = data_label
                #body.add_edge(src, src_conn, data_node, None, memlet)

        # Add nested SDFG and reconnect it
        nested_node = graph.add_nested_sdfg(
            nested_sdfg, sdfg, set(nested_in_connectors.values()),
            set(nested_out_connectors.values()))

        for i, edge in enumerate(graph.in_edges(map_entry)):
            src, src_conn, dst, dst_conn, memlet = edge
            graph.add_edge(src, src_conn, nested_node, nested_in_connectors[i],
                           nested_in_memlets[i])

        for map_exit in map_exits:
            for i, edge in enumerate(graph.out_edges(map_exit)):
                src, src_conn, dst, dst_conn, memlet = edge
                graph.add_edge(nested_node, nested_out_connectors[i], dst,
                               dst_conn, nested_out_memlets[i])

        for src, src_conn, dst, dst_conn, memlet in graph.out_edges(map_entry):
            i = int(src_conn[4:]) - 1
            new_memlet = dcpy(memlet)
            new_memlet.data = nested_in_data_nodes[i].data
            body.add_edge(nested_in_data_nodes[i], None, dst, dst_conn,
                          new_memlet)

        for map_exit in map_exits:
            for src, src_conn, dst, dst_conn, memlet in graph.in_edges(
                    map_exit):
                i = int(dst_conn[3:]) - 1
                new_memlet = dcpy(memlet)
                new_memlet.data = nested_out_data_nodes[i].data
                body.add_edge(src, src_conn, nested_out_data_nodes[i], None,
                              new_memlet)

        for node in map_subgraph:
            graph.remove_node(node)
Beispiel #10
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the map and its entry and exit nodes.
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[MapExpansion._map_entry]]
        map_exit = graph.exit_nodes(map_entry)[0]
        current_map = map_entry.map

        # Create new maps
        maps = [
            nodes.Map(current_map.label + '_' + str(param), [param],
                      subsets.Range([param_range]),
                      schedule=dtypes.ScheduleType.Sequential) for param,
            param_range in zip(current_map.params, current_map.range)
        ]
        maps[0]._schedule = dtypes.ScheduleType.Default

        # Create new map entries
        entries = [nodes.MapEntry(new_map) for new_map in maps]
        entries[0].in_connectors = map_entry.in_connectors
        entries[0].out_connectors = map_entry.out_connectors
        num_entry_out_edges = len(graph.out_edges(map_entry))
        for i in range(1, len(entries)):
            entries[i].in_connectors = set('IN_' + str(i + 1)
                                           for i in range(num_entry_out_edges))
            entries[i].out_connectors = set(
                'OUT_' + str(i + 1) for i in range(num_entry_out_edges))

        # Create new map exits
        exits = [nodes.MapExit(new_map) for new_map in maps]
        exits.reverse()
        exits[-1].in_connectors = map_exit.in_connectors
        exits[-1].out_connectors = map_exit.out_connectors
        num_entry_out_edges = len(graph.out_edges(map_exit))
        for i in range(0, len(exits) - 1):
            exits[i].in_connectors = set('IN_' + str(i + 1)
                                         for i in range(num_entry_out_edges))
            exits[i].out_connectors = set('OUT_' + str(i + 1)
                                          for i in range(num_entry_out_edges))

        # Add new nodes to state
        graph.add_nodes_from(entries)
        graph.add_nodes_from(exits)

        # Redirect edges to new nodes
        dace.graph.nxutil.change_edge_dest(graph, map_entry, entries[0])
        dace.graph.nxutil.change_edge_src(graph, map_exit, exits[-1])

        for i, e in enumerate(graph.out_edges(map_entry)):
            graph.remove_edge(e)
            graph.add_edge(entries[0], e.src_conn, entries[1],
                           'IN_' + str(i + 1), copy.deepcopy(e.data))
            graph.add_edge(entries[-1], 'OUT_' + str(i + 1), e.dst, e.dst_conn,
                           copy.deepcopy(e.data))
            for j in range(1, len(entries) - 1):
                graph.add_edge(entries[j], 'OUT_' + str(i + 1), entries[j + 1],
                               'IN_' + str(i + 1), copy.deepcopy(e.data))
        for i, e in enumerate(graph.in_edges(map_exit)):
            graph.remove_edge(e)
            graph.add_edge(e.src, e.src_conn, exits[0], 'IN_' + str(i + 1),
                           copy.deepcopy(e.data))
            graph.add_edge(exits[-2], 'OUT_' + str(i + 1), exits[-1],
                           e.dst_conn, copy.deepcopy(e.data))
            for j in range(0, len(exits) - 2):
                graph.add_edge(exits[j], 'OUT_' + str(i + 1), exits[j + 1],
                               'IN_' + str(i + 1), copy.deepcopy(e.data))

        # Remove old nodes
        graph.remove_node(map_entry)
        graph.remove_node(map_exit)
Beispiel #11
0
class MapExpansion(pm.Transformation):
    """ Implements the map-expansion pattern.

        Map-expansion takes an N-dimensional map and expands it to N 
        unidimensional maps.
    """

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [nxutil.node_path_graph(MapExpansion._map_entry)]

    @staticmethod
    def can_be_applied(graph: dace.graph.graph.OrderedMultiDiConnectorGraph,
                       candidate: Dict[dace.graph.nodes.Node, int],
                       expr_index: int,
                       sdfg: dace.SDFG,
                       strict: bool = False):
        # A candidate subgraph matches the map-expansion pattern when it includes
        # a N-dimensional map, with N greater than one.
        map_entry = graph.nodes()[candidate[MapExpansion._map_entry]]
        return map_entry.map.get_param_num() > 1

    @staticmethod
    def match_to_str(graph: dace.graph.graph.OrderedMultiDiConnectorGraph,
                     candidate: Dict[dace.graph.nodes.Node, int]):
        map_entry = graph.nodes()[candidate[MapExpansion._map_entry]]
        return map_entry.map.label + ': ' + str(map_entry.map.params)

    def apply(self, sdfg: dace.SDFG):
        # Extract the map and its entry and exit nodes.
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[MapExpansion._map_entry]]
        map_exit = graph.exit_nodes(map_entry)[0]
        current_map = map_entry.map

        # Create new maps
        maps = [
            nodes.Map(current_map.label + '_' + str(param), [param],
                      subsets.Range([param_range]),
                      schedule=dtypes.ScheduleType.Sequential) for param,
            param_range in zip(current_map.params, current_map.range)
        ]
        maps[0]._schedule = dtypes.ScheduleType.Default

        # Create new map entries
        entries = [nodes.MapEntry(new_map) for new_map in maps]
        entries[0].in_connectors = map_entry.in_connectors
        entries[0].out_connectors = map_entry.out_connectors
        num_entry_out_edges = len(graph.out_edges(map_entry))
        for i in range(1, len(entries)):
            entries[i].in_connectors = set('IN_' + str(i + 1)
                                           for i in range(num_entry_out_edges))
            entries[i].out_connectors = set(
                'OUT_' + str(i + 1) for i in range(num_entry_out_edges))

        # Create new map exits
        exits = [nodes.MapExit(new_map) for new_map in maps]
        exits.reverse()
        exits[-1].in_connectors = map_exit.in_connectors
        exits[-1].out_connectors = map_exit.out_connectors
        num_entry_out_edges = len(graph.out_edges(map_exit))
        for i in range(0, len(exits) - 1):
            exits[i].in_connectors = set('IN_' + str(i + 1)
                                         for i in range(num_entry_out_edges))
            exits[i].out_connectors = set('OUT_' + str(i + 1)
                                          for i in range(num_entry_out_edges))

        # Add new nodes to state
        graph.add_nodes_from(entries)
        graph.add_nodes_from(exits)

        # Redirect edges to new nodes
        dace.graph.nxutil.change_edge_dest(graph, map_entry, entries[0])
        dace.graph.nxutil.change_edge_src(graph, map_exit, exits[-1])

        for i, e in enumerate(graph.out_edges(map_entry)):
            graph.remove_edge(e)
            graph.add_edge(entries[0], e.src_conn, entries[1],
                           'IN_' + str(i + 1), copy.deepcopy(e.data))
            graph.add_edge(entries[-1], 'OUT_' + str(i + 1), e.dst, e.dst_conn,
                           copy.deepcopy(e.data))
            for j in range(1, len(entries) - 1):
                graph.add_edge(entries[j], 'OUT_' + str(i + 1), entries[j + 1],
                               'IN_' + str(i + 1), copy.deepcopy(e.data))
        for i, e in enumerate(graph.in_edges(map_exit)):
            graph.remove_edge(e)
            graph.add_edge(e.src, e.src_conn, exits[0], 'IN_' + str(i + 1),
                           copy.deepcopy(e.data))
            graph.add_edge(exits[-2], 'OUT_' + str(i + 1), exits[-1],
                           e.dst_conn, copy.deepcopy(e.data))
            for j in range(0, len(exits) - 2):
                graph.add_edge(exits[j], 'OUT_' + str(i + 1), exits[j + 1],
                               'IN_' + str(i + 1), copy.deepcopy(e.data))

        # Remove old nodes
        graph.remove_node(map_entry)
        graph.remove_node(map_exit)
Beispiel #12
0
class FPGATransformMap(pattern_matching.Transformation):
    """ Implements the FPGATransformMap transformation.

        Converts a single map to an FPGA-scheduled map and creates FPGA arrays
        outside it, generating CPU<->FPGA memory copies automatically.
  """

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [nxutil.node_path_graph(FPGATransformMap._map_entry)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_entry = graph.nodes()[candidate[FPGATransformMap._map_entry]]
        candidate_map = map_entry.map

        # No more than 3 dimensions
        if candidate_map.range.dims() > 3: return False

        # Map schedules that are disallowed to transform to FPGAs
        if (candidate_map.schedule in [
                dtypes.ScheduleType.MPI, dtypes.ScheduleType.GPU_Device,
                dtypes.ScheduleType.FPGA_Device,
                dtypes.ScheduleType.GPU_ThreadBlock
        ]):
            return False

        # Recursively check parent for FPGA schedules
        sdict = graph.scope_dict()
        current_node = map_entry
        while current_node is not None:
            if (current_node.map.schedule in [
                    dtypes.ScheduleType.GPU_Device,
                    dtypes.ScheduleType.FPGA_Device,
                    dtypes.ScheduleType.GPU_ThreadBlock
            ]):
                return False
            current_node = sdict[current_node]

        # Ensure that map does not include internal arrays that are allocated
        # on non-default space
        subgraph = graph.scope_subgraph(map_entry)
        for node in subgraph.nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.desc(sdfg).storage != dtypes.StorageType.Default):
                return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        map_entry = graph.nodes()[candidate[FPGATransformMap._map_entry]]

        return str(map_entry)

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[FPGATransformMap._map_entry]]
        map_entry.map._schedule = dtypes.ScheduleType.FPGA_Device

        # Find map exit nodes
        exit_nodes = graph.exit_nodes(map_entry)

        fpga_storage_types = [
            dtypes.StorageType.FPGA_Global, dtypes.StorageType.FPGA_Local,
            dtypes.StorageType.CPU_Pinned
        ]

        #######################################################
        # Add FPGA copies of CPU arrays (i.e., not already on FPGA)

        # First, understand which arrays to clone
        all_out_edges = []
        for enode in exit_nodes:
            all_out_edges.extend(list(graph.out_edges(enode)))
        in_arrays_to_clone = set()
        out_arrays_to_clone = set()
        for e in graph.in_edges(map_entry):
            data_node = sd.find_input_arraynode(graph, e)
            if data_node.desc(sdfg).storage not in fpga_storage_types:
                in_arrays_to_clone.add(data_node)
        for e in all_out_edges:
            data_node = sd.find_output_arraynode(graph, e)
            if data_node.desc(sdfg).storage not in fpga_storage_types:
                out_arrays_to_clone.add(data_node)

        # Second, create a FPGA clone of each array
        cloned_arrays = {}
        in_cloned_arraynodes = {}
        out_cloned_arraynodes = {}
        for array_node in in_arrays_to_clone:
            array = array_node.desc(sdfg)
            if array_node.data in cloned_arrays:
                pass
            elif 'fpga_' + array_node.data in sdfg.arrays:
                pass
            else:
                sdfg.add_array('fpga_' + array_node.data,
                               dtype=array.dtype,
                               shape=array.shape,
                               materialize_func=array.materialize_func,
                               transient=True,
                               storage=dtypes.StorageType.FPGA_Global,
                               allow_conflicts=array.allow_conflicts,
                               access_order=array.access_order,
                               strides=array.strides,
                               offset=array.offset)
                cloned_arrays[array_node.data] = 'fpga_' + array_node.data
            cloned_node = nodes.AccessNode('fpga_' + array_node.data)

            in_cloned_arraynodes[array_node.data] = cloned_node
        for array_node in out_arrays_to_clone:
            array = array_node.desc(sdfg)
            if array_node.data in cloned_arrays:
                pass
            elif 'fpga_' + array_node.data in sdfg.arrays:
                pass
            else:
                sdfg.add_array('fpga_' + array_node.data,
                               dtype=array.dtype,
                               shape=array.shape,
                               materialize_func=array.materialize_func,
                               transient=True,
                               storage=dtypes.StorageType.FPGA_Global,
                               allow_conflicts=array.allow_conflicts,
                               access_order=array.access_order,
                               strides=array.strides,
                               offset=array.offset)
                cloned_arrays[array_node.data] = 'fpga_' + array_node.data
            cloned_node = nodes.AccessNode('fpga_' + array_node.data)

            out_cloned_arraynodes[array_node.data] = cloned_node

        # Third, connect the cloned arrays to the originals
        # TODO(later): Shift indices and create only the necessary sub-arrays
        for array_name, node in in_cloned_arraynodes.items():
            graph.add_node(node)
            for edge in graph.in_edges(map_entry):
                if edge.data.data == array_name:
                    graph.remove_edge(edge)
                    graph.add_edge(edge.src, None, node, None, edge.data)
                    newmemlet = copy.copy(edge.data)
                    newmemlet.data = node.data
                    graph.add_edge(node, edge.src_conn, edge.dst,
                                   edge.dst_conn, newmemlet)
        for array_name, node in out_cloned_arraynodes.items():
            graph.add_node(node)
            for edge in all_out_edges:
                if edge.data.data == array_name:
                    graph.remove_edge(edge)
                    graph.add_edge(node, None, edge.dst, None, edge.data)
                    newmemlet = copy.copy(edge.data)
                    newmemlet.data = node.data
                    graph.add_edge(edge.src, edge.src_conn, node,
                                   edge.dst_conn, newmemlet)

        # Fourth, replace memlet arrays as necessary
        scope_subgraph = graph.scope_subgraph(map_entry)
        for edge in scope_subgraph.edges():
            if (edge.data.data is not None
                    and edge.data.data in cloned_arrays):
                edge.data.data = cloned_arrays[edge.data.data]

    def modifies_graph(self):
        return True
Beispiel #13
0
class AccumulateTransient(pattern_matching.Transformation):
    """ Implements the AccumulateTransient transformation, which adds
        transient stream and data nodes between nested maps that lead to a 
        stream. The transient data nodes then act as a local accumulator.
    """

    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))
    _outer_map_exit = nodes.MapExit(nodes.Map("", [], []))

    array = Property(
        dtype=str,
        desc="Array to create local storage for (if empty, first available)",
        default=None,
        allow_none=True)

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(AccumulateTransient._tasklet,
                                   AccumulateTransient._map_exit,
                                   AccumulateTransient._outer_map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        tasklet = graph.nodes()[candidate[AccumulateTransient._tasklet]]
        map_exit = graph.nodes()[candidate[AccumulateTransient._map_exit]]

        # Check if there is an accumulation output
        for _src, _, dest, _, memlet in graph.out_edges(tasklet):
            if memlet.wcr is not None and dest == map_exit:
                return True

        return False

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[AccumulateTransient._tasklet]
        map_exit = candidate[AccumulateTransient._map_exit]
        outer_map_exit = candidate[AccumulateTransient._outer_map_exit]

        return ' -> '.join(
            str(node) for node in [tasklet, map_exit, outer_map_exit])

    def apply(self, sdfg):
        graph = sdfg.node(self.state_id)

        # Avoid import loop
        from dace.transformation.dataflow.local_storage import LocalStorage

        local_storage_subgraph = {
            LocalStorage._node_a:
            self.subgraph[AccumulateTransient._map_exit],
            LocalStorage._node_b:
            self.subgraph[AccumulateTransient._outer_map_exit]
        }
        sdfg_id = sdfg.sdfg_list.index(sdfg)
        in_local_storage = LocalStorage(
            sdfg_id, self.state_id, local_storage_subgraph, self.expr_index)
        in_local_storage.array = self.array
        in_local_storage.apply(sdfg)

        # Initialize transient to zero in case of summation
        # TODO: Initialize transient in other WCR types
        memlet = graph.in_edges(in_local_storage._data_node)[0].data
        if detect_reduction_type(memlet.wcr) == dtypes.ReductionType.Sum:
            in_local_storage._data_node.setzero = True
        else:
            warnings.warn('AccumulateTransient did not properly initialize'
                          'newly-created transient!')
Beispiel #14
0
class DoubleBuffering(pattern_matching.Transformation):
    """ Implements the double buffering pattern, which pipelines reading
        and processing data by creating a second copy of the memory.
        In particular, the transformation takes a 1D map and all internal
        (directly connected) transients, adds an additional dimension of size 2,
        and turns the map into a for loop that processes and reads the data in a
        double-buffered manner. Other memlets will not be transformed.
    """

    _map_entry = nodes.MapEntry(nodes.Map('_', [], []))
    _transient = nodes.AccessNode('_')

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(DoubleBuffering._map_entry,
                                   DoubleBuffering._transient)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_entry = graph.nodes()[candidate[DoubleBuffering._map_entry]]
        transient = graph.nodes()[candidate[DoubleBuffering._transient]]

        # Only one dimensional maps are allowed
        if len(map_entry.map.params) != 1:
            return False

        # Verify the map can be transformed to a for-loop
        if not MapToForLoop.can_be_applied(
                graph,
            {MapToForLoop._map_entry: candidate[DoubleBuffering._map_entry]},
                expr_index, sdfg, strict):
            return False

        # Verify that all directly-connected internal access nodes point to
        # transient arrays
        first = True
        for edge in graph.out_edges(map_entry):
            if isinstance(edge.dst, nodes.AccessNode):
                desc = sdfg.arrays[edge.dst.data]
                if not isinstance(desc, data.Array) or not desc.transient:
                    return False
                else:
                    # To avoid duplicate matches, only match the first transient
                    if first and edge.dst != transient:
                        return False
                    first = False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        return str(graph.node(candidate[DoubleBuffering._map_entry]))

    def apply(self, sdfg: sd.SDFG):
        graph: sd.SDFGState = sdfg.nodes()[self.state_id]
        map_entry = graph.node(self.subgraph[DoubleBuffering._map_entry])

        map_param = map_entry.map.params[0]  # Assuming one dimensional

        ##############################
        # Change condition of loop to one fewer iteration (so that the
        # final one reads from the last buffer)
        map_rstart, map_rend, map_rstride = map_entry.map.range[0]
        map_rend = symbolic.pystr_to_symbolic('(%s) - (%s)' %
                                              (map_rend, map_rstride))
        map_entry.map.range = subsets.Range([(map_rstart, map_rend,
                                              map_rstride)])

        ##############################
        # Gather transients to modify
        transients_to_modify = set(edge.dst.data
                                   for edge in graph.out_edges(map_entry)
                                   if isinstance(edge.dst, nodes.AccessNode))

        # Add dimension to transients and modify memlets
        for transient in transients_to_modify:
            desc: data.Array = sdfg.arrays[transient]
            # Using non-python syntax to ensure properties change
            desc.strides = [desc.total_size] + list(desc.strides)
            desc.shape = [2] + list(desc.shape)
            desc.offset = [0] + list(desc.offset)
            desc.total_size = desc.total_size * 2

        ##############################
        # Modify memlets to use map parameter as buffer index
        modified_subsets = []  # Store modified memlets for final state
        for edge in graph.scope_subgraph(map_entry).edges():
            if edge.data.data in transients_to_modify:
                edge.data.subset = self._modify_memlet(sdfg, edge.data.subset,
                                                       edge.data.data)
                modified_subsets.append(edge.data.subset)
            else:  # Could be other_subset
                path = graph.memlet_path(edge)
                src_node = path[0].src
                dst_node = path[-1].dst

                # other_subset could be None. In that case, recreate from array
                dataname = None
                if (isinstance(src_node, nodes.AccessNode)
                        and src_node.data in transients_to_modify):
                    dataname = src_node.data
                elif (isinstance(dst_node, nodes.AccessNode)
                      and dst_node.data in transients_to_modify):
                    dataname = dst_node.data
                if dataname is not None:
                    subset = (edge.data.other_subset or
                              subsets.Range.from_array(sdfg.arrays[dataname]))
                    edge.data.other_subset = self._modify_memlet(
                        sdfg, subset, dataname)
                    modified_subsets.append(edge.data.other_subset)

        ##############################
        # Turn map into for loop
        map_to_for = MapToForLoop(self.sdfg_id, self.state_id, {
            MapToForLoop._map_entry:
            self.subgraph[DoubleBuffering._map_entry]
        }, self.expr_index)
        nsdfg_node, nstate = map_to_for.apply(sdfg)

        ##############################
        # Gather node copies and remove memlets
        edges_to_replace = []
        for node in nstate.source_nodes():
            for edge in nstate.out_edges(node):
                if (isinstance(edge.dst, nodes.AccessNode)
                        and edge.dst.data in transients_to_modify):
                    edges_to_replace.append(edge)
                    nstate.remove_edge(edge)
            if nstate.out_degree(node) == 0:
                nstate.remove_node(node)

        ##############################
        # Add initial reads to initial nested state
        initial_state: sd.SDFGState = nsdfg_node.sdfg.start_state
        initial_state.set_label('%s_init' % map_entry.map.label)
        for edge in edges_to_replace:
            initial_state.add_node(edge.src)
            rnode = edge.src
            wnode = initial_state.add_write(edge.dst.data)
            initial_state.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn,
                                   copy.deepcopy(edge.data))

        # All instances of the map parameter in this state become the loop start
        sd.replace(initial_state, map_param, map_rstart)
        # Initial writes go to the first buffer
        sd.replace(initial_state, '__dace_db_param', '0')

        ##############################
        # Modify main state's memlets

        # Divide by loop stride
        new_expr = symbolic.pystr_to_symbolic('(%s / %s) %% 2' %
                                              (map_param, map_rstride))
        sd.replace(nstate, '__dace_db_param', new_expr)

        ##############################
        # Add the main state's contents to the last state, modifying
        # memlets appropriately.
        final_state: sd.SDFGState = nsdfg_node.sdfg.sink_nodes()[0]
        final_state.set_label('%s_final_computation' % map_entry.map.label)
        dup_nstate = copy.deepcopy(nstate)
        final_state.add_nodes_from(dup_nstate.nodes())
        for e in dup_nstate.edges():
            final_state.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data)

        ##############################
        # Add reads into next buffers to main state
        for edge in edges_to_replace:
            rnode = copy.deepcopy(edge.src)
            nstate.add_node(rnode)
            wnode = nstate.add_write(edge.dst.data)
            new_memlet = copy.deepcopy(edge.data)
            if new_memlet.data in transients_to_modify:
                new_memlet.other_subset = self._replace_in_subset(
                    new_memlet.other_subset, map_param,
                    '(%s + %s)' % (map_param, map_rstride))
            else:
                new_memlet.subset = self._replace_in_subset(
                    new_memlet.subset, map_param,
                    '(%s + %s)' % (map_param, map_rstride))

            nstate.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn,
                            new_memlet)

        nstate.set_label('%s_double_buffered' % map_entry.map.label)
        # Divide by loop stride
        new_expr = symbolic.pystr_to_symbolic('((%s / %s) + 1) %% 2' %
                                              (map_param, map_rstride))
        sd.replace(nstate, '__dace_db_param', new_expr)

    @staticmethod
    def _modify_memlet(sdfg, subset, data_name):
        desc = sdfg.arrays[data_name]
        if len(subset) == len(desc.shape):
            # Already in the right shape, modify new dimension
            subset = list(subset)[1:]

        new_subset = subsets.Range([('__dace_db_param', '__dace_db_param',
                                     1)] + list(subset))
        return new_subset

    @staticmethod
    def _replace_in_subset(subset, string_or_symbol, new_string_or_symbol):
        new_subset = copy.deepcopy(subset)

        repldict = {
            symbolic.pystr_to_symbolic(string_or_symbol):
            symbolic.pystr_to_symbolic(new_string_or_symbol)
        }

        for i, dim in enumerate(new_subset):
            try:
                new_subset[i] = tuple(d.subs(repldict) for d in dim)
            except TypeError:
                new_subset[i] = (dim.subs(repldict)
                                 if symbolic.issymbolic(dim) else dim)

        return new_subset
Beispiel #15
0
class MapInterchange(pattern_matching.Transformation):
    """ Implements the map-interchange transformation.
    
        Map-interchange takes two nested maps and interchanges their position.
    """

    _outer_map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _inner_map_entry = nodes.MapEntry(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(MapInterchange._outer_map_entry,
                                   MapInterchange._inner_map_entry)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        # TODO: Add matching condition that the map variables are independent
        # of each other.
        # TODO: Assuming that the subsets on the edges between the two map
        # entries/exits are the union of separate inner subsets, is it possible
        # that inverting these edges breaks the continuity of union? What about
        # the opposite?

        # Check the edges between the entries of the two maps.
        outer_map_entry = graph.nodes()[candidate[
            MapInterchange._outer_map_entry]]
        inner_map_entry = graph.nodes()[candidate[
            MapInterchange._inner_map_entry]]
        # Check that the destination of all the outgoing edges
        # from the outer map's entry is the inner map's entry.
        for e in graph.out_edges(outer_map_entry):
            if e.dst != inner_map_entry:
                return False
        # Check that the source of all the incoming edges
        # to the inner map's entry is the outer map's entry.
        for e in graph.in_edges(inner_map_entry):
            if e.src != outer_map_entry:
                return False

        # Check the edges between the exits of the two maps.
        inner_map_exits = graph.exit_nodes(inner_map_entry)
        outer_map_exits = graph.exit_nodes(outer_map_entry)
        inner_map_exit = inner_map_exits[0]
        outer_map_exit = outer_map_exits[0]

        # Check that the destination of all the outgoing edges
        # from the inner map's exit is the outer map's exit.
        for e in graph.out_edges(inner_map_exit):
            if e.dst != outer_map_exit:
                return False
        # Check that the source of all the incoming edges
        # to the outer map's exit is the inner map's exit.
        for e in graph.in_edges(outer_map_exit):
            if e.src != inner_map_exit:
                return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        outer_map_entry = graph.nodes()[candidate[
            MapInterchange._outer_map_entry]]
        inner_map_entry = graph.nodes()[candidate[
            MapInterchange._inner_map_entry]]

        return ' -> '.join(entry.map.label + ': ' + str(entry.map.params)
                           for entry in [outer_map_entry, inner_map_entry])

    def apply(self, sdfg):
        # Extract the parameters and ranges of the inner/outer maps.
        graph = sdfg.nodes()[self.state_id]
        outer_map_entry = graph.nodes()[self.subgraph[
            MapInterchange._outer_map_entry]]
        inner_map_entry = graph.nodes()[self.subgraph[
            MapInterchange._inner_map_entry]]
        inner_map_exits = graph.exit_nodes(inner_map_entry)
        outer_map_exits = graph.exit_nodes(outer_map_entry)
        if len(inner_map_exits) > 1 or len(outer_map_exits) > 1:
            raise NotImplementedError('Map interchange does not work with ' +
                                      'multiple map exits')
        inner_map_exit = inner_map_exits[0]
        outer_map_exit = outer_map_exits[0]

        # Switch connectors
        outer_map_entry.in_connectors, inner_map_entry.in_connectors = \
            inner_map_entry.in_connectors, outer_map_entry.in_connectors
        outer_map_entry.out_connectors, inner_map_entry.out_connectors = \
            inner_map_entry.out_connectors, outer_map_entry.out_connectors
        outer_map_exit.in_connectors, inner_map_exit.in_connectors = \
            inner_map_exit.in_connectors, outer_map_exit.in_connectors
        outer_map_exit.out_connectors, inner_map_exit.out_connectors = \
            inner_map_exit.out_connectors, outer_map_exit.out_connectors

        # Get edges between the map entries and exits.
        entry_edges = graph.edges_between(outer_map_entry, inner_map_entry)
        exit_edges = graph.edges_between(inner_map_exit, outer_map_exit)
        for e in entry_edges + exit_edges:
            graph.remove_edge(e)

        # Change source and destination of edges.
        dace.graph.nxutil.change_edge_dest(graph, outer_map_entry,
                                           inner_map_entry)
        dace.graph.nxutil.change_edge_src(graph, inner_map_entry,
                                          outer_map_entry)
        dace.graph.nxutil.change_edge_dest(graph, inner_map_exit,
                                           outer_map_exit)
        dace.graph.nxutil.change_edge_src(graph, outer_map_exit,
                                          inner_map_exit)

        # Add edges between the map entries and exits.
        for e in entry_edges + exit_edges:
            graph.add_edge(e.dst, e.src_conn, e.src, e.dst_conn, e.data)
Beispiel #16
0
class MapReduceFusion(pm.Transformation):
    """ Implements the map-reduce-fusion transformation.
        Fuses a map with an immediately following reduction, where the array
        between the map and the reduction is not used anywhere else.
    """

    _tasklet = nodes.Tasklet('_')
    _tmap_exit = nodes.MapExit(nodes.Map("", [], []))
    _in_array = nodes.AccessNode('_')
    _reduce = nodes.Reduce('lambda: None', None)
    _out_array = nodes.AccessNode('_')

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(MapReduceFusion._tasklet,
                                   MapReduceFusion._tmap_exit,
                                   MapReduceFusion._in_array,
                                   MapReduceFusion._reduce,
                                   MapReduceFusion._out_array)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        tmap_exit = graph.nodes()[candidate[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[candidate[MapReduceFusion._in_array]]
        reduce_node = graph.nodes()[candidate[MapReduceFusion._reduce]]
        tasklet = graph.nodes()[candidate[MapReduceFusion._tasklet]]

        # Make sure that the array is only accessed by the map and the reduce
        if any([
                src != tmap_exit
                for src, _, _, _, memlet in graph.in_edges(in_array)
        ]):
            return False
        if any([
                dest != reduce_node
                for _, _, dest, _, memlet in graph.out_edges(in_array)
        ]):
            return False

        tmem = next(e for e in graph.edges_between(tasklet, tmap_exit)
                    if e.data.data == in_array.data).data

        # (strict) Make sure that the transient is not accessed anywhere else
        # in this state or other states
        if strict and (len([
                n for n in graph.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == in_array.data
        ]) > 1 or in_array.data in sdfg.shared_transients()):
            return False

        # If memlet already has WCR and it is different from reduce node,
        # do not match
        if tmem.wcr is not None and tmem.wcr != reduce_node.wcr:
            return False

        # Verify that reduction ranges match tasklet map
        tout_memlet = graph.in_edges(in_array)[0].data
        rin_memlet = graph.out_edges(in_array)[0].data
        if tout_memlet.subset != rin_memlet.subset:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[MapReduceFusion._tasklet]
        map_exit = candidate[MapReduceFusion._tmap_exit]
        reduce = candidate[MapReduceFusion._reduce]

        return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]]
        reduce_node = graph.nodes()[self.subgraph[MapReduceFusion._reduce]]
        out_array = graph.nodes()[self.subgraph[MapReduceFusion._out_array]]

        # Set nodes to remove according to the expression index
        nodes_to_remove = [in_array]
        nodes_to_remove.append(reduce_node)

        memlet_edge = None
        for edge in graph.in_edges(tmap_exit):
            if edge.data.data == in_array.data:
                memlet_edge = edge
                break
        if memlet_edge is None:
            raise RuntimeError('Reduction memlet cannot be None')

        # Find which indices should be removed from new memlet
        input_edge = graph.in_edges(reduce_node)[0]
        axes = reduce_node.axes or list(range(input_edge.data.subset))
        array_edge = graph.out_edges(reduce_node)[0]

        # Delete relevant edges and nodes
        graph.remove_nodes_from(nodes_to_remove)

        # Filter out reduced dimensions from subset
        filtered_subset = [
            dim for i, dim in enumerate(memlet_edge.data.subset)
            if i not in axes
        ]
        if len(filtered_subset) == 0:  # Output is a scalar
            filtered_subset = [0]

        # Modify edge from tasklet to map exit
        memlet_edge.data.data = out_array.data
        memlet_edge.data.wcr = reduce_node.wcr
        memlet_edge.data.wcr_identity = reduce_node.identity
        memlet_edge.data.subset = type(
            memlet_edge.data.subset)(filtered_subset)

        # Add edge from map exit to output array
        graph.add_edge(
            memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:], array_edge.dst,
            array_edge.dst_conn,
            Memlet(array_edge.data.data, array_edge.data.num_accesses,
                   array_edge.data.subset, array_edge.data.veclen,
                   reduce_node.wcr, reduce_node.identity))
Beispiel #17
0
class MapWCRFusion(pm.Transformation):
    """ Implements the map expanded-reduce fusion transformation.
        Fuses a map with an immediately following reduction, where the array
        between the map and the reduction is not used anywhere else, and the
        reduction is divided to two maps with a WCR, denoting partial reduction.
    """

    _tasklet = nodes.Tasklet('_')
    _tmap_exit = nodes.MapExit(nodes.Map("", [], []))
    _in_array = nodes.AccessNode('_')
    _rmap_in_entry = nodes.MapEntry(nodes.Map("", [], []))
    _rmap_in_tasklet = nodes.Tasklet('_')
    _rmap_in_cr = nodes.MapExit(nodes.Map("", [], []))
    _rmap_out_entry = nodes.MapEntry(nodes.Map("", [], []))
    _rmap_out_exit = nodes.MapExit(nodes.Map("", [], []))
    _out_array = nodes.AccessNode('_')

    @staticmethod
    def expressions():
        return [
            # Map, then partial reduction of axes
            nxutil.node_path_graph(
                MapWCRFusion._tasklet, MapWCRFusion._tmap_exit,
                MapWCRFusion._in_array, MapWCRFusion._rmap_out_entry,
                MapWCRFusion._rmap_in_entry, MapWCRFusion._rmap_in_tasklet,
                MapWCRFusion._rmap_in_cr, MapWCRFusion._rmap_out_exit,
                MapWCRFusion._out_array)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        tmap_exit = graph.nodes()[candidate[MapWCRFusion._tmap_exit]]
        in_array = graph.nodes()[candidate[MapWCRFusion._in_array]]
        rmap_entry = graph.nodes()[candidate[MapWCRFusion._rmap_out_entry]]

        # Make sure that the array is only accessed by the map and the reduce
        if any([
                src != tmap_exit
                for src, _, _, _, memlet in graph.in_edges(in_array)
        ]):
            return False
        if any([
                dest != rmap_entry
                for _, _, dest, _, memlet in graph.out_edges(in_array)
        ]):
            return False

        # Make sure that there is a reduction in the second map
        rmap_cr = graph.nodes()[candidate[MapWCRFusion._rmap_in_cr]]
        reduce_edge = graph.in_edges(rmap_cr)[0]
        if reduce_edge.data.wcr is None:
            return False

        # (strict) Make sure that the transient is not accessed anywhere else
        # in this state or other states
        if strict and (len([
                n for n in graph.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == in_array.data
        ]) > 1 or in_array.data in sdfg.shared_transients()):
            return False

        # Verify that reduction ranges match tasklet map
        tout_memlet = graph.in_edges(in_array)[0].data
        rin_memlet = graph.out_edges(in_array)[0].data
        if tout_memlet.subset != rin_memlet.subset:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[MapWCRFusion._tasklet]
        map_exit = candidate[MapWCRFusion._tmap_exit]
        reduce = candidate[MapWCRFusion._rmap_in_cr]

        return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce])

    def apply(self, sdfg):
        graph = sdfg.node(self.state_id)

        # To apply, collapse the second map and then fuse the two resulting maps
        map_collapse = MapCollapse(
            self.sdfg_id, self.state_id, {
                MapCollapse._outer_map_entry:
                self.subgraph[MapWCRFusion._rmap_out_entry],
                MapCollapse._inner_map_entry:
                self.subgraph[MapWCRFusion._rmap_in_entry]
            }, 0)
        map_entry, _ = map_collapse.apply(sdfg)

        map_fusion = MapFusion(
            self.sdfg_id, self.state_id, {
                MapFusion._first_map_exit:
                self.subgraph[MapWCRFusion._tmap_exit],
                MapFusion._second_map_entry: graph.node_id(map_entry)
            }, 0)
        map_fusion.apply(sdfg)
Beispiel #18
0
class MapToForLoop(pattern_matching.Transformation):
    """ Implements the Map to for-loop transformation.

        Takes a map and enforces a sequential schedule by transforming it into
        a state-machine of a for-loop. Creates a nested SDFG, if necessary.
    """

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [nxutil.node_path_graph(MapToForLoop._map_entry)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        # Only uni-dimensional maps are accepted.
        map_entry = graph.nodes()[candidate[MapToForLoop._map_entry]]
        if len(map_entry.map.params) > 1:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        map_entry = graph.nodes()[candidate[MapToForLoop._map_entry]]
        return map_entry.map.label + ': ' + str(map_entry.map.params)

    def apply(self, sdfg) -> Tuple[nodes.NestedSDFG, SDFGState]:
        """ Applies the transformation and returns a tuple with the new nested
            SDFG node and the main state in the for-loop. """

        # Retrieve map entry and exit nodes.
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[MapToForLoop._map_entry]]
        map_exit = graph.exit_nodes(map_entry)[0]

        loop_idx = map_entry.map.params[0]
        loop_from, loop_to, loop_step = map_entry.map.range[0]

        # Turn the map scope into a nested SDFG
        node = nest_state_subgraph(sdfg, graph,
                                   graph.scope_subgraph(map_entry))

        nsdfg: SDFG = node.sdfg
        nstate: SDFGState = nsdfg.nodes()[0]

        # If map range is dynamic, replace loop expressions with memlets
        param_to_edge = {}
        for edge in nstate.in_edges(map_entry):
            if edge.dst_conn and not edge.dst_conn.startswith('IN_'):
                param = '__DACE_P%d' % len(param_to_edge)
                repldict = {symbolic.pystr_to_symbolic(edge.dst_conn): param}
                param_to_edge[param] = edge
                loop_from = loop_from.subs(repldict)
                loop_to = loop_to.subs(repldict)
                loop_step = loop_step.subs(repldict)

        # Avoiding import loop
        from dace.codegen.targets.cpp import cpp_array_expr

        def replace_param(param):
            param = symbolic.symstr(param)
            for p, pval in param_to_edge.items():
                # TODO: This special replacement condition will be removed
                #       when the code generator is modified to make consistent
                #       scalar/array decisions.
                if (isinstance(nsdfg.arrays[pval.data.data], data.Scalar)
                        or (nsdfg.arrays[pval.data.data].shape[0] == 1
                            and len(nsdfg.arrays[pval.data.data].shape) == 1)):
                    param = param.replace(p, pval.data.data)
                else:
                    param = param.replace(p, cpp_array_expr(nsdfg, pval.data))
            return param

        # End of dynamic input range

        # Create a loop inside the nested SDFG
        nsdfg.add_loop(None, nstate, None, loop_idx, replace_param(loop_from),
                       '%s < %s' % (loop_idx, replace_param(loop_to + 1)),
                       '%s + %s' % (loop_idx, replace_param(loop_step)))

        # Skip map in input edges
        for edge in nstate.out_edges(map_entry):
            src_node = nstate.memlet_path(edge)[0].src
            nstate.add_edge(src_node, None, edge.dst, edge.dst_conn, edge.data)
            nstate.remove_edge(edge)

        # Skip map in output edges
        for edge in nstate.in_edges(map_exit):
            dst_node = nstate.memlet_path(edge)[-1].dst
            nstate.add_edge(edge.src, edge.src_conn, dst_node, None, edge.data)
            nstate.remove_edge(edge)

        # Remove nodes from dynamic map range
        nstate.remove_nodes_from(
            [e.src for e in dace.sdfg.dynamic_map_inputs(nstate, map_entry)])
        # Remove scope nodes
        nstate.remove_nodes_from([map_entry, map_exit])

        return node, nstate
Beispiel #19
0
    def _stripmine(self, sdfg, graph, candidate):

        # Retrieve map entry and exit nodes.
        map_entry = graph.nodes()[candidate[StripMining._map_entry]]
        map_exit = graph.exit_nodes(map_entry)[0]

        # Retrieve transformation properties.
        dim_idx = self.dim_idx
        new_dim_prefix = self.new_dim_prefix
        tile_size = self.tile_size
        divides_evenly = self.divides_evenly
        strided = self.strided

        tile_stride = self.tile_stride
        if tile_stride is None or len(tile_stride) == 0:
            tile_stride = tile_size

        # Retrieve parameter and range of dimension to be strip-mined.
        target_dim = map_entry.map.params[dim_idx]
        td_from, td_to, td_step = map_entry.map.range[dim_idx]

        # Create new map. Replace by cloning???
        new_dim = self._find_new_dim(sdfg, graph, map_entry, new_dim_prefix,
                                     target_dim)
        nd_from = 0
        nd_to = symbolic.pystr_to_symbolic(
            'int_ceil(%s + 1 - %s, %s) - 1' %
            (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_stride))
        nd_step = 1
        new_dim_range = (nd_from, nd_to, nd_step)
        new_map = nodes.Map(new_dim + '_' + map_entry.map.label, [new_dim],
                            subsets.Range([new_dim_range]))
        new_map_entry = nodes.MapEntry(new_map)
        new_map_exit = nodes.MapExit(new_map)

        # Change the range of the selected dimension to iterate over a single
        # tile
        if strided:
            td_from_new = symbolic.pystr_to_symbolic(new_dim)
            td_to_new_approx = td_to
            td_step = symbolic.pystr_to_symbolic(tile_size)
        else:
            td_from_new = symbolic.pystr_to_symbolic(
                '%s + %s * %s' %
                (symbolic.symstr(td_from), str(new_dim), tile_stride))
            td_to_new_exact = symbolic.pystr_to_symbolic(
                'min(%s + 1, %s + %s * %s + %s) - 1' %
                (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_stride,
                 str(new_dim), tile_size))
            td_to_new_approx = symbolic.pystr_to_symbolic(
                '%s + %s * %s + %s - 1' %
                (symbolic.symstr(td_from), tile_stride, str(new_dim),
                 tile_size))
        if divides_evenly or strided:
            td_to_new = td_to_new_approx
        else:
            td_to_new = dace.symbolic.SymExpr(td_to_new_exact,
                                              td_to_new_approx)
        map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step)

        # Make internal map's schedule to "not parallel"
        new_map.schedule = map_entry.map.schedule
        map_entry.map.schedule = dtypes.ScheduleType.Sequential

        # Redirect edges
        new_map_entry.in_connectors = dcpy(map_entry.in_connectors)
        nxutil.change_edge_dest(graph, map_entry, new_map_entry)
        new_map_exit.out_connectors = dcpy(map_exit.out_connectors)
        nxutil.change_edge_src(graph, map_exit, new_map_exit)

        # Create new entry edges
        new_in_edges = dict()
        entry_in_conn = set()
        entry_out_conn = set()
        for _src, src_conn, _dst, _, memlet in graph.out_edges(map_entry):
            if (src_conn is not None
                    and src_conn[:4] == 'OUT_' and not isinstance(
                        sdfg.arrays[memlet.data], dace.data.Scalar)):
                new_subset = calc_set_image(
                    map_entry.map.params,
                    map_entry.map.range,
                    memlet.subset,
                )
                conn = src_conn[4:]
                key = (memlet.data, 'IN_' + conn, 'OUT_' + conn)
                if key in new_in_edges.keys():
                    old_subset = new_in_edges[key].subset
                    new_in_edges[key].subset = calc_set_union(
                        old_subset, new_subset)
                else:
                    entry_in_conn.add('IN_' + conn)
                    entry_out_conn.add('OUT_' + conn)
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_in_edges[key] = new_memlet
            else:
                if src_conn is not None and src_conn[:4] == 'OUT_':
                    conn = src_conn[4:]
                    in_conn = 'IN_' + conn
                    out_conn = 'OUT_' + conn
                else:
                    in_conn = src_conn
                    out_conn = src_conn
                if in_conn:
                    entry_in_conn.add(in_conn)
                if out_conn:
                    entry_out_conn.add(out_conn)
                new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet)
        new_map_entry.out_connectors = entry_out_conn
        map_entry.in_connectors = entry_in_conn
        for (_, in_conn, out_conn), memlet in new_in_edges.items():
            graph.add_edge(new_map_entry, out_conn, map_entry, in_conn, memlet)

        # Create new exit edges
        new_out_edges = dict()
        exit_in_conn = set()
        exit_out_conn = set()
        for _src, _, _dst, dst_conn, memlet in graph.in_edges(map_exit):
            if (dst_conn is not None
                    and dst_conn[:3] == 'IN_' and not isinstance(
                        sdfg.arrays[memlet.data], dace.data.Scalar)):
                new_subset = calc_set_image(
                    map_entry.map.params,
                    map_entry.map.range,
                    memlet.subset,
                )
                conn = dst_conn[3:]
                key = (memlet.data, 'IN_' + conn, 'OUT_' + conn)
                if key in new_out_edges.keys():
                    old_subset = new_out_edges[key].subset
                    new_out_edges[key].subset = calc_set_union(
                        old_subset, new_subset)
                else:
                    exit_in_conn.add('IN_' + conn)
                    exit_out_conn.add('OUT_' + conn)
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_out_edges[key] = new_memlet
            else:
                if dst_conn is not None and dst_conn[:3] == 'IN_':
                    conn = dst_conn[3:]
                    in_conn = 'IN_' + conn
                    out_conn = 'OUT_' + conn
                else:
                    in_conn = src_conn
                    out_conn = src_conn
                if in_conn:
                    exit_in_conn.add(in_conn)
                if out_conn:
                    exit_out_conn.add(out_conn)
                new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet)
        new_map_exit.in_connectors = exit_in_conn
        map_exit.out_connectors = exit_out_conn
        for (_, in_conn, out_conn), memlet in new_out_edges.items():
            graph.add_edge(map_exit, out_conn, new_map_exit, in_conn, memlet)

        # Return strip-mined dimension.
        return target_dim, new_dim, new_map
Beispiel #20
0
class Vectorization(pattern_matching.Transformation):
    """ Implements the vectorization transformation.

        Vectorization matches when all the input and output memlets of a 
        tasklet inside a map access the inner-most loop variable in their last
        dimension. The transformation changes the step of the inner-most loop
        to be equal to the length of the vector and vectorizes the memlets.
  """

    vector_len = Property(desc="Vector length", dtype=int, default=4)
    propagate_parent = Property(desc="Propagate vector length through "
                                "parent SDFGs",
                                dtype=bool,
                                default=False)

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(Vectorization._map_entry,
                                   Vectorization._tasklet,
                                   Vectorization._map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_entry = graph.nodes()[candidate[Vectorization._map_entry]]
        tasklet = graph.nodes()[candidate[Vectorization._tasklet]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])
        found = False

        # Check if all edges, adjacent to the tasklet,
        # use the parameter in their last dimension.
        for _src, _, _dest, _, memlet in graph.all_edges(tasklet):

            # Cases that do not matter for vectorization
            if memlet.data is None:  # Empty memlets
                continue
            if isinstance(sdfg.arrays[memlet.data], data.Stream):  # Streams
                continue

            # Vectorization can not be applied in WCR
            if memlet.wcr is not None:
                return False

            try:
                subset = memlet.subset
                veclen = memlet.veclen
            except AttributeError:
                return False

            if subset is None:
                return False

            try:
                if veclen > symbolic.pystr_to_symbolic('1'):
                    return False

                for idx, expr in enumerate(subset):
                    if isinstance(expr, tuple):
                        for ex in expr:
                            ex = symbolic.pystr_to_symbolic(ex)
                            symbols = ex.free_symbols
                            if param in symbols:
                                if idx == subset.dims() - 1:
                                    found = True
                                else:
                                    return False
                    else:
                        expr = symbolic.pystr_to_symbolic(expr)
                        symbols = expr.free_symbols
                        if param in symbols:
                            if idx == subset.dims() - 1:
                                found = True
                            else:
                                return False
            except TypeError:  # cannot determine truth value of Relational
                return False

        return found

    @staticmethod
    def match_to_str(graph, candidate):

        map_entry = candidate[Vectorization._map_entry]
        tasklet = candidate[Vectorization._tasklet]
        map_exit = candidate[Vectorization._map_exit]

        return ' -> '.join(
            str(node) for node in [map_entry, tasklet, map_exit])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]]
        tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]]
        map_exit = graph.nodes()[self.subgraph[Vectorization._map_exit]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])

        # Create new vector size.
        vector_size = self.vector_len

        # Change the step of the inner-most dimension.
        dim_from, dim_to, _dim_step = map_entry.map.range[-1]
        map_entry.map.range[-1] = (dim_from, dim_to, vector_size)

        # Vectorize memlets adjacent to the tasklet.
        for edge in graph.all_edges(tasklet):
            _src, _, _dest, _, memlet = edge

            if memlet.data is None:  # Empty memlets
                continue

            lastindex = memlet.subset[-1]
            if isinstance(lastindex, tuple):
                symbols = set()
                for indd in lastindex:
                    symbols.update(
                        symbolic.pystr_to_symbolic(indd).free_symbols)
            else:
                symbols = symbolic.pystr_to_symbolic(
                    memlet.subset[-1]).free_symbols

            if param in symbols:
                try:
                    # propagate vector length inside this SDFG
                    for e in graph.memlet_path(edge):
                        e.data.veclen = vector_size

                    source_edge = graph.memlet_path(edge)[0]
                    sink_edge = graph.memlet_path(edge)[-1]

                    # propagate to the parent (TODO: handle multiple level of nestings)
                    if self.propagate_parent and sdfg.parent is not None:
                        # Find parent Nested SDFG node
                        parent_node = next(n for n in sdfg.parent.nodes()
                                           if isinstance(n, nodes.NestedSDFG)
                                           and n.sdfg.name == sdfg.name)

                        # continue in propagating the vector length following the path that arrives to source_edge or
                        # starts from sink_edge
                        for pe in sdfg.parent.all_edges(parent_node):
                            if str(pe.dst_conn) == str(source_edge.src) or str(
                                    pe.src_conn) == str(sink_edge.dst):
                                for ppe in sdfg.parent.memlet_path(pe):
                                    ppe.data.veclen = vector_size

                except AttributeError:
                    raise
        return
Beispiel #21
0
    def __stripmine(self, sdfg, graph, candidate):
        # Retrieve map entry and exit nodes.
        map_entry = graph.nodes()[candidate[OrthogonalTiling._map_entry]]
        map_exit = graph.exit_nodes(map_entry)[0]

        # Map subgraph
        map_subgraph = graph.scope_subgraph(map_entry)

        # Retrieve transformation properties.
        prefix = self.prefix
        tile_sizes = self.tile_sizes
        divides_evenly = self.divides_evenly

        new_param = []
        new_range = []

        for dim_idx in range(len(map_entry.map.params)):

            if dim_idx >= len(tile_sizes):
                tile_size = tile_sizes[-1]
            else:
                tile_size = tile_sizes[dim_idx]

            # Retrieve parameter and range of dimension to be strip-mined.
            target_dim = map_entry.map.params[dim_idx]
            td_from, td_to, td_step = map_entry.map.range[dim_idx]

            new_dim = prefix + '_' + target_dim

            # Basic values
            if divides_evenly:
                tile_num = '(%s + 1 - %s) / %s' % (symbolic.symstr(td_to),
                                                   symbolic.symstr(td_from),
                                                   str(tile_size))
            else:
                tile_num = 'int_ceil((%s + 1 - %s), %s)' % (symbolic.symstr(
                    td_to), symbolic.symstr(td_from), str(tile_size))

            # Outer map values (over all tiles)
            nd_from = 0
            nd_to = symbolic.pystr_to_symbolic(str(tile_num) + ' - 1')
            nd_step = 1

            # Inner map values (over one tile)
            td_from_new = dace.symbolic.pystr_to_symbolic(td_from)
            td_to_new_exact = symbolic.pystr_to_symbolic(
                'min(%s + 1 - %s * %s, %s + %s) - 1' %
                (symbolic.symstr(td_to), str(new_dim), str(tile_size),
                 td_from_new, str(tile_size)))
            td_to_new_approx = symbolic.pystr_to_symbolic(
                '%s + %s - 1' % (td_from_new, str(tile_size)))

            # Outer map (over all tiles)
            new_dim_range = (nd_from, nd_to, nd_step)
            new_param.append(new_dim)
            new_range.append(new_dim_range)

            # Inner map (over one tile)
            if divides_evenly:
                td_to_new = td_to_new_approx
            else:
                td_to_new = dace.symbolic.SymExpr(td_to_new_exact,
                                                  td_to_new_approx)
            map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step)

            # Fix subgraph memlets
            target_dim = dace.symbolic.pystr_to_symbolic(target_dim)
            offset = dace.symbolic.pystr_to_symbolic(
                '%s * %s' % (new_dim, str(tile_size)))
            for _, _, _, _, memlet in map_subgraph.edges():
                old_subset = memlet.subset
                if isinstance(old_subset, dace.subsets.Indices):
                    new_indices = []
                    for idx in old_subset:
                        new_idx = idx.subs(target_dim, target_dim + offset)
                        new_indices.append(new_idx)
                    memlet.subset = dace.subsets.Indices(new_indices)
                elif isinstance(old_subset, dace.subsets.Range):
                    new_ranges = []
                    for i, old_range in enumerate(old_subset):
                        if len(old_range) == 3:
                            b, e, s, = old_range
                            t = old_subset.tile_sizes[i]
                        else:
                            raise ValueError(
                                'Range %s is invalid.' % old_range)
                        new_b = b.subs(target_dim, target_dim + offset)
                        new_e = e.subs(target_dim, target_dim + offset)
                        new_s = s.subs(target_dim, target_dim + offset)
                        new_t = t.subs(target_dim, target_dim + offset)
                        new_ranges.append((new_b, new_e, new_s, new_t))
                    memlet.subset = dace.subsets.Range(new_ranges)
                else:
                    raise NotImplementedError

        new_map = nodes.Map(prefix + '_' + map_entry.map.label, new_param,
                            subsets.Range(new_range))
        new_map_entry = nodes.MapEntry(new_map)
        new_exit = nodes.MapExit(new_map)

        # Make internal map's schedule to "not parallel"
        map_entry.map._schedule = dtypes.ScheduleType.Default

        # Redirect/create edges.
        new_in_edges = {}
        for _src, conn, _dest, _, memlet in graph.out_edges(map_entry):
            if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar):
                new_subset = copy.deepcopy(memlet.subset)
                # new_subset = calc_set_image(map_entry.map.params,
                #                             map_entry.map.range, memlet.subset,
                #                             cont_or_strided)
                if memlet.data in new_in_edges:
                    src, src_conn, dest, dest_conn, new_memlet, num = \
                        new_in_edges[memlet.data]
                    new_memlet.subset = calc_set_union(
                        new_memlet.data, sdfg.arrays[nnew_memlet.data],
                        new_memlet.subset, new_subset)
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_in_edges.update({
                        memlet.data: (src, src_conn, dest, dest_conn,
                                      new_memlet, min(num, int(conn[4:])))
                    })
                else:
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_in_edges.update({
                        memlet.data: (new_map_entry, None, map_entry, None,
                                      new_memlet, int(conn[4:]))
                    })
        nxutil.change_edge_dest(graph, map_entry, new_map_entry)

        new_out_edges = {}
        for _src, conn, _dest, _, memlet in graph.in_edges(map_exit):
            if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar):
                new_subset = memlet.subset
                # new_subset = calc_set_image(map_entry.map.params,
                #                             map_entry.map.range,
                #                             memlet.subset, cont_or_strided)
                if memlet.data in new_out_edges:
                    src, src_conn, dest, dest_conn, new_memlet, num = \
                        new_out_edges[memlet.data]
                    new_memlet.subset = calc_set_union(
                        new_memlet.data, sdfg.arrays[nnew_memlet.data],
                        new_memlet.subset, new_subset)
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_out_edges.update({
                        memlet.data: (src, src_conn, dest, dest_conn,
                                      new_memlet, min(num, conn[4:]))
                    })
                else:
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_out_edges.update({
                        memlet.data: (map_exit, None, new_exit, None,
                                      new_memlet, conn[4:])
                    })
        nxutil.change_edge_src(graph, map_exit, new_exit)

        # Connector related work follows
        # 1. Dictionary 'old_connector_number': 'new_connector_numer'
        # 2. New node in/out connectors
        # 3. New edges

        in_conn_nums = []
        for _, e in new_in_edges.items():
            _, _, _, _, _, num = e
            in_conn_nums.append(num)
        in_conn = {}
        for i, num in enumerate(in_conn_nums):
            in_conn.update({num: i + 1})

        entry_in_connectors = set()
        entry_out_connectors = set()
        for i in range(len(in_conn_nums)):
            entry_in_connectors.add('IN_' + str(i + 1))
            entry_out_connectors.add('OUT_' + str(i + 1))
        new_map_entry.in_connectors = entry_in_connectors
        new_map_entry.out_connectors = entry_out_connectors

        for _, e in new_in_edges.items():
            src, _, dst, _, memlet, num = e
            graph.add_edge(src, 'OUT_' + str(in_conn[num]), dst,
                           'IN_' + str(in_conn[num]), memlet)

        out_conn_nums = []
        for _, e in new_out_edges.items():
            _, _, dst, _, _, num = e
            if dst is not new_exit:
                continue
            out_conn_nums.append(num)
        out_conn = {}
        for i, num in enumerate(out_conn_nums):
            out_conn.update({num: i + 1})

        exit_in_connectors = set()
        exit_out_connectors = set()
        for i in range(len(out_conn_nums)):
            exit_in_connectors.add('IN_' + str(i + 1))
            exit_out_connectors.add('OUT_' + str(i + 1))
        new_exit.in_connectors = exit_in_connectors
        new_exit.out_connectors = exit_out_connectors

        for _, e in new_out_edges.items():
            src, _, dst, _, memlet, num = e
            graph.add_edge(src, 'OUT_' + str(out_conn[num]), dst,
                           'IN_' + str(out_conn[num]), memlet)

        # Return strip-mined dimension.
        return target_dim, new_dim, new_map
Beispiel #22
0
class MPITransformMap(pattern_matching.Transformation):
    """ Implements the MPI parallelization pattern.

        Takes a map and makes it an MPI-scheduled map, introduces transients
        that keep locally accessed data.
        
         Original SDFG
         =============
         ```
         Input1 -                                            Output1
                 \                                          /
         Input2 --- MapEntry -- Arbitrary R  -- MapExit -- Output2
                 /                                          \
         InputN -                                            OutputN
         ```

         Nothing in R may access other inputs/outputs that are not defined in R
         itself and do not go through MapEntry/MapExit
         Map must be a one-dimensional map for now.
         The range of the map must be a Range object.

         Output:
         =======
        
         * Add transients for the accessed parts
         * The schedule property of Map is set to MPI
         * The range of Map is changed to
           var = startexpr + p * chunksize ... startexpr + p + 1 * chunksize
           where p is the current rank and P is the total number of ranks,
           and chunksize is defined as (endexpr - startexpr) / P, adding the 
           remaining K iterations to the first K procs.
         * For each input InputI, create a new transient transInputI, which 
           has an attribute that specifies that it needs to be filled with 
           (possibly) remote data
         * Collect all accesses to InputI within R, assume their convex hull is
           InputI[rs ... re]
         * The transInputI transient will contain InputI[rs ... re]
         * Change all accesses to InputI within R to accesses to transInputI
    """

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [nxutil.node_path_graph(MPITransformMap._map_entry)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_entry = graph.nodes()[candidate[MPITransformMap._map_entry]]

        # Check if the map is one-dimensional
        if map_entry.map.range.dims() != 1:
            return False

        # We cannot transform a map which is already of schedule type MPI
        if map_entry.map.schedule == dtypes.ScheduleType.MPI:
            return False

        # We cannot transform a map which is already inside a MPI map, or in
        # another device
        schedule_whitelist = [
            dtypes.ScheduleType.Default, dtypes.ScheduleType.Sequential
        ]
        sdict = graph.scope_dict()
        parent = sdict[map_entry]
        while parent is not None:
            if parent.map.schedule not in schedule_whitelist:
                return False
            parent = sdict[parent]

        # Dynamic map ranges not supported (will allocate dynamic memory)
        if has_dynamic_map_inputs(graph, map_entry):
            return False

        # MPI schedules currently do not support WCR
        map_exit = graph.exit_nodes(map_entry)[0]
        if any(e.data.wcr for e in graph.out_edges(map_exit)):
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        map_entry = graph.nodes()[candidate[MPITransformMap._map_entry]]

        return map_entry.map.label

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]

        map_entry = graph.nodes()[self.subgraph[MPITransformMap._map_entry]]

        # Avoiding import loops
        from dace.transformation.dataflow.strip_mining import StripMining
        from dace.transformation.dataflow.local_storage import LocalStorage

        rangeexpr = str(map_entry.map.range.num_elements())

        stripmine_subgraph = {
            StripMining._map_entry: self.subgraph[MPITransformMap._map_entry]
        }
        sdfg_id = sdfg.sdfg_list.index(sdfg)
        stripmine = StripMining(sdfg_id, self.state_id, stripmine_subgraph,
                                self.expr_index)
        stripmine.dim_idx = -1
        stripmine.new_dim_prefix = "mpi"
        stripmine.tile_size = "(" + rangeexpr + "/__dace_comm_size)"
        stripmine.divides_evenly = True
        stripmine.apply(sdfg)

        # Find all in-edges that lead to candidate[MPITransformMap._map_entry]
        outer_map = None
        edges = [
            e for e in graph.in_edges(map_entry)
            if isinstance(e.src, nodes.EntryNode)
        ]

        outer_map = edges[0].src

        # Add MPI schedule attribute to outer map
        outer_map.map._schedule = dtypes.ScheduleType.MPI

        # Now create a transient for each array
        for e in edges:
            in_local_storage_subgraph = {
                LocalStorage._node_a: graph.node_id(outer_map),
                LocalStorage._node_b: self.subgraph[MPITransformMap._map_entry]
            }
            sdfg_id = sdfg.sdfg_list.index(sdfg)
            in_local_storage = LocalStorage(sdfg_id, self.state_id,
                                            in_local_storage_subgraph,
                                            self.expr_index)
            in_local_storage.array = e.data.data
            in_local_storage.apply(sdfg)

        # Transform OutLocalStorage for each output of the MPI map
        in_map_exits = graph.exit_nodes(map_entry)
        out_map_exits = graph.exit_nodes(outer_map)
        in_map_exit = in_map_exits[0]
        out_map_exit = out_map_exits[0]

        for e in graph.out_edges(out_map_exit):
            name = e.data.data
            outlocalstorage_subgraph = {
                LocalStorage._node_a: graph.node_id(in_map_exit),
                LocalStorage._node_b: graph.node_id(out_map_exit)
            }
            sdfg_id = sdfg.sdfg_list.index(sdfg)
            outlocalstorage = LocalStorage(sdfg_id, self.state_id,
                                           outlocalstorage_subgraph,
                                           self.expr_index)
            outlocalstorage.array = name
            outlocalstorage.apply(sdfg)
Beispiel #23
0
class Vectorization(pattern_matching.Transformation):
    """ Implements the vectorization transformation.

        Vectorization matches when all the input and output memlets of a 
        tasklet inside a map access the inner-most loop variable in their last
        dimension. The transformation changes the step of the inner-most loop
        to be equal to the length of the vector and vectorizes the memlets.
  """

    vector_len = Property(desc="Vector length", dtype=int, default=4)

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(Vectorization._map_entry,
                                   Vectorization._tasklet,
                                   Vectorization._map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_entry = graph.nodes()[candidate[Vectorization._map_entry]]
        tasklet = graph.nodes()[candidate[Vectorization._tasklet]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])
        found = False
        dtype = None

        # Check if all edges, adjacent to the tasklet,
        # use the parameter in their last dimension.
        for _src, _, _dest, _, memlet in graph.all_edges(tasklet):

            # Cases that do not matter for vectorization
            if isinstance(sdfg.arrays[memlet.data], data.Stream):
                continue
            if memlet.wcr is not None:
                continue

            try:
                subset = memlet.subset
                veclen = memlet.veclen
            except AttributeError:
                return False

            if subset is None:
                return False

            try:
                if veclen > symbolic.pystr_to_symbolic('1'):
                    return False

                for idx, expr in enumerate(subset):
                    if isinstance(expr, tuple):
                        for ex in expr:
                            ex = symbolic.pystr_to_symbolic(ex)
                            symbols = ex.free_symbols
                            if param in symbols:
                                if idx == subset.dims() - 1:
                                    found = True
                                else:
                                    return False
                    else:
                        expr = symbolic.pystr_to_symbolic(expr)
                        symbols = expr.free_symbols
                        if param in symbols:
                            if idx == subset.dims() - 1:
                                found = True
                            else:
                                return False
            except TypeError:  # cannot determine truth value of Relational
                return False

        return found

    @staticmethod
    def match_to_str(graph, candidate):

        map_entry = candidate[Vectorization._map_entry]
        tasklet = candidate[Vectorization._tasklet]
        map_exit = candidate[Vectorization._map_exit]

        return ' -> '.join(
            str(node) for node in [map_entry, tasklet, map_exit])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]]
        tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]]
        map_exit = graph.nodes()[self.subgraph[Vectorization._map_exit]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])

        # Create new vector size.
        vector_size = self.vector_len

        # Change the step of the inner-most dimension.
        dim_from, dim_to, _dim_step = map_entry.map.range[-1]
        map_entry.map.range[-1] = (dim_from, dim_to, vector_size)

        # Vectorize memlets adjacent to the tasklet.
        for _src, _, _dest, _, memlet in graph.all_edges(tasklet):
            subset = memlet.subset
            lastindex = memlet.subset[-1]
            if isinstance(lastindex, tuple):
                symbols = set()
                for indd in lastindex:
                    symbols.update(
                        symbolic.pystr_to_symbolic(indd).free_symbols)
            else:
                symbols = symbolic.pystr_to_symbolic(
                    memlet.subset[-1]).free_symbols
            if param in symbols:
                try:
                    memlet.veclen = vector_size
                except AttributeError:
                    return

        # TODO: Create new map for non-vectorizable part.

        return

    def modifies_graph(self):
        return True
Beispiel #24
0
class MapReduceFusion(pm.Transformation):
    """ Implements the map-reduce-fusion transformation.
        Fuses a map with an immediately following reduction, where the array
        between the map and the reduction is not used anywhere else.
    """

    _tasklet = nodes.Tasklet('_')
    _tmap_exit = nodes.MapExit(nodes.Map("", [], []))
    _in_array = nodes.AccessNode('_')
    _rmap_in_entry = nodes.MapEntry(nodes.Map("", [], []))
    _rmap_in_tasklet = nodes.Tasklet('_')
    _rmap_in_cr = nodes.MapExit(nodes.Map("", [], []))
    _rmap_out_entry = nodes.MapEntry(nodes.Map("", [], []))
    _rmap_out_exit = nodes.MapExit(nodes.Map("", [], []))
    _out_array = nodes.AccessNode('_')
    _reduce = nodes.Reduce('lambda: None', None)

    @staticmethod
    def expressions():
        return [
            # Map, then reduce of all axes
            nxutil.node_path_graph(
                MapReduceFusion._tasklet, MapReduceFusion._tmap_exit,
                MapReduceFusion._in_array, MapReduceFusion._rmap_in_entry,
                MapReduceFusion._rmap_in_tasklet, MapReduceFusion._rmap_in_cr,
                MapReduceFusion._out_array),
            # Map, then partial reduction of axes
            nxutil.node_path_graph(
                MapReduceFusion._tasklet, MapReduceFusion._tmap_exit,
                MapReduceFusion._in_array, MapReduceFusion._rmap_out_entry,
                MapReduceFusion._rmap_in_entry,
                MapReduceFusion._rmap_in_tasklet, MapReduceFusion._rmap_in_cr,
                MapReduceFusion._rmap_out_exit, MapReduceFusion._out_array),
            # Map, then reduce node
            nxutil.node_path_graph(
                MapReduceFusion._tasklet, MapReduceFusion._tmap_exit,
                MapReduceFusion._in_array, MapReduceFusion._reduce,
                MapReduceFusion._out_array)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        tmap_exit = graph.nodes()[candidate[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[candidate[MapReduceFusion._in_array]]
        if expr_index == 0:  # Reduce without outer map
            rmap_entry = graph.nodes()[candidate[
                MapReduceFusion._rmap_in_entry]]
            # rmap_in_entry = rmap_entry
        elif expr_index == 1:  # Reduce with outer map
            rmap_entry = graph.nodes()[candidate[
                MapReduceFusion._rmap_out_entry]]
            # rmap_in_entry = graph.nodes()[candidate[
            #     MapReduceFusion._rmap_in_entry]]
        else:  # Reduce node
            rmap_entry = graph.nodes()[candidate[MapReduceFusion._reduce]]

        # Make sure that the array is only accessed by the map and the reduce
        if any([
                src != tmap_exit
                for src, _, _, _, memlet in graph.in_edges(in_array)
        ]):
            return False
        if any([
                dest != rmap_entry
                for _, _, dest, _, memlet in graph.out_edges(in_array)
        ]):
            return False

        # Make sure that there is a reduction in the second map
        if expr_index < 2:
            rmap_cr = graph.nodes()[candidate[MapReduceFusion._rmap_in_cr]]
            reduce_edge = graph.in_edges(rmap_cr)[0]
            if reduce_edge.data.wcr is None:
                return False

        # Make sure that the transient is not accessed by other states
        # if garr.get_unique_name() in cgen_state.sdfg.shared_transients():
        #     return False

        # reduce_inarr = reduce.in_array
        # reduce_outarr = reduce.out_array
        # reduce_inslice = reduce.inslice
        # reduce_outslice = reduce.outslice

        # insize = cgen_state.var_sizes[reduce_inarr]
        # outsize = cgen_state.var_sizes[reduce_outarr]

        # Currently only supports full-range arrays
        # TODO(later): Support fusion of partial reductions and refactor slice/subarray handling
        #if not nxutil.fullrange(reduce_inslice, insize) or \
        #   not nxutil.fullrange(reduce_outslice, outsize):
        #    return False

        # Verify acceses from tasklet through MapExit
        #already_found = False
        #for _src, _, _dest, _, memlet in graph.in_edges(map_exit):
        #    if isinstance(memlet.subset, subsets.Indices):
        #        # Make sure that only one value is reduced at a time
        #        if memlet.data == in_array.desc:
        #            if already_found:
        #                return False
        #            already_found = True

        ## Find axes after reduction
        #indims = len(reduce.inslice)
        #axis_after_reduce = [None] * indims
        #ctr = 0
        #for i in range(indims):
        #    if reduce.axes is not None and i in reduce.axes:
        #        axis_after_reduce[i] = None
        #    else:
        #        axis_after_reduce[i] = ctr
        #        ctr += 1

        ## Match map ranges with reduce ranges
        #curaxis = 0
        #for dim, var in enumerate(memlet.subset):
        #    # Make sure that indices are direct symbols
        #    #if not isinstance(symbolic.pystr_to_symbolic(var), sympy.Symbol):
        #    #    return False
        #    perm = None
        #    for i, mapvar in enumerate(map_exit.map.params):
        #        if symbolic.pystr_to_symbolic(mapvar) == var:
        #            perm = i
        #            break
        #    if perm is None:  # If symbol is not found in map range
        #        return False

        #    # Make sure that map ranges match output slice after reduction
        #    map_range = map_exit.map.range[perm]
        #    if map_range[0] != 0:
        #        return False  # Disallow start from middle
        #    if map_range[2] is not None and map_range[2] != 1:
        #        return False  # Disallow skip
        #    if reduce.axes is not None and dim not in reduce.axes:
        #        if map_range[1] != symbolic.pystr_to_symbolic(
        #                reduce.outslice[axis_after_reduce[dim]][1]):
        #            return False  # Range check (output axis)
        #    else:
        #        if map_range[1] != symbolic.pystr_to_symbolic(reduce.inslice[dim][1]):
        #            return False  # Range check (reduction axis)

        # Verify that reduction ranges match tasklet map
        tout_memlet = graph.in_edges(in_array)[0].data
        rin_memlet = graph.out_edges(in_array)[0].data
        if tout_memlet.subset != rin_memlet.subset:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[MapReduceFusion._tasklet]
        map_exit = candidate[MapReduceFusion._tmap_exit]
        if len(candidate) == 5:  # Expression 2
            reduce = candidate[MapReduceFusion._reduce]
        else:
            reduce = candidate[MapReduceFusion._rmap_in_cr]

        return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce])

    @staticmethod
    def find_memlet_map_permutation(memlet: Memlet, map: nodes.Map):
        perm = [None] * len(memlet.subset)
        indices = set()
        for i, dim in enumerate(memlet.subset):
            for j, mapdim in enumerate(map.params):
                if symbolic.pystr_to_symbolic(
                        mapdim) == dim and j not in indices:
                    perm[i] = j
                    indices.add(j)
                    break
        return perm

    @staticmethod
    def find_permutation(tasklet_map: nodes.Map, red_outer_map: nodes.Map,
                         red_inner_map: nodes.Map, tmem: Memlet):
        """ Find permutation between tasklet-exit memlet and tasklet map. """
        result = [], []

        assert len(tasklet_map.range) == len(red_inner_map.range) + len(
            red_outer_map.range)

        # Match map ranges with reduce ranges
        unavailable_ranges_out = set()
        unavailable_ranges_in = set()
        for i, tmap_rng in enumerate(tasklet_map.range):
            found = False
            for j, rng in enumerate(red_outer_map.range):
                if tmap_rng == rng and j not in unavailable_ranges_out:
                    result[0].append(i)
                    unavailable_ranges_out.add(j)
                    found = True
                    break
            if found: continue
            for j, rng in enumerate(red_inner_map.range):
                if tmap_rng == rng and j not in unavailable_ranges_in:
                    result[1].append(i)
                    unavailable_ranges_in.add(j)
                    found = True
                    break
            if not found: break

        # Ensure all map variables matched with reduce variables
        assert len(result[0]) + len(result[1]) == len(tasklet_map.range)

        # Returns ([outer map indices], [inner (CR) map indices])
        return result

    @staticmethod
    def find_permutation_reduce(tasklet_map: nodes.Map,
                                reduce_node: nodes.Reduce, graph: SDFGState,
                                tmem: Memlet):

        in_memlet = graph.in_edges(reduce_node)[0].data
        out_memlet = graph.out_edges(reduce_node)[0].data
        assert len(tasklet_map.range) == in_memlet.subset.dims()

        # Find permutation between tasklet-exit memlet and tasklet map
        tmem_perm = MapReduceFusion.find_memlet_map_permutation(
            tmem, tasklet_map)
        mapred_perm = []

        # Match map ranges with reduce ranges
        unavailable_ranges = set()
        for i, tmap_rng in enumerate(tasklet_map.range):
            found = False

            for j, in_rng in enumerate(in_memlet.subset):
                if tmap_rng == in_rng and j not in unavailable_ranges:
                    mapred_perm.append(i)
                    unavailable_ranges.add(j)
                    found = True
                    break
            if not found: break

        # Ensure all map variables matched with reduce variables
        assert len(tmem_perm) == len(tmem.subset)
        assert len(mapred_perm) == len(in_memlet.subset)

        # Prepare result from the two permutations and the reduction axes
        result = []
        for i in range(len(mapred_perm)):
            if reduce_node.axes is None or i in reduce_node.axes:
                continue
            result.append(mapred_perm[tmem_perm[i]])

        return result

    def apply(self, sdfg):
        def gnode(nname):
            return graph.nodes()[self.subgraph[nname]]

        expr_index = self.expr_index
        graph = sdfg.nodes()[self.state_id]
        tasklet = gnode(MapReduceFusion._tasklet)
        tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]]
        if expr_index == 0:  # Reduce without outer map
            rmap_entry = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_in_entry]]
        elif expr_index == 1:  # Reduce with outer map
            rmap_out_entry = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_out_entry]]
            rmap_out_exit = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_out_exit]]
            rmap_in_entry = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_in_entry]]
            rmap_tasklet = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_in_tasklet]]

        if expr_index == 2:
            rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._reduce]]
        else:
            rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._rmap_in_cr]]
        out_array = gnode(MapReduceFusion._out_array)

        # Set nodes to remove according to the expression index
        nodes_to_remove = [in_array]
        if expr_index == 0:
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry))
        elif expr_index == 1:
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_entry))
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry))
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_exit))
        else:
            nodes_to_remove.append(gnode(MapReduceFusion._reduce))

        # If no other edges lead to mapexit, remove it. Otherwise, keep
        # it and remove reduction incoming/outgoing edges
        if expr_index != 2 and len(graph.in_edges(tmap_exit)) == 1:
            nodes_to_remove.append(tmap_exit)

        memlet_edge = None
        for edge in graph.in_edges(tmap_exit):
            if edge.data.data == in_array.data:
                memlet_edge = edge
                break
        if memlet_edge is None:
            raise RuntimeError('Reduction memlet cannot be None')

        if expr_index == 0:  # Reduce without outer map
            # Index order does not matter, merge as-is
            pass
        elif expr_index == 1:  # Reduce with outer map
            tmap = tmap_exit.map
            perm_outer, perm_inner = MapReduceFusion.find_permutation(
                tmap, rmap_out_entry.map, rmap_in_entry.map, memlet_edge.data)

            # Split tasklet map into tmap_out -> tmap_in (according to
            # reduction)
            omap = nodes.Map(
                tmap.label + '_nonreduce',
                [p for i, p in enumerate(tmap.params) if i in perm_outer],
                [r for i, r in enumerate(tmap.range) if i in perm_outer],
                tmap.schedule, tmap.unroll, tmap.is_async)
            tmap.params = [
                p for i, p in enumerate(tmap.params) if i in perm_inner
            ]
            tmap.range = [
                r for i, r in enumerate(tmap.range) if i in perm_inner
            ]
            omap_entry = nodes.MapEntry(omap)
            omap_exit = rmap_out_exit
            rmap_out_exit.map = omap

            # Reconnect graph to new map
            tmap_entry = graph.entry_node(tmap_exit)
            tmap_in_edges = list(graph.in_edges(tmap_entry))
            for e in tmap_in_edges:
                nxutil.change_edge_dest(graph, tmap_entry, omap_entry)
            for e in tmap_in_edges:
                graph.add_edge(omap_entry, e.src_conn, tmap_entry, e.dst_conn,
                               copy.copy(e.data))
        elif expr_index == 2:  # Reduce node
            # Find correspondence between map indices and array outputs
            tmap = tmap_exit.map
            perm = MapReduceFusion.find_permutation_reduce(
                tmap, rmap_cr, graph, memlet_edge.data)

            output_subset = [tmap.params[d] for d in perm]
            if len(output_subset) == 0:  # Output is a scalar
                output_subset = [0]

            array_edge = graph.out_edges(rmap_cr)[0]

            # Delete relevant edges and nodes
            graph.remove_edge(memlet_edge)
            graph.remove_nodes_from(nodes_to_remove)

            # Add new edges and nodes
            #   From tasklet to map exit
            graph.add_edge(
                memlet_edge.src, memlet_edge.src_conn, memlet_edge.dst,
                memlet_edge.dst_conn,
                Memlet(out_array.data, memlet_edge.data.num_accesses,
                       subsets.Indices(output_subset), memlet_edge.data.veclen,
                       rmap_cr.wcr, rmap_cr.identity))

            #   From map exit to output array
            graph.add_edge(
                memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:],
                array_edge.dst, array_edge.dst_conn,
                Memlet(array_edge.data.data, array_edge.data.num_accesses,
                       array_edge.data.subset, array_edge.data.veclen,
                       rmap_cr.wcr, rmap_cr.identity))

            return

        # Remove tmp array node prior to the others, so that a new one
        # can be created in its stead (see below)
        graph.remove_node(nodes_to_remove[0])
        nodes_to_remove = nodes_to_remove[1:]

        # Create tasklet -> tmp -> tasklet connection
        tmp = graph.add_array(
            'tmp',
            memlet_edge.data.subset.bounding_box_size(),
            sdfg.arrays[memlet_edge.data.data].dtype,
            transient=True)
        tasklet_tmp_memlet = copy.deepcopy(memlet_edge.data)
        tasklet_tmp_memlet.data = tmp.data
        tasklet_tmp_memlet.subset = ShapeProperty.to_string(tmp.shape)

        # Modify memlet to point to output array
        memlet_edge.data.data = out_array.data

        # Recover reduction axes from CR reduce subset
        reduce_cr_subset = graph.in_edges(rmap_tasklet)[0].data.subset
        reduce_axes = []
        for ind, crvar in enumerate(reduce_cr_subset.indices):
            if '__i' in str(crvar):
                reduce_axes.append(ind)

        # Modify memlet access index by filtering out reduction axes
        if True:  # expr_index == 0:
            newindices = []
            for ind, ovar in enumerate(memlet_edge.data.subset.indices):
                if ind not in reduce_axes:
                    newindices.append(ovar)
        if len(newindices) == 0:
            newindices = [0]

        memlet_edge.data.subset = subsets.Indices(newindices)

        graph.remove_edge(memlet_edge)

        graph.add_edge(memlet_edge.src, memlet_edge.src_conn, tmp,
                       memlet_edge.dst_conn, tasklet_tmp_memlet)

        red_edges = list(graph.in_edges(rmap_tasklet))
        if len(red_edges) != 1:
            raise RuntimeError('CR edge must be unique')

        tmp_tasklet_memlet = copy.deepcopy(tasklet_tmp_memlet)
        graph.add_edge(tmp, None, rmap_tasklet, red_edges[0].dst_conn,
                       tmp_tasklet_memlet)

        for e in graph.edges_between(rmap_tasklet, rmap_cr):
            e.data.subset = memlet_edge.data.subset

        # Move output edges to point directly to CR node
        if expr_index == 1:
            # Set output memlet between CR node and outer reduction map to
            # contain the same subset as the one pointing to the CR node
            for e in graph.out_edges(rmap_cr):
                e.data.subset = memlet_edge.data.subset

            rmap_out = gnode(MapReduceFusion._rmap_out_exit)
            nxutil.change_edge_src(graph, rmap_out, omap_exit)

        # Remove nodes
        graph.remove_nodes_from(nodes_to_remove)

        # For unrelated outputs, connect original output to rmap_out
        if expr_index == 1 and tmap_exit not in nodes_to_remove:
            other_out_edges = list(graph.out_edges(tmap_exit))
            for e in other_out_edges:
                graph.remove_edge(e)
                graph.add_edge(e.src, e.src_conn, omap_exit, None, e.data)
                graph.add_edge(omap_exit, None, e.dst, e.dst_conn,
                               copy.copy(e.data))

    def modifies_graph(self):
        return True
class GPUTransformLocalStorage(pattern_matching.Transformation):
    """Implements the GPUTransformLocalStorage transformation.

        Similar to GPUTransformMap, but takes multiple maps leading from the
        same data node into account, creating a local storage for each range.

        @see: GPUTransformMap
    """

    _arrays_removed = 0
    _maps_transformed = 0

    fullcopy = Property(desc="Copy whole arrays rather than used subset",
                        dtype=bool,
                        default=False)

    nested_seq = Property(
        desc="Makes nested code semantically-equivalent to single-core code,"
        "transforming nested maps and memory into sequential and "
        "local memory respectively.",
        dtype=bool,
        default=True,
    )

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _reduce = nodes.Reduce("lambda: None", None)

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(GPUTransformLocalStorage._map_entry),
            nxutil.node_path_graph(GPUTransformLocalStorage._reduce),
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        if expr_index == 0:
            map_entry = graph.nodes()[candidate[
                GPUTransformLocalStorage._map_entry]]
            candidate_map = map_entry.map

            # Disallow GPUTransform on nested maps in strict mode
            if strict:
                if graph.scope_dict()[map_entry] is not None:
                    return False

            # Map schedules that are disallowed to transform to GPUs
            if (candidate_map.schedule == dtypes.ScheduleType.MPI
                    or candidate_map.schedule == dtypes.ScheduleType.GPU_Device
                    or candidate_map.schedule
                    == dtypes.ScheduleType.GPU_ThreadBlock or
                    candidate_map.schedule == dtypes.ScheduleType.Sequential):
                return False

            # Dynamic map ranges cannot become kernels
            if sd.has_dynamic_map_inputs(graph, map_entry):
                return False

            # Recursively check parent for GPU schedules
            sdict = graph.scope_dict()
            current_node = map_entry
            while current_node is not None:
                if (current_node.map.schedule == dtypes.ScheduleType.GPU_Device
                        or current_node.map.schedule
                        == dtypes.ScheduleType.GPU_ThreadBlock):
                    return False
                current_node = sdict[current_node]

            # Ensure that map does not include internal arrays that are
            # allocated on non-default space
            subgraph = graph.scope_subgraph(map_entry)
            for node in subgraph.nodes():
                if (isinstance(node, nodes.AccessNode) and
                        node.desc(sdfg).storage != dtypes.StorageType.Default
                        and node.desc(sdfg).storage !=
                        dtypes.StorageType.Register):
                    return False

            # If one of the outputs is a stream, do not match
            map_exit = graph.exit_nodes(map_entry)[0]
            for edge in graph.out_edges(map_exit):
                dst = graph.memlet_path(edge)[-1].dst
                if (isinstance(dst, nodes.AccessNode)
                        and isinstance(sdfg.arrays[dst.data], data.Stream)):
                    return False

            return True
        elif expr_index == 1:
            reduce = graph.nodes()[candidate[GPUTransformLocalStorage._reduce]]

            # Map schedules that are disallowed to transform to GPUs
            if (reduce.schedule == dtypes.ScheduleType.MPI
                    or reduce.schedule == dtypes.ScheduleType.GPU_Device
                    or reduce.schedule == dtypes.ScheduleType.GPU_ThreadBlock):
                return False

            # Recursively check parent for GPU schedules
            sdict = graph.scope_dict()
            current_node = sdict[reduce]
            while current_node is not None:
                if (current_node.map.schedule == dtypes.ScheduleType.GPU_Device
                        or current_node.map.schedule
                        == dtypes.ScheduleType.GPU_ThreadBlock):
                    return False
                current_node = sdict[current_node]

            return True

    @staticmethod
    def match_to_str(graph, candidate):
        if GPUTransformLocalStorage._reduce in candidate:
            return str(
                graph.nodes()[candidate[GPUTransformLocalStorage._reduce]])
        else:
            map_entry = graph.nodes()[candidate[
                GPUTransformLocalStorage._map_entry]]
            return str(map_entry)

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        if self.expr_index == 0:
            cnode = graph.nodes()[self.subgraph[
                GPUTransformLocalStorage._map_entry]]
            node_schedprop = cnode.map
            exit_nodes = graph.exit_nodes(cnode)
        else:
            cnode = graph.nodes()[self.subgraph[
                GPUTransformLocalStorage._reduce]]
            node_schedprop = cnode
            exit_nodes = [cnode]

        # Change schedule
        node_schedprop._schedule = dtypes.ScheduleType.GPU_Device
        if Config.get_bool("debugprint"):
            GPUTransformLocalStorage._maps_transformed += 1
        # If nested graph is designated as sequential, transform schedules and
        # storage from Default to Sequential/Register
        if self.nested_seq and self.expr_index == 0:
            for node in graph.scope_subgraph(cnode).nodes():
                if isinstance(node, nodes.AccessNode):
                    arr = node.desc(sdfg)
                    if arr.storage == dtypes.StorageType.Default:
                        arr.storage = dtypes.StorageType.Register
                elif isinstance(node, nodes.MapEntry):
                    if node.map.schedule == dtypes.ScheduleType.Default:
                        node.map.schedule = dtypes.ScheduleType.Sequential

        gpu_storage_types = [
            dtypes.StorageType.GPU_Global,
            dtypes.StorageType.GPU_Shared,
            dtypes.StorageType.GPU_Stack,
        ]

        #######################################################
        # Add GPU copies of CPU arrays (i.e., not already on GPU)

        # First, understand which arrays to clone
        all_out_edges = []
        for enode in exit_nodes:
            all_out_edges.extend(list(graph.out_edges(enode)))
        in_arrays_to_clone = set()
        out_arrays_to_clone = set()
        for e in graph.in_edges(cnode):
            data_node = sd.find_input_arraynode(graph, e)
            if data_node.desc(sdfg).storage not in gpu_storage_types:
                in_arrays_to_clone.add((data_node, e.data))
        for e in all_out_edges:
            data_node = sd.find_output_arraynode(graph, e)
            if data_node.desc(sdfg).storage not in gpu_storage_types:
                out_arrays_to_clone.add((data_node, e.data))

        if Config.get_bool("debugprint"):
            GPUTransformLocalStorage._arrays_removed += len(
                in_arrays_to_clone) + len(out_arrays_to_clone)

        # Second, create a GPU clone of each array
        # TODO: Overapproximate union of memlets
        cloned_arrays = {}
        in_cloned_arraynodes = {}
        out_cloned_arraynodes = {}
        for array_node, memlet in in_arrays_to_clone:
            array = array_node.desc(sdfg)
            cloned_name = "gpu_" + array_node.data
            for i, r in enumerate(memlet.bounding_box_size()):
                size = symbolic.overapproximate(r)
                try:
                    if int(size) == 1:
                        suffix = []
                        for c in str(memlet.subset[i][0]):
                            if c.isalpha() or c.isdigit() or c == "_":
                                suffix.append(c)
                            elif c == "+":
                                suffix.append("p")
                            elif c == "-":
                                suffix.append("m")
                            elif c == "*":
                                suffix.append("t")
                            elif c == "/":
                                suffix.append("d")
                        cloned_name += "_" + "".join(suffix)
                except:
                    continue
            if cloned_name in sdfg.arrays.keys():
                cloned_array = sdfg.arrays[cloned_name]
            elif array_node.data in cloned_arrays:
                cloned_array = cloned_arrays[array_node.data]
            else:
                full_shape = []
                for r in memlet.bounding_box_size():
                    size = symbolic.overapproximate(r)
                    try:
                        full_shape.append(int(size))
                    except:
                        full_shape.append(size)
                actual_dims = [
                    idx for idx, r in enumerate(full_shape)
                    if not (isinstance(r, int) and r == 1)
                ]
                if len(actual_dims) == 0:  # abort
                    actual_dims = [len(full_shape) - 1]
                if isinstance(array, data.Scalar):
                    sdfg.add_array(name=cloned_name,
                                   shape=[1],
                                   dtype=array.dtype,
                                   transient=True,
                                   storage=dtypes.StorageType.GPU_Global)
                elif isinstance(array, data.Stream):
                    sdfg.add_stream(
                        name=cloned_name,
                        dtype=array.dtype,
                        shape=[full_shape[d] for d in actual_dims],
                        veclen=array.veclen,
                        buffer_size=array.buffer_size,
                        storage=dtypes.StorageType.GPU_Global,
                        transient=True,
                        offset=[array.offset[d] for d in actual_dims])
                else:
                    sdfg.add_array(
                        name=cloned_name,
                        shape=[full_shape[d] for d in actual_dims],
                        dtype=array.dtype,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.GPU_Global,
                        allow_conflicts=array.allow_conflicts,
                        strides=[array.strides[d] for d in actual_dims],
                        offset=[array.offset[d] for d in actual_dims],
                    )
                cloned_arrays[array_node.data] = cloned_name
            cloned_node = type(array_node)(cloned_name)

            in_cloned_arraynodes[array_node.data] = cloned_node
        for array_node, memlet in out_arrays_to_clone:
            array = array_node.desc(sdfg)
            cloned_name = "gpu_" + array_node.data
            for i, r in enumerate(memlet.bounding_box_size()):
                size = symbolic.overapproximate(r)
                try:
                    if int(size) == 1:
                        suffix = []
                        for c in str(memlet.subset[i][0]):
                            if c.isalpha() or c.isdigit() or c == "_":
                                suffix.append(c)
                            elif c == "+":
                                suffix.append("p")
                            elif c == "-":
                                suffix.append("m")
                            elif c == "*":
                                suffix.append("t")
                            elif c == "/":
                                suffix.append("d")
                        cloned_name += "_" + "".join(suffix)
                except:
                    continue
            if cloned_name in sdfg.arrays.keys():
                cloned_array = sdfg.arrays[cloned_name]
            elif array_node.data in cloned_arrays:
                cloned_array = cloned_arrays[array_node.data]
            else:
                full_shape = []
                for r in memlet.bounding_box_size():
                    size = symbolic.overapproximate(r)
                    try:
                        full_shape.append(int(size))
                    except:
                        full_shape.append(size)
                actual_dims = [
                    idx for idx, r in enumerate(full_shape)
                    if not (isinstance(r, int) and r == 1)
                ]
                if len(actual_dims) == 0:  # abort
                    actual_dims = [len(full_shape) - 1]
                if isinstance(array, data.Scalar):
                    sdfg.add_array(name=cloned_name,
                                   shape=[1],
                                   dtype=array.dtype,
                                   transient=True,
                                   storage=dtypes.StorageType.GPU_Global)
                elif isinstance(array, data.Stream):
                    sdfg.add_stream(
                        name=cloned_name,
                        dtype=array.dtype,
                        shape=[full_shape[d] for d in actual_dims],
                        veclen=array.veclen,
                        buffer_size=array.buffer_size,
                        storage=dtypes.StorageType.GPU_Global,
                        transient=True,
                        offset=[array.offset[d] for d in actual_dims])
                else:
                    sdfg.add_array(
                        name=cloned_name,
                        shape=[full_shape[d] for d in actual_dims],
                        dtype=array.dtype,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.GPU_Global,
                        allow_conflicts=array.allow_conflicts,
                        strides=[array.strides[d] for d in actual_dims],
                        offset=[array.offset[d] for d in actual_dims],
                    )
                cloned_arrays[array_node.data] = cloned_name
            cloned_node = type(array_node)(cloned_name)
            cloned_node.setzero = True

            out_cloned_arraynodes[array_node.data] = cloned_node

        # Third, connect the cloned arrays to the originals
        for array_name, node in in_cloned_arraynodes.items():
            graph.add_node(node)
            is_scalar = isinstance(sdfg.arrays[array_name], data.Scalar)
            for edge in graph.in_edges(cnode):
                if edge.data.data == array_name:
                    newmemlet = copy.deepcopy(edge.data)
                    newmemlet.data = node.data

                    if is_scalar:
                        newmemlet.subset = sbs.Indices([0])
                    else:
                        offset = []
                        lost_dims = []
                        lost_ranges = []
                        newsubset = [None] * len(edge.data.subset)
                        for ind, r in enumerate(edge.data.subset):
                            offset.append(r[0])
                            if isinstance(edge.data.subset[ind], tuple):
                                begin = edge.data.subset[ind][0] - r[0]
                                end = edge.data.subset[ind][1] - r[0]
                                step = edge.data.subset[ind][2]
                                if begin == end:
                                    lost_dims.append(ind)
                                    lost_ranges.append((begin, end, step))
                                else:
                                    newsubset[ind] = (begin, end, step)
                            else:
                                newsubset[ind] -= r[0]
                        if len(lost_dims) == len(edge.data.subset):
                            lost_dims.pop()
                            newmemlet.subset = type(
                                edge.data.subset)([lost_ranges[-1]])
                        else:
                            newmemlet.subset = type(edge.data.subset)(
                                [r for r in newsubset if r is not None])

                    graph.add_edge(node, None, edge.dst, edge.dst_conn,
                                   newmemlet)

                    for e in graph.bfs_edges(edge.dst, reverse=False):
                        parent, _, _child, _, memlet = e
                        if parent != edge.dst and not in_scope(
                                graph, parent, edge.dst):
                            break
                        if memlet.data != edge.data.data:
                            continue
                        path = graph.memlet_path(e)
                        if not isinstance(path[-1].dst, nodes.CodeNode):
                            if in_path(path, e, nodes.ExitNode, forward=True):
                                if isinstance(parent, nodes.CodeNode):
                                    # Output edge
                                    break
                                else:
                                    continue
                        if is_scalar:
                            memlet.subset = sbs.Indices([0])
                        else:
                            newsubset = [None] * len(memlet.subset)
                            for ind, r in enumerate(memlet.subset):
                                if ind in lost_dims:
                                    continue
                                if isinstance(memlet.subset[ind], tuple):
                                    begin = r[0] - offset[ind]
                                    end = r[1] - offset[ind]
                                    step = r[2]
                                    newsubset[ind] = (begin, end, step)
                                else:
                                    newsubset[ind] = (
                                        r - offset[ind],
                                        r - offset[ind],
                                        1,
                                    )
                            memlet.subset = type(edge.data.subset)(
                                [r for r in newsubset if r is not None])
                        memlet.data = node.data

                    if self.fullcopy:
                        edge.data.subset = sbs.Range.from_array(
                            node.desc(sdfg))
                    edge.data.other_subset = newmemlet.subset
                    graph.add_edge(edge.src, edge.src_conn, node, None,
                                   edge.data)
                    graph.remove_edge(edge)

        for array_name, node in out_cloned_arraynodes.items():
            graph.add_node(node)
            is_scalar = isinstance(sdfg.arrays[array_name], data.Scalar)
            for edge in all_out_edges:
                if edge.data.data == array_name:
                    newmemlet = copy.deepcopy(edge.data)
                    newmemlet.data = node.data

                    if is_scalar:
                        newmemlet.subset = sbs.Indices([0])
                    else:
                        offset = []
                        lost_dims = []
                        lost_ranges = []
                        newsubset = [None] * len(edge.data.subset)
                        for ind, r in enumerate(edge.data.subset):
                            offset.append(r[0])
                            if isinstance(edge.data.subset[ind], tuple):
                                begin = edge.data.subset[ind][0] - r[0]
                                end = edge.data.subset[ind][1] - r[0]
                                step = edge.data.subset[ind][2]
                                if begin == end:
                                    lost_dims.append(ind)
                                    lost_ranges.append((begin, end, step))
                                else:
                                    newsubset[ind] = (begin, end, step)
                            else:
                                newsubset[ind] -= r[0]
                        if len(lost_dims) == len(edge.data.subset):
                            lost_dims.pop()
                            newmemlet.subset = type(
                                edge.data.subset)([lost_ranges[-1]])
                        else:
                            newmemlet.subset = type(edge.data.subset)(
                                [r for r in newsubset if r is not None])

                    graph.add_edge(edge.src, edge.src_conn, node, None,
                                   newmemlet)

                    end_node = graph.scope_dict()[edge.src]
                    for e in graph.bfs_edges(edge.src, reverse=True):
                        parent, _, _child, _, memlet = e
                        if parent == end_node:
                            break
                        if memlet.data != edge.data.data:
                            continue
                        path = graph.memlet_path(e)
                        if not isinstance(path[0].dst, nodes.CodeNode):
                            if in_path(path, e, nodes.EntryNode,
                                       forward=False):
                                if isinstance(parent, nodes.CodeNode):
                                    # Output edge
                                    break
                                else:
                                    continue
                        if is_scalar:
                            memlet.subset = sbs.Indices([0])
                        else:
                            newsubset = [None] * len(memlet.subset)
                            for ind, r in enumerate(memlet.subset):
                                if ind in lost_dims:
                                    continue
                                if isinstance(memlet.subset[ind], tuple):
                                    begin = r[0] - offset[ind]
                                    end = r[1] - offset[ind]
                                    step = r[2]
                                    newsubset[ind] = (begin, end, step)
                                else:
                                    newsubset[ind] = (
                                        r - offset[ind],
                                        r - offset[ind],
                                        1,
                                    )
                            memlet.subset = type(edge.data.subset)(
                                [r for r in newsubset if r is not None])
                        memlet.data = node.data

                    edge.data.wcr = None
                    if self.fullcopy:
                        edge.data.subset = sbs.Range.from_array(
                            node.desc(sdfg))
                    edge.data.other_subset = newmemlet.subset
                    graph.add_edge(node, None, edge.dst, edge.dst_conn,
                                   edge.data)
                    graph.remove_edge(edge)

        # Fourth, replace memlet arrays as necessary
        if self.expr_index == 0:
            scope_subgraph = graph.scope_subgraph(cnode)
            for edge in scope_subgraph.edges():
                if edge.data.data is not None and edge.data.data in cloned_arrays:
                    edge.data.data = cloned_arrays[edge.data.data]

    def modifies_graph(self):
        return True
Beispiel #26
0
    def apply(self, sdfg):
        def gnode(nname):
            return graph.nodes()[self.subgraph[nname]]

        expr_index = self.expr_index
        graph = sdfg.nodes()[self.state_id]
        tasklet = gnode(MapReduceFusion._tasklet)
        tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]]
        if expr_index == 0:  # Reduce without outer map
            rmap_entry = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_in_entry]]
        elif expr_index == 1:  # Reduce with outer map
            rmap_out_entry = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_out_entry]]
            rmap_out_exit = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_out_exit]]
            rmap_in_entry = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_in_entry]]
            rmap_tasklet = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_in_tasklet]]

        if expr_index == 2:
            rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._reduce]]
        else:
            rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._rmap_in_cr]]
        out_array = gnode(MapReduceFusion._out_array)

        # Set nodes to remove according to the expression index
        nodes_to_remove = [in_array]
        if expr_index == 0:
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry))
        elif expr_index == 1:
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_entry))
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry))
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_exit))
        else:
            nodes_to_remove.append(gnode(MapReduceFusion._reduce))

        # If no other edges lead to mapexit, remove it. Otherwise, keep
        # it and remove reduction incoming/outgoing edges
        if expr_index != 2 and len(graph.in_edges(tmap_exit)) == 1:
            nodes_to_remove.append(tmap_exit)

        memlet_edge = None
        for edge in graph.in_edges(tmap_exit):
            if edge.data.data == in_array.data:
                memlet_edge = edge
                break
        if memlet_edge is None:
            raise RuntimeError('Reduction memlet cannot be None')

        if expr_index == 0:  # Reduce without outer map
            # Index order does not matter, merge as-is
            pass
        elif expr_index == 1:  # Reduce with outer map
            tmap = tmap_exit.map
            perm_outer, perm_inner = MapReduceFusion.find_permutation(
                tmap, rmap_out_entry.map, rmap_in_entry.map, memlet_edge.data)

            # Split tasklet map into tmap_out -> tmap_in (according to
            # reduction)
            omap = nodes.Map(
                tmap.label + '_nonreduce',
                [p for i, p in enumerate(tmap.params) if i in perm_outer],
                [r for i, r in enumerate(tmap.range) if i in perm_outer],
                tmap.schedule, tmap.unroll, tmap.is_async)
            tmap.params = [
                p for i, p in enumerate(tmap.params) if i in perm_inner
            ]
            tmap.range = [
                r for i, r in enumerate(tmap.range) if i in perm_inner
            ]
            omap_entry = nodes.MapEntry(omap)
            omap_exit = rmap_out_exit
            rmap_out_exit.map = omap

            # Reconnect graph to new map
            tmap_entry = graph.entry_node(tmap_exit)
            tmap_in_edges = list(graph.in_edges(tmap_entry))
            for e in tmap_in_edges:
                nxutil.change_edge_dest(graph, tmap_entry, omap_entry)
            for e in tmap_in_edges:
                graph.add_edge(omap_entry, e.src_conn, tmap_entry, e.dst_conn,
                               copy.copy(e.data))
        elif expr_index == 2:  # Reduce node
            # Find correspondence between map indices and array outputs
            tmap = tmap_exit.map
            perm = MapReduceFusion.find_permutation_reduce(
                tmap, rmap_cr, graph, memlet_edge.data)

            output_subset = [tmap.params[d] for d in perm]
            if len(output_subset) == 0:  # Output is a scalar
                output_subset = [0]

            array_edge = graph.out_edges(rmap_cr)[0]

            # Delete relevant edges and nodes
            graph.remove_edge(memlet_edge)
            graph.remove_nodes_from(nodes_to_remove)

            # Add new edges and nodes
            #   From tasklet to map exit
            graph.add_edge(
                memlet_edge.src, memlet_edge.src_conn, memlet_edge.dst,
                memlet_edge.dst_conn,
                Memlet(out_array.data, memlet_edge.data.num_accesses,
                       subsets.Indices(output_subset), memlet_edge.data.veclen,
                       rmap_cr.wcr, rmap_cr.identity))

            #   From map exit to output array
            graph.add_edge(
                memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:],
                array_edge.dst, array_edge.dst_conn,
                Memlet(array_edge.data.data, array_edge.data.num_accesses,
                       array_edge.data.subset, array_edge.data.veclen,
                       rmap_cr.wcr, rmap_cr.identity))

            return

        # Remove tmp array node prior to the others, so that a new one
        # can be created in its stead (see below)
        graph.remove_node(nodes_to_remove[0])
        nodes_to_remove = nodes_to_remove[1:]

        # Create tasklet -> tmp -> tasklet connection
        tmp = graph.add_array(
            'tmp',
            memlet_edge.data.subset.bounding_box_size(),
            sdfg.arrays[memlet_edge.data.data].dtype,
            transient=True)
        tasklet_tmp_memlet = copy.deepcopy(memlet_edge.data)
        tasklet_tmp_memlet.data = tmp.data
        tasklet_tmp_memlet.subset = ShapeProperty.to_string(tmp.shape)

        # Modify memlet to point to output array
        memlet_edge.data.data = out_array.data

        # Recover reduction axes from CR reduce subset
        reduce_cr_subset = graph.in_edges(rmap_tasklet)[0].data.subset
        reduce_axes = []
        for ind, crvar in enumerate(reduce_cr_subset.indices):
            if '__i' in str(crvar):
                reduce_axes.append(ind)

        # Modify memlet access index by filtering out reduction axes
        if True:  # expr_index == 0:
            newindices = []
            for ind, ovar in enumerate(memlet_edge.data.subset.indices):
                if ind not in reduce_axes:
                    newindices.append(ovar)
        if len(newindices) == 0:
            newindices = [0]

        memlet_edge.data.subset = subsets.Indices(newindices)

        graph.remove_edge(memlet_edge)

        graph.add_edge(memlet_edge.src, memlet_edge.src_conn, tmp,
                       memlet_edge.dst_conn, tasklet_tmp_memlet)

        red_edges = list(graph.in_edges(rmap_tasklet))
        if len(red_edges) != 1:
            raise RuntimeError('CR edge must be unique')

        tmp_tasklet_memlet = copy.deepcopy(tasklet_tmp_memlet)
        graph.add_edge(tmp, None, rmap_tasklet, red_edges[0].dst_conn,
                       tmp_tasklet_memlet)

        for e in graph.edges_between(rmap_tasklet, rmap_cr):
            e.data.subset = memlet_edge.data.subset

        # Move output edges to point directly to CR node
        if expr_index == 1:
            # Set output memlet between CR node and outer reduction map to
            # contain the same subset as the one pointing to the CR node
            for e in graph.out_edges(rmap_cr):
                e.data.subset = memlet_edge.data.subset

            rmap_out = gnode(MapReduceFusion._rmap_out_exit)
            nxutil.change_edge_src(graph, rmap_out, omap_exit)

        # Remove nodes
        graph.remove_nodes_from(nodes_to_remove)

        # For unrelated outputs, connect original output to rmap_out
        if expr_index == 1 and tmap_exit not in nodes_to_remove:
            other_out_edges = list(graph.out_edges(tmap_exit))
            for e in other_out_edges:
                graph.remove_edge(e)
                graph.add_edge(e.src, e.src_conn, omap_exit, None, e.data)
                graph.add_edge(omap_exit, None, e.dst, e.dst_conn,
                               copy.copy(e.data))
Beispiel #27
0
class OutLocalStorage(pattern_matching.Transformation):
    """ Implements the OutLocalStorage transformation, which adds a transient
        data node between nested map exits.
    """

    _inner_map_exit = nodes.MapExit(nodes.Map("", [], []))
    _outer_map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(  #OutLocalStorage._tasklet,
                OutLocalStorage._inner_map_exit,
                OutLocalStorage._outer_map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        return True

    @staticmethod
    def match_to_str(graph, candidate):
        inner_map_exit = candidate[OutLocalStorage._inner_map_exit]
        outer_map_exit = candidate[OutLocalStorage._outer_map_exit]

        return ' -> '.join(
            str(node) for node in [inner_map_exit, outer_map_exit])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        inner_map_exit = graph.nodes()[self.subgraph[
            OutLocalStorage._inner_map_exit]]
        outer_map_exit = graph.nodes()[self.subgraph[
            OutLocalStorage._outer_map_exit]]

        original_edge = None
        invariant_memlet = None
        array = None
        for edge in graph.in_edges(outer_map_exit):
            src = edge.src
            if src != inner_map_exit:
                continue
            memlet = edge.data
            original_edge = edge
            invariant_memlet = memlet
            array = memlet.data
            break

        new_data = sdfg.add_array(
            graph.label + '_trans_' + invariant_memlet.data, [
                symbolic.overapproximate(r)
                for r in invariant_memlet.bounding_box_size()
            ],
            sdfg.arrays[invariant_memlet.data].dtype,
            transient=True)
        data_node = nodes.AccessNode(graph.label + '_trans_' +
                                     invariant_memlet.data)
        data_node.setzero = True

        from_data_mm = copy.deepcopy(invariant_memlet)
        to_data_mm = copy.deepcopy(invariant_memlet)
        to_data_mm.data = data_node.data
        offset = []
        for ind, r in enumerate(invariant_memlet.subset):
            offset.append(r[0])
            if isinstance(invariant_memlet.subset[ind], tuple):
                begin = invariant_memlet.subset[ind][0] - r[0]
                end = invariant_memlet.subset[ind][1] - r[0]
                step = invariant_memlet.subset[ind][2]
                to_data_mm.subset[ind] = (begin, end, step)
            else:
                to_data_mm.subset[ind] -= r[0]

        # Reconnect, assuming one edge to the stream
        graph.remove_edge(original_edge)
        graph.add_edge(inner_map_exit, original_edge.src_conn, data_node, None,
                       to_data_mm)
        graph.add_edge(data_node, None, outer_map_exit, original_edge.dst_conn,
                       from_data_mm)

        for _parent, _, _child, _, memlet in graph.bfs_edges(inner_map_exit,
                                                             reverse=True):
            if isinstance(_child, nodes.CodeNode):
                break
            if memlet.data != array:
                continue
            for ind, r in enumerate(memlet.subset):
                if isinstance(memlet.subset[ind], tuple):
                    begin = r[0] - offset[ind]
                    end = r[1] - offset[ind]
                    step = r[2]
                    memlet.subset[ind] = (begin, end, step)
                else:
                    memlet.subset[ind] -= offset[ind]
            memlet.data = graph.label + '_trans_' + invariant_memlet.data

        return
Beispiel #28
0
def _build_dataflow_graph_recurse(sdfg, state, primitives, modules, superEntry,
                                  super_exit):
    # Array of pairs (exit node, memlet)
    exit_nodes = []

    if len(primitives) == 0:
        # Inject empty tasklets into empty states
        primitives = [astnodes._EmptyTaskletNode("Empty Tasklet", None)]

    for prim in primitives:
        label = prim.name

        # Expand node to get entry and exit points
        if isinstance(prim, astnodes._MapNode):
            if len(prim.children) == 0:
                raise ValueError("Map node expected to have children")
            mapNode = nd.Map(label,
                             prim.params,
                             prim.range,
                             is_async=prim.is_async)
            # Add connectors for inputs that exist as array nodes
            entry = nd.MapEntry(
                mapNode,
                _get_input_symbols(prim.inputs, prim.range.free_symbols))
            exit = nd.MapExit(mapNode)
        elif isinstance(prim, astnodes._ConsumeNode):
            if len(prim.children) == 0:
                raise ValueError("Consume node expected to have children")
            consumeNode = nd.Consume(label, (prim.params[1], prim.num_pes),
                                     prim.condition)
            entry = nd.ConsumeEntry(consumeNode)
            exit = nd.ConsumeExit(consumeNode)
        elif isinstance(prim, astnodes._ReduceNode):
            rednode = nd.Reduce(prim.ast, prim.axes, prim.identity)
            state.add_node(rednode)
            entry = rednode
            exit = rednode
        elif isinstance(prim, astnodes._TaskletNode):
            if isinstance(prim, astnodes._EmptyTaskletNode):
                tasklet = nd.EmptyTasklet(prim.name)
            else:
                # Remove memlets from tasklet AST
                if prim.language == types.Language.Python:
                    clean_code = MemletRemover().visit(prim.ast)
                    clean_code = ModuleInliner(modules).visit(clean_code)
                else:  # Use external code from tasklet definition
                    if prim.extcode is None:
                        raise SyntaxError("Cannot define an intrinsic "
                                          "tasklet without an implementation")
                    clean_code = prim.extcode
                tasklet = nd.Tasklet(
                    prim.name,
                    set(prim.inputs.keys()),
                    set(prim.outputs.keys()),
                    code=clean_code,
                    language=prim.language,
                    code_global=prim.gcode)  # TODO: location=prim.location

            # Need to add the tasklet in case we're in an empty state, where no
            # edge will be drawn to it
            state.add_node(tasklet)
            entry = tasklet
            exit = tasklet

        elif isinstance(prim, astnodes._NestedSDFGNode):
            prim.sdfg.parent = state
            prim.sdfg._parent_sdfg = sdfg
            prim.sdfg.update_sdfg_list([])
            nsdfg = nd.NestedSDFG(prim.name, prim.sdfg,
                                  set(prim.inputs.keys()),
                                  set(prim.outputs.keys()))
            state.add_node(nsdfg)
            entry = nsdfg
            exit = nsdfg

        elif isinstance(prim, astnodes._ProgramNode):
            return
        elif isinstance(prim, astnodes._ControlFlowNode):
            continue
        else:
            raise TypeError("Node type not implemented: " +
                            str(prim.__class__))

        # Add incoming edges
        for varname, memlet in prim.inputs.items():
            arr = memlet.dataname
            if (prim.parent is not None
                    and memlet.dataname in prim.parent.transients.keys()):
                node = input_node_for_array(state, memlet.dataname)

                # Add incoming edge into transient as well
                # FIXME: A bit hacked?
                if arr in prim.parent.inputs:
                    astmem = prim.parent.inputs[arr]
                    _add_astmemlet_edge(sdfg, state, superEntry, None, node,
                                        None, astmem)

                    # Remove local name from incoming edge to parent
                    prim.parent.inputs[arr].local_name = None
            elif superEntry:
                node = superEntry
            else:
                node = input_node_for_array(state, memlet.dataname)

            # Destination connector inference
            # Connected to a tasklet or a nested SDFG
            dst_conn = (memlet.local_name
                        if isinstance(entry, nd.CodeNode) else None)
            # Connected to a scope as part of its range
            if str(varname).startswith('__DACEIN_'):
                dst_conn = str(varname)[9:]
            # Handle special case of consume input stream
            if (isinstance(entry, nd.ConsumeEntry)
                    and memlet.data == prim.stream):
                dst_conn = 'IN_stream'

            # If a memlet that covers this input already exists, skip
            # generating this one; otherwise replace memlet with ours
            skip_incoming_edge = False
            remove_edge = None
            for e in state.edges_between(node, entry):
                if e.data.data != memlet.dataname or dst_conn != e.dst_conn:
                    continue
                if e.data.subset.covers(memlet.subset):
                    skip_incoming_edge = True
                    break
                elif memlet.subset.covers(e.data.subset):
                    remove_edge = e
                    break
                else:
                    print('WARNING: Performing bounding-box union on',
                          memlet.subset, 'and', e.data.subset, '(in)')
                    e.data.subset = sbs.bounding_box_union(
                        e.data.subset, memlet.subset)
                    e.data.num_accesses += memlet.num_accesses
                    skip_incoming_edge = True
                    break

            if remove_edge is not None:
                state.remove_edge(remove_edge)

            if skip_incoming_edge == False:
                _add_astmemlet_edge(sdfg, state, node, None, entry, dst_conn,
                                    memlet)

        # If there are no inputs, generate a dummy edge
        if superEntry and len(prim.inputs) == 0:
            state.add_edge(superEntry, None, entry, None, EmptyMemlet())

        if len(prim.children) > 0:
            # Recurse
            inner_outputs = _build_dataflow_graph_recurse(
                sdfg, state, prim.children, modules, entry, exit)
            # Infer output node for each memlet
            for i, (out_src, mem) in enumerate(inner_outputs):
                # If there is no such array in this primitive's outputs,
                # it's an external array (e.g., a map in a map). In this case,
                # connect to the exit node
                if mem.dataname in prim.outputs:
                    inner_outputs[i] = (out_src, prim.outputs[mem.dataname])
                else:
                    inner_outputs[i] = (out_src, mem)
        else:
            inner_outputs = [(exit, mem) for mem in prim.outputs.values()]

        # Add outgoing edges
        for out_src, astmem in inner_outputs:

            data = astmem.data
            dataname = astmem.dataname

            # If WCR is not none, it needs to be handled in the code. Check for
            # this after, as we only expect it for one distinct case
            wcr_was_handled = astmem.wcr is None

            # TODO: This is convoluted. We should find a more readable
            # way of connecting the outgoing edges.

            if super_exit is None:

                # Assert that we're in a top-level node
                if ((not isinstance(prim.parent, astnodes._ProgramNode)) and
                    (not isinstance(prim.parent, astnodes._ControlFlowNode))):
                    raise RuntimeError("Expected to be at the top node")

                # Looks hacky
                src_conn = (astmem.local_name if isinstance(
                    out_src, (nd.Tasklet, nd.NestedSDFG)) else None)

                # Here we just need to connect memlets directly to their
                # respective data nodes
                out_tgt = output_node_for_array(state, astmem.dataname)

                # If a memlet that covers this outuput already exists, skip
                # generating this one; otherwise replace memlet with ours
                skip_outgoing_edge = False
                remove_edge = None
                for e in state.edges_between(out_src, out_tgt):
                    if e.data.data != astmem.dataname or src_conn != e.src_conn:
                        continue
                    if e.data.subset.covers(astmem.subset):
                        skip_outgoing_edge = True
                        break
                    elif astmem.subset.covers(e.data.subset):
                        remove_edge = e
                        break
                    else:
                        print('WARNING: Performing bounding-box union on',
                              astmem.subset, 'and', e.data.subset, '(out)')
                        e.data.subset = sbs.bounding_box_union(
                            e.data.subset, astmem.subset)
                        e.data.num_accesses += astmem.num_accesses
                        skip_outgoing_edge = True
                        break

                if skip_outgoing_edge == True:
                    continue
                if remove_edge is not None:
                    state.remove_edge(remove_edge)

                _add_astmemlet_edge(sdfg,
                                    state,
                                    out_src,
                                    src_conn,
                                    out_tgt,
                                    None,
                                    astmem,
                                    wcr=astmem.wcr,
                                    wcr_identity=astmem.wcr_identity)
                wcr_was_handled = (True if astmem.wcr is not None else
                                   wcr_was_handled)

                # If the program defines another output, connect it too.
                # This refers to the case where we have streams, which
                # must define an input and output, and sometimes this output
                # is defined in pdp.outputs
                if (isinstance(out_tgt, nd.AccessNode)
                        and isinstance(out_tgt.desc(sdfg), dt.Stream)):
                    try:
                        stream_memlet = next(
                            v for k, v in prim.parent.outputs.items()
                            if k == out_tgt.data)
                        stream_output = output_node_for_array(
                            state, stream_memlet.dataname)
                        _add_astmemlet_edge(sdfg, state, out_tgt, None,
                                            stream_output, None, stream_memlet)
                    except StopIteration:  # Stream output not found, skip
                        pass

            else:  # We're in a nest

                if isinstance(prim, astnodes._ScopeNode):
                    # We're a map or a consume node, that needs to connect our
                    # exit to either an array or to the super_exit
                    if data.transient and dataname in prim.parent.transients:
                        # Connect the exit directly
                        out_tgt = output_node_for_array(state, data.dataname)
                        _add_astmemlet_edge(sdfg, state, out_src, None,
                                            out_tgt, None, astmem)
                    else:
                        # This is either a transient defined in an outer scope,
                        # or an I/O array, so redirect thruogh the exit node
                        _add_astmemlet_edge(sdfg, state, out_src, None,
                                            super_exit, None, astmem)
                        # Instruct outer recursion layer to continue the route
                        exit_nodes.append((super_exit, astmem))
                elif isinstance(
                        prim,
                    (astnodes._TaskletNode, astnodes._NestedSDFGNode)):
                    # We're a tasklet, and need to connect either to the exit
                    # if the array is I/O or is defined in a scope further out,
                    # or directly to the transient if it's defined locally
                    if dataname in prim.parent.transients:
                        # This is a local transient variable, so connect to it
                        # directly
                        out_tgt = output_node_for_array(state, data.dataname)
                        _add_astmemlet_edge(sdfg, state, out_src,
                                            astmem.local_name, out_tgt, None,
                                            astmem)
                    else:
                        # This is an I/O array, or an outer level transient, so
                        # redirect through the exit node
                        _add_astmemlet_edge(sdfg,
                                            state,
                                            out_src,
                                            astmem.local_name,
                                            super_exit,
                                            None,
                                            astmem,
                                            wcr=astmem.wcr,
                                            wcr_identity=astmem.wcr_identity)
                        exit_nodes.append((super_exit, astmem))
                        if astmem.wcr is not None:
                            wcr_was_handled = True  # Sanity check
                else:
                    raise TypeError("Unexpected node type: {}".format(
                        type(out_src).__name__))

            if not wcr_was_handled and not isinstance(prim,
                                                      astnodes._ScopeNode):
                raise RuntimeError("Detected unhandled WCR for primitive '{}' "
                                   "of type {}. WCR is only expected for "
                                   "tasklets in a map/consume scope.".format(
                                       prim.name,
                                       type(prim).__name__))

    return exit_nodes
Beispiel #29
0
class StreamTransient(pattern_matching.Transformation):
    """ Implements the StreamTransient transformation, which adds a transient
        stream node between nested maps that lead to a stream. The transient
        then acts as a local buffer.
    """

    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))
    _outer_map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(StreamTransient._tasklet,
                                   StreamTransient._map_exit,
                                   StreamTransient._outer_map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_exit = graph.nodes()[candidate[StreamTransient._map_exit]]
        outer_map_exit = graph.nodes()[candidate[
            StreamTransient._outer_map_exit]]

        # Check if there is a streaming output
        for _src, _, dest, _, memlet in graph.out_edges(map_exit):
            if isinstance(sdfg.arrays[memlet.data],
                          data.Stream) and dest == outer_map_exit:
                return True

        return False

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[StreamTransient._tasklet]
        map_exit = candidate[StreamTransient._map_exit]
        outer_map_exit = candidate[StreamTransient._outer_map_exit]

        return ' -> '.join(
            str(node) for node in [tasklet, map_exit, outer_map_exit])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        tasklet = graph.nodes()[self.subgraph[StreamTransient._tasklet]]
        map_exit = graph.nodes()[self.subgraph[StreamTransient._map_exit]]
        outer_map_exit = graph.nodes()[self.subgraph[
            StreamTransient._outer_map_exit]]
        memlet = None
        edge = None
        for e in graph.out_edges(map_exit):
            memlet = e.data
            # TODO: What if there's more than one?
            if e.dst == outer_map_exit and isinstance(sdfg.arrays[memlet.data],
                                                      data.Stream):
                edge = e
                break
        tasklet_memlet = None
        for e in graph.out_edges(tasklet):
            tasklet_memlet = e.data
            if tasklet_memlet.data == memlet.data:
                break

        bbox = map_exit.map.range.bounding_box_size()
        bbox_approx = [symbolic.overapproximate(dim) for dim in bbox]
        dataname = memlet.data

        # Create the new node: Temporary stream and an access node
        newstream = sdfg.add_stream(
            'tile_' + dataname,
            sdfg.arrays[memlet.data].dtype,
            1,
            bbox_approx[0],
            [1],
            transient=True,
        )
        snode = nodes.AccessNode('tile_' + dataname)

        to_stream_mm = copy.deepcopy(memlet)
        to_stream_mm.data = snode.data
        tasklet_memlet.data = snode.data

        # Reconnect, assuming one edge to the stream
        graph.remove_edge(edge)
        graph.add_edge(map_exit, None, snode, None, to_stream_mm)
        graph.add_edge(snode, None, outer_map_exit, None, memlet)

        return

    def modifies_graph(self):
        return True
Beispiel #30
0
class MapFusion(pattern_matching.Transformation):
    """ Implements the map fusion pattern.

        Map Fusion takes two maps that are connected in series and have the 
        same range, and fuses them to one map. The tasklets in the new map are
        connected in the same manner as they were before the fusion.
    """

    _first_map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _second_map_entry = nodes.MapEntry(nodes.Map("", [], []))

    @staticmethod
    def annotates_memlets():
        return False

    @staticmethod
    def expressions():
        return [nxutil.node_path_graph(MapFusion._first_map_entry)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        # The first map must have a non-conflicting map exit.
        # (cannot fuse with CR in the first map)
        first_map_entry = graph.nodes()[candidate[MapFusion._first_map_entry]]
        first_exits = graph.exit_nodes(first_map_entry)
        first_exit = first_exits[0]
        if any([e.data.wcr is not None for e in graph.in_edges(first_exit)]):
            return False

        # Check whether there is a pattern map -> data -> map.
        data_nodes = []
        for _, _, dst, _, _ in graph.out_edges(first_exit):
            if isinstance(dst, nodes.AccessNode):
                data_nodes.append(dst)
            else:
                return False
        second_map_entry = None
        for data_node in data_nodes:
            for _, _, dst, _, _ in graph.out_edges(data_node):
                if isinstance(dst, nodes.MapEntry):
                    if second_map_entry is None:
                        second_map_entry = dst
                    elif dst != second_map_entry:
                        return False
                else:
                    return False
        if second_map_entry is None:
            return False
        for src, _, _, _, _ in graph.in_edges(second_map_entry):
            if not src in data_nodes:
                return False

        # Check map spaces (this should be generalized to ignore order).
        first_range = first_map_entry.map.range
        second_range = second_map_entry.map.range
        if first_range != second_range:
            return False

        # Success
        candidate[MapFusion._second_map_entry] = graph.nodes().index(
            second_map_entry)

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        first_map_entry = graph.nodes()[candidate[MapFusion._first_map_entry]]
        second_map_entry = graph.nodes()[candidate[
            MapFusion._second_map_entry]]

        return ' -> '.join(entry.map.label + ': ' + str(entry.map.params)
                           for entry in [first_map_entry, second_map_entry])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        first_map_entry = graph.nodes()[self.subgraph[
            MapFusion._first_map_entry]]
        first_map_exit = graph.exit_nodes(first_map_entry)[0]
        second_map_entry = graph.nodes()[self.subgraph[
            MapFusion._second_map_entry]]
        second_exits = graph.exit_nodes(second_map_entry)
        first_map_params = [
            symbolic.pystr_to_symbolic(p) for p in first_map_entry.map.params
        ]
        second_map_params = [
            symbolic.pystr_to_symbolic(p) for p in second_map_entry.map.params
        ]

        # Fix exits
        for exit_node in second_exits:
            if isinstance(exit_node, nodes.MapExit):
                exit_node.map = first_map_entry.map

        # Substitute symbols in second map.
        for _parent, _, _child, _, memlet in graph.bfs_edges(second_map_entry,
                                                             reverse=False):
            for fp, sp in zip(first_map_params, second_map_params):
                for ind, r in enumerate(memlet.subset):
                    if isinstance(memlet.subset[ind], tuple):
                        begin = r[0].subs(sp, fp)
                        end = r[1].subs(sp, fp)
                        step = r[2].subs(sp, fp)
                        memlet.subset[ind] = (begin, end, step)
                    else:
                        memlet.subset[ind] = memlet.subset[ind].subs(sp, fp)

        transients = {}
        for _, _, dst, _, memlet in graph.out_edges(first_map_exit):
            if not memlet.data in transients:
                transients[memlet.data] = dst
        new_edges = []
        for src, src_conn, _, dst_conn, memlet in graph.in_edges(
                first_map_exit):
            new_memlet = dcpy(memlet)
            new_edges.append(
                (src, src_conn, transients[memlet.data], dst_conn, new_memlet))
        for _, src_conn, dst, dst_conn, memlet in graph.out_edges(
                second_map_entry):
            new_memlet = dcpy(memlet)
            new_edges.append(
                (transients[memlet.data], src_conn, dst, dst_conn, new_memlet))

        # Delete nodes/edges
        for edge in graph.in_edges(first_map_exit):
            graph.remove_edge(edge)
        for edge in graph.out_edges(second_map_entry):
            graph.remove_edge(edge)
        data_nodes = []
        for _, _, dst, _, _ in graph.out_edges(first_map_exit):
            data_nodes.append(dst)
        for data_node in data_nodes:
            for edge in graph.all_edges(data_node):
                graph.remove_edge(edge)
        graph.remove_node(first_map_exit)
        graph.remove_node(second_map_entry)

        # Add edges
        for edge in new_edges:
            graph.add_edge(*edge)

        # Reduce transient sizes
        for data_node in data_nodes:
            data_desc = data_node.desc(sdfg)
            if data_desc.transient:
                edges = graph.in_edges(data_node)
                subset = edges[0].data.subset
                for idx in range(1, len(edges)):
                    subset = calc_set_union(subset, edges[idx].data.subset)
                data_desc.shape = subset.bounding_box_size()
                data_desc.strides = list(subset.bounding_box_size())
                data_desc.offset = [0] * subset.dims()