Esempio n. 1
0
    def apply(self, sdfg: SDFG) -> Union[Any, None]:

        state = sdfg.node(self.state_id)
        nsdfg = self.nsdfg(sdfg)

        read_set, write_set = nsdfg.sdfg.read_and_write_sets()
        prune_in = nsdfg.in_connectors.keys() - read_set
        prune_out = nsdfg.out_connectors.keys() - write_set

        # Detect which nodes are used, so we can delete unused nodes after the
        # connectors have been pruned
        all_data_used = read_set | write_set

        # Add WCR outputs to "do not prune" input list
        for e in state.out_edges(nsdfg):
            if e.data.wcr is not None and e.src_conn in prune_in:
                if (state.in_degree(
                        next(
                            iter(state.in_edges_by_connector(
                                nsdfg, e.src_conn))).src) > 0):
                    prune_in.remove(e.src_conn)

        for conn in prune_in:
            for e in state.in_edges_by_connector(nsdfg, conn):
                state.remove_memlet_path(e, remove_orphans=True)
                if conn in nsdfg.sdfg.arrays and conn not in all_data_used:
                    # If the data is now unused, we can purge it from the SDFG
                    nsdfg.sdfg.remove_data(conn)

        for conn in prune_out:
            for e in state.out_edges_by_connector(nsdfg, conn):
                state.remove_memlet_path(e, remove_orphans=True)
                if conn in nsdfg.sdfg.arrays and conn not in all_data_used:
                    # If the data is now unused, we can purge it from the SDFG
                    nsdfg.sdfg.remove_data(conn)
Esempio n. 2
0
File: papi.py Progetto: mfkiwl/dace
    def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node,
                             dfg: StateGraphView):
        scope_dict = sdfg.node(state_id).scope_dict()

        out_costs = 0
        for edge in dfg.out_edges(node):
            _, uconn, v, _, memlet = edge
            dst_node = dfg.memlet_path(edge)[-1].dst

            if (isinstance(node, nodes.CodeNode)
                    and isinstance(dst_node, nodes.AccessNode)):

                # If the memlet is pointing into an array in an inner scope,
                # it will be handled by the inner scope.
                if (scope_dict[node] != scope_dict[dst_node]
                        and scope_contains_scope(scope_dict, node, dst_node)):
                    continue

                if not uconn:
                    # This would normally raise a syntax error
                    return 0

                if memlet.subset.data_dims() == 0:
                    if memlet.wcr is not None:
                        # write_and_resolve
                        # We have to assume that every reduction costs 3
                        # accesses of the same size (read old, read new, write)
                        out_costs += 3 * PAPIUtils.get_memlet_byte_size(
                            sdfg, memlet)
                    else:
                        # This standard operation is already counted
                        out_costs += PAPIUtils.get_memlet_byte_size(
                            sdfg, memlet)
        return out_costs
Esempio n. 3
0
    def apply(self, sdfg: dace.SDFG) -> None:
        state = sdfg.node(self.state_id)
        left = self.left(sdfg)
        right = self.right(sdfg)

        # Merge source locations
        dinfo = self._merge_source_locations(left, right)

        # merge oir nodes
        res = HorizontalExecutionLibraryNode(
            oir_node=oir.HorizontalExecution(
                body=left.as_oir().body + right.as_oir().body,
                declarations=left.as_oir().declarations +
                right.as_oir().declarations,
            ),
            iteration_space=left.iteration_space,
            debuginfo=dinfo,
        )
        state.add_node(res)

        intermediate_accesses = set(
            n for path in nx.all_simple_paths(state.nx, left, right)
            for n in path[1:-1])

        # rewire edges and connectors to left and delete right
        for edge in state.edges_between(left, right):
            state.remove_edge_and_connectors(edge)
        for acc in intermediate_accesses:
            for edge in state.in_edges(acc):
                if edge.src is not left:
                    rewire_edge(state, edge, dst=res)
                else:
                    state.remove_edge_and_connectors(edge)
            for edge in state.out_edges(acc):
                if edge.dst is not right:
                    rewire_edge(state, edge, src=res)
                else:
                    state.remove_edge_and_connectors(edge)
        for edge in state.in_edges(left):
            rewire_edge(state, edge, dst=res)
        for edge in state.out_edges(right):
            rewire_edge(state, edge, src=res)
        for edge in state.out_edges(left):
            rewire_edge(state, edge, src=res)
        for edge in state.in_edges(right):
            rewire_edge(state, edge, dst=res)
        state.remove_node(left)
        state.remove_node(right)
        for acc in intermediate_accesses:
            if not state.in_edges(acc):
                if not state.out_edges(acc):
                    state.remove_node(acc)
                else:
                    assert (len(state.edges_between(acc, res)) == 1
                            and len(state.out_edges(acc))
                            == 1), "Previously written array now read-only."
                    state.remove_node(acc)
                    res.remove_in_connector("IN_" + acc.label)
            elif not state.out_edges:
                acc.access = dace.AccessType.WriteOnly
Esempio n. 4
0
    def apply(self, sdfg: SDFG) -> Union[Any, None]:

        state = sdfg.node(self.state_id)
        nsdfg = self.nsdfg(sdfg)

        read_set, write_set = nsdfg.sdfg.read_and_write_sets()
        prune_in = nsdfg.in_connectors.keys() - read_set
        prune_out = nsdfg.out_connectors.keys() - write_set

        # Detect which nodes are used, so we can delete unused nodes after the
        # connectors have been pruned
        all_data_used = read_set | write_set
        # Add WCR outputs to "do not prune" input list
        for e in state.out_edges(nsdfg):
            if e.data.wcr is not None and e.src_conn in prune_in:
                if (state.in_degree(
                        next(
                            iter(state.in_edges_by_connector(
                                nsdfg, e.src_conn))).src) > 0):
                    prune_in.remove(e.src_conn)
        do_not_prune = set()
        for conn in prune_in:
            if any(
                    state.in_degree(state.memlet_path(e)[0].src) > 0
                    for e in state.in_edges(nsdfg) if e.dst_conn == conn):
                do_not_prune.add(conn)
                continue
            for e in state.in_edges_by_connector(nsdfg, conn):
                state.remove_memlet_path(e, remove_orphans=True)

        for conn in prune_out:
            if any(
                    state.out_degree(state.memlet_path(e)[-1].dst) > 0
                    for e in state.out_edges(nsdfg) if e.src_conn == conn):
                do_not_prune.add(conn)
                continue
            for e in state.out_edges_by_connector(nsdfg, conn):
                state.remove_memlet_path(e, remove_orphans=True)

        for conn in prune_in:
            if conn in nsdfg.sdfg.arrays and conn not in all_data_used and conn not in do_not_prune:
                # If the data is now unused, we can purge it from the SDFG
                nsdfg.sdfg.remove_data(conn)
        for conn in prune_out:
            if conn in nsdfg.sdfg.arrays and conn not in all_data_used and conn not in do_not_prune:
                # If the data is now unused, we can purge it from the SDFG
                nsdfg.sdfg.remove_data(conn)

        if self.remove_unused_containers:
            # Remove unused containers from parent SDFGs
            containers = list(sdfg.arrays.keys())
            for name in containers:
                s = nsdfg.sdfg
                while s.parent_sdfg:
                    s = s.parent_sdfg
                    try:
                        s.remove_data(name)
                    except ValueError:
                        break
Esempio n. 5
0
    def apply(self, sdfg: dace.SDFG):
        before_state = sdfg.node(self.subgraph[self._before_state])
        loop_state = sdfg.node(self.subgraph[self._loop_state])
        guard_state = sdfg.node(self.subgraph[self._guard_state])
        loop_var = next(iter(sdfg.in_edges(guard_state)[0].data.assignments))

        loop_axis = self._get_loop_axis(loop_state, loop_var)

        buffer_size = self._get_buffer_size(loop_state, loop_var, loop_axis)
        self._replace_indices(sdfg.states(), loop_var, loop_axis, buffer_size)

        array = sdfg.arrays[self.array]
        # TODO: generalize
        if array.shape[loop_axis] == array.total_size:
            array.shape = tuple(buffer_size if i == loop_axis else s
                                for i, s in enumerate(array.shape))
            array.total_size = buffer_size
Esempio n. 6
0
    def apply(self, sdfg: dace.SDFG):
        state = sdfg.node(self.state_id)
        first_map_entry = state.node(self.subgraph[self._first_map_entry])
        first_tasklet = state.node(self.subgraph[self._first_tasklet])
        first_map_exit = state.node(self.subgraph[self._first_map_exit])
        array_access = state.node(self.subgraph[self._array_access])
        second_map_entry = state.node(self.subgraph[self._second_map_entry])

        self._update_map_connectors(state, array_access, first_map_entry,
                                    second_map_entry)

        self._replicate_first_map(sdfg, array_access, first_map_entry,
                                  first_map_exit, second_map_entry)

        state.remove_nodes_from(
            state.all_nodes_between(first_map_entry, first_map_exit)
            | {first_map_exit})
Esempio n. 7
0
    def apply(self, sdfg: SDFG):
        input: nodes.AccessNode = self.input(sdfg)
        tasklet: nodes.Tasklet = self.tasklet(sdfg)
        output: nodes.AccessNode = self.output(sdfg)
        state: SDFGState = sdfg.node(self.state_id)

        # If state fission is necessary to keep semantics, do it first
        if (self.expr_index == 0 and state.in_degree(input) > 0
                and state.out_degree(output) == 0):
            newstate = sdfg.add_state_after(state)
            newstate.add_node(tasklet)
            new_input, new_output = None, None

            # Keep old edges for after we remove tasklet from the original state
            in_edges = list(state.in_edges(tasklet))
            out_edges = list(state.out_edges(tasklet))

            for e in in_edges:
                r = newstate.add_read(e.src.data)
                newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data)
                if e.src is input:
                    new_input = r
            for e in out_edges:
                w = newstate.add_write(e.dst.data)
                newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data)
                if e.dst is output:
                    new_output = w

            # Remove tasklet and resulting isolated nodes
            state.remove_node(tasklet)
            for e in in_edges:
                if state.degree(e.src) == 0:
                    state.remove_node(e.src)
            for e in out_edges:
                if state.degree(e.dst) == 0:
                    state.remove_node(e.dst)

            # Reset state and nodes for rest of transformation
            input = new_input
            output = new_output
            state = newstate
        # End of state fission

        if self.expr_index == 0:
            inedges = state.edges_between(input, tasklet)
            outedge = state.edges_between(tasklet, output)[0]
        else:
            me = self.map_entry(sdfg)
            mx = self.map_exit(sdfg)

            inedges = state.edges_between(me, tasklet)
            outedge = state.edges_between(tasklet, mx)[0]

        # Get relevant output connector
        outconn = outedge.src_conn

        ops = '[%s]' % ''.join(
            re.escape(o) for o in AugAssignToWCR._EXPRESSIONS)

        # Change tasklet code
        if tasklet.language is dtypes.Language.Python:
            raise NotImplementedError
        elif tasklet.language is dtypes.Language.CPP:
            cstr = tasklet.code.as_string.strip()
            for edge in inedges:
                inconn = edge.dst_conn
                match = re.match(
                    r'^\s*%s\s*=\s*%s\s*(%s)(.*);$' %
                    (re.escape(outconn), re.escape(inconn), ops), cstr)
                if match is None:
                    # match = re.match(
                    #     r'^\s*%s\s*=\s*(.*)\s*(%s)\s*%s;$' %
                    #     (re.escape(outconn), ops, re.escape(inconn)), cstr)
                    # if match is None:
                    continue
                    # op = match.group(2)
                    # expr = match.group(1)
                else:
                    op = match.group(1)
                    expr = match.group(2)

                if edge.data.subset != outedge.data.subset:
                    continue


                # Map asymmetric WCRs to symmetric ones if possible
                if op in AugAssignToWCR._EXPR_MAP:
                    op, newexpr = AugAssignToWCR._EXPR_MAP[op]
                    expr = newexpr.format(expr=expr)

                tasklet.code.code = '%s = %s;' % (outconn, expr)
                inedge = edge
                break
        else:
            raise NotImplementedError

        # Change output edge
        outedge.data.wcr = f'lambda a,b: a {op} b'

        if self.expr_index == 0:
            # Remove input node and connector
            state.remove_edge_and_connectors(inedge)
            if state.degree(input) == 0:
                state.remove_node(input)
        else:
            # Remove input edge and dst connector, but not necessarily src
            state.remove_memlet_path(inedge)

        # If outedge leads to non-transient, and this is a nested SDFG,
        # propagate outwards
        sd = sdfg
        while (not sd.arrays[outedge.data.data].transient
               and sd.parent_nsdfg_node is not None):
            nsdfg = sd.parent_nsdfg_node
            nstate = sd.parent
            sd = sd.parent_sdfg
            outedge = next(
                iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data)))
            for outedge in nstate.memlet_path(outedge):
                outedge.data.wcr = f'lambda a,b: a {op} b'
Esempio n. 8
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the map and its entry and exit nodes.
        graph = sdfg.node(self.state_id)
        map_entry = self.map_entry(sdfg)
        map_exit = graph.exit_node(map_entry)
        current_map = map_entry.map

        # Create new maps
        new_maps = [
            nodes.Map(current_map.label + '_' + str(param), [param],
                      subsets.Range([param_range]),
                      schedule=dtypes.ScheduleType.Sequential) for param,
            param_range in zip(current_map.params[1:], current_map.range[1:])
        ]
        current_map.params = [current_map.params[0]]
        current_map.range = subsets.Range([current_map.range[0]])

        # Create new map entries and exits
        entries = [nodes.MapEntry(new_map) for new_map in new_maps]
        exits = [nodes.MapExit(new_map) for new_map in new_maps]

        # Create edges, abiding by the following rules:
        # 1. If there are no edges coming from the outside, use empty memlets
        # 2. Edges with IN_* connectors replicate along the maps
        # 3. Edges for dynamic map ranges replicate until reaching range(s)
        for edge in graph.out_edges(map_entry):
            graph.remove_edge(edge)
            graph.add_memlet_path(map_entry,
                                  *entries,
                                  edge.dst,
                                  src_conn=edge.src_conn,
                                  memlet=edge.data,
                                  dst_conn=edge.dst_conn)

        # Modify dynamic map ranges
        dynamic_edges = dace.sdfg.dynamic_map_inputs(graph, map_entry)
        for edge in dynamic_edges:
            # Remove old edge and connector
            graph.remove_edge(edge)
            edge.dst.remove_in_connector(edge.dst_conn)

            # Propagate to each range it belongs to
            path = []
            for mapnode in [map_entry] + entries:
                path.append(mapnode)
                if any(edge.dst_conn in map(str, symbolic.symlist(r))
                       for r in mapnode.map.range):
                    graph.add_memlet_path(edge.src,
                                          *path,
                                          memlet=edge.data,
                                          src_conn=edge.src_conn,
                                          dst_conn=edge.dst_conn)

        # Create new map exits
        for edge in graph.in_edges(map_exit):
            graph.remove_edge(edge)
            graph.add_memlet_path(edge.src,
                                  *exits[::-1],
                                  map_exit,
                                  memlet=edge.data,
                                  src_conn=edge.src_conn,
                                  dst_conn=edge.dst_conn)

        from dace.sdfg.scope import ScopeTree
        scope = None
        queue: List[ScopeTree] = graph.scope_leaves()
        while len(queue) > 0:
            tnode = queue.pop()
            if tnode.entry == entries[-1]:
                scope = tnode
                break
            elif tnode.parent is not None:
                queue.append(tnode.parent)
        else:
            raise ValueError('Cannot find scope in state')

        consolidate_edges(sdfg, scope)

        return [map_entry] + entries
 def apply(self, sdfg: dace.SDFG):
     guard = sdfg.node(self.subgraph[ld.DetectLoop._loop_guard])
     edge = sdfg.in_edges(guard)[0]
     loopindex = next(iter(edge.data.assignments.keys()))
     guard._LOOPINDEX = loopindex
Esempio n. 10
0
    def generate_scope(self, sdfg: dace.SDFG, scope: ScopeSubgraphView,
                       state_id: int, function_stream: CodeIOStream,
                       callsite_stream: CodeIOStream):
        entry_node = scope.source_nodes()[0]

        loop_type = list(set([sdfg.arrays[a].dtype for a in sdfg.arrays]))[0]
        ltype_size = loop_type.bytes

        long_type = copy.copy(dace.int64)
        long_type.ctype = 'int64_t'

        self.counter_type = {
            1: dace.int8,
            2: dace.int16,
            4: dace.int32,
            8: long_type
        }[ltype_size]

        callsite_stream.write('{')

        # Define all input connectors of the map entry
        state_dfg = sdfg.node(state_id)
        for e in dace.sdfg.dynamic_map_inputs(state_dfg, entry_node):
            if e.data.data != e.dst_conn:
                callsite_stream.write(
                    self.cpu_codegen.memlet_definition(
                        sdfg, e.data, False, e.dst_conn,
                        e.dst.in_connectors[e.dst_conn]), sdfg, state_id,
                    entry_node)

        # We only create an SVE do-while in the innermost loop
        for param, rng in zip(entry_node.map.params, entry_node.map.range):
            begin, end, stride = (sym2cpp(r) for r in rng)

            self.dispatcher.defined_vars.enter_scope(sdfg)

            # Check whether we are in the innermost loop
            if param != entry_node.map.params[-1]:
                # Default C++ for-loop
                callsite_stream.write(
                    f'for(auto {param} = {begin}; {param} <= {end}; {param} += {stride}) {{'
                )
            else:
                # Generate the SVE loop header

                # The name of our loop predicate is always __pg_{param}
                self.dispatcher.defined_vars.add('__pg_' + param,
                                                 DefinedType.Scalar, 'svbool_t')

                # Declare our counting variable (e.g. i) and precompute the loop predicate for our range
                callsite_stream.write(
                    f'''{self.counter_type} {param} = {begin};
                    svbool_t __pg_{param} = svwhilele_b{ltype_size * 8}({param}, ({self.counter_type}) {end});
                    do {{''', sdfg, state_id, entry_node)

        # Dispatch the subgraph generation
        self.dispatcher.dispatch_subgraph(sdfg,
                                          scope,
                                          state_id,
                                          function_stream,
                                          callsite_stream,
                                          skip_entry_node=True,
                                          skip_exit_node=True)

        # Close the loops from above (in reverse)
        for param, rng in zip(reversed(entry_node.map.params),
                              reversed(entry_node.map.range)):
            # The innermost loop is SVE and needs a special while-footer, otherwise we just add the closing bracket
            if param != entry_node.map.params[-1]:
                # Close the default C++ for-loop
                callsite_stream.write('}')
            else:
                # Generate the SVE loop footer

                _, end, stride = (sym2cpp(r) for r in rng)

                # Increase the counting variable (according to the number of processed elements)
                # Then recompute the loop predicate and test for it
                callsite_stream.write(
                    f'''{param} += svcntp_b{ltype_size * 8}(__pg_{param}, __pg_{param}) * {stride};
                    __pg_{param} = svwhilele_b{ltype_size * 8}({param}, ({self.counter_type}) {end});
                    }} while(svptest_any(svptrue_b{ltype_size * 8}(), __pg_{param}));''',
                    sdfg, state_id, entry_node)

            self.dispatcher.defined_vars.exit_scope(sdfg)

        callsite_stream.write('}')
Esempio n. 11
0
    def apply(self, sdfg: dace.SDFG):
        graph: dace.SDFGState = sdfg.node(self.state_id)
        stencil_a: Stencil = graph.node(
            self.subgraph[StencilFusion._stencil_a])
        stencil_b: Stencil = graph.node(
            self.subgraph[StencilFusion._stencil_b])
        array: nodes.AccessNode = graph.node(
            self.subgraph[StencilFusion._tmp_array])

        intermediate_name = graph.in_edges(array)[0].src_conn
        intermediate_name_b = graph.out_edges(array)[0].dst_conn

        # Replace outputs of first stencil with outputs of second stencil
        # In node and in connectors, reconnect
        stencil_a.output_fields = stencil_b.output_fields
        stencil_a.boundary_conditions = stencil_b.boundary_conditions
        for edge in list(graph.out_edges(stencil_a)):
            if edge.src_conn == intermediate_name:
                graph.remove_edge(edge)
                del stencil_a._out_connectors[intermediate_name]
        for edge in graph.out_edges(stencil_b):
            stencil_a.add_out_connector(edge.src_conn)
            graph.add_edge(stencil_a, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        # Add other stencil inputs of the second stencil to the first
        # In node and in connectors, reconnect
        for edge in graph.in_edges(stencil_b):
            # Skip edge to remove
            if edge.dst_conn == intermediate_name_b:
                continue
            if edge.dst_conn not in stencil_a.accesses:
                stencil_a.accesses[edge.dst_conn] = stencil_b.accesses[
                    edge.dst_conn]
                stencil_a.add_in_connector(edge.dst_conn)
                graph.add_edge(edge.src, edge.src_conn, stencil_a,
                               edge.dst_conn, edge.data)
            else:
                # If same input is accessed in both stencils, only append the
                # inputs that are new to stencil_a
                for access in stencil_b.accesses[edge.dst_conn][1]:
                    if access not in stencil_a.accesses[edge.dst_conn][1]:
                        stencil_a.accesses[edge.dst_conn][1].append(access)

        # Add second stencil's statements to first stencil, replacing the input
        # to the second stencil with the name of the output access
        if stencil_a.code.language == dace.Language.Python:
            # Replace first stencil's output with connector name
            for i, stmt in enumerate(stencil_a.code.code):
                stencil_a.code.code[i] = ReplaceSubscript({
                    intermediate_name:
                    intermediate_name_b
                }).visit(stmt)

            # Append second stencil's contents, using connector name instead of
            # accessing the intermediate transient
            # TODO: Use offsetted stencil
            for i, stmt in enumerate(stencil_b.code.code):
                stencil_a.code.code.append(
                    ReplaceSubscript({
                        intermediate_name_b: intermediate_name_b
                    }).visit(stmt))

        elif stencil_a.code.language == dace.Language.CPP:
            raise NotImplementedError
        else:
            raise ValueError('Unrecognized language: %s' %
                             stencil_a.code.language)

        # Remove array from graph
        graph.remove_node(array)
        del sdfg.arrays[array.data]

        # Remove 2nd stencil
        graph.remove_node(stencil_b)
Esempio n. 12
0
    def apply(self, sdfg: SDFG) -> nodes.MapEntry:
        me: nodes.MapEntry = self.mapentry(sdfg)
        graph = sdfg.node(self.state_id)

        # Add new map within map
        mx = graph.exit_node(me)
        new_me, new_mx = graph.add_map('warp_tile',
                                       dict(__tid=f'0:{self.warp_size}'),
                                       dtypes.ScheduleType.GPU_ThreadBlock)
        __tid = symbolic.pystr_to_symbolic('__tid')
        for e in graph.out_edges(me):
            xfh.reconnect_edge_through_map(graph, e, new_me, True)
        for e in graph.in_edges(mx):
            xfh.reconnect_edge_through_map(graph, e, new_mx, False)

        # Stride and offset all internal maps
        maps_to_stride = xfh.get_internal_scopes(graph, new_me, immediate=True)
        for nstate, nmap in maps_to_stride:
            nsdfg = nstate.parent
            nsdfg_node = nsdfg.parent_nsdfg_node

            # Map cannot be partitioned across a warp
            if (nmap.range.size()[-1] < self.warp_size) == True:
                continue

            if nsdfg is not sdfg and nsdfg_node is not None:
                nsdfg_node.symbol_mapping['__tid'] = __tid
                if '__tid' not in nsdfg.symbols:
                    nsdfg.add_symbol('__tid', dtypes.int32)
            nmap.range[-1] = (nmap.range[-1][0], nmap.range[-1][1] - __tid,
                              nmap.range[-1][2] * self.warp_size)
            subgraph = nstate.scope_subgraph(nmap)
            subgraph.replace(nmap.params[-1], f'{nmap.params[-1]} + __tid')
            inner_map_exit = nstate.exit_node(nmap)
            # If requested, replicate maps with multiple dependent maps
            if self.replicate_maps:
                destinations = [
                    nstate.memlet_path(edge)[-1].dst
                    for edge in nstate.out_edges(inner_map_exit)
                ]

                for dst in destinations:
                    # Transformation will not replicate map with more than one
                    # output
                    if len(destinations) != 1:
                        break
                    if not isinstance(dst, nodes.AccessNode):
                        continue  # Not leading to access node
                    if not xfh.contained_in(nstate, dst, new_me):
                        continue  # Memlet path goes out of map
                    if not nsdfg.arrays[dst.data].transient:
                        continue  # Cannot modify non-transients
                    for edge in nstate.out_edges(dst)[1:]:
                        rep_subgraph = xfh.replicate_scope(
                            nsdfg, nstate, subgraph)
                        rep_edge = nstate.out_edges(
                            rep_subgraph.sink_nodes()[0])[0]
                        # Add copy of data
                        newdesc = copy.deepcopy(sdfg.arrays[dst.data])
                        newname = nsdfg.add_datadesc(dst.data,
                                                     newdesc,
                                                     find_new_name=True)
                        newaccess = nstate.add_access(newname)
                        # Redirect edges
                        xfh.redirect_edge(nstate,
                                          rep_edge,
                                          new_dst=newaccess,
                                          new_data=newname)
                        xfh.redirect_edge(nstate,
                                          edge,
                                          new_src=newaccess,
                                          new_data=newname)

            # If has WCR, add warp-collaborative reduction on outputs
            for out_edge in nstate.out_edges(inner_map_exit):
                if out_edge.data.wcr is not None:
                    ctype = nsdfg.arrays[out_edge.data.data].dtype.ctype
                    redtype = detect_reduction_type(out_edge.data.wcr)
                    if redtype == dtypes.ReductionType.Custom:
                        raise NotImplementedError
                    credtype = ('dace::ReductionType::' +
                                str(redtype)[str(redtype).find('.') + 1:])

                    # Add local access between thread-local and warp reduction
                    name = nsdfg._find_new_name(out_edge.data.data)
                    nsdfg.add_scalar(name,
                                     nsdfg.arrays[out_edge.data.data].dtype,
                                     transient=True)

                    # Initialize thread-local to global value
                    read = nstate.add_read(out_edge.data.data)
                    write = nstate.add_write(name)
                    edge = nstate.add_nedge(read, write,
                                            copy.deepcopy(out_edge.data))
                    edge.data.wcr = None
                    xfh.state_fission(nsdfg,
                                      SubgraphView(nstate, [read, write]))

                    newnode = nstate.add_access(name)
                    nstate.remove_edge(out_edge)
                    edge = nstate.add_edge(out_edge.src, out_edge.src_conn,
                                           newnode, None,
                                           copy.deepcopy(out_edge.data))
                    for e in nstate.memlet_path(edge):
                        e.data.data = name
                        e.data.subset = subsets.Range([(0, 0, 1)])

                    if out_edge.data.subset.num_elements(
                    ) == 1:  # One element: tasklet
                        wrt = nstate.add_tasklet(
                            'warpreduce', {'__a'}, {'__out'},
                            f'__out = dace::warpReduce<{credtype}, {ctype}>::reduce(__a);',
                            dtypes.Language.CPP)
                        nstate.add_edge(newnode, None, wrt, '__a',
                                        Memlet(name))
                        out_edge.data.wcr = None
                        nstate.add_edge(wrt, '__out', out_edge.dst, None,
                                        out_edge.data)
                    else:  # More than one element: mapped tasklet
                        raise NotImplementedError
            # End of WCR to warp reduction

        # Make nested SDFG out of new scope
        xfh.nest_state_subgraph(sdfg, graph,
                                graph.scope_subgraph(new_me, False, False))

        return new_me
Esempio n. 13
0
    def apply(self, sdfg: dace.SDFG):
        graph: dace.SDFGState = sdfg.node(self.state_id)
        map_entry: nodes.MapEntry = graph.node(self.subgraph[NestK._map_entry])
        stencil: Stencil = graph.node(self.subgraph[NestK._stencil])

        # Find dimension index and name
        pname = map_entry.map.params[0]
        dim_index = None
        for edge in graph.all_edges(stencil):
            if edge.data.data is None:  # Empty memlet
                continue

            if len(edge.data.subset) == 3:
                for i, rng in enumerate(edge.data.subset.ndrange()):
                    for r in rng:
                        if (pname in map(str, r.free_symbols)):
                            dim_index = i
                            break
                    if dim_index is not None:
                        break
                if dim_index is not None:
                    break
        ###

        map_exit = graph.exit_node(map_entry)

        # Reconnect external edges directly to stencil node
        for edge in graph.in_edges(map_entry):
            # Find matching internal edges
            tree = graph.memlet_tree(edge)
            for child in tree.children:
                memlet = propagation.propagate_memlet(graph, child.edge.data,
                                                      map_entry, False)
                graph.add_edge(edge.src, edge.src_conn, stencil,
                               child.edge.dst_conn, memlet)
        for edge in graph.out_edges(map_exit):
            # Find matching internal edges
            tree = graph.memlet_tree(edge)
            for child in tree.children:
                memlet = propagation.propagate_memlet(graph, child.edge.data,
                                                      map_entry, False)
                graph.add_edge(stencil, child.edge.src_conn, edge.dst,
                               edge.dst_conn, memlet)

        # Remove map
        graph.remove_nodes_from([map_entry, map_exit])

        # Reshape stencil node computation based on nested map range
        stencil.shape[dim_index] = map_entry.map.range.num_elements()

        # Add dimensions to access and output fields
        add_dims = set()
        for edge in graph.in_edges(stencil):
            if edge.data.data and len(edge.data.subset) == 3:
                if stencil.accesses[edge.dst_conn][0][dim_index] is False:
                    add_dims.add(edge.dst_conn)
                stencil.accesses[edge.dst_conn][0][dim_index] = True
        for edge in graph.out_edges(stencil):
            if edge.data.data and len(edge.data.subset) == 3:
                if stencil.output_fields[edge.src_conn][0][dim_index] is False:
                    add_dims.add(edge.src_conn)
                stencil.output_fields[edge.src_conn][0][dim_index] = True
        # Change all instances in the code as well
        if stencil.code.language != dace.Language.Python:
            raise ValueError(
                'For NestK to work, Stencil code language must be Python')
        for i, stmt in enumerate(stencil.code.code):
            stencil.code.code[i] = DimensionAdder(add_dims,
                                                  dim_index).visit(stmt)