Beispiel #1
0
    def expressions():

        # Case 1: Loop with one state
        sdfg = sd.SDFG('_')
        sdfg.add_nodes_from([
            DetectLoop._loop_guard, DetectLoop._loop_begin,
            DetectLoop._exit_state
        ])
        sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin,
                      edges.InterstateEdge())
        sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state,
                      edges.InterstateEdge())
        sdfg.add_edge(DetectLoop._loop_begin, DetectLoop._loop_guard,
                      edges.InterstateEdge())

        # Case 2: Loop with multiple states (no back-edge from state)
        msdfg = sd.SDFG('_')
        msdfg.add_nodes_from([
            DetectLoop._loop_guard, DetectLoop._loop_begin,
            DetectLoop._exit_state
        ])
        msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin,
                       edges.InterstateEdge())
        msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state,
                       edges.InterstateEdge())

        return [sdfg, msdfg]
Beispiel #2
0
def create_states_simple(pdp,
                         out_sdfg,
                         start_state=None,
                         end_state=None,
                         start_edge=None):
    """ Creates a state per primitive, with the knowledge that they can be 
        optimized later.
        @param pdp: A parsed dace program.
        @param out_sdfg: The output SDFG.
        @param start_state: The starting/parent state to connect from (for
                            recursive calls).
        @param end_state: The end/parent state to connect to (for
                          recursive calls).
        @return: A dictionary mapping between a state and the list of dace
                 primitives included in it.
    """
    state_to_primitives = OrderedDict()

    # Create starting state and edge
    if start_state is None:
        start_state = out_sdfg.add_state('start')
        state_to_primitives[start_state] = []
    if start_edge is None:
        start_edge = ed.InterstateEdge()

    previous_state = start_state
    previous_edge = start_edge

    for i, primitive in enumerate(pdp.children):
        state = out_sdfg.add_state(primitive.name)
        state_to_primitives[state] = []
        # Edge that can be created on entry to control flow children
        entry_edge = None

        #########################################
        # Cases depending on primitive type
        #########################################

        # Nothing special happens with a dataflow node (nested states are
        # handled with a separate call to create_states_simple)
        if isinstance(primitive, astnodes._DataFlowNode):
            out_sdfg.add_edge(previous_state, state, previous_edge)
            state_to_primitives[state] = [primitive]
            previous_state = state
            previous_edge = ed.InterstateEdge()

        # Control flow needs to traverse into children nodes
        elif isinstance(primitive, astnodes._ControlFlowNode):
            # Iteration has >=3 states - begin, loop[...], end; and connects the
            # loop states, as well as the begin to end directly if the condition
            # did not evaluate to true
            if isinstance(primitive, astnodes._IterateNode):

                condition = ast.parse(
                    '(%s %s %s)' %
                    (primitive.params[0], '<' if primitive.range[0][2] >= 0
                     else '>', primitive.range[0][1] + 1)).body[0]
                condition_neg = astutils.negate_expr(condition)

                # Loop-start state
                lstart_state = out_sdfg.add_state(primitive.name + '_start')
                state_to_primitives[lstart_state] = []
                out_sdfg.add_edge(previous_state, lstart_state, previous_edge)
                out_sdfg.add_edge(
                    lstart_state, state,
                    ed.InterstateEdge(
                        assignments={
                            primitive.params[0]: primitive.range[0][0]
                        }))

                # Loop-end state that jumps back to `state`
                loop_state = out_sdfg.add_state(primitive.name + '_end')
                state_to_primitives[loop_state] = []
                # Connect loop
                out_sdfg.add_edge(
                    loop_state, state,
                    ed.InterstateEdge(
                        assignments={
                            primitive.params[0]:
                            symbolic.pystr_to_symbolic(primitive.params[0]) +
                            primitive.range[0][2]
                        }))

                # End connection
                previous_state = state
                previous_edge = ed.InterstateEdge(condition=condition_neg)

                # Create children states
                cmap = create_states_simple(
                    primitive, out_sdfg, state, loop_state,
                    ed.InterstateEdge(condition=condition))
                state_to_primitives.update(cmap)

            # Loop is similar to iterate, but more general w.r.t. conditions
            elif isinstance(primitive, astnodes._LoopNode):
                loop_condition = primitive.condition

                # Entry
                out_sdfg.add_edge(previous_state, state, previous_edge)

                # Loop-end state that jumps back to `state`
                loop_state = out_sdfg.add_state(primitive.name + '_end')
                state_to_primitives[loop_state] = []

                # Loopback
                out_sdfg.add_edge(loop_state, state, ed.InterstateEdge())
                # End connection
                previous_state = state
                previous_edge = ed.InterstateEdge(
                    condition=astutils.negate_expr(loop_condition))
                entry_edge = ed.InterstateEdge(condition=loop_condition)

                # Create children states
                cmap = create_states_simple(primitive, out_sdfg, state,
                                            loop_state, entry_edge)
                state_to_primitives.update(cmap)

            elif isinstance(primitive, astnodes._IfNode):
                if_condition = primitive.condition
                # Check if we have an else node, otherwise add a skip condition
                # ourselves
                if (i + 1) < len(pdp.children) and isinstance(
                        pdp.children[i + 1], astnodes._ElseNode):
                    has_else = True
                    else_prim = pdp.children[i + 1]
                    else_condition = else_prim.condition
                else:
                    has_else = False
                    else_condition = astutils.negate_expr(primitive.condition)

                # End-of-branch state (converge to this)
                bend_state = out_sdfg.add_state(primitive.name + '_end')
                state_to_primitives[bend_state] = []

                # Entry
                out_sdfg.add_edge(previous_state, state, previous_edge)

                # Create children states
                cmap = create_states_simple(
                    primitive, out_sdfg, state, bend_state,
                    ed.InterstateEdge(condition=if_condition))
                state_to_primitives.update(cmap)

                # Handle 'else' condition
                if not has_else:
                    out_sdfg.add_edge(
                        state, bend_state,
                        ed.InterstateEdge(condition=else_condition))
                else:
                    # Recursively parse 'else' primitive's children
                    cmap = create_states_simple(
                        else_prim, out_sdfg, state, bend_state,
                        ed.InterstateEdge(condition=else_condition))
                    state_to_primitives.update(cmap)

                # Exit
                previous_state = bend_state
                previous_edge = ed.InterstateEdge()

            elif isinstance(primitive, astnodes._ElseNode):
                if i - 1 < 0 or not isinstance(pdp.children[i - 1],
                                               astnodes._IfNode):
                    raise SyntaxError('Found else state without matching if')

                # If 'else' state is correct, we already processed it
                del state_to_primitives[state]
                out_sdfg.remove_node(state)

    # Connect to end_state (and create it if necessary)
    if end_state is None:
        end_state = out_sdfg.add_state('end')
        state_to_primitives[end_state] = []
    out_sdfg.add_edge(previous_state, end_state, previous_edge)

    return state_to_primitives
Beispiel #3
0
class StateFusion(pattern_matching.Transformation):
    """ Implements the state-fusion transformation.
        
        State-fusion takes two states that are connected through a single edge,
        and fuses them into one state. If strict, only applies if no memory 
        access hazards are created.
    """

    _states_fused = 0
    _first_state = sdfg.SDFGState()
    _edge = edges.InterstateEdge()
    _second_state = sdfg.SDFGState()

    @staticmethod
    def annotates_memlets():
        return False

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(StateFusion._first_state,
                                   StateFusion._second_state)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        out_edges = graph.out_edges(first_state)
        in_edges = graph.in_edges(first_state)

        # First state must have only one output edge (with dst the second
        # state).
        if len(out_edges) != 1:
            return False
        # The interstate edge must not have a condition.
        if not out_edges[0].data.is_unconditional():
            return False
        # The interstate edge may have assignments, as long as there are input
        # edges to the first state, that can absorb them.
        if out_edges[0].data.assignments and not in_edges:
            return False
        # There can be no state that have output edges pointing to both the
        # first and the second state. Such a case will produce a multi-graph.
        for src, _, _ in in_edges:
            for _, dst, _ in graph.out_edges(src):
                if dst == second_state:
                    return False

        if strict:
            # If second state has other input edges, there might be issues
            # Exceptions are when none of the states contain dataflow, unless
            # the first state is an initial state (in which case the new initial
            # state would be ambiguous).
            first_in_edges = graph.in_edges(first_state)
            second_in_edges = graph.in_edges(second_state)
            if ((not second_state.is_empty() or not first_state.is_empty()
                 or len(first_in_edges) == 0) and len(second_in_edges) != 1):
                return False

            # Get connected components.
            first_cc = [
                cc_nodes
                for cc_nodes in nx.weakly_connected_components(first_state._nx)
            ]
            second_cc = [
                cc_nodes for cc_nodes in nx.weakly_connected_components(
                    second_state._nx)
            ]

            # Find source/sink (data) nodes
            first_input = {
                node
                for node in nxutil.find_source_nodes(first_state)
                if isinstance(node, nodes.AccessNode)
            }
            first_output = {
                node
                for node in first_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in first_input
            }
            second_input = {
                node
                for node in nxutil.find_source_nodes(second_state)
                if isinstance(node, nodes.AccessNode)
            }
            second_output = {
                node
                for node in second_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in second_input
            }

            # Find source/sink (data) nodes by connected component
            first_cc_input = [cc.intersection(first_input) for cc in first_cc]
            first_cc_output = [
                cc.intersection(first_output) for cc in first_cc
            ]
            second_cc_input = [
                cc.intersection(second_input) for cc in second_cc
            ]
            second_cc_output = [
                cc.intersection(second_output) for cc in second_cc
            ]

            # Apply transformation in case all paths to the second state's
            # nodes go through the same access node, which implies sequential
            # behavior in SDFG semantics.
            check_strict = len(first_cc)
            for cc_output in first_cc_output:
                out_nodes = [
                    n for n in first_state.sink_nodes() if n in cc_output
                ]
                # Branching exists, multiple paths may involve same access node
                # potentially causing data races
                if len(out_nodes) > 1:
                    continue

                # Otherwise, check if any of the second state's connected
                # components for matching input
                for node in cc_output:
                    if (next(
                        (x for x in second_input if x.label == node.label),
                            None) is not None):
                        check_strict -= 1
                        break

            if check_strict > 0:
                # Check strict conditions
                # RW dependency
                for node in first_input:
                    if (next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None):
                        return False
                # WW dependency
                for node in first_output:
                    if (next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None):
                        return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        return " -> ".join(state.label
                           for state in [first_state, second_state])

    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            nxutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            nxutil.change_edge_src(sdfg, second_state, first_state)
            nxutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in nxutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in nxutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in nxutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        # First keep a backup of the topological sorted order of the nodes
        order = [
            x for x in reversed(list(nx.topological_sort(first_state._nx)))
            if isinstance(x, nodes.AccessNode)
        ]
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in first_input:
            try:
                old_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_src(first_state, old_node, node)
            first_state.remove_node(old_node)
            second_input.remove(old_node)
            node.access = dtypes.AccessType.ReadWrite
        for node in first_output:
            try:
                new_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_dest(first_state, node, new_node)
            first_state.remove_node(node)
            second_input.remove(new_node)
            new_node.access = dtypes.AccessType.ReadWrite
        # Check if any input nodes of the second state have to be merged with
        # non-input/output nodes of the first state.
        for node in second_input:
            if first_state.in_degree(node) == 0:
                n = next((x for x in order if x.label == node.label), None)
                if n:
                    nxutil.change_edge_src(first_state, node, n)
                    first_state.remove_node(node)
                    n.access = dtypes.AccessType.ReadWrite

        # Redirect edges and remove second state
        nxutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if Config.get_bool("debugprint"):
            StateFusion._states_fused += 1
Beispiel #4
0
    def apply(self, sdfg):
        # Retrieve map entry and exit nodes.
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[MapToForLoop._map_entry]]
        map_exits = graph.exit_nodes(map_entry)
        loop_idx = map_entry.map.params[0]
        loop_from, loop_to, loop_step = map_entry.map.range[0]

        nested_sdfg = dace.SDFG(graph.label + '_' + map_entry.map.label)

        # Construct nested SDFG
        begin = nested_sdfg.add_state('begin')
        guard = nested_sdfg.add_state('guard')
        body = nested_sdfg.add_state('body')
        end = nested_sdfg.add_state('end')

        nested_sdfg.add_edge(
            begin, guard,
            edges.InterstateEdge(assignments={str(loop_idx): str(loop_from)}))
        nested_sdfg.add_edge(
            guard,
            body,
            edges.InterstateEdge(condition = str(loop_idx) + ' <= ' + \
                                             str(loop_to))
        )
        nested_sdfg.add_edge(
            guard,
            end,
            edges.InterstateEdge(condition = str(loop_idx) + ' > ' + \
                                             str(loop_to))
        )
        nested_sdfg.add_edge(
            body,
            guard,
            edges.InterstateEdge(assignments = {str(loop_idx): str(loop_idx) + \
                                                ' + ' +str(loop_step)})
        )

        # Add map contents
        map_subgraph = graph.scope_subgraph(map_entry)
        for node in map_subgraph.nodes():
            if node is not map_entry and node not in map_exits:
                body.add_node(node)
        for src, src_conn, dst, dst_conn, memlet in map_subgraph.edges():
            if src is not map_entry and dst not in map_exits:
                body.add_edge(src, src_conn, dst, dst_conn, memlet)

        # Reconnect inputs
        nested_in_data_nodes = {}
        nested_in_connectors = {}
        nested_in_memlets = {}
        for i, edge in enumerate(graph.in_edges(map_entry)):
            src, src_conn, dst, dst_conn, memlet = edge
            data_label = '_in_' + memlet.data
            memdata = sdfg.arrays[memlet.data]
            if isinstance(memdata, data.Array):
                data_array = sdfg.add_array(data_label, memdata.dtype, [
                    symbolic.overapproximate(r)
                    for r in memlet.bounding_box_size()
                ])
            elif isinstance(memdata, data.Scalar):
                data_array = sdfg.add_scalar(data_label, memdata.dtype)
            else:
                raise NotImplementedError()
            data_node = nodes.AccessNode(data_label)
            body.add_node(data_node)
            nested_in_data_nodes.update({i: data_node})
            nested_in_connectors.update({i: data_label})
            nested_in_memlets.update({i: memlet})
            for _, _, _, _, old_memlet in body.edges():
                if old_memlet.data == memlet.data:
                    old_memlet.data = data_label
            #body.add_edge(data_node, None, dst, dst_conn, memlet)

        # Reconnect outputs
        nested_out_data_nodes = {}
        nested_out_connectors = {}
        nested_out_memlets = {}
        for map_exit in map_exits:
            for i, edge in enumerate(graph.out_edges(map_exit)):
                src, src_conn, dst, dst_conn, memlet = edge
                data_label = '_out_' + memlet.data
                memdata = sdfg.arrays[memlet.data]
                if isinstance(memdata, data.Array):
                    data_array = sdfg.add_array(data_label, memdata.dtype, [
                        symbolic.overapproximate(r)
                        for r in memlet.bounding_box_size()
                    ])
                elif isinstance(memdata, data.Scalar):
                    data_array = sdfg.add_scalar(data_label, memdata.dtype)
                else:
                    raise NotImplementedError()
                data_node = nodes.AccessNode(data_label)
                body.add_node(data_node)
                nested_out_data_nodes.update({i: data_node})
                nested_out_connectors.update({i: data_label})
                nested_out_memlets.update({i: memlet})
                for _, _, _, _, old_memlet in body.edges():
                    if old_memlet.data == memlet.data:
                        old_memlet.data = data_label
                #body.add_edge(src, src_conn, data_node, None, memlet)

        # Add nested SDFG and reconnect it
        nested_node = graph.add_nested_sdfg(
            nested_sdfg, sdfg, set(nested_in_connectors.values()),
            set(nested_out_connectors.values()))

        for i, edge in enumerate(graph.in_edges(map_entry)):
            src, src_conn, dst, dst_conn, memlet = edge
            graph.add_edge(src, src_conn, nested_node, nested_in_connectors[i],
                           nested_in_memlets[i])

        for map_exit in map_exits:
            for i, edge in enumerate(graph.out_edges(map_exit)):
                src, src_conn, dst, dst_conn, memlet = edge
                graph.add_edge(nested_node, nested_out_connectors[i], dst,
                               dst_conn, nested_out_memlets[i])

        for src, src_conn, dst, dst_conn, memlet in graph.out_edges(map_entry):
            i = int(src_conn[4:]) - 1
            new_memlet = dcpy(memlet)
            new_memlet.data = nested_in_data_nodes[i].data
            body.add_edge(nested_in_data_nodes[i], None, dst, dst_conn,
                          new_memlet)

        for map_exit in map_exits:
            for src, src_conn, dst, dst_conn, memlet in graph.in_edges(
                    map_exit):
                i = int(dst_conn[3:]) - 1
                new_memlet = dcpy(memlet)
                new_memlet.data = nested_out_data_nodes[i].data
                body.add_edge(src, src_conn, nested_out_data_nodes[i], None,
                              new_memlet)

        for node in map_subgraph:
            graph.remove_node(node)
Beispiel #5
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(state)
        output_nodes = nxutil.find_sink_nodes(state)

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        parent_sdfg = {state: sdfg}  # Map states to their parent SDFG
        for node, graph in state.all_nodes_recursive():
            if isinstance(graph, dace.SDFG):
                parent_sdfg[node] = graph
            if isinstance(node, dace.graph.nodes.AccessNode):
                for e in graph.all_edges(node):
                    if e.data.wcr is not None:
                        trace = dace.sdfg.trace_nested_access(
                            node, graph, parent_sdfg[graph])
                        for node_trace, state_trace, sdfg_trace in trace:
                            # Find the name of the accessed node in our scope
                            if state_trace == state and sdfg_trace == sdfg:
                                outer_name = node_trace.data
                                break
                        else:
                            # This does not trace back to the current state, so
                            # we don't care
                            continue
                        input_nodes.append(outer_name)
                        wcr_input_nodes.add(outer_name)

        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if not isinstance(node, dace.graph.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        materialize_func=desc.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet(node.data, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    nxutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            nxutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, edges.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if not isinstance(node, dace.graph.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        materialize_func=desc.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet('fpga_' + node.data,
                                    full_range.num_elements(), full_range, 1)
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                nxutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            nxutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, edges.InterstateEdge())

        veclen_ = 1

        # propagate vector info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            # need to go inside the nested SDFG and grab the vector length
            if isinstance(dst, dace.graph.nodes.NestedSDFG):
                # this edge is going to the nested SDFG
                for inner_state in dst.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == dst_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen
            if isinstance(src, dace.graph.nodes.NestedSDFG):
                # this edge is coming from the nested SDFG
                for inner_state in src.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == src_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen

            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
                mem.veclen = veclen_

        fpga_update(sdfg, state, 0)
Beispiel #6
0
    def apply(self, sdfg: sd.SDFG):

        #######################################################
        # Step 0: SDFG metadata

        # Find all input and output data descriptors
        input_nodes = []
        output_nodes = []
        global_code_nodes = [[] for _ in sdfg.nodes()]

        for i, state in enumerate(sdfg.nodes()):
            sdict = state.scope_dict()
            for node in state.nodes():
                if (isinstance(node, nodes.AccessNode)
                        and node.desc(sdfg).transient == False):
                    if (state.out_degree(node) > 0
                            and node.data not in input_nodes):
                        input_nodes.append((node.data, node.desc(sdfg)))
                    if (state.in_degree(node) > 0
                            and node.data not in output_nodes):
                        output_nodes.append((node.data, node.desc(sdfg)))
                elif isinstance(node, nodes.CodeNode) and sdict[node] is None:
                    if not isinstance(node, nodes.EmptyTasklet):
                        global_code_nodes[i].append(node)

            # Input nodes may also be nodes with WCR memlets and no identity
            for e in state.edges():
                if e.data.wcr is not None and e.data.wcr_identity is None:
                    if (e.data.data not in input_nodes
                            and sdfg.arrays[e.data.data].transient == False):
                        input_nodes.append(e.data.data)

        start_state = sdfg.start_state
        end_states = sdfg.sink_nodes()

        #######################################################
        # Step 1: Create cloned GPU arrays and replace originals

        cloned_arrays = {}
        for inodename, inode in input_nodes:
            newdesc = inode.clone()
            newdesc.storage = types.StorageType.GPU_Global
            newdesc.transient = True
            sdfg.add_datadesc('gpu_' + inodename, newdesc)
            cloned_arrays[inodename] = 'gpu_' + inodename

        for onodename, onode in output_nodes:
            if onodename in cloned_arrays:
                continue
            newdesc = onode.clone()
            newdesc.storage = types.StorageType.GPU_Global
            newdesc.transient = True
            sdfg.add_datadesc('gpu_' + onodename, newdesc)
            cloned_arrays[onodename] = 'gpu_' + onodename

        # Replace nodes
        for state in sdfg.nodes():
            for node in state.nodes():
                if (isinstance(node, nodes.AccessNode)
                        and node.data in cloned_arrays):
                    node.data = cloned_arrays[node.data]

        # Replace memlets
        for state in sdfg.nodes():
            for edge in state.edges():
                if edge.data.data in cloned_arrays:
                    edge.data.data = cloned_arrays[edge.data.data]

        #######################################################
        # Step 2: Create copy-in state

        copyin_state = sdfg.add_state(sdfg.label + '_copyin')
        sdfg.add_edge(copyin_state, start_state, ed.InterstateEdge())

        for nname, desc in input_nodes:
            src_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo)
            dst_array = nodes.AccessNode(cloned_arrays[nname],
                                         debuginfo=desc.debuginfo)
            copyin_state.add_node(src_array)
            copyin_state.add_node(dst_array)
            copyin_state.add_nedge(
                src_array, dst_array,
                memlet.Memlet.from_array(src_array.data, src_array.desc(sdfg)))

        #######################################################
        # Step 3: Create copy-out state

        copyout_state = sdfg.add_state(sdfg.label + '_copyout')
        for state in end_states:
            sdfg.add_edge(state, copyout_state, ed.InterstateEdge())

        for nname, desc in output_nodes:
            src_array = nodes.AccessNode(cloned_arrays[nname],
                                         debuginfo=desc.debuginfo)
            dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo)
            copyout_state.add_node(src_array)
            copyout_state.add_node(dst_array)
            copyout_state.add_nedge(
                src_array, dst_array,
                memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg)))

        #######################################################
        # Step 4: Modify transient data storage

        for state in sdfg.nodes():
            sdict = state.scope_dict()
            for node in state.nodes():
                if isinstance(node,
                              nodes.AccessNode) and node.desc(sdfg).transient:
                    nodedesc = node.desc(sdfg)
                    if sdict[node] is None:
                        # NOTE: the cloned arrays match too but it's the same
                        # storage so we don't care
                        nodedesc.storage = types.StorageType.GPU_Global

                        # Try to move allocation/deallocation out of loops
                        if self.toplevel_trans:
                            nodedesc.toplevel = True
                    else:
                        # Make internal transients registers
                        if self.register_trans:
                            nodedesc.storage = types.StorageType.Register

        #######################################################
        # Step 5: Wrap free tasklets and nested SDFGs with a GPU map

        for state, gcodes in zip(sdfg.nodes(), global_code_nodes):
            for gcode in gcodes:
                # Create map and connectors
                me, mx = state.add_map(gcode.label + '_gmap',
                                       {gcode.label + '__gmapi': '0:1'},
                                       schedule=types.ScheduleType.GPU_Device)
                # Store in/out edges in lists so that they don't get corrupted
                # when they are removed from the graph
                in_edges = list(state.in_edges(gcode))
                out_edges = list(state.out_edges(gcode))
                me.in_connectors = set('IN_' + e.dst_conn for e in in_edges)
                me.out_connectors = set('OUT_' + e.dst_conn for e in in_edges)
                mx.in_connectors = set('IN_' + e.src_conn for e in out_edges)
                mx.out_connectors = set('OUT_' + e.src_conn for e in out_edges)

                # Create memlets through map
                for e in in_edges:
                    state.remove_edge(e)
                    state.add_edge(e.src, e.src_conn, me, 'IN_' + e.dst_conn,
                                   e.data)
                    state.add_edge(me, 'OUT_' + e.dst_conn, e.dst, e.dst_conn,
                                   e.data)
                for e in out_edges:
                    state.remove_edge(e)
                    state.add_edge(e.src, e.src_conn, mx, 'IN_' + e.src_conn,
                                   e.data)
                    state.add_edge(mx, 'OUT_' + e.src_conn, e.dst, e.dst_conn,
                                   e.data)

                # Map without inputs
                if len(in_edges) == 0:
                    state.add_nedge(me, gcode, memlet.EmptyMemlet())
        #######################################################
        # Step 6: Change all top-level maps to GPU maps

        for i, state in enumerate(sdfg.nodes()):
            sdict = state.scope_dict()
            for node in state.nodes():
                if isinstance(node, nodes.EntryNode):
                    if sdict[node] is None:
                        node.schedule = types.ScheduleType.GPU_Device
                    elif self.sequential_innermaps:
                        node.schedule = types.ScheduleType.Sequential

        #######################################################
        # Step 7: Strict transformations
        if not self.strict_transform:
            return

        # Apply strict state fusions greedily.
        opt = optimizer.SDFGOptimizer(sdfg, inplace=True)
        fusions = 0
        arrays = 0
        options = [
            match for match in opt.get_pattern_matches(strict=True)
            if isinstance(match, (StateFusion, RedundantArray))
        ]
        while options:
            ssdfg = sdfg.sdfg_list[options[0].sdfg_id]
            options[0].apply(ssdfg)
            ssdfg.validate()
            if isinstance(options[0], StateFusion):
                fusions += 1
            if isinstance(options[0], RedundantArray):
                arrays += 1

            options = [
                match for match in opt.get_pattern_matches(strict=True)
                if isinstance(match, (StateFusion, RedundantArray))
            ]

        if Config.get_bool('debugprint') and (fusions > 0 or arrays > 0):
            print('Automatically applied {} strict state fusions and removed'
                  ' {} redundant arrays.'.format(fusions, arrays))
Beispiel #7
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(state)
        output_nodes = nxutil.find_sink_nodes(state)

        fpga_data = {}

        if input_nodes:

            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if array.name in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.dtype,
                        array.shape,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=types.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        access_order=array.access_order,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[array.name] = fpga_array
                fpga_node = type(node)(fpga_array)

                pre_state.add_node(node)
                pre_state.add_node(fpga_node)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(array, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(node, None, fpga_node, None, mem)

                state.add_node(fpga_node)
                nxutil.change_edge_src(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(pre_state)
            nxutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, edges.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.dtype,
                        array.shape,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=types.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        access_order=array.access_order,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[node.data] = fpga_array
                fpga_node = type(node)(fpga_array)

                post_state.add_node(node)
                post_state.add_node(fpga_node)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(fpga_array, full_range.num_elements(),
                                    full_range, 1)
                post_state.add_edge(fpga_node, None, node, None, mem)

                state.add_node(fpga_node)
                nxutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            nxutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, edges.InterstateEdge())

        for src, _, dst, _, mem in state.edges():
            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + node.data

        fpga_update(state, 0)
Beispiel #8
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(state)
        output_nodes = nxutil.find_sink_nodes(state)

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        for node, graph in state.all_nodes_recursive():
            if isinstance(node, dace.graph.nodes.AccessNode):
                for e in graph.all_edges(node):
                    if e.data.wcr is not None:
                        # This is an output node with wcr
                        # find the target in the parent sdfg

                        # following the structure State->SDFG->State-> SDFG
                        # from the current_state we have to go two levels up
                        parent_state = graph.parent.parent
                        if parent_state is not None:
                            for parent_edges in parent_state.edges():
                                if parent_edges.src_conn == e.dst.data or (
                                        isinstance(parent_edges.dst,
                                                   dace.graph.nodes.AccessNode)
                                        and e.dst.data
                                        == parent_edges.dst.data):
                                    # This must be copied to device
                                    input_nodes.append(parent_edges.dst)
                                    wcr_input_nodes.add(parent_edges.dst)

        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:
                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.shape,
                        array.dtype,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(node.data, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    nxutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            nxutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, edges.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:
                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.shape,
                        array.dtype,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet('fpga_' + node.data,
                                    full_range.num_elements(), full_range, 1)
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                nxutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            nxutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, edges.InterstateEdge())

        veclen_ = 1

        # propagate vector info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            # need to go inside the nested SDFG and grab the vector length
            if isinstance(dst, dace.graph.nodes.NestedSDFG):
                # this edge is going to the nested SDFG
                for inner_state in dst.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == dst_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen
            if isinstance(src, dace.graph.nodes.NestedSDFG):
                # this edge is coming from the nested SDFG
                for inner_state in src.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == src_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen

            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
                mem.veclen = veclen_

        fpga_update(sdfg, state, 0)
Beispiel #9
0
    def apply(self, sdfg: sd.SDFG):

        #######################################################
        # Step 0: SDFG metadata

        # Find all input and output data descriptors
        input_nodes = []
        output_nodes = []
        global_code_nodes = [[] for _ in sdfg.nodes()]

        for i, state in enumerate(sdfg.nodes()):
            sdict = state.scope_dict()
            for node in state.nodes():
                if (isinstance(node, nodes.AccessNode)
                        and node.desc(sdfg).transient == False):
                    if (state.out_degree(node) > 0
                            and node.data not in input_nodes):
                        # Special case: nodes that lead to dynamic map ranges
                        # must stay on host
                        for e in state.out_edges(node):
                            last_edge = state.memlet_path(e)[-1]
                            if (isinstance(last_edge.dst, nodes.EntryNode)
                                    and last_edge.dst_conn and
                                    not last_edge.dst_conn.startswith('IN_')):
                                break
                        else:
                            input_nodes.append((node.data, node.desc(sdfg)))
                    if (state.in_degree(node) > 0
                            and node.data not in output_nodes):
                        output_nodes.append((node.data, node.desc(sdfg)))
                elif isinstance(node, nodes.CodeNode) and sdict[node] is None:
                    if not isinstance(node, nodes.EmptyTasklet):
                        global_code_nodes[i].append(node)

            # Input nodes may also be nodes with WCR memlets and no identity
            for e in state.edges():
                if e.data.wcr is not None and e.data.wcr_identity is None:
                    if (e.data.data not in input_nodes
                            and sdfg.arrays[e.data.data].transient == False):
                        input_nodes.append(
                            (e.data.data, sdfg.arrays[e.data.data]))

        start_state = sdfg.start_state
        end_states = sdfg.sink_nodes()

        #######################################################
        # Step 1: Create cloned GPU arrays and replace originals

        cloned_arrays = {}
        for inodename, inode in set(input_nodes):
            if isinstance(inode, data.Scalar):  # Scalars can remain on host
                continue
            newdesc = inode.clone()
            newdesc.storage = dtypes.StorageType.GPU_Global
            newdesc.transient = True
            name = sdfg.add_datadesc('gpu_' + inodename,
                                     newdesc,
                                     find_new_name=True)
            cloned_arrays[inodename] = name

        for onodename, onode in set(output_nodes):
            if onodename in cloned_arrays:
                continue
            newdesc = onode.clone()
            newdesc.storage = dtypes.StorageType.GPU_Global
            newdesc.transient = True
            name = sdfg.add_datadesc('gpu_' + onodename,
                                     newdesc,
                                     find_new_name=True)
            cloned_arrays[onodename] = name

        # Replace nodes
        for state in sdfg.nodes():
            for node in state.nodes():
                if (isinstance(node, nodes.AccessNode)
                        and node.data in cloned_arrays):
                    node.data = cloned_arrays[node.data]

        # Replace memlets
        for state in sdfg.nodes():
            for edge in state.edges():
                if edge.data.data in cloned_arrays:
                    edge.data.data = cloned_arrays[edge.data.data]

        #######################################################
        # Step 2: Create copy-in state
        excluded_copyin = self.exclude_copyin.split(',')

        copyin_state = sdfg.add_state(sdfg.label + '_copyin')
        sdfg.add_edge(copyin_state, start_state, ed.InterstateEdge())

        for nname, desc in dtypes.deduplicate(input_nodes):
            if nname in excluded_copyin or nname not in cloned_arrays:
                continue
            src_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo)
            dst_array = nodes.AccessNode(cloned_arrays[nname],
                                         debuginfo=desc.debuginfo)
            copyin_state.add_node(src_array)
            copyin_state.add_node(dst_array)
            copyin_state.add_nedge(
                src_array, dst_array,
                memlet.Memlet.from_array(src_array.data, src_array.desc(sdfg)))

        #######################################################
        # Step 3: Create copy-out state
        excluded_copyout = self.exclude_copyout.split(',')

        copyout_state = sdfg.add_state(sdfg.label + '_copyout')
        for state in end_states:
            sdfg.add_edge(state, copyout_state, ed.InterstateEdge())

        for nname, desc in dtypes.deduplicate(output_nodes):
            if nname in excluded_copyout or nname not in cloned_arrays:
                continue
            src_array = nodes.AccessNode(cloned_arrays[nname],
                                         debuginfo=desc.debuginfo)
            dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo)
            copyout_state.add_node(src_array)
            copyout_state.add_node(dst_array)
            copyout_state.add_nedge(
                src_array, dst_array,
                memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg)))

        #######################################################
        # Step 4: Modify transient data storage

        for state in sdfg.nodes():
            sdict = state.scope_dict()
            for node in state.nodes():
                if isinstance(node,
                              nodes.AccessNode) and node.desc(sdfg).transient:
                    nodedesc = node.desc(sdfg)

                    # Special case: nodes that lead to dynamic map ranges must
                    # stay on host
                    if any(
                            isinstance(
                                state.memlet_path(e)[-1].dst, nodes.EntryNode)
                            for e in state.out_edges(node)):
                        continue

                    if sdict[node] is None:
                        # NOTE: the cloned arrays match too but it's the same
                        # storage so we don't care
                        nodedesc.storage = dtypes.StorageType.GPU_Global

                        # Try to move allocation/deallocation out of loops
                        if (self.toplevel_trans
                                and not isinstance(nodedesc, data.Stream)):
                            nodedesc.toplevel = True
                    else:
                        # Make internal transients registers
                        if self.register_trans:
                            nodedesc.storage = dtypes.StorageType.Register

        #######################################################
        # Step 5: Wrap free tasklets and nested SDFGs with a GPU map

        for state, gcodes in zip(sdfg.nodes(), global_code_nodes):
            for gcode in gcodes:
                if gcode.label in self.exclude_tasklets.split(','):
                    continue
                # Create map and connectors
                me, mx = state.add_map(gcode.label + '_gmap',
                                       {gcode.label + '__gmapi': '0:1'},
                                       schedule=dtypes.ScheduleType.GPU_Device)
                # Store in/out edges in lists so that they don't get corrupted
                # when they are removed from the graph
                in_edges = list(state.in_edges(gcode))
                out_edges = list(state.out_edges(gcode))
                me.in_connectors = set('IN_' + e.dst_conn for e in in_edges)
                me.out_connectors = set('OUT_' + e.dst_conn for e in in_edges)
                mx.in_connectors = set('IN_' + e.src_conn for e in out_edges)
                mx.out_connectors = set('OUT_' + e.src_conn for e in out_edges)

                # Create memlets through map
                for e in in_edges:
                    state.remove_edge(e)
                    state.add_edge(e.src, e.src_conn, me, 'IN_' + e.dst_conn,
                                   e.data)
                    state.add_edge(me, 'OUT_' + e.dst_conn, e.dst, e.dst_conn,
                                   e.data)
                for e in out_edges:
                    state.remove_edge(e)
                    state.add_edge(e.src, e.src_conn, mx, 'IN_' + e.src_conn,
                                   e.data)
                    state.add_edge(mx, 'OUT_' + e.src_conn, e.dst, e.dst_conn,
                                   e.data)

                # Map without inputs
                if len(in_edges) == 0:
                    state.add_nedge(me, gcode, memlet.EmptyMemlet())
        #######################################################
        # Step 6: Change all top-level maps and Reduce nodes to GPU schedule

        for i, state in enumerate(sdfg.nodes()):
            sdict = state.scope_dict()
            for node in state.nodes():
                if isinstance(node, (nodes.EntryNode, nodes.Reduce)):
                    if sdict[node] is None:
                        node.schedule = dtypes.ScheduleType.GPU_Device
                    elif (isinstance(node, nodes.EntryNode)
                          and self.sequential_innermaps):
                        node.schedule = dtypes.ScheduleType.Sequential

        #######################################################
        # Step 7: Introduce copy-out if data used in outgoing interstate edges

        for state in list(sdfg.nodes()):
            arrays_used = set()
            for e in sdfg.out_edges(state):
                # Used arrays = intersection between symbols and cloned arrays
                arrays_used.update(
                    set(e.data.condition_symbols())
                    & set(cloned_arrays.keys()))

            # Create a state and copy out used arrays
            if len(arrays_used) > 0:
                co_state = sdfg.add_state(state.label + '_icopyout')

                # Reconnect outgoing edges to after interim copyout state
                for e in sdfg.out_edges(state):
                    nxutil.change_edge_src(sdfg, state, co_state)
                # Add unconditional edge to interim state
                sdfg.add_edge(state, co_state, ed.InterstateEdge())

                # Add copy-out nodes
                for nname in arrays_used:
                    desc = sdfg.arrays[nname]
                    src_array = nodes.AccessNode(cloned_arrays[nname],
                                                 debuginfo=desc.debuginfo)
                    dst_array = nodes.AccessNode(nname,
                                                 debuginfo=desc.debuginfo)
                    co_state.add_node(src_array)
                    co_state.add_node(dst_array)
                    co_state.add_nedge(
                        src_array, dst_array,
                        memlet.Memlet.from_array(dst_array.data,
                                                 dst_array.desc(sdfg)))

        #######################################################
        # Step 8: Strict transformations
        if not self.strict_transform:
            return

        # Apply strict state fusions greedily.
        sdfg.apply_strict_transformations()
Beispiel #10
0
    def apply(self, sdfg):
        # Obtain loop information
        guard: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_guard])
        begin: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_begin])
        after_state: sd.SDFGState = sdfg.node(
            self.subgraph[DetectLoop._exit_state])

        # Obtain iteration variable, range, and stride
        guard_inedges = sdfg.in_edges(guard)
        condition_edge = sdfg.edges_between(guard, begin)[0]
        itervar = list(guard_inedges[0].data.assignments.keys())[0]
        condition = condition_edge.data.condition_sympy()
        rng = LoopUnroll._loop_range(itervar, guard_inedges, condition)

        # Loop must be unrollable
        if self.count == 0 and any(
                symbolic.issymbolic(r, sdfg.constants) for r in rng):
            raise ValueError('Loop cannot be fully unrolled, size is symbolic')
        if self.count != 0:
            raise NotImplementedError  # TODO(later)

        # Find the state prior to the loop
        if str(rng[0]) == guard_inedges[0].data.assignments[itervar]:
            before_state: sd.SDFGState = guard_inedges[0].src
            last_state: sd.SDFGState = guard_inedges[1].src
        else:
            before_state: sd.SDFGState = guard_inedges[1].src
            last_state: sd.SDFGState = guard_inedges[0].src

        # Get loop states
        loop_states = list(
            nxutil.dfs_topological_sort(
                sdfg,
                sources=[begin],
                condition=lambda _, child: child != guard))
        first_id = loop_states.index(begin)
        last_id = loop_states.index(last_state)
        loop_subgraph = gr.SubgraphView(sdfg, loop_states)

        # Evaluate the real values of the loop
        start, end, stride = (symbolic.evaluate(r, sdfg.constants)
                              for r in rng)

        # Create states for loop subgraph
        unrolled_states = []
        for i in range(start, end + 1, stride):
            # Using to/from JSON copies faster than deepcopy (which will also
            # copy the parent SDFG)
            new_states = [
                sd.SDFGState.from_json(s.to_json(), context={'sdfg': sdfg})
                for s in loop_states
            ]

            # Replace iterate with value in each state
            for state in new_states:
                state.set_label(state.label + '_%s_%d' % (itervar, i))
                state.replace(itervar, str(i))

            # Add subgraph to original SDFG
            for edge in loop_subgraph.edges():
                src = new_states[loop_states.index(edge.src)]
                dst = new_states[loop_states.index(edge.dst)]

                # Replace conditions in subgraph edges
                data: edges.InterstateEdge = copy.deepcopy(edge.data)
                if data.condition:
                    ASTFindReplace({itervar: str(i)}).visit(data.condition)

                sdfg.add_edge(src, dst, data)

            # Connect iterations with unconditional edges
            if len(unrolled_states) > 0:
                sdfg.add_edge(unrolled_states[-1][1], new_states[first_id],
                              edges.InterstateEdge())

            unrolled_states.append((new_states[first_id], new_states[last_id]))

        # Connect new states to before and after states without conditions
        if unrolled_states:
            sdfg.add_edge(before_state, unrolled_states[0][0],
                          edges.InterstateEdge())
            sdfg.add_edge(unrolled_states[-1][1], after_state,
                          edges.InterstateEdge())

        # Remove old states from SDFG
        sdfg.remove_nodes_from([guard] + loop_states)
Beispiel #11
0
class StateFusion(pattern_matching.Transformation):
    """ Implements the state-fusion transformation.
        
        State-fusion takes two states that are connected through a single edge,
        and fuses them into one state. If strict, only applies if no memory 
        access hazards are created.
    """

    _first_state = sdfg.SDFGState()
    _edge = edges.InterstateEdge()
    _second_state = sdfg.SDFGState()

    @staticmethod
    def annotates_memlets():
        return False

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(StateFusion._first_state,
                                   StateFusion._second_state)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        out_edges = graph.out_edges(first_state)
        in_edges = graph.in_edges(first_state)

        # First state must have only one output edge (with dst the second
        # state).
        if len(out_edges) != 1:
            return False
        # The interstate edge must not have a condition.
        if out_edges[0].data.condition.as_string != '':
            return False
        # The interstate edge may have assignments, as long as there are input
        # edges to the first state, that can absorb them.
        if out_edges[0].data.assignments and not in_edges:
            return False
        # There can be no state that have output edges pointing to both the
        # first and the second state. Such a case will produce a multi-graph.
        for src, _, _ in in_edges:
            for _, dst, _ in graph.out_edges(src):
                if dst == second_state:
                    return False

        if strict:

            # If second state has other input edges, there might be issues
            second_in_edges = graph.in_edges(second_state)
            if ((not second_state.is_empty() or not first_state.is_empty())
                    and len(second_in_edges) != 1):
                return False

            # Get connected components.
            first_cc = [
                cc_nodes
                for cc_nodes in nx.weakly_connected_components(first_state._nx)
            ]
            second_cc = [
                cc_nodes for cc_nodes in nx.weakly_connected_components(
                    second_state._nx)
            ]

            # Find source/sink (data) nodes
            first_input = {
                node
                for node in nxutil.find_source_nodes(first_state)
                if isinstance(node, nodes.AccessNode)
            }
            first_output = {
                node
                for node in first_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in first_input
            }
            second_input = {
                node
                for node in nxutil.find_source_nodes(second_state)
                if isinstance(node, nodes.AccessNode)
            }
            second_output = {
                node
                for node in second_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in second_input
            }

            # Find source/sink (data) nodes by connected component
            first_cc_input = [cc.intersection(first_input) for cc in first_cc]
            first_cc_output = [
                cc.intersection(first_output) for cc in first_cc
            ]
            second_cc_input = [
                cc.intersection(second_input) for cc in second_cc
            ]
            second_cc_output = [
                cc.intersection(second_output) for cc in second_cc
            ]

            check_strict = len(first_cc)
            for cc_output in first_cc_output:
                for node in cc_output:
                    if next((x for x in second_input if x.label == node.label),
                            None) is not None:
                        check_strict -= 1
                        break

            if check_strict > 0:
                # Check strict conditions
                # RW dependency
                for node in first_input:
                    if next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None:
                        return False
                # WW dependency
                for node in first_output:
                    if next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None:
                        return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        return ' -> '.join(state.label
                           for state in [first_state, second_state])

    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            nxutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            nxutil.change_edge_src(sdfg, second_state, first_state)
            nxutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in nxutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in nxutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in nxutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in first_input:
            try:
                old_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_src(first_state, old_node, node)
            first_state.remove_node(old_node)
        for node in first_output:
            try:
                new_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_dest(first_state, node, new_node)
            first_state.remove_node(node)

        # Redirect edges and remove second state
        nxutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)