def auto_optimize(sdfg: dace.SDFG, cuda, apply_strict=False): """ Automatically optimize ``sdfg``. :param sdfg: the sdfg to optimize (inplace). :param cuda: whether to optimize for cuda. :param apply_strict: whether to apply strict transformations to the sdfg after optimization. """ expand_onnx_nodes(sdfg) # MKL is currently broken set_fast_implementations( sdfg, dace.DeviceType.GPU if cuda else dace.DeviceType.CPU, blocklist=["MKL"]) if apply_strict: # there is a nondeterministic bug in redundant array that appears if # we don't apply inline first sdfg.apply_transformations_repeated(interstate.InlineSDFG) sdfg.apply_strict_transformations()
def canonicalize_sdfg(sdfg: dace.SDFG, symbols={}): # Clean up unnecessary subgraphs remove_scalar_transients(sdfg) remove_unused_sinks(sdfg) remove_constant_stencils(sdfg) split_condition_interstate_edges(sdfg) # Fuse and nest parallel K-loops sdfg.apply_transformations_repeated(MapFission, validate=False) standardize_data_layout(sdfg) sdfg.apply_transformations_repeated([NestK, InlineSDFG], validate=False) sdfg.apply_transformations_repeated([StencilFusion]) # Remove loops loops_removed = sdfg.apply_transformations_repeated([RemoveLoop], validate=False) if loops_removed > 0: raise ValueError("Control flow loops not supported.") from dace.transformation.interstate import StateFusion sdfg.apply_transformations_repeated(StateFusion) sdfg.apply_strict_transformations() # Specialize symbols and constants sdfg.specialize(symbols) symbols.update(sdfg.constants) undefined_symbols = sdfg.free_symbols if len(undefined_symbols) != 0: raise ValueError("Missing symbols: {}".format( ", ".join(undefined_symbols))) for node, _ in sdfg.all_nodes_recursive(): if isinstance(node, stencil.Stencil): node.shape = _specialize_symbols(node.shape, symbols) if isinstance(node, dace.sdfg.nodes.MapEntry): ranges = [] for r in node.map.range: ranges.append(_specialize_symbols(r, symbols)) node.map.range = ranges # Make transformation passes on tasklets and stencil libnodes if hasattr(node, 'code'): new_code = [_Predicator().visit(stmt) for stmt in node.code.code] # min/max predication requires multiple passes (nested expressions) minmax_predicated = 1 while minmax_predicated > 0: pred = _MinMaxPredicator() tmp_code = [pred.visit(stmt) for stmt in new_code] minmax_predicated = pred.count # Some of the outputs may be lists, flatten new_code = [] def flatten(val): for v in val: if isinstance(v, list): flatten(v) else: new_code.append(v) flatten(tmp_code) node.code.code = new_code return sdfg