def expansion(node, state, sdfg): node.validate(sdfg, state) dtype = node.dtype func = to_blastype(dtype.type).lower() + 'gemm' if dtype == dace.float32: alpha = "1.0f" beta = "0.0f" elif dtype == dace.float64: alpha = "1.0" beta = "0.0" elif dtype == dace.complex64: alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" beta = "dace::blas::BlasConstants::Get().Complex64Zero()" elif dtype == dace.complex128: alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" beta = "dace::blas::BlasConstants::Get().Complex128Zero()" else: raise ValueError("Unsupported type for BLAS dot product: " + str(dtype)) (_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides) = _get_matmul_inputs(node, state, sdfg) cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdesc.dtype.ctype, func) # Adaptations for MKL/BLAS API opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans' opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans' if not opt['BATCH']: code = ("cblas_{func}(CblasColMajor, {ta}, {tb}, " "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " "_c, {ldc});").format_map(opt) else: code = ''' for (int __ib = 0; __ib < {BATCH}; ++__ib) {{ cblas_{func}(CblasColMajor, {ta}, {tb}, {M}, {N}, {K}, {alpha}, (({dtype}*){x}) + __ib*{stride_a}, {lda}, (({dtype}*){y}) + __ib*{stride_b}, {ldb}, {beta}, (({dtype}*)_c) + __ib*{stride_c}, {ldc}); }}'''.format_map(opt) tasklet = dace.graph.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) (_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg) dtype = adesc.dtype.base_type func = to_blastype(dtype.type).lower() + 'gemm' alpha = f'{dtype.ctype}({node.alpha})' beta = f'{dtype.ctype}({node.beta})' # Deal with complex input constants if isinstance(node.alpha, complex): alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})' if isinstance(node.beta, complex): beta = f'{dtype.ctype}({node.beta.real}, {node.beta.imag})' cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc) opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, dtype.ctype, func) # Adaptations for BLAS API opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans' opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans' code = '' if dtype in (dace.complex64, dace.complex128): code = f''' {dtype.ctype} alpha = {alpha}; {dtype.ctype} beta = {beta}; ''' opt['alpha'] = '&alpha' opt['beta'] = '&beta' code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, " "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " "_c, {ldc});").format_map(opt) tasklet = dace.sdfg.nodes.Tasklet( node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP, ) return tasklet
def expansion(node, parent_state, parent_sdfg, n=None, **kwargs): (desc_x, stride_x, rows_x, cols_x), desc_result = node.validate(parent_sdfg, parent_state) dtype = desc_x.dtype.base_type lapack_dtype = blas_helpers.to_blastype(dtype.type).lower() if desc_x.dtype.veclen > 1: raise (NotImplementedError) n = n or node.n uplo = "'L'" if node._lower else "'U'" code = f"_res = LAPACKE_{lapack_dtype}potrf(LAPACK_ROW_MAJOR, {uplo}, {rows_x}, _xin, {stride_x});" tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet
def expansion(node, parent_state, parent_sdfg, **kwargs): a, b, c, desca, descb, gdescc, ldesc = node.validate( parent_sdfg, parent_state) dtype = a.dtype.base_type lapack_dtype_str = blas_helpers.to_blastype(dtype.type).lower() transa = 'N' if node._transa == 'T' else 'T' code = f""" const double zero = 0.0E+0, one = 1.0E+0; const char trans = '{transa}'; MKL_INT grows = (trans == 'T' ? {node._m} : {node._n}); MKL_INT gcols = 1; MKL_INT a_rows = {node._n}; MKL_INT a_cols = {node._m}; MKL_INT b_rows = (trans == 'T' ? {node._n} : {node._m}); MKL_INT b_cols = 1; MKL_INT brows = grows / __state->__mkl_scalapack_size; MKL_INT bcols = 1; MKL_INT a_brows = _a_block_sizes[1]; MKL_INT a_bcols = _a_block_sizes[0]; MKL_INT b_brows = _b_block_sizes[0]; MKL_INT b_bcols = 1; MKL_INT mloc = numroc( &grows, &brows, &__state->__mkl_scalapack_myprow, &__state->__mkl_int_zero, &__state->__mkl_scalapack_prows); MKL_INT a_mloc = numroc( &a_rows, &a_brows, &__state->__mkl_scalapack_myprow, &__state->__mkl_int_zero, &__state->__mkl_scalapack_prows); MKL_INT a_nloc = numroc( &a_cols, &a_bcols, &__state->__mkl_scalapack_mypcol, &__state->__mkl_int_zero, &__state->__mkl_scalapack_pcols); MKL_INT b_mloc = numroc( &b_rows, &b_brows, &__state->__mkl_scalapack_myprow, &__state->__mkl_int_zero, &__state->__mkl_scalapack_prows); MKL_INT info; MKL_INT _a_ldesc[9], _b_ldesc[9], _c_ldesc[9]; MKL_INT a_lld = a_mloc; descinit(_a_ldesc, &a_rows, &a_cols, &a_brows, &a_bcols, &__state->__mkl_int_zero, &__state->__mkl_int_zero, &__state->__mkl_scalapack_context, &a_lld, &info); MKL_INT b_lld = b_mloc; descinit(_b_ldesc, &b_rows, &b_cols, &b_mloc, &b_bcols, &__state->__mkl_int_zero, &__state->__mkl_int_zero, &__state->__mkl_scalapack_context, &b_lld, &info); MKL_INT c_lld = mloc; descinit(_c_ldesc, &grows, &gcols, &mloc, &bcols, &__state->__mkl_int_zero, &__state->__mkl_int_zero, &__state->__mkl_scalapack_context, &c_lld, &info); MKL_INT _m = a_rows, _n = a_cols; p{lapack_dtype_str}gemv( &trans, &_m, &_n, &one, _a, &__state->__mkl_int_one, &__state->__mkl_int_one, _a_ldesc, _b, &__state->__mkl_int_one, &__state->__mkl_int_one, _b_ldesc, &__state->__mkl_int_one, &zero, _c, &__state->__mkl_int_one, &__state->__mkl_int_one, _c_ldesc, &__state->__mkl_int_one); """ tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet
def expansion(node, parent_state, parent_sdfg, n=None, **kwargs): (desc_a, stride_a, rows_a, cols_a), (desc_rhs, stride_rhs, rows_rhs, cols_rhs), desc_ipiv, desc_res = node.validate( parent_sdfg, parent_state) dtype = desc_a.dtype.base_type lapack_dtype = blas_helpers.to_blastype(dtype.type).lower() if desc_a.dtype.veclen > 1: raise (NotImplementedError) n = n or node.n code = f"_res = LAPACKE_{lapack_dtype}getrs(LAPACK_ROW_MAJOR, 'N', {rows_a}, {cols_rhs}, _a, {stride_a}, _ipiv, _rhs_in, {stride_rhs});" tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet
def expansion(node, parent_state, parent_sdfg, n=None, **kwargs): (desc_x, stride_x, rows_x, cols_x), desc_ipiv, desc_result = node.validate(parent_sdfg, parent_state) dtype = desc_x.dtype.base_type lapack_dtype = blas_helpers.to_blastype(dtype.type).lower() cast = "" if lapack_dtype == 'c': cast = "(MKL_Complex8*)" elif lapack_dtype == 'z': cast = "(MKL_Complex16*)" if desc_x.dtype.veclen > 1: raise (NotImplementedError) n = n or node.n code = f"_res = LAPACKE_{lapack_dtype}getri(LAPACK_ROW_MAJOR, {rows_x}, {cast}_xin, {stride_x}, _ipiv);" tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) (_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg) dtype = adesc.dtype.base_type func = to_blastype(dtype.type).lower() + 'gemm' # TODO: Fix w.r.t. other alpha/beta values if dtype == dace.float32: alpha = "1.0f" beta = "0.0f" elif dtype == dace.float64: alpha = "1.0" beta = "0.0" elif dtype == dace.complex64: alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" beta = "dace::blas::BlasConstants::Get().Complex64Zero()" elif dtype == dace.complex128: alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" beta = "dace::blas::BlasConstants::Get().Complex128Zero()" else: raise ValueError("Unsupported type for BLAS dot product: " + str(dtype)) cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdesc.dtype.ctype, func) # Adaptations for MKL/BLAS API opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans' opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans' code = ("cblas_{func}(CblasColMajor, {ta}, {tb}, " "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " "_c, {ldc});").format_map(opt) tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) # Find inputs and output adesc, bdesc, cdesc = None, None, None for e in state.in_edges(node): if e.dst_conn == '_a': anode = state.memlet_path(e)[0].src if isinstance(anode, dace.sdfg.nodes.AccessNode): adesc: dt.Array = sdfg.arrays[anode.data] elif e.dst_conn == '_b': bnode = state.memlet_path(e)[0].src if isinstance(bnode, dace.sdfg.nodes.AccessNode): bdesc: dt.Array = sdfg.arrays[bnode.data] for e in state.out_edges(node): if e.src_conn == '_c': cnode = state.memlet_path(e)[-1].dst if isinstance(cnode, dace.sdfg.nodes.AccessNode): cdesc: dt.Array = sdfg.arrays[cnode.data] if not adesc or not bdesc or not cdesc: raise ValueError('Unsupported input/output arrays') dtype = adesc.dtype.base_type func = '%sgemm' % to_blastype(dtype.type) if dtype == dace.float16: cdtype = '__half' factort = 'Half' elif dtype == dace.float32: cdtype = 'float' factort = 'Float' elif dtype == dace.float64: cdtype = 'double' factort = 'Double' elif dtype == dace.complex64: cdtype = 'cuComplex' factort = 'Complex64' elif dtype == dace.complex128: cdtype = 'cuDoubleComplex' factort = 'Complex128' else: raise ValueError("Unsupported type: " + str(dtype)) call_prefix = environments.cublas.cuBLAS.handle_setup_code(node) call_suffix = '' # Handle alpha / beta constants = { 1.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Pone()", #-1.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Mone()", 0.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Zero()", } if node.alpha not in constants or node.beta not in constants: # Deal with complex input constants if isinstance(node.alpha, complex): alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})' else: alpha = f'{dtype.ctype}({node.alpha})' if isinstance(node.beta, complex): beta = f'{dtype.ctype}({node.beta.real}, {node.beta.imag})' else: beta = f'{dtype.ctype}({node.beta})' # Set pointer mode to host call_prefix += f'''cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_HOST); {dtype.ctype} alpha = {alpha}; {dtype.ctype} beta = {beta}; ''' call_suffix += ''' cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_DEVICE); ''' alpha = f'({cdtype} *)&alpha' beta = f'({cdtype} *)&beta' else: alpha = constants[node.alpha] beta = constants[node.beta] # Set up options for code formatting opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) # Matrix multiplication call = '''cublas{func}(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, {M}, {N}, {K}, {alpha}, ({dtype}*){x}, {lda}, ({dtype}*){y}, {ldb}, {beta}, ({dtype}*)_c, {ldc});''' code = (call_prefix + call.format_map(opt) + call_suffix) tasklet = dace.sdfg.nodes.Tasklet( node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP, ) # If buffers are not on the GPU, copy them if any(desc.storage not in [dace.StorageType.GPU_Global, dace.StorageType.CPU_Pinned] for desc in [adesc, bdesc, cdesc]): nsdfg = dace.SDFG('nested_gemm') for name, desc in [('_a', adesc), ('_b', bdesc), ('_c', cdesc)]: if isinstance(desc, dt.View): dcopy = desc.as_array() else: dcopy = dc(desc) dcopy.lifetime = dtypes.AllocationLifetime.Scope dcopy_gpu = dc(dcopy) dcopy.transient = False nsdfg.add_datadesc(name, dcopy) dcopy_gpu.transient = True dcopy_gpu.storage = dace.StorageType.GPU_Global nsdfg.add_datadesc(name + '_gpu', dcopy_gpu) nstate = nsdfg.add_state() a = nstate.add_read('_a') ga = nstate.add_access('_a_gpu') b = nstate.add_read('_b') gb = nstate.add_access('_b_gpu') c = nstate.add_write('_c') gc = nstate.add_access('_c_gpu') # Reset code and connectors tasklet.in_connectors = { "_conn" + k: None for k in tasklet.in_connectors } tasklet.out_connectors = { "_conn" + k: None for k in tasklet.out_connectors } call = '''cublas{func}(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, {M}, {N}, {K}, {alpha}, ({dtype}*){x}, {lda}, ({dtype}*){y}, {ldb}, {beta}, ({dtype}*)_conn_c, {ldc});''' opt['x'] = '_conn' + opt['x'] opt['y'] = '_conn' + opt['y'] tasklet.code.as_string = (call_prefix + call.format_map(opt) + call_suffix) nstate.add_node(tasklet) nstate.add_nedge(a, ga, dace.Memlet.from_array('_a', adesc)) nstate.add_nedge(b, gb, dace.Memlet.from_array('_b', bdesc)) nstate.add_edge(ga, None, tasklet, '_conn_a', dace.Memlet.from_array('_a_gpu', adesc)) nstate.add_edge(gb, None, tasklet, '_conn_b', dace.Memlet.from_array('_b_gpu', bdesc)) nstate.add_edge(tasklet, '_conn_c', gc, None, dace.Memlet.from_array('_c_gpu', cdesc)) nstate.add_nedge(gc, c, dace.Memlet.from_array('_c', cdesc)) if node.beta != 0.0: rc = nstate.add_read('_c') rgc = nstate.add_access('_c_gpu') tasklet.add_in_connector('_conn_cin') nstate.add_nedge(rc, rgc, dace.Memlet('_c')) nstate.add_edge(rgc, None, tasklet, '_conn_cin', dace.Memlet('_c_gpu')) return nsdfg # End of copy to GPU return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) # Find inputs and output adesc, bdesc, cdesc = None, None, None for e in state.in_edges(node): if e.dst_conn == '_a': anode = state.memlet_path(e)[0].src if isinstance(anode, dace.sdfg.nodes.AccessNode): adesc: dt.Array = sdfg.arrays[anode.data] elif e.dst_conn == '_b': bnode = state.memlet_path(e)[0].src if isinstance(bnode, dace.sdfg.nodes.AccessNode): bdesc: dt.Array = sdfg.arrays[bnode.data] for e in state.out_edges(node): if e.src_conn == '_c': cnode = state.memlet_path(e)[-1].dst if isinstance(cnode, dace.sdfg.nodes.AccessNode): cdesc: dt.Array = sdfg.arrays[cnode.data] if not adesc or not bdesc or not cdesc: raise ValueError('Unsupported input/output arrays') needs_copy = any(desc.storage not in (dace.StorageType.GPU_Global, dace.StorageType.CPU_Pinned) for desc in (adesc, bdesc, cdesc)) dtype = cdesc.dtype.base_type func = '%sgemm' % to_blastype(dtype.type) if dtype == dace.float16: cdtype = '__half' factort = 'Half' elif dtype == dace.float32: cdtype = 'float' factort = 'Float' elif dtype == dace.float64: cdtype = 'double' factort = 'Double' elif dtype == dace.complex64: cdtype = 'cuComplex' factort = 'Complex64' elif dtype == dace.complex128: cdtype = 'cuDoubleComplex' factort = 'Complex128' else: raise ValueError("Unsupported type: " + str(dtype)) call_prefix = environments.cublas.cuBLAS.handle_setup_code(node) call_suffix = '' # Handle alpha / beta constants = { 1.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Pone()", 0.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Zero()", } if node.alpha not in constants: # Deal with complex input constants if isinstance(node.alpha, complex): alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})' else: alpha = f'{dtype.ctype}({node.alpha})' # Set pointer mode to host call_prefix += f'''cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_HOST); {dtype.ctype} alpha = {alpha}; {dtype.ctype} beta = 0; ''' call_suffix += ''' cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_DEVICE); ''' beta = f'({cdtype} *)&beta' alpha = f'({cdtype} *)&alpha' else: alpha = constants[node.alpha] beta = "__state->cublas_handle.Constants(__dace_cuda_device).%sZero()" % factort # Set up options for code formatting opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) opt['array_prefix'] = '_' if needs_copy else '' # Matrix multiplication if (node.compute_type is None and node.accumulator_type is None and node.algorithm is None): call = '''cublas{func}StridedBatched(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, {M}, {N}, {K}, {alpha}, ({dtype}*){array_prefix}{x}, {lda}, {stride_a}, ({dtype}*){array_prefix}{y}, {ldb}, {stride_b}, {beta}, ({dtype}*){array_prefix}_c, {ldc}, {stride_c}, {BATCH});'''.format_map(opt) else: if node.compute_type is not None: acctype = node.compute_type elif node.accumulator_type is not None: acc_dtype: dtypes.typeclass = node.accumulator_type acctype = f'CUBLAS_COMPUTE_{to_cublas_computetype(acc_dtype)}' else: acctype = f'CUBLAS_COMPUTE_{to_cublas_computetype(dtype)}' algorithm = 'CUBLAS_GEMM_DEFAULT_TENSOR_OP' if node.algorithm is not None: algorithm = node.algorithm call = f''' cublasGemmStridedBatchedEx(__dace_cublas_handle, CUBLAS_OP_{opt['ta']}, CUBLAS_OP_{opt['tb']}, {opt['M']}, {opt['N']}, {opt['K']}, {alpha}, {opt['array_prefix']}{opt['x']}, {dtype_to_cudadatatype(opt['xdtype'])}, {opt['lda']}, {opt['stride_a']}, {opt['array_prefix']}{opt['y']}, {dtype_to_cudadatatype(opt['ydtype'])}, {opt['ldb']}, {opt['stride_b']}, {beta}, {opt['array_prefix']}_c, {dtype_to_cudadatatype(opt['cdtype'])}, {opt['ldc']}, {opt['stride_c']}, {opt['BATCH']}, {acctype}, {algorithm}); ''' code = call_prefix + call + call_suffix tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) # If buffers are not on the GPU, copy them if needs_copy: nsdfg = dace.SDFG('nested_batched_matmul') tasklet = dace.sdfg.nodes.Tasklet( node.name, { '__a': dtypes.pointer(adesc.dtype), '__b': dtypes.pointer(bdesc.dtype) }, {'__c': dtypes.pointer(cdesc.dtype)}, code, language=dace.dtypes.Language.CPP) for name, desc in [('_a', adesc), ('_b', bdesc), ('_c', cdesc)]: if isinstance(desc, dt.View): dcopy = desc.as_array() else: dcopy = dc(desc) dcopy.transient = False dcopy.lifetime = dtypes.AllocationLifetime.Scope dcopy_gpu = dc(dcopy) nsdfg.add_datadesc(name, dcopy) dcopy_gpu.transient = True dcopy_gpu.storage = dace.StorageType.GPU_Global nsdfg.add_datadesc(name + '_gpu', dcopy_gpu) nstate = nsdfg.add_state() a = nstate.add_read('_a') ga = nstate.add_access('_a_gpu') b = nstate.add_read('_b') gb = nstate.add_access('_b_gpu') c = nstate.add_write('_c') gc = nstate.add_access('_c_gpu') nstate.add_node(tasklet) nstate.add_nedge(a, ga, dace.Memlet.from_array('_a', adesc)) nstate.add_nedge(b, gb, dace.Memlet.from_array('_b', bdesc)) nstate.add_edge(ga, None, tasklet, '__a', dace.Memlet.from_array('_a_gpu', adesc)) nstate.add_edge(gb, None, tasklet, '__b', dace.Memlet.from_array('_b_gpu', bdesc)) nstate.add_edge(tasklet, '__c', gc, None, dace.Memlet.from_array('_c_gpu', cdesc)) nstate.add_nedge(gc, c, dace.Memlet.from_array('_c', cdesc)) return nsdfg # End of copy to GPU return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) dtype = node.dtype func = '%sgemm' % to_blastype(dtype.type) if dtype == dace.float16: cdtype = '__half' factort = 'Half' elif dtype == dace.float32: cdtype = 'float' factort = 'Float' elif dtype == dace.float64: cdtype = 'double' factort = 'Double' elif dtype == dace.complex64: cdtype = 'cuComplex' factort = 'Complex64' elif dtype == dace.complex128: cdtype = 'cuDoubleComplex' factort = 'Complex128' else: raise ValueError("Unsupported type: " + str(dtype)) alpha = "dace::blas::CublasConstants::Get(__dace_cuda_device).%sPone()" % factort beta = "dace::blas::CublasConstants::Get(__dace_cuda_device).%sZero()" % factort # Find inputs and output adesc, bdesc, cdesc = None, None, None for e in state.in_edges(node): if e.dst_conn == '_a': anode = state.memlet_path(e)[0].src if isinstance(anode, dace.sdfg.nodes.AccessNode): adesc: Array = sdfg.arrays[anode.data] elif e.dst_conn == '_b': bnode = state.memlet_path(e)[0].src if isinstance(bnode, dace.sdfg.nodes.AccessNode): bdesc: Array = sdfg.arrays[bnode.data] for e in state.out_edges(node): if e.src_conn == '_c': cnode = state.memlet_path(e)[-1].dst if isinstance(cnode, dace.sdfg.nodes.AccessNode): cdesc: Array = sdfg.arrays[cnode.data] if not adesc or not bdesc or not cdesc: raise ValueError('Unsupported input/output arrays') # Set up options for code formatting opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) # Matrix multiplication call = '''cublas{func}(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, {M}, {N}, {K}, {alpha}, ({dtype}*){x}, {lda}, ({dtype}*){y}, {ldb}, {beta}, ({dtype}*)_c, {ldc});''' code = (environments.cublas.cuBLAS.handle_setup_code(node) + call.format_map(opt)) tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) # If buffers are not on the GPU, copy them # TODO: This creates variable shadowing if any(desc.storage not in [dace.StorageType.GPU_Global, dace.StorageType.CPU_Pinned] for desc in [adesc, bdesc, cdesc]): nsdfg = dace.SDFG('nested_gemm') for name, desc in [('_a', adesc), ('_b', bdesc), ('_c', cdesc)]: dcopy = dc(desc) dcopy.transient = False nsdfg.add_datadesc(name, dcopy) dcopy_gpu = dc(desc) dcopy_gpu.transient = True dcopy_gpu.storage = dace.StorageType.GPU_Global nsdfg.add_datadesc(name + '_gpu', dcopy_gpu) nstate = nsdfg.add_state() a = nstate.add_read('_a') ga = nstate.add_access('_a_gpu') b = nstate.add_read('_b') gb = nstate.add_access('_b_gpu') c = nstate.add_write('_c') gc = nstate.add_access('_c_gpu') tasklet.in_connectors = { "_conn" + k: None for k in tasklet.in_connectors } tasklet.out_connectors = { "_conn" + k: None for k in tasklet.out_connectors } nstate.add_node(tasklet) nstate.add_nedge(a, ga, dace.Memlet.from_array('_a', adesc)) nstate.add_nedge(b, gb, dace.Memlet.from_array('_b', bdesc)) nstate.add_edge(ga, None, tasklet, '_conn_a', dace.Memlet.from_array('_a_gpu', adesc)) nstate.add_edge(gb, None, tasklet, '_conn_b', dace.Memlet.from_array('_b_gpu', bdesc)) nstate.add_edge(tasklet, '_conn_c', gc, None, dace.Memlet.from_array('_c_gpu', cdesc)) nstate.add_nedge(gc, c, dace.Memlet.from_array('_c', cdesc)) return nsdfg # End of copy to GPU return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) # Find inputs and output adesc, bdesc, cdesc = None, None, None for e in state.in_edges(node): if e.dst_conn == '_a': anode = state.memlet_path(e)[0].src if isinstance(anode, dace.sdfg.nodes.AccessNode): adesc: dt.Array = sdfg.arrays[anode.data] elif e.dst_conn == '_b': bnode = state.memlet_path(e)[0].src if isinstance(bnode, dace.sdfg.nodes.AccessNode): bdesc: dt.Array = sdfg.arrays[bnode.data] for e in state.out_edges(node): if e.src_conn == '_c': cnode = state.memlet_path(e)[-1].dst if isinstance(cnode, dace.sdfg.nodes.AccessNode): cdesc: dt.Array = sdfg.arrays[cnode.data] if not adesc or not bdesc or not cdesc: raise ValueError('Unsupported input/output arrays') dtype = cdesc.dtype.base_type func = '%sgemm' % to_blastype(dtype.type) if dtype == dace.float16: cdtype = '__half' factort = 'Half' elif dtype == dace.float32: cdtype = 'float' factort = 'Float' elif dtype == dace.float64: cdtype = 'double' factort = 'Double' elif dtype == dace.complex64: cdtype = 'cuComplex' factort = 'Complex64' elif dtype == dace.complex128: cdtype = 'cuDoubleComplex' factort = 'Complex128' else: raise ValueError("Unsupported type: " + str(dtype)) alpha = "__state->cublas_handle.Constants(__dace_cuda_device).%sPone()" % factort beta = "__state->cublas_handle.Constants(__dace_cuda_device).%sZero()" % factort # Set up options for code formatting opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) # Matrix multiplication call = '''cublas{func}StridedBatched(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, {M}, {N}, {K}, {alpha}, ({dtype}*){array_prefix}{x}, {lda}, {stride_a}, ({dtype}*){array_prefix}{y}, {ldb}, {stride_b}, {beta}, ({dtype}*){array_prefix}_c, {ldc}, {stride_c}, {BATCH});''' opt['array_prefix'] = '' code = (environments.cublas.cuBLAS.handle_setup_code(node) + call.format_map(opt)) tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) # If buffers are not on the GPU, copy them # TODO: doesn't work when storage is Default and Default=GPU_Global if any(desc.storage not in [dace.StorageType.GPU_Global, dace.StorageType.CPU_Pinned] for desc in [adesc, bdesc, cdesc]): nsdfg = dace.SDFG('nested_batched_matmul') opt['array_prefix'] = '_' code = (environments.cublas.cuBLAS.handle_setup_code(node) + call.format_map(opt)) tasklet = dace.sdfg.nodes.Tasklet( node.name, { '__a': dtypes.pointer(adesc.dtype), '__b': dtypes.pointer(bdesc.dtype) }, {'__c': dtypes.pointer(cdesc.dtype)}, code, language=dace.dtypes.Language.CPP) for name, desc in [('_a', adesc), ('_b', bdesc), ('_c', cdesc)]: if isinstance(desc, dt.View): dcopy = desc.as_array() else: dcopy = dc(desc) dcopy.transient = False dcopy.lifetime = dtypes.AllocationLifetime.Scope dcopy_gpu = dc(dcopy) nsdfg.add_datadesc(name, dcopy) dcopy_gpu.transient = True dcopy_gpu.storage = dace.StorageType.GPU_Global nsdfg.add_datadesc(name + '_gpu', dcopy_gpu) nstate = nsdfg.add_state() a = nstate.add_read('_a') ga = nstate.add_access('_a_gpu') b = nstate.add_read('_b') gb = nstate.add_access('_b_gpu') c = nstate.add_write('_c') gc = nstate.add_access('_c_gpu') nstate.add_node(tasklet) nstate.add_nedge(a, ga, dace.Memlet.from_array('_a', adesc)) nstate.add_nedge(b, gb, dace.Memlet.from_array('_b', bdesc)) nstate.add_edge(ga, None, tasklet, '__a', dace.Memlet.from_array('_a_gpu', adesc)) nstate.add_edge(gb, None, tasklet, '__b', dace.Memlet.from_array('_b_gpu', bdesc)) nstate.add_edge(tasklet, '__c', gc, None, dace.Memlet.from_array('_c_gpu', cdesc)) nstate.add_nedge(gc, c, dace.Memlet.from_array('_c', cdesc)) return nsdfg # End of copy to GPU return tasklet