Exemple #1
0
    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        map_entry = self.map_entry
        transient = self.transient

        # Only one dimensional maps are allowed
        if len(map_entry.map.params) != 1:
            return False

        # Verify the map can be transformed to a for-loop
        m2for = MapToForLoop()
        m2for.setup_match(
            sdfg, sdfg.sdfg_id, self.state_id,
            {MapToForLoop.map_entry: self.subgraph[DoubleBuffering.map_entry]},
            expr_index)
        if not m2for.can_be_applied(graph, expr_index, sdfg, permissive):
            return False

        # Verify that all directly-connected internal access nodes point to
        # transient arrays
        first = True
        for edge in graph.out_edges(map_entry):
            if isinstance(edge.dst, nodes.AccessNode):
                desc = sdfg.arrays[edge.dst.data]
                if not isinstance(desc, data.Array) or not desc.transient:
                    return False
                else:
                    # To avoid duplicate matches, only match the first transient
                    if first and edge.dst != transient:
                        return False
                    first = False

        return True
Exemple #2
0
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_entry = graph.nodes()[candidate[DoubleBuffering._map_entry]]
        transient = graph.nodes()[candidate[DoubleBuffering._transient]]

        # Only one dimensional maps are allowed
        if len(map_entry.map.params) != 1:
            return False

        # Verify the map can be transformed to a for-loop
        if not MapToForLoop.can_be_applied(
                graph,
            {MapToForLoop._map_entry: candidate[DoubleBuffering._map_entry]},
                expr_index, sdfg, strict):
            return False

        # Verify that all directly-connected internal access nodes point to
        # transient arrays
        first = True
        for edge in graph.out_edges(map_entry):
            if isinstance(edge.dst, nodes.AccessNode):
                desc = sdfg.arrays[edge.dst.data]
                if not isinstance(desc, data.Array) or not desc.transient:
                    return False
                else:
                    # To avoid duplicate matches, only match the first transient
                    if first and edge.dst != transient:
                        return False
                    first = False

        return True
Exemple #3
0
    def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG):
        map_entry = self.map_entry

        map_param = map_entry.map.params[0]  # Assuming one dimensional

        ##############################
        # Change condition of loop to one fewer iteration (so that the
        # final one reads from the last buffer)
        map_rstart, map_rend, map_rstride = map_entry.map.range[0]
        map_rend = symbolic.pystr_to_symbolic('(%s) - (%s)' %
                                              (map_rend, map_rstride))
        map_entry.map.range = subsets.Range([(map_rstart, map_rend,
                                              map_rstride)])

        ##############################
        # Gather transients to modify
        transients_to_modify = set(edge.dst.data
                                   for edge in graph.out_edges(map_entry)
                                   if isinstance(edge.dst, nodes.AccessNode))

        # Add dimension to transients and modify memlets
        for transient in transients_to_modify:
            desc: data.Array = sdfg.arrays[transient]
            # Using non-python syntax to ensure properties change
            desc.strides = [desc.total_size] + list(desc.strides)
            desc.shape = [2] + list(desc.shape)
            desc.offset = [0] + list(desc.offset)
            desc.total_size = desc.total_size * 2

        ##############################
        # Modify memlets to use map parameter as buffer index
        modified_subsets = []  # Store modified memlets for final state
        for edge in graph.scope_subgraph(map_entry).edges():
            if edge.data.data in transients_to_modify:
                edge.data.subset = self._modify_memlet(sdfg, edge.data.subset,
                                                       edge.data.data)
                modified_subsets.append(edge.data.subset)
            else:  # Could be other_subset
                path = graph.memlet_path(edge)
                src_node = path[0].src
                dst_node = path[-1].dst

                # other_subset could be None. In that case, recreate from array
                dataname = None
                if (isinstance(src_node, nodes.AccessNode)
                        and src_node.data in transients_to_modify):
                    dataname = src_node.data
                elif (isinstance(dst_node, nodes.AccessNode)
                      and dst_node.data in transients_to_modify):
                    dataname = dst_node.data
                if dataname is not None:
                    subset = (edge.data.other_subset or
                              subsets.Range.from_array(sdfg.arrays[dataname]))
                    edge.data.other_subset = self._modify_memlet(
                        sdfg, subset, dataname)
                    modified_subsets.append(edge.data.other_subset)

        ##############################
        # Turn map into for loop
        map_to_for = MapToForLoop()
        map_to_for.setup_match(
            sdfg, self.sdfg_id, self.state_id,
            {MapToForLoop.map_entry: graph.node_id(self.map_entry)},
            self.expr_index)
        nsdfg_node, nstate = map_to_for.apply(graph, sdfg)

        ##############################
        # Gather node copies and remove memlets
        edges_to_replace = []
        for node in nstate.source_nodes():
            for edge in nstate.out_edges(node):
                if (isinstance(edge.dst, nodes.AccessNode)
                        and edge.dst.data in transients_to_modify):
                    edges_to_replace.append(edge)
                    nstate.remove_edge(edge)
            if nstate.out_degree(node) == 0:
                nstate.remove_node(node)

        ##############################
        # Add initial reads to initial nested state
        initial_state: sd.SDFGState = nsdfg_node.sdfg.start_state
        initial_state.set_label('%s_init' % map_entry.map.label)
        for edge in edges_to_replace:
            initial_state.add_node(edge.src)
            rnode = edge.src
            wnode = initial_state.add_write(edge.dst.data)
            initial_state.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn,
                                   copy.deepcopy(edge.data))

        # All instances of the map parameter in this state become the loop start
        sd.replace(initial_state, map_param, map_rstart)
        # Initial writes go to the appropriate buffer
        init_expr = symbolic.pystr_to_symbolic('(%s / %s) %% 2' %
                                               (map_rstart, map_rstride))
        sd.replace(initial_state, '__dace_db_param', init_expr)

        ##############################
        # Modify main state's memlets

        # Divide by loop stride
        new_expr = symbolic.pystr_to_symbolic('(%s / %s) %% 2' %
                                              (map_param, map_rstride))
        sd.replace(nstate, '__dace_db_param', new_expr)

        ##############################
        # Add the main state's contents to the last state, modifying
        # memlets appropriately.
        final_state: sd.SDFGState = nsdfg_node.sdfg.sink_nodes()[0]
        final_state.set_label('%s_final_computation' % map_entry.map.label)
        dup_nstate = copy.deepcopy(nstate)
        final_state.add_nodes_from(dup_nstate.nodes())
        for e in dup_nstate.edges():
            final_state.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data)

        # If there is a WCR output with transient, only output in last state
        nstate: sd.SDFGState
        for node in nstate.sink_nodes():
            for e in list(nstate.in_edges(node)):
                if e.data.wcr is not None:
                    path = nstate.memlet_path(e)
                    if isinstance(path[0].src, nodes.AccessNode):
                        nstate.remove_memlet_path(e)

        ##############################
        # Add reads into next buffers to main state
        for edge in edges_to_replace:
            rnode = copy.deepcopy(edge.src)
            nstate.add_node(rnode)
            wnode = nstate.add_write(edge.dst.data)
            new_memlet = copy.deepcopy(edge.data)
            if new_memlet.data in transients_to_modify:
                new_memlet.other_subset = self._replace_in_subset(
                    new_memlet.other_subset, map_param,
                    '(%s + %s)' % (map_param, map_rstride))
            else:
                new_memlet.subset = self._replace_in_subset(
                    new_memlet.subset, map_param,
                    '(%s + %s)' % (map_param, map_rstride))

            nstate.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn,
                            new_memlet)

        nstate.set_label('%s_double_buffered' % map_entry.map.label)
        # Divide by loop stride
        new_expr = symbolic.pystr_to_symbolic('((%s / %s) + 1) %% 2' %
                                              (map_param, map_rstride))
        sd.replace(nstate, '__dace_db_param', new_expr)

        # Remove symbol once done
        del nsdfg_node.sdfg.symbols['__dace_db_param']
        del nsdfg_node.symbol_mapping['__dace_db_param']

        return nsdfg_node
    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        subgraph = self.subgraph_view(sdfg)
        map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph)

        result = StencilTiling.topology(sdfg, graph, map_entries)
        (children_dict, parent_dict, sink_maps) = result

        # next up, calculate inferred ranges for each map
        # for each map entry, this contains a tuple of dicts:
        # each of those maps from data_name of the array to
        # inferred outer ranges. An inferred outer range is created
        # by taking the union of ranges of inner subsets corresponding
        # to that data and substituting this subset by the min / max of the
        # parametrized map boundaries
        # finally, from these outer ranges we can easily calculate
        # strides and tile sizes required for every map
        inferred_ranges = defaultdict(dict)

        # create array of reverse topologically sorted map entries
        # to iterate over
        topo_reversed = []
        queue = set(sink_maps.copy())
        while len(queue) > 0:
            element = next(e for e in queue
                           if not children_dict[e] - set(topo_reversed))
            topo_reversed.append(element)
            queue.remove(element)
            for parent in parent_dict[element]:
                queue.add(parent)

        # main loop
        # first get coverage dicts for each map entry
        # for each map, contains a tuple of two dicts
        # each of those two maps from data name to outer range
        coverage = {}
        for map_entry in map_entries:
            coverage[map_entry] = StencilTiling.coverage_dicts(
                sdfg, graph, map_entry, outer_range=True)

        # we have a mapping from data name to outer range
        # however we want a mapping from map parameters to outer ranges
        # for this we need to find out how all array dimensions map to
        # outer ranges

        variable_mapping = defaultdict(list)
        for map_entry in topo_reversed:
            map = map_entry.map

            # first find out variable mapping
            for e in itertools.chain(
                    graph.out_edges(map_entry),
                    graph.in_edges(graph.exit_node(map_entry))):
                mapping = []
                for dim in e.data.subset:
                    syms = set()
                    for d in dim:
                        syms |= symbolic.symlist(d).keys()
                    if len(syms) > 1:
                        raise NotImplementedError(
                            "One incoming or outgoing stencil subset is indexed "
                            "by multiple map parameters. "
                            "This is not supported yet.")
                    try:
                        mapping.append(syms.pop())
                    except KeyError:
                        # just append None if there is no map symbol in it.
                        # we don't care for now.
                        mapping.append(None)

                if e.data in variable_mapping:
                    # assert that this is the same everywhere.
                    # else we might run into problems
                    assert variable_mapping[e.data.data] == mapping
                else:
                    variable_mapping[e.data.data] = mapping

            # now do mapping data name -> outer range
            # and from that infer mapping variable -> outer range
            local_ranges = {dn: None for dn in coverage[map_entry][1].keys()}
            for data_name, cov in coverage[map_entry][1].items():
                local_ranges[data_name] = subsets.union(
                    local_ranges[data_name], cov)
                # now look at proceeding maps
                # and union those subsets -> could be larger with stencil indent
                for child_map in children_dict[map_entry]:
                    if data_name in coverage[child_map][0]:
                        local_ranges[data_name] = subsets.union(
                            local_ranges[data_name],
                            coverage[child_map][0][data_name])

            # final assignent: combine local_ranges and variable_mapping
            # together into inferred_ranges
            inferred_ranges[map_entry] = {p: None for p in map.params}
            for data_name, ranges in local_ranges.items():
                for param, r in zip(variable_mapping[data_name], ranges):
                    # create new range from this subset and assign
                    rng = subsets.Range((r, ))
                    if param:
                        inferred_ranges[map_entry][param] = subsets.union(
                            inferred_ranges[map_entry][param], rng)

        # get parameters -- should all be the same
        params = next(iter(map_entries)).map.params.copy()
        # define reference range as inferred range of one of the sink maps
        self.reference_range = inferred_ranges[next(iter(sink_maps))]
        if self.debug:
            print("StencilTiling::Reference Range", self.reference_range)
        # next up, search for the ranges that don't change
        invariant_dims = []
        for idx, p in enumerate(params):
            different = False
            if self.reference_range[p] is None:
                invariant_dims.append(idx)
                warnings.warn(
                    f"StencilTiling::No Stencil pattern detected for parameter {p}"
                )
                continue
            for m in map_entries:
                if inferred_ranges[m][p] != self.reference_range[p]:
                    different = True
                    break
            if not different:
                invariant_dims.append(idx)
                warnings.warn(
                    f"StencilTiling::No Stencil pattern detected for parameter {p}"
                )

        # during stripmining, we will create new outer map entries
        # for easy access
        self._outer_entries = set()
        # with inferred_ranges constructed, we can begin to strip mine
        for map_entry in map_entries:
            # Retrieve map entry and exit nodes.
            map = map_entry.map

            stripmine_subgraph = {
                StripMining._map_entry: graph.nodes().index(map_entry)
            }

            sdfg_id = sdfg.sdfg_id
            last_map_entry = None
            original_schedule = map_entry.schedule
            self.tile_sizes = []
            self.tile_offset_lower = []
            self.tile_offset_upper = []

            # strip mining each dimension where necessary
            removed_maps = 0
            for dim_idx, param in enumerate(map_entry.map.params):
                # get current_node tile size
                if dim_idx >= len(self.strides):
                    tile_stride = symbolic.pystr_to_symbolic(self.strides[-1])
                else:
                    tile_stride = symbolic.pystr_to_symbolic(
                        self.strides[dim_idx])

                trivial = False

                if dim_idx in invariant_dims:
                    self.tile_sizes.append(tile_stride)
                    self.tile_offset_lower.append(0)
                    self.tile_offset_upper.append(0)
                else:
                    target_range_current = inferred_ranges[map_entry][param]
                    reference_range_current = self.reference_range[param]

                    min_diff = symbolic.SymExpr(reference_range_current.min_element()[0] \
                                    - target_range_current.min_element()[0])
                    max_diff = symbolic.SymExpr(target_range_current.max_element()[0] \
                                    - reference_range_current.max_element()[0])

                    try:
                        min_diff = symbolic.evaluate(min_diff, {})
                        max_diff = symbolic.evaluate(max_diff, {})
                    except TypeError:
                        raise RuntimeError("Symbolic evaluation of map "
                                           "ranges failed. Please check "
                                           "your parameters and match.")

                    self.tile_sizes.append(tile_stride + max_diff + min_diff)
                    self.tile_offset_lower.append(
                        symbolic.pystr_to_symbolic(str(min_diff)))
                    self.tile_offset_upper.append(
                        symbolic.pystr_to_symbolic(str(max_diff)))

                # get calculated parameters
                tile_size = self.tile_sizes[-1]

                dim_idx -= removed_maps
                # If map or tile sizes are trivial, skip strip-mining map dimension
                # special cases:
                # if tile size is trivial AND we have an invariant dimension, skip
                if tile_size == map.range.size()[dim_idx] and (
                        dim_idx + removed_maps) in invariant_dims:
                    continue

                # trivial map: we just continue
                if map.range.size()[dim_idx] in [0, 1]:
                    continue

                if tile_size == 1 and tile_stride == 1 and (
                        dim_idx + removed_maps) in invariant_dims:
                    trivial = True
                    removed_maps += 1

                # indent all map ranges accordingly and then perform
                # strip mining on these. Offset inner maps accordingly afterwards

                range_tuple = (map.range[dim_idx][0] +
                               self.tile_offset_lower[-1],
                               map.range[dim_idx][1] -
                               self.tile_offset_upper[-1],
                               map.range[dim_idx][2])
                map.range[dim_idx] = range_tuple
                stripmine = StripMining(sdfg_id, self.state_id,
                                        stripmine_subgraph, 0)

                stripmine.tiling_type = 'ceilrange'
                stripmine.dim_idx = dim_idx
                stripmine.new_dim_prefix = self.prefix if not trivial else ''
                # use tile_stride for both -- we will extend
                # the inner tiles later
                stripmine.tile_size = str(tile_stride)
                stripmine.tile_stride = str(tile_stride)
                outer_map = stripmine.apply(sdfg)
                outer_map.schedule = original_schedule

                # apply to the new map the schedule of the original one
                map_entry.schedule = self.schedule

                # if tile stride is 1, we can make a nice simplification by just
                # taking the overapproximated inner range as inner range
                # this eliminates the min/max in the range which
                # enables loop unrolling
                if tile_stride == 1:
                    map_entry.range[dim_idx] = tuple(
                        symbolic.SymExpr(el._approx_expr) if isinstance(
                            el, symbolic.SymExpr) else el
                        for el in map_entry.range[dim_idx])

                # in map_entry: enlarge tiles by upper and lower offset
                # doing it this way and not via stripmine strides ensures
                # that the max gets changed as well
                old_range = map_entry.range[dim_idx]
                map_entry.range[dim_idx] = ((old_range[0] -
                                             self.tile_offset_lower[-1]),
                                            (old_range[1] +
                                             self.tile_offset_upper[-1]),
                                            old_range[2])

                # We have to propagate here for correct outer volume and subset sizes
                _propagate_node(graph, map_entry)
                _propagate_node(graph, graph.exit_node(map_entry))

                # usual tiling pipeline
                if last_map_entry:
                    new_map_entry = graph.in_edges(map_entry)[0].src
                    mapcollapse_subgraph = {
                        MapCollapse._outer_map_entry:
                        graph.node_id(last_map_entry),
                        MapCollapse._inner_map_entry:
                        graph.node_id(new_map_entry)
                    }
                    mapcollapse = MapCollapse(sdfg_id, self.state_id,
                                              mapcollapse_subgraph, 0)
                    mapcollapse.apply(sdfg)
                last_map_entry = graph.in_edges(map_entry)[0].src
            # add last instance of map entries to _outer_entries
            if last_map_entry:
                self._outer_entries.add(last_map_entry)

            # Map Unroll Feature: only unroll if conditions are met:
            # Only unroll if at least one of the inner map ranges is strictly larger than 1
            # Only unroll if strides all are one
            if self.unroll_loops and all(s == 1 for s in self.strides) and any(
                    s not in [0, 1] for s in map_entry.range.size()):
                l = len(map_entry.params)
                if l > 1:
                    subgraph = {
                        MapExpansion.map_entry: graph.nodes().index(map_entry)
                    }
                    trafo_expansion = MapExpansion(sdfg.sdfg_id,
                                                   sdfg.nodes().index(graph),
                                                   subgraph, 0)
                    trafo_expansion.apply(sdfg)
                maps = [map_entry]
                for _ in range(l - 1):
                    map_entry = graph.out_edges(map_entry)[0].dst
                    maps.append(map_entry)

                for map in reversed(maps):
                    # MapToForLoop
                    subgraph = {
                        MapToForLoop._map_entry: graph.nodes().index(map)
                    }
                    trafo_for_loop = MapToForLoop(sdfg.sdfg_id,
                                                  sdfg.nodes().index(graph),
                                                  subgraph, 0)
                    trafo_for_loop.apply(sdfg)
                    nsdfg = trafo_for_loop.nsdfg

                    # LoopUnroll

                    guard = trafo_for_loop.guard
                    end = trafo_for_loop.after_state
                    begin = next(e.dst for e in nsdfg.out_edges(guard)
                                 if e.dst != end)

                    subgraph = {
                        DetectLoop._loop_guard: nsdfg.nodes().index(guard),
                        DetectLoop._loop_begin: nsdfg.nodes().index(begin),
                        DetectLoop._exit_state: nsdfg.nodes().index(end)
                    }
                    transformation = LoopUnroll(0, 0, subgraph, 0)
                    transformation.apply(nsdfg)
            elif self.unroll_loops:
                warnings.warn(
                    "Did not unroll loops. Either all ranges are equal to "
                    "one or range difference is symbolic.")

        self._outer_entries = list(self._outer_entries)