Ejemplo n.º 1
0
def cuda_trsm(A: torch.Tensor,
              v: torch.Tensor,
              alpha: float,
              lower: int,
              transpose: int,
              stream: Optional[torch.cuda.Stream] = None) -> torch.Tensor:
    if not is_f_contig(A, strict=False):
        raise ValueError("A must be f-contiguous for CUDA TRSM to work.")
    if not check_same_device(A, v):
        raise ValueError("A and v must be on the same CUDA device.")
    if not A.is_cuda:
        raise ValueError("A and v must be CUDA tensors!")

    device = A.device
    s = stream
    if stream is None:
        s = torch.cuda.current_stream(device=device)
    cublas_hdl = cublas_handle(device.index)
    trsm_fn = choose_fn(A.dtype, cublasDtrsm, cublasStrsm, "TRSM")

    # noinspection PyProtectedMember
    with torch.cuda.device(device), torch.cuda.stream(s), cublas_stream(
            cublas_hdl, s._as_parameter_):
        # Deal with copying v, which may not be F-contiguous.
        vF = create_fortran(v.size(), v.dtype, device)
        if is_f_contig(v, strict=False):
            # We can just make a copy of v
            vF.copy_(v)
            s.synchronize(
            )  # sync is necessary here for correctness. Not sure why! TODO: Is it still needed?
        else:
            vF = cuda_transpose(input=v, output=vF.T).T

        uplo = 'L' if lower else 'U'
        trans = 'T' if transpose else 'N'
        trsm_fn(cublas_hdl,
                side='L',
                uplo=uplo,
                trans=trans,
                diag='N',
                m=vF.shape[0],
                n=vF.shape[1],
                alpha=alpha,
                A=A.data_ptr(),
                lda=A.stride(1),
                B=vF.data_ptr(),
                ldb=vF.stride(1))
        if is_f_contig(v, strict=False):
            vout = vF
        else:
            vout = create_C(v.size(), v.dtype, device)
            vout = cuda_transpose(input=vF, output=vout.T).T
    return vout
Ejemplo n.º 2
0
    def test_rect(self, rect, order, dtype):
        from falkon.la_helpers.cuda_la_helpers import cuda_transpose
        mat = fix_mat(rect, order=order, dtype=dtype, copy=True, numpy=True)
        exp_mat_out = np.copy(mat.T, order=order)

        mat = move_tensor(torch.from_numpy(mat), "cuda:0")
        mat_out = move_tensor(torch.from_numpy(exp_mat_out), "cuda:0")
        mat_out.fill_(0.0)

        cuda_transpose(input=mat, output=mat_out)

        mat_out = move_tensor(mat_out, "cpu").numpy()
        assert mat_out.strides == exp_mat_out.strides
        np.testing.assert_allclose(exp_mat_out, mat_out)
Ejemplo n.º 3
0
def cuda_trsm(A: torch.Tensor, v: torch.Tensor, alpha: float, lower: int,
              transpose: int) -> torch.Tensor:
    if not is_f_contig(A, strict=False):
        raise ValueError("A must be f-contiguous for CUDA TRSM to work.")
    if not check_same_device(A, v):
        raise ValueError("A and v must be on the same CUDA device.")
    if not A.is_cuda:
        raise ValueError("A and v must be CUDA tensors!")

    s = torch.cuda.Stream(device=A.device)
    cublas_hdl = cublas_handle(A.device.index)
    trsm_fn = choose_fn(A.dtype, cublasDtrsm, cublasStrsm, "TRSM")

    with torch.cuda.device(A.device), torch.cuda.stream(s), cublas_stream(
            cublas_hdl, s._as_parameter_):
        # Deal with copying v, which may not be F-contiguous.
        vF = create_fortran(v.size(), v.dtype, v.device)
        if is_f_contig(v, strict=False):
            # We can just make a copy of v
            vF.copy_(v)
        else:
            vF = cuda_transpose(input=v, output=vF.T).T

        uplo = 'L' if lower else 'U'
        trans = 'T' if transpose else 'N'
        trsm_fn(cublas_hdl,
                side='L',
                uplo=uplo,
                trans=trans,
                diag='N',
                m=vF.shape[0],
                n=vF.shape[1],
                alpha=alpha,
                A=A.data_ptr(),
                lda=A.stride(1),
                B=vF.data_ptr(),
                ldb=vF.stride(1))
        if not is_f_contig(v, strict=False):
            vout = create_C(v.size(), v.dtype, v.device)
            vout = cuda_transpose(input=vF, output=vout.T).T
        else:
            vout = vF
        s.synchronize()
    return vout