def dist(self, point_a, point_b): """Compute geodesic distance between two points. Parameters ---------- point_a : array-like, shape=[n_samples, dimension + 1] or shape=[1, dimension + 1] point_b : array-like, shape=[n_samples, dimension + 1] or shape=[1, dimension + 1] Returns ------- dist : array-like, shape=[n_samples, 1] or shape=[1, 1] """ norm_a = self.embedding_metric.norm(point_a) norm_b = self.embedding_metric.norm(point_b) inner_prod = self.embedding_metric.inner_product(point_a, point_b) cos_angle = inner_prod / (norm_a * norm_b) cos_angle = gs.clip(cos_angle, -1, 1) dist = gs.arccos(cos_angle) return dist
def dist(self, point_a, point_b): """Compute the geodesic distance between two points. Parameters ---------- point_a : array-like, shape=[..., dim + 1] First point on the hypersphere. point_b : array-like, shape=[..., dim + 1] Second point on the hypersphere. Returns ------- dist : array-like, shape=[..., 1] Geodesic distance between the two points. """ norm_a = self.embedding_metric.norm(point_a) norm_b = self.embedding_metric.norm(point_b) inner_prod = self.embedding_metric.inner_product(point_a, point_b) cos_angle = inner_prod / (norm_a * norm_b) cos_angle = gs.clip(cos_angle, -1, 1) dist = gs.arccos(cos_angle) return dist
def log(self, point, base_point): """ Riemannian logarithm of a point wrt a base point. """ point = gs.to_ndarray(point, to_ndim=2) base_point = gs.to_ndarray(base_point, to_ndim=2) norm_base_point = self.embedding_metric.norm(base_point) norm_point = self.embedding_metric.norm(point) inner_prod = self.embedding_metric.inner_product(base_point, point) cos_angle = inner_prod / (norm_base_point * norm_point) cos_angle = gs.clip(cos_angle, -1., 1.) angle = gs.arccos(cos_angle) angle = gs.to_ndarray(angle, to_ndim=1) angle = gs.to_ndarray(angle, to_ndim=2, axis=1) mask_0 = gs.isclose(angle, 0.) mask_else = gs.equal(mask_0, gs.array(False)) mask_0_float = gs.cast(mask_0, gs.float32) mask_else_float = gs.cast(mask_else, gs.float32) coef_1 = gs.zeros_like(angle) coef_2 = gs.zeros_like(angle) coef_1 += mask_0_float * (1. + INV_SIN_TAYLOR_COEFFS[1] * angle**2 + INV_SIN_TAYLOR_COEFFS[3] * angle**4 + INV_SIN_TAYLOR_COEFFS[5] * angle**6 + INV_SIN_TAYLOR_COEFFS[7] * angle**8) coef_2 += mask_0_float * (1. + INV_TAN_TAYLOR_COEFFS[1] * angle**2 + INV_TAN_TAYLOR_COEFFS[3] * angle**4 + INV_TAN_TAYLOR_COEFFS[5] * angle**6 + INV_TAN_TAYLOR_COEFFS[7] * angle**8) # This avoids division by 0. angle += mask_0_float * 1. coef_1 += mask_else_float * angle / gs.sin(angle) coef_2 += mask_else_float * angle / gs.tan(angle) log = (gs.einsum('ni,nj->nj', coef_1, point) - gs.einsum('ni,nj->nj', coef_2, base_point)) mask_same_values = gs.isclose(point, base_point) mask_else = gs.equal(mask_same_values, gs.array(False)) mask_else_float = gs.cast(mask_else, gs.float32) mask_else_float = gs.to_ndarray(mask_else_float, to_ndim=1) mask_else_float = gs.to_ndarray(mask_else_float, to_ndim=2) mask_not_same_points = gs.sum(mask_else_float, axis=1) mask_same_points = gs.isclose(mask_not_same_points, 0.) mask_same_points = gs.cast(mask_same_points, gs.float32) mask_same_points = gs.to_ndarray(mask_same_points, to_ndim=2, axis=1) mask_same_points_float = gs.cast(mask_same_points, gs.float32) log -= mask_same_points_float * log return log
def log(self, point, base_point, **kwargs): """Compute the Riemannian logarithm of a point. Parameters ---------- point : array-like, shape=[..., dim + 1] Point on the hypersphere. base_point : array-like, shape=[..., dim + 1] Point on the hypersphere. Returns ------- log : array-like, shape=[..., dim + 1] Tangent vector at the base point equal to the Riemannian logarithm of point at the base point. """ inner_prod = self.embedding_metric.inner_product(base_point, point) cos_angle = gs.clip(inner_prod, -1., 1.) squared_angle = gs.arccos(cos_angle) ** 2 coef_1_ = utils.taylor_exp_even_func( squared_angle, utils.inv_sinc_close_0, order=5) coef_2_ = utils.taylor_exp_even_func( squared_angle, utils.inv_tanc_close_0, order=5) log = (gs.einsum('...,...j->...j', coef_1_, point) - gs.einsum('...,...j->...j', coef_2_, base_point)) return log
def dist(self, point_a, point_b): r"""Geodesic distance between two points. The geodesic distance between two points :math: `x, y` corresponds to the Procrustes distance after alignment of the pre-shapes. It is computed with the formula: .. math: d(x, y) = arccos(tr(xy^T)) where tr is the trace operator. Parameters ---------- point_a : array-like, shape=[..., k_landmarks, m_ambient] Point. point_b : array-like, shape=[..., k_landmarks, m_ambient] Point. Returns ------- dist : array-like, shape=[...,] Distance. """ aligned = self.preshape.align(point_a, point_b) trace = gs.einsum('...ij,...ij->...', aligned, point_b) trace = gs.clip(trace, -1, 1) dist = gs.arccos(trace) return dist
def rotation_vector_from_quaternion(self, quaternion): """ Convert a unit quaternion into a rotation vector. """ assert self.n == 3, ('The quaternion representation does not exist' ' for rotations in %d dimensions.' % self.n) quaternion = gs.to_ndarray(quaternion, to_ndim=2) n_quaternions, _ = quaternion.shape cos_half_angle = quaternion[:, 0] cos_half_angle = gs.clip(cos_half_angle, -1, 1) half_angle = gs.arccos(cos_half_angle) half_angle = gs.to_ndarray(half_angle, to_ndim=2, axis=1) assert half_angle.shape == (n_quaternions, 1) rot_vec = gs.zeros_like(quaternion[:, 1:]) mask_0 = gs.isclose(half_angle, 0) mask_0 = gs.squeeze(mask_0, axis=1) mask_not_0 = ~mask_0 rotation_axis = (quaternion[mask_not_0, 1:] / gs.sin(half_angle[mask_not_0])) rot_vec[mask_not_0] = (2 * half_angle[mask_not_0] * rotation_axis) rot_vec = self.regularize(rot_vec) return rot_vec
def extrinsic_to_spherical(self, point_extrinsic): """Convert point from extrinsic to spherical coordinates. Convert from the extrinsic coordinates, i.e. embedded in Euclidean space of dim 3 to spherical coordinates in the hypersphere. Spherical coordinates are defined from the north pole, i.e. angles [0., 0.] correspond to point [0., 0., 1.]. Only implemented in dimension 2. Parameters ---------- point_extrinsic : array-like, shape=[..., dim] Point on the sphere, in extrinsic coordinates. Returns ------- point_spherical : array_like, shape=[..., dim + 1] Point on the sphere, in spherical coordinates relative to the north pole. """ if self.dim != 2: raise NotImplementedError( "The conversion from to extrinsic coordinates " "spherical coordinates is implemented" " only in dimension 2.") theta = gs.arccos(point_extrinsic[..., -1]) x = point_extrinsic[..., 0] y = point_extrinsic[..., 1] phi = gs.arctan2(y, x) phi = gs.where(phi < 0, phi + 2 * gs.pi, phi) return gs.stack([theta, phi], axis=-1)
def log(self, point, base_point, **kwargs): """Compute the Riemannian logarithm of a point. Parameters ---------- point : array-like, shape=[..., n_samples] Point on the hypersphere. base_point : array-like, shape=[..., n_samples] Point on the hypersphere. Returns ------- log : array-like, shape=[..., n_samples] Tangent vector at the base point equal to the Riemannian logarithm of point at the base point. """ inner_prod = self.inner_product(base_point, point) cos_angle = gs.clip(inner_prod, -1.0, 1.0) theta = gs.arccos(cos_angle) coef_1_ = utils.taylor_exp_even_func(theta, utils.inv_sinc_close_0, order=5) coef_2_ = utils.taylor_exp_even_func(theta, utils.inv_tanc_close_0, order=5) log = gs.einsum("...,...j->...j", theta * coef_1_, point) - gs.einsum( "...,...j->...j", theta * coef_2_, base_point ) return log
def rotation_vector_from_quaternion(self, quaternion): """Convert a unit quaternion into a rotation vector. Parameters ---------- quaternion : array-like, shape=[..., 4] Returns ------- rot_vec : array-like, shape=[..., 3] """ cos_half_angle = quaternion[:, 0] cos_half_angle = gs.clip(cos_half_angle, -1, 1) half_angle = gs.arccos(cos_half_angle) half_angle = gs.to_ndarray(half_angle, to_ndim=2, axis=1) mask_0 = gs.isclose(half_angle, 0.) mask_not_0 = ~mask_0 rotation_axis = gs.divide( quaternion[:, 1:], gs.sin(half_angle) * gs.cast(mask_not_0, gs.float32) + gs.cast(mask_0, gs.float32)) rot_vec = gs.array(2 * half_angle * rotation_axis * gs.cast(mask_not_0, gs.float32)) rot_vec = self.regularize(rot_vec) return rot_vec
def dist(self, point_a, point_b): """ Geodesic distance between two points. """ norm_a = self.embedding_metric.norm(point_a) norm_b = self.embedding_metric.norm(point_b) inner_prod = self.embedding_metric.inner_product(point_a, point_b) cos_angle = inner_prod / (norm_a * norm_b) cos_angle = gs.clip(cos_angle, -1, 1) dist = gs.arccos(cos_angle) return dist
def log(self, point, base_point): """ Compute the Riemannian logarithm at point base_point, of point wrt the metric obtained by embedding of the n-dimensional sphere in the (n+1)-dimensional euclidean space. This gives a tangent vector at point base_point. :param base_point: point on the n-dimensional sphere :param point: point on the n-dimensional sphere :return log: tangent vector at base_point """ point = gs.to_ndarray(point, to_ndim=2) base_point = gs.to_ndarray(base_point, to_ndim=2) norm_base_point = self.embedding_metric.norm(base_point) norm_point = self.embedding_metric.norm(point) inner_prod = self.embedding_metric.inner_product(base_point, point) cos_angle = inner_prod / (norm_base_point * norm_point) cos_angle = gs.clip(cos_angle, -1.0, 1.0) angle = gs.arccos(cos_angle) mask_0 = gs.isclose(angle, 0.0) mask_else = gs.equal(mask_0, False) coef_1 = gs.zeros_like(angle) coef_2 = gs.zeros_like(angle) coef_1[mask_0] = ( 1. + INV_SIN_TAYLOR_COEFFS[1] * angle[mask_0] ** 2 + INV_SIN_TAYLOR_COEFFS[3] * angle[mask_0] ** 4 + INV_SIN_TAYLOR_COEFFS[5] * angle[mask_0] ** 6 + INV_SIN_TAYLOR_COEFFS[7] * angle[mask_0] ** 8) coef_2[mask_0] = ( 1. + INV_TAN_TAYLOR_COEFFS[1] * angle[mask_0] ** 2 + INV_TAN_TAYLOR_COEFFS[3] * angle[mask_0] ** 4 + INV_TAN_TAYLOR_COEFFS[5] * angle[mask_0] ** 6 + INV_TAN_TAYLOR_COEFFS[7] * angle[mask_0] ** 8) coef_1[mask_else] = angle[mask_else] / gs.sin(angle[mask_else]) coef_2[mask_else] = angle[mask_else] / gs.tan(angle[mask_else]) log = (gs.einsum('ni,nj->nj', coef_1, point) - gs.einsum('ni,nj->nj', coef_2, base_point)) return log
def dist(self, point_a, point_b): """ Geodesic distance between two points. """ # TODO(nina): case gs.dot(unit_vec, unit_vec) != 1 # if gs.all(gs.equal(point_a, point_b)): # return 0. norm_a = self.embedding_metric.norm(point_a) norm_b = self.embedding_metric.norm(point_b) inner_prod = self.embedding_metric.inner_product(point_a, point_b) cos_angle = inner_prod / (norm_a * norm_b) cos_angle = gs.clip(cos_angle, -1, 1) dist = gs.arccos(cos_angle) return dist
def dist(self, point_a, point_b): """ Compute the Riemannian distance between points point_a and point_b. """ # TODO(xxx): case gs.dot(unit_vec, unit_vec) != 1 # if gs.all(gs.equal(point_a, point_b)): # return 0. norm_a = self.embedding_metric.norm(point_a) norm_b = self.embedding_metric.norm(point_b) inner_prod = self.embedding_metric.inner_product(point_a, point_b) cos_angle = inner_prod / (norm_a * norm_b) cos_angle = gs.clip(cos_angle, -1, 1) dist = gs.arccos(cos_angle) return dist
def log(self, point, base_point): """ Riemannian logarithm of a point wrt a base point. """ point = gs.to_ndarray(point, to_ndim=2) base_point = gs.to_ndarray(base_point, to_ndim=2) norm_base_point = self.embedding_metric.norm(base_point) norm_point = self.embedding_metric.norm(point) inner_prod = self.embedding_metric.inner_product(base_point, point) cos_angle = inner_prod / (norm_base_point * norm_point) cos_angle = gs.clip(cos_angle, -1.0, 1.0) angle = gs.arccos(cos_angle) mask_0 = gs.isclose(angle, 0.0) mask_else = gs.equal(mask_0, gs.cast(gs.array(False), gs.int8)) coef_1 = gs.zeros_like(angle) coef_2 = gs.zeros_like(angle) coef_1[mask_0] = ( 1. + INV_SIN_TAYLOR_COEFFS[1] * angle[mask_0] ** 2 + INV_SIN_TAYLOR_COEFFS[3] * angle[mask_0] ** 4 + INV_SIN_TAYLOR_COEFFS[5] * angle[mask_0] ** 6 + INV_SIN_TAYLOR_COEFFS[7] * angle[mask_0] ** 8) coef_2[mask_0] = ( 1. + INV_TAN_TAYLOR_COEFFS[1] * angle[mask_0] ** 2 + INV_TAN_TAYLOR_COEFFS[3] * angle[mask_0] ** 4 + INV_TAN_TAYLOR_COEFFS[5] * angle[mask_0] ** 6 + INV_TAN_TAYLOR_COEFFS[7] * angle[mask_0] ** 8) coef_1[mask_else] = angle[mask_else] / gs.sin(angle[mask_else]) coef_2[mask_else] = angle[mask_else] / gs.tan(angle[mask_else]) log = (gs.einsum('ni,nj->nj', coef_1, point) - gs.einsum('ni,nj->nj', coef_2, base_point)) return log
def random_von_mises_fisher(self, mu=None, kappa=10, n_samples=1, max_iter=100): """Sample with the von Mises-Fisher distribution. This distribution corresponds to the maximum entropy distribution given a mean. In dimension 2, a closed form expression is available. In larger dimension, rejection sampling is used according to [Wood94]_ References ---------- https://en.wikipedia.org/wiki/Von_Mises-Fisher_distribution .. [Wood94] Wood, Andrew T. A. “Simulation of the von Mises Fisher Distribution.” Communications in Statistics - Simulation and Computation, June 27, 2007. https://doi.org/10.1080/03610919408813161. Parameters ---------- mu : array-like, shape=[dim] Mean parameter of the distribution. kappa : float Kappa parameter of the von Mises distribution. Optional, default: 10. n_samples : int Number of samples. Optional, default: 1. Returns ------- point : array-like, shape=[..., 3] Points sampled on the sphere in extrinsic coordinates in Euclidean space of dimension 3. """ dim = self.dim if dim == 2: angle = 2. * gs.pi * gs.random.rand(n_samples) angle = gs.to_ndarray(angle, to_ndim=2, axis=1) unit_vector = gs.hstack((gs.cos(angle), gs.sin(angle))) scalar = gs.random.rand(n_samples) coord_z = 1. + 1. / kappa * gs.log(scalar + (1. - scalar) * gs.exp(gs.array(-2. * kappa))) coord_z = gs.to_ndarray(coord_z, to_ndim=2, axis=1) coord_xy = gs.sqrt(1. - coord_z**2) * unit_vector sample = gs.hstack((coord_xy, coord_z)) if mu is not None: rot_vec = gs.cross(gs.array([0., 0., 1.]), mu) rot_vec *= gs.arccos(mu[-1]) / gs.linalg.norm(rot_vec) rot = SpecialOrthogonal( 3, 'vector').matrix_from_rotation_vector(rot_vec) sample = gs.matmul(sample, gs.transpose(rot)) else: if mu is None: mu = gs.array([0.] * dim + [1.]) # rejection sampling in the general case sqrt = gs.sqrt(4 * kappa**2. + dim**2) envelop_param = (-2 * kappa + sqrt) / dim node = (1. - envelop_param) / (1. + envelop_param) correction = kappa * node + dim * gs.log(1. - node**2) n_accepted, n_iter = 0, 0 result = [] while (n_accepted < n_samples) and (n_iter < max_iter): sym_beta = beta.rvs(dim / 2, dim / 2, size=n_samples - n_accepted) coord_z = (1 - (1 + envelop_param) * sym_beta) / ( 1 - (1 - envelop_param) * sym_beta) accept_tol = gs.random.rand(n_samples - n_accepted) criterion = (kappa * coord_z + dim * gs.log(1 - node * coord_z) - correction) > gs.log(accept_tol) result.append(coord_z[criterion]) n_accepted += gs.sum(criterion) n_iter += 1 if n_accepted < n_samples: logging.warning( 'Maximum number of iteration reached in rejection ' 'sampling before n_samples were accepted.') coord_z = gs.concatenate(result) coord_rest = self.random_uniform(n_accepted) coord_rest = self.to_tangent(coord_rest, mu) coord_rest = self.projection(coord_rest) coord_rest = gs.einsum('...,...i->...i', gs.sqrt(1 - coord_z**2), coord_rest) sample = coord_rest + coord_z[:, None] * mu[None, :] return sample if n_samples > 1 else sample[0]
def rotation_vector_from_matrix(self, rot_mat): """ In 3D, convert rotation matrix to rotation vector (axis-angle representation). Get the angle through the trace of the rotation matrix: The eigenvalues are: 1, cos(angle) + i sin(angle), cos(angle) - i sin(angle) so that: trace = 1 + 2 cos(angle), -1 <= trace <= 3 Get the rotation vector through the formula: S_r = angle / ( 2 * sin(angle) ) (R - R^T) For the edge case where the angle is close to pi, the formulation is derived by going from rotation matrix to unit quaternion to axis-angle: r = angle * v / |v|, where (w, v) is a unit quaternion. In nD, the rotation vector stores the n(n-1)/2 values of the skew-symmetric matrix representing the rotation. """ rot_mat = gs.to_ndarray(rot_mat, to_ndim=3) n_rot_mats, mat_dim_1, mat_dim_2 = rot_mat.shape assert mat_dim_1 == mat_dim_2 == self.n rot_mat = closest_rotation_matrix(rot_mat) if self.n == 3: trace = gs.trace(rot_mat, axis1=1, axis2=2) trace = gs.to_ndarray(trace, to_ndim=2, axis=1) assert trace.shape == (n_rot_mats, 1), trace.shape cos_angle = .5 * (trace - 1) cos_angle = gs.clip(cos_angle, -1, 1) angle = gs.arccos(cos_angle) rot_mat_transpose = gs.transpose(rot_mat, axes=(0, 2, 1)) rot_vec = vector_from_skew_matrix(rot_mat - rot_mat_transpose) mask_0 = gs.isclose(angle, 0) mask_0 = gs.squeeze(mask_0, axis=1) rot_vec[mask_0] = (rot_vec[mask_0] * (.5 - (trace[mask_0] - 3.) / 12.)) mask_pi = gs.isclose(angle, gs.pi) mask_pi = gs.squeeze(mask_pi, axis=1) # choose the largest diagonal element # to avoid a square root of a negative number a = 0 if gs.any(mask_pi): a = gs.argmax(gs.diagonal(rot_mat[mask_pi], axis1=1, axis2=2)) b = gs.mod(a + 1, 3) c = gs.mod(a + 2, 3) # compute the axis vector sq_root = gs.sqrt( (rot_mat[mask_pi, a, a] - rot_mat[mask_pi, b, b] - rot_mat[mask_pi, c, c] + 1.)) rot_vec_pi = gs.zeros((sum(mask_pi), self.dimension)) rot_vec_pi[:, a] = sq_root / 2. rot_vec_pi[:, b] = ( (rot_mat[mask_pi, b, a] + rot_mat[mask_pi, a, b]) / (2. * sq_root)) rot_vec_pi[:, c] = ( (rot_mat[mask_pi, c, a] + rot_mat[mask_pi, a, c]) / (2. * sq_root)) rot_vec[mask_pi] = (angle[mask_pi] * rot_vec_pi / gs.linalg.norm(rot_vec_pi)) mask_else = ~mask_0 & ~mask_pi rot_vec[mask_else] = (angle[mask_else] / (2. * gs.sin(angle[mask_else])) * rot_vec[mask_else]) else: skew_mat = self.embedding_manifold.group_log_from_identity(rot_mat) rot_vec = vector_from_skew_matrix(skew_mat) return self.regularize(rot_vec)
def log(self, point, base_point): """Compute the Riemannian logarithm of a point. Parameters ---------- point : array-like, shape=[..., dim + 1] Point on the hypersphere. base_point : array-like, shape=[..., dim + 1] Point on the hypersphere. Returns ------- log : array-like, shape=[..., dim + 1] Tangent vector at the base point equal to the Riemannian logarithm of point at the base point. """ norm_base_point = self.embedding_metric.norm(base_point) norm_point = self.embedding_metric.norm(point) inner_prod = self.embedding_metric.inner_product(base_point, point) cos_angle = inner_prod / (norm_base_point * norm_point) cos_angle = gs.clip(cos_angle, -1., 1.) angle = gs.arccos(cos_angle) angle = gs.to_ndarray(angle, to_ndim=1) angle = gs.to_ndarray(angle, to_ndim=2, axis=1) mask_0 = gs.isclose(angle, 0.) mask_else = gs.equal(mask_0, gs.array(False)) mask_0_float = gs.cast(mask_0, gs.float32) mask_else_float = gs.cast(mask_else, gs.float32) coef_1 = gs.zeros_like(angle) coef_2 = gs.zeros_like(angle) coef_1 += mask_0_float * (1. + INV_SIN_TAYLOR_COEFFS[1] * angle**2 + INV_SIN_TAYLOR_COEFFS[3] * angle**4 + INV_SIN_TAYLOR_COEFFS[5] * angle**6 + INV_SIN_TAYLOR_COEFFS[7] * angle**8) coef_2 += mask_0_float * (1. + INV_TAN_TAYLOR_COEFFS[1] * angle**2 + INV_TAN_TAYLOR_COEFFS[3] * angle**4 + INV_TAN_TAYLOR_COEFFS[5] * angle**6 + INV_TAN_TAYLOR_COEFFS[7] * angle**8) # This avoids division by 0. angle += mask_0_float * 1. coef_1 += mask_else_float * angle / gs.sin(angle) coef_2 += mask_else_float * angle / gs.tan(angle) log = (gs.einsum('...i,...j->...j', coef_1, point) - gs.einsum('...i,...j->...j', coef_2, base_point)) mask_same_values = gs.isclose(point, base_point) mask_else = gs.equal(mask_same_values, gs.array(False)) mask_else_float = gs.cast(mask_else, gs.float32) mask_else_float = gs.to_ndarray(mask_else_float, to_ndim=1) mask_else_float = gs.to_ndarray(mask_else_float, to_ndim=2) mask_not_same_points = gs.sum(mask_else_float, axis=1) mask_same_points = gs.isclose(mask_not_same_points, 0.) mask_same_points = gs.cast(mask_same_points, gs.float32) mask_same_points = gs.to_ndarray(mask_same_points, to_ndim=2, axis=1) mask_same_points_float = gs.cast(mask_same_points, gs.float32) log -= mask_same_points_float * log return log
def rotation_vector_from_matrix(self, rot_mat): r"""Convert rotation matrix (in 3D) to rotation vector (axis-angle). Get the angle through the trace of the rotation matrix: The eigenvalues are: :math:`\{1, \cos(angle) + i \sin(angle), \cos(angle) - i \sin(angle)\}` so that: :math:`trace = 1 + 2 \cos(angle), \{-1 \leq trace \leq 3\}` Get the rotation vector through the formula: :math:`S_r = \frac{angle}{(2 * \sin(angle) ) (R - R^T)}` For the edge case where the angle is close to pi, the formulation is derived by using the following equality (see the Axis-angle representation on Wikipedia): :math:`outer(r, r) = \frac{1}{2} (R + I_3)` In nD, the rotation vector stores the :math:`n(n-1)/2` values of the skew-symmetric matrix representing the rotation. Parameters ---------- rot_mat : array-like, shape=[..., n, n] Returns ------- regularized_rot_vec : array-like, shape=[..., 3] """ n_rot_mats, _, _ = rot_mat.shape trace = gs.trace(rot_mat, axis1=1, axis2=2) trace = gs.to_ndarray(trace, to_ndim=2, axis=1) trace_num = gs.clip(trace, -1, 3) angle = gs.arccos(0.5 * (trace_num - 1)) rot_mat_transpose = gs.transpose(rot_mat, axes=(0, 2, 1)) rot_vec_not_pi = self.vector_from_skew_matrix(rot_mat - rot_mat_transpose) mask_0 = gs.cast(gs.isclose(angle, 0.), gs.float32) mask_pi = gs.cast(gs.isclose(angle, gs.pi, atol=1e-2), gs.float32) mask_else = (1 - mask_0) * (1 - mask_pi) numerator = 0.5 * mask_0 + angle * mask_else denominator = (1 - angle**2 / 6) * mask_0 + 2 * gs.sin(angle) * mask_else + mask_pi rot_vec_not_pi = rot_vec_not_pi * numerator / denominator vector_outer = 0.5 * (gs.eye(3) + rot_mat) gs.set_diag( vector_outer, gs.maximum(0., gs.diagonal(vector_outer, axis1=1, axis2=2))) squared_diag_comp = gs.diagonal(vector_outer, axis1=1, axis2=2) diag_comp = gs.sqrt(squared_diag_comp) norm_line = gs.linalg.norm(vector_outer, axis=2) max_line_index = gs.argmax(norm_line, axis=1) selected_line = gs.get_slice(vector_outer, (range(n_rot_mats), max_line_index)) signs = gs.sign(selected_line) rot_vec_pi = angle * signs * diag_comp rot_vec = rot_vec_not_pi + mask_pi * rot_vec_pi return self.regularize(rot_vec)