def dist(self, x, y, keepdims=False): sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype)) x_y = self._mobius_add(-x, y) norm_x_y = tf.linalg.norm(x_y, axis=-1, keepdims=keepdims) eps = utils.get_eps(x) tanh = tf.clip_by_value(sqrt_k * norm_x_y, -1.0 + eps, 1.0 - eps) return 2 * tf.math.atanh(tanh) / sqrt_k
def log(self, x, y): sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype)) x_y = self._mobius_add(-x, y) norm_x_y = tf.linalg.norm(x_y, axis=-1, keepdims=True) eps = utils.get_eps(x) tanh = tf.clip_by_value(sqrt_k * norm_x_y, -1.0 + eps, 1.0 - eps) lambda_x = self._lambda(x, keepdims=True) return 2 * (x_y / norm_x_y) * tf.math.atanh(tanh) / (sqrt_k * lambda_x)
def projx(self, x): sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype)) norm = tf.linalg.norm(x, axis=-1, keepdims=True) return tf.where( sqrt_k * norm < tf.ones_like(norm), x, x / (sqrt_k * norm + 10 * utils.get_eps(x)), )
def _mobius_scal_mul(self, x, r): """Compute the Möbius scalar multiplication of :math:`x \in \mathcal{D}^{n}_{k} \ {0}` by :math:`r` :math:`x \otimes r = (1/\sqrt{k})\tanh(r \atanh(\sqrt{k}||x||))\frac{x}{||x||}` """ sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype)) norm_x = tf.linalg.norm(x, axis=-1, keepdims=True) eps = utils.get_eps(x) tan = tf.clip_by_value(sqrt_k * norm_x, -1.0 + eps, 1.0 - eps) return (1 / sqrt_k) * tf.math.tanh(r * tf.math.atanh(tan)) * x / norm_x
def _diff_power(self, x, d, power): e, v = tf.linalg.eigh(x) v_t = utils.transposem(v) e = tf.expand_dims(e, -2) if power == "log": pow_e = tf.math.log(e) elif power == "exp": pow_e = tf.math.exp(e) s = utils.transposem(tf.ones_like(e)) @ e pow_s = utils.transposem(tf.ones_like(pow_e)) @ pow_e denom = utils.transposem(s) - s numer = utils.transposem(pow_s) - pow_s abs_denom = tf.math.abs(denom) eps = utils.get_eps(x) if power == "log": numer = tf.where(abs_denom < eps, tf.ones_like(numer), numer) denom = tf.where(abs_denom < eps, utils.transposem(s), denom) elif power == "exp": numer = tf.where(abs_denom < eps, utils.transposem(pow_s), numer) denom = tf.where(abs_denom < eps, tf.ones_like(denom), denom) t = v_t @ d @ v * numer / denom return v @ t @ v_t
def from_poincare(self, x, k): """Inverse of the diffeomorphism to the Poincaré ball""" k = tf.cast(k, x.dtype) x_sq_norm = tf.reduce_sum(x * x, axis=-1, keepdims=True) y = tf.math.sqrt(k) * tf.concat([1 + x_sq_norm, 2 * x], axis=-1) return y / (1.0 - x_sq_norm + utils.get_eps(x))
def norm(self, x, u, keepdims=False): inner = self.inner(x, u, u, keepdims=keepdims) return tf.math.sqrt(tf.maximum(inner, utils.get_eps(x)))
def _check_vector_on_tangent(self, x, u, atol, rtol): inner = self.inner(x, x, u) rtol = 100 * utils.get_eps(x) if rtol is None else rtol return utils.allclose(inner, tf.zeros_like(inner), atol, rtol)
def log(self, x, y): u = self.proju(x, y - x) norm_u = self.norm(x, u, keepdims=True) dist = self.dist(x, y, keepdims=True) log = u * dist / norm_u return tf.where(dist > utils.get_eps(x), log, u)
def exp(self, x, u): norm_u = self.norm(x, u, keepdims=True) exp = x * tf.math.cos(norm_u) + u * tf.math.sin(norm_u) / norm_u retr = self.projx(x + u) return tf.where(norm_u > utils.get_eps(x), exp, retr)
def norm(self, x, u, keepdims=False): norm_u = tf.linalg.norm(u, axis=-1, keepdims=keepdims) return tf.maximum(norm_u, utils.get_eps(x))