Beispiel #1
0
def _generic_fmm(proc_idx, queue, device_id):
    # Unpack the function arguments
    a: ArgsFmm = queue.get()
    X1: torch.Tensor = a.X1
    X2: torch.Tensor = a.X2
    cuda_inputs = X1.is_cuda
    out = a.out
    kernel, gpu_dtype = a.kernel, a.gpu_dtype
    max_mem = a.max_mem
    num_streams = a.num_streams

    # flags and local variables
    change_dtype = gpu_dtype != X1.dtype
    X1_equal_X2 = _gpu_tns_same_memory(X1, X2)
    use_gpu_bufs = change_dtype or not cuda_inputs
    stride = "F" if is_f_contig(out, strict=True) else "C"
    j_iter = 0
    dts = sizeof_dtype(gpu_dtype)
    tc_device = torch.device('cuda:%d' % (int(device_id)))
    avail_mem = max_mem / dts

    # Choose block sizes n, m such that we won't run out of GPU memory
    ntot, d = X1.shape
    mtot = X2.shape[0]
    extra_mem = kernel.extra_mem()
    if cuda_inputs and not change_dtype:
        # No allocation will be performed by us. Only in-kernel stuff.
        n, m = select_dim_over_nm(max_n=ntot,
                                  max_m=mtot,
                                  d=d,
                                  coef_nd=extra_mem.get('nd', 0),
                                  coef_md=extra_mem.get('md', 0),
                                  coef_nm=extra_mem.get('nm', 0),
                                  coef_n=extra_mem.get('n', 0),
                                  coef_m=extra_mem.get('m', 0),
                                  rest=extra_mem.get('d', 0),
                                  max_mem=avail_mem)
    else:
        n, m = select_dim_over_nm(
            max_n=ntot,
            max_m=mtot,
            d=d,
            coef_nd=num_streams * (extra_mem.get('nd', 0) + 1),
            coef_md=num_streams * (extra_mem.get('md', 0) + 1),
            coef_nm=num_streams * (extra_mem.get('nm', 0) + 1),
            coef_n=extra_mem.get('n', 0),
            coef_m=extra_mem.get('m', 0),
            rest=extra_mem.get('d', 0),
            max_mem=avail_mem)

    # Create streams
    streams = [tcd.Stream(device=tc_device) for _ in range(num_streams)]

    # Create buffers
    if use_gpu_bufs:
        gX1 = create_same_stride((n, d), X1, gpu_dtype, tc_device)
        gX2_list = [
            create_same_stride((m, d), X2, gpu_dtype, tc_device)
            for _ in range(num_streams)
        ]
        gout_list = [
            create_same_stride((n, m), out, gpu_dtype, tc_device)
            for _ in range(num_streams)
        ]
    if not cuda_inputs:
        cpu_buf_list = [
            create_same_stride((n, m), out, gpu_dtype, 'cpu', pin_memory=True)
            for _ in range(num_streams)
        ]

    # Define helpers for the copy-back operations (from cpu_buf to output)
    copy_ops = [None] * num_streams

    def wrap_copy_op(stream_idx):
        if copy_ops[stream_idx] is not None:
            copy_ops[stream_idx]()
            copy_ops[stream_idx] = None

    def do_copy_op(output, buf, i_, ic_, j_, jc_):
        # This function will also do the type conversion
        output[i_:i_ + ic_, j_:j_ + jc_].copy_(buf[:ic_, :jc_])

    # Kernel computation begin
    with tcd.device(tc_device):
        for i in range(0, ntot, n):
            ic = min(n, ntot - i)

            with tcd.stream(streams[j_iter % len(streams)]):
                X1_chunk = X1.narrow(0, i, ic)
                if use_gpu_bufs:
                    cur_gX1 = gX1.narrow(0, 0, ic)
                    cur_gX1.copy_(X1_chunk, non_blocking=True)
                else:
                    cur_gX1 = X1_chunk

            for j in range(0, mtot, m):
                jc = min(m, mtot - j)
                # Choose the buffers for this inner iteration
                stream_id = j_iter % len(streams)
                stream = streams[stream_id]
                if use_gpu_bufs:
                    gX2 = gX2_list[stream_id]
                    gout = gout_list[stream_id]
                if not cuda_inputs:
                    cpu_buf = cpu_buf_list[stream_id]

                # Sync for buffers we must use now (e.g. 2 previous iters)
                with tcd.stream(stream):  # Inner-loop
                    stream.synchronize()
                    wrap_copy_op(stream_id)

                    if X1_equal_X2 and j < i:  # Shortcut for symmetric kernels
                        jc = min(m, mtot - j)
                        out[i:i + ic, j:j + jc].copy_(out[j:j + jc,
                                                          i:i + ic].T,
                                                      non_blocking=True)
                        j_iter += 1
                        continue

                    # Copy (CPU->GPU)
                    X2_chunk = X2.narrow(0, j, jc)
                    if use_gpu_bufs:
                        cur_gX2 = gX2.narrow(0, 0, jc)
                        cur_gX2.copy_(X2_chunk, non_blocking=True)
                    else:
                        cur_gX2 = X2_chunk

                    if use_gpu_bufs:
                        cur_gout = gout[:ic, :jc]
                    else:
                        cur_gout = out[i:i + ic, j:j + jc]
                    cur_gout.fill_(0.0)

                    # Compute
                    ddd = kernel._prepare(cur_gX1, cur_gX2)
                    kernel._apply(cur_gX1, cur_gX2.T, cur_gout)
                    cur_gout = kernel._finalize(cur_gout, ddd)

                    # Copy Back (GPU->CPU)
                    if not cuda_inputs:
                        # copy_ does not care about the contiguity of copies, as long as it's consistent
                        # however, in case of C-contiguous inputs it will create an intermediate array
                        # which is undesired. We use cuda_memcpy2d_async which works well with C-contiguous
                        # arrays.
                        if stride == "F":
                            copy_to_host(ic,
                                         jc,
                                         cur_gout,
                                         0,
                                         0,
                                         cpu_buf,
                                         0,
                                         0,
                                         s=stream)
                        else:
                            cuda_memcpy2d_async(dst=cpu_buf.data_ptr(),
                                                dpitch=cpu_buf.stride(0) * dts,
                                                src=cur_gout.data_ptr(),
                                                spitch=cur_gout.stride(0) *
                                                dts,
                                                width=jc * dts,
                                                height=ic,
                                                stream=stream._as_parameter_)
                        copy_ops[stream_id] = partial(do_copy_op, out, cpu_buf,
                                                      i, ic, j, jc)
                    elif change_dtype:
                        out.narrow(0, i,
                                   ic).narrow(1, j,
                                              jc).copy_(cur_gout,
                                                        non_blocking=True)
                j_iter += 1

            for i in range(num_streams):
                streams[i].synchronize()
                wrap_copy_op(i)

    return out
Beispiel #2
0
def par_lauum_f_lower(A: torch.Tensor, block_allocs: List[BlockAlloc],
                      my_rows: List[int], barrier: threading.Barrier,
                      device_id: int, cublas_handle, independent_output: bool):
    N = A.shape[0]

    lauum_fn = choose_fn(A.dtype, scll.dlauum, scll.slauum, "Lapack LAUUM")
    trmm_fn = choose_fn(A.dtype, cublasDtrmm, cublasStrmm, "cuBlas TRMM")
    gemm_fn = choose_fn(A.dtype, cublasDgemm, cublasSgemm, "cuBlas GEMM")
    syrk_fn = choose_fn(A.dtype, cublasDsyrk, cublasSsyrk, "cuBlas SYRK")

    tc_device = torch.device('cuda:%d' % (device_id))
    s1 = torch.cuda.Stream(device=tc_device)
    s2 = torch.cuda.Stream(device=tc_device)
    cublasSetStream(cublas_handle, s1._as_parameter_)

    max_block_size = max(ba.length for ba in block_allocs)
    my_rows = sorted(my_rows)

    with torch.cuda.device(tc_device), torch.cuda.stream(s1):
        # Preallocate 2 columns
        whole_col_b = create_fortran((A.shape[0], max_block_size), A.dtype,
                                     tc_device)
        whole_col_r = create_fortran((A.shape[0], max_block_size), A.dtype,
                                     tc_device)
        temp_bb = create_fortran((max_block_size, max_block_size),
                                 A.dtype,
                                 'cpu',
                                 pin_memory=True)

        for b in range(len(block_allocs)):
            bb = block_allocs[b]
            # Load col b.
            # Instead of loading the whole column only load the last rows
            # as necessary by inspecting the minimum value in my_rows which is >= b.
            try:
                min_row = min([r for r in my_rows if r >= b])
                b_start = block_allocs[min_row].start
                col_b = copy_to_device(N - b_start, bb.length, A, b_start,
                                       bb.start, whole_col_b, 0, 0, s1)
            except ValueError:
                pass  # No column here
            if not independent_output:
                barrier.wait()

            for r in my_rows:
                if r < b:
                    continue
                if r == b:
                    # SYRK on g_b[bb.length:, :] with output replacing g_b[:bb.length, :]
                    # C = beta*C + alpha * op(A) @ op(A).T
                    if b_start + bb.length < N:
                        syrk_fn(cublas_handle,
                                uplo='L',
                                trans='T',
                                n=bb.length,
                                k=col_b.shape[0] - bb.length,
                                alpha=1.0,
                                A=col_b[bb.length:, :].data_ptr(),
                                lda=col_b.stride(1),
                                beta=0.0,
                                C=col_b.data_ptr(),
                                ldc=col_b.stride(1))
                    # CPU LAUUM on A[bb.start:bb.end, bb.start:bb.end]. This is a bit messy, should do cleanup.
                    Abb = A[bb.start:bb.end, bb.start:bb.end]  # L\U
                    if independent_output:
                        Abb_np = Abb.numpy().copy(order="F")
                        # Make symmetric: L\L
                        copy_triang(Abb_np, upper=False)
                        uu, info = lauum_fn(Abb_np, lower=1,
                                            overwrite_c=True)  # LAU\L
                        Abb.copy_(torch.from_numpy(uu.T))  # L\LAU
                    else:
                        uu, info = lauum_fn(Abb.numpy(),
                                            lower=1,
                                            overwrite_c=False)  # LAU\L
                        if b_start + bb.length < N:
                            zero_triang(uu, upper=True)
                        Abb.copy_(torch.from_numpy(uu))
                    if b_start + bb.length < N:
                        # It is IMPORTANT to do the copy on s1 and then sync it.
                        tbb = copy_to_host(bb.length, bb.length, col_b, 0, 0,
                                           temp_bb, 0, 0, s1)
                        s1.synchronize()
                        if independent_output:
                            Abb.add_(torch.triu(tbb.T))
                        else:
                            Abb.add_(tbb)
                else:  # r > b
                    br = block_allocs[r]

                    # Load column r. Since r > b this column will be shorter than column b
                    col_r = copy_to_device(N - br.start, br.length, A,
                                           br.start, br.start, whole_col_r, 0,
                                           0, s1)
                    # Restrict column b to only the last 'r' rows
                    ccb = col_b[br.start - b_start:, :]

                    # TRMM on g_r[0:br.length, :] which is triangular (r*r)
                    #         and cur_g_b[0:br.length, :]
                    #         output is a r*b matrix and should be stored in a separate g_out block
                    # Could store output in the first rows of g_b
                    # C = alpha * op(A) @ B -- A triangular
                    trmm_fn(handle=cublas_handle,
                            side='L',
                            uplo='L',
                            trans='T',
                            diag='N',
                            m=br.length,
                            n=bb.length,
                            alpha=1.0,
                            A=col_r.data_ptr(),
                            lda=col_r.stride(1),
                            B=ccb.data_ptr(),
                            ldb=ccb.stride(1),
                            C=ccb.data_ptr(),
                            ldc=ccb.stride(1))

                    # GEMM on g_r[br.length:, :].T and cur_g_b[bb.length:, :]
                    #         output  is the same r*b matrix as before, outputs need to be summed.
                    # C = alpha * op(A) @ op(B) + beta * C
                    if br.end < N:
                        gemm_fn(handle=cublas_handle,
                                transa='T',
                                transb='N',
                                m=br.length,
                                n=bb.length,
                                k=col_r.shape[0] - br.length,
                                alpha=1.0,
                                A=col_r[br.length:, :].data_ptr(),
                                lda=col_r.stride(1),
                                B=ccb[br.length:, :].data_ptr(),
                                ldb=ccb.stride(1),
                                beta=1.0,
                                C=ccb.data_ptr(),
                                ldc=ccb.stride(1))
                    # Copy back to A[r, b]
                    if independent_output:
                        _temp_cpu = copy_to_host(br.length, bb.length, ccb, 0,
                                                 0, temp_bb, 0, 0, s1)
                        s1.synchronize()
                        A[bb.start:bb.end, br.start:br.end].copy_(_temp_cpu.T)
                    else:
                        s1.synchronize()
                        copy_to_host(br.length, bb.length, ccb, 0, 0, A,
                                     br.start, bb.start, s2)
            s2.synchronize()
Beispiel #3
0
def _ic_cholesky(A, upper, device, cusolver_handle):
    """Cholesky factorization of matrix `A` on the GPU

    Uses the cuSOLVER library for implementation of the POTRF function.

    Parameters:
    -----------
    A : [n, n] CPU or GPU array (column-contiguous)
        The (positive definite) matrix which should be factorized
    upper : bool
        Whether we need to factorize the upper of lower portion of `A`. The other side
        of the matrix will not be touched.
    device : int
        The GPU device on which to run the factorization
    cusolver_handle
        Pointer to the cuSOLVER context, which needs to be initialized before calling
        the function.

    Returns:
    --------
    A : [n, n] CPU or GPU array (column-contiguous)
        The factorization of A which overwrites the upper (or lower) triangular part
        of the matrix A. This is not a copy of the original matrix.
    """
    # Check library initialization
    if cusolver_handle is None:
        raise RuntimeError("CuSolver must be initialized "
                           "before running in-core Cholesky.")
    if not is_f_contig(A):
        raise RuntimeError("Cholesky input must be F-contiguous")

    uplo = 'U' if upper else 'L'
    n = A.shape[0]

    tc_device = torch.device("cuda:%d" % (device))
    # Choose functions by dtype
    potrf_buf_size = choose_fn(A.dtype, cusolverDnDpotrf_bufferSize,
                               cusolverDnSpotrf_bufferSize,
                               "POTRF Buffer size")
    potrf_fn = choose_fn(A.dtype, cusolverDnDpotrf, cusolverDnSpotrf, "POTRF")

    # noinspection PyUnresolvedReferences
    with torch.cuda.device(tc_device):
        # Copy A to device memory
        if A.is_cuda:
            Agpu = A
        else:
            Agpu = create_fortran(A.shape, A.dtype, tc_device)
            copy_to_device(n, n, A, 0, 0, Agpu, 0, 0)

        # Create workspace buffer
        potrf_bsize = potrf_buf_size(handle=cusolver_handle,
                                     uplo=uplo,
                                     n=n,
                                     A=Agpu.data_ptr(),
                                     lda=n)
        potrf_wspace = create_fortran((potrf_bsize, ), A.dtype, tc_device)
        dev_info = torch.tensor(4, dtype=torch.int32, device=tc_device)

        # Run cholesky
        potrf_fn(handle=cusolver_handle,
                 uplo=uplo,
                 n=n,
                 A=Agpu.data_ptr(),
                 lda=n,
                 workspace=potrf_wspace.data_ptr(),
                 Lwork=potrf_bsize,
                 devInfo=dev_info)
        torch.cuda.synchronize()

        # Copy back to CPU
        if not A.is_cuda:
            copy_to_host(n, n, Agpu, 0, 0, A, 0, 0)
            del Agpu
        del potrf_wspace, dev_info
    return A
Beispiel #4
0
def par_lauum_f_lower(A: torch.Tensor,
                      block_allocs: List[BlockAlloc],
                      my_rows: List[int],
                      barrier: threading.Barrier,
                      device_id: int,
                      cublas_handle,
                      independent_output: bool):
    N = A.shape[0]
    is_cuda = A.device.type == "cuda"

    trmm_fn = choose_fn(A.dtype, cublasDtrmm, cublasStrmm, "cuBlas TRMM")
    gemm_fn = choose_fn(A.dtype, cublasDgemm, cublasSgemm, "cuBlas GEMM")
    syrk_fn = choose_fn(A.dtype, cublasDsyrk, cublasSsyrk, "cuBlas SYRK")

    tc_device = torch.device('cuda:%d' % (device_id))
    s1 = torch.cuda.Stream(device=tc_device)
    s3 = torch.cuda.Stream(device=tc_device)

    max_block_size = max(ba.length for ba in block_allocs)
    my_rows = sorted(my_rows)

    with torch.cuda.device(tc_device), torch.cuda.stream(s1), cublas_stream(cublas_handle, s1._as_parameter_):
        # Pre allocate b-col, syrk-out, lauum-out
        mem_needed = N * max_block_size + 2 * (max_block_size ** 2)
        if not is_cuda:
            # Also pre alloc r-col
            mem_needed += N * max_block_size
        f_gpu = torch.empty(size=(mem_needed,), dtype=A.dtype, device=tc_device)
        f_offset = 0
        whole_col_b, f_offset = _extract_flat(f_gpu, (N, max_block_size), other=A, offset=f_offset)
        syrk_out, f_offset = _extract_flat(f_gpu, (max_block_size, max_block_size), other=A, offset=f_offset)
        lauum_out, f_offset = _extract_flat(f_gpu, (max_block_size, max_block_size), other=A, offset=f_offset)
        if not is_cuda:
            temp_bb = create_fortran((max_block_size, max_block_size), A.dtype, 'cpu', pin_memory=True)
            whole_col_r, f_offset = _extract_flat(f_gpu, (N, max_block_size), other=A, offset=f_offset)
        syrk_out.fill_(0.0)

        for b in range(len(block_allocs)):
            bb = block_allocs[b]
            # Load col b.
            # Instead of loading the whole column only load the last rows
            # as necessary by inspecting the minimum value in my_rows which is >= b.
            try:
                min_row = min([r for r in my_rows if r >= b])
                b_start = block_allocs[min_row].start
                if is_cuda:
                    col_b = whole_col_b[b_start:, :bb.length]
                    col_b.copy_(A[b_start:N, bb.start:bb.end])
                else:
                    col_b: torch.Tensor = copy_to_device(
                        N - b_start, bb.length, A, b_start, bb.start, whole_col_b, 0, 0, s1)
            except ValueError:
                pass  # No column here
            if not independent_output:
                # wait for copy to device to succeed. After barrier other threads may modify
                # the part of col_b which we need!
                s1.synchronize()
                barrier.wait()

            for r in my_rows:
                if r == b:
                    # SYRK on col_b[bb.length:, :] with output into syrk_out[:bb.length, :bb.length]
                    # C = beta*C + alpha * op(A) @ op(A).T
                    if b_start + bb.length < N:
                        cur_syrk_out = syrk_out[:bb.length, :bb.length]
                        syrk_fn(cublas_handle, uplo='L', trans='T',
                                n=bb.length, k=col_b.shape[0] - bb.length,
                                alpha=1.0, A=col_b[bb.length:, :].data_ptr(), lda=col_b.stride(1),
                                beta=0.0, C=cur_syrk_out.data_ptr(), ldc=syrk_out.stride(1))

                    with torch.cuda.stream(s3):
                        if independent_output:
                            s1.synchronize()  # we need col_b to be loaded
                        cur_lauum_out = lauum_out[:bb.length, :bb.length]
                        # Note that col_b[:bb.length, :bb.length] == Abb
                        if independent_output:
                            # In the independent output case we need to preserve tril(Abb) instead!
                            cur_lauum_out.copy_(col_b[:bb.length, :bb.length].T)
                        else:
                            # In normal case we need triu(Abb) to be preserved in the upper triangle of lauum_out
                            cur_lauum_out.copy_(col_b[:bb.length, :bb.length])

                        # LAUUM on col_b[:bb.length, :bb.length], into lauum_out[:bb.length, :bb.length]
                        cuda_lauum(n=bb.length, A=col_b[:bb.length, :bb.length], lda=col_b.stride(1),
                                   B=cur_lauum_out, ldb=max_block_size, lower=True)
                    s1.wait_stream(s3)  # all subsequent work will need cur_lauum_out

                    # Add outputs of SYRK and LAUUM (only if SYRK was performed)
                    if b_start + bb.length < N:
                        cur_lauum_out.add_(cur_syrk_out)

                    # Copy lauum_out into the original matrix, while preserving the other side
                    # of the triangular matrix. This depends on the `independent_output` flag.
                    Abb = A[bb.start:bb.end, bb.start:bb.end]
                    if independent_output:
                        # cuda and non-cuda cases, since we have different orderings.
                        Abb.copy_(cur_lauum_out.T)
                    elif is_cuda:
                        Abb.copy_(cur_lauum_out)
                    else:
                        copy_to_host(bb.length, bb.length, cur_lauum_out, 0, 0, Abb, 0, 0, s=s1)
                elif r > b:
                    br = block_allocs[r]

                    # Load column r. Since r > b this column will be shorter than column b
                    if is_cuda:  # If col_r is already in GPU no copy needed.
                        col_r = A[br.start:N, br.start:br.end]
                    else:
                        col_r = copy_to_device(N - br.start, br.length, A, br.start, br.start,
                                               whole_col_r, 0, 0, s1)
                    # Restrict column b to only the last 'r' rows
                    ccb = col_b[br.start - b_start:, :]

                    # TRMM on g_r[0:br.length, :] which is triangular (r*r)
                    #         and cur_g_b[0:br.length, :]
                    #         output is a r*b matrix stored in the first rows of ccb
                    # C = alpha * op(A) @ B -- A triangular
                    trmm_fn(
                        handle=cublas_handle,
                        side='L', uplo='L', trans='T', diag='N',
                        m=br.length, n=bb.length,
                        alpha=1.0, A=col_r.data_ptr(), lda=col_r.stride(1),
                        B=ccb.data_ptr(), ldb=ccb.stride(1),
                        C=ccb.data_ptr(), ldc=ccb.stride(1))

                    # GEMM on g_r[br.length:, :].T and cur_g_b[bb.length:, :]
                    #         output  is the same r*b matrix as before, outputs need to be summed.
                    # C = alpha * op(A) @ op(B) + beta * C
                    if br.end < N:
                        gemm_fn(handle=cublas_handle,
                                transa='T', transb='N',
                                m=br.length, n=bb.length, k=col_r.shape[0] - br.length,
                                alpha=1.0, A=col_r[br.length:, :].data_ptr(), lda=col_r.stride(1),
                                B=ccb[br.length:, :].data_ptr(), ldb=ccb.stride(1),
                                beta=1.0, C=ccb.data_ptr(), ldc=ccb.stride(1))
                    # Copy back to A[r, b]
                    if independent_output:
                        if is_cuda:
                            A[bb.start:bb.end, br.start:br.end].copy_(ccb[:br.length, :bb.length].T)
                        else:
                            _temp_cpu = copy_to_host(br.length, bb.length, ccb, 0, 0, temp_bb, 0, 0, s1)
                            s1.synchronize()  # must wait for data to be onto CPU.
                            A[bb.start:bb.end, br.start:br.end].copy_(_temp_cpu.T)
                    elif is_cuda:
                        A[br.start:br.end, bb.start:bb.end].copy_(ccb[:br.length, :bb.length])
                    else:
                        copy_to_host(br.length, bb.length, ccb, 0, 0, A, br.start, bb.start, s1)
            s1.synchronize()
Beispiel #5
0
def par_lauum_f_lower(A: torch.Tensor, block_allocs: List[BlockAlloc],
                      my_rows: List[int], barrier: threading.Barrier,
                      device_id: int, cublas_handle, independent_output: bool):
    N = A.shape[0]
    is_cuda = A.device.type == "cuda"

    trmm_fn = choose_fn(A.dtype, cublasDtrmm, cublasStrmm, "cuBlas TRMM")
    gemm_fn = choose_fn(A.dtype, cublasDgemm, cublasSgemm, "cuBlas GEMM")
    syrk_fn = choose_fn(A.dtype, cublasDsyrk, cublasSsyrk, "cuBlas SYRK")

    tc_device = torch.device('cuda:%d' % (device_id))
    s1 = torch.cuda.Stream(device=tc_device)
    s3 = torch.cuda.Stream(device=tc_device)

    max_block_size = max(ba.length for ba in block_allocs)
    my_rows = sorted(my_rows)

    with torch.cuda.device(tc_device), torch.cuda.stream(s1), cublas_stream(
            cublas_handle, s1._as_parameter_):
        # Preallocate 2 columns
        if not is_cuda:
            whole_col_b = create_fortran((A.shape[0], max_block_size), A.dtype,
                                         tc_device)
            whole_col_r = create_fortran((A.shape[0], max_block_size), A.dtype,
                                         tc_device)
        syrk_out = create_fortran((max_block_size, max_block_size), A.dtype,
                                  tc_device)
        lauum_out = create_fortran((max_block_size, max_block_size), A.dtype,
                                   tc_device)
        temp_bb = create_fortran((max_block_size, max_block_size),
                                 A.dtype,
                                 'cpu',
                                 pin_memory=True)

        for b in range(len(block_allocs)):
            bb = block_allocs[b]
            # Load col b.
            # Instead of loading the whole column only load the last rows
            # as necessary by inspecting the minimum value in my_rows which is >= b.
            try:
                min_row = min([r for r in my_rows if r >= b])
                b_start = block_allocs[min_row].start
                if is_cuda:
                    col_b: torch.Tensor = A[b_start:N, bb.start:bb.end]
                else:
                    col_b: torch.Tensor = copy_to_device(
                        N - b_start, bb.length, A, b_start, bb.start,
                        whole_col_b, 0, 0, s1)
            except ValueError:
                pass  # No column here
            if not independent_output:
                barrier.wait()

            for r in my_rows:
                if r == b:
                    # SYRK on col_b[bb.length:, :] with output into syrk_out[:bb.length, :bb.length]
                    # C = beta*C + alpha * op(A) @ op(A).T
                    if b_start + bb.length < N:
                        cur_syrk_out = syrk_out[:bb.length, :bb.length]
                        syrk_fn(cublas_handle,
                                uplo='L',
                                trans='T',
                                n=bb.length,
                                k=col_b.shape[0] - bb.length,
                                alpha=1.0,
                                A=col_b[bb.length:, :].data_ptr(),
                                lda=col_b.stride(1),
                                beta=0.0,
                                C=cur_syrk_out.data_ptr(),
                                ldc=syrk_out.stride(1))

                    with torch.cuda.stream(s3):
                        cur_lauum_out = lauum_out[:bb.length, :bb.length]
                        # Note that col_b[:bb.length, :bb.length] == Abb
                        if independent_output:
                            # In the independent output case we need to preserve tril(Abb) instead!
                            cur_lauum_out.copy_(
                                col_b[:bb.length, :bb.length].T)
                        else:
                            # In normal case we need triu(Abb) to be preserved in the upper triangle of lauum_out
                            cur_lauum_out.copy_(col_b[:bb.length, :bb.length])

                        # LAUUM on col_b[:bb.length, :bb.length], into lauum_out[:bb.length, :bb.length]
                        cuda_lauum_lower(n=bb.length,
                                         A=col_b[:bb.length, :bb.length],
                                         lda=col_b.stride(1),
                                         B=cur_lauum_out,
                                         ldb=max_block_size)
                    s3.synchronize()

                    # Add outputs of SYRK and LAUUM (only if SYRK was performed)
                    if b_start + bb.length < N:
                        s1.synchronize()
                        cur_lauum_out.add_(cur_syrk_out)

                    # Copy lauum_out into the original matrix, while preserving the other side
                    # of the triangular matrix. This depends on the `independent_output` flag.
                    Abb = A[bb.start:bb.end, bb.start:bb.end]
                    if independent_output:
                        Abb.copy_(cur_lauum_out.T)
                    else:
                        copy_to_host(bb.length,
                                     bb.length,
                                     cur_lauum_out,
                                     0,
                                     0,
                                     Abb,
                                     0,
                                     0,
                                     s=s1)
                elif r > b:
                    br = block_allocs[r]

                    # Load column r. Since r > b this column will be shorter than column b
                    if is_cuda:
                        col_r = A[br.start:N, br.start:br.end]
                    else:
                        col_r = copy_to_device(N - br.start, br.length, A,
                                               br.start, br.start, whole_col_r,
                                               0, 0, s1)
                    # Restrict column b to only the last 'r' rows
                    ccb = col_b[br.start - b_start:, :]

                    # TRMM on g_r[0:br.length, :] which is triangular (r*r)
                    #         and cur_g_b[0:br.length, :]
                    #         output is a r*b matrix and should be stored in a separate g_out block
                    # Could store output in the first rows of g_b
                    # C = alpha * op(A) @ B -- A triangular
                    trmm_fn(handle=cublas_handle,
                            side='L',
                            uplo='L',
                            trans='T',
                            diag='N',
                            m=br.length,
                            n=bb.length,
                            alpha=1.0,
                            A=col_r.data_ptr(),
                            lda=col_r.stride(1),
                            B=ccb.data_ptr(),
                            ldb=ccb.stride(1),
                            C=ccb.data_ptr(),
                            ldc=ccb.stride(1))

                    # GEMM on g_r[br.length:, :].T and cur_g_b[bb.length:, :]
                    #         output  is the same r*b matrix as before, outputs need to be summed.
                    # C = alpha * op(A) @ op(B) + beta * C
                    if br.end < N:
                        gemm_fn(handle=cublas_handle,
                                transa='T',
                                transb='N',
                                m=br.length,
                                n=bb.length,
                                k=col_r.shape[0] - br.length,
                                alpha=1.0,
                                A=col_r[br.length:, :].data_ptr(),
                                lda=col_r.stride(1),
                                B=ccb[br.length:, :].data_ptr(),
                                ldb=ccb.stride(1),
                                beta=1.0,
                                C=ccb.data_ptr(),
                                ldc=ccb.stride(1))
                    # Copy back to A[r, b]
                    if independent_output:
                        if is_cuda:
                            A[bb.start:bb.end, br.start:br.end].copy_(
                                ccb[:br.length, :bb.length].T)
                        else:
                            _temp_cpu = copy_to_host(br.length, bb.length, ccb,
                                                     0, 0, temp_bb, 0, 0, s1)
                            s1.synchronize()
                            A[bb.start:bb.end,
                              br.start:br.end].copy_(_temp_cpu.T)
                    elif not is_cuda:
                        copy_to_host(br.length, bb.length, ccb, 0, 0, A,
                                     br.start, bb.start, s1)
            s1.synchronize()