def auto_optimize(sdfg: dace.SDFG, cuda, apply_strict=False): """ Automatically optimize ``sdfg``. :param sdfg: the sdfg to optimize (inplace). :param cuda: whether to optimize for cuda. :param apply_strict: whether to apply strict transformations to the sdfg after optimization. """ expand_onnx_nodes(sdfg) # MKL is currently broken set_fast_implementations( sdfg, dace.DeviceType.GPU if cuda else dace.DeviceType.CPU, blocklist=["MKL"]) if apply_strict: # there is a nondeterministic bug in redundant array that appears if # we don't apply inline first sdfg.apply_transformations_repeated(interstate.InlineSDFG) sdfg.apply_strict_transformations()
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) assert node.alpha == 1.0 and node.beta == 1.0 and node.transA == 0 and node.transB == 1 # the gemm libnode is broken for now, so we just do it manually if "C" in node.in_connectors: def prog(A, B, C, Y): Y[:] = A @ np.transpose(B) + C else: def prog(A, B, Y): Y[:] = A @ np.transpose(B) sdfg = program_for_node(prog, sdfg, state, node) sdfg.apply_strict_transformations() return sdfg
def vectorize(sdfg: dace.SDFG, par: str, ignored_conns: list = []): input_bits = set([sdfg.arrays[a].dtype.bytes * 8 for a in sdfg.arrays]) if len(input_bits) > 1: raise NotImplementedError('Different data type sizes as inputs') input_bit_width = list(input_bits)[0] sdfg.apply_strict_transformations() # FIXME: Hardcoded for the demo machine (512 bits) util.SVE_LEN.set(512 / input_bit_width) for node, dfg in sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.MapEntry): if node.params[-1] == par: node.schedule = dace.ScheduleType.SVE_Map for c in node.out_connectors: edges = get_connector_edges(dfg, node, c, False) vectorize_connector(sdfg, dfg, node, par, c, False) for e in edges: vectorize_connector(sdfg, dfg, e.dst, par, e.dst_conn, True) for edge, dfg in sdfg.all_edges_recursive(): if not isinstance(dfg, dace.SDFGState): continue # Force every output connector within the graph to be a vector #if edge.data.wcr is None: # continue scope = util.get_sve_scope(sdfg, dfg, edge.src) if scope is not None: vectorize_connector(sdfg, dfg, edge.src, par, edge.src_conn, False) # Then use a tweaked (but incorrect) version of infer_connector_types infer_connector_types(sdfg) return sdfg
def canonicalize_sdfg(sdfg: dace.SDFG, symbols={}): # Clean up unnecessary subgraphs remove_scalar_transients(sdfg) remove_unused_sinks(sdfg) remove_constant_stencils(sdfg) split_condition_interstate_edges(sdfg) # Fuse and nest parallel K-loops sdfg.apply_transformations_repeated(MapFission, validate=False) standardize_data_layout(sdfg) sdfg.apply_transformations_repeated([NestK, InlineSDFG], validate=False) sdfg.apply_transformations_repeated([StencilFusion]) # Remove loops loops_removed = sdfg.apply_transformations_repeated([RemoveLoop], validate=False) if loops_removed > 0: raise ValueError("Control flow loops not supported.") from dace.transformation.interstate import StateFusion sdfg.apply_transformations_repeated(StateFusion) sdfg.apply_strict_transformations() # Specialize symbols and constants sdfg.specialize(symbols) symbols.update(sdfg.constants) undefined_symbols = sdfg.free_symbols if len(undefined_symbols) != 0: raise ValueError("Missing symbols: {}".format( ", ".join(undefined_symbols))) for node, _ in sdfg.all_nodes_recursive(): if isinstance(node, stencil.Stencil): node.shape = _specialize_symbols(node.shape, symbols) if isinstance(node, dace.sdfg.nodes.MapEntry): ranges = [] for r in node.map.range: ranges.append(_specialize_symbols(r, symbols)) node.map.range = ranges # Make transformation passes on tasklets and stencil libnodes if hasattr(node, 'code'): new_code = [_Predicator().visit(stmt) for stmt in node.code.code] # min/max predication requires multiple passes (nested expressions) minmax_predicated = 1 while minmax_predicated > 0: pred = _MinMaxPredicator() tmp_code = [pred.visit(stmt) for stmt in new_code] minmax_predicated = pred.count # Some of the outputs may be lists, flatten new_code = [] def flatten(val): for v in val: if isinstance(v, list): flatten(v) else: new_code.append(v) flatten(tmp_code) node.code.code = new_code return sdfg