def pytorch_orgqr_with_roundtrip(v, tau, checks=True): if checks: assert torch.is_tensor(v) and torch.is_tensor(tau) assert v.dim() == 2 and tau.dim() == 1 assert v.shape[0] >= v.shape[1] assert v.shape[1] == tau.numel() assert v.device == tau.device if v.device == 'cpu': return torch.orgqr(v, tau) device = v.device v, tau = v.cpu(), tau.cpu() out = torch.orgqr(v, tau).to(device) return out
def basis(A): """Return orthogonal basis of A columns.""" if A.is_cuda: # torch.orgqr is not available in CUDA Q = torch.linalg.qr(A).Q else: Q = torch.orgqr(*torch.geqrf(A)) return Q
def basis(A): """Return orthogonal basis of A columns. """ if A.is_cuda: # torch.orgqr is not available in CUDA Q, _ = torch.qr(A, some=True) else: Q = torch.orgqr(*torch.geqrf(A)) return Q
print('Error:', torch.linalg.norm(M - Q @ R).item()) print('Orthogonality:', torch.linalg.norm(torch.eye(n,n, device = device) - Q @ Q.t()).item()) print('--- torch.qr ---') begin = time.time() Q, R = torch.qr(M) end = time.time() print('Time:', end - begin) print('Error:', torch.linalg.norm(M - Q @ R).item()) print('Orthogonality:', torch.linalg.norm(torch.eye(n,n, device = device) - Q @ Q.t()).item()) if device == 'cpu': print('--- LAPACK geqrf + orgqr ---') begin = time.time() a, tau = torch.geqrf(M) Q = torch.orgqr(a, tau) end = time.time() print('Time:', end - begin) #print('--- NUMPY linalg.qr ---') #M = M.numpy() #begin = time.time() #Q, R = np.linalg.qr(M) #end = time.time() #print('Time:', end - begin)