Esempio n. 1
0
 def retr(self, x, u):
     xut = x @ utils.transposem(u)
     xut_sym = (xut - utils.transposem(xut)) / 2.0
     eye = tf.eye(
         tf.shape(xut)[-1], batch_shape=tf.shape(xut)[:-2], dtype=x.dtype
     )
     return tf.linalg.solve(xut_sym + eye, x - xut_sym @ x)
Esempio n. 2
0
 def diff_from_cholesky(self, x, u):
     """Inverse of the differential of diffeomorphism to the Cholesky space"""
     assert_x = tf.debugging.Assert(
         self._cholesky.check_point_on_manifold(x), [x])
     assert_u = tf.debugging.Assert(
         self._cholesky.check_vector_on_tangent(x, u), [u])
     with tf.control_dependencies([assert_x, assert_u]):
         return x @ utils.transposem(u) + u @ utils.transposem(x)
Esempio n. 3
0
 def inner(self, x, u, v, keepdims=False):
     xtu = utils.transposem(x) @ u
     xtv = utils.transposem(x) @ v
     u_v_inner = tf.reduce_sum(u * v, axis=[-2, -1], keepdims=keepdims)
     xtu_xtv_inner = tf.reduce_sum(
         xtu * xtv, axis=[-2, -1], keepdims=keepdims
     )
     return u_v_inner - 0.5 * xtu_xtv_inner
Esempio n. 4
0
 def retr(self, x, u):
     xtu = utils.transposem(x) @ u
     w_ = (u - x @ xtu) @ utils.transposem(x)
     w = w_ - utils.transposem(w_)
     y = x + u
     for _ in range(self.num_iter):
         y = x + w @ ((x + y) / 2.0)
     return y
Esempio n. 5
0
 def geodesic(self, x, u, t):
     xtu = utils.transposem(x) @ u
     w_ = (u - x @ xtu) @ utils.transposem(x)
     w = w_ - utils.transposem(w_)
     eye = tf.linalg.eye(
         tf.shape(w)[-1], batch_shape=tf.shape(w)[:-2], dtype=x.dtype
     )
     cayley_t = tf.linalg.inv(eye - t * w / 2.0) @ (eye + t * w / 2.0)
     return cayley_t @ x
 def ptransp(self, x, y, v):
     log_xy = self.log(x, y)
     s, u, vt = tf.linalg.svd(log_xy, full_matrices=False)
     cos_s = tf.linalg.diag(tf.math.cos(s))
     sin_s = tf.linalg.diag(tf.math.sin(s))
     geod = ((-x @ utils.transposem(vt) @ sin_s + u @ cos_s)
             @ utils.transposem(u) @ v)
     proj = v - u @ utils.transposem(u) @ v
     return geod + proj
Esempio n. 7
0
 def geodesic(self, x, u, t):
     xtu = utils.transposem(x) @ u
     utu = utils.transposem(u) @ u
     eye = tf.eye(
         tf.shape(utu)[-1], batch_shape=tf.shape(utu)[:-2], dtype=x.dtype
     )
     logw = blockm(xtu, -utu, eye, xtu)
     w = tf.linalg.expm(t * logw)
     z = tf.concat([tf.linalg.expm(-xtu * t), tf.zeros_like(utu)], axis=-2)
     y = tf.concat([x, u], axis=-1) @ w @ z
     return y
Esempio n. 8
0
 def _check_point_on_manifold(self, x, atol, rtol):
     x_t = utils.transposem(x)
     eigvals, _ = tf.linalg.eigh(x)
     is_symmetric = utils.allclose(x, x_t, atol, rtol)
     is_pos_vals = utils.allclose(eigvals, tf.abs(eigvals), atol, rtol)
     is_zero_vals = utils.allclose(eigvals, tf.zeros_like(eigvals), atol,
                                   rtol)
     return is_symmetric & is_pos_vals & tf.logical_not(is_zero_vals)
Esempio n. 9
0
 def _check_point_on_manifold(self, x, atol, rtol):
     xtx = utils.transposem(x) @ x
     eye = tf.eye(
         tf.shape(xtx)[-1], batch_shape=tf.shape(xtx)[:-2], dtype=x.dtype
     )
     is_orth = utils.allclose(xtx, eye, atol, rtol)
     det = tf.linalg.det(x)
     is_unit_det = utils.allclose(det, tf.ones_like(det), atol, rtol)
     return is_orth & is_unit_det
Esempio n. 10
0
 def diff_to_cholesky(self, x, u):
     """Differential of the diffeomorphism to the Cholesky space"""
     assert_x = tf.debugging.Assert(self.check_point_on_manifold(x), [x])
     assert_u = tf.debugging.Assert(self.check_vector_on_tangent(x, u), [u])
     with tf.control_dependencies([assert_x, assert_u]):
         y = self.to_cholesky(x)
         y_inv = tf.linalg.inv(y)
         p = y_inv @ u @ utils.transposem(y_inv)
         p_diag, p_lower = self._cholesky._diag_and_strictly_lower(p)
         return y @ (p_lower + 0.5 * tf.linalg.diag(p_diag))
Esempio n. 11
0
 def _check_point_on_manifold(self, x, atol, rtol):
     xtx = utils.transposem(x) @ x
     shape = xtx.shape.as_list()
     eye = tf.eye(shape[-1], batch_shape=shape[:-2])
     is_idempotent = utils.allclose(xtx, tf.cast(eye, x.dtype), atol, rtol)
     s = tf.linalg.svd(x, compute_uv=False)
     rank = tf.math.count_nonzero(s, axis=-1, dtype=tf.float32)
     k = tf.ones_like(rank) * int(x.shape[-1])
     is_col_rank = utils.allclose(rank, k, atol, rtol)
     return is_idempotent & is_col_rank
Esempio n. 12
0
 def _diff_power(self, x, d, power):
     e, v = tf.linalg.eigh(x)
     v_t = utils.transposem(v)
     e = tf.expand_dims(e, -2)
     if power == "log":
         pow_e = tf.math.log(e)
     elif power == "exp":
         pow_e = tf.math.exp(e)
     s = utils.transposem(tf.ones_like(e)) @ e
     pow_s = utils.transposem(tf.ones_like(pow_e)) @ pow_e
     denom = utils.transposem(s) - s
     numer = utils.transposem(pow_s) - pow_s
     abs_denom = tf.math.abs(denom)
     eps = utils.get_eps(x)
     if power == "log":
         numer = tf.where(abs_denom < eps, tf.ones_like(numer), numer)
         denom = tf.where(abs_denom < eps, utils.transposem(s), denom)
     elif power == "exp":
         numer = tf.where(abs_denom < eps, utils.transposem(pow_s), numer)
         denom = tf.where(abs_denom < eps, tf.ones_like(denom), denom)
     t = v_t @ d @ v * numer / denom
     return v @ t @ v_t
Esempio n. 13
0
 def call(self, inputs):
     s, u, v = tf.linalg.svd(inputs)
     log_s = tf.math.log(s)
     return u @ tf.linalg.diag(log_s) @ utils.transposem(v)
Esempio n. 14
0
 def call(self, inputs):
     s, u, v = tf.linalg.svd(inputs)
     sigma = tf.maximum(s, self.epsilon)
     return u @ tf.linalg.diag(sigma) @ utils.transposem(v)
Esempio n. 15
0
 def call(self, inputs):
     return utils.transposem(self.w) @ inputs @ self.w
Esempio n. 16
0
 def log(self, x, y):
     xty = utils.transposem(x) @ y
     u = utils.logm(xty)
     return (u - utils.transposem(u)) / 2.0
Esempio n. 17
0
 def proju(self, x, u):
     xtu = utils.transposem(x) @ u
     return (xtu - utils.transposem(xtu)) / 2.0
Esempio n. 18
0
 def call(self, inputs):
     thetas = rot_angle(inputs)[..., tf.newaxis, tf.newaxis]
     zeros = tf.zeros_like(thetas)
     skew = (inputs - utils.transposem(inputs)) / 2.0
     log = skew * thetas / tf.math.sin(thetas)
     return tf.where(thetas < EPS, zeros, log)
Esempio n. 19
0
 def call(self, inputs):
     if self._expand:
         inputs = tf.expand_dims(inputs, -3)
     return utils.transposem(self.w) @ inputs
Esempio n. 20
0
 def projx(self, x):
     x_sym = (utils.transposem(x) + x) / 2.0
     s, u, v = tf.linalg.svd(x_sym)
     sigma = tf.linalg.diag(tf.maximum(s, 0.0))
     return v @ sigma @ utils.transposem(v)
Esempio n. 21
0
 def proju(self, x, u):
     xtu = utils.transposem(x) @ u
     xtu_sym = (utils.transposem(xtu) + xtu) / 2.0
     return u - x @ xtu_sym
Esempio n. 22
0
 def _check_point_on_manifold(self, x, atol, rtol):
     xtx = utils.transposem(x) @ x
     eye = tf.eye(
         tf.shape(xtx)[-1], batch_shape=tf.shape(xtx)[:-2], dtype=xtx.dtype
     )
     return utils.allclose(xtx, eye, atol, rtol)
Esempio n. 23
0
 def proju(self, x, u):
     return 0.5 * (utils.transposem(u) + u)
Esempio n. 24
0
 def ptransp(self, x, y, v):
     e = tf.linalg.sqrtm(y @ tf.linalg.inv(x))
     return e @ v @ utils.transposem(e)
Esempio n. 25
0
 def proju(self, x, u):
     xtu = utils.transposem(x) @ u
     w = (u - x @ xtu) @ utils.transposem(x)
     return (w - utils.transposem(w)) @ x
Esempio n. 26
0
 def projx(self, x):
     x_sym = (utils.transposem(x) + x) / 2.0
     s, _u, v = tf.linalg.svd(x_sym)
     sigma = tf.linalg.diag(tf.maximum(s, 0.0))
     spd = v @ sigma @ utils.transposem(v)
     return tf.linalg.cholesky(spd)
Esempio n. 27
0
 def proju(self, x, u):
     u_sym = (utils.transposem(u) + u) / 2.0
     u_diag, u_lower = self._diag_and_strictly_lower(u_sym)
     return u_lower + tf.linalg.diag(u_diag)
Esempio n. 28
0
 def _check_vector_on_tangent(self, x, u, atol, rtol):
     diff = utils.transposem(u) + u
     return utils.allclose(diff, tf.zeros_like(diff), atol, rtol)
Esempio n. 29
0
 def call(self, inputs):
     return inputs @ utils.transposem(inputs)
Esempio n. 30
0
 def proju(self, x, u):
     return u - x @ utils.transposem(u) @ x