    def apply(self, _, sdfg):
        state = self.state

        # Find source/sink (data) nodes that are relevant outside this FPGA
        # kernel
        shared_transients = set(sdfg.shared_transients())
        input_nodes = [
            n for n in sdutil.find_source_nodes(state)
            if isinstance(n, nodes.AccessNode) and
            (not sdfg.arrays[n.data].transient or n.data in shared_transients)
        output_nodes = [
            n for n in sdutil.find_sink_nodes(state)
            if isinstance(n, nodes.AccessNode) and
            (not sdfg.arrays[n.data].transient or n.data in shared_transients)

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        parent_sdfg = {state: sdfg}  # Map states to their parent SDFG
        for node, graph in state.all_nodes_recursive():
            if isinstance(graph, dace.SDFG):
                parent_sdfg[node] = graph
            if isinstance(node, dace.sdfg.nodes.AccessNode):
                for e in graph.in_edges(node):
                    if e.data.wcr is not None:
                        trace = dace.sdfg.trace_nested_access(
                            node, graph, parent_sdfg[graph])
                        for node_trace, memlet_trace, state_trace, sdfg_trace in trace:
                            # Find the name of the accessed node in our scope
                            if state_trace == state and sdfg_trace == sdfg:
                                _, outer_node = node_trace
                                if outer_node is not None:
                            # This does not trace back to the current state, so
                            # we don't care
        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                    fpga_array[1].location = copy.copy(desc.location)
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                mem = memlet.Memlet(data=node.data,
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    sdutil.change_edge_src(state, node, fpga_node)

            sdutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, sd.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                    fpga_array[1].location = copy.copy(desc.location)
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                mem = memlet.Memlet(f"fpga_{node.data}", None,
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                sdutil.change_edge_dest(state, node, fpga_node)

            sdutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, sd.InterstateEdge())

        # propagate memlet info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
        fpga_update(sdfg, state, 0)
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[FPGATransformState._state]]

        # Find source/sink (data) nodes
        input_nodes = sdutil.find_source_nodes(state)
        output_nodes = sdutil.find_sink_nodes(state)

        fpga_data = {}

        # Input nodes may also be nodes with WCR memlets
        # We have to recur across nested SDFGs to find them
        wcr_input_nodes = set()
        stack = []

        parent_sdfg = {state: sdfg}  # Map states to their parent SDFG
        for node, graph in state.all_nodes_recursive():
            if isinstance(graph, dace.SDFG):
                parent_sdfg[node] = graph
            if isinstance(node, dace.sdfg.nodes.AccessNode):
                for e in graph.all_edges(node):
                    if e.data.wcr is not None:
                        trace = dace.sdfg.trace_nested_access(
                            node, graph, parent_sdfg[graph])
                        for node_trace, state_trace, sdfg_trace in trace:
                            # Find the name of the accessed node in our scope
                            if state_trace == state and sdfg_trace == sdfg:
                                outer_node = node_trace
                                # This does not trace back to the current state, so
                                # we don't care

        if input_nodes:
            # create pre_state
            pre_state = sd.SDFGState('pre_' + state.label, sdfg)

            for node in input_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                elif node not in wcr_input_nodes:
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                    fpga_data[node.data] = fpga_array

                pre_node = pre_state.add_read(node.data)
                pre_fpga_node = pre_state.add_write('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet(node.data, full_range.num_elements(),
                                    full_range, 1)
                pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem)

                if node not in wcr_input_nodes:
                    fpga_node = state.add_read('fpga_' + node.data)
                    sdutil.change_edge_src(state, node, fpga_node)

            sdutil.change_edge_dest(sdfg, state, pre_state)
            sdfg.add_edge(pre_state, state, sd.InterstateEdge())

        if output_nodes:

            post_state = sd.SDFGState('post_' + state.label, sdfg)

            for node in output_nodes:

                if not isinstance(node, dace.sdfg.nodes.AccessNode):
                desc = node.desc(sdfg)
                if not isinstance(desc, dace.data.Array):
                    # TODO: handle streams

                if node.data in fpga_data:
                    fpga_array = fpga_data[node.data]
                    fpga_array = sdfg.add_array(
                        'fpga_' + node.data,
                    fpga_data[node.data] = fpga_array
                # fpga_node = type(node)(fpga_array)

                post_node = post_state.add_write(node.data)
                post_fpga_node = post_state.add_read('fpga_' + node.data)
                full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape])
                mem = memlet.Memlet('fpga_' + node.data,
                                    full_range.num_elements(), full_range, 1)
                post_state.add_edge(post_fpga_node, None, post_node, None, mem)

                fpga_node = state.add_write('fpga_' + node.data)
                sdutil.change_edge_dest(state, node, fpga_node)

            sdutil.change_edge_src(sdfg, state, post_state)
            sdfg.add_edge(state, post_state, sd.InterstateEdge())

        veclen_ = 1

        # propagate vector info from a nested sdfg
        for src, src_conn, dst, dst_conn, mem in state.edges():
            # need to go inside the nested SDFG and grab the vector length
            if isinstance(dst, dace.sdfg.nodes.NestedSDFG):
                # this edge is going to the nested SDFG
                for inner_state in dst.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.sdfg.nodes.AccessNode
                                      ) and n.data == dst_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen
            if isinstance(src, dace.sdfg.nodes.NestedSDFG):
                # this edge is coming from the nested SDFG
                for inner_state in src.sdfg.states():
                    for n in inner_state.nodes():
                        if isinstance(n, dace.sdfg.nodes.AccessNode
                                      ) and n.data == src_conn:
                            # assuming all memlets have the same vector length
                            veclen_ = inner_state.all_edges(n)[0].data.veclen

            if mem.data is not None and mem.data in fpga_data:
                mem.data = 'fpga_' + mem.data
                mem.veclen = veclen_

        fpga_update(sdfg, state, 0)
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        out_edges = graph.out_edges(first_state)
        in_edges = graph.in_edges(first_state)

        # First state must have only one output edge (with dst the second
        # state).
        if len(out_edges) != 1:
            return False
        # The interstate edge must not have a condition.
        if not out_edges[0].data.is_unconditional():
            return False
        # The interstate edge may have assignments, as long as there are input
        # edges to the first state that can absorb them.
        if out_edges[0].data.assignments:
            if not in_edges:
                return False
            # Fail if symbol is set before the state to fuse
            # TODO: Also fail if symbol is used in the dataflow of that state
            new_assignments = set(out_edges[0].data.assignments.keys())
            if any((new_assignments & set(e.data.assignments.keys()))
                   for e in in_edges):
                return False

        # There can be no state that have output edges pointing to both the
        # first and the second state. Such a case will produce a multi-graph.
        for src, _, _ in in_edges:
            for _, dst, _ in graph.out_edges(src):
                if dst == second_state:
                    return False

        if strict:
            # If second state has other input edges, there might be issues
            # Exceptions are when none of the states contain dataflow, unless
            # the first state is an initial state (in which case the new initial
            # state would be ambiguous).
            first_in_edges = graph.in_edges(first_state)
            second_in_edges = graph.in_edges(second_state)
            if ((not second_state.is_empty() or not first_state.is_empty()
                 or len(first_in_edges) == 0) and len(second_in_edges) != 1):
                return False

            # Get connected components.
            first_cc = [
                for cc_nodes in nx.weakly_connected_components(first_state._nx)
            second_cc = [
                cc_nodes for cc_nodes in nx.weakly_connected_components(

            # Find source/sink (data) nodes
            first_input = {
                for node in sdutil.find_source_nodes(first_state)
                if isinstance(node, nodes.AccessNode)
            first_output = {
                for node in first_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in first_input
            second_input = {
                for node in sdutil.find_source_nodes(second_state)
                if isinstance(node, nodes.AccessNode)
            second_output = {
                for node in second_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in second_input

            # Find source/sink (data) nodes by connected component
            first_cc_input = [cc.intersection(first_input) for cc in first_cc]
            first_cc_output = [
                cc.intersection(first_output) for cc in first_cc
            second_cc_input = [
                cc.intersection(second_input) for cc in second_cc
            second_cc_output = [
                cc.intersection(second_output) for cc in second_cc

            # Apply transformation in case all paths to the second state's
            # nodes go through the same access node, which implies sequential
            # behavior in SDFG semantics.
            check_strict = len(first_cc)
            for cc_output in first_cc_output:
                out_nodes = [
                    n for n in first_state.sink_nodes() if n in cc_output
                # Branching exists, multiple paths may involve same access node
                # potentially causing data races
                if len(out_nodes) > 1:

                # Otherwise, check if any of the second state's connected
                # components for matching input
                for node in out_nodes:
                    if (next(
                        (x for x in second_input if x.label == node.label),
                            None) is not None):
                        check_strict -= 1

            if check_strict > 0:
                # Check strict conditions
                # RW dependency
                for node in first_input:
                    if (next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None):
                        return False
                # WW dependency
                for node in first_output:
                    if (next(
                        (x for x in second_output if x.label == node.label),
                            None) is not None):
                        return False

        return True
    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):

        # Special case 1: first state is empty
        if first_state.is_empty():
            sdutil.change_edge_dest(sdfg, first_state, second_state)

        # Special case 2: second state is empty
        if second_state.is_empty():
            sdutil.change_edge_src(sdfg, second_state, first_state)
            sdutil.change_edge_dest(sdfg, second_state, first_state)

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in sdutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        first_output = [
            node for node in sdutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        second_input = [
            node for node in sdutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None

        # Merge second state to first state
        # First keep a backup of the topological sorted order of the nodes
        order = [
            x for x in reversed(list(nx.topological_sort(first_state._nx)))
            if isinstance(x, nodes.AccessNode)
        for node in second_state.nodes():
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in second_input:
            if first_state.in_degree(node) == 0:
                n = next((x for x in order if x.label == node.label), None)
                if n:
                    sdutil.change_edge_src(first_state, node, n)
                    n.access = dtypes.AccessType.ReadWrite

        # Redirect edges and remove second state
        sdutil.change_edge_src(sdfg, second_state, first_state)
        if Config.get_bool("debugprint"):
            StateFusion._states_fused += 1
    def apply(self, sdfg):
        if isinstance(self.subgraph[StateFusion.first_state], SDFGState):
            first_state: SDFGState = self.subgraph[StateFusion.first_state]
            second_state: SDFGState = self.subgraph[StateFusion.second_state]
            first_state: SDFGState = sdfg.node(
            second_state: SDFGState = sdfg.node(

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):

        # Special case 1: first state is empty
        if first_state.is_empty():
            sdutil.change_edge_dest(sdfg, first_state, second_state)
            if sdfg.start_state == first_state:
                sdfg.start_state = sdfg.node_id(second_state)

        # Special case 2: second state is empty
        if second_state.is_empty():
            sdutil.change_edge_src(sdfg, second_state, first_state)
            sdutil.change_edge_dest(sdfg, second_state, first_state)
            if sdfg.start_state == second_state:
                sdfg.start_state = sdfg.node_id(first_state)

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in sdutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        first_output = [
            node for node in sdutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        second_input = [
            node for node in sdutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)

        top2 = top_level_nodes(second_state)

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.data == node.data), None) is None

        # Merge second state to first state
        # First keep a backup of the topological sorted order of the nodes
        sdict = first_state.scope_dict()
        order = [
            x for x in reversed(list(nx.topological_sort(first_state._nx)))
            if isinstance(x, nodes.AccessNode) and sdict[x] is None
        for node in second_state.nodes():
            if isinstance(node, nodes.NestedSDFG):
                # update parent information
                node.sdfg.parent = first_state
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        top = top_level_nodes(first_state)

        # Merge common (data) nodes
        for node in second_input:

            # merge only top level nodes, skip everything else
            if node not in top2:

            if first_state.in_degree(node) == 0:
                candidates = [
                    x for x in order if x.data == node.data and x in top
                if len(candidates) == 0:
                elif len(candidates) == 1:
                    n = candidates[0]
                    # Choose first candidate that intersects memlets
                    for cand in candidates:
                        if StateFusion.memlets_intersect(
                                first_state, [cand], False, second_state,
                            [node], True):
                            n = cand
                        # No node intersects, use topologically-last node
                        n = candidates[0]

                sdutil.change_edge_src(first_state, node, n)
                n.access = dtypes.AccessType.ReadWrite

        # Redirect edges and remove second state
        sdutil.change_edge_src(sdfg, second_state, first_state)
        if sdfg.start_state == second_state:
            sdfg.start_state = sdfg.node_id(first_state)
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        # Workaround for supporting old and new conventions
        if isinstance(candidate[StateFusion.first_state], SDFGState):
            first_state: SDFGState = candidate[StateFusion.first_state]
            second_state: SDFGState = candidate[StateFusion.second_state]
            first_state: SDFGState = graph.node(
            second_state: SDFGState = graph.node(

        out_edges = graph.out_edges(first_state)
        in_edges = graph.in_edges(first_state)

        # First state must have only one output edge (with dst the second
        # state).
        if len(out_edges) != 1:
            return False
        # If both states have more than one incoming edge, some control flow
        # may become ambiguous
        if len(in_edges) > 1 and graph.in_degree(second_state) > 1:
            return False
        # The interstate edge must not have a condition.
        if not out_edges[0].data.is_unconditional():
            return False
        # The interstate edge may have assignments, as long as there are input
        # edges to the first state that can absorb them.
        if out_edges[0].data.assignments:
            if not in_edges:
                return False
            # Fail if symbol is set before the state to fuse
            new_assignments = set(out_edges[0].data.assignments.keys())
            if any((new_assignments & set(e.data.assignments.keys()))
                   for e in in_edges):
                return False
            # Fail if symbol is used in the dataflow of that state
            if len(new_assignments & first_state.free_symbols) > 0:
                return False
            # Fail if assignments have free symbols that are updated in the
            # first state
            freesyms = out_edges[0].data.free_symbols
            if freesyms and any(n.data in freesyms for n in first_state.nodes()
                                if isinstance(n, nodes.AccessNode)
                                and first_state.in_degree(n) > 0):
                return False
            # Fail if symbols assigned on the first edge are free symbols on the
            # second edge
            symbols_used = set(out_edges[0].data.free_symbols)
            for e in in_edges:
                if e.data.assignments.keys() & symbols_used:
                    return False

        # There can be no state that have output edges pointing to both the
        # first and the second state. Such a case will produce a multi-graph.
        for src, _, _ in in_edges:
            for _, dst, _ in graph.out_edges(src):
                if dst == second_state:
                    return False

        if strict:

            # NOTE: This is quick fix for MPI Waitall (probably also needed for
            # Wait), until we have a better SDFG representation of the buffer
            # dependencies.
                from dace.libraries.mpi import Waitall
                next(node for node in first_state.nodes()
                     if isinstance(node, Waitall) or node.label == '_Waitall_')
                return False
            except StopIteration:
                from dace.libraries.mpi import Waitall
                next(node for node in second_state.nodes()
                     if isinstance(node, Waitall) or node.label == '_Waitall_')
                return False
            except StopIteration:

            # If second state has other input edges, there might be issues
            # Exceptions are when none of the states contain dataflow, unless
            # the first state is an initial state (in which case the new initial
            # state would be ambiguous).
            first_in_edges = graph.in_edges(first_state)
            second_in_edges = graph.in_edges(second_state)
            if ((not second_state.is_empty() or not first_state.is_empty()
                 or len(first_in_edges) == 0) and len(second_in_edges) != 1):
                return False

            # Get connected components.
            first_cc = [
                for cc_nodes in nx.weakly_connected_components(first_state._nx)
            second_cc = [
                for cc_nodes in nx.weakly_connected_components(second_state._nx)

            # Find source/sink (data) nodes
            first_input = {
                for node in sdutil.find_source_nodes(first_state)
                if isinstance(node, nodes.AccessNode)
            first_output = {
                for node in first_state.scope_children()[None] if
                isinstance(node, nodes.AccessNode) and node not in first_input
            second_input = {
                for node in sdutil.find_source_nodes(second_state)
                if isinstance(node, nodes.AccessNode)
            second_output = {
                for node in second_state.scope_children()[None] if
                isinstance(node, nodes.AccessNode) and node not in second_input

            # Find source/sink (data) nodes by connected component
            first_cc_input = [cc.intersection(first_input) for cc in first_cc]
            first_cc_output = [cc.intersection(first_output) for cc in first_cc]
            second_cc_input = [
                cc.intersection(second_input) for cc in second_cc
            second_cc_output = [
                cc.intersection(second_output) for cc in second_cc

            # Apply transformation in case all paths to the second state's
            # nodes go through the same access node, which implies sequential
            # behavior in SDFG semantics.
            first_output_names = {node.data for node in first_output}
            second_input_names = {node.data for node in second_input}

            # If any second input appears more than once, fail
            if len(second_input) > len(second_input_names):
                return False

            # If any first output that is an input to the second state
            # appears in more than one CC, fail
            matches = first_output_names & second_input_names
            for match in matches:
                cc_appearances = 0
                for cc in first_cc_output:
                    if len([n for n in cc if n.data == match]) > 0:
                        cc_appearances += 1
                if cc_appearances > 1:
                    return False

            # Recreate fused connected component correspondences, and then
            # check for hazards
            resulting_ccs: List[CCDesc] = StateFusion.find_fused_components(
                first_cc_input, first_cc_output, second_cc_input,

            # Check for data races
            for fused_cc in resulting_ccs:
                # Write-Write hazard - data is output of both first and second
                # states, without a read in between
                write_write_candidates = (
                    (fused_cc.first_outputs & fused_cc.second_outputs) -
                # Find the leaf (topological) instances of the matches
                order = [
                    x for x in reversed(
                    if isinstance(x, nodes.AccessNode)
                    and x.data in fused_cc.first_outputs
                # Those nodes will be the connection points upon fusion
                match_nodes = {
                    next(n for n in order if n.data == match)
                    for match in (fused_cc.first_outputs
                                    & fused_cc.second_inputs)

                # If we have potential candidates, check if there is a
                # path from the first write to the second write (in that
                # case, there is no hazard):
                for cand in write_write_candidates:
                    nodes_first = [n for n in first_output if n.data == cand]
                    nodes_second = [n for n in second_output if n.data == cand]

                    # If there is a path for the candidate that goes through
                    # the match nodes in both states, there is no conflict
                    fail = False
                    path_found = False
                    for match in match_nodes:
                        for node in nodes_first:
                            path_to = nx.has_path(first_state._nx, node, match)
                            if not path_to:
                            path_found = True
                            node2 = next(n for n in second_input
                                         if n.data == match.data)
                            if not all(
                                    nx.has_path(second_state._nx, node2, n)
                                    for n in nodes_second):
                                fail = True
                        if fail or path_found:

                    # Check for intersection (if None, fusion is ok)
                    if fail or not path_found:
                        if StateFusion.memlets_intersect(
                                first_state, nodes_first, False, second_state,
                                nodes_second, False):
                            return False
                # End of write-write hazard check

                first_inout = fused_cc.first_inputs | fused_cc.first_outputs
                for other_cc in resulting_ccs:
                    # NOTE: Special handling for `other_cc is fused_cc`
                    if other_cc is fused_cc:
                        # Checking for potential Read-Write data races
                        for d in first_inout:
                            if d in other_cc.second_outputs:
                                nodes_second = [
                                    n for n in second_output if n.data == d
                                # Read-Write race
                                if d in fused_cc.first_inputs:
                                    nodes_first = [
                                        n for n in first_input if n.data == d
                                    nodes_first = []
                                for n2 in nodes_second:
                                    for e in second_state.in_edges(n2):
                                        path = second_state.memlet_path(e)
                                        src = path[0].src
                                        if src in second_input and src.data in fused_cc.first_outputs:
                                            for n1 in fused_cc.first_output_nodes:
                                                if n1.data == src.data:
                                                    for n0 in nodes_first:
                                                        if not nx.has_path(
                                                                n0, n1):
                                                            return False
                    # If an input/output of a connected component in the first
                    # state is an output of another connected component in the
                    # second state, we have a potential data race (Read-Write
                    # or Write-Write)
                    for d in first_inout:
                        if d in other_cc.second_outputs:
                            # Check for intersection (if None, fusion is ok)
                            nodes_second = [
                                n for n in second_output if n.data == d
                            # Read-Write race
                            if d in fused_cc.first_inputs:
                                nodes_first = [
                                    n for n in first_input if n.data == d
                                if StateFusion.memlets_intersect(
                                        first_state, nodes_first, True,
                                        second_state, nodes_second, False):
                                    return False
                            # Write-Write race
                            if d in fused_cc.first_outputs:
                                nodes_first = [
                                    n for n in first_output if n.data == d
                                if StateFusion.memlets_intersect(
                                        first_state, nodes_first, False,
                                        second_state, nodes_second, False):
                                    return False
                    # End of data race check

                # Read-after-write dependencies: if there is an output of the
                # second state that is an input of the first, ensure all paths
                # from the input of the first state lead to the output.
                # Otherwise, there may be a RAW due to topological sort or
                # concurrency.
                second_inout = ((fused_cc.first_inputs | fused_cc.first_outputs)
                                & fused_cc.second_outputs)
                for inout in second_inout:
                    nodes_first = [
                        n for n in match_nodes
                        if n.data == inout
                    if any(first_state.out_degree(n) > 0 for n in nodes_first):
                        return False

                # Read-after-write dependencies: if there is more than one first
                # output with the same data, make sure it can be unambiguously
                # connected to the second state
                if (len(fused_cc.first_output_nodes) > len(
                    for inpnode in fused_cc.second_input_nodes:
                        found = None
                        for outnode in fused_cc.first_output_nodes:
                            if outnode.data != inpnode.data:
                            if StateFusion.memlets_intersect(
                                    first_state, [outnode], False, second_state,
                                [inpnode], True):
                                # If found more than once, either there is a
                                # path from one to another or it is ambiguous
                                if found is not None:
                                    if nx.has_path(first_state.nx, outnode,
                                        # Found is a descendant, continue
                                    elif nx.has_path(first_state.nx, found,
                                        # New node is a descendant, set as found
                                        found = outnode
                                        # No path: ambiguous match
                                        return False
                                found = outnode

        return True
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        first_state = graph.nodes()[candidate[StateFusion._first_state]]
        second_state = graph.nodes()[candidate[StateFusion._second_state]]

        out_edges = graph.out_edges(first_state)
        in_edges = graph.in_edges(first_state)

        # First state must have only one output edge (with dst the second
        # state).
        if len(out_edges) != 1:
            return False
        # The interstate edge must not have a condition.
        if not out_edges[0].data.is_unconditional():
            return False
        # The interstate edge may have assignments, as long as there are input
        # edges to the first state that can absorb them.
        if out_edges[0].data.assignments:
            if not in_edges:
                return False
            # Fail if symbol is set before the state to fuse
            new_assignments = set(out_edges[0].data.assignments.keys())
            if any((new_assignments & set(e.data.assignments.keys()))
                   for e in in_edges):
                return False
            # Fail if symbol is used in the dataflow of that state
            if len(new_assignments & first_state.free_symbols) > 0:
                return False

        # There can be no state that have output edges pointing to both the
        # first and the second state. Such a case will produce a multi-graph.
        for src, _, _ in in_edges:
            for _, dst, _ in graph.out_edges(src):
                if dst == second_state:
                    return False

        if strict:
            # If second state has other input edges, there might be issues
            # Exceptions are when none of the states contain dataflow, unless
            # the first state is an initial state (in which case the new initial
            # state would be ambiguous).
            first_in_edges = graph.in_edges(first_state)
            second_in_edges = graph.in_edges(second_state)
            if ((not second_state.is_empty() or not first_state.is_empty()
                 or len(first_in_edges) == 0) and len(second_in_edges) != 1):
                return False

            # Get connected components.
            first_cc = [
                for cc_nodes in nx.weakly_connected_components(first_state._nx)
            second_cc = [
                cc_nodes for cc_nodes in nx.weakly_connected_components(

            # Find source/sink (data) nodes
            first_input = {
                for node in sdutil.find_source_nodes(first_state)
                if isinstance(node, nodes.AccessNode)
            first_output = {
                for node in first_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in first_input
            second_input = {
                for node in sdutil.find_source_nodes(second_state)
                if isinstance(node, nodes.AccessNode)
            second_output = {
                for node in second_state.nodes() if
                isinstance(node, nodes.AccessNode) and node not in second_input

            # Find source/sink (data) nodes by connected component
            first_cc_input = [cc.intersection(first_input) for cc in first_cc]
            first_cc_output = [
                cc.intersection(first_output) for cc in first_cc
            second_cc_input = [
                cc.intersection(second_input) for cc in second_cc
            second_cc_output = [
                cc.intersection(second_output) for cc in second_cc

            # Apply transformation in case all paths to the second state's
            # nodes go through the same access node, which implies sequential
            # behavior in SDFG semantics.
            first_output_names = {node.data for node in first_output}
            second_input_names = {node.data for node in second_input}

            # If any second input appears more than once, fail
            if len(second_input) > len(second_input_names):
                return False

            # If any first output that is an input to the second state
            # appears in more than one CC, fail
            matches = first_output_names & second_input_names
            for match in matches:
                cc_appearances = 0
                for cc in first_cc_output:
                    if len([n for n in cc if n.data == match]) > 0:
                        cc_appearances += 1
                if cc_appearances > 1:
                    return False

            # Recreate fused connected component correspondences, and then
            # check for hazards
            resulting_ccs: List[CCDesc] = StateFusion.find_fused_components(
                first_cc_input, first_cc_output, second_cc_input,

            # Check for data races
            for fused_cc in resulting_ccs:
                # Write-Write hazard - data is output of both first and second
                # states, without a read in between
                write_write_candidates = (
                    (fused_cc.first_outputs & fused_cc.second_outputs) -
                if len(write_write_candidates) > 0:
                    # If we have potential candidates, check if there is a
                    # path from the first write to the second write (in that
                    # case, there is no hazard):
                    # Find the leaf (topological) instances of the matches
                    order = [
                        x for x in reversed(
                        if isinstance(x, nodes.AccessNode)
                        and x.data in fused_cc.first_outputs
                    # Those nodes will be the connection points upon fusion
                    match_nodes = {
                        next(n for n in order if n.data == match)
                        for match in (fused_cc.first_outputs
                                      & fused_cc.second_inputs)
                    match_nodes = set()

                for cand in write_write_candidates:
                    nodes_first = [n for n in first_output if n.data == cand]
                    nodes_second = [n for n in second_output if n.data == cand]

                    # If there is a path for the candidate that goes through
                    # the match nodes in both states, there is no conflict
                    fail = False
                    path_found = False
                    for match in match_nodes:
                        for node in nodes_first:
                            path_to = nx.has_path(first_state._nx, node, match)
                            if not path_to:
                            path_found = True
                            node2 = next(n for n in second_input
                                         if n.data == match.data)
                            if not all(
                                    nx.has_path(second_state._nx, node2, n)
                                    for n in nodes_second):
                                fail = True
                        if fail or path_found:

                    # Check for intersection (if None, fusion is ok)
                    if fail or not path_found:
                        if StateFusion.memlets_intersect(
                                first_state, nodes_first, False, second_state,
                                nodes_second, False):
                            return False
                # End of write-write hazard check

                first_inout = fused_cc.first_inputs | fused_cc.first_outputs
                for other_cc in resulting_ccs:
                    if other_cc is fused_cc:
                    # If an input/output of a connected component in the first
                    # state is an output of another connected component in the
                    # second state, we have a potential data race (Read-Write
                    # or Write-Write)
                    for d in first_inout:
                        if d in other_cc.second_outputs:
                            # Check for intersection (if None, fusion is ok)
                            nodes_second = [
                                n for n in second_output if n.data == d
                            # Read-Write race
                            if d in fused_cc.first_inputs:
                                nodes_first = [
                                    n for n in first_input if n.data == d
                                if StateFusion.memlets_intersect(
                                        first_state, nodes_first, True,
                                        second_state, nodes_second, False):
                                    return False
                            # Write-Write race
                            if d in fused_cc.first_outputs:
                                nodes_first = [
                                    n for n in first_output if n.data == d
                                if StateFusion.memlets_intersect(
                                        first_state, nodes_first, False,
                                        second_state, nodes_second, False):
                                    return False
                    # End of data race check

        return True