def pytorch_orgqr_with_roundtrip(v, tau, checks=True):
    if checks:
        assert torch.is_tensor(v) and torch.is_tensor(tau)
        assert v.dim() == 2 and tau.dim() == 1
        assert v.shape[0] >= v.shape[1]
        assert v.shape[1] == tau.numel()
        assert v.device == tau.device
    if v.device == 'cpu':
        return torch.orgqr(v, tau)
    device = v.device
    v, tau = v.cpu(), tau.cpu()
    out = torch.orgqr(v, tau).to(device)
    return out
Esempio n. 2
0
def basis(A):
    """Return orthogonal basis of A columns."""
    if A.is_cuda:
        # torch.orgqr is not available in CUDA
        Q = torch.linalg.qr(A).Q
    else:
        Q = torch.orgqr(*torch.geqrf(A))
    return Q
Esempio n. 3
0
def basis(A):
    """Return orthogonal basis of A columns.
    """
    if A.is_cuda:
        # torch.orgqr is not available in CUDA
        Q, _ = torch.qr(A, some=True)
    else:
        Q = torch.orgqr(*torch.geqrf(A))
    return Q
Esempio n. 4
0
print('Error:', torch.linalg.norm(M - Q @ R).item())
print('Orthogonality:', torch.linalg.norm(torch.eye(n,n, device = device) - Q @ Q.t()).item())


print('--- torch.qr ---')
begin = time.time()
Q, R = torch.qr(M)
end = time.time()

print('Time:', end - begin)
print('Error:', torch.linalg.norm(M - Q @ R).item())
print('Orthogonality:', torch.linalg.norm(torch.eye(n,n, device = device) - Q @ Q.t()).item())



if device == 'cpu':

    print('--- LAPACK geqrf + orgqr ---')
    begin = time.time()
    a, tau = torch.geqrf(M)
    Q = torch.orgqr(a, tau)
    end = time.time()
    print('Time:', end - begin)

    #print('--- NUMPY linalg.qr ---')
    #M = M.numpy()
    #begin = time.time()
    #Q, R = np.linalg.qr(M)
    #end = time.time()
    #print('Time:', end - begin)