def ab_decomposition( u: jnp.ndarray, v: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Decompose vector v as follows v = u @ a + u_orth @ b. If vector v is a tangent, then a matrix is skew-hermitian. Args: u: array like of shape (..., n, m). v: array like of shape (..., n, m). Returns: elements of decomposition a, b, and u_orth.""" n, m = u.shape[-2:] tail = u.shape[:-2] u = u.reshape((-1, n, m)) v = v.reshape((-1, n, m)) u_orth = vmap(lambda x: jnp.linalg.qr(x, mode='complete')[0])(u)[..., m:] a = u.conj().transpose((0, 2, 1)) @ v b = u_orth.conj().transpose((0, 2, 1)) @ v a = a.reshape((*tail, -1, m)) b = b.reshape((*tail, -1, m)) u_orth = u_orth.reshape((*tail, n, -1)) return a, b, u_orth
def _apply( self, iter: jnp.ndarray, grad: jnp.ndarray, state: Tuple[jnp.ndarray], param: jnp.ndarray, precond: Union[None, jnp.ndarray], use_precond=False ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray]]: if use_precond: rgrad = self.manifold.egrad_to_rgrad(param, grad.conj(), precond) else: rgrad = self.manifold.egrad_to_rgrad(param, grad.conj()) momentum = self.beta1 * state[0] + (1 - self.beta1) * rgrad if use_precond: v = self.beta2 * state[1] + (1 - self.beta2) * self.manifold.inner( param, rgrad, rgrad, precond ) else: v = self.beta2 * state[1] + (1 - self.beta2) * self.manifold.inner( param, rgrad, rgrad ) if self.ams: v_hat = jax.lax.complex(jnp.maximum(jnp.real(v), jnp.real(state[2])), jnp.imag(v)) # Bias correction lr_corr = ( self.learning_rate * jnp.sqrt(1 - self.beta2 ** (iter + 1)) / (1 - self.beta1 ** (iter + 1)) ) if self.ams: search_dir = -lr_corr * momentum / (jnp.sqrt(v_hat) + self.eps) param, momentum = self.manifold.retraction_transport( param, momentum, search_dir ) return param, (momentum, v, v_hat) else: search_dir = -lr_corr * momentum / (jnp.sqrt(v) + self.eps) param, momentum = self.manifold.retraction_transport( param, momentum, search_dir ) return param, (momentum, v)
def inner(self, u: jnp.ndarray, vec1: jnp.ndarray, vec2: jnp.ndarray, precond: Union[None, jnp.ndarray] = None) -> jnp.ndarray: """Returns manifold wise inner product of vectors from a tangent space. Args: u: complex valued tensor of shape (..., n, p), a set of points from the complex Stiefel manifold. vec1: complex valued tensor of shape (..., n, p), a set of tangent vectors from the complex Stiefel manifold. vec2: complex valued tensor of shape (..., n, p), a set of tangent vectors from the complex Stiefel manifold. precond: complex valued tensor of shape (..., p, p), optional preconditioner representing natural metric of an isometric tensor network. Returns: complex valued tensor of shape (..., 1, 1), manifold wise inner product Note: The complexity for the 'euclidean' metric is O(pn), the complexity for the 'canonical' metric is O(np^2)""" if not self._use_precond: if self._metric == "euclidean": s_sq = (vec1.conj() * vec2).sum(keepdims=True, axis=(-2, -1)) elif self._metric == "canonical": s_sq_1 = (vec1.conj() * vec2).sum(keepdims=True, axis=(-2, -1)) vec1_dag_u = adj(vec1) @ u u_dag_vec2 = adj(u) @ vec2 s_sq_2 = (u_dag_vec2 * transp(vec1_dag_u)).sum(axis=(-2, -1), keepdims=True) s_sq = s_sq_1 - 0.5 * s_sq_2 else: s_sq = (vec1.conj() * (vec2 @ precond)).sum(keepdims=True, axis=(-2, -1)) return jnp.real(s_sq).astype(dtype=u.dtype)
def _apply( self, iter: jnp.ndarray, grad: jnp.ndarray, state: Tuple[jnp.ndarray], param: jnp.ndarray, precond: Union[None, jnp.ndarray]=None, use_precond=False ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray]]: if use_precond: rgrad = self.manifold.egrad_to_rgrad(param, grad.conj(), precond) else: rgrad = self.manifold.egrad_to_rgrad(param, grad.conj()) if self.use_momentum: momentum = self.momentum * state[0] + (1 - self.momentum) * rgrad param, momentum = self.manifold.retraction_transport( param, momentum, -self.learning_rate * momentum ) return param, (momentum,) else: param = self.manifold.retraction(param, -self.learning_rate * rgrad) return param, state
def adj(a: jnp.ndarray) -> jnp.ndarray: """Returns adjoint matrix. Args: a: complex valued tensor of shape (..., n1, n2) Returns: complex valued tensor of shape (..., n2, n1)""" matrix_shape = a.shape[-2:] bs_shape = a.shape[:-2] a = a.reshape((-1, *matrix_shape)) a = a.transpose((0, 2, 1)) a = a.reshape((*bs_shape, matrix_shape[1], matrix_shape[0])) a = a.conj() return a
def symmetrize(A: jnp.ndarray) -> jnp.ndarray: """Make symmetric and hermitian.""" return (A + A.conj().T) / 2