def _generic_fmm(proc_idx, queue, device_id): # Unpack the function arguments a: ArgsFmm = queue.get() X1: torch.Tensor = a.X1 X2: torch.Tensor = a.X2 cuda_inputs = X1.is_cuda out = a.out kernel, gpu_dtype = a.kernel, a.gpu_dtype max_mem = a.max_mem num_streams = a.num_streams # flags and local variables change_dtype = gpu_dtype != X1.dtype X1_equal_X2 = _gpu_tns_same_memory(X1, X2) use_gpu_bufs = change_dtype or not cuda_inputs stride = "F" if is_f_contig(out, strict=True) else "C" j_iter = 0 dts = sizeof_dtype(gpu_dtype) tc_device = torch.device('cuda:%d' % (int(device_id))) avail_mem = max_mem / dts # Choose block sizes n, m such that we won't run out of GPU memory ntot, d = X1.shape mtot = X2.shape[0] extra_mem = kernel.extra_mem() if cuda_inputs and not change_dtype: # No allocation will be performed by us. Only in-kernel stuff. n, m = select_dim_over_nm(max_n=ntot, max_m=mtot, d=d, coef_nd=extra_mem.get('nd', 0), coef_md=extra_mem.get('md', 0), coef_nm=extra_mem.get('nm', 0), coef_n=extra_mem.get('n', 0), coef_m=extra_mem.get('m', 0), rest=extra_mem.get('d', 0), max_mem=avail_mem) else: n, m = select_dim_over_nm( max_n=ntot, max_m=mtot, d=d, coef_nd=num_streams * (extra_mem.get('nd', 0) + 1), coef_md=num_streams * (extra_mem.get('md', 0) + 1), coef_nm=num_streams * (extra_mem.get('nm', 0) + 1), coef_n=extra_mem.get('n', 0), coef_m=extra_mem.get('m', 0), rest=extra_mem.get('d', 0), max_mem=avail_mem) # Create streams streams = [tcd.Stream(device=tc_device) for _ in range(num_streams)] # Create buffers if use_gpu_bufs: gX1 = create_same_stride((n, d), X1, gpu_dtype, tc_device) gX2_list = [ create_same_stride((m, d), X2, gpu_dtype, tc_device) for _ in range(num_streams) ] gout_list = [ create_same_stride((n, m), out, gpu_dtype, tc_device) for _ in range(num_streams) ] if not cuda_inputs: cpu_buf_list = [ create_same_stride((n, m), out, gpu_dtype, 'cpu', pin_memory=True) for _ in range(num_streams) ] # Define helpers for the copy-back operations (from cpu_buf to output) copy_ops = [None] * num_streams def wrap_copy_op(stream_idx): if copy_ops[stream_idx] is not None: copy_ops[stream_idx]() copy_ops[stream_idx] = None def do_copy_op(output, buf, i_, ic_, j_, jc_): # This function will also do the type conversion output[i_:i_ + ic_, j_:j_ + jc_].copy_(buf[:ic_, :jc_]) # Kernel computation begin with tcd.device(tc_device): for i in range(0, ntot, n): ic = min(n, ntot - i) with tcd.stream(streams[j_iter % len(streams)]): X1_chunk = X1.narrow(0, i, ic) if use_gpu_bufs: cur_gX1 = gX1.narrow(0, 0, ic) cur_gX1.copy_(X1_chunk, non_blocking=True) else: cur_gX1 = X1_chunk for j in range(0, mtot, m): jc = min(m, mtot - j) # Choose the buffers for this inner iteration stream_id = j_iter % len(streams) stream = streams[stream_id] if use_gpu_bufs: gX2 = gX2_list[stream_id] gout = gout_list[stream_id] if not cuda_inputs: cpu_buf = cpu_buf_list[stream_id] # Sync for buffers we must use now (e.g. 2 previous iters) with tcd.stream(stream): # Inner-loop stream.synchronize() wrap_copy_op(stream_id) if X1_equal_X2 and j < i: # Shortcut for symmetric kernels jc = min(m, mtot - j) out[i:i + ic, j:j + jc].copy_(out[j:j + jc, i:i + ic].T, non_blocking=True) j_iter += 1 continue # Copy (CPU->GPU) X2_chunk = X2.narrow(0, j, jc) if use_gpu_bufs: cur_gX2 = gX2.narrow(0, 0, jc) cur_gX2.copy_(X2_chunk, non_blocking=True) else: cur_gX2 = X2_chunk if use_gpu_bufs: cur_gout = gout[:ic, :jc] else: cur_gout = out[i:i + ic, j:j + jc] cur_gout.fill_(0.0) # Compute ddd = kernel._prepare(cur_gX1, cur_gX2) kernel._apply(cur_gX1, cur_gX2.T, cur_gout) cur_gout = kernel._finalize(cur_gout, ddd) # Copy Back (GPU->CPU) if not cuda_inputs: # copy_ does not care about the contiguity of copies, as long as it's consistent # however, in case of C-contiguous inputs it will create an intermediate array # which is undesired. We use cuda_memcpy2d_async which works well with C-contiguous # arrays. if stride == "F": copy_to_host(ic, jc, cur_gout, 0, 0, cpu_buf, 0, 0, s=stream) else: cuda_memcpy2d_async(dst=cpu_buf.data_ptr(), dpitch=cpu_buf.stride(0) * dts, src=cur_gout.data_ptr(), spitch=cur_gout.stride(0) * dts, width=jc * dts, height=ic, stream=stream._as_parameter_) copy_ops[stream_id] = partial(do_copy_op, out, cpu_buf, i, ic, j, jc) elif change_dtype: out.narrow(0, i, ic).narrow(1, j, jc).copy_(cur_gout, non_blocking=True) j_iter += 1 for i in range(num_streams): streams[i].synchronize() wrap_copy_op(i) return out
def par_lauum_f_lower(A: torch.Tensor, block_allocs: List[BlockAlloc], my_rows: List[int], barrier: threading.Barrier, device_id: int, cublas_handle, independent_output: bool): N = A.shape[0] lauum_fn = choose_fn(A.dtype, scll.dlauum, scll.slauum, "Lapack LAUUM") trmm_fn = choose_fn(A.dtype, cublasDtrmm, cublasStrmm, "cuBlas TRMM") gemm_fn = choose_fn(A.dtype, cublasDgemm, cublasSgemm, "cuBlas GEMM") syrk_fn = choose_fn(A.dtype, cublasDsyrk, cublasSsyrk, "cuBlas SYRK") tc_device = torch.device('cuda:%d' % (device_id)) s1 = torch.cuda.Stream(device=tc_device) s2 = torch.cuda.Stream(device=tc_device) cublasSetStream(cublas_handle, s1._as_parameter_) max_block_size = max(ba.length for ba in block_allocs) my_rows = sorted(my_rows) with torch.cuda.device(tc_device), torch.cuda.stream(s1): # Preallocate 2 columns whole_col_b = create_fortran((A.shape[0], max_block_size), A.dtype, tc_device) whole_col_r = create_fortran((A.shape[0], max_block_size), A.dtype, tc_device) temp_bb = create_fortran((max_block_size, max_block_size), A.dtype, 'cpu', pin_memory=True) for b in range(len(block_allocs)): bb = block_allocs[b] # Load col b. # Instead of loading the whole column only load the last rows # as necessary by inspecting the minimum value in my_rows which is >= b. try: min_row = min([r for r in my_rows if r >= b]) b_start = block_allocs[min_row].start col_b = copy_to_device(N - b_start, bb.length, A, b_start, bb.start, whole_col_b, 0, 0, s1) except ValueError: pass # No column here if not independent_output: barrier.wait() for r in my_rows: if r < b: continue if r == b: # SYRK on g_b[bb.length:, :] with output replacing g_b[:bb.length, :] # C = beta*C + alpha * op(A) @ op(A).T if b_start + bb.length < N: syrk_fn(cublas_handle, uplo='L', trans='T', n=bb.length, k=col_b.shape[0] - bb.length, alpha=1.0, A=col_b[bb.length:, :].data_ptr(), lda=col_b.stride(1), beta=0.0, C=col_b.data_ptr(), ldc=col_b.stride(1)) # CPU LAUUM on A[bb.start:bb.end, bb.start:bb.end]. This is a bit messy, should do cleanup. Abb = A[bb.start:bb.end, bb.start:bb.end] # L\U if independent_output: Abb_np = Abb.numpy().copy(order="F") # Make symmetric: L\L copy_triang(Abb_np, upper=False) uu, info = lauum_fn(Abb_np, lower=1, overwrite_c=True) # LAU\L Abb.copy_(torch.from_numpy(uu.T)) # L\LAU else: uu, info = lauum_fn(Abb.numpy(), lower=1, overwrite_c=False) # LAU\L if b_start + bb.length < N: zero_triang(uu, upper=True) Abb.copy_(torch.from_numpy(uu)) if b_start + bb.length < N: # It is IMPORTANT to do the copy on s1 and then sync it. tbb = copy_to_host(bb.length, bb.length, col_b, 0, 0, temp_bb, 0, 0, s1) s1.synchronize() if independent_output: Abb.add_(torch.triu(tbb.T)) else: Abb.add_(tbb) else: # r > b br = block_allocs[r] # Load column r. Since r > b this column will be shorter than column b col_r = copy_to_device(N - br.start, br.length, A, br.start, br.start, whole_col_r, 0, 0, s1) # Restrict column b to only the last 'r' rows ccb = col_b[br.start - b_start:, :] # TRMM on g_r[0:br.length, :] which is triangular (r*r) # and cur_g_b[0:br.length, :] # output is a r*b matrix and should be stored in a separate g_out block # Could store output in the first rows of g_b # C = alpha * op(A) @ B -- A triangular trmm_fn(handle=cublas_handle, side='L', uplo='L', trans='T', diag='N', m=br.length, n=bb.length, alpha=1.0, A=col_r.data_ptr(), lda=col_r.stride(1), B=ccb.data_ptr(), ldb=ccb.stride(1), C=ccb.data_ptr(), ldc=ccb.stride(1)) # GEMM on g_r[br.length:, :].T and cur_g_b[bb.length:, :] # output is the same r*b matrix as before, outputs need to be summed. # C = alpha * op(A) @ op(B) + beta * C if br.end < N: gemm_fn(handle=cublas_handle, transa='T', transb='N', m=br.length, n=bb.length, k=col_r.shape[0] - br.length, alpha=1.0, A=col_r[br.length:, :].data_ptr(), lda=col_r.stride(1), B=ccb[br.length:, :].data_ptr(), ldb=ccb.stride(1), beta=1.0, C=ccb.data_ptr(), ldc=ccb.stride(1)) # Copy back to A[r, b] if independent_output: _temp_cpu = copy_to_host(br.length, bb.length, ccb, 0, 0, temp_bb, 0, 0, s1) s1.synchronize() A[bb.start:bb.end, br.start:br.end].copy_(_temp_cpu.T) else: s1.synchronize() copy_to_host(br.length, bb.length, ccb, 0, 0, A, br.start, bb.start, s2) s2.synchronize()
def _ic_cholesky(A, upper, device, cusolver_handle): """Cholesky factorization of matrix `A` on the GPU Uses the cuSOLVER library for implementation of the POTRF function. Parameters: ----------- A : [n, n] CPU or GPU array (column-contiguous) The (positive definite) matrix which should be factorized upper : bool Whether we need to factorize the upper of lower portion of `A`. The other side of the matrix will not be touched. device : int The GPU device on which to run the factorization cusolver_handle Pointer to the cuSOLVER context, which needs to be initialized before calling the function. Returns: -------- A : [n, n] CPU or GPU array (column-contiguous) The factorization of A which overwrites the upper (or lower) triangular part of the matrix A. This is not a copy of the original matrix. """ # Check library initialization if cusolver_handle is None: raise RuntimeError("CuSolver must be initialized " "before running in-core Cholesky.") if not is_f_contig(A): raise RuntimeError("Cholesky input must be F-contiguous") uplo = 'U' if upper else 'L' n = A.shape[0] tc_device = torch.device("cuda:%d" % (device)) # Choose functions by dtype potrf_buf_size = choose_fn(A.dtype, cusolverDnDpotrf_bufferSize, cusolverDnSpotrf_bufferSize, "POTRF Buffer size") potrf_fn = choose_fn(A.dtype, cusolverDnDpotrf, cusolverDnSpotrf, "POTRF") # noinspection PyUnresolvedReferences with torch.cuda.device(tc_device): # Copy A to device memory if A.is_cuda: Agpu = A else: Agpu = create_fortran(A.shape, A.dtype, tc_device) copy_to_device(n, n, A, 0, 0, Agpu, 0, 0) # Create workspace buffer potrf_bsize = potrf_buf_size(handle=cusolver_handle, uplo=uplo, n=n, A=Agpu.data_ptr(), lda=n) potrf_wspace = create_fortran((potrf_bsize, ), A.dtype, tc_device) dev_info = torch.tensor(4, dtype=torch.int32, device=tc_device) # Run cholesky potrf_fn(handle=cusolver_handle, uplo=uplo, n=n, A=Agpu.data_ptr(), lda=n, workspace=potrf_wspace.data_ptr(), Lwork=potrf_bsize, devInfo=dev_info) torch.cuda.synchronize() # Copy back to CPU if not A.is_cuda: copy_to_host(n, n, Agpu, 0, 0, A, 0, 0) del Agpu del potrf_wspace, dev_info return A
def par_lauum_f_lower(A: torch.Tensor, block_allocs: List[BlockAlloc], my_rows: List[int], barrier: threading.Barrier, device_id: int, cublas_handle, independent_output: bool): N = A.shape[0] is_cuda = A.device.type == "cuda" trmm_fn = choose_fn(A.dtype, cublasDtrmm, cublasStrmm, "cuBlas TRMM") gemm_fn = choose_fn(A.dtype, cublasDgemm, cublasSgemm, "cuBlas GEMM") syrk_fn = choose_fn(A.dtype, cublasDsyrk, cublasSsyrk, "cuBlas SYRK") tc_device = torch.device('cuda:%d' % (device_id)) s1 = torch.cuda.Stream(device=tc_device) s3 = torch.cuda.Stream(device=tc_device) max_block_size = max(ba.length for ba in block_allocs) my_rows = sorted(my_rows) with torch.cuda.device(tc_device), torch.cuda.stream(s1), cublas_stream(cublas_handle, s1._as_parameter_): # Pre allocate b-col, syrk-out, lauum-out mem_needed = N * max_block_size + 2 * (max_block_size ** 2) if not is_cuda: # Also pre alloc r-col mem_needed += N * max_block_size f_gpu = torch.empty(size=(mem_needed,), dtype=A.dtype, device=tc_device) f_offset = 0 whole_col_b, f_offset = _extract_flat(f_gpu, (N, max_block_size), other=A, offset=f_offset) syrk_out, f_offset = _extract_flat(f_gpu, (max_block_size, max_block_size), other=A, offset=f_offset) lauum_out, f_offset = _extract_flat(f_gpu, (max_block_size, max_block_size), other=A, offset=f_offset) if not is_cuda: temp_bb = create_fortran((max_block_size, max_block_size), A.dtype, 'cpu', pin_memory=True) whole_col_r, f_offset = _extract_flat(f_gpu, (N, max_block_size), other=A, offset=f_offset) syrk_out.fill_(0.0) for b in range(len(block_allocs)): bb = block_allocs[b] # Load col b. # Instead of loading the whole column only load the last rows # as necessary by inspecting the minimum value in my_rows which is >= b. try: min_row = min([r for r in my_rows if r >= b]) b_start = block_allocs[min_row].start if is_cuda: col_b = whole_col_b[b_start:, :bb.length] col_b.copy_(A[b_start:N, bb.start:bb.end]) else: col_b: torch.Tensor = copy_to_device( N - b_start, bb.length, A, b_start, bb.start, whole_col_b, 0, 0, s1) except ValueError: pass # No column here if not independent_output: # wait for copy to device to succeed. After barrier other threads may modify # the part of col_b which we need! s1.synchronize() barrier.wait() for r in my_rows: if r == b: # SYRK on col_b[bb.length:, :] with output into syrk_out[:bb.length, :bb.length] # C = beta*C + alpha * op(A) @ op(A).T if b_start + bb.length < N: cur_syrk_out = syrk_out[:bb.length, :bb.length] syrk_fn(cublas_handle, uplo='L', trans='T', n=bb.length, k=col_b.shape[0] - bb.length, alpha=1.0, A=col_b[bb.length:, :].data_ptr(), lda=col_b.stride(1), beta=0.0, C=cur_syrk_out.data_ptr(), ldc=syrk_out.stride(1)) with torch.cuda.stream(s3): if independent_output: s1.synchronize() # we need col_b to be loaded cur_lauum_out = lauum_out[:bb.length, :bb.length] # Note that col_b[:bb.length, :bb.length] == Abb if independent_output: # In the independent output case we need to preserve tril(Abb) instead! cur_lauum_out.copy_(col_b[:bb.length, :bb.length].T) else: # In normal case we need triu(Abb) to be preserved in the upper triangle of lauum_out cur_lauum_out.copy_(col_b[:bb.length, :bb.length]) # LAUUM on col_b[:bb.length, :bb.length], into lauum_out[:bb.length, :bb.length] cuda_lauum(n=bb.length, A=col_b[:bb.length, :bb.length], lda=col_b.stride(1), B=cur_lauum_out, ldb=max_block_size, lower=True) s1.wait_stream(s3) # all subsequent work will need cur_lauum_out # Add outputs of SYRK and LAUUM (only if SYRK was performed) if b_start + bb.length < N: cur_lauum_out.add_(cur_syrk_out) # Copy lauum_out into the original matrix, while preserving the other side # of the triangular matrix. This depends on the `independent_output` flag. Abb = A[bb.start:bb.end, bb.start:bb.end] if independent_output: # cuda and non-cuda cases, since we have different orderings. Abb.copy_(cur_lauum_out.T) elif is_cuda: Abb.copy_(cur_lauum_out) else: copy_to_host(bb.length, bb.length, cur_lauum_out, 0, 0, Abb, 0, 0, s=s1) elif r > b: br = block_allocs[r] # Load column r. Since r > b this column will be shorter than column b if is_cuda: # If col_r is already in GPU no copy needed. col_r = A[br.start:N, br.start:br.end] else: col_r = copy_to_device(N - br.start, br.length, A, br.start, br.start, whole_col_r, 0, 0, s1) # Restrict column b to only the last 'r' rows ccb = col_b[br.start - b_start:, :] # TRMM on g_r[0:br.length, :] which is triangular (r*r) # and cur_g_b[0:br.length, :] # output is a r*b matrix stored in the first rows of ccb # C = alpha * op(A) @ B -- A triangular trmm_fn( handle=cublas_handle, side='L', uplo='L', trans='T', diag='N', m=br.length, n=bb.length, alpha=1.0, A=col_r.data_ptr(), lda=col_r.stride(1), B=ccb.data_ptr(), ldb=ccb.stride(1), C=ccb.data_ptr(), ldc=ccb.stride(1)) # GEMM on g_r[br.length:, :].T and cur_g_b[bb.length:, :] # output is the same r*b matrix as before, outputs need to be summed. # C = alpha * op(A) @ op(B) + beta * C if br.end < N: gemm_fn(handle=cublas_handle, transa='T', transb='N', m=br.length, n=bb.length, k=col_r.shape[0] - br.length, alpha=1.0, A=col_r[br.length:, :].data_ptr(), lda=col_r.stride(1), B=ccb[br.length:, :].data_ptr(), ldb=ccb.stride(1), beta=1.0, C=ccb.data_ptr(), ldc=ccb.stride(1)) # Copy back to A[r, b] if independent_output: if is_cuda: A[bb.start:bb.end, br.start:br.end].copy_(ccb[:br.length, :bb.length].T) else: _temp_cpu = copy_to_host(br.length, bb.length, ccb, 0, 0, temp_bb, 0, 0, s1) s1.synchronize() # must wait for data to be onto CPU. A[bb.start:bb.end, br.start:br.end].copy_(_temp_cpu.T) elif is_cuda: A[br.start:br.end, bb.start:bb.end].copy_(ccb[:br.length, :bb.length]) else: copy_to_host(br.length, bb.length, ccb, 0, 0, A, br.start, bb.start, s1) s1.synchronize()
def par_lauum_f_lower(A: torch.Tensor, block_allocs: List[BlockAlloc], my_rows: List[int], barrier: threading.Barrier, device_id: int, cublas_handle, independent_output: bool): N = A.shape[0] is_cuda = A.device.type == "cuda" trmm_fn = choose_fn(A.dtype, cublasDtrmm, cublasStrmm, "cuBlas TRMM") gemm_fn = choose_fn(A.dtype, cublasDgemm, cublasSgemm, "cuBlas GEMM") syrk_fn = choose_fn(A.dtype, cublasDsyrk, cublasSsyrk, "cuBlas SYRK") tc_device = torch.device('cuda:%d' % (device_id)) s1 = torch.cuda.Stream(device=tc_device) s3 = torch.cuda.Stream(device=tc_device) max_block_size = max(ba.length for ba in block_allocs) my_rows = sorted(my_rows) with torch.cuda.device(tc_device), torch.cuda.stream(s1), cublas_stream( cublas_handle, s1._as_parameter_): # Preallocate 2 columns if not is_cuda: whole_col_b = create_fortran((A.shape[0], max_block_size), A.dtype, tc_device) whole_col_r = create_fortran((A.shape[0], max_block_size), A.dtype, tc_device) syrk_out = create_fortran((max_block_size, max_block_size), A.dtype, tc_device) lauum_out = create_fortran((max_block_size, max_block_size), A.dtype, tc_device) temp_bb = create_fortran((max_block_size, max_block_size), A.dtype, 'cpu', pin_memory=True) for b in range(len(block_allocs)): bb = block_allocs[b] # Load col b. # Instead of loading the whole column only load the last rows # as necessary by inspecting the minimum value in my_rows which is >= b. try: min_row = min([r for r in my_rows if r >= b]) b_start = block_allocs[min_row].start if is_cuda: col_b: torch.Tensor = A[b_start:N, bb.start:bb.end] else: col_b: torch.Tensor = copy_to_device( N - b_start, bb.length, A, b_start, bb.start, whole_col_b, 0, 0, s1) except ValueError: pass # No column here if not independent_output: barrier.wait() for r in my_rows: if r == b: # SYRK on col_b[bb.length:, :] with output into syrk_out[:bb.length, :bb.length] # C = beta*C + alpha * op(A) @ op(A).T if b_start + bb.length < N: cur_syrk_out = syrk_out[:bb.length, :bb.length] syrk_fn(cublas_handle, uplo='L', trans='T', n=bb.length, k=col_b.shape[0] - bb.length, alpha=1.0, A=col_b[bb.length:, :].data_ptr(), lda=col_b.stride(1), beta=0.0, C=cur_syrk_out.data_ptr(), ldc=syrk_out.stride(1)) with torch.cuda.stream(s3): cur_lauum_out = lauum_out[:bb.length, :bb.length] # Note that col_b[:bb.length, :bb.length] == Abb if independent_output: # In the independent output case we need to preserve tril(Abb) instead! cur_lauum_out.copy_( col_b[:bb.length, :bb.length].T) else: # In normal case we need triu(Abb) to be preserved in the upper triangle of lauum_out cur_lauum_out.copy_(col_b[:bb.length, :bb.length]) # LAUUM on col_b[:bb.length, :bb.length], into lauum_out[:bb.length, :bb.length] cuda_lauum_lower(n=bb.length, A=col_b[:bb.length, :bb.length], lda=col_b.stride(1), B=cur_lauum_out, ldb=max_block_size) s3.synchronize() # Add outputs of SYRK and LAUUM (only if SYRK was performed) if b_start + bb.length < N: s1.synchronize() cur_lauum_out.add_(cur_syrk_out) # Copy lauum_out into the original matrix, while preserving the other side # of the triangular matrix. This depends on the `independent_output` flag. Abb = A[bb.start:bb.end, bb.start:bb.end] if independent_output: Abb.copy_(cur_lauum_out.T) else: copy_to_host(bb.length, bb.length, cur_lauum_out, 0, 0, Abb, 0, 0, s=s1) elif r > b: br = block_allocs[r] # Load column r. Since r > b this column will be shorter than column b if is_cuda: col_r = A[br.start:N, br.start:br.end] else: col_r = copy_to_device(N - br.start, br.length, A, br.start, br.start, whole_col_r, 0, 0, s1) # Restrict column b to only the last 'r' rows ccb = col_b[br.start - b_start:, :] # TRMM on g_r[0:br.length, :] which is triangular (r*r) # and cur_g_b[0:br.length, :] # output is a r*b matrix and should be stored in a separate g_out block # Could store output in the first rows of g_b # C = alpha * op(A) @ B -- A triangular trmm_fn(handle=cublas_handle, side='L', uplo='L', trans='T', diag='N', m=br.length, n=bb.length, alpha=1.0, A=col_r.data_ptr(), lda=col_r.stride(1), B=ccb.data_ptr(), ldb=ccb.stride(1), C=ccb.data_ptr(), ldc=ccb.stride(1)) # GEMM on g_r[br.length:, :].T and cur_g_b[bb.length:, :] # output is the same r*b matrix as before, outputs need to be summed. # C = alpha * op(A) @ op(B) + beta * C if br.end < N: gemm_fn(handle=cublas_handle, transa='T', transb='N', m=br.length, n=bb.length, k=col_r.shape[0] - br.length, alpha=1.0, A=col_r[br.length:, :].data_ptr(), lda=col_r.stride(1), B=ccb[br.length:, :].data_ptr(), ldb=ccb.stride(1), beta=1.0, C=ccb.data_ptr(), ldc=ccb.stride(1)) # Copy back to A[r, b] if independent_output: if is_cuda: A[bb.start:bb.end, br.start:br.end].copy_( ccb[:br.length, :bb.length].T) else: _temp_cpu = copy_to_host(br.length, bb.length, ccb, 0, 0, temp_bb, 0, 0, s1) s1.synchronize() A[bb.start:bb.end, br.start:br.end].copy_(_temp_cpu.T) elif not is_cuda: copy_to_host(br.length, bb.length, ccb, 0, 0, A, br.start, bb.start, s1) s1.synchronize()