Exemplo n.º 1
0
    def __stripmine(self, sdfg, graph, candidate):
        # Retrieve map entry and exit nodes.
        map_entry = graph.nodes()[candidate[OrthogonalTiling._map_entry]]
        map_exit = graph.exit_nodes(map_entry)[0]

        # Map subgraph
        map_subgraph = graph.scope_subgraph(map_entry)

        # Retrieve transformation properties.
        prefix = self.prefix
        tile_sizes = self.tile_sizes
        divides_evenly = self.divides_evenly

        new_param = []
        new_range = []

        for dim_idx in range(len(map_entry.map.params)):

            if dim_idx >= len(tile_sizes):
                tile_size = tile_sizes[-1]
            else:
                tile_size = tile_sizes[dim_idx]

            # Retrieve parameter and range of dimension to be strip-mined.
            target_dim = map_entry.map.params[dim_idx]
            td_from, td_to, td_step = map_entry.map.range[dim_idx]

            new_dim = prefix + '_' + target_dim

            # Basic values
            if divides_evenly:
                tile_num = '(%s + 1 - %s) / %s' % (symbolic.symstr(td_to),
                                                   symbolic.symstr(td_from),
                                                   str(tile_size))
            else:
                tile_num = 'int_ceil((%s + 1 - %s), %s)' % (symbolic.symstr(
                    td_to), symbolic.symstr(td_from), str(tile_size))

            # Outer map values (over all tiles)
            nd_from = 0
            nd_to = symbolic.pystr_to_symbolic(str(tile_num) + ' - 1')
            nd_step = 1

            # Inner map values (over one tile)
            td_from_new = dace.symbolic.pystr_to_symbolic(td_from)
            td_to_new_exact = symbolic.pystr_to_symbolic(
                'min(%s + 1 - %s * %s, %s + %s) - 1' %
                (symbolic.symstr(td_to), str(new_dim), str(tile_size),
                 td_from_new, str(tile_size)))
            td_to_new_approx = symbolic.pystr_to_symbolic(
                '%s + %s - 1' % (td_from_new, str(tile_size)))

            # Outer map (over all tiles)
            new_dim_range = (nd_from, nd_to, nd_step)
            new_param.append(new_dim)
            new_range.append(new_dim_range)

            # Inner map (over one tile)
            if divides_evenly:
                td_to_new = td_to_new_approx
            else:
                td_to_new = dace.symbolic.SymExpr(td_to_new_exact,
                                                  td_to_new_approx)
            map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step)

            # Fix subgraph memlets
            target_dim = dace.symbolic.pystr_to_symbolic(target_dim)
            offset = dace.symbolic.pystr_to_symbolic(
                '%s * %s' % (new_dim, str(tile_size)))
            for _, _, _, _, memlet in map_subgraph.edges():
                old_subset = memlet.subset
                if isinstance(old_subset, dace.subsets.Indices):
                    new_indices = []
                    for idx in old_subset:
                        new_idx = idx.subs(target_dim, target_dim + offset)
                        new_indices.append(new_idx)
                    memlet.subset = dace.subsets.Indices(new_indices)
                elif isinstance(old_subset, dace.subsets.Range):
                    new_ranges = []
                    for i, old_range in enumerate(old_subset):
                        if len(old_range) == 3:
                            b, e, s, = old_range
                            t = old_subset.tile_sizes[i]
                        else:
                            raise ValueError(
                                'Range %s is invalid.' % old_range)
                        new_b = b.subs(target_dim, target_dim + offset)
                        new_e = e.subs(target_dim, target_dim + offset)
                        new_s = s.subs(target_dim, target_dim + offset)
                        new_t = t.subs(target_dim, target_dim + offset)
                        new_ranges.append((new_b, new_e, new_s, new_t))
                    memlet.subset = dace.subsets.Range(new_ranges)
                else:
                    raise NotImplementedError

        new_map = nodes.Map(prefix + '_' + map_entry.map.label, new_param,
                            subsets.Range(new_range))
        new_map_entry = nodes.MapEntry(new_map)
        new_exit = nodes.MapExit(new_map)

        # Make internal map's schedule to "not parallel"
        map_entry.map._schedule = dtypes.ScheduleType.Default

        # Redirect/create edges.
        new_in_edges = {}
        for _src, conn, _dest, _, memlet in graph.out_edges(map_entry):
            if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar):
                new_subset = copy.deepcopy(memlet.subset)
                # new_subset = calc_set_image(map_entry.map.params,
                #                             map_entry.map.range, memlet.subset,
                #                             cont_or_strided)
                if memlet.data in new_in_edges:
                    src, src_conn, dest, dest_conn, new_memlet, num = \
                        new_in_edges[memlet.data]
                    new_memlet.subset = calc_set_union(
                        new_memlet.data, sdfg.arrays[nnew_memlet.data],
                        new_memlet.subset, new_subset)
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_in_edges.update({
                        memlet.data: (src, src_conn, dest, dest_conn,
                                      new_memlet, min(num, int(conn[4:])))
                    })
                else:
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_in_edges.update({
                        memlet.data: (new_map_entry, None, map_entry, None,
                                      new_memlet, int(conn[4:]))
                    })
        nxutil.change_edge_dest(graph, map_entry, new_map_entry)

        new_out_edges = {}
        for _src, conn, _dest, _, memlet in graph.in_edges(map_exit):
            if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar):
                new_subset = memlet.subset
                # new_subset = calc_set_image(map_entry.map.params,
                #                             map_entry.map.range,
                #                             memlet.subset, cont_or_strided)
                if memlet.data in new_out_edges:
                    src, src_conn, dest, dest_conn, new_memlet, num = \
                        new_out_edges[memlet.data]
                    new_memlet.subset = calc_set_union(
                        new_memlet.data, sdfg.arrays[nnew_memlet.data],
                        new_memlet.subset, new_subset)
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_out_edges.update({
                        memlet.data: (src, src_conn, dest, dest_conn,
                                      new_memlet, min(num, conn[4:]))
                    })
                else:
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_out_edges.update({
                        memlet.data: (map_exit, None, new_exit, None,
                                      new_memlet, conn[4:])
                    })
        nxutil.change_edge_src(graph, map_exit, new_exit)

        # Connector related work follows
        # 1. Dictionary 'old_connector_number': 'new_connector_numer'
        # 2. New node in/out connectors
        # 3. New edges

        in_conn_nums = []
        for _, e in new_in_edges.items():
            _, _, _, _, _, num = e
            in_conn_nums.append(num)
        in_conn = {}
        for i, num in enumerate(in_conn_nums):
            in_conn.update({num: i + 1})

        entry_in_connectors = set()
        entry_out_connectors = set()
        for i in range(len(in_conn_nums)):
            entry_in_connectors.add('IN_' + str(i + 1))
            entry_out_connectors.add('OUT_' + str(i + 1))
        new_map_entry.in_connectors = entry_in_connectors
        new_map_entry.out_connectors = entry_out_connectors

        for _, e in new_in_edges.items():
            src, _, dst, _, memlet, num = e
            graph.add_edge(src, 'OUT_' + str(in_conn[num]), dst,
                           'IN_' + str(in_conn[num]), memlet)

        out_conn_nums = []
        for _, e in new_out_edges.items():
            _, _, dst, _, _, num = e
            if dst is not new_exit:
                continue
            out_conn_nums.append(num)
        out_conn = {}
        for i, num in enumerate(out_conn_nums):
            out_conn.update({num: i + 1})

        exit_in_connectors = set()
        exit_out_connectors = set()
        for i in range(len(out_conn_nums)):
            exit_in_connectors.add('IN_' + str(i + 1))
            exit_out_connectors.add('OUT_' + str(i + 1))
        new_exit.in_connectors = exit_in_connectors
        new_exit.out_connectors = exit_out_connectors

        for _, e in new_out_edges.items():
            src, _, dst, _, memlet, num = e
            graph.add_edge(src, 'OUT_' + str(out_conn[num]), dst,
                           'IN_' + str(out_conn[num]), memlet)

        # Return strip-mined dimension.
        return target_dim, new_dim, new_map
Exemplo n.º 2
0
    def apply(self, sdfg):
        def gnode(nname):
            return graph.nodes()[self.subgraph[nname]]

        expr_index = self.expr_index
        graph = sdfg.nodes()[self.state_id]
        tasklet = gnode(MapReduceFusion._tasklet)
        tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]]
        if expr_index == 0:  # Reduce without outer map
            rmap_entry = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_in_entry]]
        elif expr_index == 1:  # Reduce with outer map
            rmap_out_entry = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_out_entry]]
            rmap_out_exit = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_out_exit]]
            rmap_in_entry = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_in_entry]]
            rmap_tasklet = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_in_tasklet]]

        if expr_index == 2:
            rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._reduce]]
        else:
            rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._rmap_in_cr]]
        out_array = gnode(MapReduceFusion._out_array)

        # Set nodes to remove according to the expression index
        nodes_to_remove = [in_array]
        if expr_index == 0:
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry))
        elif expr_index == 1:
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_entry))
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry))
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_exit))
        else:
            nodes_to_remove.append(gnode(MapReduceFusion._reduce))

        # If no other edges lead to mapexit, remove it. Otherwise, keep
        # it and remove reduction incoming/outgoing edges
        if expr_index != 2 and len(graph.in_edges(tmap_exit)) == 1:
            nodes_to_remove.append(tmap_exit)

        memlet_edge = None
        for edge in graph.in_edges(tmap_exit):
            if edge.data.data == in_array.data:
                memlet_edge = edge
                break
        if memlet_edge is None:
            raise RuntimeError('Reduction memlet cannot be None')

        if expr_index == 0:  # Reduce without outer map
            # Index order does not matter, merge as-is
            pass
        elif expr_index == 1:  # Reduce with outer map
            tmap = tmap_exit.map
            perm_outer, perm_inner = MapReduceFusion.find_permutation(
                tmap, rmap_out_entry.map, rmap_in_entry.map, memlet_edge.data)

            # Split tasklet map into tmap_out -> tmap_in (according to
            # reduction)
            omap = nodes.Map(
                tmap.label + '_nonreduce',
                [p for i, p in enumerate(tmap.params) if i in perm_outer],
                [r for i, r in enumerate(tmap.range) if i in perm_outer],
                tmap.schedule, tmap.unroll, tmap.is_async)
            tmap.params = [
                p for i, p in enumerate(tmap.params) if i in perm_inner
            ]
            tmap.range = [
                r for i, r in enumerate(tmap.range) if i in perm_inner
            ]
            omap_entry = nodes.MapEntry(omap)
            omap_exit = rmap_out_exit
            rmap_out_exit.map = omap

            # Reconnect graph to new map
            tmap_entry = graph.entry_node(tmap_exit)
            tmap_in_edges = list(graph.in_edges(tmap_entry))
            for e in tmap_in_edges:
                nxutil.change_edge_dest(graph, tmap_entry, omap_entry)
            for e in tmap_in_edges:
                graph.add_edge(omap_entry, e.src_conn, tmap_entry, e.dst_conn,
                               copy.copy(e.data))
        elif expr_index == 2:  # Reduce node
            # Find correspondence between map indices and array outputs
            tmap = tmap_exit.map
            perm = MapReduceFusion.find_permutation_reduce(
                tmap, rmap_cr, graph, memlet_edge.data)

            output_subset = [tmap.params[d] for d in perm]
            if len(output_subset) == 0:  # Output is a scalar
                output_subset = [0]

            array_edge = graph.out_edges(rmap_cr)[0]

            # Delete relevant edges and nodes
            graph.remove_edge(memlet_edge)
            graph.remove_nodes_from(nodes_to_remove)

            # Add new edges and nodes
            #   From tasklet to map exit
            graph.add_edge(
                memlet_edge.src, memlet_edge.src_conn, memlet_edge.dst,
                memlet_edge.dst_conn,
                Memlet(out_array.data, memlet_edge.data.num_accesses,
                       subsets.Indices(output_subset), memlet_edge.data.veclen,
                       rmap_cr.wcr, rmap_cr.identity))

            #   From map exit to output array
            graph.add_edge(
                memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:],
                array_edge.dst, array_edge.dst_conn,
                Memlet(array_edge.data.data, array_edge.data.num_accesses,
                       array_edge.data.subset, array_edge.data.veclen,
                       rmap_cr.wcr, rmap_cr.identity))

            return

        # Remove tmp array node prior to the others, so that a new one
        # can be created in its stead (see below)
        graph.remove_node(nodes_to_remove[0])
        nodes_to_remove = nodes_to_remove[1:]

        # Create tasklet -> tmp -> tasklet connection
        tmp = graph.add_array(
            'tmp',
            memlet_edge.data.subset.bounding_box_size(),
            sdfg.arrays[memlet_edge.data.data].dtype,
            transient=True)
        tasklet_tmp_memlet = copy.deepcopy(memlet_edge.data)
        tasklet_tmp_memlet.data = tmp.data
        tasklet_tmp_memlet.subset = ShapeProperty.to_string(tmp.shape)

        # Modify memlet to point to output array
        memlet_edge.data.data = out_array.data

        # Recover reduction axes from CR reduce subset
        reduce_cr_subset = graph.in_edges(rmap_tasklet)[0].data.subset
        reduce_axes = []
        for ind, crvar in enumerate(reduce_cr_subset.indices):
            if '__i' in str(crvar):
                reduce_axes.append(ind)

        # Modify memlet access index by filtering out reduction axes
        if True:  # expr_index == 0:
            newindices = []
            for ind, ovar in enumerate(memlet_edge.data.subset.indices):
                if ind not in reduce_axes:
                    newindices.append(ovar)
        if len(newindices) == 0:
            newindices = [0]

        memlet_edge.data.subset = subsets.Indices(newindices)

        graph.remove_edge(memlet_edge)

        graph.add_edge(memlet_edge.src, memlet_edge.src_conn, tmp,
                       memlet_edge.dst_conn, tasklet_tmp_memlet)

        red_edges = list(graph.in_edges(rmap_tasklet))
        if len(red_edges) != 1:
            raise RuntimeError('CR edge must be unique')

        tmp_tasklet_memlet = copy.deepcopy(tasklet_tmp_memlet)
        graph.add_edge(tmp, None, rmap_tasklet, red_edges[0].dst_conn,
                       tmp_tasklet_memlet)

        for e in graph.edges_between(rmap_tasklet, rmap_cr):
            e.data.subset = memlet_edge.data.subset

        # Move output edges to point directly to CR node
        if expr_index == 1:
            # Set output memlet between CR node and outer reduction map to
            # contain the same subset as the one pointing to the CR node
            for e in graph.out_edges(rmap_cr):
                e.data.subset = memlet_edge.data.subset

            rmap_out = gnode(MapReduceFusion._rmap_out_exit)
            nxutil.change_edge_src(graph, rmap_out, omap_exit)

        # Remove nodes
        graph.remove_nodes_from(nodes_to_remove)

        # For unrelated outputs, connect original output to rmap_out
        if expr_index == 1 and tmap_exit not in nodes_to_remove:
            other_out_edges = list(graph.out_edges(tmap_exit))
            for e in other_out_edges:
                graph.remove_edge(e)
                graph.add_edge(e.src, e.src_conn, omap_exit, None, e.data)
                graph.add_edge(omap_exit, None, e.dst, e.dst_conn,
                               copy.copy(e.data))
Exemplo n.º 3
0
    def _stripmine(self, sdfg, graph, candidate):

        # Retrieve map entry and exit nodes.
        map_entry = graph.nodes()[candidate[StripMining._map_entry]]
        map_exit = graph.exit_nodes(map_entry)[0]

        # Retrieve transformation properties.
        dim_idx = self.dim_idx
        new_dim_prefix = self.new_dim_prefix
        tile_size = self.tile_size
        divides_evenly = self.divides_evenly
        strided = self.strided

        tile_stride = self.tile_stride
        if tile_stride is None or len(tile_stride) == 0:
            tile_stride = tile_size

        # Retrieve parameter and range of dimension to be strip-mined.
        target_dim = map_entry.map.params[dim_idx]
        td_from, td_to, td_step = map_entry.map.range[dim_idx]

        # Create new map. Replace by cloning???
        new_dim = self._find_new_dim(sdfg, graph, map_entry, new_dim_prefix,
                                     target_dim)
        nd_from = 0
        nd_to = symbolic.pystr_to_symbolic(
            'int_ceil(%s + 1 - %s, %s) - 1' %
            (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_stride))
        nd_step = 1
        new_dim_range = (nd_from, nd_to, nd_step)
        new_map = nodes.Map(new_dim + '_' + map_entry.map.label, [new_dim],
                            subsets.Range([new_dim_range]))
        new_map_entry = nodes.MapEntry(new_map)
        new_map_exit = nodes.MapExit(new_map)

        # Change the range of the selected dimension to iterate over a single
        # tile
        if strided:
            td_from_new = symbolic.pystr_to_symbolic(new_dim)
            td_to_new_approx = td_to
            td_step = symbolic.pystr_to_symbolic(tile_size)
        else:
            td_from_new = symbolic.pystr_to_symbolic(
                '%s + %s * %s' %
                (symbolic.symstr(td_from), str(new_dim), tile_stride))
            td_to_new_exact = symbolic.pystr_to_symbolic(
                'min(%s + 1, %s + %s * %s + %s) - 1' %
                (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_stride,
                 str(new_dim), tile_size))
            td_to_new_approx = symbolic.pystr_to_symbolic(
                '%s + %s * %s + %s - 1' %
                (symbolic.symstr(td_from), tile_stride, str(new_dim),
                 tile_size))
        if divides_evenly or strided:
            td_to_new = td_to_new_approx
        else:
            td_to_new = dace.symbolic.SymExpr(td_to_new_exact,
                                              td_to_new_approx)
        map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step)

        # Make internal map's schedule to "not parallel"
        new_map.schedule = map_entry.map.schedule
        map_entry.map.schedule = dtypes.ScheduleType.Sequential

        # Redirect edges
        new_map_entry.in_connectors = dcpy(map_entry.in_connectors)
        nxutil.change_edge_dest(graph, map_entry, new_map_entry)
        new_map_exit.out_connectors = dcpy(map_exit.out_connectors)
        nxutil.change_edge_src(graph, map_exit, new_map_exit)

        # Create new entry edges
        new_in_edges = dict()
        entry_in_conn = set()
        entry_out_conn = set()
        for _src, src_conn, _dst, _, memlet in graph.out_edges(map_entry):
            if (src_conn is not None
                    and src_conn[:4] == 'OUT_' and not isinstance(
                        sdfg.arrays[memlet.data], dace.data.Scalar)):
                new_subset = calc_set_image(
                    map_entry.map.params,
                    map_entry.map.range,
                    memlet.subset,
                )
                conn = src_conn[4:]
                key = (memlet.data, 'IN_' + conn, 'OUT_' + conn)
                if key in new_in_edges.keys():
                    old_subset = new_in_edges[key].subset
                    new_in_edges[key].subset = calc_set_union(
                        old_subset, new_subset)
                else:
                    entry_in_conn.add('IN_' + conn)
                    entry_out_conn.add('OUT_' + conn)
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_in_edges[key] = new_memlet
            else:
                if src_conn is not None and src_conn[:4] == 'OUT_':
                    conn = src_conn[4:]
                    in_conn = 'IN_' + conn
                    out_conn = 'OUT_' + conn
                else:
                    in_conn = src_conn
                    out_conn = src_conn
                if in_conn:
                    entry_in_conn.add(in_conn)
                if out_conn:
                    entry_out_conn.add(out_conn)
                new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet)
        new_map_entry.out_connectors = entry_out_conn
        map_entry.in_connectors = entry_in_conn
        for (_, in_conn, out_conn), memlet in new_in_edges.items():
            graph.add_edge(new_map_entry, out_conn, map_entry, in_conn, memlet)

        # Create new exit edges
        new_out_edges = dict()
        exit_in_conn = set()
        exit_out_conn = set()
        for _src, _, _dst, dst_conn, memlet in graph.in_edges(map_exit):
            if (dst_conn is not None
                    and dst_conn[:3] == 'IN_' and not isinstance(
                        sdfg.arrays[memlet.data], dace.data.Scalar)):
                new_subset = calc_set_image(
                    map_entry.map.params,
                    map_entry.map.range,
                    memlet.subset,
                )
                conn = dst_conn[3:]
                key = (memlet.data, 'IN_' + conn, 'OUT_' + conn)
                if key in new_out_edges.keys():
                    old_subset = new_out_edges[key].subset
                    new_out_edges[key].subset = calc_set_union(
                        old_subset, new_subset)
                else:
                    exit_in_conn.add('IN_' + conn)
                    exit_out_conn.add('OUT_' + conn)
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_out_edges[key] = new_memlet
            else:
                if dst_conn is not None and dst_conn[:3] == 'IN_':
                    conn = dst_conn[3:]
                    in_conn = 'IN_' + conn
                    out_conn = 'OUT_' + conn
                else:
                    in_conn = src_conn
                    out_conn = src_conn
                if in_conn:
                    exit_in_conn.add(in_conn)
                if out_conn:
                    exit_out_conn.add(out_conn)
                new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet)
        new_map_exit.in_connectors = exit_in_conn
        map_exit.out_connectors = exit_out_conn
        for (_, in_conn, out_conn), memlet in new_out_edges.items():
            graph.add_edge(map_exit, out_conn, new_map_exit, in_conn, memlet)

        # Return strip-mined dimension.
        return target_dim, new_dim, new_map
Exemplo n.º 4
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(state)
        output_nodes = nxutil.find_sink_nodes(state)

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        parent_sdfg = {state: sdfg}  # Map states to their parent SDFG
        for node, graph in state.all_nodes_recursive():
            if isinstance(graph, dace.SDFG):
                parent_sdfg[node] = graph
            if isinstance(node, dace.graph.nodes.AccessNode):
                for e in graph.all_edges(node):
                    if e.data.wcr is not None:
                        trace = dace.sdfg.trace_nested_access(
                            node, graph, parent_sdfg[graph])
                        for node_trace, state_trace, sdfg_trace in trace:
                            # Find the name of the accessed node in our scope
                            if state_trace == state and sdfg_trace == sdfg:
                                outer_name = node_trace.data
                                break
                        else:
                            # This does not trace back to the current state, so
                            # we don't care
                            continue
                        input_nodes.append(outer_name)
                        wcr_input_nodes.add(outer_name)

        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if not isinstance(node, dace.graph.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        materialize_func=desc.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet(node.data, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    nxutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            nxutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, edges.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if not isinstance(node, dace.graph.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        materialize_func=desc.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet('fpga_' + node.data,
                                    full_range.num_elements(), full_range, 1)
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                nxutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            nxutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, edges.InterstateEdge())

        veclen_ = 1

        # propagate vector info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            # need to go inside the nested SDFG and grab the vector length
            if isinstance(dst, dace.graph.nodes.NestedSDFG):
                # this edge is going to the nested SDFG
                for inner_state in dst.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == dst_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen
            if isinstance(src, dace.graph.nodes.NestedSDFG):
                # this edge is coming from the nested SDFG
                for inner_state in src.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == src_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen

            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
                mem.veclen = veclen_

        fpga_update(sdfg, state, 0)
Exemplo n.º 5
0
    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            nxutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            nxutil.change_edge_src(sdfg, second_state, first_state)
            nxutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in nxutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in nxutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in nxutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        # First keep a backup of the topological sorted order of the nodes
        order = [
            x for x in reversed(list(nx.topological_sort(first_state._nx)))
            if isinstance(x, nodes.AccessNode)
        ]
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in first_input:
            try:
                old_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_src(first_state, old_node, node)
            first_state.remove_node(old_node)
            second_input.remove(old_node)
            node.access = dtypes.AccessType.ReadWrite
        for node in first_output:
            try:
                new_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_dest(first_state, node, new_node)
            first_state.remove_node(node)
            second_input.remove(new_node)
            new_node.access = dtypes.AccessType.ReadWrite
        # Check if any input nodes of the second state have to be merged with
        # non-input/output nodes of the first state.
        for node in second_input:
            if first_state.in_degree(node) == 0:
                n = next((x for x in order if x.label == node.label), None)
                if n:
                    nxutil.change_edge_src(first_state, node, n)
                    first_state.remove_node(node)
                    n.access = dtypes.AccessType.ReadWrite

        # Redirect edges and remove second state
        nxutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if Config.get_bool("debugprint"):
            StateFusion._states_fused += 1
Exemplo n.º 6
0
    def __stripmine(self, sdfg, graph, candidate):

        # Retrieve map entry and exit nodes.
        map_entry = graph.nodes()[candidate[StripMining._map_entry]]
        map_exits = graph.exit_nodes(map_entry)

        # Retrieve transformation properties.
        dim_idx = self.dim_idx
        new_dim_prefix = self.new_dim_prefix
        tile_size = self.tile_size
        divides_evenly = self.divides_evenly
        strided = self.strided

        # Retrieve parameter and range of dimension to be strip-mined.
        target_dim = map_entry.map.params[dim_idx]
        td_from, td_to, td_step = map_entry.map.range[dim_idx]

        # Create new map. Replace by cloning???
        new_dim = new_dim_prefix + '_' + target_dim
        nd_from = 0
        nd_to = symbolic.pystr_to_symbolic(
            'int_ceil(%s + 1 - %s, %s) - 1' %
            (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_size))
        nd_step = 1
        new_dim_range = (nd_from, nd_to, nd_step)
        new_map = nodes.Map(new_dim + '_' + map_entry.map.label, [new_dim],
                            subsets.Range([new_dim_range]))
        new_map_entry = nodes.MapEntry(new_map)

        # Change the range of the selected dimension to iterate over a single
        # tile
        if strided:
            td_from_new = symbolic.pystr_to_symbolic(new_dim)
            td_to_new_approx = td_to
            td_step = symbolic.pystr_to_symbolic(tile_size)
        else:
            td_from_new = symbolic.pystr_to_symbolic(
                '%s + %s * %s' %
                (symbolic.symstr(td_from), str(new_dim), tile_size))
            td_to_new_exact = symbolic.pystr_to_symbolic(
                'min(%s + 1, %s + %s * %s + %s) - 1' %
                (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_size,
                 str(new_dim), tile_size))
            td_to_new_approx = symbolic.pystr_to_symbolic(
                '%s + %s * %s + %s - 1' %
                (symbolic.symstr(td_from), tile_size, str(new_dim), tile_size))
        if divides_evenly or strided:
            td_to_new = td_to_new_approx
        else:
            td_to_new = dace.symbolic.SymExpr(td_to_new_exact,
                                              td_to_new_approx)
        map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step)

        # Make internal map's schedule to "not parallel"
        map_entry.map._schedule = dtypes.ScheduleType.Default

        # Redirect/create edges.
        new_in_edges = {}
        for _src, conn, _dest, _, memlet in graph.out_edges(map_entry):
            if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar):
                new_subset = calc_set_image(
                    map_entry.map.params,
                    map_entry.map.range,
                    memlet.subset,
                )
                if memlet.data in new_in_edges:
                    src, src_conn, dest, dest_conn, new_memlet, num = \
                        new_in_edges[memlet.data]
                    new_memlet.subset = calc_set_union(new_memlet.subset,
                                                       new_subset)
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_in_edges.update({
                        memlet.data:
                        (src, src_conn, dest, dest_conn, new_memlet,
                         min(num, int(conn[4:])))
                    })
                else:
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    new_memlet.num_accesses = new_memlet.num_elements()
                    new_in_edges.update({
                        memlet.data:
                        (new_map_entry, None, map_entry, None, new_memlet,
                         int(conn[4:]))
                    })
        nxutil.change_edge_dest(graph, map_entry, new_map_entry)

        new_out_edges = {}
        new_exits = []
        for map_exit in map_exits:
            if isinstance(map_exit, nodes.MapExit):
                new_exit = nodes.MapExit(new_map)
                new_exits.append(new_exit)
            for _src, conn, _dest, _, memlet in graph.in_edges(map_exit):
                if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar):
                    new_subset = calc_set_image(
                        map_entry.map.params,
                        map_entry.map.range,
                        memlet.subset,
                    )
                    if memlet.data in new_out_edges:
                        src, src_conn, dest, dest_conn, new_memlet, num = \
                            new_out_edges[memlet.data]
                        new_memlet.subset = calc_set_union(
                            new_memlet.subset, new_subset)
                        new_memlet.num_accesses = new_memlet.num_elements()
                        new_out_edges.update({
                            memlet.data:
                            (src, src_conn, dest, dest_conn, new_memlet,
                             min(num, conn[4:]))
                        })
                    else:
                        new_memlet = dcpy(memlet)
                        new_memlet.subset = new_subset
                        new_memlet.num_accesses = new_memlet.num_elements()
                        new_out_edges.update({
                            memlet.data: (map_exit, None, new_exit, None,
                                          new_memlet, conn[4:])
                        })
            nxutil.change_edge_src(graph, map_exit, new_exit)

        in_conn_nums = []
        for _, e in new_in_edges.items():
            _, _, _, _, _, num = e
            in_conn_nums.append(num)
        in_conn = {}
        for i, num in enumerate(in_conn_nums):
            in_conn.update({num: i + 1})

        entry_in_connectors = set()
        entry_out_connectors = set()
        for i in range(len(in_conn_nums)):
            entry_in_connectors.add('IN_' + str(i + 1))
            entry_out_connectors.add('OUT_' + str(i + 1))
        new_map_entry.in_connectors = entry_in_connectors
        new_map_entry.out_connectors = entry_out_connectors

        for _, e in new_in_edges.items():
            src, _, dst, _, memlet, num = e
            graph.add_edge(src, 'OUT_' + str(in_conn[num]), dst,
                           'IN_' + str(in_conn[num]), memlet)

        for new_exit in new_exits:

            out_conn_nums = []
            for _, e in new_out_edges.items():
                _, _, dst, _, _, num = e
                if dst is not new_exit:
                    continue
                out_conn_nums.append(num)
            out_conn = {}
            for i, num in enumerate(out_conn_nums):
                out_conn.update({num: i + 1})

            exit_in_connectors = set()
            exit_out_connectors = set()
            for i in range(len(out_conn_nums)):
                exit_in_connectors.add('IN_' + str(i + 1))
                exit_out_connectors.add('OUT_' + str(i + 1))
            new_exit.in_connectors = exit_in_connectors
            new_exit.out_connectors = exit_out_connectors

            for _, e in new_out_edges.items():
                src, _, dst, _, memlet, num = e
                graph.add_edge(src, 'OUT_' + str(out_conn[num]), dst,
                               'IN_' + str(out_conn[num]), memlet)

        # Return strip-mined dimension.
        return target_dim, new_dim, new_map
Exemplo n.º 7
0
    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            nxutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            nxutil.change_edge_src(sdfg, second_state, first_state)
            nxutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in nxutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in nxutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in nxutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in first_input:
            try:
                old_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_src(first_state, old_node, node)
            first_state.remove_node(old_node)
            second_input.remove(old_node)
        for node in first_output:
            try:
                new_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_dest(first_state, node, new_node)
            first_state.remove_node(node)
            second_input.remove(new_node)

        # Redirect edges and remove second state
        nxutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if Config.get_bool("debugprint"):
            StateFusion._states_fused += 1
Exemplo n.º 8
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(state)
        output_nodes = nxutil.find_sink_nodes(state)

        fpga_data = {}

        if input_nodes:

            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if array.name in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.dtype,
                        array.shape,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=types.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        access_order=array.access_order,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[array.name] = fpga_array
                fpga_node = type(node)(fpga_array)

                pre_state.add_node(node)
                pre_state.add_node(fpga_node)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(array, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(node, None, fpga_node, None, mem)

                state.add_node(fpga_node)
                nxutil.change_edge_src(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(pre_state)
            nxutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, edges.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.dtype,
                        array.shape,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=types.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        access_order=array.access_order,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[node.data] = fpga_array
                fpga_node = type(node)(fpga_array)

                post_state.add_node(node)
                post_state.add_node(fpga_node)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(fpga_array, full_range.num_elements(),
                                    full_range, 1)
                post_state.add_edge(fpga_node, None, node, None, mem)

                state.add_node(fpga_node)
                nxutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            nxutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, edges.InterstateEdge())

        for src, _, dst, _, mem in state.edges():
            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + node.data

        fpga_update(state, 0)
Exemplo n.º 9
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(state)
        output_nodes = nxutil.find_sink_nodes(state)

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        for node, graph in state.all_nodes_recursive():
            if isinstance(node, dace.graph.nodes.AccessNode):
                for e in graph.all_edges(node):
                    if e.data.wcr is not None:
                        # This is an output node with wcr
                        # find the target in the parent sdfg

                        # following the structure State->SDFG->State-> SDFG
                        # from the current_state we have to go two levels up
                        parent_state = graph.parent.parent
                        if parent_state is not None:
                            for parent_edges in parent_state.edges():
                                if parent_edges.src_conn == e.dst.data or (
                                        isinstance(parent_edges.dst,
                                                   dace.graph.nodes.AccessNode)
                                        and e.dst.data
                                        == parent_edges.dst.data):
                                    # This must be copied to device
                                    input_nodes.append(parent_edges.dst)
                                    wcr_input_nodes.add(parent_edges.dst)

        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:
                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.shape,
                        array.dtype,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(node.data, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    nxutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            nxutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, edges.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:
                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.shape,
                        array.dtype,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet('fpga_' + node.data,
                                    full_range.num_elements(), full_range, 1)
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                nxutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            nxutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, edges.InterstateEdge())

        veclen_ = 1

        # propagate vector info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            # need to go inside the nested SDFG and grab the vector length
            if isinstance(dst, dace.graph.nodes.NestedSDFG):
                # this edge is going to the nested SDFG
                for inner_state in dst.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == dst_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen
            if isinstance(src, dace.graph.nodes.NestedSDFG):
                # this edge is coming from the nested SDFG
                for inner_state in src.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == src_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen

            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
                mem.veclen = veclen_

        fpga_update(sdfg, state, 0)
Exemplo n.º 10
0
    def apply(self, sdfg: sd.SDFG):

        #######################################################
        # Step 0: SDFG metadata

        # Find all input and output data descriptors
        input_nodes = []
        output_nodes = []
        global_code_nodes = [[] for _ in sdfg.nodes()]

        for i, state in enumerate(sdfg.nodes()):
            sdict = state.scope_dict()
            for node in state.nodes():
                if (isinstance(node, nodes.AccessNode)
                        and node.desc(sdfg).transient == False):
                    if (state.out_degree(node) > 0
                            and node.data not in input_nodes):
                        # Special case: nodes that lead to dynamic map ranges
                        # must stay on host
                        for e in state.out_edges(node):
                            last_edge = state.memlet_path(e)[-1]
                            if (isinstance(last_edge.dst, nodes.EntryNode)
                                    and last_edge.dst_conn and
                                    not last_edge.dst_conn.startswith('IN_')):
                                break
                        else:
                            input_nodes.append((node.data, node.desc(sdfg)))
                    if (state.in_degree(node) > 0
                            and node.data not in output_nodes):
                        output_nodes.append((node.data, node.desc(sdfg)))
                elif isinstance(node, nodes.CodeNode) and sdict[node] is None:
                    if not isinstance(node, nodes.EmptyTasklet):
                        global_code_nodes[i].append(node)

            # Input nodes may also be nodes with WCR memlets and no identity
            for e in state.edges():
                if e.data.wcr is not None and e.data.wcr_identity is None:
                    if (e.data.data not in input_nodes
                            and sdfg.arrays[e.data.data].transient == False):
                        input_nodes.append(
                            (e.data.data, sdfg.arrays[e.data.data]))

        start_state = sdfg.start_state
        end_states = sdfg.sink_nodes()

        #######################################################
        # Step 1: Create cloned GPU arrays and replace originals

        cloned_arrays = {}
        for inodename, inode in set(input_nodes):
            if isinstance(inode, data.Scalar):  # Scalars can remain on host
                continue
            newdesc = inode.clone()
            newdesc.storage = dtypes.StorageType.GPU_Global
            newdesc.transient = True
            name = sdfg.add_datadesc('gpu_' + inodename,
                                     newdesc,
                                     find_new_name=True)
            cloned_arrays[inodename] = name

        for onodename, onode in set(output_nodes):
            if onodename in cloned_arrays:
                continue
            newdesc = onode.clone()
            newdesc.storage = dtypes.StorageType.GPU_Global
            newdesc.transient = True
            name = sdfg.add_datadesc('gpu_' + onodename,
                                     newdesc,
                                     find_new_name=True)
            cloned_arrays[onodename] = name

        # Replace nodes
        for state in sdfg.nodes():
            for node in state.nodes():
                if (isinstance(node, nodes.AccessNode)
                        and node.data in cloned_arrays):
                    node.data = cloned_arrays[node.data]

        # Replace memlets
        for state in sdfg.nodes():
            for edge in state.edges():
                if edge.data.data in cloned_arrays:
                    edge.data.data = cloned_arrays[edge.data.data]

        #######################################################
        # Step 2: Create copy-in state
        excluded_copyin = self.exclude_copyin.split(',')

        copyin_state = sdfg.add_state(sdfg.label + '_copyin')
        sdfg.add_edge(copyin_state, start_state, ed.InterstateEdge())

        for nname, desc in dtypes.deduplicate(input_nodes):
            if nname in excluded_copyin or nname not in cloned_arrays:
                continue
            src_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo)
            dst_array = nodes.AccessNode(cloned_arrays[nname],
                                         debuginfo=desc.debuginfo)
            copyin_state.add_node(src_array)
            copyin_state.add_node(dst_array)
            copyin_state.add_nedge(
                src_array, dst_array,
                memlet.Memlet.from_array(src_array.data, src_array.desc(sdfg)))

        #######################################################
        # Step 3: Create copy-out state
        excluded_copyout = self.exclude_copyout.split(',')

        copyout_state = sdfg.add_state(sdfg.label + '_copyout')
        for state in end_states:
            sdfg.add_edge(state, copyout_state, ed.InterstateEdge())

        for nname, desc in dtypes.deduplicate(output_nodes):
            if nname in excluded_copyout or nname not in cloned_arrays:
                continue
            src_array = nodes.AccessNode(cloned_arrays[nname],
                                         debuginfo=desc.debuginfo)
            dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo)
            copyout_state.add_node(src_array)
            copyout_state.add_node(dst_array)
            copyout_state.add_nedge(
                src_array, dst_array,
                memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg)))

        #######################################################
        # Step 4: Modify transient data storage

        for state in sdfg.nodes():
            sdict = state.scope_dict()
            for node in state.nodes():
                if isinstance(node,
                              nodes.AccessNode) and node.desc(sdfg).transient:
                    nodedesc = node.desc(sdfg)

                    # Special case: nodes that lead to dynamic map ranges must
                    # stay on host
                    if any(
                            isinstance(
                                state.memlet_path(e)[-1].dst, nodes.EntryNode)
                            for e in state.out_edges(node)):
                        continue

                    if sdict[node] is None:
                        # NOTE: the cloned arrays match too but it's the same
                        # storage so we don't care
                        nodedesc.storage = dtypes.StorageType.GPU_Global

                        # Try to move allocation/deallocation out of loops
                        if (self.toplevel_trans
                                and not isinstance(nodedesc, data.Stream)):
                            nodedesc.toplevel = True
                    else:
                        # Make internal transients registers
                        if self.register_trans:
                            nodedesc.storage = dtypes.StorageType.Register

        #######################################################
        # Step 5: Wrap free tasklets and nested SDFGs with a GPU map

        for state, gcodes in zip(sdfg.nodes(), global_code_nodes):
            for gcode in gcodes:
                if gcode.label in self.exclude_tasklets.split(','):
                    continue
                # Create map and connectors
                me, mx = state.add_map(gcode.label + '_gmap',
                                       {gcode.label + '__gmapi': '0:1'},
                                       schedule=dtypes.ScheduleType.GPU_Device)
                # Store in/out edges in lists so that they don't get corrupted
                # when they are removed from the graph
                in_edges = list(state.in_edges(gcode))
                out_edges = list(state.out_edges(gcode))
                me.in_connectors = set('IN_' + e.dst_conn for e in in_edges)
                me.out_connectors = set('OUT_' + e.dst_conn for e in in_edges)
                mx.in_connectors = set('IN_' + e.src_conn for e in out_edges)
                mx.out_connectors = set('OUT_' + e.src_conn for e in out_edges)

                # Create memlets through map
                for e in in_edges:
                    state.remove_edge(e)
                    state.add_edge(e.src, e.src_conn, me, 'IN_' + e.dst_conn,
                                   e.data)
                    state.add_edge(me, 'OUT_' + e.dst_conn, e.dst, e.dst_conn,
                                   e.data)
                for e in out_edges:
                    state.remove_edge(e)
                    state.add_edge(e.src, e.src_conn, mx, 'IN_' + e.src_conn,
                                   e.data)
                    state.add_edge(mx, 'OUT_' + e.src_conn, e.dst, e.dst_conn,
                                   e.data)

                # Map without inputs
                if len(in_edges) == 0:
                    state.add_nedge(me, gcode, memlet.EmptyMemlet())
        #######################################################
        # Step 6: Change all top-level maps and Reduce nodes to GPU schedule

        for i, state in enumerate(sdfg.nodes()):
            sdict = state.scope_dict()
            for node in state.nodes():
                if isinstance(node, (nodes.EntryNode, nodes.Reduce)):
                    if sdict[node] is None:
                        node.schedule = dtypes.ScheduleType.GPU_Device
                    elif (isinstance(node, nodes.EntryNode)
                          and self.sequential_innermaps):
                        node.schedule = dtypes.ScheduleType.Sequential

        #######################################################
        # Step 7: Introduce copy-out if data used in outgoing interstate edges

        for state in list(sdfg.nodes()):
            arrays_used = set()
            for e in sdfg.out_edges(state):
                # Used arrays = intersection between symbols and cloned arrays
                arrays_used.update(
                    set(e.data.condition_symbols())
                    & set(cloned_arrays.keys()))

            # Create a state and copy out used arrays
            if len(arrays_used) > 0:
                co_state = sdfg.add_state(state.label + '_icopyout')

                # Reconnect outgoing edges to after interim copyout state
                for e in sdfg.out_edges(state):
                    nxutil.change_edge_src(sdfg, state, co_state)
                # Add unconditional edge to interim state
                sdfg.add_edge(state, co_state, ed.InterstateEdge())

                # Add copy-out nodes
                for nname in arrays_used:
                    desc = sdfg.arrays[nname]
                    src_array = nodes.AccessNode(cloned_arrays[nname],
                                                 debuginfo=desc.debuginfo)
                    dst_array = nodes.AccessNode(nname,
                                                 debuginfo=desc.debuginfo)
                    co_state.add_node(src_array)
                    co_state.add_node(dst_array)
                    co_state.add_nedge(
                        src_array, dst_array,
                        memlet.Memlet.from_array(dst_array.data,
                                                 dst_array.desc(sdfg)))

        #######################################################
        # Step 8: Strict transformations
        if not self.strict_transform:
            return

        # Apply strict state fusions greedily.
        sdfg.apply_strict_transformations()