Пример #1
0
def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta,
                           cdtype, func) -> Dict[str, Any]:
    """ Get option map for GEMM code generation (with column-major order). """
    # Avoid import loops
    from dace.codegen.targets.common import sym2cpp
    from dace.libraries.blas.blas_helpers import get_gemm_opts

    (_, _, ashape,
     astride), (_, _, bshape,
                bstride), (_, _, cshape,
                           cstride) = _get_matmul_operands(node, state, sdfg)

    if getattr(node, 'transA', False):
        ashape = list(reversed(ashape))
        astride = list(reversed(astride))
    if getattr(node, 'transB', False):
        bshape = list(reversed(bshape))
        bstride = list(reversed(bstride))

    opt = get_gemm_opts(astride, bstride, cstride)
    bopt = _get_batchmm_opts(ashape, astride, bshape, bstride, cshape, cstride)

    opt['x'] = '_a'
    opt['y'] = '_b'
    opt['xdtype'] = adesc.dtype
    opt['ydtype'] = bdesc.dtype
    opt['cdtype'] = cdesc.dtype
    opt['M'] = sym2cpp(ashape[-2])
    opt['N'] = sym2cpp(bshape[-1])
    opt['K'] = sym2cpp(ashape[-1])
    opt['lda'] = sym2cpp(opt['lda'])
    opt['ldb'] = sym2cpp(opt['ldb'])
    opt['ldc'] = sym2cpp(opt['ldc'])

    if opt['swap']:
        if bopt:
            bopt['sa'], bopt['sb'] = bopt['sb'], bopt['sa']
        opt['lda'], opt['ldb'] = opt['ldb'], opt['lda']
        opt['x'], opt['y'] = opt['y'], opt['x']
        opt['xdtype'], opt['ydtype'] = opt['ydtype'], opt['xdtype']
        opt['ta'], opt['tb'] = opt['tb'], opt['ta']
        opt['M'], opt['N'] = opt['N'], opt['M']

    opt['alpha'] = alpha
    opt['beta'] = beta
    opt['dtype'] = cdtype
    opt['func'] = func
    if bopt:
        opt['stride_a'] = sym2cpp(bopt['sa'])
        opt['stride_b'] = sym2cpp(bopt['sb'])
        opt['stride_c'] = sym2cpp(bopt['sc'])
        opt['BATCH'] = sym2cpp(bopt['b'])
    else:
        opt['BATCH'] = None

    return opt
Пример #2
0
def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta,
                           cdtype, func) -> Dict[str, Any]:
    """ Get option map for GEMM code generation (with column-major order). """
    # Avoid import loops
    from dace.codegen.targets.common import sym2cpp

    (_, _, ashape, astride), (_, _, bshape,
                              bstride) = _get_matmul_inputs(node, state, sdfg)
    opt = get_gemm_opts(astride, bstride, cdesc.strides)
    bopt = get_batchmm_opts(ashape, astride, bshape, bstride, cdesc.shape,
                            cdesc.strides)
    opt['x'] = '_a'
    opt['y'] = '_b'
    opt['M'] = sym2cpp(ashape[-2])
    opt['N'] = sym2cpp(bshape[-1])
    opt['K'] = sym2cpp(ashape[-1])
    opt['lda'] = sym2cpp(opt['lda'])
    opt['ldb'] = sym2cpp(opt['ldb'])
    opt['ldc'] = sym2cpp(opt['ldc'])

    if opt['swap']:
        if bopt:
            bopt['sa'], bopt['sb'] = bopt['sb'], bopt['sa']
        opt['lda'], opt['ldb'] = opt['ldb'], opt['lda']
        opt['x'], opt['y'] = opt['y'], opt['x']
        opt['ta'], opt['tb'] = opt['tb'], opt['ta']
        opt['M'], opt['N'] = opt['N'], opt['M']

    opt['alpha'] = alpha
    opt['beta'] = beta
    opt['dtype'] = cdtype
    opt['func'] = func
    if bopt:
        opt['stride_a'] = sym2cpp(bopt['sa'])
        opt['stride_b'] = sym2cpp(bopt['sb'])
        opt['stride_c'] = sym2cpp(bopt['sc'])
        opt['BATCH'] = sym2cpp(bopt['b'])
    else:
        opt['BATCH'] = None

    return opt