Exemplo n.º 1
0
 def test_parallel_transport(self):
     n_samples = 2
     base_point = self.hypersphere.random_uniform(n_samples)
     tan_vec_a = self.hypersphere.to_tangent(gs.random.rand(n_samples, 3),
                                             base_point)
     tan_vec_b = self.hypersphere.to_tangent(gs.random.rand(n_samples, 3),
                                             base_point)
     expected = self.hypersphere.metric.parallel_transport(
         tan_vec_a, base_point, tan_vec_b)
     expected_point = self.hypersphere.metric.exp(tan_vec_b, base_point)
     base_point = gs.cast(base_point, gs.float64)
     base_point, tan_vec_a, tan_vec_b = gs.convert_to_wider_dtype(
         [base_point, tan_vec_a, tan_vec_b])
     for step, alpha in zip(["pole", "schild"], [1, 2]):
         min_n = 1 if step == "pole" else 50
         tol = 1e-5 if step == "pole" else 1e-2
         for n_rungs in [min_n, 11]:
             ladder = self.hypersphere.metric.ladder_parallel_transport(
                 tan_vec_a,
                 base_point,
                 tan_vec_b,
                 n_rungs=n_rungs,
                 scheme=step,
                 alpha=alpha,
             )
             result = ladder["transported_tangent_vec"]
             result_point = ladder["end_point"]
             self.assertAllClose(result, expected, rtol=tol, atol=tol)
             self.assertAllClose(result_point, expected_point)
Exemplo n.º 2
0
    def test_convert_to_wider_dtype(self):
        gs_list = [gs.array([1, 2]), gs.array([2.2, 3.3], dtype=gs.float32)]
        gs_result = gs.convert_to_wider_dtype(gs_list)

        result = [a.dtype == gs.float32 for a in gs_result]

        self.assertTrue(gs.all(result))

        gs_list = [gs.array([1, 2]), gs.array([2.2, 3.3], dtype=gs.float64)]
        gs_result = gs.convert_to_wider_dtype(gs_list)

        result = [a.dtype == gs.float64 for a in gs_result]

        self.assertTrue(gs.all(result))

        gs_list = [
            gs.array([11.11, 222.2], dtype=gs.float64),
            gs.array([2.2, 3.3], dtype=gs.float32)]
        gs_result = gs.convert_to_wider_dtype(gs_list)

        result = [a.dtype == gs.float64 for a in gs_result]

        self.assertTrue(gs.all(result))
Exemplo n.º 3
0
    def log(self, point, base_point, **kwargs):
        r"""Compute the Riemannian logarithm of point w.r.t. base_point.

        Given :math:`P, P'` in Gr(n, k) the logarithm from :math:`P`
        to :math:`P` is induced by the infinitesimal rotation [Batzies2015]_:

        .. math::

            Y = \frac 1 2 \log \big((2 P' - 1)(2 P - 1)\big)

        The tangent vector :math:`X` at :math:`P`
        is then recovered by :math:`X = [Y, P]`.

        Parameters
        ----------
        point : array-like, shape=[..., n, n]
            Point.
        base_point : array-like, shape=[..., n, n]
            Base point.

        Returns
        -------
        tangent_vec : array-like, shape=[..., n, n]
            Riemannian logarithm, a tangent vector at `base_point`.

        References
        ----------
        .. [Batzies2015] Batzies, Hüper, Machado, Leite.
            "Geometric Mean and Geodesic Regression on Grassmannians"
            Linear Algebra and its Applications, 466, 83-101, 2015.
        """
        GLn = GeneralLinear(self.n)
        id_n = GLn.identity
        id_n, point, base_point = gs.convert_to_wider_dtype(
            [id_n, point, base_point])
        sym2 = 2 * point - id_n
        sym1 = 2 * base_point - id_n
        rot = GLn.compose(sym2, sym1)
        return Matrices.bracket(GLn.log(rot) / 2, base_point)
Exemplo n.º 4
0
def _default_gradient_descent(points, metric, weights, max_iter, point_type,
                              epsilon, initial_step_size, verbose):
    """Perform default gradient descent."""
    if point_type == "vector":
        points = gs.to_ndarray(points, to_ndim=2)
        einsum_str = "n,nj->j"
    else:
        points = gs.to_ndarray(points, to_ndim=3)
        einsum_str = "n,nij->ij"
    n_points = gs.shape(points)[0]

    if weights is None:
        weights = gs.ones((n_points, ))

    mean = points[0]

    if n_points == 1:
        return mean

    sum_weights = gs.sum(weights)
    sq_dists_between_iterates = []
    iteration = 0
    sq_dist = 0.0
    var = 0.0

    norm_old = gs.linalg.norm(points)
    step = initial_step_size

    while iteration < max_iter:
        logs = metric.log(point=points, base_point=mean)
        weights, logs = gs.convert_to_wider_dtype([weights, logs])

        var = gs.sum(
            metric.squared_norm(logs, mean) * weights) / gs.sum(weights)

        tangent_mean = gs.einsum(einsum_str, weights, logs)
        tangent_mean /= sum_weights
        norm = gs.linalg.norm(tangent_mean)

        sq_dist = metric.squared_norm(tangent_mean, mean)
        sq_dists_between_iterates.append(sq_dist)

        var_is_0 = gs.isclose(var, gs.array(0.0, dtype=var.dtype))
        sq_dist_is_small = gs.less_equal(sq_dist, epsilon * metric.dim)
        condition = ~gs.logical_or(var_is_0, sq_dist_is_small)
        if not (condition or iteration == 0):
            break

        estimate_next = metric.exp(step * tangent_mean, mean)
        mean = estimate_next
        iteration += 1

        if norm < norm_old:
            norm_old = norm
        elif norm > norm_old:
            step = step / 2.0

    if iteration == max_iter:
        logging.warning("Maximum number of iterations {} reached. "
                        "The mean may be inaccurate".format(max_iter))

    if verbose:
        logging.info("n_iter: {}, final variance: {}, final dist: {}".format(
            iteration, var, sq_dist))

    return mean