def vector_transport(self, u, vec1, vec2): """Returns a vector tranported along an another vector via vector transport. Args: u: complex valued tensor of shape (..., n, n), a set of points from the manifold, starting points. vec1: complex valued tensor of shape (..., n, n), a set of vectors to be transported. vec2: complex valued tensor of shape (..., n, n), a set of direction vectors. Returns: complex valued tensor of shape (..., n, n), a set of transported vectors. Note: The complexity O(n^3).""" if self._metric == 'log_euclidean': lmbd, U = tf.linalg.eigh(u) # geoidesic in S Su = U @ tf.linalg.diag(tf.math.log(lmbd)) @ adj(U) Svec2 = _pull_back_log(vec2, U, lmbd) Sresult = Su + Svec2 # eig decomposition of a new point from S log_new_lmbd, new_U = tf.linalg.eigh(Sresult) # new lmbd new_lmbd = tf.exp(log_new_lmbd) # transported vector new_vec1 = _push_forward_log(_pull_back_log(vec1, U, lmbd), new_U, new_lmbd) return new_vec1 elif self._metric == 'log_cholesky': v = self.retraction(u, vec2) L = tf.linalg.cholesky(u) inv_L = tf.linalg.inv(L) inv_diag_L = tf.linalg.diag(1 / tf.linalg.diag_part(L)) X = _pull_back_chol(vec1, L, inv_L) K = tf.linalg.cholesky(v) L_transport = _lower(X) + tf.linalg.band_part(K, 0, 0) *\ inv_diag_L * tf.linalg.band_part(X, 0, 0) return K @ adj(L_transport) + L_transport @ adj(K)
def retraction_transport(self, u, vec1, vec2): """Performs a retraction and a vector transport simultaneously. Args: u: complex valued tensor of shape (..., n, n), a set of points from the manifold, starting points. vec1: complex valued tensor of shape (..., n, n), a set of vectors to be transported. vec2: complex valued tensor of shape (..., n, n), a set of direction vectors. Returns: two complex valued tensors of shape (..., n, n), a set of transported points and vectors.""" if self.metric == 'log_euclidean': lmbd, U = tf.linalg.eigh(u) # geoidesic in S Su = U @ tf.linalg.diag(tf.math.log(lmbd)) @ adj(U) Svec2 = _pull_back_log(vec2, U, lmbd) Sresult = Su + Svec2 # eig decomposition of new point from S log_new_lmbd, new_U = tf.linalg.eigh(Sresult) # new point from S++ new_point = new_U @ tf.linalg.diag(tf.exp(log_new_lmbd)) @\ adj(new_U) # new lmbd new_lmbd = tf.exp(log_new_lmbd) # transported vector new_vec1 = _push_forward_log(_pull_back_log(vec1, U, lmbd), new_U, new_lmbd) return new_point, new_vec1 elif self.metric == 'log_cholesky': v = self.retraction(u, vec2) L = tf.linalg.cholesky(u) inv_L = tf.linalg.inv(L) inv_diag_L = tf.linalg.diag(1 / tf.linalg.diag_part(L)) X = _pull_back_chol(vec1, L, inv_L) K = tf.linalg.cholesky(v) L_transport = _lower(X) + tf.linalg.band_part(K, 0, 0) *\ inv_diag_L * tf.linalg.band_part(X, 0, 0) return v, K @ adj(L_transport) + L_transport @ adj(K)
def inner(self, u, vec1, vec2): """Returns manifold wise inner product of vectors from a tangent space. Args: u: complex valued tensor of shape (..., n, n), a set of points from the manifold. vec1: complex valued tensor of shape (..., n, n), a set of tangent vectors from the manifold. vec2: complex valued tensor of shape (..., n, n), a set of tangent vectors from the manifold. Returns: complex valued tensor of shape (..., 1, 1), manifold wise inner product. Note: The complexity O(n^3) for both inner products.""" if self._metric == 'log_euclidean': lmbd, U = tf.linalg.eigh(u) W = _pull_back_log(vec1, U, lmbd) V = _pull_back_log(vec2, U, lmbd) prod = tf.math.real( tf.reduce_sum(tf.math.conj(W) * V, axis=(-2, -1), keepdims=True)) prod = tf.cast(prod, dtype=u.dtype) return prod elif self._metric == 'log_cholesky': u_shape = tf.shape(u) L = tf.linalg.cholesky(u) inv_L = tf.linalg.inv(L) W = _pull_back_chol(vec1, L, inv_L) V = _pull_back_chol(vec2, L, inv_L) mask = tf.ones(u_shape[-2:], dtype=u.dtype) mask = _lower(mask) G = mask + tf.linalg.diag(1 / (tf.linalg.diag_part(L)**2)) prod = tf.reduce_sum(tf.math.conj(W) * G * V, axis=(-2, -1)) prod = tf.math.real(prod) prod = prod[..., tf.newaxis, tf.newaxis] prod = tf.cast(prod, dtype=u.dtype) return prod
def retraction(self, u, vec): """Transports a set of points from the manifold via a retraction map. Args: u: complex valued tensor of shape (..., n, n), a set of points to be transported. vec: complex valued tensor of shape (..., n, n), a set of direction vectors. Returns: complex valued tensor of shape (..., n, n), a set of transported points. Note: The complexity O(n^3).""" if self._metric == 'log_euclidean': lmbd, U = tf.linalg.eigh(u) # geodesic in S Su = U @ tf.linalg.diag(tf.math.log(lmbd)) @ adj(U) Svec = _pull_back_log(vec, U, lmbd) Sresult = Su + Svec return tf.linalg.expm(Sresult) elif self._metric == 'log_cholesky': L = tf.linalg.cholesky(u) inv_L = tf.linalg.inv(L) X = _pull_back_chol(vec, L, inv_L) inv_diag_L = tf.linalg.diag(1 / tf.linalg.diag_part(L)) cholesky_retraction = _lower(L) + _lower(X) +\ tf.linalg.band_part(L, 0, 0) *\ tf.exp(tf.linalg.band_part(X, 0, 0) * inv_diag_L) return cholesky_retraction @ adj(cholesky_retraction)