def test_parallel_transport(self): n_samples = 2 base_point = self.hypersphere.random_uniform(n_samples) tan_vec_a = self.hypersphere.to_tangent(gs.random.rand(n_samples, 3), base_point) tan_vec_b = self.hypersphere.to_tangent(gs.random.rand(n_samples, 3), base_point) expected = self.hypersphere.metric.parallel_transport( tan_vec_a, base_point, tan_vec_b) expected_point = self.hypersphere.metric.exp(tan_vec_b, base_point) base_point = gs.cast(base_point, gs.float64) base_point, tan_vec_a, tan_vec_b = gs.convert_to_wider_dtype( [base_point, tan_vec_a, tan_vec_b]) for step, alpha in zip(["pole", "schild"], [1, 2]): min_n = 1 if step == "pole" else 50 tol = 1e-5 if step == "pole" else 1e-2 for n_rungs in [min_n, 11]: ladder = self.hypersphere.metric.ladder_parallel_transport( tan_vec_a, base_point, tan_vec_b, n_rungs=n_rungs, scheme=step, alpha=alpha, ) result = ladder["transported_tangent_vec"] result_point = ladder["end_point"] self.assertAllClose(result, expected, rtol=tol, atol=tol) self.assertAllClose(result_point, expected_point)
def test_convert_to_wider_dtype(self): gs_list = [gs.array([1, 2]), gs.array([2.2, 3.3], dtype=gs.float32)] gs_result = gs.convert_to_wider_dtype(gs_list) result = [a.dtype == gs.float32 for a in gs_result] self.assertTrue(gs.all(result)) gs_list = [gs.array([1, 2]), gs.array([2.2, 3.3], dtype=gs.float64)] gs_result = gs.convert_to_wider_dtype(gs_list) result = [a.dtype == gs.float64 for a in gs_result] self.assertTrue(gs.all(result)) gs_list = [ gs.array([11.11, 222.2], dtype=gs.float64), gs.array([2.2, 3.3], dtype=gs.float32)] gs_result = gs.convert_to_wider_dtype(gs_list) result = [a.dtype == gs.float64 for a in gs_result] self.assertTrue(gs.all(result))
def log(self, point, base_point, **kwargs): r"""Compute the Riemannian logarithm of point w.r.t. base_point. Given :math:`P, P'` in Gr(n, k) the logarithm from :math:`P` to :math:`P` is induced by the infinitesimal rotation [Batzies2015]_: .. math:: Y = \frac 1 2 \log \big((2 P' - 1)(2 P - 1)\big) The tangent vector :math:`X` at :math:`P` is then recovered by :math:`X = [Y, P]`. Parameters ---------- point : array-like, shape=[..., n, n] Point. base_point : array-like, shape=[..., n, n] Base point. Returns ------- tangent_vec : array-like, shape=[..., n, n] Riemannian logarithm, a tangent vector at `base_point`. References ---------- .. [Batzies2015] Batzies, Hüper, Machado, Leite. "Geometric Mean and Geodesic Regression on Grassmannians" Linear Algebra and its Applications, 466, 83-101, 2015. """ GLn = GeneralLinear(self.n) id_n = GLn.identity id_n, point, base_point = gs.convert_to_wider_dtype( [id_n, point, base_point]) sym2 = 2 * point - id_n sym1 = 2 * base_point - id_n rot = GLn.compose(sym2, sym1) return Matrices.bracket(GLn.log(rot) / 2, base_point)
def _default_gradient_descent(points, metric, weights, max_iter, point_type, epsilon, initial_step_size, verbose): """Perform default gradient descent.""" if point_type == "vector": points = gs.to_ndarray(points, to_ndim=2) einsum_str = "n,nj->j" else: points = gs.to_ndarray(points, to_ndim=3) einsum_str = "n,nij->ij" n_points = gs.shape(points)[0] if weights is None: weights = gs.ones((n_points, )) mean = points[0] if n_points == 1: return mean sum_weights = gs.sum(weights) sq_dists_between_iterates = [] iteration = 0 sq_dist = 0.0 var = 0.0 norm_old = gs.linalg.norm(points) step = initial_step_size while iteration < max_iter: logs = metric.log(point=points, base_point=mean) weights, logs = gs.convert_to_wider_dtype([weights, logs]) var = gs.sum( metric.squared_norm(logs, mean) * weights) / gs.sum(weights) tangent_mean = gs.einsum(einsum_str, weights, logs) tangent_mean /= sum_weights norm = gs.linalg.norm(tangent_mean) sq_dist = metric.squared_norm(tangent_mean, mean) sq_dists_between_iterates.append(sq_dist) var_is_0 = gs.isclose(var, gs.array(0.0, dtype=var.dtype)) sq_dist_is_small = gs.less_equal(sq_dist, epsilon * metric.dim) condition = ~gs.logical_or(var_is_0, sq_dist_is_small) if not (condition or iteration == 0): break estimate_next = metric.exp(step * tangent_mean, mean) mean = estimate_next iteration += 1 if norm < norm_old: norm_old = norm elif norm > norm_old: step = step / 2.0 if iteration == max_iter: logging.warning("Maximum number of iterations {} reached. " "The mean may be inaccurate".format(max_iter)) if verbose: logging.info("n_iter: {}, final variance: {}, final dist: {}".format( iteration, var, sq_dist)) return mean