def expansion(node, state, sdfg): node.validate(sdfg, state) (_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg) dtype = adesc.dtype.base_type func = to_blastype(dtype.type).lower() + 'gemm' alpha = f'{dtype.ctype}({node.alpha})' beta = f'{dtype.ctype}({node.beta})' # Deal with complex input constants if isinstance(node.alpha, complex): alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})' if isinstance(node.beta, complex): beta = f'{dtype.ctype}({node.beta.real}, {node.beta.imag})' cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc) opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, dtype.ctype, func) # Adaptations for BLAS API opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans' opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans' code = '' if dtype in (dace.complex64, dace.complex128): code = f''' {dtype.ctype} alpha = {alpha}; {dtype.ctype} beta = {beta}; ''' opt['alpha'] = '&alpha' opt['beta'] = '&beta' code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, " "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " "_c, {ldc});").format_map(opt) tasklet = dace.sdfg.nodes.Tasklet( node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP, ) return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) (_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg) cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc) dtype = cdesc.dtype.base_type func = to_blastype(dtype.type).lower() + 'gemm' if dtype == dace.float32: alpha = "1.0f" beta = "0.0f" elif dtype == dace.float64: alpha = "1.0" beta = "0.0" elif dtype == dace.complex64: alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" beta = "dace::blas::BlasConstants::Get().Complex64Zero()" elif dtype == dace.complex128: alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" beta = "dace::blas::BlasConstants::Get().Complex128Zero()" else: raise ValueError("Unsupported type for BLAS dot product: " + str(dtype)) opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdesc.dtype.ctype, func) # Adaptations for MKL/BLAS API opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans' opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans' code = ''' for (int __ib = 0; __ib < {BATCH}; ++__ib) {{ cblas_{func}(CblasColMajor, {ta}, {tb}, {M}, {N}, {K}, {alpha}, (({dtype}*){x}) + __ib*{stride_a}, {lda}, (({dtype}*){y}) + __ib*{stride_b}, {ldb}, {beta}, (({dtype}*)_c) + __ib*{stride_c}, {ldc}); }}'''.format_map(opt) tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet