Example #1
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        if self.expr_index == 0:
            map_entry = self.map_entry
            nsdfg_node = helpers.nest_state_subgraph(
                sdfg,
                graph,
                graph.scope_subgraph(map_entry),
                full_data=self.fullcopy)
        else:
            cnode = self.reduce
            nsdfg_node = helpers.nest_state_subgraph(sdfg,
                                                     graph,
                                                     SubgraphView(
                                                         graph, [cnode]),
                                                     full_data=self.fullcopy)

        # Avoiding import loops
        from dace.transformation.interstate import GPUTransformSDFG
        transformation = GPUTransformSDFG(sdfg, 0, -1, {}, 0)
        transformation.register_trans = self.register_trans
        transformation.sequential_innermaps = self.sequential_innermaps
        transformation.toplevel_trans = self.toplevel_trans

        transformation.apply(nsdfg_node.sdfg, nsdfg_node.sdfg)

        # Inline back as necessary
        sdfg.simplify()
Example #2
0
    def dry_run(sdfg: SDFG, *args, **kwargs) -> Any:
        # Check existing instrumented data for shape mismatch
        kwargs.update({aname: a for aname, a in zip(sdfg.arg_names, args)})

        dreport = sdfg.get_instrumented_data()
        if dreport is not None:
            for data in dreport.keys():
                rep_arr = dreport.get_first_version(data)
                sdfg_arr = sdfg.arrays[data]
                # Potential shape mismatch
                if rep_arr.shape != sdfg_arr.shape:
                    # Check given data first
                    if hasattr(kwargs[data], 'shape') and rep_arr.shape != kwargs[data].shape:
                        sdfg.clear_data_reports()
                        dreport = None
                        break

        # If there is no valid instrumented data available yet, run in data instrumentation mode
        if dreport is None:
            for state in sdfg.nodes():
                for node in state.nodes():
                    if isinstance(node, dace.nodes.AccessNode) and not node.desc(sdfg).transient:
                        node.instrument = dace.DataInstrumentationType.Save

            result = sdfg(**kwargs)

            # Disable data instrumentation from now on
            for state in sdfg.nodes():
                for node in state.nodes():
                    if isinstance(node, dace.nodes.AccessNode):
                        node.instrument = dace.DataInstrumentationType.No_Instrumentation
        else:
            return None

        return result
Example #3
0
def fuse_states(sdfg: SDFG,
                strict: bool = True,
                progress: bool = False) -> int:
    """
    Fuses all possible states of an SDFG (and all sub-SDFGs) using an optimized
    routine that uses the structure of the StateFusion transformation.
    :param sdfg: The SDFG to transform.
    :param strict: If True (default), operates in strict mode.
    :param progress: If True, prints out a progress bar of fusion (may be
                     inaccurate, requires ``tqdm``)
    :return: The total number of states fused.
    """
    from dace.transformation.interstate import StateFusion  # Avoid import loop
    counter = 0
    if progress:
        from tqdm import tqdm
        fusible_states = 0
        for sd in sdfg.all_sdfgs_recursive():
            fusible_states += sd.number_of_edges()
        pbar = tqdm(total=fusible_states)

    for sd in sdfg.all_sdfgs_recursive():
        id = sd.sdfg_id
        while True:
            edges = list(sd.nx.edges)
            applied = 0
            skip_nodes = set()
            for u, v in edges:
                if u in skip_nodes or v in skip_nodes:
                    continue
                candidate = {
                    StateFusion.first_state: u,
                    StateFusion.second_state: v
                }
                sf = StateFusion(id, -1, candidate, 0, override=True)
                if sf.can_be_applied(sd, candidate, 0, sd, strict=strict):
                    sf.apply(sd)
                    applied += 1
                    counter += 1
                    if progress:
                        pbar.update(1)
                    skip_nodes.add(u)
                    skip_nodes.add(v)
            if applied == 0:
                break
    if progress:
        pbar.close()
    if config.Config.get_bool('debugprint'):
        print(f'Applied {counter} State Fusions')
    return counter
Example #4
0
def fuse_states(sdfg: SDFG) -> int:
    """
    Fuses all possible states of an SDFG (and all sub-SDFGs) using an optimized
    routine that uses the structure of the StateFusion transformation.
    :param sdfg: The SDFG to transform.
    :return: The total number of states fused.
    """
    from dace.transformation.interstate import StateFusion  # Avoid import loop
    counter = 0
    for sd in sdfg.all_sdfgs_recursive():
        id = sd.sdfg_id
        while True:
            edges = list(sd.nx.edges)
            applied = 0
            skip_nodes = set()
            for u, v in edges:
                if u in skip_nodes or v in skip_nodes:
                    continue
                candidate = {
                    StateFusion.first_state: u,
                    StateFusion.second_state: v
                }
                sf = StateFusion(id, -1, candidate, 0, override=True)
                if sf.can_be_applied(sd, candidate, 0, sd, strict=True):
                    sf.apply(sd)
                    applied += 1
                    counter += 1
                    skip_nodes.add(u)
                    skip_nodes.add(v)
            if applied == 0:
                break
    if config.Config.get_bool('debugprint'):
        print(f'Applied {counter} State Fusions')
    return counter
Example #5
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        import dace.libraries.blas as blas

        transpose_a = self.transpose_a
        _at = self.at
        transpose_b = self.transpose_b
        _bt = self.bt
        a_times_b = self.a_times_b

        for src, src_conn, _, _, memlet in graph.in_edges(transpose_a):
            graph.add_edge(src, src_conn, a_times_b, '_b', memlet)
        graph.remove_node(transpose_a)
        for src, src_conn, _, _, memlet in graph.in_edges(transpose_b):
            graph.add_edge(src, src_conn, a_times_b, '_a', memlet)
        graph.remove_node(transpose_b)
        graph.remove_node(_at)
        graph.remove_node(_bt)

        for _, _, dst, dst_conn, memlet in graph.out_edges(a_times_b):
            subset = dcpy(memlet.subset)
            subset.squeeze()
            size = subset.size()
            shape = [size[1], size[0]]
            break
        tmp_name, tmp_arr = sdfg.add_temp_transient(shape, a_times_b.dtype)
        tmp_acc = graph.add_access(tmp_name)
        transpose_c = blas.Transpose('_Transpose_', a_times_b.dtype)
        for edge in graph.out_edges(a_times_b):
            _, _, dst, dst_conn, memlet = edge
            graph.remove_edge(edge)
            graph.add_edge(transpose_c, '_out', dst, dst_conn, memlet)
        graph.add_edge(a_times_b, '_c', tmp_acc, None, dace.Memlet.from_array(tmp_name, tmp_arr))
        graph.add_edge(tmp_acc, None, transpose_c, '_inp', dace.Memlet.from_array(tmp_name, tmp_arr))
Example #6
0
def structured_control_flow_tree(
        sdfg: SDFG, dispatch_state: Callable[[SDFGState], str]) -> ControlFlow:
    """
    Returns a structured control-flow tree (i.e., with constructs such as 
    branches and loops) from an SDFG, which can be used to generate its code
    in a compiler- and human-friendly way.
    :param sdfg: The SDFG to iterate over.
    :return: Control-flow block representing the entire SDFG.
    """
    # Avoid import loops
    from dace.sdfg.analysis import cfg

    # Get parent states and back-edges
    ptree = cfg.state_parent_tree(sdfg)
    back_edges = cfg.back_edges(sdfg)

    # Annotate branches
    branch_merges: Dict[SDFGState, SDFGState] = {}
    adf = cfg.acyclic_dominance_frontier(sdfg)
    for state in sdfg.nodes():
        oedges = sdfg.out_edges(state)
        # Skip if not branch
        if len(oedges) <= 1:
            continue
        # Skip if natural loop
        if len(oedges) == 2 and (
            (ptree[oedges[0].dst] == state and ptree[oedges[1].dst] != state)
                or
            (ptree[oedges[1].dst] == state and ptree[oedges[0].dst] != state)):
            continue

        common_frontier = set()
        for oedge in oedges:
            frontier = adf[oedge.dst]
            if not frontier:
                frontier = {oedge.dst}
            common_frontier |= frontier
        if len(common_frontier) == 1:
            branch_merges[state] = next(iter(common_frontier))

    root_block = GeneralBlock(dispatch_state, [], [])
    _structured_control_flow_traversal(sdfg, sdfg.start_state, ptree,
                                       branch_merges, back_edges,
                                       dispatch_state, root_block)
    return root_block
Example #7
0
def load_precompiled_sdfg(folder: str):
    """
    Loads a pre-compiled SDFG from an output folder (e.g. ".dacecache/program").
    Folder must contain a file called "program.sdfg" and a subfolder called
    "build" with the shared object.

    :param folder: Path to SDFG output folder.
    :return: A callable CompiledSDFG object.
    """
    from dace.codegen import compiled_sdfg as csdfg
    sdfg = SDFG.from_file(os.path.join(folder, 'program.sdfg'))
    suffix = config.Config.get('compiler', 'library_extension')
    return csdfg.CompiledSDFG(
        sdfg,
        csdfg.ReloadableDLL(
            os.path.join(folder, 'build', f'lib{sdfg.name}.{suffix}'),
            sdfg.name))
Example #8
0
def inline_sdfgs(sdfg: SDFG,
                 strict: bool = True,
                 progress: bool = False) -> int:
    """
    Inlines all possible nested SDFGs (or sub-SDFGs) using an optimized
    routine that uses the structure of the SDFG hierarchy.
    :param sdfg: The SDFG to transform.
    :param strict: If True (default), operates in strict mode.
    :param progress: If True, prints out a progress bar of inlining (may be
                     inaccurate, requires ``tqdm``)
    :return: The total number of SDFGs inlined.
    """
    from dace.transformation.interstate import InlineSDFG  # Avoid import loop
    counter = 0
    sdfgs = list(sdfg.all_sdfgs_recursive())
    if progress:
        from tqdm import tqdm
        pbar = tqdm(total=len(sdfgs))

    for sd in reversed(sdfgs):
        id = sd.sdfg_id
        for state_id, state in enumerate(sd.nodes()):
            for node in state.nodes():
                if not isinstance(node, NestedSDFG):
                    continue
                # We have to reevaluate every time due to changing IDs
                node_id = state.node_id(node)
                candidate = {
                    InlineSDFG._nested_sdfg: node_id,
                }
                inliner = InlineSDFG(id, state_id, candidate, 0, override=True)
                if inliner.can_be_applied(state,
                                          candidate,
                                          0,
                                          sd,
                                          strict=strict):
                    inliner.apply(sd)
                    counter += 1
                    if progress:
                        pbar.update(1)
    if progress:
        pbar.close()
    if config.Config.get_bool('debugprint'):
        print(f'Inlined {counter} SDFGs')
    return counter
Example #9
0
def consolidate_edges(sdfg: SDFG, starting_scope=None) -> int:
    """
    Union scope-entering memlets relating to the same data node in all states.
    This effectively reduces the number of connectors and allows more
    transformations to be performed, at the cost of losing the individual
    per-tasklet memlets.
    :param sdfg: The SDFG to consolidate.
    :return: Number of edges removed.
    """
    from dace.sdfg.propagation import propagate_memlets_sdfg, propagate_memlets_scope

    consolidated = 0
    for state in sdfg.nodes():
        # Start bottom-up
        if starting_scope and starting_scope.entry not in state.nodes():
            continue

        queue = [starting_scope] if starting_scope else state.scope_leaves()
        next_queue = []
        while len(queue) > 0:
            for scope in queue:
                consolidated += consolidate_edges_scope(state, scope.entry)
                consolidated += consolidate_edges_scope(state, scope.exit)
                if scope.parent is not None:
                    next_queue.append(scope.parent)
            queue = next_queue
            next_queue = []

        if starting_scope is not None:
            # Repropagate memlets from this scope outwards
            propagate_memlets_scope(sdfg, state, starting_scope)

            # No need to traverse other states
            break

    # Repropagate memlets
    if starting_scope is None:
        propagate_memlets_sdfg(sdfg)

    return consolidated
Example #10
0
def get_next_nonempty_states(sdfg: SDFG, state: SDFGState) -> Set[SDFGState]:
    """
    From the given state, return the next set of states that are reachable
    in the SDFG, skipping empty states. Traversal stops at the non-empty 
    state.
    This function is used to determine whether synchronization should happen
    at the end of a GPU state.
    :param sdfg: The SDFG that contains the state.
    :param state: The state to start from.
    :return: A set of reachable non-empty states.
    """
    result: Set[SDFGState] = set()

    # Traverse children until states are not empty
    for succ in sdfg.successors(state):
        result |= set(
            dfs_conditional(sdfg,
                            sources=[succ],
                            condition=lambda parent, _: parent.is_empty()))

    # Filter out empty states
    result = {s for s in result if not s.is_empty()}

    return result
Example #11
0
def consolidate_edges(sdfg: SDFG) -> int:
    """
    Union scope-entering memlets relating to the same data node in all states.
    This effectively reduces the number of connectors and allows more
    transformations to be performed, at the cost of losing the individual
    per-tasklet memlets.
    :param sdfg: The SDFG to consolidate.
    :return: Number of edges removed.
    """
    consolidated = 0
    for state in sdfg.nodes():
        # Start bottom-up
        queue = state.scope_leaves()
        next_queue = []
        while len(queue) > 0:
            for scope in queue:
                consolidated += consolidate_edges_scope(state, scope.entry)
                consolidated += consolidate_edges_scope(state, scope.exit)
                if scope.parent is not None:
                    next_queue.append(scope.parent)
            queue = next_queue
            next_queue = []

    return consolidated
Example #12
0
def _structured_control_flow_traversal(
        sdfg: SDFG,
        start: SDFGState,
        ptree: Dict[SDFGState, SDFGState],
        branch_merges: Dict[SDFGState, SDFGState],
        back_edges: List[Edge[InterstateEdge]],
        dispatch_state: Callable[[SDFGState], str],
        parent_block: GeneralBlock,
        stop: SDFGState = None,
        generate_children_of: SDFGState = None) -> Set[SDFGState]:
    """ 
    Helper function for ``structured_control_flow_tree``. 
    :param sdfg: SDFG.
    :param start: Starting state for traversal.
    :param ptree: State parent tree (computed from ``state_parent_tree``).
    :param branch_merges: Dictionary mapping from branch state to its merge
                          state.
    :param dispatch_state: A function that dispatches code generation for a 
                           single state.
    :param parent_block: The block to append children to.
    :param stop: Stopping state to not traverse through (merge state of a 
                 branch or guard state of a loop).
    :return: Generator that yields states in state-order from ``start`` to 
             ``stop``.
    """
    # Traverse states in custom order
    visited = set()
    if stop is not None:
        visited.add(stop)
    stack = [start]
    while stack:
        node = stack.pop()
        if (generate_children_of is not None
                and not _child_of(node, generate_children_of, ptree)):
            continue
        if node in visited:
            continue
        visited.add(node)
        stateblock = SingleState(dispatch_state, node)

        oe = sdfg.out_edges(node)
        if len(oe) == 0:  # End state
            # If there are no remaining nodes, this is the last state and it can
            # be marked as such
            if len(stack) == 0:
                stateblock.last_state = True
            parent_block.elements.append(stateblock)
            continue
        elif len(oe) == 1:  # No traversal change
            stack.append(oe[0].dst)
            parent_block.elements.append(stateblock)
            continue

        # Potential branch or loop
        if node in branch_merges:
            mergestate = branch_merges[node]

            # Add branching node and ignore outgoing edges
            parent_block.elements.append(stateblock)
            parent_block.edges_to_ignore.extend(oe)
            stateblock.last_state = True

            # Parse all outgoing edges recursively first
            cblocks: Dict[Edge[InterstateEdge], GeneralBlock] = {}
            for branch in oe:
                cblocks[branch] = GeneralBlock(dispatch_state, [], [])
                visited |= _structured_control_flow_traversal(
                    sdfg,
                    branch.dst,
                    ptree,
                    branch_merges,
                    back_edges,
                    dispatch_state,
                    cblocks[branch],
                    stop=mergestate,
                    generate_children_of=node)

            # Classify branch type:
            branch_block = None
            # If there are 2 out edges, one negation of the other:
            #   * if/else in case both branches are not merge state
            #   * if without else in case one branch is merge state
            if (len(oe) == 2 and oe[0].data.condition_sympy() == sp.Not(
                    oe[1].data.condition_sympy())):
                # If without else
                if oe[0].dst is mergestate:
                    branch_block = IfScope(dispatch_state, sdfg, node,
                                           oe[1].data.condition,
                                           cblocks[oe[1]])
                elif oe[1].dst is mergestate:
                    branch_block = IfScope(dispatch_state, sdfg, node,
                                           oe[0].data.condition,
                                           cblocks[oe[0]])
                else:
                    branch_block = IfScope(dispatch_state, sdfg, node,
                                           oe[0].data.condition,
                                           cblocks[oe[0]], cblocks[oe[1]])
            else:
                # If there are 2 or more edges (one is not the negation of the
                # other):
                switch = _cases_from_branches(oe, cblocks)
                if switch:
                    # If all edges are of form "x == y" for a single x and
                    # integer y, it is a switch/case
                    branch_block = SwitchCaseScope(dispatch_state, sdfg, node,
                                                   switch[0], switch[1])
                else:
                    # Otherwise, create if/else if/.../else goto exit chain
                    branch_block = IfElseChain(dispatch_state, sdfg, node,
                                               [(e.data.condition, cblocks[e])
                                                for e in oe])
            # End of branch classification
            parent_block.elements.append(branch_block)
            if mergestate != stop:
                stack.append(mergestate)

        elif len(oe) == 2:  # Potential loop
            # TODO(later): Recognize do/while loops
            # If loop, traverse body, then exit
            body_start = None
            loop_exit = None
            scope = None
            if ptree[oe[0].dst] == node and ptree[oe[1].dst] != node:
                scope = _loop_from_structure(sdfg, node, oe[0], oe[1],
                                             back_edges, dispatch_state)
                body_start = oe[0].dst
                loop_exit = oe[1].dst
            elif ptree[oe[1].dst] == node and ptree[oe[0].dst] != node:
                scope = _loop_from_structure(sdfg, node, oe[1], oe[0],
                                             back_edges, dispatch_state)
                body_start = oe[1].dst
                loop_exit = oe[0].dst

            if scope:
                visited |= _structured_control_flow_traversal(
                    sdfg,
                    body_start,
                    ptree,
                    branch_merges,
                    back_edges,
                    dispatch_state,
                    scope.body,
                    stop=node,
                    generate_children_of=node)

                # Add branching node and ignore outgoing edges
                parent_block.elements.append(stateblock)
                parent_block.edges_to_ignore.extend(oe)

                parent_block.elements.append(scope)

                # If for loop, ignore certain edges
                if isinstance(scope, ForScope):
                    # Mark init edge(s) to ignore in parent_block and all children
                    _ignore_recursive([
                        e for e in sdfg.in_edges(node) if e not in back_edges
                    ], parent_block)
                    # Mark back edge for ignoring in all children of loop body
                    _ignore_recursive(
                        [e for e in sdfg.in_edges(node) if e in back_edges],
                        scope.body)

                stack.append(loop_exit)
                continue

            # No proper loop detected: Unstructured control flow
            parent_block.elements.append(stateblock)
            stack.extend([e.dst for e in oe])
        else:  # No merge state: Unstructured control flow
            parent_block.elements.append(stateblock)
            stack.extend([e.dst for e in oe])

    return visited - {stop}
Example #13
0
def _loop_from_structure(
        sdfg: SDFG, guard: SDFGState, enter_edge: Edge[InterstateEdge],
        leave_edge: Edge[InterstateEdge],
        back_edges: List[Edge[InterstateEdge]],
        dispatch_state: Callable[[SDFGState],
                                 str]) -> Union[ForScope, WhileScope]:
    """ 
    Helper method that constructs the correct structured loop construct from a
    set of states. Can construct for or while loops.
    """

    body = GeneralBlock(dispatch_state, [], [])

    guard_inedges = sdfg.in_edges(guard)
    increment_edges = [e for e in guard_inedges if e in back_edges]
    init_edges = [e for e in guard_inedges if e not in back_edges]

    # If no back edge found (or more than one, indicating a "continue"
    # statement), disregard
    if len(increment_edges) > 1 or len(increment_edges) == 0:
        return None
    increment_edge = increment_edges[0]

    # Mark increment edge to be ignored in body
    body.edges_to_ignore.append(increment_edge)

    # Outgoing edges must be a negation of each other
    if enter_edge.data.condition_sympy() != (sp.Not(
            leave_edge.data.condition_sympy())):
        return None

    # Body of guard state must be empty
    if not guard.is_empty():
        return None

    if not increment_edge.data.is_unconditional():
        return None
    if len(enter_edge.data.assignments) > 0:
        return None

    condition = enter_edge.data.condition

    # Detect whether this loop is a for loop:
    # All incoming edges to the guard must set the same variable
    itvars = None
    for iedge in guard_inedges:
        if itvars is None:
            itvars = set(iedge.data.assignments.keys())
        else:
            itvars &= iedge.data.assignments.keys()
    if itvars and len(itvars) == 1:
        itvar = next(iter(itvars))
        init = init_edges[0].data.assignments[itvar]

        # Check that all init edges are the same and that increment edge only
        # increments
        if (all(e.data.assignments[itvar] == init for e in init_edges)
                and len(increment_edge.data.assignments) == 1):
            update = increment_edge.data.assignments[itvar]
            return ForScope(dispatch_state, itvar, guard, init, condition,
                            update, body)

    # Otherwise, it is a while loop
    return WhileScope(dispatch_state, guard, condition, body)
Example #14
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        node_a = self.node_a
        node_b = self.node_b
        prefix = self.prefix

        # Determine direction of new memlet
        scope_dict = graph.scope_dict()
        propagate_forward = sd.scope_contains_scope(scope_dict, node_a, node_b)

        array = self.array
        if array is None or len(array) == 0:
            array = next(e.data.data
                         for e in graph.edges_between(node_a, node_b)
                         if e.data.data is not None and e.data.wcr is None)

        original_edge = None
        invariant_memlet = None
        for edge in graph.edges_between(node_a, node_b):
            if array == edge.data.data:
                original_edge = edge
                invariant_memlet = edge.data
                break
        if invariant_memlet is None:
            for edge in graph.edges_between(node_a, node_b):
                original_edge = edge
                invariant_memlet = edge.data
                warnings.warn('Array %s not found! Using array %s instead.' %
                              (array, invariant_memlet.data))
                array = invariant_memlet.data
                break
        if invariant_memlet is None:
            raise NameError('Array %s not found!' % array)
        if self.create_array:
            # Add transient array
            new_data, _ = sdfg.add_transient(
                name=prefix + invariant_memlet.data,
                shape=[
                    symbolic.overapproximate(r).simplify()
                    for r in invariant_memlet.bounding_box_size()
                ],
                dtype=sdfg.arrays[invariant_memlet.data].dtype,
                find_new_name=True)

        else:
            new_data = prefix + invariant_memlet.data
        data_node = nodes.AccessNode(new_data)
        # Store as fields so that other transformations can use them
        self._local_name = new_data
        self._data_node = data_node

        to_data_mm = copy.deepcopy(invariant_memlet)
        from_data_mm = copy.deepcopy(invariant_memlet)
        offset = subsets.Indices([r[0] for r in invariant_memlet.subset])

        # Reconnect, assuming one edge to the access node
        graph.remove_edge(original_edge)
        if propagate_forward:
            graph.add_edge(node_a, original_edge.src_conn, data_node, None,
                           to_data_mm)
            new_edge = graph.add_edge(data_node, None, node_b,
                                      original_edge.dst_conn, from_data_mm)
        else:
            new_edge = graph.add_edge(node_a, original_edge.src_conn,
                                      data_node, None, to_data_mm)
            graph.add_edge(data_node, None, node_b, original_edge.dst_conn,
                           from_data_mm)

        # Offset all edges in the memlet tree (including the new edge)
        for edge in graph.memlet_tree(new_edge):
            edge.data.subset.offset(offset, True)
            edge.data.data = new_data

        return data_node
Example #15
0
def inline_sdfgs(sdfg: SDFG,
                 permissive: bool = False,
                 progress: bool = None,
                 multistate: bool = True) -> int:
    """
    Inlines all possible nested SDFGs (or sub-SDFGs) using an optimized
    routine that uses the structure of the SDFG hierarchy.
    :param sdfg: The SDFG to transform.
    :param permissive: If True, operates in permissive mode, which ignores some
                       checks.
    :param progress: If True, prints out a progress bar of inlining (may be
                     inaccurate, requires ``tqdm``). If None, prints out
                     progress if over 5 seconds have passed. If False, never
                     shows progress bar.
    :param multistate: Include 
    :return: The total number of SDFGs inlined.
    """
    # Avoid import loops
    from dace.transformation.interstate import InlineSDFG, InlineMultistateSDFG
    if progress is True or progress is None:
        try:
            from tqdm import tqdm
        except ImportError:
            tqdm = None

    counter = 0
    sdfgs = list(sdfg.all_sdfgs_recursive())
    if progress is True:
        pbar = tqdm(total=len(sdfgs), desc='Inlining SDFGs')

    start = time.time()

    for sd in reversed(sdfgs):
        id = sd.sdfg_id
        for state in sd.nodes():
            for node in state.nodes():
                if (progress is None and tqdm is not None
                        and (time.time() - start) > 5):
                    progress = True
                    pbar = tqdm(total=len(sdfgs),
                                desc='Inlining SDFG',
                                initial=counter)

                if not isinstance(node, NestedSDFG):
                    continue
                # We have to reevaluate every time due to changing IDs
                node_id = state.node_id(node)
                state_id = sd.node_id(state)
                if multistate:
                    candidate = {
                        InlineMultistateSDFG.nested_sdfg: node_id,
                    }
                    inliner = InlineMultistateSDFG(id,
                                                   state_id,
                                                   candidate,
                                                   0,
                                                   override=True)
                    if inliner.can_be_applied(state,
                                              candidate,
                                              0,
                                              sd,
                                              permissive=permissive):
                        inliner.apply(sd)
                        counter += 1
                        if progress:
                            pbar.update(1)
                        continue

                candidate = {
                    InlineSDFG._nested_sdfg: node_id,
                }
                inliner = InlineSDFG(id, state_id, candidate, 0, override=True)
                if inliner.can_be_applied(state,
                                          candidate,
                                          0,
                                          sd,
                                          permissive=permissive):
                    inliner.apply(sd)
                    counter += 1
                    if progress:
                        pbar.update(1)
    if progress:
        pbar.close()
    if config.Config.get_bool('debugprint') and counter > 0:
        print(f'Inlined {counter} SDFGs')
    return counter
Example #16
0
def fuse_states(sdfg: SDFG,
                permissive: bool = False,
                progress: bool = None) -> int:
    """
    Fuses all possible states of an SDFG (and all sub-SDFGs) using an optimized
    routine that uses the structure of the StateFusion transformation.
    :param sdfg: The SDFG to transform.
    :param permissive: If True, operates in permissive mode, which ignores some
                       race condition checks.
    :param progress: If True, prints out a progress bar of fusion (may be
                     inaccurate, requires ``tqdm``). If None, prints out
                     progress if over 5 seconds have passed. If False, never
                     shows progress bar.
    :return: The total number of states fused.
    """
    from dace.transformation.interstate import StateFusion  # Avoid import loop
    if progress is True or progress is None:
        try:
            from tqdm import tqdm
        except ImportError:
            tqdm = None

    counter = 0
    if progress is True or progress is None:
        fusible_states = 0
        for sd in sdfg.all_sdfgs_recursive():
            fusible_states += sd.number_of_edges()

    if progress is True:
        pbar = tqdm(total=fusible_states, desc='Fusing states')

    start = time.time()

    for sd in sdfg.all_sdfgs_recursive():
        id = sd.sdfg_id
        while True:
            edges = list(sd.nx.edges)
            applied = 0
            skip_nodes = set()
            for u, v in edges:
                if (progress is None and tqdm is not None
                        and (time.time() - start) > 5):
                    progress = True
                    pbar = tqdm(total=fusible_states,
                                desc='Fusing states',
                                initial=counter)

                if u in skip_nodes or v in skip_nodes:
                    continue
                candidate = {
                    StateFusion.first_state: u,
                    StateFusion.second_state: v
                }
                sf = StateFusion(id, -1, candidate, 0, override=True)
                if sf.can_be_applied(sd,
                                     candidate,
                                     0,
                                     sd,
                                     permissive=permissive):
                    sf.apply(sd)
                    applied += 1
                    counter += 1
                    if progress:
                        pbar.update(1)
                    skip_nodes.add(u)
                    skip_nodes.add(v)
            if applied == 0:
                break
    if progress:
        pbar.close()
    if config.Config.get_bool('debugprint') and counter > 0:
        print(f'Applied {counter} State Fusions')
    return counter