def _fit_riemannian(self, X, y, weights=None, compute_training_score=False):
        """Estimate the parameters using a Riemannian gradient descent.

        Estimate the intercept and the coefficient defining the
        geodesic regression model, using the Riemannian gradient.

        Parameters
        ----------
        X : {array-like, sparse matrix}, shape=[...,}]
            Training input samples.
        y : array-like, shape=[..., {dim, [n,n]}]
            Training target values.
        weights : array-like, shape=[...,]
            Weights associated to the points.
            Optional, default: None.
        compute_training_score : bool
            Whether to compute R^2.
            Optional, default: False.

        Returns
        -------
        self : object
            Returns self.
        """
        shape = (
            y.shape[-1:] if self.space.default_point_type == "vector" else y.shape[-2:]
        )
        if hasattr(self.metric, "parallel_transport"):

            def vector_transport(tan_a, tan_b, base_point, _):
                return self.metric.parallel_transport(tan_a, tan_b, base_point)

        else:

            def vector_transport(tan_a, _, __, point):
                return self.space.to_tangent(tan_a, point)

        objective_with_grad = gs.autodiff.value_and_grad(
            lambda params: self._loss(X, y, params, shape, weights)
        )

        lr = self.learning_rate
        intercept_init, coef_init = self.initialize_parameters(y)
        intercept_hat = intercept_hat_new = self.space.projection(intercept_init)
        coef_hat = coef_hat_new = self.space.to_tangent(coef_init, intercept_hat)
        param = gs.vstack([gs.flatten(intercept_hat), gs.flatten(coef_hat)])
        current_loss = [math.inf]
        current_grad = gs.zeros_like(param)
        current_iter = i = 0
        for i in range(self.max_iter):
            loss, grad = objective_with_grad(param)
            if gs.any(gs.isnan(grad)):
                logging.warning(f"NaN encountered in gradient at iter {current_iter}")
                lr /= 2
                grad = current_grad
            elif loss >= current_loss[-1] and i > 0:
                lr /= 2
            else:
                if not current_iter % 5:
                    lr *= 2
                coef_hat = coef_hat_new
                intercept_hat = intercept_hat_new
                current_iter += 1
            if abs(loss - current_loss[-1]) < self.tol:
                if self.verbose:
                    logging.info(f"Tolerance threshold reached at iter {current_iter}")
                break

            grad_intercept, grad_coef = gs.split(grad, 2)
            riem_grad_intercept = self.space.to_tangent(
                gs.reshape(grad_intercept, shape), intercept_hat
            )
            riem_grad_coef = self.space.to_tangent(
                gs.reshape(grad_coef, shape), intercept_hat
            )

            intercept_hat_new = self.metric.exp(
                -lr * riem_grad_intercept, intercept_hat
            )
            coef_hat_new = vector_transport(
                coef_hat - lr * riem_grad_coef,
                -lr * riem_grad_intercept,
                intercept_hat,
                intercept_hat_new,
            )

            param = gs.vstack([gs.flatten(intercept_hat_new), gs.flatten(coef_hat_new)])

            current_loss.append(loss)
            current_grad = grad

        self.intercept_ = self.space.projection(intercept_hat)
        self.coef_ = self.space.to_tangent(coef_hat, self.intercept_)

        if self.verbose:
            logging.info(
                f"Number of gradient evaluations: {i}, "
                f"Number of gradient iterations: {current_iter}"
                f" loss at termination: {current_loss[-1]}"
            )
        if compute_training_score:
            variance = gs.sum(self.metric.squared_dist(y, self.intercept_))
            self.training_score_ = 1 - 2 * current_loss[-1] / variance

        return self
示例#2
0
def _check_bandwidth(bandwidth):
    """Check if the bandwidth is a positive real number."""
    if gs.any(bandwidth <= 0):
        raise ValueError("The bandwidth should be a positive real number.")
    bandwidth = gs.array(bandwidth, dtype=float)
    return bandwidth
    def rotation_vector_from_matrix(self, rot_mat):
        """
        In 3D, convert rotation matrix to rotation vector
        (axis-angle representation).

        Get the angle through the trace of the rotation matrix:
        The eigenvalues are:
        1, cos(angle) + i sin(angle), cos(angle) - i sin(angle)
        so that: trace = 1 + 2 cos(angle), -1 <= trace <= 3

        Get the rotation vector through the formula:
        S_r = angle / ( 2 * sin(angle) ) (R - R^T)

        For the edge case where the angle is close to pi,
        the formulation is derived by going from rotation matrix to unit
        quaternion to axis-angle:
         r = angle * v / |v|, where (w, v) is a unit quaternion.

        In nD, the rotation vector stores the n(n-1)/2 values of the
        skew-symmetric matrix representing the rotation.

        :param rot_mat: rotation matrix
        :return rot_vec: rotation vector
        """
        rot_mat = gs.to_ndarray(rot_mat, to_ndim=3)
        n_rot_mats, mat_dim_1, mat_dim_2 = rot_mat.shape
        assert mat_dim_1 == mat_dim_2 == self.n

        rot_mat = closest_rotation_matrix(rot_mat)

        if self.n == 3:
            trace = gs.trace(rot_mat, axis1=1, axis2=2)
            trace = gs.to_ndarray(trace, to_ndim=2, axis=1)
            assert trace.shape == (n_rot_mats, 1), trace.shape

            cos_angle = .5 * (trace - 1)
            cos_angle = gs.clip(cos_angle, -1, 1)
            angle = gs.arccos(cos_angle)

            rot_mat_transpose = gs.transpose(rot_mat, axes=(0, 2, 1))
            rot_vec = vector_from_skew_matrix(rot_mat - rot_mat_transpose)

            mask_0 = gs.isclose(angle, 0)
            mask_0 = gs.squeeze(mask_0, axis=1)
            rot_vec[mask_0] = (rot_vec[mask_0] * (.5 -
                                                  (trace[mask_0] - 3.) / 12.))

            mask_pi = gs.isclose(angle, gs.pi)
            mask_pi = gs.squeeze(mask_pi, axis=1)

            # choose the largest diagonal element
            # to avoid a square root of a negative number
            a = 0
            if gs.any(mask_pi):
                a = gs.argmax(gs.diagonal(rot_mat[mask_pi], axis1=1, axis2=2))
            b = gs.mod(a + 1, 3)
            c = gs.mod(a + 2, 3)

            # compute the axis vector
            sq_root = gs.sqrt(
                (rot_mat[mask_pi, a, a] - rot_mat[mask_pi, b, b] -
                 rot_mat[mask_pi, c, c] + 1.))
            rot_vec_pi = gs.zeros((sum(mask_pi), self.dimension))
            rot_vec_pi[:, a] = sq_root / 2.
            rot_vec_pi[:, b] = (
                (rot_mat[mask_pi, b, a] + rot_mat[mask_pi, a, b]) /
                (2. * sq_root))
            rot_vec_pi[:, c] = (
                (rot_mat[mask_pi, c, a] + rot_mat[mask_pi, a, c]) /
                (2. * sq_root))

            rot_vec[mask_pi] = (angle[mask_pi] * rot_vec_pi /
                                gs.linalg.norm(rot_vec_pi))

            mask_else = ~mask_0 & ~mask_pi
            rot_vec[mask_else] = (angle[mask_else] /
                                  (2. * gs.sin(angle[mask_else])) *
                                  rot_vec[mask_else])
        else:
            skew_mat = self.embedding_manifold.group_log_from_identity(rot_mat)
            rot_vec = vector_from_skew_matrix(skew_mat)

        return self.regularize(rot_vec)
示例#4
0
def _check_distance(distance):
    """Check if the distance if a non-negative real number."""
    if gs.any(distance < 0):
        raise ValueError("The distance should be a non-negative real number.")
    distance = gs.array(distance, dtype=float)
    return distance
示例#5
0
    def christoffels(self, base_point):
        """Compute the Christoffel symbols.

        Compute the Christoffel symbols of the Fisher information metric.
        For computation purposes, we replace the value of
        (gs.polygamma(1, x) - 1/x) by an equivalent (close lower-bound) when it becomes
        too difficult to compute, as per in the second reference.

        References
        ----------
        .. [AD2008] Arwini, K. A., & Dodson, C. T. (2008).
            Information geometry (pp. 31-54). Springer Berlin Heidelberg.

        .. [GQ2015] Guo, B. N., Qi, F., Zhao, J. L., & Luo, Q. M. (2015).
            Sharp inequalities for polygamma functions.
            Mathematica Slovaca, 65(1), 103-120.

        Parameters
        ----------
        base_point : array-like, shape=[..., 2]
            Base point.

        Returns
        -------
        christoffels : array-like, shape=[..., 2, 2, 2]
            Christoffel symbols, with the contravariant index on
            the first dimension.
            :math: 'christoffels[..., i, j, k] = Gamma^i_{jk}'
        """
        base_point = gs.to_ndarray(base_point, to_ndim=2)

        kappa, gamma = base_point[:, 0], base_point[:, 1]

        if gs.any(kappa > 4e15):
            raise ValueError(
                "Christoffels computation overflows with values of kappa. "
                "All values of kappa < 4e15 work.")

        shape = kappa.shape

        c111 = gs.where(
            gs.polygamma(1, kappa) - 1 / kappa > gs.atol,
            (gs.polygamma(2, kappa) + gs.array(kappa, dtype=gs.float32)**-2) /
            (2 * (gs.polygamma(1, kappa) - 1 / kappa)),
            0.25 * (kappa**2 * gs.polygamma(2, kappa) + 1),
        )

        c122 = gs.where(
            gs.polygamma(1, kappa) - 1 / kappa > gs.atol,
            -1 / (2 * gamma**2 * (gs.polygamma(1, kappa) - 1 / kappa)),
            -(kappa**2) / (4 * gamma**2),
        )

        c1 = gs.squeeze(
            from_vector_to_diagonal_matrix(gs.transpose(gs.array([c111,
                                                                  c122]))))

        c2 = gs.squeeze(
            gs.transpose(
                gs.array([[gs.zeros(shape), 1 / (2 * kappa)],
                          [1 / (2 * kappa), -1 / gamma]])))

        christoffels = gs.array([c1, c2])

        if len(christoffels.shape) == 4:
            christoffels = gs.transpose(christoffels, [1, 0, 2, 3])

        return gs.squeeze(christoffels)