Exemplo n.º 1
0
 def expressions():
     state = sd.SDFGState()
     state.add_nedge(DeduplicateAccess._map_entry, DeduplicateAccess._node1,
                     Memlet())
     state.add_nedge(DeduplicateAccess._map_entry, DeduplicateAccess._node2,
                     Memlet())
     return [state]
Exemplo n.º 2
0
class StartStateElimination(transformation.Transformation):
    """
    Start-state elimination removes a redundant state that has one outgoing edge
    and no contents. This transformation applies only to nested SDFGs.
    """

    start_state = sdfg.SDFGState()

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(StartStateElimination.start_state)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        state = graph.nodes()[candidate[StartStateElimination.start_state]]

        # The transformation applies only to nested SDFGs
        if not graph.parent:
            return False

        out_edges = graph.out_edges(state)
        in_edges = graph.in_edges(state)

        # If this is a start state, there are no incoming edges
        if len(in_edges) != 0:
            return False

        # We only match start states with one sink and no conditions
        if len(out_edges) != 1:
            return False
        edge = out_edges[0]
        if not edge.data.is_unconditional():
            return False

        # Only empty states can be eliminated
        if state.number_of_nodes() > 0:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        state = graph.nodes()[candidate[StartStateElimination.start_state]]
        return state.label

    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[StartStateElimination.start_state]]
        # Move assignments to the nested SDFG node's symbol mappings
        node = sdfg.parent_nsdfg_node
        edge = sdfg.out_edges(state)[0]
        for k, v in edge.data.assignments.items():
            node.symbol_mapping[k] = v
        sdfg.remove_node(state)
Exemplo n.º 3
0
class EndStateElimination(transformation.Transformation):
    """
    End-state elimination removes a redundant state that has one incoming edge
    and no contents.
    """

    _end_state = sdfg.SDFGState()

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(EndStateElimination._end_state)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        state = graph.nodes()[candidate[EndStateElimination._end_state]]

        out_edges = graph.out_edges(state)
        in_edges = graph.in_edges(state)

        # If this is an end state, there are no outgoing edges
        if len(out_edges) != 0:
            return False

        # We only match end states with one source and no conditions
        if len(in_edges) != 1:
            return False
        edge = in_edges[0]
        if not edge.data.is_unconditional():
            return False

        # Only empty states can be eliminated
        if state.number_of_nodes() > 0:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        state = graph.nodes()[candidate[EndStateElimination._end_state]]
        return state.label

    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[EndStateElimination._end_state]]
        # Handle orphan symbols (due to the deletion the incoming edge)
        edge = sdfg.in_edges(state)[0]
        sym_assign = edge.data.assignments.keys()
        sdfg.remove_node(state)
        # Remove orphan symbols
        for sym in sym_assign:
            if sym in sdfg.free_symbols:
                sdfg.remove_symbol(sym)
Exemplo n.º 4
0
class StateAssignElimination(transformation.Transformation):
    """ 
    State assign elimination removes all assignments into the final state
    and subsumes the assigned value into its contents.
    """

    _end_state = sdfg.SDFGState()

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(StateAssignElimination._end_state)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        state = graph.nodes()[candidate[StateAssignElimination._end_state]]

        out_edges = graph.out_edges(state)
        in_edges = graph.in_edges(state)

        # If this is an end state, there are no outgoing edges
        if len(out_edges) != 0:
            return False

        # We only match end states with one source and at least one assignment
        if len(in_edges) != 1:
            return False
        edge = in_edges[0]
        if len(edge.data.assignments) == 0:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        state = graph.nodes()[candidate[StateAssignElimination._end_state]]
        return state.label

    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[StateAssignElimination._end_state]]
        edge = sdfg.in_edges(state)[0]
        # Since inter-state assignments that use an assigned value leads to
        # undefined behavior (e.g., {m: n, n: m}), we can replace each
        # assignment separately.
        for varname, assignment in edge.data.assignments.items():
            state.replace(varname, assignment)
        # Remove assignments from edge
        edge.data.assignments = {}
Exemplo n.º 5
0
class StateFusion(pattern_matching.Transformation):
    """ Implements the state-fusion transformation.
        
        State-fusion takes two states that are connected through a single edge,
        and fuses them into one state. If strict, only applies if no memory 
        access hazards are created.
    """

    _states_fused = 0
    _first_state = sdfg.SDFGState()
    _edge = sdfg.InterstateEdge()
    _second_state = sdfg.SDFGState()

    @staticmethod
    def annotates_memlets():
        return False

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(StateFusion._first_state,
                                   StateFusion._second_state)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        out_edges = graph.out_edges(first_state)
        in_edges = graph.in_edges(first_state)

        # First state must have only one output edge (with dst the second
        # state).
        if len(out_edges) != 1:
            return False
        # The interstate edge must not have a condition.
        if not out_edges[0].data.is_unconditional():
            return False
        # The interstate edge may have assignments, as long as there are input
        # edges to the first state that can absorb them.
        if out_edges[0].data.assignments:
            if not in_edges:
                return False
            # Fail if symbol is set before the state to fuse
            # TODO: Also fail if symbol is used in the dataflow of that state
            new_assignments = set(out_edges[0].data.assignments.keys())
            if any((new_assignments & set(e.data.assignments.keys()))
                   for e in in_edges):
                return False

        # There can be no state that have output edges pointing to both the
        # first and the second state. Such a case will produce a multi-graph.
        for src, _, _ in in_edges:
            for _, dst, _ in graph.out_edges(src):
                if dst == second_state:
                    return False

        if strict:
            # If second state has other input edges, there might be issues
            # Exceptions are when none of the states contain dataflow, unless
            # the first state is an initial state (in which case the new initial
            # state would be ambiguous).
            first_in_edges = graph.in_edges(first_state)
            second_in_edges = graph.in_edges(second_state)
            if ((not second_state.is_empty() or not first_state.is_empty()
                 or len(first_in_edges) == 0) and len(second_in_edges) != 1):
                return False

            # Get connected components.
            first_cc = [
                cc_nodes
                for cc_nodes in nx.weakly_connected_components(first_state._nx)
            ]
            second_cc = [
                cc_nodes for cc_nodes in nx.weakly_connected_components(
                    second_state._nx)
            ]

            # Find source/sink (data) nodes
            first_input = {
                node
                for node in sdutil.find_source_nodes(first_state)
                if isinstance(node, nodes.AccessNode)
            }
            first_output = {
                node
                for node in first_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in first_input
            }
            second_input = {
                node
                for node in sdutil.find_source_nodes(second_state)
                if isinstance(node, nodes.AccessNode)
            }
            second_output = {
                node
                for node in second_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in second_input
            }

            # Find source/sink (data) nodes by connected component
            first_cc_input = [cc.intersection(first_input) for cc in first_cc]
            first_cc_output = [
                cc.intersection(first_output) for cc in first_cc
            ]
            second_cc_input = [
                cc.intersection(second_input) for cc in second_cc
            ]
            second_cc_output = [
                cc.intersection(second_output) for cc in second_cc
            ]

            # Apply transformation in case all paths to the second state's
            # nodes go through the same access node, which implies sequential
            # behavior in SDFG semantics.
            check_strict = len(first_cc)
            for cc_output in first_cc_output:
                out_nodes = [
                    n for n in first_state.sink_nodes() if n in cc_output
                ]
                # Branching exists, multiple paths may involve same access node
                # potentially causing data races
                if len(out_nodes) > 1:
                    continue

                # Otherwise, check if any of the second state's connected
                # components for matching input
                for node in out_nodes:
                    if (next(
                        (x for x in second_input if x.label == node.label),
                            None) is not None):
                        check_strict -= 1
                        break

            if check_strict > 0:
                # Check strict conditions
                # RW dependency
                for node in first_input:
                    if (next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None):
                        return False
                # WW dependency
                for node in first_output:
                    if (next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None):
                        return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        return " -> ".join(state.label
                           for state in [first_state, second_state])

    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            sdutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            sdutil.change_edge_src(sdfg, second_state, first_state)
            sdutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in sdutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in sdutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in sdutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        # First keep a backup of the topological sorted order of the nodes
        order = [
            x for x in reversed(list(nx.topological_sort(first_state._nx)))
            if isinstance(x, nodes.AccessNode)
        ]
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in second_input:
            if first_state.in_degree(node) == 0:
                n = next((x for x in order if x.label == node.label), None)
                if n:
                    sdutil.change_edge_src(first_state, node, n)
                    first_state.remove_node(node)
                    n.access = dtypes.AccessType.ReadWrite

        # Redirect edges and remove second state
        sdutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if Config.get_bool("debugprint"):
            StateFusion._states_fused += 1
Exemplo n.º 6
0
class DetectLoop(pattern_matching.Transformation):
    """ Detects a for-loop construct from an SDFG. """

    _loop_guard = sd.SDFGState()
    _loop_begin = sd.SDFGState()
    _exit_state = sd.SDFGState()

    @staticmethod
    def expressions():

        # Case 1: Loop with one state
        sdfg = sd.SDFG('_')
        sdfg.add_nodes_from([
            DetectLoop._loop_guard, DetectLoop._loop_begin,
            DetectLoop._exit_state
        ])
        sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin,
                      edges.InterstateEdge())
        sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state,
                      edges.InterstateEdge())
        sdfg.add_edge(DetectLoop._loop_begin, DetectLoop._loop_guard,
                      edges.InterstateEdge())

        # Case 2: Loop with multiple states (no back-edge from state)
        msdfg = sd.SDFG('_')
        msdfg.add_nodes_from([
            DetectLoop._loop_guard, DetectLoop._loop_begin,
            DetectLoop._exit_state
        ])
        msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin,
                       edges.InterstateEdge())
        msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state,
                       edges.InterstateEdge())

        return [sdfg, msdfg]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        guard = graph.node(candidate[DetectLoop._loop_guard])
        begin = graph.node(candidate[DetectLoop._loop_begin])

        # A for-loop guard only has two incoming edges (init and increment)
        guard_inedges = graph.in_edges(guard)
        if len(guard_inedges) != 2:
            return False
        # A for-loop guard only has two outgoing edges (loop and exit-loop)
        guard_outedges = graph.out_edges(guard)
        if len(guard_outedges) != 2:
            return False

        # Both incoming edges to guard must set exactly one variable and
        # the same one
        if (len(guard_inedges[0].data.assignments) != 1
                or len(guard_inedges[1].data.assignments) != 1):
            return False
        itervar = list(guard_inedges[0].data.assignments.keys())[0]
        if itervar not in guard_inedges[1].data.assignments:
            return False

        # Outgoing edges must not have assignments and be a negation of each
        # other
        if any(len(e.data.assignments) > 0 for e in guard_outedges):
            return False
        if guard_outedges[0].data.condition_sympy() != (sp.Not(
                guard_outedges[1].data.condition_sympy())):
            return False

        # All nodes inside loop must be dominated by loop guard
        dominators = nx.dominance.immediate_dominators(sdfg.nx,
                                                       sdfg.start_state)
        loop_nodes = nxutil.dfs_topological_sort(
            sdfg, sources=[begin], condition=lambda _, child: child != guard)
        backedge_found = False
        for node in loop_nodes:
            if any(e.dst == guard for e in graph.out_edges(node)):
                backedge_found = True

            # Traverse the dominator tree upwards, if we reached the guard,
            # the node is in the loop. If we reach the starting state
            # without passing through the guard, fail.
            dom = node
            while dom != dominators[dom]:
                if dom == guard:
                    break
                dom = dominators[dom]
            else:
                return False

        if not backedge_found:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        guard = graph.node(candidate[DetectLoop._loop_guard])
        begin = graph.node(candidate[DetectLoop._loop_begin])
        sexit = graph.node(candidate[DetectLoop._exit_state])
        ind = list(graph.in_edges(guard)[0].data.assignments.keys())[0]

        return (' -> '.join(state.label for state in [guard, begin, sexit]) +
                ' (for loop over "%s")' % ind)

    def apply(self, sdfg):
        pass
Exemplo n.º 7
0
class SymbolAliasPromotion(transformation.Transformation):
    """
    SymbolAliasPromotion moves inter-state assignments that create symbolic
    aliases to the previous inter-state edge according to the topological order.
    The purpose of this transformation is to iteratively move symbolic aliases
    together, so that true duplicates can be easily removed.
    """

    _first_state = sdfg.SDFGState()
    _second_state = sdfg.SDFGState()

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(SymbolAliasPromotion._first_state,
                                   SymbolAliasPromotion._second_state)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        fstate = graph.nodes()[candidate[SymbolAliasPromotion._first_state]]
        sstate = graph.nodes()[candidate[SymbolAliasPromotion._second_state]]

        # For the topological order to be unambiguous:
        # 1. First state must have unique input edge.
        in_fedges = graph.in_edges(fstate)
        if len(in_fedges) != 1:
            return False
        in_edge = in_fedges[0].data
        # 2. There must be a unique edge from the first state to the second
        # one and no edge from the second state to the first one.
        edges = graph.edges_between(fstate, sstate)
        if len(edges) != 1:
            return False
        if len(graph.edges_between(sstate, fstate)) > 1:
            return False

        edge = edges[0].data
        in_edge = in_fedges[0].data

        to_consider = _alias_assignments(sdfg, edge)

        to_not_consider = set()
        for k, v in to_consider.items():
            # Remove symbols that are taking part in the edge's condition
            condsyms = [str(s) for s in edge.condition_sympy().free_symbols]
            if k in condsyms:
                to_not_consider.add(k)
            # Remove symbols that are set in the in_edge
            # with a different assignment
            if k in in_edge.assignments and in_edge.assignments[k] != v:
                to_not_consider.add(k)
            # Remove symbols whose assignment (RHS) is a symbol
            # and is set in the in_edge.
            if v in sdfg.symbols and v in in_edge.assignments:
                to_not_consider.add(k)
            # Remove symbols whose assignment (RHS) is a scalar
            # and is set in the first state.
            if v in sdfg.arrays and isinstance(sdfg.arrays[v], dt.Scalar):
                if any(
                        isinstance(n, nodes.AccessNode) and n.data == v
                        for n in fstate.nodes()):
                    to_not_consider.add(k)

        for k in to_not_consider:
            del to_consider[k]

        # No assignments to promote
        if len(to_consider) == 0:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        state = graph.nodes()[candidate[SymbolAliasPromotion._second_state]]
        return state.label

    def apply(self, sdfg):
        fstate = sdfg.nodes()[self.subgraph[SymbolAliasPromotion._first_state]]
        sstate = sdfg.nodes()[self.subgraph[SymbolAliasPromotion._second_state]]

        edge = sdfg.edges_between(fstate, sstate)[0].data
        in_edge = sdfg.in_edges(fstate)[0].data

        to_consider = _alias_assignments(sdfg, edge)

        to_not_consider = set()
        for k, v in to_consider.items():
            # Remove symbols that are taking part in the edge's condition
            condsyms = [str(s) for s in edge.condition_sympy().free_symbols]
            if k in condsyms:
                to_not_consider.add(k)
            # Remove symbols that are set in the in_edge
            # with a different assignment
            if k in in_edge.assignments and in_edge.assignments[k] != v:
                to_not_consider.add(k)
            # Remove symbols whose assignment (RHS) is a symbol
            # and is set in the in_edge.
            if v in sdfg.symbols and v in in_edge.assignments:
                to_not_consider.add(k)
            # Remove symbols whose assignment (RHS) is a scalar
            # and is set in the first state.
            if v in sdfg.arrays and isinstance(sdfg.arrays[v], dt.Scalar):
                if any(
                        isinstance(n, nodes.AccessNode) and n.data == v
                        for n in fstate.nodes()):
                    to_not_consider.add(k)

        for k in to_not_consider:
            del to_consider[k]

        for k, v in to_consider.items():
            del edge.assignments[k]
            in_edge.assignments[k] = v
Exemplo n.º 8
0
class FPGATransformState(pattern_matching.Transformation):
    """ Implements the FPGATransformState transformation. """

    _state = sd.SDFGState()

    @staticmethod
    def expressions():
        return [nxutil.node_path_graph(FPGATransformState._state)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        state = graph.nodes()[candidate[FPGATransformState._state]]

        for node in state.nodes():

            if (isinstance(node, nodes.AccessNode)
                    and node.desc(sdfg).storage != types.StorageType.Default):
                return False

            if not isinstance(node, nodes.MapEntry):
                continue

            map_entry = node
            candidate_map = map_entry.map

            # No more than 3 dimensions
            if candidate_map.range.dims() > 3: return False

            # Map schedules that are disallowed to transform to FPGAs
            if (candidate_map.schedule == types.ScheduleType.MPI
                    or candidate_map.schedule == types.ScheduleType.GPU_Device
                    or candidate_map.schedule == types.ScheduleType.FPGA_Device
                    or candidate_map.schedule
                    == types.ScheduleType.GPU_ThreadBlock):
                return False

            # Recursively check parent for FPGA schedules
            sdict = state.scope_dict()
            current_node = map_entry
            while current_node != None:
                if (current_node.map.schedule == types.ScheduleType.GPU_Device
                        or current_node.map.schedule
                        == types.ScheduleType.FPGA_Device
                        or current_node.map.schedule
                        == types.ScheduleType.GPU_ThreadBlock):
                    return False
                current_node = sdict[current_node]

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        state = graph.nodes()[candidate[FPGATransformState._state]]

        return state.label

    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(state)
        output_nodes = nxutil.find_sink_nodes(state)

        fpga_data = {}

        if input_nodes:

            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if array.name in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.dtype,
                        array.shape,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=types.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        access_order=array.access_order,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[array.name] = fpga_array
                fpga_node = type(node)(fpga_array)

                pre_state.add_node(node)
                pre_state.add_node(fpga_node)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(array, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(node, None, fpga_node, None, mem)

                state.add_node(fpga_node)
                nxutil.change_edge_src(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(pre_state)
            nxutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, edges.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.dtype,
                        array.shape,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=types.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        access_order=array.access_order,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[node.data] = fpga_array
                fpga_node = type(node)(fpga_array)

                post_state.add_node(node)
                post_state.add_node(fpga_node)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(fpga_array, full_range.num_elements(),
                                    full_range, 1)
                post_state.add_edge(fpga_node, None, node, None, mem)

                state.add_node(fpga_node)
                nxutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            nxutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, edges.InterstateEdge())

        for src, _, dst, _, mem in state.edges():
            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + node.data

        fpga_update(state, 0)
Exemplo n.º 9
0
class FPGATransformState(pattern_matching.Transformation):
    """ Implements the FPGATransformState transformation. """

    _state = sd.SDFGState()

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(FPGATransformState._state)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        state = graph.nodes()[candidate[FPGATransformState._state]]

        # TODO: Support most of these cases
        for edge, graph in state.all_edges_recursive():
            # Code->Code memlets are disallowed (for now)
            if (isinstance(edge.src, nodes.CodeNode)
                    and isinstance(edge.dst, nodes.CodeNode)):
                return False

        for node, graph in state.all_nodes_recursive():
            # Consume scopes are currently unsupported
            if isinstance(node, (nodes.ConsumeEntry, nodes.ConsumeExit)):
                return False

            # Streams have strict conditions due to code generator limitations
            if (isinstance(node, nodes.AccessNode)
                    and isinstance(sdfg.arrays[node.data], data.Stream)):
                nodedesc = graph.parent.arrays[node.data]
                sdict = graph.scope_dict()
                if nodedesc.storage in [
                        dtypes.StorageType.CPU_Heap,
                        dtypes.StorageType.CPU_Pinned,
                        dtypes.StorageType.CPU_ThreadLocal
                ]:
                    return False

                # Cannot allocate FIFO from CPU code
                if sdict[node] is None:
                    return False

                # Arrays of streams cannot have symbolic size on FPGA
                if dace.symbolic.issymbolic(nodedesc.total_size,
                                            graph.parent.constants):
                    return False

                # Streams cannot be unbounded on FPGA
                if nodedesc.buffer_size < 1:
                    return False

        for node in state.nodes():

            if (isinstance(node, nodes.AccessNode)
                    and node.desc(sdfg).storage != dtypes.StorageType.Default):
                return False

            if not isinstance(node, nodes.MapEntry):
                continue

            map_entry = node
            candidate_map = map_entry.map

            # No more than 3 dimensions
            if candidate_map.range.dims() > 3: return False

            # Map schedules that are disallowed to transform to FPGAs
            if (candidate_map.schedule == dtypes.ScheduleType.MPI
                    or candidate_map.schedule == dtypes.ScheduleType.GPU_Device
                    or
                    candidate_map.schedule == dtypes.ScheduleType.FPGA_Device
                    or candidate_map.schedule ==
                    dtypes.ScheduleType.GPU_ThreadBlock):
                return False

            # Recursively check parent for FPGA schedules
            sdict = state.scope_dict()
            current_node = map_entry
            while current_node is not None:
                if (current_node.map.schedule == dtypes.ScheduleType.GPU_Device
                        or current_node.map.schedule ==
                        dtypes.ScheduleType.FPGA_Device
                        or current_node.map.schedule ==
                        dtypes.ScheduleType.GPU_ThreadBlock):
                    return False
                current_node = sdict[current_node]

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        state = graph.nodes()[candidate[FPGATransformState._state]]

        return state.label

    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = sdutil.find_source_nodes(state)
        output_nodes = sdutil.find_sink_nodes(state)

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        parent_sdfg = {state: sdfg}  # Map states to their parent SDFG
        for node, graph in state.all_nodes_recursive():
            if isinstance(graph, dace.SDFG):
                parent_sdfg[node] = graph
            if isinstance(node, dace.sdfg.nodes.AccessNode):
                for e in graph.all_edges(node):
                    if e.data.wcr is not None:
                        trace = dace.sdfg.trace_nested_access(
                            node, graph, parent_sdfg[graph])
                        for node_trace, state_trace, sdfg_trace in trace:
                            # Find the name of the accessed node in our scope
                            if state_trace == state and sdfg_trace == sdfg:
                                outer_node = node_trace
                                break
                            else:
                                # This does not trace back to the current state, so
                                # we don't care
                                continue
                        input_nodes.append(outer_node)
                        wcr_input_nodes.add(outer_node)

        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        materialize_func=desc.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet(node.data, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    sdutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            sdutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, sd.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        materialize_func=desc.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet('fpga_' + node.data,
                                    full_range.num_elements(), full_range, 1)
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                sdutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            sdutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, sd.InterstateEdge())

        veclen_ = 1

        # propagate vector info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            # need to go inside the nested SDFG and grab the vector length
            if isinstance(dst, dace.sdfg.nodes.NestedSDFG):
                # this edge is going to the nested SDFG
                for inner_state in dst.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.sdfg.nodes.AccessNode
                                      ) and n.data == dst_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen
            if isinstance(src, dace.sdfg.nodes.NestedSDFG):
                # this edge is coming from the nested SDFG
                for inner_state in src.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.sdfg.nodes.AccessNode
                                      ) and n.data == src_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen

            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
                mem.veclen = veclen_

        fpga_update(sdfg, state, 0)
Exemplo n.º 10
0
class StateFusion(transformation.Transformation):
    """ Implements the state-fusion transformation.
        
        State-fusion takes two states that are connected through a single edge,
        and fuses them into one state. If strict, only applies if no memory 
        access hazards are created.
    """

    _first_state = sdfg.SDFGState()
    _second_state = sdfg.SDFGState()

    @staticmethod
    def annotates_memlets():
        return False

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(StateFusion._first_state,
                                   StateFusion._second_state)
        ]

    @staticmethod
    def find_fused_components(first_cc_input, first_cc_output, second_cc_input,
                              second_cc_output) -> List[CCDesc]:
        # Make a bipartite graph out of the first and second components
        g = nx.DiGraph()
        g.add_nodes_from((0, i) for i in range(len(first_cc_output)))
        g.add_nodes_from((1, i) for i in range(len(second_cc_output)))
        # Find matching nodes in second state
        for i, cc1 in enumerate(first_cc_output):
            outnames1 = {n.data for n in cc1}
            for j, cc2 in enumerate(second_cc_input):
                inpnames2 = {n.data for n in cc2}
                if len(outnames1 & inpnames2) > 0:
                    g.add_edge((0, i), (1, j))

        # Construct result out of connected components of the bipartite graph
        result = []
        for cc in nx.weakly_connected_components(g):
            input1, output1, input2, output2 = set(), set(), set(), set()
            for gind, cind in cc:
                if gind == 0:
                    input1 |= {n.data for n in first_cc_input[cind]}
                    output1 |= {n.data for n in first_cc_output[cind]}
                else:
                    input2 |= {n.data for n in second_cc_input[cind]}
                    output2 |= {n.data for n in second_cc_output[cind]}
            result.append(CCDesc(input1, output1, input2, output2))

        return result

    @staticmethod
    def memlets_intersect(graph_a: SDFGState, group_a: List[nodes.AccessNode],
                          inputs_a: bool, graph_b: SDFGState,
                          group_b: List[nodes.AccessNode],
                          inputs_b: bool) -> bool:
        """ 
        Performs an all-pairs check for subset intersection on two
        groups of nodes. If group intersects or result is indeterminate, 
        returns True as a precaution.
        :param graph_a: The graph in which the first set of nodes reside.
        :param group_a: The first set of nodes to check.
        :param inputs_a: If True, checks inputs of the first group.
        :param graph_b: The graph in which the second set of nodes reside.
        :param group_b: The second set of nodes to check.
        :param inputs_b: If True, checks inputs of the second group.
        :returns True if subsets intersect or result is indeterminate.
        """
        # Set traversal functions
        src_subset = lambda e: (e.data.src_subset if e.data.src_subset is
                                not None else e.data.dst_subset)
        dst_subset = lambda e: (e.data.dst_subset if e.data.dst_subset is
                                not None else e.data.src_subset)
        if inputs_a:
            edges_a = [e for n in group_a for e in graph_a.out_edges(n)]
            subset_a = src_subset
        else:
            edges_a = [e for n in group_a for e in graph_a.in_edges(n)]
            subset_a = dst_subset
        if inputs_b:
            edges_b = [e for n in group_b for e in graph_b.out_edges(n)]
            subset_b = src_subset
        else:
            edges_b = [e for n in group_b for e in graph_b.in_edges(n)]
            subset_b = dst_subset

        # Simple all-pairs check
        for ea in edges_a:
            for eb in edges_b:
                result = subsets.intersects(subset_a(ea), subset_b(eb))
                if result is True or result is None:
                    return True
        return False

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        out_edges = graph.out_edges(first_state)
        in_edges = graph.in_edges(first_state)

        # First state must have only one output edge (with dst the second
        # state).
        if len(out_edges) != 1:
            return False
        # The interstate edge must not have a condition.
        if not out_edges[0].data.is_unconditional():
            return False
        # The interstate edge may have assignments, as long as there are input
        # edges to the first state that can absorb them.
        if out_edges[0].data.assignments:
            if not in_edges:
                return False
            # Fail if symbol is set before the state to fuse
            new_assignments = set(out_edges[0].data.assignments.keys())
            if any((new_assignments & set(e.data.assignments.keys()))
                   for e in in_edges):
                return False
            # Fail if symbol is used in the dataflow of that state
            if len(new_assignments & first_state.free_symbols) > 0:
                return False

        # There can be no state that have output edges pointing to both the
        # first and the second state. Such a case will produce a multi-graph.
        for src, _, _ in in_edges:
            for _, dst, _ in graph.out_edges(src):
                if dst == second_state:
                    return False

        if strict:
            # If second state has other input edges, there might be issues
            # Exceptions are when none of the states contain dataflow, unless
            # the first state is an initial state (in which case the new initial
            # state would be ambiguous).
            first_in_edges = graph.in_edges(first_state)
            second_in_edges = graph.in_edges(second_state)
            if ((not second_state.is_empty() or not first_state.is_empty()
                 or len(first_in_edges) == 0) and len(second_in_edges) != 1):
                return False

            # Get connected components.
            first_cc = [
                cc_nodes
                for cc_nodes in nx.weakly_connected_components(first_state._nx)
            ]
            second_cc = [
                cc_nodes for cc_nodes in nx.weakly_connected_components(
                    second_state._nx)
            ]

            # Find source/sink (data) nodes
            first_input = {
                node
                for node in sdutil.find_source_nodes(first_state)
                if isinstance(node, nodes.AccessNode)
            }
            first_output = {
                node
                for node in first_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in first_input
            }
            second_input = {
                node
                for node in sdutil.find_source_nodes(second_state)
                if isinstance(node, nodes.AccessNode)
            }
            second_output = {
                node
                for node in second_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in second_input
            }

            # Find source/sink (data) nodes by connected component
            first_cc_input = [cc.intersection(first_input) for cc in first_cc]
            first_cc_output = [
                cc.intersection(first_output) for cc in first_cc
            ]
            second_cc_input = [
                cc.intersection(second_input) for cc in second_cc
            ]
            second_cc_output = [
                cc.intersection(second_output) for cc in second_cc
            ]

            # Apply transformation in case all paths to the second state's
            # nodes go through the same access node, which implies sequential
            # behavior in SDFG semantics.
            first_output_names = {node.data for node in first_output}
            second_input_names = {node.data for node in second_input}

            # If any second input appears more than once, fail
            if len(second_input) > len(second_input_names):
                return False

            # If any first output that is an input to the second state
            # appears in more than one CC, fail
            matches = first_output_names & second_input_names
            for match in matches:
                cc_appearances = 0
                for cc in first_cc_output:
                    if len([n for n in cc if n.data == match]) > 0:
                        cc_appearances += 1
                if cc_appearances > 1:
                    return False

            # Recreate fused connected component correspondences, and then
            # check for hazards
            resulting_ccs: List[CCDesc] = StateFusion.find_fused_components(
                first_cc_input, first_cc_output, second_cc_input,
                second_cc_output)

            # Check for data races
            for fused_cc in resulting_ccs:
                # Write-Write hazard - data is output of both first and second
                # states, without a read in between
                write_write_candidates = (
                    (fused_cc.first_outputs & fused_cc.second_outputs) -
                    fused_cc.second_inputs)
                if len(write_write_candidates) > 0:
                    # If we have potential candidates, check if there is a
                    # path from the first write to the second write (in that
                    # case, there is no hazard):
                    # Find the leaf (topological) instances of the matches
                    order = [
                        x for x in reversed(
                            list(nx.topological_sort(first_state._nx)))
                        if isinstance(x, nodes.AccessNode)
                        and x.data in fused_cc.first_outputs
                    ]
                    # Those nodes will be the connection points upon fusion
                    match_nodes = {
                        next(n for n in order if n.data == match)
                        for match in (fused_cc.first_outputs
                                      & fused_cc.second_inputs)
                    }
                else:
                    match_nodes = set()

                for cand in write_write_candidates:
                    nodes_first = [n for n in first_output if n.data == cand]
                    nodes_second = [n for n in second_output if n.data == cand]

                    # If there is a path for the candidate that goes through
                    # the match nodes in both states, there is no conflict
                    fail = False
                    path_found = False
                    for match in match_nodes:
                        for node in nodes_first:
                            path_to = nx.has_path(first_state._nx, node, match)
                            if not path_to:
                                continue
                            path_found = True
                            node2 = next(n for n in second_input
                                         if n.data == match.data)
                            if not all(
                                    nx.has_path(second_state._nx, node2, n)
                                    for n in nodes_second):
                                fail = True
                                break
                        if fail or path_found:
                            break

                    # Check for intersection (if None, fusion is ok)
                    if fail or not path_found:
                        if StateFusion.memlets_intersect(
                                first_state, nodes_first, False, second_state,
                                nodes_second, False):
                            return False
                # End of write-write hazard check

                first_inout = fused_cc.first_inputs | fused_cc.first_outputs
                for other_cc in resulting_ccs:
                    if other_cc is fused_cc:
                        continue
                    # If an input/output of a connected component in the first
                    # state is an output of another connected component in the
                    # second state, we have a potential data race (Read-Write
                    # or Write-Write)
                    for d in first_inout:
                        if d in other_cc.second_outputs:
                            # Check for intersection (if None, fusion is ok)
                            nodes_second = [
                                n for n in second_output if n.data == d
                            ]
                            # Read-Write race
                            if d in fused_cc.first_inputs:
                                nodes_first = [
                                    n for n in first_input if n.data == d
                                ]
                                if StateFusion.memlets_intersect(
                                        first_state, nodes_first, True,
                                        second_state, nodes_second, False):
                                    return False
                            # Write-Write race
                            if d in fused_cc.first_outputs:
                                nodes_first = [
                                    n for n in first_output if n.data == d
                                ]
                                if StateFusion.memlets_intersect(
                                        first_state, nodes_first, False,
                                        second_state, nodes_second, False):
                                    return False
                    # End of data race check

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        return " -> ".join(state.label
                           for state in [first_state, second_state])

    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            sdutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            sdutil.change_edge_src(sdfg, second_state, first_state)
            sdutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in sdutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in sdutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in sdutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        # First keep a backup of the topological sorted order of the nodes
        order = [
            x for x in reversed(list(nx.topological_sort(first_state._nx)))
            if isinstance(x, nodes.AccessNode)
        ]
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in second_input:
            if first_state.in_degree(node) == 0:
                n = next((x for x in order if x.label == node.label), None)
                if n:
                    sdutil.change_edge_src(first_state, node, n)
                    first_state.remove_node(node)
                    n.access = dtypes.AccessType.ReadWrite

        # Redirect edges and remove second state
        sdutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
Exemplo n.º 11
0
class FPGATransformState(transformation.Transformation):
    """ Implements the FPGATransformState transformation. """

    _state = sd.SDFGState()

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(FPGATransformState._state)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False):
        state = graph.nodes()[candidate[FPGATransformState._state]]

        for node, graph in state.all_nodes_recursive():
            # Consume scopes are currently unsupported
            if isinstance(node, (nodes.ConsumeEntry, nodes.ConsumeExit)):
                return False

            # Streams have strict conditions due to code generator limitations
            if (isinstance(node, nodes.AccessNode) and isinstance(
                    graph.parent.arrays[node.data], data.Stream)):
                nodedesc = graph.parent.arrays[node.data]
                sdict = graph.scope_dict()
                if nodedesc.storage in [
                        dtypes.StorageType.CPU_Heap,
                        dtypes.StorageType.CPU_Pinned,
                        dtypes.StorageType.CPU_ThreadLocal
                ]:
                    return False

                # Cannot allocate FIFO from CPU code
                if sdict[node] is None:
                    return False

                # Arrays of streams cannot have symbolic size on FPGA
                if dace.symbolic.issymbolic(nodedesc.total_size,
                                            graph.parent.constants):
                    return False

                # Streams cannot be unbounded on FPGA
                if nodedesc.buffer_size < 1:
                    return False

        for node in state.nodes():

            if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).storage
                    not in (dtypes.StorageType.Default,
                            dtypes.StorageType.Register)):
                return False

            if not isinstance(node, nodes.MapEntry):
                continue

            map_entry = node
            candidate_map = map_entry.map

            # Map schedules that are disallowed to transform to FPGAs
            if (candidate_map.schedule == dtypes.ScheduleType.MPI
                    or candidate_map.schedule == dtypes.ScheduleType.GPU_Device
                    or candidate_map.schedule == dtypes.ScheduleType.FPGA_Device
                    or candidate_map.schedule
                    == dtypes.ScheduleType.GPU_ThreadBlock):
                return False

            # Recursively check parent for FPGA schedules
            sdict = state.scope_dict()
            current_node = map_entry
            while current_node is not None:
                if (current_node.map.schedule == dtypes.ScheduleType.GPU_Device
                        or current_node.map.schedule
                        == dtypes.ScheduleType.FPGA_Device
                        or current_node.map.schedule
                        == dtypes.ScheduleType.GPU_ThreadBlock):
                    return False
                current_node = sdict[current_node]

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        state = graph.nodes()[candidate[FPGATransformState._state]]

        return state.label

    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes that are relevant outside this FPGA
        # kernel
        shared_transients = set(sdfg.shared_transients())
        input_nodes = [
            n for n in sdutil.find_source_nodes(state)
            if isinstance(n, nodes.AccessNode) and
            (not sdfg.arrays[n.data].transient or n.data in shared_transients)
        ]
        output_nodes = [
            n for n in sdutil.find_sink_nodes(state)
            if isinstance(n, nodes.AccessNode) and
            (not sdfg.arrays[n.data].transient or n.data in shared_transients)
        ]

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        parent_sdfg = {state: sdfg}  # Map states to their parent SDFG
        for node, graph in state.all_nodes_recursive():
            if isinstance(graph, dace.SDFG):
                parent_sdfg[node] = graph
            if isinstance(node, dace.sdfg.nodes.AccessNode):
                for e in graph.in_edges(node):
                    if e.data.wcr is not None:
                        trace = dace.sdfg.trace_nested_access(
                            node, graph, parent_sdfg[graph])
                        for node_trace, memlet_trace, state_trace, sdfg_trace in trace:
                            # Find the name of the accessed node in our scope
                            if state_trace == state and sdfg_trace == sdfg:
                                _, outer_node = node_trace
                                if outer_node is not None:
                                    break
                        else:
                            # This does not trace back to the current state, so
                            # we don't care
                            continue
                        input_nodes.append(outer_node)
                        wcr_input_nodes.add(outer_node)
        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_array[1].location = copy.copy(desc.location)
                    desc.location.clear()
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                mem = memlet.Memlet(data=node.data,
                                    subset=subsets.Range.from_array(desc))
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    sdutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            sdutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, sd.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_array[1].location = copy.copy(desc.location)
                    desc.location.clear()
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                mem = memlet.Memlet(f"fpga_{node.data}", None,
                                    subsets.Range.from_array(desc))
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                sdutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            sdutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, sd.InterstateEdge())

        # propagate memlet info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
        fpga_update(sdfg, state, 0)
Exemplo n.º 12
0
class StateFusion(pattern_matching.Transformation):
    """ Implements the state-fusion transformation.
        
        State-fusion takes two states that are connected through a single edge,
        and fuses them into one state. If strict, only applies if no memory 
        access hazards are created.
    """

    _states_fused = 0
    _first_state = sdfg.SDFGState()
    _edge = edges.InterstateEdge()
    _second_state = sdfg.SDFGState()

    @staticmethod
    def annotates_memlets():
        return False

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(StateFusion._first_state,
                                   StateFusion._second_state)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        out_edges = graph.out_edges(first_state)
        in_edges = graph.in_edges(first_state)

        # First state must have only one output edge (with dst the second
        # state).
        if len(out_edges) != 1:
            return False
        # The interstate edge must not have a condition.
        if out_edges[0].data.condition.as_string != "":
            return False
        # The interstate edge may have assignments, as long as there are input
        # edges to the first state, that can absorb them.
        if out_edges[0].data.assignments and not in_edges:
            return False
        # There can be no state that have output edges pointing to both the
        # first and the second state. Such a case will produce a multi-graph.
        for src, _, _ in in_edges:
            for _, dst, _ in graph.out_edges(src):
                if dst == second_state:
                    return False

        if strict:
            # If second state has other input edges, there might be issues
            # Exceptions are when none of the states contain dataflow, unless
            # the first state is an initial state (in which case the new initial
            # state would be ambiguous).
            first_in_edges = graph.in_edges(first_state)
            second_in_edges = graph.in_edges(second_state)
            if ((not second_state.is_empty() or not first_state.is_empty()
                 or len(first_in_edges) == 0) and len(second_in_edges) != 1):
                return False

            # Get connected components.
            first_cc = [
                cc_nodes
                for cc_nodes in nx.weakly_connected_components(first_state._nx)
            ]
            second_cc = [
                cc_nodes for cc_nodes in nx.weakly_connected_components(
                    second_state._nx)
            ]

            # Find source/sink (data) nodes
            first_input = {
                node
                for node in nxutil.find_source_nodes(first_state)
                if isinstance(node, nodes.AccessNode)
            }
            first_output = {
                node
                for node in first_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in first_input
            }
            second_input = {
                node
                for node in nxutil.find_source_nodes(second_state)
                if isinstance(node, nodes.AccessNode)
            }
            second_output = {
                node
                for node in second_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in second_input
            }

            # Find source/sink (data) nodes by connected component
            first_cc_input = [cc.intersection(first_input) for cc in first_cc]
            first_cc_output = [
                cc.intersection(first_output) for cc in first_cc
            ]
            second_cc_input = [
                cc.intersection(second_input) for cc in second_cc
            ]
            second_cc_output = [
                cc.intersection(second_output) for cc in second_cc
            ]

            check_strict = len(first_cc)
            for cc_output in first_cc_output:
                for node in cc_output:
                    if (next(
                        (x for x in second_input if x.label == node.label),
                            None) is not None):
                        check_strict -= 1
                        break

            if check_strict > 0:
                # Check strict conditions
                # RW dependency
                for node in first_input:
                    if (next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None):
                        return False
                # WW dependency
                for node in first_output:
                    if (next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None):
                        return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        return " -> ".join(state.label
                           for state in [first_state, second_state])

    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            nxutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            nxutil.change_edge_src(sdfg, second_state, first_state)
            nxutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in nxutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in nxutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in nxutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in first_input:
            try:
                old_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_src(first_state, old_node, node)
            first_state.remove_node(old_node)
            second_input.remove(old_node)
        for node in first_output:
            try:
                new_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_dest(first_state, node, new_node)
            first_state.remove_node(node)
            second_input.remove(new_node)

        # Redirect edges and remove second state
        nxutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if Config.get_bool("debugprint"):
            StateFusion._states_fused += 1

    @staticmethod
    def print_debuginfo():
        print("Automatically fused {} states using StateFusion transform.".
              format(StateFusion._states_fused))
Exemplo n.º 13
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(state)
        output_nodes = nxutil.find_sink_nodes(state)

        fpga_data = {}

        if input_nodes:

            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if array.name in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.dtype,
                        array.shape,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=types.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        access_order=array.access_order,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[array.name] = fpga_array
                fpga_node = type(node)(fpga_array)

                pre_state.add_node(node)
                pre_state.add_node(fpga_node)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(array, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(node, None, fpga_node, None, mem)

                state.add_node(fpga_node)
                nxutil.change_edge_src(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(pre_state)
            nxutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, edges.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.dtype,
                        array.shape,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=types.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        access_order=array.access_order,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[node.data] = fpga_array
                fpga_node = type(node)(fpga_array)

                post_state.add_node(node)
                post_state.add_node(fpga_node)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(fpga_array, full_range.num_elements(),
                                    full_range, 1)
                post_state.add_edge(fpga_node, None, node, None, mem)

                state.add_node(fpga_node)
                nxutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            nxutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, edges.InterstateEdge())

        for src, _, dst, _, mem in state.edges():
            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + node.data

        fpga_update(state, 0)
Exemplo n.º 14
0
    def apply(self, _, sdfg):
        state = self.state

        # Find source/sink (data) nodes that are relevant outside this FPGA
        # kernel
        shared_transients = set(sdfg.shared_transients())
        input_nodes = [
            n for n in sdutil.find_source_nodes(state)
            if isinstance(n, nodes.AccessNode) and
            (not sdfg.arrays[n.data].transient or n.data in shared_transients)
        ]
        output_nodes = [
            n for n in sdutil.find_sink_nodes(state)
            if isinstance(n, nodes.AccessNode) and
            (not sdfg.arrays[n.data].transient or n.data in shared_transients)
        ]

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        parent_sdfg = {state: sdfg}  # Map states to their parent SDFG
        for node, graph in state.all_nodes_recursive():
            if isinstance(graph, dace.SDFG):
                parent_sdfg[node] = graph
            if isinstance(node, dace.sdfg.nodes.AccessNode):
                for e in graph.in_edges(node):
                    if e.data.wcr is not None:
                        trace = dace.sdfg.trace_nested_access(
                            node, graph, parent_sdfg[graph])
                        for node_trace, memlet_trace, state_trace, sdfg_trace in trace:
                            # Find the name of the accessed node in our scope
                            if state_trace == state and sdfg_trace == sdfg:
                                _, outer_node = node_trace
                                if outer_node is not None:
                                    break
                        else:
                            # This does not trace back to the current state, so
                            # we don't care
                            continue
                        input_nodes.append(outer_node)
                        wcr_input_nodes.add(outer_node)
        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_array[1].location = copy.copy(desc.location)
                    desc.location.clear()
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                mem = memlet.Memlet(data=node.data,
                                    subset=subsets.Range.from_array(desc))
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    sdutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            sdutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, sd.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_array[1].location = copy.copy(desc.location)
                    desc.location.clear()
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                mem = memlet.Memlet(f"fpga_{node.data}", None,
                                    subsets.Range.from_array(desc))
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                sdutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            sdutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, sd.InterstateEdge())

        # propagate memlet info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
        fpga_update(sdfg, state, 0)
Exemplo n.º 15
0
class DoubleBuffering(pattern_matching.Transformation):
    """ Implements the double buffering pattern, which pipelines reading
        and processing data by creating a second copy of the memory. """

    _begin = sd.SDFGState()
    _guard = sd.SDFGState()
    _body = sd.SDFGState()
    _end = sd.SDFGState()

    @staticmethod
    def expressions():
        for_loop_graph = dace.graph.graph.OrderedDiGraph()
        for_loop_graph.add_nodes_from([
            DoubleBuffering._begin, DoubleBuffering._guard,
            DoubleBuffering._body, DoubleBuffering._end
        ])
        for_loop_graph.add_edge(DoubleBuffering._begin, DoubleBuffering._guard,
                                None)
        for_loop_graph.add_edge(DoubleBuffering._guard, DoubleBuffering._body,
                                None)
        for_loop_graph.add_edge(DoubleBuffering._body, DoubleBuffering._guard,
                                None)
        for_loop_graph.add_edge(DoubleBuffering._guard, DoubleBuffering._end,
                                None)

        return [for_loop_graph]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        begin = graph.nodes()[candidate[DoubleBuffering._begin]]
        guard = graph.nodes()[candidate[DoubleBuffering._guard]]
        body = graph.nodes()[candidate[DoubleBuffering._body]]
        end = graph.nodes()[candidate[DoubleBuffering._end]]

        if not begin.is_empty():
            return False
        if not guard.is_empty():
            return False
        if not end.is_empty():
            return False
        if body.is_empty():
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        begin = graph.nodes()[candidate[DoubleBuffering._begin]]
        guard = graph.nodes()[candidate[DoubleBuffering._guard]]
        body = graph.nodes()[candidate[DoubleBuffering._body]]
        end = graph.nodes()[candidate[DoubleBuffering._end]]

        return ', '.join(state.label for state in [begin, guard, body, end])

    def apply(self, sdfg):
        begin = sdfg.nodes()[self.subgraph[DoubleBuffering._begin]]
        guard = sdfg.nodes()[self.subgraph[DoubleBuffering._guard]]
        body = sdfg.nodes()[self.subgraph[DoubleBuffering._body]]
        end = sdfg.nodes()[self.subgraph[DoubleBuffering._end]]

        loop_vars = []
        for _, dst, e in sdfg.out_edges(body):
            if dst is guard:
                for var in e.assignments.keys():
                    loop_vars.append(var)

        if len(loop_vars) != 1:
            raise NotImplementedError()

        loop_var = loop_vars[0]
        sym_var = dace.symbolic.pystr_to_symbolic(loop_var)

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(body)
        #output_nodes = nxutil.find_sink_nodes(body)

        copied_nodes = set()
        db_nodes = {}
        for node in input_nodes:
            for _, _, dst, _, mem in body.out_edges(node):
                if (isinstance(dst, dace.graph.nodes.AccessNode)
                        and loop_var in mem.subset.free_symbols):
                    # Create new data and nodes in guard
                    if node not in copied_nodes:
                        guard.add_node(node)
                        copied_nodes.add(node)
                    if dst not in copied_nodes:
                        old_data = dst.desc(sdfg)
                        if isinstance(old_data, dace.data.Array):
                            new_shape = tuple([2] + list(old_data.shape))
                            new_data = sdfg.add_array(old_data.data,
                                                      old_data.dtype,
                                                      new_shape,
                                                      transient=True)
                        elif isinstance(old_data, data.Scalar):
                            new_data = sdfg.add_array(old_data.data,
                                                      old_data.dtype, (2),
                                                      transient=True)
                        else:
                            raise NotImplementedError()
                        new_node = dace.graph.nodes.AccessNode(old_data.data)
                        guard.add_node(new_node)
                        copied_nodes.add(dst)
                        db_nodes.update({dst: new_node})
                    # Create memlet in guard
                    new_mem = copy.deepcopy(mem)
                    old_index = new_mem.other_subset
                    if isinstance(old_index, dace.subsets.Range):
                        new_ranges = [(0, 0, 1)] + old_index.ranges
                        new_mem.other_subset = dace.subsets.Range(new_ranges)
                    elif isinstance(old_index, dace.subsets.Indices):
                        new_indices = [0] + old_index.indices
                        new_mem.other_subset = dace.subsets.Indices(
                            new_indices)
                    guard.add_edge(node, None, new_node, None, new_mem)
                    # Create nodes, memlets in body
                    first_node = copy.deepcopy(new_node)
                    second_node = copy.deepcopy(new_node)
                    body.add_nodes_from([first_node, second_node])
                    dace.graph.nxutil.change_edge_dest(body, dst, first_node)
                    dace.graph.nxutil.change_edge_src(body, dst, second_node)
                    for src, _, dest, _, memm in body.edges():
                        if src is node and dest is first_node:
                            old_index = memm.other_subset
                            idx = (sym_var + 1) % 2
                            if isinstance(old_index, dace.subsets.Range):
                                new_ranges = [(idx, idx, 1)] + old_index.ranges
                            elif isinstance(old_index, dace.subsets.Indices):
                                new_ranges = [(idx, idx, 1)]
                                for index in old_index.indices:
                                    new_ranges.append((index, index, 1))
                            memm.other_subset = dace.subsets.Range(new_ranges)
                        elif memm.data == dst.data:
                            old_index = memm.subset
                            idx = sym_var % 2
                            if isinstance(old_index, dace.subsets.Range):
                                new_ranges = [(idx, idx, 1)] + old_index.ranges
                            elif isinstance(old_index, dace.subsets.Indices):
                                new_ranges = [(idx, idx, 1)]
                                for index in old_index.indices:
                                    new_ranges.append((index, index, 1))
                            memm.subset = dace.subsets.Range(new_ranges)
                            memm.data = first_node.data
                    body.remove_node(dst)
Exemplo n.º 16
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = sdutil.find_source_nodes(state)
        output_nodes = sdutil.find_sink_nodes(state)

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        parent_sdfg = {state: sdfg}  # Map states to their parent SDFG
        for node, graph in state.all_nodes_recursive():
            if isinstance(graph, dace.SDFG):
                parent_sdfg[node] = graph
            if isinstance(node, dace.sdfg.nodes.AccessNode):
                for e in graph.all_edges(node):
                    if e.data.wcr is not None:
                        trace = dace.sdfg.trace_nested_access(
                            node, graph, parent_sdfg[graph])
                        for node_trace, state_trace, sdfg_trace in trace:
                            # Find the name of the accessed node in our scope
                            if state_trace == state and sdfg_trace == sdfg:
                                outer_node = node_trace
                                break
                            else:
                                # This does not trace back to the current state, so
                                # we don't care
                                continue
                        input_nodes.append(outer_node)
                        wcr_input_nodes.add(outer_node)

        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        materialize_func=desc.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet(node.data, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    sdutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            sdutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, sd.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                    continue
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams
                    continue

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        desc.shape,
                        desc.dtype,
                        materialize_func=desc.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=desc.allow_conflicts,
                        strides=desc.strides,
                        offset=desc.offset)
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet('fpga_' + node.data,
                                    full_range.num_elements(), full_range, 1)
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                sdutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            sdutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, sd.InterstateEdge())

        veclen_ = 1

        # propagate vector info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            # need to go inside the nested SDFG and grab the vector length
            if isinstance(dst, dace.sdfg.nodes.NestedSDFG):
                # this edge is going to the nested SDFG
                for inner_state in dst.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.sdfg.nodes.AccessNode
                                      ) and n.data == dst_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen
            if isinstance(src, dace.sdfg.nodes.NestedSDFG):
                # this edge is coming from the nested SDFG
                for inner_state in src.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.sdfg.nodes.AccessNode
                                      ) and n.data == src_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen

            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
                mem.veclen = veclen_

        fpga_update(sdfg, state, 0)
Exemplo n.º 17
0
class StateAssignElimination(transformation.Transformation):
    """
    State assign elimination removes all assignments into the final state
    and subsumes the assigned value into its contents.
    """

    _end_state = sdfg.SDFGState()

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(StateAssignElimination._end_state)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        state = graph.nodes()[candidate[StateAssignElimination._end_state]]

        out_edges = graph.out_edges(state)
        in_edges = graph.in_edges(state)

        # We only match end states with one source and at least one assignment
        if len(in_edges) != 1:
            return False
        edge = in_edges[0]

        assignments_to_consider = _assignments_to_consider(sdfg, edge)

        # No assignments to eliminate
        if len(assignments_to_consider) == 0:
            return False

        # If this is an end state, there are no other edges to consider
        if len(out_edges) == 0:
            return True

        # Otherwise, ensure the symbols are never set/used again in edges
        akeys = set(assignments_to_consider.keys())
        for e in sdfg.edges():
            if e is edge:
                continue
            if e.data.free_symbols & akeys:
                return False

        # If used in any state that is not the current one, fail
        for s in sdfg.nodes():
            if s is state:
                continue
            if s.free_symbols & akeys:
                return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        state = graph.nodes()[candidate[StateAssignElimination._end_state]]
        return state.label

    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[StateAssignElimination._end_state]]
        edge = sdfg.in_edges(state)[0]
        # Since inter-state assignments that use an assigned value leads to
        # undefined behavior (e.g., {m: n, n: m}), we can replace each
        # assignment separately.
        keys_to_remove = set()
        assignments_to_consider = _assignments_to_consider(sdfg, edge)
        for varname, assignment in assignments_to_consider.items():
            state.replace(varname, assignment)
            keys_to_remove.add(varname)

        repl_dict = {}

        for varname in keys_to_remove:
            # Remove assignments from edge
            del edge.data.assignments[varname]

            for e in sdfg.edges():
                if varname in e.data.free_symbols:
                    break
            else:
                # If removed assignment does not appear in any other edge,
                # replace and remove symbol
                if varname in sdfg.symbols:
                    sdfg.remove_symbol(varname)
                # if assignments_to_consider[varname] in sdfg.symbols:
                if varname in sdfg.free_symbols:
                    repl_dict[varname] = assignments_to_consider[varname]
        
        def _str_repl(s, d):
            for k, v in d.items():
                s.replace(str(k), str(v))

        if repl_dict:
            symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m))
Exemplo n.º 18
0
class DetectLoop(transformation.Transformation):
    """ Detects a for-loop construct from an SDFG. """

    _loop_guard = sd.SDFGState()
    _loop_begin = sd.SDFGState()
    _exit_state = sd.SDFGState()

    @staticmethod
    def expressions():

        # Case 1: Loop with one state
        sdfg = sd.SDFG('_')
        sdfg.add_nodes_from([
            DetectLoop._loop_guard, DetectLoop._loop_begin,
            DetectLoop._exit_state
        ])
        sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin,
                      sd.InterstateEdge())
        sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state,
                      sd.InterstateEdge())
        sdfg.add_edge(DetectLoop._loop_begin, DetectLoop._loop_guard,
                      sd.InterstateEdge())

        # Case 2: Loop with multiple states (no back-edge from state)
        msdfg = sd.SDFG('_')
        msdfg.add_nodes_from([
            DetectLoop._loop_guard, DetectLoop._loop_begin,
            DetectLoop._exit_state
        ])
        msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin,
                       sd.InterstateEdge())
        msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state,
                       sd.InterstateEdge())

        return [sdfg, msdfg]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False):
        guard = graph.node(candidate[DetectLoop._loop_guard])
        begin = graph.node(candidate[DetectLoop._loop_begin])

        # A for-loop guard only has two incoming edges (init and increment)
        guard_inedges = graph.in_edges(guard)
        if len(guard_inedges) < 2:
            return False
        # A for-loop guard only has two outgoing edges (loop and exit-loop)
        guard_outedges = graph.out_edges(guard)
        if len(guard_outedges) != 2:
            return False

        # All incoming edges to the guard must set the same variable
        itvar = None
        for iedge in guard_inedges:
            if itvar is None:
                itvar = set(iedge.data.assignments.keys())
            else:
                itvar &= iedge.data.assignments.keys()
        if itvar is None:
            return False

        # Outgoing edges must be a negation of each other
        if guard_outedges[0].data.condition_sympy() != (sp.Not(
                guard_outedges[1].data.condition_sympy())):
            return False

        # All nodes inside loop must be dominated by loop guard
        dominators = nx.dominance.immediate_dominators(sdfg.nx,
                                                       sdfg.start_state)
        loop_nodes = sdutil.dfs_conditional(
            sdfg, sources=[begin], condition=lambda _, child: child != guard)
        backedge = None
        for node in loop_nodes:
            for e in graph.out_edges(node):
                if e.dst == guard:
                    backedge = e
                    break

            # Traverse the dominator tree upwards, if we reached the guard,
            # the node is in the loop. If we reach the starting state
            # without passing through the guard, fail.
            dom = node
            while dom != dominators[dom]:
                if dom == guard:
                    break
                dom = dominators[dom]
            else:
                return False

        if backedge is None:
            return False

        # The backedge must assignment the iteration variable
        itvar &= backedge.data.assignments.keys()
        if len(itvar) != 1:
            # Either no consistent iteration variable found, or too many
            # consistent iteration variables found
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        guard = graph.node(candidate[DetectLoop._loop_guard])
        begin = graph.node(candidate[DetectLoop._loop_begin])
        sexit = graph.node(candidate[DetectLoop._exit_state])
        ind = list(graph.in_edges(guard)[0].data.assignments.keys())[0]

        return (' -> '.join(state.label for state in [guard, begin, sexit]) +
                ' (for loop over "%s")' % ind)

    def apply(self, sdfg):
        pass
Exemplo n.º 19
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = nxutil.find_source_nodes(state)
        output_nodes = nxutil.find_sink_nodes(state)

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        for node, graph in state.all_nodes_recursive():
            if isinstance(node, dace.graph.nodes.AccessNode):
                for e in graph.all_edges(node):
                    if e.data.wcr is not None:
                        # This is an output node with wcr
                        # find the target in the parent sdfg

                        # following the structure State->SDFG->State-> SDFG
                        # from the current_state we have to go two levels up
                        parent_state = graph.parent.parent
                        if parent_state is not None:
                            for parent_edges in parent_state.edges():
                                if parent_edges.src_conn == e.dst.data or (
                                        isinstance(parent_edges.dst,
                                                   dace.graph.nodes.AccessNode)
                                        and e.dst.data
                                        == parent_edges.dst.data):
                                    # This must be copied to device
                                    input_nodes.append(parent_edges.dst)
                                    wcr_input_nodes.add(parent_edges.dst)

        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:
                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.shape,
                        array.dtype,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet(node.data, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    nxutil.change_edge_src(state, node, fpga_node)
                    state.remove_node(node)

            sdfg.add_node(pre_state)
            nxutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, edges.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:
                if (not isinstance(node, dace.graph.nodes.AccessNode)
                        or not isinstance(node.desc(sdfg), dace.data.Array)):
                    # Only transfer array nodes
                    # TODO: handle streams
                    continue

                array = node.desc(sdfg)
                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                else:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                        array.shape,
                        array.dtype,
                        materialize_func=array.materialize_func,
                        transient=True,
                        storage=dtypes.StorageType.FPGA_Global,
                        allow_conflicts=array.allow_conflicts,
                        strides=array.strides,
                        offset=array.offset)
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1)
                                            for s in array.shape])
                mem = memlet.Memlet('fpga_' + node.data,
                                    full_range.num_elements(), full_range, 1)
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                nxutil.change_edge_dest(state, node, fpga_node)
                state.remove_node(node)

            sdfg.add_node(post_state)
            nxutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, edges.InterstateEdge())

        veclen_ = 1

        # propagate vector info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            # need to go inside the nested SDFG and grab the vector length
            if isinstance(dst, dace.graph.nodes.NestedSDFG):
                # this edge is going to the nested SDFG
                for inner_state in dst.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == dst_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen
            if isinstance(src, dace.graph.nodes.NestedSDFG):
                # this edge is coming from the nested SDFG
                for inner_state in src.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.graph.nodes.AccessNode
                                      ) and n.data == src_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen

            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
                mem.veclen = veclen_

        fpga_update(sdfg, state, 0)