Beispiel #1
0
def symbols_in_ast(tree: ast.AST):
    """ Walks an AST and finds all names, excluding function names. """
    symbols = []
    skip = set()
    for node in ast.walk(tree):
        if node in skip:
            continue
        if isinstance(node, ast.Call):
            skip.add(node.func)
        if isinstance(node, ast.Name):
            symbols.append(node.id)
    return dtypes.deduplicate(symbols)
Beispiel #2
0
def symbols_in_ast(tree):
    """ Walks an AST and finds all names, excluding function names. """
    to_visit = list(tree.__dict__.items())
    symbols = []
    while len(to_visit) > 0:
        (key, val) = to_visit.pop()
        if key == "func":
            continue
        if isinstance(val, ast.Name):
            symbols.append(val.id)
            continue
        if isinstance(val, ast.expr):
            to_visit += list(val.__dict__.items())
        if isinstance(val, list):
            to_visit += [(key, v) for v in val]
    return dtypes.deduplicate(symbols)
Beispiel #3
0
    def apply(self, sdfg: sd.SDFG):

        #######################################################
        # Step 0: SDFG metadata

        # Find all input and output data descriptors
        input_nodes = []
        output_nodes = []
        global_code_nodes = [[] for _ in sdfg.nodes()]

        for i, state in enumerate(sdfg.nodes()):
            sdict = state.scope_dict()
            for node in state.nodes():
                if (isinstance(node, nodes.AccessNode)
                        and node.desc(sdfg).transient == False):
                    if (state.out_degree(node) > 0
                            and node.data not in input_nodes):
                        # Special case: nodes that lead to top-level dynamic
                        # map ranges must stay on host
                        for e in state.out_edges(node):
                            last_edge = state.memlet_path(e)[-1]
                            if (isinstance(last_edge.dst, nodes.EntryNode)
                                    and last_edge.dst_conn and
                                    not last_edge.dst_conn.startswith('IN_')
                                    and sdict[last_edge.dst] is None):
                                break
                        else:
                            input_nodes.append((node.data, node.desc(sdfg)))
                    if (state.in_degree(node) > 0
                            and node.data not in output_nodes):
                        output_nodes.append((node.data, node.desc(sdfg)))
                elif isinstance(node, nodes.CodeNode) and sdict[node] is None:
                    if not isinstance(node,
                                      (nodes.LibraryNode, nodes.NestedSDFG)):
                        global_code_nodes[i].append(node)

            # Input nodes may also be nodes with WCR memlets and no identity
            for e in state.edges():
                if e.data.wcr is not None:
                    if (e.data.data not in input_nodes
                            and sdfg.arrays[e.data.data].transient == False):
                        input_nodes.append(
                            (e.data.data, sdfg.arrays[e.data.data]))

        start_state = sdfg.start_state
        end_states = sdfg.sink_nodes()

        #######################################################
        # Step 1: Create cloned GPU arrays and replace originals

        cloned_arrays = {}
        for inodename, inode in set(input_nodes):
            if isinstance(inode, data.Scalar):  # Scalars can remain on host
                continue
            if inode.storage == dtypes.StorageType.GPU_Global:
                continue
            newdesc = inode.clone()
            newdesc.storage = dtypes.StorageType.GPU_Global
            newdesc.transient = True
            name = sdfg.add_datadesc('gpu_' + inodename,
                                     newdesc,
                                     find_new_name=True)
            cloned_arrays[inodename] = name

        for onodename, onode in set(output_nodes):
            if onodename in cloned_arrays:
                continue
            if onode.storage == dtypes.StorageType.GPU_Global:
                continue
            newdesc = onode.clone()
            newdesc.storage = dtypes.StorageType.GPU_Global
            newdesc.transient = True
            name = sdfg.add_datadesc('gpu_' + onodename,
                                     newdesc,
                                     find_new_name=True)
            cloned_arrays[onodename] = name

        # Replace nodes
        for state in sdfg.nodes():
            for node in state.nodes():
                if (isinstance(node, nodes.AccessNode)
                        and node.data in cloned_arrays):
                    node.data = cloned_arrays[node.data]

        # Replace memlets
        for state in sdfg.nodes():
            for edge in state.edges():
                if edge.data.data in cloned_arrays:
                    edge.data.data = cloned_arrays[edge.data.data]

        #######################################################
        # Step 2: Create copy-in state
        excluded_copyin = self.exclude_copyin.split(',')

        copyin_state = sdfg.add_state(sdfg.label + '_copyin')
        sdfg.add_edge(copyin_state, start_state, sd.InterstateEdge())

        for nname, desc in dtypes.deduplicate(input_nodes):
            if nname in excluded_copyin or nname not in cloned_arrays:
                continue
            src_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo)
            dst_array = nodes.AccessNode(cloned_arrays[nname],
                                         debuginfo=desc.debuginfo)
            copyin_state.add_node(src_array)
            copyin_state.add_node(dst_array)
            copyin_state.add_nedge(
                src_array, dst_array,
                memlet.Memlet.from_array(src_array.data, src_array.desc(sdfg)))

        #######################################################
        # Step 3: Create copy-out state
        excluded_copyout = self.exclude_copyout.split(',')

        copyout_state = sdfg.add_state(sdfg.label + '_copyout')
        for state in end_states:
            sdfg.add_edge(state, copyout_state, sd.InterstateEdge())

        for nname, desc in dtypes.deduplicate(output_nodes):
            if nname in excluded_copyout or nname not in cloned_arrays:
                continue
            src_array = nodes.AccessNode(cloned_arrays[nname],
                                         debuginfo=desc.debuginfo)
            dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo)
            copyout_state.add_node(src_array)
            copyout_state.add_node(dst_array)
            copyout_state.add_nedge(
                src_array, dst_array,
                memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg)))

        #######################################################
        # Step 4: Modify transient data storage

        for state in sdfg.nodes():
            sdict = state.scope_dict()
            for node in state.nodes():
                if isinstance(node,
                              nodes.AccessNode) and node.desc(sdfg).transient:
                    nodedesc = node.desc(sdfg)

                    # Special case: nodes that lead to dynamic map ranges must
                    # stay on host
                    if any(
                            isinstance(
                                state.memlet_path(e)[-1].dst, nodes.EntryNode)
                            for e in state.out_edges(node)):
                        continue

                    gpu_storage = [
                        dtypes.StorageType.GPU_Global,
                        dtypes.StorageType.GPU_Shared,
                        dtypes.StorageType.CPU_Pinned
                    ]
                    if sdict[
                            node] is None and nodedesc.storage not in gpu_storage:
                        # NOTE: the cloned arrays match too but it's the same
                        # storage so we don't care
                        nodedesc.storage = dtypes.StorageType.GPU_Global

                        # Try to move allocation/deallocation out of loops
                        if (self.toplevel_trans
                                and not isinstance(nodedesc, data.Stream)):
                            nodedesc.lifetime = dtypes.AllocationLifetime.SDFG
                    elif nodedesc.storage not in gpu_storage:
                        # Make internal transients registers
                        if self.register_trans:
                            nodedesc.storage = dtypes.StorageType.Register

        #######################################################
        # Step 5: Wrap free tasklets and nested SDFGs with a GPU map

        for state, gcodes in zip(sdfg.nodes(), global_code_nodes):
            for gcode in gcodes:
                if gcode.label in self.exclude_tasklets.split(','):
                    continue
                # Create map and connectors
                me, mx = state.add_map(gcode.label + '_gmap',
                                       {gcode.label + '__gmapi': '0:1'},
                                       schedule=dtypes.ScheduleType.GPU_Device)
                # Store in/out edges in lists so that they don't get corrupted
                # when they are removed from the graph
                in_edges = list(state.in_edges(gcode))
                out_edges = list(state.out_edges(gcode))
                me.in_connectors = {('IN_' + e.dst_conn): None
                                    for e in in_edges}
                me.out_connectors = {('OUT_' + e.dst_conn): None
                                     for e in in_edges}
                mx.in_connectors = {('IN_' + e.src_conn): None
                                    for e in out_edges}
                mx.out_connectors = {('OUT_' + e.src_conn): None
                                     for e in out_edges}

                # Create memlets through map
                for e in in_edges:
                    state.remove_edge(e)
                    state.add_edge(e.src, e.src_conn, me, 'IN_' + e.dst_conn,
                                   e.data)
                    state.add_edge(me, 'OUT_' + e.dst_conn, e.dst, e.dst_conn,
                                   e.data)
                for e in out_edges:
                    state.remove_edge(e)
                    state.add_edge(e.src, e.src_conn, mx, 'IN_' + e.src_conn,
                                   e.data)
                    state.add_edge(mx, 'OUT_' + e.src_conn, e.dst, e.dst_conn,
                                   e.data)

                # Map without inputs
                if len(in_edges) == 0:
                    state.add_nedge(me, gcode, memlet.Memlet())
        #######################################################
        # Step 6: Change all top-level maps and library nodes to GPU schedule

        for i, state in enumerate(sdfg.nodes()):
            sdict = state.scope_dict()
            for node in state.nodes():
                if isinstance(node, (nodes.EntryNode, nodes.LibraryNode)):
                    if sdict[node] is None:
                        node.schedule = dtypes.ScheduleType.GPU_Device
                    elif (isinstance(node,
                                     (nodes.EntryNode, nodes.LibraryNode))
                          and self.sequential_innermaps):
                        node.schedule = dtypes.ScheduleType.Sequential

        #######################################################
        # Step 7: Introduce copy-out if data used in outgoing interstate edges

        for state in list(sdfg.nodes()):
            arrays_used = set()
            for e in sdfg.out_edges(state):
                # Used arrays = intersection between symbols and cloned arrays
                arrays_used.update(
                    set(e.data.free_symbols)
                    & set(cloned_arrays.keys()))

            # Create a state and copy out used arrays
            if len(arrays_used) > 0:
                co_state = sdfg.add_state(state.label + '_icopyout')

                # Reconnect outgoing edges to after interim copyout state
                for e in sdfg.out_edges(state):
                    sdutil.change_edge_src(sdfg, state, co_state)
                # Add unconditional edge to interim state
                sdfg.add_edge(state, co_state, sd.InterstateEdge())

                # Add copy-out nodes
                for nname in arrays_used:
                    desc = sdfg.arrays[nname]
                    src_array = nodes.AccessNode(cloned_arrays[nname],
                                                 debuginfo=desc.debuginfo)
                    dst_array = nodes.AccessNode(nname,
                                                 debuginfo=desc.debuginfo)
                    co_state.add_node(src_array)
                    co_state.add_node(dst_array)
                    co_state.add_nedge(
                        src_array, dst_array,
                        memlet.Memlet.from_array(dst_array.data,
                                                 dst_array.desc(sdfg)))

        #######################################################
        # Step 8: Strict transformations
        if not self.strict_transform:
            return

        # Apply strict state fusions greedily.
        sdfg.apply_strict_transformations()
Beispiel #4
0
    def generate_module(self, sdfg, state, name, subgraph, parameters,
                        symbol_parameters, module_stream, entry_stream,
                        host_stream):
        """Generates a module that will run as a dataflow function in the FPGA
           kernel."""

        state_id = sdfg.node_id(state)
        dfg = sdfg.nodes()[state_id]

        kernel_args_call = []
        kernel_args_module = []
        added = set()

        parameters = list(sorted(parameters, key=lambda t: t[1]))
        arrays = dtypes.deduplicate(
            [p for p in parameters if not isinstance(p[2], dace.data.Scalar)])
        scalars = [p for p in parameters if isinstance(p[2], dace.data.Scalar)]
        scalars += ((False, k, v, None) for k, v in symbol_parameters.items())
        scalars = dace.dtypes.deduplicate(sorted(scalars, key=lambda t: t[1]))
        for is_output, pname, p, interface_id in itertools.chain(
                arrays, scalars):
            if isinstance(p, dace.data.Array):
                arr_name = "{}_{}".format(pname, "out" if is_output else "in")
                # Add interface ID to called module, but not to the module
                # arguments
                argname = arr_name
                if interface_id is not None:
                    argname = arr_name + "_%d" % interface_id

                kernel_args_call.append(argname)
                dtype = p.dtype
                kernel_args_module.append("{} {}*{}".format(
                    dtype.ctype, "const " if not is_output else "", arr_name))
            else:
                # Don't make duplicate arguments for other types than arrays
                if pname in added:
                    continue
                added.add(pname)
                if isinstance(p, dace.data.Stream):
                    kernel_args_call.append(
                        p.as_arg(with_types=False, name=pname))
                    if p.is_stream_array():
                        kernel_args_module.append(
                            "dace::FIFO<{}, {}, {}> {}[{}]".format(
                                p.dtype.base_type.ctype, p.veclen,
                                p.buffer_size, pname, p.size_string()))
                    else:
                        kernel_args_module.append(
                            "dace::FIFO<{}, {}, {}> &{}".format(
                                p.dtype.base_type.ctype, p.veclen,
                                p.buffer_size, pname))
                else:
                    kernel_args_call.append(
                        p.as_arg(with_types=False, name=pname))
                    kernel_args_module.append(
                        p.as_arg(with_types=True, name=pname))

        # create a unique module name to prevent name clashes
        module_function_name = f"module_{name}_{sdfg.sdfg_id}"

        # Unrolling processing elements: if there first scope of the subgraph
        # is an unrolled map, generate a processing element for each iteration
        scope_children = subgraph.scope_children()
        top_scopes = [
            n for n in scope_children[None]
            if isinstance(n, dace.sdfg.nodes.EntryNode)
        ]
        unrolled_loops = 0
        if len(top_scopes) == 1:
            scope = top_scopes[0]
            if scope.unroll:
                self._unrolled_pes.add(scope.map)
                kernel_args_call += ", ".join(scope.map.params)
                kernel_args_module += ["int " + p for p in scope.params]
                for p, r in zip(scope.map.params, scope.map.range):
                    if len(r) > 3:
                        raise cgx.CodegenError("Strided unroll not supported")
                    entry_stream.write(
                        "for (size_t {param} = {begin}; {param} < {end}; "
                        "{param} += {increment}) {{\n#pragma HLS UNROLL".
                        format(param=p,
                               begin=r[0],
                               end=r[1] + 1,
                               increment=r[2]))
                    unrolled_loops += 1

        # Generate caller code in top-level function
        entry_stream.write(
            "HLSLIB_DATAFLOW_FUNCTION({}, {});".format(
                module_function_name, ", ".join(kernel_args_call)), sdfg,
            state_id)

        for _ in range(unrolled_loops):
            entry_stream.write("}")

        # ----------------------------------------------------------------------
        # Generate kernel code
        # ----------------------------------------------------------------------

        self._dispatcher.defined_vars.enter_scope(subgraph)

        module_body_stream = CodeIOStream()

        module_body_stream.write(
            "void {}({}) {{".format(module_function_name,
                                    ", ".join(kernel_args_module)), sdfg,
            state_id)

        # Construct ArrayInterface wrappers to pack input and output pointers
        # to the same global array
        in_args = {
            argname
            for out, argname, arg, _ in parameters
            if isinstance(arg, dace.data.Array)
            and arg.storage == dace.dtypes.StorageType.FPGA_Global and not out
        }
        out_args = {
            argname
            for out, argname, arg, _ in parameters
            if isinstance(arg, dace.data.Array)
            and arg.storage == dace.dtypes.StorageType.FPGA_Global and out
        }
        if len(in_args) > 0 or len(out_args) > 0:
            # Add ArrayInterface objects to wrap input and output pointers to
            # the same array
            module_body_stream.write("\n")
            interfaces_added = set()
            for _, argname, arg, _ in parameters:
                if argname in interfaces_added:
                    continue
                interfaces_added.add(argname)
                has_in_ptr = argname in in_args
                has_out_ptr = argname in out_args
                if not has_in_ptr and not has_out_ptr:
                    continue
                in_ptr = ("{}_in".format(argname) if has_in_ptr else "nullptr")
                out_ptr = ("{}_out".format(argname)
                           if has_out_ptr else "nullptr")
                ctype = "dace::ArrayInterface<{}>".format(arg.dtype.ctype)
                module_body_stream.write("{} {}({}, {});".format(
                    ctype, argname, in_ptr, out_ptr))
                self._dispatcher.defined_vars.add(argname,
                                                  DefinedType.ArrayInterface,
                                                  ctype,
                                                  allow_shadowing=True)
            module_body_stream.write("\n")

        # Allocate local transients
        data_to_allocate = (set(subgraph.top_level_transients()) -
                            set(sdfg.shared_transients()) -
                            set([p[1] for p in parameters]))
        allocated = set()
        for node in subgraph.nodes():
            if not isinstance(node, dace.sdfg.nodes.AccessNode):
                continue
            if node.data not in data_to_allocate or node.data in allocated:
                continue
            allocated.add(node.data)
            self._dispatcher.dispatch_allocate(sdfg, state, state_id, node,
                                               module_stream,
                                               module_body_stream)

        self._dispatcher.dispatch_subgraph(sdfg,
                                           subgraph,
                                           state_id,
                                           module_stream,
                                           module_body_stream,
                                           skip_entry_node=False)

        module_stream.write(module_body_stream.getvalue(), sdfg, state_id)
        module_stream.write("}\n\n")

        self._dispatcher.defined_vars.exit_scope(subgraph)
Beispiel #5
0
 def successors(self, node: NodeT) -> Iterable[NodeT]:
     """Returns an iterable of nodes that have edges leading to the passed
     node"""
     return deduplicate([e.dst for e in self.out_edges(node)])
Beispiel #6
0
 def predecessors(self, node: NodeT) -> Iterable[NodeT]:
     """Returns an iterable of nodes that have edges leading to the passed
     node"""
     return deduplicate([e.src for e in self.in_edges(node)])
Beispiel #7
0
    def generate_code(
        self,
        sdfg: SDFG,
        schedule: Optional[dtypes.ScheduleType],
        sdfg_id: str = ""
    ) -> Tuple[str, str, Set[TargetCodeGenerator], Set[str]]:
        """ Generate frame code for a given SDFG, calling registered targets'
            code generation callbacks for them to generate their own code.
            :param sdfg: The SDFG to generate code for.
            :param schedule: The schedule the SDFG is currently located, or
                             None if the SDFG is top-level.
            :param sdfg_id: An optional string id given to the SDFG label
            :return: A tuple of the generated global frame code, local frame
                     code, and a set of targets that have been used in the
                     generation of this SDFG.
        """

        sdfg_label = sdfg.name + sdfg_id

        global_stream = CodeIOStream()
        callsite_stream = CodeIOStream()

        # Set default storage/schedule types in SDFG
        set_default_schedule_and_storage_types(sdfg, schedule)

        is_top_level = sdfg.parent is None

        # Generate code
        ###########################

        # Invoke all instrumentation providers
        for instr in self._dispatcher.instrumentation.values():
            if instr is not None:
                instr.on_sdfg_begin(sdfg, callsite_stream, global_stream)

        # Allocate outer-level transients
        shared_transients = sdfg.shared_transients()
        allocated = set()
        for state in sdfg.nodes():
            for node in state.data_nodes():
                if (node.data in shared_transients
                        and node.data not in allocated):
                    self._dispatcher.dispatch_allocate(sdfg, state, None, node,
                                                       global_stream,
                                                       callsite_stream)
                    self._dispatcher.dispatch_initialize(
                        sdfg, state, None, node, global_stream,
                        callsite_stream)
                    allocated.add(node.data)

        # Allocate inter-state variables
        assigned, _ = sdfg.interstate_symbols()
        for isvarName, isvarType in assigned.items():
            # Skip symbols that have been declared as outer-level transients
            if isvarName in allocated:
                continue
            callsite_stream.write(
                '%s;\n' %
                (isvarType.signature(with_types=True, name=isvarName)), sdfg)

        # Initialize parameter arrays
        for argnode in dtypes.deduplicate(sdfg.input_arrays() +
                                          sdfg.output_arrays()):
            # Ignore transient arrays
            if argnode.desc(sdfg).transient: continue
            self._dispatcher.dispatch_initialize(sdfg, sdfg, None, argnode,
                                                 global_stream,
                                                 callsite_stream)

        callsite_stream.write('\n', sdfg)

        states_topological = list(sdfg.topological_sort(sdfg.start_state))

        # {edge: [dace.edges.ControlFlow]}
        control_flow = {e: [] for e in sdfg.edges()}

        if dace.config.Config.get_bool('optimizer', 'detect_control_flow'):

            ####################################################################
            # Loop detection procedure

            all_cycles = list(sdfg.find_cycles())  # Returns a list of lists
            # Order according to topological sort
            all_cycles = [
                sorted(c, key=lambda x: states_topological.index(x))
                for c in all_cycles
            ]
            # Group in terms of starting node
            starting_nodes = [c[0] for c in all_cycles]
            # Order cycles according to starting node in topological sort
            starting_nodes = sorted(starting_nodes,
                                    key=lambda x: states_topological.index(x))
            cycles_by_node = [[c for c in all_cycles if c[0] == n]
                              for n in starting_nodes]
            for cycles in cycles_by_node:

                # Use arbitrary cycle to find the first and last nodes
                first_node = cycles[0][0]
                last_node = cycles[0][-1]

                if not first_node.is_empty():
                    # The entry node should not contain any computations
                    continue

                if not all([c[-1] == last_node for c in cycles]):
                    # There are multiple back edges: not a for or while loop
                    continue

                previous_edge = [
                    e for e in sdfg.in_edges(first_node) if e.src != last_node
                ]
                if len(previous_edge) != 1:
                    # No single starting point: not a for or while
                    continue
                previous_edge = previous_edge[0]

                back_edge = sdfg.edges_between(last_node, first_node)
                if len(back_edge) != 1:
                    raise RuntimeError("Expected exactly one edge in cycle")
                back_edge = back_edge[0]

                # Build a set of all nodes in all cycles associated with this
                # set of start and end node
                internal_nodes = functools.reduce(
                    lambda a, b: a | b, [set(c)
                                         for c in cycles]) - {first_node}

                exit_edge = [
                    e for e in sdfg.out_edges(first_node)
                    if e.dst not in internal_nodes | {first_node}
                ]
                if len(exit_edge) != 1:
                    # No single stopping condition: not a for or while
                    # (we don't support continue or break)
                    continue
                exit_edge = exit_edge[0]

                entry_edge = [
                    e for e in sdfg.out_edges(first_node) if e != exit_edge
                ]
                if len(entry_edge) != 1:
                    # No single starting condition: not a for or while
                    continue
                entry_edge = entry_edge[0]

                # Make sure this is not already annotated to be another construct
                if (len(control_flow[entry_edge]) != 0
                        or len(control_flow[back_edge]) != 0):
                    continue

                # Nested loops case I - previous edge of internal loop is a
                # loop-entry of an external loop (first state in a loop is
                # another loop)
                if (len(control_flow[previous_edge]) == 1
                        and isinstance(control_flow[previous_edge][0],
                                       dace.graph.edges.LoopEntry)):
                    # Nested loop, mark parent scope
                    loop_parent = control_flow[previous_edge][0].scope
                # Nested loops case II - exit edge of internal loop is a
                # back-edge of an external loop (last state in a loop is another
                # loop)
                elif (len(control_flow[exit_edge]) == 1
                      and isinstance(control_flow[exit_edge][0],
                                     dace.graph.edges.LoopBack)):
                    # Nested loop, mark parent scope
                    loop_parent = control_flow[exit_edge][0].scope
                elif (len(control_flow[exit_edge]) == 0
                      or len(control_flow[previous_edge]) == 0):
                    loop_parent = None
                else:
                    continue

                if entry_edge == back_edge:
                    # No entry check (we don't support do-loops)
                    # TODO: do we want to add some support for self-loops?
                    continue

                # Now we make sure that there is no other way to exit this
                # cycle, by checking that there's no reachable node *not*
                # included in any cycle between the first and last node.
                if any([len(set(c) - internal_nodes) > 1 for c in cycles]):
                    continue

                # This is a loop! Generate the necessary annotation objects.
                loop_scope = dace.graph.edges.LoopScope(internal_nodes)

                if ((len(previous_edge.data.assignments) > 0
                     or len(back_edge.data.assignments) > 0) and
                    (len(control_flow[previous_edge]) == 0 or
                     (len(control_flow[previous_edge]) == 1 and
                      control_flow[previous_edge][0].scope == loop_parent))):
                    # Generate assignment edge, if available
                    control_flow[previous_edge].append(
                        dace.graph.edges.LoopAssignment(
                            loop_scope, previous_edge))
                # Assign remaining control flow constructs
                control_flow[entry_edge].append(
                    dace.graph.edges.LoopEntry(loop_scope, entry_edge))
                control_flow[exit_edge].append(
                    dace.graph.edges.LoopExit(loop_scope, exit_edge))
                control_flow[back_edge].append(
                    dace.graph.edges.LoopBack(loop_scope, back_edge))

            ###################################################################
            # If/then/else detection procedure

            candidates = [
                n for n in states_topological if sdfg.out_degree(n) == 2
            ]
            for candidate in candidates:

                # A valid if occurs when then are no reachable nodes for either
                # path that does not pass through a common dominator.
                dominators = nx.dominance.dominance_frontiers(
                    sdfg.nx, candidate)

                left_entry, right_entry = sdfg.out_edges(candidate)
                if (len(control_flow[left_entry]) > 0
                        or len(control_flow[right_entry]) > 0):
                    # Already assigned to a control flow construct
                    # TODO: carefully allow this in some cases
                    continue

                left, right = left_entry.dst, right_entry.dst
                dominator = dominators[left] & dominators[right]
                if len(dominator) != 1:
                    # There must be a single dominator across both branches,
                    # unless one of the nodes _is_ the next dominator
                    # if (len(dominator) == 0 and dominators[left] == {right}
                    #         or dominators[right] == {left}):
                    #     dominator = dominators[left] | dominators[right]
                    # else:
                    #     continue
                    continue
                dominator = next(iter(dominator))  # Exactly one dominator

                exit_edges = sdfg.in_edges(dominator)
                if len(exit_edges) != 2:
                    # There must be a single entry and a single exit. This
                    # could be relaxed in the future.
                    continue

                left_exit, right_exit = exit_edges
                if (len(control_flow[left_exit]) > 0
                        or len(control_flow[right_exit]) > 0):
                    # Already assigned to a control flow construct
                    # TODO: carefully allow this in some cases
                    continue

                # Now traverse from the source and verify that all possible paths
                # pass through the dominator
                left_nodes = sdfg.all_nodes_between(left, dominator)
                if left_nodes is None:
                    # Not all paths lead to the next dominator
                    continue
                right_nodes = sdfg.all_nodes_between(right, dominator)
                if right_nodes is None:
                    # Not all paths lead to the next dominator
                    continue
                all_nodes = left_nodes | right_nodes

                # Make sure there is no overlap between left and right nodes
                if len(left_nodes & right_nodes) > 0:
                    continue

                # This is a valid if/then/else construct. Generate annotations
                if_then_else = dace.graph.edges.IfThenElse(
                    candidate, dominator)

                # Arbitrarily assign then/else to the two branches. If one edge
                # has no dominator but leads to the dominator, it means there's
                # only a then clause (and no else).
                has_else = False
                if len(dominators[left]) == 1:
                    then_scope = dace.graph.edges.IfThenScope(
                        if_then_else, left_nodes)
                    else_scope = dace.graph.edges.IfElseScope(
                        if_then_else, right_nodes)
                    control_flow[left_entry].append(
                        dace.graph.edges.IfEntry(then_scope, left_entry))
                    control_flow[left_exit].append(
                        dace.graph.edges.IfExit(then_scope, left_exit))
                    control_flow[right_exit].append(
                        dace.graph.edges.IfExit(else_scope, right_exit))
                    if len(dominators[right]) == 1:
                        control_flow[right_entry].append(
                            dace.graph.edges.IfEntry(else_scope, right_entry))
                        has_else = True
                else:
                    then_scope = dace.graph.edges.IfThenScope(
                        if_then_else, right_nodes)
                    else_scope = dace.graph.edges.IfElseScope(
                        if_then_else, left_nodes)
                    control_flow[right_entry].append(
                        dace.graph.edges.IfEntry(then_scope, right_entry))
                    control_flow[right_exit].append(
                        dace.graph.edges.IfExit(then_scope, right_exit))
                    control_flow[left_exit].append(
                        dace.graph.edges.IfExit(else_scope, left_exit))

        #######################################################################
        # Generate actual program body

        states_generated = set()  # For sanity check
        generated_edges = set()
        self.generate_states(sdfg, "sdfg", control_flow,
                             global_stream, callsite_stream,
                             set(states_topological), states_generated,
                             generated_edges)

        #######################################################################

        # Sanity check
        if len(states_generated) != len(sdfg.nodes()):
            raise RuntimeError(
                "Not all states were generated in SDFG {}!"
                "\n  Generated: {}\n  Missing: {}".format(
                    sdfg.label, [s.label for s in states_generated],
                    [s.label for s in (set(sdfg.nodes()) - states_generated)]))

        # Deallocate transients
        shared_transients = sdfg.shared_transients()
        deallocated = set()
        for state in sdfg.nodes():
            for node in state.data_nodes():
                if (node.data in shared_transients
                        and node.data not in deallocated):
                    self._dispatcher.dispatch_deallocate(
                        sdfg, state, None, node, global_stream,
                        callsite_stream)
                    deallocated.add(node.data)

        # Now that we have all the information about dependencies, generate
        # header and footer
        if is_top_level:
            header_stream = CodeIOStream()
            header_global_stream = CodeIOStream()
            footer_stream = CodeIOStream()
            footer_global_stream = CodeIOStream()
            self.generate_header(sdfg, self._dispatcher.used_environments,
                                 header_global_stream, header_stream)

            # Open program function
            function_signature = 'void __program_%s_internal(%s)\n{\n' % (
                sdfg.name, sdfg.signature())

            self.generate_footer(sdfg, self._dispatcher.used_environments,
                                 footer_global_stream, footer_stream)

            header_global_stream.write(global_stream.getvalue())
            header_global_stream.write(footer_global_stream.getvalue())
            generated_header = header_global_stream.getvalue()

            all_code = CodeIOStream()
            all_code.write(function_signature)
            all_code.write(header_stream.getvalue())
            all_code.write(callsite_stream.getvalue())
            all_code.write(footer_stream.getvalue())
            generated_code = all_code.getvalue()
        else:
            generated_header = global_stream.getvalue()
            generated_code = callsite_stream.getvalue()

        # Return the generated global and local code strings
        return (generated_header, generated_code,
                self._dispatcher.used_targets,
                self._dispatcher.used_environments)
    def apply_pass(
        self, sdfg: SDFG,
        pipeline_results: Dict[str,
                               Any]) -> Optional[Dict[SDFGState, Set[str]]]:
        """
        Removes unreachable dataflow throughout SDFG states.
        :param sdfg: The SDFG to modify.
        :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass
                                 results as ``{Pass subclass name: returned object from pass}``. If not run in a
                                 pipeline, an empty dictionary is expected.
        :return: A dictionary mapping states to removed data descriptor names, or None if nothing changed.
        """
        # Depends on the following analysis passes:
        #  * State reachability
        #  * Read/write access sets per state
        reachable: Dict[SDFGState,
                        Set[SDFGState]] = pipeline_results['StateReachability']
        access_sets: Dict[SDFGState,
                          Tuple[Set[str],
                                Set[str]]] = pipeline_results['AccessSets']
        result: Dict[SDFGState, Set[str]] = defaultdict(set)

        # Traverse SDFG backwards
        for state in reversed(list(cfg.stateorder_topological_sort(sdfg))):
            #############################################
            # Analysis
            #############################################

            # Compute states where memory will no longer be read
            writes = access_sets[state][1]
            descendants = reachable[state]
            descendant_reads = set().union(*(access_sets[succ][0]
                                             for succ in descendants))
            no_longer_used: Set[str] = set(data for data in writes
                                           if data not in descendant_reads)

            # Compute dead nodes
            dead_nodes: List[nodes.Node] = []

            # Propagate deadness backwards within a state
            for node in sdutil.dfs_topological_sort(state, reverse=True):
                if self._is_node_dead(node, sdfg, state, dead_nodes,
                                      no_longer_used):
                    dead_nodes.append(node)

            # Scope exit nodes are only dead if their corresponding entry nodes are
            live_nodes = set()
            for node in dead_nodes:
                if isinstance(node, nodes.ExitNode) and state.entry_node(
                        node) not in dead_nodes:
                    live_nodes.add(node)
            dead_nodes = dtypes.deduplicate(
                [n for n in dead_nodes if n not in live_nodes])

            if not dead_nodes:
                continue

            # Remove nodes while preserving scopes
            scopes_to_reconnect: Set[nodes.Node] = set()
            for node in state.nodes():
                # Look for scope exits that will be disconnected
                if isinstance(node, nodes.ExitNode) and node not in dead_nodes:
                    if any(n in dead_nodes for n in state.predecessors(node)):
                        scopes_to_reconnect.add(node)

            # Two types of scope disconnections may occur:
            # 1. Two scope exits will no longer be connected
            # 2. A predecessor of dead nodes is in a scope and not connected to its exit
            # Case (1) is taken care of by ``remove_memlet_path``
            # Case (2) is handled below
            # Reconnect scopes
            if scopes_to_reconnect:
                schildren = state.scope_children()
                for exit_node in scopes_to_reconnect:
                    entry_node = state.entry_node(exit_node)
                    for node in schildren[entry_node]:
                        if node is exit_node:
                            continue
                        if isinstance(node, nodes.EntryNode):
                            node = state.exit_node(node)
                        # If node will be disconnected from exit node, add an empty memlet
                        if all(succ in dead_nodes
                               for succ in state.successors(node)):
                            state.add_nedge(node, exit_node, Memlet())

            #############################################
            # Removal
            #############################################
            predecessor_nsdfgs: Dict[nodes.NestedSDFG,
                                     Set[str]] = defaultdict(set)
            for node in dead_nodes:
                # Remove memlet paths and connectors pertaining to dead nodes
                for e in state.in_edges(node):
                    mtree = state.memlet_tree(e)
                    for leaf in mtree.leaves():
                        # Keep track of predecessors of removed nodes for connector pruning
                        if isinstance(leaf.src, nodes.NestedSDFG):
                            predecessor_nsdfgs[leaf.src].add(leaf.src_conn)
                        state.remove_memlet_path(leaf)

                # Remove the node itself as necessary
                state.remove_node(node)

            result[state].update(dead_nodes)

            # Remove isolated access nodes after elimination
            access_nodes = set(state.data_nodes())
            for node in access_nodes:
                if state.degree(node) == 0:
                    state.remove_node(node)
                    result[state].add(node)

            # Prune now-dead connectors
            for node, dead_conns in predecessor_nsdfgs.items():
                for conn in dead_conns:
                    # If removed connector belonged to a nested SDFG, and no other input connector shares name,
                    # make nested data transient (dead dataflow elimination would remove internally as necessary)
                    if conn not in node.in_connectors:
                        node.sdfg.arrays[conn].transient = True

            # Update read sets for the predecessor states to reuse
            access_nodes -= result[state]
            access_node_names = set(n.data for n in access_nodes
                                    if state.out_degree(n) > 0)
            access_sets[state] = (access_node_names, access_sets[state][1])

        return result or None