示例#1
0
 def dist(self, x, y, keepdims=False):
     sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype))
     x_y = self._mobius_add(-x, y)
     norm_x_y = tf.linalg.norm(x_y, axis=-1, keepdims=keepdims)
     eps = utils.get_eps(x)
     tanh = tf.clip_by_value(sqrt_k * norm_x_y, -1.0 + eps, 1.0 - eps)
     return 2 * tf.math.atanh(tanh) / sqrt_k
示例#2
0
 def log(self, x, y):
     sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype))
     x_y = self._mobius_add(-x, y)
     norm_x_y = tf.linalg.norm(x_y, axis=-1, keepdims=True)
     eps = utils.get_eps(x)
     tanh = tf.clip_by_value(sqrt_k * norm_x_y, -1.0 + eps, 1.0 - eps)
     lambda_x = self._lambda(x, keepdims=True)
     return 2 * (x_y / norm_x_y) * tf.math.atanh(tanh) / (sqrt_k * lambda_x)
示例#3
0
 def projx(self, x):
     sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype))
     norm = tf.linalg.norm(x, axis=-1, keepdims=True)
     return tf.where(
         sqrt_k * norm < tf.ones_like(norm),
         x,
         x / (sqrt_k * norm + 10 * utils.get_eps(x)),
     )
示例#4
0
    def _mobius_scal_mul(self, x, r):
        """Compute the Möbius scalar multiplication of :math:`x \in
        \mathcal{D}^{n}_{k} \ {0}` by :math:`r`

        :math:`x \otimes r = (1/\sqrt{k})\tanh(r
        \atanh(\sqrt{k}||x||))\frac{x}{||x||}`

        """
        sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype))
        norm_x = tf.linalg.norm(x, axis=-1, keepdims=True)
        eps = utils.get_eps(x)
        tan = tf.clip_by_value(sqrt_k * norm_x, -1.0 + eps, 1.0 - eps)
        return (1 / sqrt_k) * tf.math.tanh(r * tf.math.atanh(tan)) * x / norm_x
示例#5
0
 def _diff_power(self, x, d, power):
     e, v = tf.linalg.eigh(x)
     v_t = utils.transposem(v)
     e = tf.expand_dims(e, -2)
     if power == "log":
         pow_e = tf.math.log(e)
     elif power == "exp":
         pow_e = tf.math.exp(e)
     s = utils.transposem(tf.ones_like(e)) @ e
     pow_s = utils.transposem(tf.ones_like(pow_e)) @ pow_e
     denom = utils.transposem(s) - s
     numer = utils.transposem(pow_s) - pow_s
     abs_denom = tf.math.abs(denom)
     eps = utils.get_eps(x)
     if power == "log":
         numer = tf.where(abs_denom < eps, tf.ones_like(numer), numer)
         denom = tf.where(abs_denom < eps, utils.transposem(s), denom)
     elif power == "exp":
         numer = tf.where(abs_denom < eps, utils.transposem(pow_s), numer)
         denom = tf.where(abs_denom < eps, tf.ones_like(denom), denom)
     t = v_t @ d @ v * numer / denom
     return v @ t @ v_t
 def from_poincare(self, x, k):
     """Inverse of the diffeomorphism to the Poincaré ball"""
     k = tf.cast(k, x.dtype)
     x_sq_norm = tf.reduce_sum(x * x, axis=-1, keepdims=True)
     y = tf.math.sqrt(k) * tf.concat([1 + x_sq_norm, 2 * x], axis=-1)
     return y / (1.0 - x_sq_norm + utils.get_eps(x))
 def norm(self, x, u, keepdims=False):
     inner = self.inner(x, u, u, keepdims=keepdims)
     return tf.math.sqrt(tf.maximum(inner, utils.get_eps(x)))
 def _check_vector_on_tangent(self, x, u, atol, rtol):
     inner = self.inner(x, x, u)
     rtol = 100 * utils.get_eps(x) if rtol is None else rtol
     return utils.allclose(inner, tf.zeros_like(inner), atol, rtol)
示例#9
0
 def log(self, x, y):
     u = self.proju(x, y - x)
     norm_u = self.norm(x, u, keepdims=True)
     dist = self.dist(x, y, keepdims=True)
     log = u * dist / norm_u
     return tf.where(dist > utils.get_eps(x), log, u)
示例#10
0
 def exp(self, x, u):
     norm_u = self.norm(x, u, keepdims=True)
     exp = x * tf.math.cos(norm_u) + u * tf.math.sin(norm_u) / norm_u
     retr = self.projx(x + u)
     return tf.where(norm_u > utils.get_eps(x), exp, retr)
示例#11
0
 def norm(self, x, u, keepdims=False):
     norm_u = tf.linalg.norm(u, axis=-1, keepdims=keepdims)
     return tf.maximum(norm_u, utils.get_eps(x))