def _use_qr(u, params): """Uses QR decomposition.""" a, b, c = params m, n = u.shape y = jnp.concatenate([jnp.sqrt(c) * u, jnp.eye(n)]) q, _ = lax_linalg.qr(y, full_matrices=False) q1 = q[:m, :] q2 = (q[m:, :]).T.conj() e = b / c u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2)) return u
def qr(a, mode="reduced"): if mode in ("reduced", "r", "full"): full_matrices = False elif mode == "complete": full_matrices = True else: raise ValueError("Unsupported QR decomposition mode '{}'".format(mode)) a = _promote_arg_dtypes(jnp.asarray(a)) q, r = lax_linalg.qr(a, full_matrices) if mode == "r": return r return q, r
def _qr(a, mode, pivoting): if pivoting: raise NotImplementedError( "The pivoting=True case of qr is not implemented.") if mode in ("full", "r"): full_matrices = True elif mode == "economic": full_matrices = False else: raise ValueError(f"Unsupported QR decomposition mode '{mode}'") a, = _promote_dtypes_inexact(jnp.asarray(a)) q, r = lax_linalg.qr(a, full_matrices=full_matrices) if mode == "r": return (r, ) return q, r
def _qr(a, mode, pivoting): if pivoting: raise NotImplementedError( "The pivoting=True case of qr is not implemented.") if mode in ("full", "r"): full_matrices = True elif mode == "economic": full_matrices = False else: raise ValueError("Unsupported QR decomposition mode '{}'".format(mode)) a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) q, r = lax_linalg.qr(a, full_matrices) if mode == "r": return r return q, r
def qr(a, mode="reduced"): a, = _promote_dtypes_inexact(jnp.asarray(a)) if mode == "raw": a, taus = lax_linalg.geqrf(a) return _T(a), taus if mode in ("reduced", "r", "full"): full_matrices = False elif mode == "complete": full_matrices = True else: raise ValueError(f"Unsupported QR decomposition mode '{mode}'") q, r = lax_linalg.qr(a, full_matrices=full_matrices) if mode == "r": return r return q, r
def _use_qr(u, m, n, params): """QDWH iteration using QR decomposition. Args: u: a matrix, with static (padded) shape M x N. m, n: the dynamic shape of the matrix, where m <= M and n <= N. params: the QDWH parameters. """ a, b, c = params M, N = u.shape y = _dynamic_concat(jnp.sqrt(c) * u, jnp.eye(N, dtype=jnp.dtype(u)), m) q, _ = lax_linalg.qr(y, full_matrices=False) # q1 = q[:m, :] q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n)) # q2 = (q[m:, :]).T.conj() q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0) q2 = _mask(q2, (n, n)).T.conj() e = b / c u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2)) return u