예제 #1
0
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        out_edges = graph.out_edges(first_state)
        in_edges = graph.in_edges(first_state)

        # First state must have only one output edge (with dst the second
        # state).
        if len(out_edges) != 1:
            return False
        # The interstate edge must not have a condition.
        if not out_edges[0].data.is_unconditional():
            return False
        # The interstate edge may have assignments, as long as there are input
        # edges to the first state, that can absorb them.
        if out_edges[0].data.assignments and not in_edges:
            return False
        # There can be no state that have output edges pointing to both the
        # first and the second state. Such a case will produce a multi-graph.
        for src, _, _ in in_edges:
            for _, dst, _ in graph.out_edges(src):
                if dst == second_state:
                    return False

        if strict:
            # If second state has other input edges, there might be issues
            # Exceptions are when none of the states contain dataflow, unless
            # the first state is an initial state (in which case the new initial
            # state would be ambiguous).
            first_in_edges = graph.in_edges(first_state)
            second_in_edges = graph.in_edges(second_state)
            if ((not second_state.is_empty() or not first_state.is_empty()
                 or len(first_in_edges) == 0) and len(second_in_edges) != 1):
                return False

            # Get connected components.
            first_cc = [
                cc_nodes
                for cc_nodes in nx.weakly_connected_components(first_state._nx)
            ]
            second_cc = [
                cc_nodes for cc_nodes in nx.weakly_connected_components(
                    second_state._nx)
            ]

            # Find source/sink (data) nodes
            first_input = {
                node
                for node in nxutil.find_source_nodes(first_state)
                if isinstance(node, nodes.AccessNode)
            }
            first_output = {
                node
                for node in first_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in first_input
            }
            second_input = {
                node
                for node in nxutil.find_source_nodes(second_state)
                if isinstance(node, nodes.AccessNode)
            }
            second_output = {
                node
                for node in second_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in second_input
            }

            # Find source/sink (data) nodes by connected component
            first_cc_input = [cc.intersection(first_input) for cc in first_cc]
            first_cc_output = [
                cc.intersection(first_output) for cc in first_cc
            ]
            second_cc_input = [
                cc.intersection(second_input) for cc in second_cc
            ]
            second_cc_output = [
                cc.intersection(second_output) for cc in second_cc
            ]

            # Apply transformation in case all paths to the second state's
            # nodes go through the same access node, which implies sequential
            # behavior in SDFG semantics.
            check_strict = len(first_cc)
            for cc_output in first_cc_output:
                out_nodes = [
                    n for n in first_state.sink_nodes() if n in cc_output
                ]
                # Branching exists, multiple paths may involve same access node
                # potentially causing data races
                if len(out_nodes) > 1:
                    continue

                # Otherwise, check if any of the second state's connected
                # components for matching input
                for node in cc_output:
                    if (next(
                        (x for x in second_input if x.label == node.label),
                            None) is not None):
                        check_strict -= 1
                        break

            if check_strict > 0:
                # Check strict conditions
                # RW dependency
                for node in first_input:
                    if (next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None):
                        return False
                # WW dependency
                for node in first_output:
                    if (next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None):
                        return False

        return True
예제 #2
0
    def apply(self, sdfg):
        # Copy SDFG to nested SDFG
        nested_sdfg = dace.SDFG('nested_' + sdfg.label)
        nested_sdfg.add_nodes_from(sdfg.nodes())
        for src, dst, data in sdfg.edges():
            nested_sdfg.add_edge(src, dst, data)

        input_orig = {}
        input_data = set()
        input_nodes = {}
        output_orig = {}
        output_data = set()
        output_nodes = {}
        for state in sdfg.nodes():
            for node in nxutil.find_source_nodes(state):
                if isinstance(
                        node,
                        nodes.AccessNode) and not node.desc(sdfg).transient:
                    if node.data not in input_data:
                        input_orig.update({node.data + '_in': node.data})
                        input_nodes.update({node.data + '_in': dc(node)})
                        new_data = dc(node.desc(sdfg))
                        input_data.add(node.data)
                        sdfg.arrays.update({node.data + '_in': new_data})
                    node.data = node.data + '_in'
            for node in nxutil.find_sink_nodes(state):
                if isinstance(
                        node,
                        nodes.AccessNode) and not node.desc(sdfg).transient:
                    if node.data not in output_data:
                        output_orig.update({node.data + '_out': node.data})
                        output_nodes.update({node.data + '_out': dc(node)})
                        new_data = dc(node.desc(sdfg))
                        output_data.add(node.data)
                        sdfg.arrays.update({node.data + '_out': new_data})

                        # WCR Fix
                        if self.promote_global_trans:
                            for edge in state.in_edges(node):
                                if sd._memlet_path(state, edge)[0].data.wcr:
                                    if node.data not in input_data:
                                        input_orig.update(
                                            {node.data + '_in': node.data})
                                        input_nodes.update(
                                            {node.data + '_in': dc(node)})
                                        new_data = dc(node.desc(sdfg))
                                        sdfg.arrays.update(
                                            {node.data + '_in': new_data})
                                        input_data.add(node.data + '_in')
                                    break

                    node.data = node.data + '_out'
            if self.promote_global_trans:
                scope_dict = state.scope_dict()
                for node in state.nodes():
                    if (isinstance(node, nodes.AccessNode)
                            and node.desc(sdfg).transient
                            and not scope_dict[node]):
                        if node.data not in output_data:
                            output_orig.update({node.data + '_out': node.data})
                            output_nodes.update({node.data + '_out': dc(node)})
                            new_data = dc(node.desc(sdfg))
                            output_data.add(node.data + '_out')
                            sdfg.arrays.update({node.data + '_out': new_data})
                        node.data = node.data + '_out'
                        node.desc(sdfg).transient = False
            for _, edge in enumerate(state.edges()):
                _, _, _, _, mem = edge
                src = sd._memlet_path(state, edge)[0].src
                dst = sd._memlet_path(state, edge)[-1].dst
                if isinstance(src,
                              nodes.AccessNode) and src.data in input_data:
                    mem.data = src.data
                if isinstance(src,
                              nodes.AccessNode) and src.data in output_data:
                    mem.data = src.data
                if isinstance(dst,
                              nodes.AccessNode) and dst.data in output_data:
                    mem.data = dst.data

        sdfg.remove_nodes_from(sdfg.nodes())

        state = sdfg.add_state(sdfg.label)
        state.add_nodes_from(input_nodes.values())
        state.add_nodes_from(output_nodes.values())

        nested_node = state.add_nested_sdfg(nested_sdfg, sdfg,
                                            input_data.keys(),
                                            output_data.keys())
        for key, val in input_nodes.items():
            state.add_edge(
                val, None, nested_node, key,
                memlet.Memlet.simple(
                    val, str(subsets.Range.from_array(val.desc(sdfg)))))
        for key, val in output_nodes.items():
            state.add_edge(
                nested_node, key, val, None,
                memlet.Memlet.simple(
                    val, str(subsets.Range.from_array(val.desc(sdfg)))))
예제 #3
0
    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            nxutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            nxutil.change_edge_src(sdfg, second_state, first_state)
            nxutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in nxutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in nxutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in nxutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        # First keep a backup of the topological sorted order of the nodes
        order = [
            x for x in reversed(list(nx.topological_sort(first_state._nx)))
            if isinstance(x, nodes.AccessNode)
        ]
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in first_input:
            try:
                old_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_src(first_state, old_node, node)
            first_state.remove_node(old_node)
            second_input.remove(old_node)
            node.access = dtypes.AccessType.ReadWrite
        for node in first_output:
            try:
                new_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_dest(first_state, node, new_node)
            first_state.remove_node(node)
            second_input.remove(new_node)
            new_node.access = dtypes.AccessType.ReadWrite
        # Check if any input nodes of the second state have to be merged with
        # non-input/output nodes of the first state.
        for node in second_input:
            if first_state.in_degree(node) == 0:
                n = next((x for x in order if x.label == node.label), None)
                if n:
                    nxutil.change_edge_src(first_state, node, n)
                    first_state.remove_node(node)
                    n.access = dtypes.AccessType.ReadWrite

        # Redirect edges and remove second state
        nxutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if Config.get_bool("debugprint"):
            StateFusion._states_fused += 1
예제 #4
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(state)
        output_nodes = nxutil.find_sink_nodes(state)

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        parent_sdfg = {state: sdfg}  # Map states to their parent SDFG
        for node, graph in state.all_nodes_recursive():
            if isinstance(graph, dace.SDFG):
                parent_sdfg[node] = graph
            if isinstance(node, dace.graph.nodes.AccessNode):
                for e in graph.all_edges(node):
                    if e.data.wcr is not None:
                        trace = dace.sdfg.trace_nested_access(
                            node, graph, parent_sdfg[graph])
                        for node_trace, state_trace, sdfg_trace in trace:
                            # Find the name of the accessed node in our scope
                            if state_trace == state and sdfg_trace == sdfg:
                                outer_name = node_trace.data
                                break
                        else:
                            # This does not trace back to the current state, so
                            # we don't care
                            continue
                        input_nodes.append(outer_name)
                        wcr_input_nodes.add(outer_name)

        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if not isinstance(node, dace.graph.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        materialize_func=desc.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet(node.data, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    nxutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            nxutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, edges.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if not isinstance(node, dace.graph.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        materialize_func=desc.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet('fpga_' + node.data,
                                    full_range.num_elements(), full_range, 1)
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                nxutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            nxutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, edges.InterstateEdge())

        veclen_ = 1

        # propagate vector info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            # need to go inside the nested SDFG and grab the vector length
            if isinstance(dst, dace.graph.nodes.NestedSDFG):
                # this edge is going to the nested SDFG
                for inner_state in dst.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == dst_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen
            if isinstance(src, dace.graph.nodes.NestedSDFG):
                # this edge is coming from the nested SDFG
                for inner_state in src.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == src_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen

            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
                mem.veclen = veclen_

        fpga_update(sdfg, state, 0)
예제 #5
0
    def apply(self, sdfg):
        begin = sdfg.nodes()[self.subgraph[DoubleBuffering._begin]]
        guard = sdfg.nodes()[self.subgraph[DoubleBuffering._guard]]
        body = sdfg.nodes()[self.subgraph[DoubleBuffering._body]]
        end = sdfg.nodes()[self.subgraph[DoubleBuffering._end]]

        loop_vars = []
        for _, dst, e in sdfg.out_edges(body):
            if dst is guard:
                for var in e.assignments.keys():
                    loop_vars.append(var)

        if len(loop_vars) != 1:
            raise NotImplementedError()

        loop_var = loop_vars[0]
        sym_var = dace.symbolic.pystr_to_symbolic(loop_var)

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(body)
        #output_nodes = nxutil.find_sink_nodes(body)

        copied_nodes = set()
        db_nodes = {}
        for node in input_nodes:
            for _, _, dst, _, mem in body.out_edges(node):
                if (isinstance(dst, dace.graph.nodes.AccessNode)
                        and loop_var in mem.subset.free_symbols):
                    # Create new data and nodes in guard
                    if node not in copied_nodes:
                        guard.add_node(node)
                        copied_nodes.add(node)
                    if dst not in copied_nodes:
                        old_data = dst.desc(sdfg)
                        if isinstance(old_data, dace.data.Array):
                            new_shape = tuple([2] + list(old_data.shape))
                            new_data = sdfg.add_array(old_data.data,
                                                      old_data.dtype,
                                                      new_shape,
                                                      transient=True)
                        elif isinstance(old_data, data.Scalar):
                            new_data = sdfg.add_array(old_data.data,
                                                      old_data.dtype, (2),
                                                      transient=True)
                        else:
                            raise NotImplementedError()
                        new_node = dace.graph.nodes.AccessNode(old_data.data)
                        guard.add_node(new_node)
                        copied_nodes.add(dst)
                        db_nodes.update({dst: new_node})
                    # Create memlet in guard
                    new_mem = copy.deepcopy(mem)
                    old_index = new_mem.other_subset
                    if isinstance(old_index, dace.subsets.Range):
                        new_ranges = [(0, 0, 1)] + old_index.ranges
                        new_mem.other_subset = dace.subsets.Range(new_ranges)
                    elif isinstance(old_index, dace.subsets.Indices):
                        new_indices = [0] + old_index.indices
                        new_mem.other_subset = dace.subsets.Indices(
                            new_indices)
                    guard.add_edge(node, None, new_node, None, new_mem)
                    # Create nodes, memlets in body
                    first_node = copy.deepcopy(new_node)
                    second_node = copy.deepcopy(new_node)
                    body.add_nodes_from([first_node, second_node])
                    dace.graph.nxutil.change_edge_dest(body, dst, first_node)
                    dace.graph.nxutil.change_edge_src(body, dst, second_node)
                    for src, _, dest, _, memm in body.edges():
                        if src is node and dest is first_node:
                            old_index = memm.other_subset
                            idx = (sym_var + 1) % 2
                            if isinstance(old_index, dace.subsets.Range):
                                new_ranges = [(idx, idx, 1)] + old_index.ranges
                            elif isinstance(old_index, dace.subsets.Indices):
                                new_ranges = [(idx, idx, 1)]
                                for index in old_index.indices:
                                    new_ranges.append((index, index, 1))
                            memm.other_subset = dace.subsets.Range(new_ranges)
                        elif memm.data == dst.data:
                            old_index = memm.subset
                            idx = sym_var % 2
                            if isinstance(old_index, dace.subsets.Range):
                                new_ranges = [(idx, idx, 1)] + old_index.ranges
                            elif isinstance(old_index, dace.subsets.Indices):
                                new_ranges = [(idx, idx, 1)]
                                for index in old_index.indices:
                                    new_ranges.append((index, index, 1))
                            memm.subset = dace.subsets.Range(new_ranges)
                            memm.data = first_node.data
                    body.remove_node(dst)
예제 #6
0
    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            nxutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            nxutil.change_edge_src(sdfg, second_state, first_state)
            nxutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in nxutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in nxutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in nxutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in first_input:
            try:
                old_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_src(first_state, old_node, node)
            first_state.remove_node(old_node)
            second_input.remove(old_node)
        for node in first_output:
            try:
                new_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_dest(first_state, node, new_node)
            first_state.remove_node(node)
            second_input.remove(new_node)

        # Redirect edges and remove second state
        nxutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if Config.get_bool("debugprint"):
            StateFusion._states_fused += 1
예제 #7
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(state)
        output_nodes = nxutil.find_sink_nodes(state)

        fpga_data = {}

        if input_nodes:

            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if array.name in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.dtype,
                        array.shape,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=types.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        access_order=array.access_order,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[array.name] = fpga_array
                fpga_node = type(node)(fpga_array)

                pre_state.add_node(node)
                pre_state.add_node(fpga_node)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(array, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(node, None, fpga_node, None, mem)

                state.add_node(fpga_node)
                nxutil.change_edge_src(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(pre_state)
            nxutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, edges.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.dtype,
                        array.shape,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=types.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        access_order=array.access_order,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[node.data] = fpga_array
                fpga_node = type(node)(fpga_array)

                post_state.add_node(node)
                post_state.add_node(fpga_node)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(fpga_array, full_range.num_elements(),
                                    full_range, 1)
                post_state.add_edge(fpga_node, None, node, None, mem)

                state.add_node(fpga_node)
                nxutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            nxutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, edges.InterstateEdge())

        for src, _, dst, _, mem in state.edges():
            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + node.data

        fpga_update(state, 0)
예제 #8
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(state)
        output_nodes = nxutil.find_sink_nodes(state)

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        for node, graph in state.all_nodes_recursive():
            if isinstance(node, dace.graph.nodes.AccessNode):
                for e in graph.all_edges(node):
                    if e.data.wcr is not None:
                        # This is an output node with wcr
                        # find the target in the parent sdfg

                        # following the structure State->SDFG->State-> SDFG
                        # from the current_state we have to go two levels up
                        parent_state = graph.parent.parent
                        if parent_state is not None:
                            for parent_edges in parent_state.edges():
                                if parent_edges.src_conn == e.dst.data or (
                                        isinstance(parent_edges.dst,
                                                   dace.graph.nodes.AccessNode)
                                        and e.dst.data
                                        == parent_edges.dst.data):
                                    # This must be copied to device
                                    input_nodes.append(parent_edges.dst)
                                    wcr_input_nodes.add(parent_edges.dst)

        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:
                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.shape,
                        array.dtype,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(node.data, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    nxutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            nxutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, edges.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:
                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.shape,
                        array.dtype,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet('fpga_' + node.data,
                                    full_range.num_elements(), full_range, 1)
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                nxutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            nxutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, edges.InterstateEdge())

        veclen_ = 1

        # propagate vector info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            # need to go inside the nested SDFG and grab the vector length
            if isinstance(dst, dace.graph.nodes.NestedSDFG):
                # this edge is going to the nested SDFG
                for inner_state in dst.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == dst_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen
            if isinstance(src, dace.graph.nodes.NestedSDFG):
                # this edge is coming from the nested SDFG
                for inner_state in src.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == src_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen

            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
                mem.veclen = veclen_

        fpga_update(sdfg, state, 0)
예제 #9
0
    def apply(self, sdfg):

        outer_sdfg = sdfg
        nested_sdfg = dc(sdfg)

        outer_sdfg.arrays.clear()
        outer_sdfg.remove_nodes_from(outer_sdfg.nodes())

        inputs = {}
        outputs = {}
        transients = {}

        for state in nested_sdfg.nodes():

            for node in nxutil.find_source_nodes(state):
                if (isinstance(node, nodes.AccessNode)
                        and not node.desc(nested_sdfg).transient):
                    arrname = node.data
                    if arrname not in inputs:
                        arrobj = nested_sdfg.arrays[arrname]
                        nested_sdfg.arrays[arrname + '_in'] = arrobj
                        outer_sdfg.arrays[arrname] = dc(arrobj)
                        inputs[arrname] = arrname + '_in'
                    node.data = arrname + '_in'

            for node in nxutil.find_sink_nodes(state):
                if (isinstance(node, nodes.AccessNode)
                        and not node.desc(nested_sdfg).transient):
                    arrname = node.data
                    if arrname not in outputs:
                        arrobj = nested_sdfg.arrays[arrname]
                        nested_sdfg.arrays[arrname + '_out'] = arrobj
                        if arrname not in inputs:
                            outer_sdfg.arrays[arrname] = dc(arrobj)
                        outputs[arrname] = arrname + '_out'

                        # TODO: Is this needed any longer ?
                        # # WCR Fix
                        # if self.promote_global_trans:
                        #     for edge in state.in_edges(node):
                        #         if state.memlet_path(edge)[0].data.wcr:
                        #             if node.data not in input_data:
                        #                 input_orig.update({
                        #                     node.data + '_in':
                        #                     node.data
                        #                 })
                        #                 input_nodes.update({
                        #                     node.data + '_in':
                        #                     dc(node)
                        #                 })
                        #                 new_data = dc(node.desc(sdfg))
                        #                 sdfg.arrays.update({
                        #                     node.data + '_in':
                        #                     new_data
                        #                 })
                        #                 input_data.add(node.data + '_in')
                        #             break

                    node.data = arrname + '_out'

            if self.promote_global_trans:
                scope_dict = state.scope_dict()
                for node in state.nodes():
                    if (isinstance(node, nodes.AccessNode)
                            and node.desc(nested_sdfg).transient):
                        arrname = node.data
                        if arrname not in transients and not scope_dict[node]:
                            arrobj = nested_sdfg.arrays[arrname]
                            nested_sdfg.arrays[arrname + '_out'] = arrobj
                            outer_sdfg.arrays[arrname] = dc(arrobj)
                            transients[arrname] = arrname + '_out'
                        node.data = arrname + '_out'

        for arrname in inputs.keys():
            nested_sdfg.arrays.pop(arrname)
        for arrname in outputs.keys():
            nested_sdfg.arrays.pop(arrname, None)
        for oldarrname, newarrname in transients.items():
            nested_sdfg.arrays.pop(oldarrname)
            nested_sdfg.arrays[newarrname].transient = False
            outer_sdfg.arrays[oldarrname].transient = False
        outputs.update(transients)

        for state in nested_sdfg.nodes():
            for _, edge in enumerate(state.edges()):
                _, _, _, _, mem = edge
                src = state.memlet_path(edge)[0].src
                dst = state.memlet_path(edge)[-1].dst
                if isinstance(src, nodes.AccessNode):
                    if (mem.data in inputs.keys()
                            and src.data == inputs[mem.data]):
                        mem.data = inputs[mem.data]
                    elif (mem.data in outputs.keys()
                          and src.data == outputs[mem.data]):
                        mem.data = outputs[mem.data]
                elif (isinstance(dst, nodes.AccessNode)
                      and mem.data in outputs.keys()
                      and dst.data == outputs[mem.data]):
                    mem.data = outputs[mem.data]

        outer_state = outer_sdfg.add_state(outer_sdfg.label)

        nested_node = outer_state.add_nested_sdfg(nested_sdfg, outer_sdfg,
                                                  inputs.values(),
                                                  outputs.values())
        for key, val in inputs.items():
            arrnode = outer_state.add_read(key)
            outer_state.add_edge(
                arrnode, None, nested_node, val,
                memlet.Memlet.from_array(key, arrnode.desc(outer_sdfg)))
        for key, val in outputs.items():
            arrnode = outer_state.add_write(key)
            outer_state.add_edge(
                nested_node, val, arrnode, None,
                memlet.Memlet.from_array(key, arrnode.desc(outer_sdfg)))
예제 #10
0
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        out_edges = graph.out_edges(first_state)
        in_edges = graph.in_edges(first_state)

        # First state must have only one output edge (with dst the second
        # state).
        if len(out_edges) != 1:
            return False
        # The interstate edge must not have a condition.
        if out_edges[0].data.condition.as_string != '':
            return False
        # The interstate edge may have assignments, as long as there are input
        # edges to the first state, that can absorb them.
        if out_edges[0].data.assignments and not in_edges:
            return False
        # There can be no state that have output edges pointing to both the
        # first and the second state. Such a case will produce a multi-graph.
        for src, _, _ in in_edges:
            for _, dst, _ in graph.out_edges(src):
                if dst == second_state:
                    return False

        if strict:

            # If second state has other input edges, there might be issues
            second_in_edges = graph.in_edges(second_state)
            if ((not second_state.is_empty() or not first_state.is_empty())
                    and len(second_in_edges) != 1):
                return False

            # Get connected components.
            first_cc = [
                cc_nodes
                for cc_nodes in nx.weakly_connected_components(first_state._nx)
            ]
            second_cc = [
                cc_nodes for cc_nodes in nx.weakly_connected_components(
                    second_state._nx)
            ]

            # Find source/sink (data) nodes
            first_input = {
                node
                for node in nxutil.find_source_nodes(first_state)
                if isinstance(node, nodes.AccessNode)
            }
            first_output = {
                node
                for node in first_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in first_input
            }
            second_input = {
                node
                for node in nxutil.find_source_nodes(second_state)
                if isinstance(node, nodes.AccessNode)
            }
            second_output = {
                node
                for node in second_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in second_input
            }

            # Find source/sink (data) nodes by connected component
            first_cc_input = [cc.intersection(first_input) for cc in first_cc]
            first_cc_output = [
                cc.intersection(first_output) for cc in first_cc
            ]
            second_cc_input = [
                cc.intersection(second_input) for cc in second_cc
            ]
            second_cc_output = [
                cc.intersection(second_output) for cc in second_cc
            ]

            check_strict = len(first_cc)
            for cc_output in first_cc_output:
                for node in cc_output:
                    if next((x for x in second_input if x.label == node.label),
                            None) is not None:
                        check_strict -= 1
                        break

            if check_strict > 0:
                # Check strict conditions
                # RW dependency
                for node in first_input:
                    if next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None:
                        return False
                # WW dependency
                for node in first_output:
                    if next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None:
                        return False

        return True