def group_log(self, point, base_point=None, point_type=None): """ Compute the group logarithm of point relative to base_point. Parameters ---------- point: array-like, shape=[n_samples, {dimension,[n,n]}] base_point: array-like, shape=[n_samples, {dimension,[n,n]}] point_type: {'vector', 'matrix'} Returns ------ tangent_vec: array-like, shape=[n_samples, {dimension,[n,n]}] """ if point_type is None: point_type = self.default_point_type identity = self.get_identity(point_type=point_type) if base_point is None: base_point = identity if point_type == "vector": point = gs.to_ndarray(point, to_ndim=2) base_point = gs.to_ndarray(base_point, to_ndim=2) if point_type == "matrix": point = gs.to_ndarray(point, to_ndim=3) base_point = gs.to_ndarray(base_point, to_ndim=3) point = self.regularize(point, point_type=point_type) base_point = self.regularize(base_point, point_type=point_type) n_points = point.shape[0] n_base_points = base_point.shape[0] assert (point.shape == base_point.shape or n_points == 1 or n_base_points == 1) if n_points == 1: point = gs.array([point[0]] * n_base_points) if n_base_points == 1: base_point = gs.array([base_point[0]] * n_points) result = gs.cond( pred=gs.allclose(base_point, identity), true_fn=lambda: self.group_log_from_identity( point, point_type=point_type), false_fn=lambda: self.group_log_not_from_identity( point, base_point, point_type), ) return result
def exp(self, tangent_vec, base_point=None, point_type=None): """Compute the group exponential at `base_point` of `tangent_vec`. Parameters ---------- tangent_vec : array-like, shape=[n_samples, {dimension,[n,n]}] base_point : array-like, shape=[n_samples, {dimension,[n,n]}] default: self.identity point_type : str, {'vector', 'matrix'} default: the default point type the type of the point Returns ------- result : array-like, shape=[n_samples, {dimension,[n,n]}] The exponentiated tangent vector """ if point_type is None: point_type = self.default_point_type identity = self.get_identity(point_type=point_type) identity = self.regularize(identity, point_type=point_type) if base_point is None: base_point = identity base_point = self.regularize(base_point, point_type=point_type) if point_type == "vector": tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2) base_point = gs.to_ndarray(base_point, to_ndim=2) if point_type == "matrix": tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=3) base_point = gs.to_ndarray(base_point, to_ndim=3) n_tangent_vecs = tangent_vec.shape[0] n_base_points = base_point.shape[0] assert (tangent_vec.shape == base_point.shape or n_tangent_vecs == 1 or n_base_points == 1) if n_tangent_vecs == 1: tangent_vec = gs.array([tangent_vec[0]] * n_base_points) if n_base_points == 1: base_point = gs.array([base_point[0]] * n_tangent_vecs) result = gs.cond(pred=gs.allclose(base_point, identity), true_fn=lambda: self.exp_from_identity( tangent_vec, point_type=point_type), false_fn=lambda: self.exp_not_from_identity( tangent_vec, base_point, point_type)) return result
def group_log(self, point, base_point=None, point_type=None): """ Compute the group logarithm at point base_point of the point point. """ if point_type is None: point_type = self.default_point_type identity = self.get_identity(point_type=point_type) if base_point is None: base_point = identity if point_type == 'vector': point = gs.to_ndarray(point, to_ndim=2) base_point = gs.to_ndarray(base_point, to_ndim=2) if point_type == 'matrix': point = gs.to_ndarray(point, to_ndim=3) base_point = gs.to_ndarray(base_point, to_ndim=3) point = self.regularize(point, point_type=point_type) base_point = self.regularize(base_point, point_type=point_type) n_points = point.shape[0] n_base_points = base_point.shape[0] assert (point.shape == base_point.shape or n_points == 1 or n_base_points == 1) if n_points == 1: point = gs.array([point[0]] * n_base_points) if n_base_points == 1: base_point = gs.array([base_point[0]] * n_points) result = gs.cond(pred=gs.allclose(base_point, identity), true_fn=lambda: self.group_log_from_identity( point, point_type=point_type), false_fn=lambda: self.group_log_not_from_identity( point, base_point, point_type)) return result
def group_exp(self, tangent_vec, base_point=None, point_type=None): """ Compute the group exponential at point base_point of tangent vector tangent_vec. """ if point_type is None: point_type = self.default_point_type identity = self.get_identity(point_type=point_type) identity = self.regularize(identity, point_type=point_type) if base_point is None: base_point = identity base_point = self.regularize(base_point, point_type=point_type) if point_type == 'vector': tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2) base_point = gs.to_ndarray(base_point, to_ndim=2) if point_type == 'matrix': tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=3) base_point = gs.to_ndarray(base_point, to_ndim=3) n_tangent_vecs = tangent_vec.shape[0] n_base_points = base_point.shape[0] assert (tangent_vec.shape == base_point.shape or n_tangent_vecs == 1 or n_base_points == 1) if n_tangent_vecs == 1: tangent_vec = gs.array([tangent_vec[0]] * n_base_points) if n_base_points == 1: base_point = gs.array([base_point[0]] * n_tangent_vecs) result = gs.cond(pred=gs.allclose(base_point, identity), true_fn=lambda: self.group_exp_from_identity( tangent_vec, point_type=point_type), false_fn=lambda: self.group_exp_not_from_identity( tangent_vec, base_point, point_type)) return result