def squared_dist(self, point_a, point_b): """Compute the Bures-Wasserstein squared distance. Compute the Riemannian squared distance between point_a and point_b. Parameters ---------- point_a : array-like, shape=[..., n, n] Point. point_b : array-like, shape=[..., n, n] Point. Returns ------- squared_dist : array-like, shape=[...] Riemannian squared distance. """ product = gs.matmul(point_a, point_b) sqrt_product = gs.linalg.sqrtm(product) trace_a = gs.trace(point_a) trace_b = gs.trace(point_b) trace_prod = gs.trace(sqrt_product) result = trace_a + trace_b - 2 * trace_prod return result
def inner_product(self, tangent_vec_a, tangent_vec_b, base_point): """Compute the Euclidean inner product. Compute the inner product of tangent_vec_a and tangent_vec_b at point base_point using the power-Euclidean metric. Parameters ---------- tangent_vec_a : array-like, shape=[..., n, n] tangent_vec_b : array-like, shape=[..., n, n] base_point : array-like, shape=[..., n, n] Returns ------- inner_product : float """ power_euclidean = self.power_euclidean spd_space = self.space if power_euclidean == 1: product = gs.einsum('...ij,...jk->...ik', tangent_vec_a, tangent_vec_b) inner_product = gs.trace(product, axis1=-2, axis2=-1) else: modified_tangent_vec_a = spd_space.differential_power( power_euclidean, tangent_vec_a, base_point) modified_tangent_vec_b = spd_space.differential_power( power_euclidean, tangent_vec_b, base_point) product = gs.einsum('...ij,...jk->...ik', modified_tangent_vec_a, modified_tangent_vec_b) inner_product = gs.trace(product, axis1=-2, axis2=-1) \ / (power_euclidean ** 2) return inner_product
def inner_product_at_identity(self, tangent_vec_a, tangent_vec_b): """ Inner product matrix at the tangent space at the identity. """ assert self.group.point_representation in ('vector', 'matrix') if self.group.point_representation == 'vector': tangent_vec_a = gs.to_ndarray(tangent_vec_a, to_ndim=2) tangent_vec_b = gs.to_ndarray(tangent_vec_b, to_ndim=2) inner_prod = gs.einsum('ij,ijk,ik->i', tangent_vec_a, self.inner_product_mat_at_identity, tangent_vec_b) inner_prod = gs.to_ndarray(inner_prod, to_ndim=2, axis=1) elif self.group.point_representation == 'matrix': logging.warning( 'Only the canonical inner product -Frobenius inner product-' ' is implemented for Lie groups whose elements are represented' ' by matrices.') tangent_vec_a = gs.to_ndarray(tangent_vec_a, to_ndim=3) tangent_vec_b = gs.to_ndarray(tangent_vec_b, to_ndim=3) aux_prod = gs.matmul(gs.transpose(tangent_vec_a, axes=(0, 2, 1)), tangent_vec_b) inner_prod = gs.trace(aux_prod) return inner_prod
def test_inner_product(self): base_point = gs.array([ [1., 2., 3.], [0., 0., 0.], [3., 1., 1.]]) tangent_vector_1 = gs.array([ [1., 2., 3.], [0., -10., 0.], [30., 1., 1.]]) tangent_vector_2 = gs.array([ [1., 4., 3.], [5., 0., 0.], [3., 1., 1.]]) result = self.metric.inner_product( tangent_vector_1, tangent_vector_2, base_point=base_point) expected = gs.trace( gs.matmul( gs.transpose(tangent_vector_1), tangent_vector_2)) self.assertAllClose(result, expected)
def inner_product(self, tangent_vec_a, tangent_vec_b, base_point): r"""Compute the inner product on the tangent space at a base point. Canonical inner product on the tangent space at `base_point`, which is different from the inner product induced by the embedding. .. math:: \langle\Delta, \tilde{\Delta}\rangle_{U}=\operatorname{tr} \left(\Delta^{T}\left(I-\frac{1}{2} U U^{T}\right) \tilde{\Delta}\right) References ---------- .. [RLSMRZ2017] R Zimmermann. A matrix-algebraic algorithm for the Riemannian logarithm on the Stiefel manifold under the canonical metric. SIAM Journal on Matrix Analysis and Applications 38 (2), 322-342, 2017. https://epubs.siam.org/doi/pdf/10.1137/16M1074485 """ tangent_vec_a = gs.to_ndarray(tangent_vec_a, to_ndim=3) tangent_vec_b = gs.to_ndarray(tangent_vec_b, to_ndim=3) base_point = gs.to_ndarray(base_point, to_ndim=3) base_point_transpose = gs.transpose(base_point, axes=(0, 2, 1)) aux = gs.matmul( gs.transpose(tangent_vec_a, axes=(0, 2, 1)), gs.eye(self.n) - 0.5 * gs.matmul(base_point, base_point_transpose)) inner_prod = gs.trace(gs.matmul(aux, tangent_vec_b), axis1=1, axis2=2) inner_prod = gs.to_ndarray(inner_prod, to_ndim=2, axis=1) return inner_prod
def test_trace(self): base_list = [[[22., 55.], [33., 88.]], [[34., 12.], [67., 35.]]] np_array = _np.array(base_list) gs_array = gs.array(base_list) np_result = _np.trace(np_array) gs_result = gs.trace(gs_array) self.assertAllCloseToNp(gs_result, np_result) np_result = _np.trace(np_array, axis1=1, axis2=2) gs_result = gs.trace(gs_array, axis1=1, axis2=2) self.assertAllCloseToNp(gs_result, np_result) np_result = _np.trace(np_array, axis1=-1, axis2=-2) gs_result = gs.trace(gs_array, axis1=-1, axis2=-2) self.assertAllCloseToNp(gs_result, np_result)
def inner_product(self, tangent_vec_a, tangent_vec_b, base_point): """Compute the Procrustes inner-product. Compute the inner-product of tangent_vec_a and tangent_vec_b at point base_point using the Procrustes Riemannian metric. Parameters ---------- tangent_vec_a : array-like, shape=[..., n, n] Tangent vector at base point. tangent_vec_b : array-like, shape=[..., n, n] Tangent vector at base point. base_point : array-like, shape=[..., n, n] Base point. Returns ------- inner_product : array-like, shape=[...,] Inner-product. """ spd_space = self.space modified_tangent_vec_a =\ spd_space.inverse_differential_power(2, tangent_vec_a, base_point) product = gs.einsum('...ij,...jk->...ik', modified_tangent_vec_a, tangent_vec_b) result = gs.trace(product, axis1=-2, axis2=-1) / 2 return result
def inner_product(self, tangent_vec_a, tangent_vec_b, base_point): """Compute the Log-Euclidean inner-product. Compute the inner-product of tangent_vec_a and tangent_vec_b at point base_point using the log-Euclidean metric. Parameters ---------- tangent_vec_a : array-like, shape=[..., n, n] Tangent vector at base point. tangent_vec_b : array-like, shape=[..., n, n] Tangent vector at base point. base_point : array-like, shape=[..., n, n] Base point. Returns ------- inner_product : array-like, shape=[...,] Inner-product. """ spd_space = self.space modified_tangent_vec_a = spd_space.differential_log( tangent_vec_a, base_point) modified_tangent_vec_b = spd_space.differential_log( tangent_vec_b, base_point) product = gs.einsum('...ij,...jk->...ik', modified_tangent_vec_a, modified_tangent_vec_b) inner_product = gs.trace(product, axis1=-2, axis2=-1) return inner_product
def inner_product(self, tangent_vec_a, tangent_vec_b, base_point): """Compute the Log-Euclidean inner product. Compute the inner product of tangent_vec_a and tangent_vec_b at point base_point using the log-Euclidean metric. Parameters ---------- tangent_vec_a : array-like, shape=[n_samples, n, n] tangent_vec_b : array-like, shape=[n_samples, n, n] base_point : array-like, shape=[n_samples, n, n] Returns ------- inner_product : float """ tangent_vec_a = gs.to_ndarray(tangent_vec_a, to_ndim=3) n_tangent_vecs_a, _, _ = tangent_vec_a.shape tangent_vec_b = gs.to_ndarray(tangent_vec_b, to_ndim=3) n_tangent_vecs_b, _, _ = tangent_vec_b.shape base_point = gs.to_ndarray(base_point, to_ndim=3) n_base_points, _, _ = base_point.shape spd_space = self.space assert (n_tangent_vecs_a == n_tangent_vecs_b == n_base_points or n_tangent_vecs_a == n_tangent_vecs_b and n_base_points == 1 or n_base_points == n_tangent_vecs_a and n_tangent_vecs_b == 1 or n_base_points == n_tangent_vecs_b and n_tangent_vecs_a == 1 or n_tangent_vecs_a == 1 and n_tangent_vecs_b == 1 or n_base_points == 1 and n_tangent_vecs_a == 1 or n_base_points == 1 and n_tangent_vecs_b == 1) if n_tangent_vecs_a == 1: tangent_vec_a = gs.tile( tangent_vec_a, (gs.maximum(n_base_points, n_tangent_vecs_b), 1, 1)) if n_tangent_vecs_b == 1: tangent_vec_b = gs.tile( tangent_vec_b, (gs.maximum(n_base_points, n_tangent_vecs_a), 1, 1)) if n_base_points == 1: base_point = gs.tile( base_point, (gs.maximum(n_tangent_vecs_a, n_tangent_vecs_b), 1, 1)) modified_tangent_vec_a = spd_space.differential_log( tangent_vec_a, base_point) modified_tangent_vec_b = spd_space.differential_log( tangent_vec_b, base_point) product = gs.matmul(modified_tangent_vec_a, modified_tangent_vec_b) inner_product = gs.trace(product, axis1=1, axis2=2) inner_product = gs.to_ndarray(inner_product, to_ndim=2, axis=1) return inner_product
def inner_product_at_identity(self, tangent_vec_a, tangent_vec_b): """Compute inner product matrix at tangent space at identity. Parameters ---------- tangent_vec_a tangent_vec_b Returns ------- inner_prod """ assert self.group.default_point_type in ('vector', 'matrix') if self.group.default_point_type == 'vector': tangent_vec_a = gs.to_ndarray(tangent_vec_a, to_ndim=2) tangent_vec_b = gs.to_ndarray(tangent_vec_b, to_ndim=2) n_tangent_vec_a = tangent_vec_a.shape[0] n_tangent_vec_b = tangent_vec_b.shape[0] assert (tangent_vec_a.shape == tangent_vec_b.shape or n_tangent_vec_a == 1 or n_tangent_vec_b == 1) if n_tangent_vec_a == 1: tangent_vec_a = gs.array([tangent_vec_a[0]] * n_tangent_vec_b) if n_tangent_vec_b == 1: tangent_vec_b = gs.array([tangent_vec_b[0]] * n_tangent_vec_a) inner_product_mat_at_identity = gs.array( [self.inner_product_mat_at_identity[0]] * max(n_tangent_vec_a, n_tangent_vec_b)) tangent_vec_a = gs.to_ndarray(tangent_vec_a, to_ndim=2) tangent_vec_b = gs.to_ndarray(tangent_vec_b, to_ndim=2) inner_product_mat_at_identity = gs.to_ndarray( inner_product_mat_at_identity, to_ndim=3) inner_prod = gs.einsum('nj,njk,nk->n', tangent_vec_a, inner_product_mat_at_identity, tangent_vec_b) inner_prod = gs.to_ndarray(inner_prod, to_ndim=2, axis=1) elif self.group.default_point_type == 'matrix': logging.warning( 'Only the canonical inner product -Frobenius inner product-' ' is implemented for Lie groups whose elements are represented' ' by matrices.') tangent_vec_a = gs.to_ndarray(tangent_vec_a, to_ndim=3) tangent_vec_b = gs.to_ndarray(tangent_vec_b, to_ndim=3) aux_prod = gs.matmul(gs.transpose(tangent_vec_a, axes=(0, 2, 1)), tangent_vec_b) inner_prod = gs.trace(aux_prod) return inner_prod
def inner_product(self, tangent_vec_a, tangent_vec_b, base_point): """ Compute the inner product of tangent_vec_a and tangent_vec_b at point base_point using the affine invariant Riemannian metric. """ inv_base_point = gs.linalg.inv(base_point) aux_a = gs.matmul(inv_base_point, tangent_vec_a) aux_b = gs.matmul(inv_base_point, tangent_vec_b) inner_product = gs.trace(gs.matmul(aux_a, aux_b), axis1=1, axis2=2) inner_product = gs.to_ndarray(inner_product, to_ndim=2, axis=1) return inner_product
def inner_product(self, tangent_vec_a, tangent_vec_b, base_point): """Compute the Procrustes inner product. Compute the inner product of tangent_vec_a and tangent_vec_b at point base_point using the Procrustes Riemannian metric. Parameters ---------- tangent_vec_a : array-like, shape=[n_samples, n, n] tangent_vec_b : array-like, shape=[n_samples, n, n] base_point : array-like, shape=[n_samples, n, n] Returns ------- inner_product : float """ tangent_vec_a = gs.to_ndarray(tangent_vec_a, to_ndim=3) n_tangent_vecs_a, _, _ = tangent_vec_a.shape tangent_vec_b = gs.to_ndarray(tangent_vec_b, to_ndim=3) n_tangent_vecs_b, _, _ = tangent_vec_b.shape base_point = gs.to_ndarray(base_point, to_ndim=3) n_base_points, _, _ = base_point.shape assert (n_tangent_vecs_a == n_tangent_vecs_b == n_base_points or n_tangent_vecs_a == n_tangent_vecs_b and n_base_points == 1 or n_base_points == n_tangent_vecs_a and n_tangent_vecs_b == 1 or n_base_points == n_tangent_vecs_b and n_tangent_vecs_a == 1 or n_tangent_vecs_a == 1 and n_tangent_vecs_b == 1 or n_base_points == 1 and n_tangent_vecs_a == 1 or n_base_points == 1 and n_tangent_vecs_b == 1) if n_tangent_vecs_a == 1: tangent_vec_a = gs.tile( tangent_vec_a, (gs.maximum(n_base_points, n_tangent_vecs_b), 1, 1)) if n_tangent_vecs_b == 1: tangent_vec_b = gs.tile( tangent_vec_b, (gs.maximum(n_base_points, n_tangent_vecs_a), 1, 1)) if n_base_points == 1: base_point = gs.tile( base_point, (gs.maximum(n_tangent_vecs_a, n_tangent_vecs_b), 1, 1)) spd_space = self.space modified_tangent_vec_a =\ spd_space.inverse_differential_power(2, tangent_vec_a, base_point) product = gs.matmul(modified_tangent_vec_a, tangent_vec_b) result = gs.trace(product, axis1=1, axis2=2) / 2 return result
def _aux_inner_product(self, tangent_vec_a, tangent_vec_b, inv_base_point): """Compute the inner product (auxiliary). Parameters ---------- tangent_vec_a : array-like, shape=[n_samples, n, n] tangent_vec_b : array-like, shape=[n_samples, n, n] inv_base_point : array-like, shape=[n_samples, n, n] Returns ------- inner_product : array-like, shape=[n_samples, n, n] """ aux_a = gs.matmul(inv_base_point, tangent_vec_a) aux_b = gs.matmul(inv_base_point, tangent_vec_b) inner_product = gs.trace(gs.matmul(aux_a, aux_b), axis1=1, axis2=2) return inner_product
def _aux_inner_product(tangent_vec_a, tangent_vec_b, inv_base_point): """Compute the inner-product (auxiliary). Parameters ---------- tangent_vec_a : array-like, shape=[..., n, n] tangent_vec_b : array-like, shape=[..., n, n] inv_base_point : array-like, shape=[..., n, n] Returns ------- inner_product : array-like, shape=[..., n, n] """ aux_a = gs.einsum('...ij,...jk->...ik', inv_base_point, tangent_vec_a) aux_b = gs.einsum('...ij,...jk->...ik', inv_base_point, tangent_vec_b) prod = gs.einsum('...ij,...jk->...ik', aux_a, aux_b) inner_product = gs.trace(prod, axis1=-2, axis2=-1) return inner_product
def inner_product(self, tangent_vec_a, tangent_vec_b, base_point): """ Compute the inner product of tangent_vec_a and tangent_vec_b at point base_point using the affine invariant Riemannian metric. """ tangent_vec_a = gs.to_ndarray(tangent_vec_a, to_ndim=3) n_tangent_vecs_a, _, _ = tangent_vec_a.shape tangent_vec_b = gs.to_ndarray(tangent_vec_b, to_ndim=3) n_tangent_vecs_b, _, _ = tangent_vec_b.shape base_point = gs.to_ndarray(base_point, to_ndim=3) n_base_points, _, _ = base_point.shape assert (n_tangent_vecs_a == n_tangent_vecs_b == n_base_points or n_tangent_vecs_a == n_tangent_vecs_b and n_base_points == 1 or n_base_points == n_tangent_vecs_a and n_tangent_vecs_b == 1 or n_base_points == n_tangent_vecs_b and n_tangent_vecs_a == 1 or n_tangent_vecs_a == 1 and n_tangent_vecs_b == 1 or n_base_points == 1 and n_tangent_vecs_a == 1 or n_base_points == 1 and n_tangent_vecs_b == 1) if n_tangent_vecs_a == 1: tangent_vec_a = gs.tile( tangent_vec_a, (gs.maximum(n_base_points, n_tangent_vecs_b), 1, 1)) if n_tangent_vecs_b == 1: tangent_vec_b = gs.tile( tangent_vec_b, (gs.maximum(n_base_points, n_tangent_vecs_a), 1, 1)) if n_base_points == 1: base_point = gs.tile( base_point, (gs.maximum(n_tangent_vecs_a, n_tangent_vecs_b), 1, 1)) inv_base_point = gs.linalg.inv(base_point) aux_a = gs.matmul(inv_base_point, tangent_vec_a) aux_b = gs.matmul(inv_base_point, tangent_vec_b) inner_product = gs.trace(gs.matmul(aux_a, aux_b), axis1=1, axis2=2) inner_product = gs.to_ndarray(inner_product, to_ndim=2, axis=1) return inner_product
def inner_product(self, tangent_vec_a, tangent_vec_b, base_point): """ Canonical inner product on the tangent space at base_point, which is different from the inner product induced by the embedding. Formula from: http://noodle.med.yale.edu/hdtag/notes/steifel_notes.pdf """ tangent_vec_a = gs.to_ndarray(tangent_vec_a, to_ndim=3) tangent_vec_b = gs.to_ndarray(tangent_vec_b, to_ndim=3) base_point = gs.to_ndarray(base_point, to_ndim=3) base_point_transpose = gs.transpose(base_point, axes=(0, 2, 1)) aux = gs.matmul( gs.transpose(tangent_vec_a, axes=(0, 2, 1)), gs.eye(self.n) - 0.5 * gs.matmul(base_point, base_point_transpose)) inner_prod = gs.trace(gs.matmul(aux, tangent_vec_b), axis1=1, axis2=2) inner_prod = gs.to_ndarray(inner_prod, to_ndim=2, axis=1) return inner_prod
def rotation_vector_from_matrix(self, rot_mat): """ In 3D, convert rotation matrix to rotation vector (axis-angle representation). Get the angle through the trace of the rotation matrix: The eigenvalues are: 1, cos(angle) + i sin(angle), cos(angle) - i sin(angle) so that: trace = 1 + 2 cos(angle), -1 <= trace <= 3 Get the rotation vector through the formula: S_r = angle / ( 2 * sin(angle) ) (R - R^T) For the edge case where the angle is close to pi, the formulation is derived by going from rotation matrix to unit quaternion to axis-angle: r = angle * v / |v|, where (w, v) is a unit quaternion. In nD, the rotation vector stores the n(n-1)/2 values of the skew-symmetric matrix representing the rotation. """ rot_mat = gs.to_ndarray(rot_mat, to_ndim=3) n_rot_mats, mat_dim_1, mat_dim_2 = rot_mat.shape assert mat_dim_1 == mat_dim_2 == self.n rot_mat = closest_rotation_matrix(rot_mat) if self.n == 3: trace = gs.trace(rot_mat, axis1=1, axis2=2) trace = gs.to_ndarray(trace, to_ndim=2, axis=1) assert trace.shape == (n_rot_mats, 1), trace.shape cos_angle = .5 * (trace - 1) cos_angle = gs.clip(cos_angle, -1, 1) angle = gs.arccos(cos_angle) rot_mat_transpose = gs.transpose(rot_mat, axes=(0, 2, 1)) rot_vec = vector_from_skew_matrix(rot_mat - rot_mat_transpose) mask_0 = gs.isclose(angle, 0) mask_0 = gs.squeeze(mask_0, axis=1) rot_vec[mask_0] = (rot_vec[mask_0] * (.5 - (trace[mask_0] - 3.) / 12.)) mask_pi = gs.isclose(angle, gs.pi) mask_pi = gs.squeeze(mask_pi, axis=1) # choose the largest diagonal element # to avoid a square root of a negative number a = 0 if gs.any(mask_pi): a = gs.argmax(gs.diagonal(rot_mat[mask_pi], axis1=1, axis2=2)) b = gs.mod(a + 1, 3) c = gs.mod(a + 2, 3) # compute the axis vector sq_root = gs.sqrt( (rot_mat[mask_pi, a, a] - rot_mat[mask_pi, b, b] - rot_mat[mask_pi, c, c] + 1.)) rot_vec_pi = gs.zeros((sum(mask_pi), self.dimension)) rot_vec_pi[:, a] = sq_root / 2. rot_vec_pi[:, b] = ( (rot_mat[mask_pi, b, a] + rot_mat[mask_pi, a, b]) / (2. * sq_root)) rot_vec_pi[:, c] = ( (rot_mat[mask_pi, c, a] + rot_mat[mask_pi, a, c]) / (2. * sq_root)) rot_vec[mask_pi] = (angle[mask_pi] * rot_vec_pi / gs.linalg.norm(rot_vec_pi)) mask_else = ~mask_0 & ~mask_pi rot_vec[mask_else] = (angle[mask_else] / (2. * gs.sin(angle[mask_else])) * rot_vec[mask_else]) else: skew_mat = self.embedding_manifold.group_log_from_identity(rot_mat) rot_vec = vector_from_skew_matrix(skew_mat) return self.regularize(rot_vec)
def _aux_inner_product(self, tangent_vec_a, tangent_vec_b, inv_base_point): aux_a = gs.matmul(inv_base_point, tangent_vec_a) aux_b = gs.matmul(inv_base_point, tangent_vec_b) inner_product = gs.trace(gs.matmul(aux_a, aux_b), axis1=1, axis2=2) return inner_product
def rotation_vector_from_matrix(self, rot_mat): r"""Convert rotation matrix (in 3D) to rotation vector (axis-angle). Get the angle through the trace of the rotation matrix: The eigenvalues are: :math:`\{1, \cos(angle) + i \sin(angle), \cos(angle) - i \sin(angle)\}` so that: :math:`trace = 1 + 2 \cos(angle), \{-1 \leq trace \leq 3\}` Get the rotation vector through the formula: :math:`S_r = \frac{angle}{(2 * \sin(angle) ) (R - R^T)}` For the edge case where the angle is close to pi, the formulation is derived by using the following equality (see the Axis-angle representation on Wikipedia): :math:`outer(r, r) = \frac{1}{2} (R + I_3)` In nD, the rotation vector stores the :math:`n(n-1)/2` values of the skew-symmetric matrix representing the rotation. Parameters ---------- rot_mat : array-like, shape=[..., n, n] Returns ------- regularized_rot_vec : array-like, shape=[..., 3] """ n_rot_mats, _, _ = rot_mat.shape trace = gs.trace(rot_mat, axis1=1, axis2=2) trace = gs.to_ndarray(trace, to_ndim=2, axis=1) trace_num = gs.clip(trace, -1, 3) angle = gs.arccos(0.5 * (trace_num - 1)) rot_mat_transpose = gs.transpose(rot_mat, axes=(0, 2, 1)) rot_vec_not_pi = self.vector_from_skew_matrix(rot_mat - rot_mat_transpose) mask_0 = gs.cast(gs.isclose(angle, 0.), gs.float32) mask_pi = gs.cast(gs.isclose(angle, gs.pi, atol=1e-2), gs.float32) mask_else = (1 - mask_0) * (1 - mask_pi) numerator = 0.5 * mask_0 + angle * mask_else denominator = (1 - angle**2 / 6) * mask_0 + 2 * gs.sin(angle) * mask_else + mask_pi rot_vec_not_pi = rot_vec_not_pi * numerator / denominator vector_outer = 0.5 * (gs.eye(3) + rot_mat) gs.set_diag( vector_outer, gs.maximum(0., gs.diagonal(vector_outer, axis1=1, axis2=2))) squared_diag_comp = gs.diagonal(vector_outer, axis1=1, axis2=2) diag_comp = gs.sqrt(squared_diag_comp) norm_line = gs.linalg.norm(vector_outer, axis=2) max_line_index = gs.argmax(norm_line, axis=1) selected_line = gs.get_slice(vector_outer, (range(n_rot_mats), max_line_index)) signs = gs.sign(selected_line) rot_vec_pi = angle * signs * diag_comp rot_vec = rot_vec_not_pi + mask_pi * rot_vec_pi return self.regularize(rot_vec)