예제 #1
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = sdutil.find_source_nodes(state)
        output_nodes = sdutil.find_sink_nodes(state)

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        parent_sdfg = {state: sdfg}  # Map states to their parent SDFG
        for node, graph in state.all_nodes_recursive():
            if isinstance(graph, dace.SDFG):
                parent_sdfg[node] = graph
            if isinstance(node, dace.sdfg.nodes.AccessNode):
                for e in graph.all_edges(node):
                    if e.data.wcr is not None:
                        trace = dace.sdfg.trace_nested_access(
                            node, graph, parent_sdfg[graph])
                        for node_trace, state_trace, sdfg_trace in trace:
                            # Find the name of the accessed node in our scope
                            if state_trace == state and sdfg_trace == sdfg:
                                outer_node = node_trace
                                break
                            else:
                                # This does not trace back to the current state, so
                                # we don't care
                                continue
                        input_nodes.append(outer_node)
                        wcr_input_nodes.add(outer_node)

        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        materialize_func=desc.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet(node.data, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    sdutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            sdutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, sd.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        materialize_func=desc.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet('fpga_' + node.data,
                                    full_range.num_elements(), full_range, 1)
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                sdutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            sdutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, sd.InterstateEdge())

        veclen_ = 1

        # propagate vector info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            # need to go inside the nested SDFG and grab the vector length
            if isinstance(dst, dace.sdfg.nodes.NestedSDFG):
                # this edge is going to the nested SDFG
                for inner_state in dst.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.sdfg.nodes.AccessNode
                                      ) and n.data == dst_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen
            if isinstance(src, dace.sdfg.nodes.NestedSDFG):
                # this edge is coming from the nested SDFG
                for inner_state in src.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.sdfg.nodes.AccessNode
                                      ) and n.data == src_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen

            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
                mem.veclen = veclen_

        fpga_update(sdfg, state, 0)
예제 #2
0
    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            sdutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            sdutil.change_edge_src(sdfg, second_state, first_state)
            sdutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in sdutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in sdutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in sdutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        # First keep a backup of the topological sorted order of the nodes
        order = [
            x for x in reversed(list(nx.topological_sort(first_state._nx)))
            if isinstance(x, nodes.AccessNode)
        ]
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in second_input:
            if first_state.in_degree(node) == 0:
                n = next((x for x in order if x.label == node.label), None)
                if n:
                    sdutil.change_edge_src(first_state, node, n)
                    first_state.remove_node(node)
                    n.access = dtypes.AccessType.ReadWrite

        # Redirect edges and remove second state
        sdutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if Config.get_bool("debugprint"):
            StateFusion._states_fused += 1
예제 #3
0
    def apply(self, _, sdfg):
        state = self.state

        # Find source/sink (data) nodes that are relevant outside this FPGA
        # kernel
        shared_transients = set(sdfg.shared_transients())
        input_nodes = [
            n for n in sdutil.find_source_nodes(state)
            if isinstance(n, nodes.AccessNode) and
            (not sdfg.arrays[n.data].transient or n.data in shared_transients)
        ]
        output_nodes = [
            n for n in sdutil.find_sink_nodes(state)
            if isinstance(n, nodes.AccessNode) and
            (not sdfg.arrays[n.data].transient or n.data in shared_transients)
        ]

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        parent_sdfg = {state: sdfg}  # Map states to their parent SDFG
        for node, graph in state.all_nodes_recursive():
            if isinstance(graph, dace.SDFG):
                parent_sdfg[node] = graph
            if isinstance(node, dace.sdfg.nodes.AccessNode):
                for e in graph.in_edges(node):
                    if e.data.wcr is not None:
                        trace = dace.sdfg.trace_nested_access(
                            node, graph, parent_sdfg[graph])
                        for node_trace, memlet_trace, state_trace, sdfg_trace in trace:
                            # Find the name of the accessed node in our scope
                            if state_trace == state and sdfg_trace == sdfg:
                                _, outer_node = node_trace
                                if outer_node is not None:
                                    break
                        else:
                            # This does not trace back to the current state, so
                            # we don't care
                            continue
                        input_nodes.append(outer_node)
                        wcr_input_nodes.add(outer_node)
        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_array[1].location = copy.copy(desc.location)
                    desc.location.clear()
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                mem = memlet.Memlet(data=node.data,
                                    subset=subsets.Range.from_array(desc))
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    sdutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            sdutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, sd.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_array[1].location = copy.copy(desc.location)
                    desc.location.clear()
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                mem = memlet.Memlet(f"fpga_{node.data}", None,
                                    subsets.Range.from_array(desc))
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                sdutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            sdutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, sd.InterstateEdge())

        # propagate memlet info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
        fpga_update(sdfg, state, 0)
예제 #4
0
    def apply(self, sdfg):
        if isinstance(self.subgraph[StateFusion.first_state], SDFGState):
            first_state: SDFGState = self.subgraph[StateFusion.first_state]
            second_state: SDFGState = self.subgraph[StateFusion.second_state]
        else:
            first_state: SDFGState = sdfg.node(
                self.subgraph[StateFusion.first_state])
            second_state: SDFGState = sdfg.node(
                self.subgraph[StateFusion.second_state])

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            sdutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            if sdfg.start_state == first_state:
                sdfg.start_state = sdfg.node_id(second_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            sdutil.change_edge_src(sdfg, second_state, first_state)
            sdutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            if sdfg.start_state == second_state:
                sdfg.start_state = sdfg.node_id(first_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in sdutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in sdutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in sdutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        top2 = top_level_nodes(second_state)

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.data == node.data), None) is None
        ]

        # Merge second state to first state
        # First keep a backup of the topological sorted order of the nodes
        sdict = first_state.scope_dict()
        order = [
            x for x in reversed(list(nx.topological_sort(first_state._nx)))
            if isinstance(x, nodes.AccessNode) and sdict[x] is None
        ]
        for node in second_state.nodes():
            if isinstance(node, nodes.NestedSDFG):
                # update parent information
                node.sdfg.parent = first_state
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        top = top_level_nodes(first_state)

        # Merge common (data) nodes
        for node in second_input:

            # merge only top level nodes, skip everything else
            if node not in top2:
                continue

            if first_state.in_degree(node) == 0:
                candidates = [
                    x for x in order if x.data == node.data and x in top
                ]
                if len(candidates) == 0:
                    continue
                elif len(candidates) == 1:
                    n = candidates[0]
                else:
                    # Choose first candidate that intersects memlets
                    for cand in candidates:
                        if StateFusion.memlets_intersect(
                                first_state, [cand], False, second_state,
                            [node], True):
                            n = cand
                            break
                    else:
                        # No node intersects, use topologically-last node
                        n = candidates[0]

                sdutil.change_edge_src(first_state, node, n)
                first_state.remove_node(node)
                n.access = dtypes.AccessType.ReadWrite

        # Redirect edges and remove second state
        sdutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if sdfg.start_state == second_state:
            sdfg.start_state = sdfg.node_id(first_state)