def _fit_riemannian(self, X, y, weights=None, compute_training_score=False): """Estimate the parameters using a Riemannian gradient descent. Estimate the intercept and the coefficient defining the geodesic regression model, using the Riemannian gradient. Parameters ---------- X : {array-like, sparse matrix}, shape=[...,}] Training input samples. y : array-like, shape=[..., {dim, [n,n]}] Training target values. weights : array-like, shape=[...,] Weights associated to the points. Optional, default: None. compute_training_score : bool Whether to compute R^2. Optional, default: False. Returns ------- self : object Returns self. """ shape = ( y.shape[-1:] if self.space.default_point_type == "vector" else y.shape[-2:] ) if hasattr(self.metric, "parallel_transport"): def vector_transport(tan_a, tan_b, base_point, _): return self.metric.parallel_transport(tan_a, tan_b, base_point) else: def vector_transport(tan_a, _, __, point): return self.space.to_tangent(tan_a, point) objective_with_grad = gs.autodiff.value_and_grad( lambda params: self._loss(X, y, params, shape, weights) ) lr = self.learning_rate intercept_init, coef_init = self.initialize_parameters(y) intercept_hat = intercept_hat_new = self.space.projection(intercept_init) coef_hat = coef_hat_new = self.space.to_tangent(coef_init, intercept_hat) param = gs.vstack([gs.flatten(intercept_hat), gs.flatten(coef_hat)]) current_loss = [math.inf] current_grad = gs.zeros_like(param) current_iter = i = 0 for i in range(self.max_iter): loss, grad = objective_with_grad(param) if gs.any(gs.isnan(grad)): logging.warning(f"NaN encountered in gradient at iter {current_iter}") lr /= 2 grad = current_grad elif loss >= current_loss[-1] and i > 0: lr /= 2 else: if not current_iter % 5: lr *= 2 coef_hat = coef_hat_new intercept_hat = intercept_hat_new current_iter += 1 if abs(loss - current_loss[-1]) < self.tol: if self.verbose: logging.info(f"Tolerance threshold reached at iter {current_iter}") break grad_intercept, grad_coef = gs.split(grad, 2) riem_grad_intercept = self.space.to_tangent( gs.reshape(grad_intercept, shape), intercept_hat ) riem_grad_coef = self.space.to_tangent( gs.reshape(grad_coef, shape), intercept_hat ) intercept_hat_new = self.metric.exp( -lr * riem_grad_intercept, intercept_hat ) coef_hat_new = vector_transport( coef_hat - lr * riem_grad_coef, -lr * riem_grad_intercept, intercept_hat, intercept_hat_new, ) param = gs.vstack([gs.flatten(intercept_hat_new), gs.flatten(coef_hat_new)]) current_loss.append(loss) current_grad = grad self.intercept_ = self.space.projection(intercept_hat) self.coef_ = self.space.to_tangent(coef_hat, self.intercept_) if self.verbose: logging.info( f"Number of gradient evaluations: {i}, " f"Number of gradient iterations: {current_iter}" f" loss at termination: {current_loss[-1]}" ) if compute_training_score: variance = gs.sum(self.metric.squared_dist(y, self.intercept_)) self.training_score_ = 1 - 2 * current_loss[-1] / variance return self
def _check_bandwidth(bandwidth): """Check if the bandwidth is a positive real number.""" if gs.any(bandwidth <= 0): raise ValueError("The bandwidth should be a positive real number.") bandwidth = gs.array(bandwidth, dtype=float) return bandwidth
def rotation_vector_from_matrix(self, rot_mat): """ In 3D, convert rotation matrix to rotation vector (axis-angle representation). Get the angle through the trace of the rotation matrix: The eigenvalues are: 1, cos(angle) + i sin(angle), cos(angle) - i sin(angle) so that: trace = 1 + 2 cos(angle), -1 <= trace <= 3 Get the rotation vector through the formula: S_r = angle / ( 2 * sin(angle) ) (R - R^T) For the edge case where the angle is close to pi, the formulation is derived by going from rotation matrix to unit quaternion to axis-angle: r = angle * v / |v|, where (w, v) is a unit quaternion. In nD, the rotation vector stores the n(n-1)/2 values of the skew-symmetric matrix representing the rotation. :param rot_mat: rotation matrix :return rot_vec: rotation vector """ rot_mat = gs.to_ndarray(rot_mat, to_ndim=3) n_rot_mats, mat_dim_1, mat_dim_2 = rot_mat.shape assert mat_dim_1 == mat_dim_2 == self.n rot_mat = closest_rotation_matrix(rot_mat) if self.n == 3: trace = gs.trace(rot_mat, axis1=1, axis2=2) trace = gs.to_ndarray(trace, to_ndim=2, axis=1) assert trace.shape == (n_rot_mats, 1), trace.shape cos_angle = .5 * (trace - 1) cos_angle = gs.clip(cos_angle, -1, 1) angle = gs.arccos(cos_angle) rot_mat_transpose = gs.transpose(rot_mat, axes=(0, 2, 1)) rot_vec = vector_from_skew_matrix(rot_mat - rot_mat_transpose) mask_0 = gs.isclose(angle, 0) mask_0 = gs.squeeze(mask_0, axis=1) rot_vec[mask_0] = (rot_vec[mask_0] * (.5 - (trace[mask_0] - 3.) / 12.)) mask_pi = gs.isclose(angle, gs.pi) mask_pi = gs.squeeze(mask_pi, axis=1) # choose the largest diagonal element # to avoid a square root of a negative number a = 0 if gs.any(mask_pi): a = gs.argmax(gs.diagonal(rot_mat[mask_pi], axis1=1, axis2=2)) b = gs.mod(a + 1, 3) c = gs.mod(a + 2, 3) # compute the axis vector sq_root = gs.sqrt( (rot_mat[mask_pi, a, a] - rot_mat[mask_pi, b, b] - rot_mat[mask_pi, c, c] + 1.)) rot_vec_pi = gs.zeros((sum(mask_pi), self.dimension)) rot_vec_pi[:, a] = sq_root / 2. rot_vec_pi[:, b] = ( (rot_mat[mask_pi, b, a] + rot_mat[mask_pi, a, b]) / (2. * sq_root)) rot_vec_pi[:, c] = ( (rot_mat[mask_pi, c, a] + rot_mat[mask_pi, a, c]) / (2. * sq_root)) rot_vec[mask_pi] = (angle[mask_pi] * rot_vec_pi / gs.linalg.norm(rot_vec_pi)) mask_else = ~mask_0 & ~mask_pi rot_vec[mask_else] = (angle[mask_else] / (2. * gs.sin(angle[mask_else])) * rot_vec[mask_else]) else: skew_mat = self.embedding_manifold.group_log_from_identity(rot_mat) rot_vec = vector_from_skew_matrix(skew_mat) return self.regularize(rot_vec)
def _check_distance(distance): """Check if the distance if a non-negative real number.""" if gs.any(distance < 0): raise ValueError("The distance should be a non-negative real number.") distance = gs.array(distance, dtype=float) return distance
def christoffels(self, base_point): """Compute the Christoffel symbols. Compute the Christoffel symbols of the Fisher information metric. For computation purposes, we replace the value of (gs.polygamma(1, x) - 1/x) by an equivalent (close lower-bound) when it becomes too difficult to compute, as per in the second reference. References ---------- .. [AD2008] Arwini, K. A., & Dodson, C. T. (2008). Information geometry (pp. 31-54). Springer Berlin Heidelberg. .. [GQ2015] Guo, B. N., Qi, F., Zhao, J. L., & Luo, Q. M. (2015). Sharp inequalities for polygamma functions. Mathematica Slovaca, 65(1), 103-120. Parameters ---------- base_point : array-like, shape=[..., 2] Base point. Returns ------- christoffels : array-like, shape=[..., 2, 2, 2] Christoffel symbols, with the contravariant index on the first dimension. :math: 'christoffels[..., i, j, k] = Gamma^i_{jk}' """ base_point = gs.to_ndarray(base_point, to_ndim=2) kappa, gamma = base_point[:, 0], base_point[:, 1] if gs.any(kappa > 4e15): raise ValueError( "Christoffels computation overflows with values of kappa. " "All values of kappa < 4e15 work.") shape = kappa.shape c111 = gs.where( gs.polygamma(1, kappa) - 1 / kappa > gs.atol, (gs.polygamma(2, kappa) + gs.array(kappa, dtype=gs.float32)**-2) / (2 * (gs.polygamma(1, kappa) - 1 / kappa)), 0.25 * (kappa**2 * gs.polygamma(2, kappa) + 1), ) c122 = gs.where( gs.polygamma(1, kappa) - 1 / kappa > gs.atol, -1 / (2 * gamma**2 * (gs.polygamma(1, kappa) - 1 / kappa)), -(kappa**2) / (4 * gamma**2), ) c1 = gs.squeeze( from_vector_to_diagonal_matrix(gs.transpose(gs.array([c111, c122])))) c2 = gs.squeeze( gs.transpose( gs.array([[gs.zeros(shape), 1 / (2 * kappa)], [1 / (2 * kappa), -1 / gamma]]))) christoffels = gs.array([c1, c2]) if len(christoffels.shape) == 4: christoffels = gs.transpose(christoffels, [1, 0, 2, 3]) return gs.squeeze(christoffels)