def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) -> Dict[str, Any]: """ Get option map for GEMM code generation (with column-major order). """ # Avoid import loops from dace.codegen.targets.common import sym2cpp from dace.libraries.blas.blas_helpers import get_gemm_opts (_, _, ashape, astride), (_, _, bshape, bstride), (_, _, cshape, cstride) = _get_matmul_operands(node, state, sdfg) if getattr(node, 'transA', False): ashape = list(reversed(ashape)) astride = list(reversed(astride)) if getattr(node, 'transB', False): bshape = list(reversed(bshape)) bstride = list(reversed(bstride)) opt = get_gemm_opts(astride, bstride, cstride) bopt = _get_batchmm_opts(ashape, astride, bshape, bstride, cshape, cstride) opt['x'] = '_a' opt['y'] = '_b' opt['xdtype'] = adesc.dtype opt['ydtype'] = bdesc.dtype opt['cdtype'] = cdesc.dtype opt['M'] = sym2cpp(ashape[-2]) opt['N'] = sym2cpp(bshape[-1]) opt['K'] = sym2cpp(ashape[-1]) opt['lda'] = sym2cpp(opt['lda']) opt['ldb'] = sym2cpp(opt['ldb']) opt['ldc'] = sym2cpp(opt['ldc']) if opt['swap']: if bopt: bopt['sa'], bopt['sb'] = bopt['sb'], bopt['sa'] opt['lda'], opt['ldb'] = opt['ldb'], opt['lda'] opt['x'], opt['y'] = opt['y'], opt['x'] opt['xdtype'], opt['ydtype'] = opt['ydtype'], opt['xdtype'] opt['ta'], opt['tb'] = opt['tb'], opt['ta'] opt['M'], opt['N'] = opt['N'], opt['M'] opt['alpha'] = alpha opt['beta'] = beta opt['dtype'] = cdtype opt['func'] = func if bopt: opt['stride_a'] = sym2cpp(bopt['sa']) opt['stride_b'] = sym2cpp(bopt['sb']) opt['stride_c'] = sym2cpp(bopt['sc']) opt['BATCH'] = sym2cpp(bopt['b']) else: opt['BATCH'] = None return opt
def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) -> Dict[str, Any]: """ Get option map for GEMM code generation (with column-major order). """ # Avoid import loops from dace.codegen.targets.common import sym2cpp (_, _, ashape, astride), (_, _, bshape, bstride) = _get_matmul_inputs(node, state, sdfg) opt = get_gemm_opts(astride, bstride, cdesc.strides) bopt = get_batchmm_opts(ashape, astride, bshape, bstride, cdesc.shape, cdesc.strides) opt['x'] = '_a' opt['y'] = '_b' opt['M'] = sym2cpp(ashape[-2]) opt['N'] = sym2cpp(bshape[-1]) opt['K'] = sym2cpp(ashape[-1]) opt['lda'] = sym2cpp(opt['lda']) opt['ldb'] = sym2cpp(opt['ldb']) opt['ldc'] = sym2cpp(opt['ldc']) if opt['swap']: if bopt: bopt['sa'], bopt['sb'] = bopt['sb'], bopt['sa'] opt['lda'], opt['ldb'] = opt['ldb'], opt['lda'] opt['x'], opt['y'] = opt['y'], opt['x'] opt['ta'], opt['tb'] = opt['tb'], opt['ta'] opt['M'], opt['N'] = opt['N'], opt['M'] opt['alpha'] = alpha opt['beta'] = beta opt['dtype'] = cdtype opt['func'] = func if bopt: opt['stride_a'] = sym2cpp(bopt['sa']) opt['stride_b'] = sym2cpp(bopt['sb']) opt['stride_c'] = sym2cpp(bopt['sc']) opt['BATCH'] = sym2cpp(bopt['b']) else: opt['BATCH'] = None return opt