def expansion(node, state, sdfg): node.validate(sdfg, state) (_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg) dtype = adesc.dtype.base_type func = to_blastype(dtype.type).lower() + 'gemm' alpha = f'{dtype.ctype}({node.alpha})' beta = f'{dtype.ctype}({node.beta})' # Deal with complex input constants if isinstance(node.alpha, complex): alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})' if isinstance(node.beta, complex): beta = f'{dtype.ctype}({node.beta.real}, {node.beta.imag})' cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc) opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, dtype.ctype, func) # Adaptations for BLAS API opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans' opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans' code = '' if dtype in (dace.complex64, dace.complex128): code = f''' {dtype.ctype} alpha = {alpha}; {dtype.ctype} beta = {beta}; ''' opt['alpha'] = '&alpha' opt['beta'] = '&beta' code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, " "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " "_c, {ldc});").format_map(opt) tasklet = dace.sdfg.nodes.Tasklet( node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP, ) return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) (_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg) cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc) dtype = cdesc.dtype.base_type func = to_blastype(dtype.type).lower() + 'gemm' if dtype == dace.float32: alpha = "1.0f" beta = "0.0f" elif dtype == dace.float64: alpha = "1.0" beta = "0.0" elif dtype == dace.complex64: alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" beta = "dace::blas::BlasConstants::Get().Complex64Zero()" elif dtype == dace.complex128: alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" beta = "dace::blas::BlasConstants::Get().Complex128Zero()" else: raise ValueError("Unsupported type for BLAS dot product: " + str(dtype)) opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdesc.dtype.ctype, func) # Adaptations for MKL/BLAS API opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans' opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans' code = ''' for (int __ib = 0; __ib < {BATCH}; ++__ib) {{ cblas_{func}(CblasColMajor, {ta}, {tb}, {M}, {N}, {K}, {alpha}, (({dtype}*){x}) + __ib*{stride_a}, {lda}, (({dtype}*){y}) + __ib*{stride_b}, {ldb}, {beta}, (({dtype}*)_c) + __ib*{stride_c}, {ldc}); }}'''.format_map(opt) tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) (_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg) dtype = adesc.dtype.base_type func = to_blastype(dtype.type).lower() + 'gemm' # TODO: Fix w.r.t. other alpha/beta values if dtype == dace.float32: alpha = "1.0f" beta = "0.0f" elif dtype == dace.float64: alpha = "1.0" beta = "0.0" elif dtype == dace.complex64: alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" beta = "dace::blas::BlasConstants::Get().Complex64Zero()" elif dtype == dace.complex128: alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" beta = "dace::blas::BlasConstants::Get().Complex128Zero()" else: raise ValueError("Unsupported type for BLAS dot product: " + str(dtype)) cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdesc.dtype.ctype, func) # Adaptations for MKL/BLAS API opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans' opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans' code = ("cblas_{func}(CblasColMajor, {ta}, {tb}, " "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " "_c, {ldc});").format_map(opt) tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) # Find inputs and output adesc, bdesc, cdesc = None, None, None for e in state.in_edges(node): if e.dst_conn == '_a': anode = state.memlet_path(e)[0].src if isinstance(anode, dace.sdfg.nodes.AccessNode): adesc: dt.Array = sdfg.arrays[anode.data] elif e.dst_conn == '_b': bnode = state.memlet_path(e)[0].src if isinstance(bnode, dace.sdfg.nodes.AccessNode): bdesc: dt.Array = sdfg.arrays[bnode.data] for e in state.out_edges(node): if e.src_conn == '_c': cnode = state.memlet_path(e)[-1].dst if isinstance(cnode, dace.sdfg.nodes.AccessNode): cdesc: dt.Array = sdfg.arrays[cnode.data] if not adesc or not bdesc or not cdesc: raise ValueError('Unsupported input/output arrays') dtype = adesc.dtype.base_type func = '%sgemm' % to_blastype(dtype.type) if dtype == dace.float16: cdtype = '__half' factort = 'Half' elif dtype == dace.float32: cdtype = 'float' factort = 'Float' elif dtype == dace.float64: cdtype = 'double' factort = 'Double' elif dtype == dace.complex64: cdtype = 'cuComplex' factort = 'Complex64' elif dtype == dace.complex128: cdtype = 'cuDoubleComplex' factort = 'Complex128' else: raise ValueError("Unsupported type: " + str(dtype)) call_prefix = environments.cublas.cuBLAS.handle_setup_code(node) call_suffix = '' # Handle alpha / beta constants = { 1.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Pone()", #-1.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Mone()", 0.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Zero()", } if node.alpha not in constants or node.beta not in constants: # Deal with complex input constants if isinstance(node.alpha, complex): alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})' else: alpha = f'{dtype.ctype}({node.alpha})' if isinstance(node.beta, complex): beta = f'{dtype.ctype}({node.beta.real}, {node.beta.imag})' else: beta = f'{dtype.ctype}({node.beta})' # Set pointer mode to host call_prefix += f'''cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_HOST); {dtype.ctype} alpha = {alpha}; {dtype.ctype} beta = {beta}; ''' call_suffix += ''' cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_DEVICE); ''' alpha = f'({cdtype} *)&alpha' beta = f'({cdtype} *)&beta' else: alpha = constants[node.alpha] beta = constants[node.beta] # Set up options for code formatting opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) # Matrix multiplication call = '''cublas{func}(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, {M}, {N}, {K}, {alpha}, ({dtype}*){x}, {lda}, ({dtype}*){y}, {ldb}, {beta}, ({dtype}*)_c, {ldc});''' code = (call_prefix + call.format_map(opt) + call_suffix) tasklet = dace.sdfg.nodes.Tasklet( node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP, ) # If buffers are not on the GPU, copy them if any(desc.storage not in [dace.StorageType.GPU_Global, dace.StorageType.CPU_Pinned] for desc in [adesc, bdesc, cdesc]): nsdfg = dace.SDFG('nested_gemm') for name, desc in [('_a', adesc), ('_b', bdesc), ('_c', cdesc)]: if isinstance(desc, dt.View): dcopy = desc.as_array() else: dcopy = dc(desc) dcopy.lifetime = dtypes.AllocationLifetime.Scope dcopy_gpu = dc(dcopy) dcopy.transient = False nsdfg.add_datadesc(name, dcopy) dcopy_gpu.transient = True dcopy_gpu.storage = dace.StorageType.GPU_Global nsdfg.add_datadesc(name + '_gpu', dcopy_gpu) nstate = nsdfg.add_state() a = nstate.add_read('_a') ga = nstate.add_access('_a_gpu') b = nstate.add_read('_b') gb = nstate.add_access('_b_gpu') c = nstate.add_write('_c') gc = nstate.add_access('_c_gpu') # Reset code and connectors tasklet.in_connectors = { "_conn" + k: None for k in tasklet.in_connectors } tasklet.out_connectors = { "_conn" + k: None for k in tasklet.out_connectors } call = '''cublas{func}(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, {M}, {N}, {K}, {alpha}, ({dtype}*){x}, {lda}, ({dtype}*){y}, {ldb}, {beta}, ({dtype}*)_conn_c, {ldc});''' opt['x'] = '_conn' + opt['x'] opt['y'] = '_conn' + opt['y'] tasklet.code.as_string = (call_prefix + call.format_map(opt) + call_suffix) nstate.add_node(tasklet) nstate.add_nedge(a, ga, dace.Memlet.from_array('_a', adesc)) nstate.add_nedge(b, gb, dace.Memlet.from_array('_b', bdesc)) nstate.add_edge(ga, None, tasklet, '_conn_a', dace.Memlet.from_array('_a_gpu', adesc)) nstate.add_edge(gb, None, tasklet, '_conn_b', dace.Memlet.from_array('_b_gpu', bdesc)) nstate.add_edge(tasklet, '_conn_c', gc, None, dace.Memlet.from_array('_c_gpu', cdesc)) nstate.add_nedge(gc, c, dace.Memlet.from_array('_c', cdesc)) if node.beta != 0.0: rc = nstate.add_read('_c') rgc = nstate.add_access('_c_gpu') tasklet.add_in_connector('_conn_cin') nstate.add_nedge(rc, rgc, dace.Memlet('_c')) nstate.add_edge(rgc, None, tasklet, '_conn_cin', dace.Memlet('_c_gpu')) return nsdfg # End of copy to GPU return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) # Find inputs and output adesc, bdesc, cdesc = None, None, None for e in state.in_edges(node): if e.dst_conn == '_a': anode = state.memlet_path(e)[0].src if isinstance(anode, dace.sdfg.nodes.AccessNode): adesc: dt.Array = sdfg.arrays[anode.data] elif e.dst_conn == '_b': bnode = state.memlet_path(e)[0].src if isinstance(bnode, dace.sdfg.nodes.AccessNode): bdesc: dt.Array = sdfg.arrays[bnode.data] for e in state.out_edges(node): if e.src_conn == '_c': cnode = state.memlet_path(e)[-1].dst if isinstance(cnode, dace.sdfg.nodes.AccessNode): cdesc: dt.Array = sdfg.arrays[cnode.data] if not adesc or not bdesc or not cdesc: raise ValueError('Unsupported input/output arrays') needs_copy = any(desc.storage not in (dace.StorageType.GPU_Global, dace.StorageType.CPU_Pinned) for desc in (adesc, bdesc, cdesc)) dtype = cdesc.dtype.base_type func = '%sgemm' % to_blastype(dtype.type) if dtype == dace.float16: cdtype = '__half' factort = 'Half' elif dtype == dace.float32: cdtype = 'float' factort = 'Float' elif dtype == dace.float64: cdtype = 'double' factort = 'Double' elif dtype == dace.complex64: cdtype = 'cuComplex' factort = 'Complex64' elif dtype == dace.complex128: cdtype = 'cuDoubleComplex' factort = 'Complex128' else: raise ValueError("Unsupported type: " + str(dtype)) call_prefix = environments.cublas.cuBLAS.handle_setup_code(node) call_suffix = '' # Handle alpha / beta constants = { 1.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Pone()", 0.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Zero()", } if node.alpha not in constants: # Deal with complex input constants if isinstance(node.alpha, complex): alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})' else: alpha = f'{dtype.ctype}({node.alpha})' # Set pointer mode to host call_prefix += f'''cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_HOST); {dtype.ctype} alpha = {alpha}; {dtype.ctype} beta = 0; ''' call_suffix += ''' cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_DEVICE); ''' beta = f'({cdtype} *)&beta' alpha = f'({cdtype} *)&alpha' else: alpha = constants[node.alpha] beta = "__state->cublas_handle.Constants(__dace_cuda_device).%sZero()" % factort # Set up options for code formatting opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) opt['array_prefix'] = '_' if needs_copy else '' # Matrix multiplication if (node.compute_type is None and node.accumulator_type is None and node.algorithm is None): call = '''cublas{func}StridedBatched(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, {M}, {N}, {K}, {alpha}, ({dtype}*){array_prefix}{x}, {lda}, {stride_a}, ({dtype}*){array_prefix}{y}, {ldb}, {stride_b}, {beta}, ({dtype}*){array_prefix}_c, {ldc}, {stride_c}, {BATCH});'''.format_map(opt) else: if node.compute_type is not None: acctype = node.compute_type elif node.accumulator_type is not None: acc_dtype: dtypes.typeclass = node.accumulator_type acctype = f'CUBLAS_COMPUTE_{to_cublas_computetype(acc_dtype)}' else: acctype = f'CUBLAS_COMPUTE_{to_cublas_computetype(dtype)}' algorithm = 'CUBLAS_GEMM_DEFAULT_TENSOR_OP' if node.algorithm is not None: algorithm = node.algorithm call = f''' cublasGemmStridedBatchedEx(__dace_cublas_handle, CUBLAS_OP_{opt['ta']}, CUBLAS_OP_{opt['tb']}, {opt['M']}, {opt['N']}, {opt['K']}, {alpha}, {opt['array_prefix']}{opt['x']}, {dtype_to_cudadatatype(opt['xdtype'])}, {opt['lda']}, {opt['stride_a']}, {opt['array_prefix']}{opt['y']}, {dtype_to_cudadatatype(opt['ydtype'])}, {opt['ldb']}, {opt['stride_b']}, {beta}, {opt['array_prefix']}_c, {dtype_to_cudadatatype(opt['cdtype'])}, {opt['ldc']}, {opt['stride_c']}, {opt['BATCH']}, {acctype}, {algorithm}); ''' code = call_prefix + call + call_suffix tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) # If buffers are not on the GPU, copy them if needs_copy: nsdfg = dace.SDFG('nested_batched_matmul') tasklet = dace.sdfg.nodes.Tasklet( node.name, { '__a': dtypes.pointer(adesc.dtype), '__b': dtypes.pointer(bdesc.dtype) }, {'__c': dtypes.pointer(cdesc.dtype)}, code, language=dace.dtypes.Language.CPP) for name, desc in [('_a', adesc), ('_b', bdesc), ('_c', cdesc)]: if isinstance(desc, dt.View): dcopy = desc.as_array() else: dcopy = dc(desc) dcopy.transient = False dcopy.lifetime = dtypes.AllocationLifetime.Scope dcopy_gpu = dc(dcopy) nsdfg.add_datadesc(name, dcopy) dcopy_gpu.transient = True dcopy_gpu.storage = dace.StorageType.GPU_Global nsdfg.add_datadesc(name + '_gpu', dcopy_gpu) nstate = nsdfg.add_state() a = nstate.add_read('_a') ga = nstate.add_access('_a_gpu') b = nstate.add_read('_b') gb = nstate.add_access('_b_gpu') c = nstate.add_write('_c') gc = nstate.add_access('_c_gpu') nstate.add_node(tasklet) nstate.add_nedge(a, ga, dace.Memlet.from_array('_a', adesc)) nstate.add_nedge(b, gb, dace.Memlet.from_array('_b', bdesc)) nstate.add_edge(ga, None, tasklet, '__a', dace.Memlet.from_array('_a_gpu', adesc)) nstate.add_edge(gb, None, tasklet, '__b', dace.Memlet.from_array('_b_gpu', bdesc)) nstate.add_edge(tasklet, '__c', gc, None, dace.Memlet.from_array('_c_gpu', cdesc)) nstate.add_nedge(gc, c, dace.Memlet.from_array('_c', cdesc)) return nsdfg # End of copy to GPU return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) dtype = node.dtype func = '%sgemm' % to_blastype(dtype.type) if dtype == dace.float16: cdtype = '__half' factort = 'Half' elif dtype == dace.float32: cdtype = 'float' factort = 'Float' elif dtype == dace.float64: cdtype = 'double' factort = 'Double' elif dtype == dace.complex64: cdtype = 'cuComplex' factort = 'Complex64' elif dtype == dace.complex128: cdtype = 'cuDoubleComplex' factort = 'Complex128' else: raise ValueError("Unsupported type: " + str(dtype)) alpha = "dace::blas::CublasConstants::Get(__dace_cuda_device).%sPone()" % factort beta = "dace::blas::CublasConstants::Get(__dace_cuda_device).%sZero()" % factort # Find inputs and output adesc, bdesc, cdesc = None, None, None for e in state.in_edges(node): if e.dst_conn == '_a': anode = state.memlet_path(e)[0].src if isinstance(anode, dace.sdfg.nodes.AccessNode): adesc: Array = sdfg.arrays[anode.data] elif e.dst_conn == '_b': bnode = state.memlet_path(e)[0].src if isinstance(bnode, dace.sdfg.nodes.AccessNode): bdesc: Array = sdfg.arrays[bnode.data] for e in state.out_edges(node): if e.src_conn == '_c': cnode = state.memlet_path(e)[-1].dst if isinstance(cnode, dace.sdfg.nodes.AccessNode): cdesc: Array = sdfg.arrays[cnode.data] if not adesc or not bdesc or not cdesc: raise ValueError('Unsupported input/output arrays') # Set up options for code formatting opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) # Matrix multiplication call = '''cublas{func}(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, {M}, {N}, {K}, {alpha}, ({dtype}*){x}, {lda}, ({dtype}*){y}, {ldb}, {beta}, ({dtype}*)_c, {ldc});''' code = (environments.cublas.cuBLAS.handle_setup_code(node) + call.format_map(opt)) tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) # If buffers are not on the GPU, copy them # TODO: This creates variable shadowing if any(desc.storage not in [dace.StorageType.GPU_Global, dace.StorageType.CPU_Pinned] for desc in [adesc, bdesc, cdesc]): nsdfg = dace.SDFG('nested_gemm') for name, desc in [('_a', adesc), ('_b', bdesc), ('_c', cdesc)]: dcopy = dc(desc) dcopy.transient = False nsdfg.add_datadesc(name, dcopy) dcopy_gpu = dc(desc) dcopy_gpu.transient = True dcopy_gpu.storage = dace.StorageType.GPU_Global nsdfg.add_datadesc(name + '_gpu', dcopy_gpu) nstate = nsdfg.add_state() a = nstate.add_read('_a') ga = nstate.add_access('_a_gpu') b = nstate.add_read('_b') gb = nstate.add_access('_b_gpu') c = nstate.add_write('_c') gc = nstate.add_access('_c_gpu') tasklet.in_connectors = { "_conn" + k: None for k in tasklet.in_connectors } tasklet.out_connectors = { "_conn" + k: None for k in tasklet.out_connectors } nstate.add_node(tasklet) nstate.add_nedge(a, ga, dace.Memlet.from_array('_a', adesc)) nstate.add_nedge(b, gb, dace.Memlet.from_array('_b', bdesc)) nstate.add_edge(ga, None, tasklet, '_conn_a', dace.Memlet.from_array('_a_gpu', adesc)) nstate.add_edge(gb, None, tasklet, '_conn_b', dace.Memlet.from_array('_b_gpu', bdesc)) nstate.add_edge(tasklet, '_conn_c', gc, None, dace.Memlet.from_array('_c_gpu', cdesc)) nstate.add_nedge(gc, c, dace.Memlet.from_array('_c', cdesc)) return nsdfg # End of copy to GPU return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) # Find inputs and output adesc, bdesc, cdesc = None, None, None for e in state.in_edges(node): if e.dst_conn == '_a': anode = state.memlet_path(e)[0].src if isinstance(anode, dace.sdfg.nodes.AccessNode): adesc: dt.Array = sdfg.arrays[anode.data] elif e.dst_conn == '_b': bnode = state.memlet_path(e)[0].src if isinstance(bnode, dace.sdfg.nodes.AccessNode): bdesc: dt.Array = sdfg.arrays[bnode.data] for e in state.out_edges(node): if e.src_conn == '_c': cnode = state.memlet_path(e)[-1].dst if isinstance(cnode, dace.sdfg.nodes.AccessNode): cdesc: dt.Array = sdfg.arrays[cnode.data] if not adesc or not bdesc or not cdesc: raise ValueError('Unsupported input/output arrays') dtype = cdesc.dtype.base_type func = '%sgemm' % to_blastype(dtype.type) if dtype == dace.float16: cdtype = '__half' factort = 'Half' elif dtype == dace.float32: cdtype = 'float' factort = 'Float' elif dtype == dace.float64: cdtype = 'double' factort = 'Double' elif dtype == dace.complex64: cdtype = 'cuComplex' factort = 'Complex64' elif dtype == dace.complex128: cdtype = 'cuDoubleComplex' factort = 'Complex128' else: raise ValueError("Unsupported type: " + str(dtype)) alpha = "__state->cublas_handle.Constants(__dace_cuda_device).%sPone()" % factort beta = "__state->cublas_handle.Constants(__dace_cuda_device).%sZero()" % factort # Set up options for code formatting opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) # Matrix multiplication call = '''cublas{func}StridedBatched(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, {M}, {N}, {K}, {alpha}, ({dtype}*){array_prefix}{x}, {lda}, {stride_a}, ({dtype}*){array_prefix}{y}, {ldb}, {stride_b}, {beta}, ({dtype}*){array_prefix}_c, {ldc}, {stride_c}, {BATCH});''' opt['array_prefix'] = '' code = (environments.cublas.cuBLAS.handle_setup_code(node) + call.format_map(opt)) tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) # If buffers are not on the GPU, copy them # TODO: doesn't work when storage is Default and Default=GPU_Global if any(desc.storage not in [dace.StorageType.GPU_Global, dace.StorageType.CPU_Pinned] for desc in [adesc, bdesc, cdesc]): nsdfg = dace.SDFG('nested_batched_matmul') opt['array_prefix'] = '_' code = (environments.cublas.cuBLAS.handle_setup_code(node) + call.format_map(opt)) tasklet = dace.sdfg.nodes.Tasklet( node.name, { '__a': dtypes.pointer(adesc.dtype), '__b': dtypes.pointer(bdesc.dtype) }, {'__c': dtypes.pointer(cdesc.dtype)}, code, language=dace.dtypes.Language.CPP) for name, desc in [('_a', adesc), ('_b', bdesc), ('_c', cdesc)]: if isinstance(desc, dt.View): dcopy = desc.as_array() else: dcopy = dc(desc) dcopy.transient = False dcopy.lifetime = dtypes.AllocationLifetime.Scope dcopy_gpu = dc(dcopy) nsdfg.add_datadesc(name, dcopy) dcopy_gpu.transient = True dcopy_gpu.storage = dace.StorageType.GPU_Global nsdfg.add_datadesc(name + '_gpu', dcopy_gpu) nstate = nsdfg.add_state() a = nstate.add_read('_a') ga = nstate.add_access('_a_gpu') b = nstate.add_read('_b') gb = nstate.add_access('_b_gpu') c = nstate.add_write('_c') gc = nstate.add_access('_c_gpu') nstate.add_node(tasklet) nstate.add_nedge(a, ga, dace.Memlet.from_array('_a', adesc)) nstate.add_nedge(b, gb, dace.Memlet.from_array('_b', bdesc)) nstate.add_edge(ga, None, tasklet, '__a', dace.Memlet.from_array('_a_gpu', adesc)) nstate.add_edge(gb, None, tasklet, '__b', dace.Memlet.from_array('_b_gpu', bdesc)) nstate.add_edge(tasklet, '__c', gc, None, dace.Memlet.from_array('_c_gpu', cdesc)) nstate.add_nedge(gc, c, dace.Memlet.from_array('_c', cdesc)) return nsdfg # End of copy to GPU return tasklet