Ejemplo n.º 1
0
def auto_optimize(sdfg: dace.SDFG, cuda, apply_strict=False):
    """ Automatically optimize ``sdfg``.

        :param sdfg: the sdfg to optimize (inplace).
        :param cuda: whether to optimize for cuda.
        :param apply_strict: whether to apply strict transformations to the sdfg after optimization.
    """
    expand_onnx_nodes(sdfg)
    # MKL is currently broken
    set_fast_implementations(
        sdfg,
        dace.DeviceType.GPU if cuda else dace.DeviceType.CPU,
        blocklist=["MKL"])
    if apply_strict:
        # there is a nondeterministic bug in redundant array that appears if
        # we don't apply inline first
        sdfg.apply_transformations_repeated(interstate.InlineSDFG)
        sdfg.apply_strict_transformations()
Ejemplo n.º 2
0
def canonicalize_sdfg(sdfg: dace.SDFG, symbols={}):
    # Clean up unnecessary subgraphs
    remove_scalar_transients(sdfg)
    remove_unused_sinks(sdfg)
    remove_constant_stencils(sdfg)
    split_condition_interstate_edges(sdfg)

    # Fuse and nest parallel K-loops
    sdfg.apply_transformations_repeated(MapFission, validate=False)
    standardize_data_layout(sdfg)
    sdfg.apply_transformations_repeated([NestK, InlineSDFG], validate=False)
    sdfg.apply_transformations_repeated([StencilFusion])

    # Remove loops
    loops_removed = sdfg.apply_transformations_repeated([RemoveLoop],
                                                        validate=False)
    if loops_removed > 0:
        raise ValueError("Control flow loops not supported.")

    from dace.transformation.interstate import StateFusion
    sdfg.apply_transformations_repeated(StateFusion)
    sdfg.apply_strict_transformations()

    # Specialize symbols and constants
    sdfg.specialize(symbols)
    symbols.update(sdfg.constants)
    undefined_symbols = sdfg.free_symbols
    if len(undefined_symbols) != 0:
        raise ValueError("Missing symbols: {}".format(
            ", ".join(undefined_symbols)))
    for node, _ in sdfg.all_nodes_recursive():
        if isinstance(node, stencil.Stencil):
            node.shape = _specialize_symbols(node.shape, symbols)
        if isinstance(node, dace.sdfg.nodes.MapEntry):
            ranges = []
            for r in node.map.range:
                ranges.append(_specialize_symbols(r, symbols))
            node.map.range = ranges

        # Make transformation passes on tasklets and stencil libnodes
        if hasattr(node, 'code'):

            new_code = [_Predicator().visit(stmt) for stmt in node.code.code]

            # min/max predication requires multiple passes (nested expressions)
            minmax_predicated = 1
            while minmax_predicated > 0:
                pred = _MinMaxPredicator()
                tmp_code = [pred.visit(stmt) for stmt in new_code]
                minmax_predicated = pred.count

                # Some of the outputs may be lists, flatten
                new_code = []

                def flatten(val):
                    for v in val:
                        if isinstance(v, list):
                            flatten(v)
                        else:
                            new_code.append(v)

                flatten(tmp_code)

            node.code.code = new_code

    return sdfg