Exemplo n.º 1
0
class TestGeneralLinearGroupMethods(geomstats.tests.TestCase):
    _multiprocess_can_split_ = True

    def setUp(self):
        gs.random.seed(1234)
        self.n = 3
        self.n_samples = 2
        self.group = GeneralLinearGroup(n=self.n)
        # We generate invertible matrices using so3_group
        self.so3_group = SpecialOrthogonalGroup(n=self.n)

        warnings.simplefilter('ignore', category=ImportWarning)

    @geomstats.tests.np_only
    def test_belongs(self):
        """
        A rotation matrix belongs to the matrix Lie group
        of invertible matrices.
        """
        rot_vec = gs.array([0.2, -0.1, 0.1])
        rot_mat = self.so3_group.matrix_from_rotation_vector(rot_vec)
        result = self.group.belongs(rot_mat)
        expected = gs.array([True])

        self.assertAllClose(result, expected)

    def test_compose(self):
        # 1. Composition by identity, on the right
        # Expect the original transformation
        rot_vec = gs.array([0.2, -0.1, 0.1])
        mat = self.so3_group.matrix_from_rotation_vector(rot_vec)

        result = self.group.compose(mat, self.group.identity)
        expected = mat
        expected = helper.to_matrix(mat)

        self.assertAllClose(result, expected)

        # 2. Composition by identity, on the left
        # Expect the original transformation
        rot_vec = gs.array([0.2, 0.1, -0.1])
        mat = self.so3_group.matrix_from_rotation_vector(rot_vec)

        result = self.group.compose(self.group.identity, mat)
        expected = mat

        self.assertAllClose(result, expected)

    def test_inverse(self):
        mat = gs.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 10.]])
        result = self.group.inverse(mat)
        expected = 1. / 3. * gs.array([[-2., -4., 3.], [-2., 11., -6.],
                                       [3., -6., 3.]])
        expected = helper.to_matrix(expected)

        self.assertAllClose(result, expected)

    def test_compose_and_inverse(self):
        # 1. Compose transformation by its inverse on the right
        # Expect the group identity
        rot_vec = gs.array([0.2, 0.1, 0.1])
        mat = self.so3_group.matrix_from_rotation_vector(rot_vec)
        inv_mat = self.group.inverse(mat)

        result = self.group.compose(mat, inv_mat)
        expected = self.group.identity
        expected = helper.to_matrix(expected)

        self.assertAllClose(result, expected)

        # 2. Compose transformation by its inverse on the left
        # Expect the group identity
        rot_vec = gs.array([0.7, 0.1, 0.1])
        mat = self.so3_group.matrix_from_rotation_vector(rot_vec)
        inv_mat = self.group.inverse(mat)

        result = self.group.compose(inv_mat, mat)
        expected = self.group.identity
        expected = helper.to_matrix(expected)

        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_group_log_and_exp(self):
        point = 5 * gs.eye(self.n)

        group_log = self.group.group_log(point)
        result = self.group.group_exp(group_log)
        expected = point
        expected = helper.to_matrix(expected)

        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_group_exp_vectorization(self):
        point = gs.array([[[2., 0., 0.], [0., 3., 0.], [0., 0., 4.]],
                          [[1., 0., 0.], [0., 5., 0.], [0., 0., 6.]]])

        expected = gs.array([[[7.38905609, 0., 0.], [0., 20.0855369, 0.],
                              [0., 0., 54.5981500]],
                             [[2.718281828, 0., 0.], [0., 148.413159, 0.],
                              [0., 0., 403.42879349]]])

        result = self.group.group_exp(point)

        self.assertAllClose(result, expected, rtol=1e-3)

    @geomstats.tests.np_and_tf_only
    def test_group_log_vectorization(self):
        point = gs.array([[[2., 0., 0.], [0., 3., 0.], [0., 0., 4.]],
                          [[1., 0., 0.], [0., 5., 0.], [0., 0., 6.]]])

        expected = gs.array([[[0.693147180, 0., 0.], [0., 1.09861228866, 0.],
                              [0., 0., 1.38629436]],
                             [[0., 0., 0.], [0., 1.609437912, 0.],
                              [0., 0., 1.79175946]]])

        result = self.group.group_log(point)

        self.assertAllClose(result, expected, atol=1e-4)

    @geomstats.tests.np_and_tf_only
    def test_expm_and_logm_vectorization_symmetric(self):
        point = gs.array([[[2., 0., 0.], [0., 3., 0.], [0., 0., 4.]],
                          [[1., 0., 0.], [0., 5., 0.], [0., 0., 6.]]])
        result = self.group.group_exp(self.group.group_log(point))
        expected = point

        self.assertAllClose(result, expected)
Exemplo n.º 2
0
class SpecialEuclideanGroup(LieGroup):
    """
    Class for the special euclidean group SE(n),
    i.e. the Lie group of rigid transformations.
    """

    def __init__(self, n):
        assert isinstance(n, int) and n > 1

        self.n = n
        self.dimension = int((n * (n - 1)) / 2 + n)
        super(SpecialEuclideanGroup, self).__init__(
                          dimension=self.dimension,
                          identity=gs.zeros(self.dimension))
        # TODO(nina): keep the names rotations and translations here?
        self.rotations = SpecialOrthogonalGroup(n=n)
        self.translations = EuclideanSpace(dimension=n)
        self.point_representation = 'vector' if n == 3 else 'matrix'

    def belongs(self, point):
        """
        Evaluate if a point belongs to SE(n).
        """
        point = gs.to_ndarray(point, to_ndim=2)
        _, point_dim = point.shape
        return point_dim == self.dimension

    def regularize(self, point):
        """
        Regularize a point to the canonical representation
        chosen for SE(n).
        """
        assert self.point_representation == 'vector'

        point = gs.to_ndarray(point, to_ndim=2)
        assert self.belongs(point)

        rotations = self.rotations
        dim_rotations = rotations.dimension

        regularized_point = gs.zeros_like(point)
        rot_vec = point[:, :dim_rotations]
        regularized_point[:, :dim_rotations] = rotations.regularize(rot_vec)
        regularized_point[:, dim_rotations:] = point[:, dim_rotations:]

        return regularized_point

    def regularize_tangent_vec_at_identity(self, tangent_vec, metric=None):
        return self.regularize_tangent_vec(tangent_vec, self.identity, metric)

    def regularize_tangent_vec(self, tangent_vec, base_point, metric=None):
        if metric is None:
            metric = self.left_canonical_metric

        tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2)
        base_point = gs.to_ndarray(base_point, to_ndim=2)

        rotations = self.rotations
        dim_rotations = rotations.dimension

        rot_tangent_vec = tangent_vec[:, :dim_rotations]
        rot_base_point = base_point[:, :dim_rotations]

        metric_mat = metric.inner_product_mat_at_identity
        rot_metric_mat = metric_mat[:, :dim_rotations, :dim_rotations]
        rot_metric = InvariantMetric(
                               group=rotations,
                               inner_product_mat_at_identity=rot_metric_mat,
                               left_or_right=metric.left_or_right)

        regularized_vec = gs.zeros_like(tangent_vec)
        regularized_vec[:, :dim_rotations] = rotations.regularize_tangent_vec(
                                                 tangent_vec=rot_tangent_vec,
                                                 base_point=rot_base_point,
                                                 metric=rot_metric)
        regularized_vec[:, dim_rotations:] = tangent_vec[:, dim_rotations:]

        return regularized_vec

    def compose(self, point_1, point_2):
        """
        Compose two elements of SE(n).

        Formula:
        point_1 . point_2 = [R1 * R2, (R1 * t2) + t1]
        where:
        R1, R2 are rotation matrices,
        t1, t2 are translation vectors.
        """
        rotations = self.rotations
        dim_rotations = rotations.dimension

        point_1 = self.regularize(point_1)
        point_2 = self.regularize(point_2)

        n_points_1, _ = point_1.shape
        n_points_2, _ = point_2.shape

        assert (point_1.shape == point_2.shape
                or n_points_1 == 1
                or n_points_2 == 1)

        rot_vec_1 = point_1[:, :dim_rotations]
        rot_mat_1 = rotations.matrix_from_rotation_vector(rot_vec_1)
        rot_mat_1 = so_group.closest_rotation_matrix(rot_mat_1)

        rot_vec_2 = point_2[:, :dim_rotations]
        rot_mat_2 = rotations.matrix_from_rotation_vector(rot_vec_2)
        rot_mat_2 = so_group.closest_rotation_matrix(rot_mat_2)

        translation_1 = point_1[:, dim_rotations:]
        translation_2 = point_2[:, dim_rotations:]

        n_compositions = gs.maximum(n_points_1, n_points_2)
        composition_rot_mat = gs.matmul(rot_mat_1, rot_mat_2)
        composition_rot_vec = rotations.rotation_vector_from_matrix(
                                                          composition_rot_mat)
        composition_translation = gs.zeros((n_compositions, self.n))
        for i in range(n_compositions):
            translation_1_i = (translation_1[0] if n_points_1 == 1
                               else translation_1[i])
            rot_mat_1_i = (rot_mat_1[0] if n_points_1 == 1
                           else rot_mat_1[i])
            translation_2_i = (translation_2[0] if n_points_2 == 1
                               else translation_2[i])
            composition_translation[i] = (gs.dot(translation_2_i,
                                                 gs.transpose(rot_mat_1_i))
                                          + translation_1_i)

        composition = gs.zeros((n_compositions, self.dimension))
        composition[:, :dim_rotations] = composition_rot_vec
        composition[:, dim_rotations:] = composition_translation

        composition = self.regularize(composition)
        return composition

    def inverse(self, point):
        """
        Compute the group inverse in SE(n).

        Formula:
        (R, t)^{-1} = (R^{-1}, R^{-1}.(-t))
        """
        rotations = self.rotations
        dim_rotations = rotations.dimension

        point = self.regularize(point)
        n_points, _ = point.shape

        rot_vec = point[:, :dim_rotations]
        translation = point[:, dim_rotations:]

        inverse_point = gs.zeros_like(point)
        inverse_rotation = -rot_vec

        inv_rot_mat = rotations.matrix_from_rotation_vector(inverse_rotation)

        inverse_translation = gs.zeros((n_points, self.n))
        for i in range(n_points):
            inverse_translation[i] = gs.dot(-translation[i],
                                            gs.transpose(inv_rot_mat[i]))

        inverse_point[:, :dim_rotations] = inverse_rotation
        inverse_point[:, dim_rotations:] = inverse_translation

        inverse_point = self.regularize(inverse_point)
        return inverse_point

    def jacobian_translation(self, point, left_or_right='left'):
        """
        Compute the jacobian matrix of the differential
        of the left/right translations from the identity to point in SE(n).
        """
        assert self.belongs(point)
        assert left_or_right in ('left', 'right')

        dim = self.dimension
        rotations = self.rotations
        dim_rotations = rotations.dimension

        point = self.regularize(point)
        n_points, _ = point.shape

        rot_vec = point[:, :dim_rotations]

        jacobian = gs.zeros((n_points,) + (dim,) * 2)

        if left_or_right == 'left':
            jacobian_rot = self.rotations.jacobian_translation(
                                                      point=rot_vec,
                                                      left_or_right='left')
            jacobian_trans = self.rotations.matrix_from_rotation_vector(
                    rot_vec)

            jacobian[:, :dim_rotations, :dim_rotations] = jacobian_rot
            jacobian[:, dim_rotations:, dim_rotations:] = jacobian_trans

        else:
            jacobian_rot = self.rotations.jacobian_translation(
                                                      point=rot_vec,
                                                      left_or_right='right')

            inv_skew_mat = - so_group.skew_matrix_from_vector(rot_vec)
            jacobian[:, :dim_rotations, :dim_rotations] = jacobian_rot
            jacobian[:, dim_rotations:, :dim_rotations] = inv_skew_mat
            jacobian[:, dim_rotations:, dim_rotations:] = gs.eye(self.n)

        assert jacobian.ndim == 3
        return jacobian

    def group_exp_from_identity(self, tangent_vec):
        """
        Compute the group exponential of the tangent vector at the identity.
        """
        tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2)

        rotations = self.rotations
        dim_rotations = rotations.dimension

        rot_vec = tangent_vec[:, :dim_rotations]
        rot_vec = self.rotations.regularize(rot_vec)
        translation = tangent_vec[:, dim_rotations:]

        angle = gs.linalg.norm(rot_vec, axis=1)
        angle = gs.to_ndarray(angle, to_ndim=2, axis=1)

        mask_close_pi = gs.isclose(angle, gs.pi)
        mask_close_pi = gs.squeeze(mask_close_pi, axis=1)
        rot_vec[mask_close_pi] = rotations.regularize(
                                       rot_vec[mask_close_pi])

        skew_mat = so_group.skew_matrix_from_vector(rot_vec)
        sq_skew_mat = gs.matmul(skew_mat, skew_mat)

        mask_0 = gs.equal(angle, 0)
        mask_close_0 = gs.isclose(angle, 0) & ~mask_0

        mask_0 = gs.squeeze(mask_0, axis=1)
        mask_close_0 = gs.squeeze(mask_close_0, axis=1)

        mask_else = ~mask_0 & ~mask_close_0

        coef_1 = gs.zeros_like(angle)
        coef_2 = gs.zeros_like(angle)

        coef_1[mask_0] = 1. / 2.
        coef_2[mask_0] = 1. / 6.

        coef_1[mask_close_0] = (1. / 2. - angle[mask_close_0] ** 2 / 24.
                                + angle[mask_close_0] ** 4 / 720.
                                - angle[mask_close_0] ** 6 / 40320.)
        coef_2[mask_close_0] = (1. / 6. - angle[mask_close_0] ** 2 / 120.
                                + angle[mask_close_0] ** 4 / 5040.
                                - angle[mask_close_0] ** 6 / 362880.)

        coef_1[mask_else] = ((1. - gs.cos(angle[mask_else]))
                             / angle[mask_else] ** 2)
        coef_2[mask_else] = ((angle[mask_else] - gs.sin(angle[mask_else]))
                             / angle[mask_else] ** 3)

        n_tangent_vecs, _ = tangent_vec.shape
        group_exp_translation = gs.zeros((n_tangent_vecs, self.n))
        for i in range(n_tangent_vecs):
            translation_i = translation[i]
            term_1_i = coef_1[i] * gs.dot(translation_i,
                                          gs.transpose(skew_mat[i]))
            term_2_i = coef_2[i] * gs.dot(translation_i,
                                          gs.transpose(sq_skew_mat[i]))

            group_exp_translation[i] = translation_i + term_1_i + term_2_i

        group_exp = gs.zeros_like(tangent_vec)
        group_exp[:, :dim_rotations] = rot_vec
        group_exp[:, dim_rotations:] = group_exp_translation

        group_exp = self.regularize(group_exp)
        return group_exp

    def group_log_from_identity(self, point):
        """
        Compute the group logarithm of the point at the identity.
        """
        assert self.belongs(point)
        point = self.regularize(point)

        rotations = self.rotations
        dim_rotations = rotations.dimension

        rot_vec = point[:, :dim_rotations]
        angle = gs.linalg.norm(rot_vec, axis=1)
        angle = gs.to_ndarray(angle, to_ndim=2, axis=1)

        translation = point[:, dim_rotations:]

        group_log = gs.zeros_like(point)
        group_log[:, :dim_rotations] = rot_vec
        skew_rot_vec = so_group.skew_matrix_from_vector(rot_vec)
        sq_skew_rot_vec = gs.matmul(skew_rot_vec, skew_rot_vec)

        mask_close_0 = gs.isclose(angle, 0)
        mask_close_0 = gs.squeeze(mask_close_0, axis=1)

        mask_close_pi = gs.isclose(angle, gs.pi)
        mask_close_pi = gs.squeeze(mask_close_pi, axis=1)

        mask_else = ~mask_close_0 & ~mask_close_pi

        coef_1 = - 0.5 * gs.ones_like(angle)
        coef_2 = gs.zeros_like(angle)

        coef_2[mask_close_0] = (1. / 12. + angle[mask_close_0] ** 2 / 720.
                                + angle[mask_close_0] ** 4 / 30240.
                                + angle[mask_close_0] ** 6 / 1209600.)

        delta_angle = angle[mask_close_pi] - gs.pi
        coef_2[mask_close_pi] = (1. / PI2
                                 + (PI2 - 8.) * delta_angle / (4. * PI3)
                                 - ((PI2 - 12.)
                                    * delta_angle ** 2 / (4. * PI4))
                                 + ((-192. + 12. * PI2 + PI4)
                                    * delta_angle ** 3 / (48. * PI5))
                                 - ((-240. + 12. * PI2 + PI4)
                                    * delta_angle ** 4 / (48. * PI6))
                                 + ((-2880. + 120. * PI2 + 10. * PI4 + PI6)
                                    * delta_angle ** 5 / (480. * PI7))
                                 - ((-3360 + 120. * PI2 + 10. * PI4 + PI6)
                                    * delta_angle ** 6 / (480. * PI8)))

        psi = (0.5 * angle[mask_else]
               * gs.sin(angle[mask_else]) / (1 - gs.cos(angle[mask_else])))
        coef_2[mask_else] = (1 - psi) / (angle[mask_else] ** 2)

        n_points, _ = point.shape
        group_log_translation = gs.zeros((n_points, self.n))
        for i in range(n_points):
            translation_i = translation[i]
            term_1_i = coef_1[i] * gs.dot(translation_i,
                                          gs.transpose(skew_rot_vec[i]))
            term_2_i = coef_2[i] * gs.dot(translation_i,
                                          gs.transpose(sq_skew_rot_vec[i]))
            group_log_translation[i] = translation_i + term_1_i + term_2_i

        group_log[:, dim_rotations:] = group_log_translation

        assert group_log.ndim == 2

        return group_log

    def random_uniform(self, n_samples=1):
        """
        Sample in SE(n) with the uniform distribution.
        """
        random_rot_vec = self.rotations.random_uniform(n_samples)
        random_translation = self.translations.random_uniform(n_samples)

        random_transfo = gs.concatenate([random_rot_vec, random_translation],
                                        axis=1)
        random_transfo = self.regularize(random_transfo)
        return random_transfo

    def exponential_matrix(self, rot_vec):
        """
        Compute the exponential of the rotation matrix represented by rot_vec.
        """

        rot_vec = self.rotations.regularize(rot_vec)
        n_rot_vecs, _ = rot_vec.shape

        angle = gs.linalg.norm(rot_vec, axis=1)
        angle = gs.to_ndarray(angle, to_ndim=2, axis=1)

        skew_rot_vec = so_group.skew_matrix_from_vector(rot_vec)

        coef_1 = gs.empty_like(angle)
        coef_2 = gs.empty_like(coef_1)

        mask_0 = gs.equal(angle, 0)
        mask_0 = gs.squeeze(mask_0, axis=1)
        mask_close_to_0 = gs.isclose(angle, 0)
        mask_close_to_0 = gs.squeeze(mask_close_to_0, axis=1)
        mask_else = ~mask_0 & ~mask_close_to_0

        coef_1[mask_close_to_0] = (1. / 2.
                                   - angle[mask_close_to_0] ** 2 / 24.)
        coef_2[mask_close_to_0] = (1. / 6.
                                   - angle[mask_close_to_0] ** 3 / 120.)

        # TODO(nina): check if the discountinuity as 0 is expected.
        coef_1[mask_0] = 0
        coef_2[mask_0] = 0

        coef_1[mask_else] = (angle[mask_else] ** (-2)
                             * (1. - gs.cos(angle[mask_else])))
        coef_2[mask_else] = (angle[mask_else] ** (-2)
                             * (1. - (gs.sin(angle[mask_else])
                                      / angle[mask_else])))

        term_1 = gs.zeros((n_rot_vecs, self.n, self.n))
        term_2 = gs.zeros_like(term_1)

        for i in range(n_rot_vecs):
            term_1[i] = gs.eye(self.n) + skew_rot_vec[i] * coef_1[i]
            term_2[i] = gs.matmul(skew_rot_vec[i], skew_rot_vec[i]) * coef_2[i]

        exponential_mat = term_1 + term_2
        assert exponential_mat.ndim == 3

        return exponential_mat

    def group_exponential_barycenter(self, points, weights=None):
        """
        Compute the group exponential barycenter in SE(n).
        """

        n_points = points.shape[0]
        assert n_points > 0

        if weights is None:
            weights = gs.ones((n_points, 1))

        weights = gs.to_ndarray(weights, to_ndim=2, axis=1)
        n_weights, _ = weights.shape
        assert n_points == n_weights

        dim = self.dimension
        rotations = self.rotations
        dim_rotations = rotations.dimension

        rotation_vectors = points[:, :dim_rotations]
        translations = points[:, dim_rotations:dim]
        assert rotation_vectors.shape == (n_points, dim_rotations)
        assert translations.shape == (n_points, self.n)

        mean_rotation = rotations.group_exponential_barycenter(
                                                points=rotation_vectors,
                                                weights=weights)
        mean_rotation_mat = rotations.matrix_from_rotation_vector(
                    mean_rotation)

        matrix = gs.zeros((1,) + (self.n,) * 2)
        translation_aux = gs.zeros((1, self.n))

        inv_rot_mats = rotations.matrix_from_rotation_vector(
                -rotation_vectors)
        # TODO(nina): this is the same mat multiplied several times
        matrix_aux = gs.matmul(mean_rotation_mat, inv_rot_mats)
        assert matrix_aux.shape == (n_points,) + (dim_rotations,) * 2

        vec_aux = rotations.rotation_vector_from_matrix(matrix_aux)
        matrix_aux = self.exponential_matrix(vec_aux)
        matrix_aux = gs.linalg.inv(matrix_aux)

        for i in range(n_points):
            matrix += weights[i] * matrix_aux[i]
            translation_aux += weights[i] * gs.dot(gs.matmul(
                                                        matrix_aux[i],
                                                        inv_rot_mats[i]),
                                                   translations[i])

        mean_translation = gs.dot(translation_aux,
                                  gs.transpose(gs.linalg.inv(matrix),
                                               axes=(0, 2, 1)))

        exp_bar = gs.zeros((1, dim))
        exp_bar[0, :dim_rotations] = mean_rotation
        exp_bar[0, dim_rotations:dim] = mean_translation

        return exp_bar
Exemplo n.º 3
0
def main(argv):

    # TF Record
    datafiles = FLAGS.data_dir + '/test/' + FLAGS.subject_id + '.tfrecord'
    dataset = tf.data.TFRecordDataset(datafiles)
    dataset = dataset.map(_parse_function_ifind)
    # dataset = dataset.repeat()
    # dataset = dataset.shuffle(FLAGS.queue_buffer)
    dataset = dataset.batch(1)
    image, vec, qt, AP1, AP2, AP3 = dataset.make_one_shot_iterator().get_next()

    # Nifti Volume
    subject_path = FLAGS.scan_dir + '/test/' + FLAGS.subject_id + '.nii.gz'
    fixed_image_sitk_tmp = sitk.ReadImage(subject_path, sitk.sitkFloat32)
    fixed_image_sitk = sitk.GetImageFromArray(
        sitk.GetArrayFromImage(fixed_image_sitk_tmp))
    fixed_image_sitk = sitk.RescaleIntensity(fixed_image_sitk, 0, 1)  # * 255.

    # Network Definition
    image_input = tf.placeholder(shape=[1, 224, 224, 1], dtype=tf.float32)
    image_resized = tf.image.resize_images(image, size=[224, 224])

    if FLAGS.loss == 'PoseNet':

        y_pred, _ = inception.inception_v3(image_input, num_classes=7)
        quaternion_pred, translation_pred = tf.split(y_pred, [4, 3], axis=1)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        SO3_GROUP = SpecialOrthogonalGroup(3)

        for i in range(FLAGS.n_iter):

            _image, _image_resized, _quaternion_true, _translation_true = \
                sess.run([image, image_resized, qt, AP2], )

            _quaternion_pred_sample = []
            _translation_pred_sample = []
            for j in range(FLAGS.n_samples):
                _quaternion_pred_i, _translation_pred_i = \
                    sess.run([quaternion_pred, translation_pred],
                             feed_dict={image_input: _image_resized})
                _quaternion_pred_sample.append(_quaternion_pred_i)
                _translation_pred_sample.append(_translation_pred_i)
                print(_quaternion_pred_i, _translation_pred_i)

            _quaternion_pred_sample = np.vstack(_quaternion_pred_sample)
            _rotvec_pred_sample = SO3_GROUP.rotation_vector_from_quaternion(
                _quaternion_pred_sample)
            _rotvec_pred = SO3_GROUP.left_canonical_metric.mean(
                _rotvec_pred_sample)

            _quaternion_pred = SO3_GROUP.quaternion_from_rotation_vector(
                _rotvec_pred)
            _translation_pred = np.mean(np.vstack(_translation_pred_sample),
                                        axis=0)

            # _quaternion_pred_variance = SO3_GROUP.left_canonical_metric.variance(_rotvec_pred_sample)
            _translation_pred_variance = np.var(
                np.vstack(_translation_pred_sample), axis=0)

            rx = SO3_GROUP.matrix_from_quaternion(_quaternion_pred)[0]
            tx = _translation_pred[0] * 60.

            image_true = np.squeeze(_image)
            image_pred = resample_sitk(fixed_image_sitk, rx, tx)

            imageio.imsave('imgdump/image_{}_true.png'.format(i), _image[0,
                                                                         ...])
            imageio.imsave('imgdump/image_{}_pred.png'.format(i), image_pred)

            calc_psnr(image_pred, image_true)
            calc_mse(image_pred, image_true)
            calc_ssim(image_pred, image_true)
            calc_correlation(image_pred, image_true)

    elif FLAGS.loss == 'AP':

        y_pred, _ = inception.inception_v3(image_input, num_classes=9)
        AP1_pred, AP2_pred, AP3_pred = tf.split(y_pred, 3, axis=1)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        for i in range(FLAGS.n_iter):

            _image, _image_resized, _AP1, _AP2, _AP3 = \
                sess.run([image, image_resized, AP1, AP2, AP3])

            _AP1_sample = []
            _AP2_sample = []
            _AP3_sample = []
            for j in range(FLAGS.n_samples):
                _AP1_pred_i, _AP2_pred_i, _AP3_pred_i = \
                    sess.run([AP1_pred, AP2_pred, AP3_pred],
                             feed_dict={image_input: _image_resized})
                _AP1_sample.append(_AP1_pred_i)
                _AP2_sample.append(_AP2_pred_i)
                _AP3_sample.append(_AP3_pred_i)

            _AP1_pred = np.mean(np.vstack(_AP1_sample), axis=0)
            _AP2_pred = np.mean(np.vstack(_AP2_sample), axis=0)
            _AP3_pred = np.mean(np.vstack(_AP3_sample), axis=0)

            _AP1_pred_variance = np.var(np.vstack(_AP1_sample), axis=0)
            _AP2_pred_variance = np.var(np.vstack(_AP2_sample), axis=0)
            _AP3_pred_variance = np.var(np.vstack(_AP3_sample), axis=0)

            dist_ap1 = np.linalg.norm(_AP1 - _AP1_pred)
            dist_ap2 = np.linalg.norm(_AP2 - _AP2_pred)
            dist_ap3 = np.linalg.norm(_AP3 - _AP3_pred)

            rx = matrix_from_anchor_points(_AP1_pred[0], _AP2_pred[0],
                                           _AP3_pred[0])
            tx = _AP2_pred[0] * 60.

            image_true = np.squeeze(_image)
            image_pred = resample_sitk(fixed_image_sitk, rx, tx)

            imageio.imsave('imgdump/image_{}_true.png'.format(i), _image[0,
                                                                         ...])
            imageio.imsave('imgdump/image_{}_pred.png'.format(i), image_pred)

            calc_psnr(image_pred, image_true)
            calc_mse(image_pred, image_true)
            calc_ssim(image_pred, image_true)
            calc_correlation(image_pred, image_true)

    elif FLAGS.loss == 'SE3':

        y_pred, _ = inception.inception_v3(image_input, num_classes=6)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        SO3_GROUP = SpecialOrthogonalGroup(3)
        SE3_GROUP = SpecialEuclideanGroup(3)

        for i in range(FLAGS.n_iter):

            print(i)

            _image, _image_resized, _rvec, _tvec = \
                sess.run([image, image_resized, vec, AP2])

            _y_pred_sample = []
            for j in range(FLAGS.n_samples):
                _y_pred_i = sess.run([y_pred],
                                     feed_dict={image_input: _image_resized})
                _y_pred_sample.append(_y_pred_i[0])

            _y_pred_sample = np.vstack(_y_pred_sample)
            _y_pred = SE3_GROUP.left_canonical_metric.mean(_y_pred_sample)
            _y_pred_variance = SE3_GROUP.left_canonical_metric.variance(
                _y_pred_sample)

            rx = SO3_GROUP.matrix_from_rotation_vector(_y_pred[0, :3])[0]
            tx = _y_pred[0, 3:] * 60.

            image_true = np.squeeze(_image)
            image_pred = resample_sitk(fixed_image_sitk, rx, tx)

            imageio.imsave('imgdump/image_{}_true.png'.format(i), _image[0,
                                                                         ...])
            imageio.imsave('imgdump/image_{}_pred.png'.format(i), image_pred)

            calc_psnr(image_pred, image_true)
            calc_mse(image_pred, image_true)
            calc_ssim(image_pred, image_true)
            calc_correlation(image_pred, image_true)

    else:
        print('Invalid Option:', FLAGS.loss)
        raise SystemExit
Exemplo n.º 4
0
def main(argv):

    # TF Record
    datafiles = FLAGS.data_dir + '/test/' + FLAGS.subject_id + '.tfrecord'
    dataset = tf.data.TFRecordDataset(datafiles)
    dataset = dataset.map(_parse_function_ifind)
    # dataset = dataset.repeat()
    # dataset = dataset.shuffle(FLAGS.queue_buffer)
    dataset = dataset.batch(1)
    image, vec, qt, AP1, AP2, AP3 = dataset.make_one_shot_iterator().get_next()

    # Nifti Volume
    subject_path = FLAGS.scan_dir + '/test/' + FLAGS.subject_id + '.nii.gz'
    fixed_image_sitk_tmp = sitk.ReadImage(subject_path, sitk.sitkFloat32)
    fixed_image_sitk = sitk.GetImageFromArray(
        sitk.GetArrayFromImage(fixed_image_sitk_tmp))
    fixed_image_sitk = sitk.RescaleIntensity(fixed_image_sitk, 0, 1) * 255.

    # Network Definition
    image_resized = tf.image.resize_images(image, size=[224, 224])

    # Measurements
    cc = []
    mse = []
    psnr = []
    ssim = []

    if FLAGS.loss == 'PoseNet':

        y_pred, _ = inception.inception_v3(image_resized,
                                           num_classes=7,
                                           is_training=False)
        quaternion_pred, translation_pred = tf.split(y_pred, [4, 3], axis=1)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        SO3_GROUP = SpecialOrthogonalGroup(3)

        for i in tqdm.tqdm(range(FLAGS.n_iter)):

            _image, _quaternion_true, _translation_true, _quaternion_pred, _translation_pred = \
                sess.run([image, qt, AP2, quaternion_pred, translation_pred])

            rx = SO3_GROUP.matrix_from_quaternion(_quaternion_pred)[0]
            tx = _translation_pred[0] * 60.

            image_true = np.squeeze(_image)
            image_pred = resample_sitk(fixed_image_sitk, rx, tx)

            imageio.imsave('imgdump/image_{}_true.png'.format(i),
                           np.uint8(_image[0, ...]))
            imageio.imsave('imgdump/image_{}_pred.png'.format(i),
                           np.uint8(image_pred))

            cc.append(calc_correlation(image_pred, image_true))
            mse.append(calc_mse(image_pred, image_true))
            psnr.append(calc_psnr(image_pred, image_true))
            ssim.append(calc_ssim(image_pred, image_true))

    elif FLAGS.loss == 'AP':

        y_pred, _ = inception.inception_v3(image_resized,
                                           num_classes=9,
                                           is_training=False)
        AP1_pred, AP2_pred, AP3_pred = tf.split(y_pred, 3, axis=1)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        for i in tqdm.tqdm(range(FLAGS.n_iter)):

            _image, _AP1, _AP2, _AP3, _AP1_pred, _AP2_pred, _AP3_pred = \
                sess.run([image, AP1, AP2, AP3, AP1_pred, AP2_pred, AP3_pred])

            dist_ap1 = np.linalg.norm(_AP1 - _AP1_pred)
            dist_ap2 = np.linalg.norm(_AP2 - _AP2_pred)
            dist_ap3 = np.linalg.norm(_AP3 - _AP3_pred)

            rx = matrix_from_anchor_points(_AP1_pred[0], _AP2_pred[0],
                                           _AP3_pred[0])
            tx = _AP2_pred[0] * 60.

            image_true = np.squeeze(_image)
            image_pred = resample_sitk(fixed_image_sitk, rx, tx)

            imageio.imsave('imgdump/image_{}_true.png'.format(i),
                           np.uint8(_image[0, ...]))
            imageio.imsave('imgdump/image_{}_pred.png'.format(i),
                           np.uint8(image_pred))

            cc.append(calc_correlation(image_pred, image_true))
            mse.append(calc_mse(image_pred, image_true))
            psnr.append(calc_psnr(image_pred, image_true))
            ssim.append(calc_ssim(image_pred, image_true))

    elif FLAGS.loss == 'SE3':

        y_pred, _ = inception.inception_v3(image_resized,
                                           num_classes=6,
                                           is_training=False)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        SO3_GROUP = SpecialOrthogonalGroup(3)
        SE3_GROUP = SpecialEuclideanGroup(3)
        _se3_err_i = []

        for i in tqdm.tqdm(range(FLAGS.n_iter)):

            _image, _rvec, _tvec, _y_pred = \
                sess.run([image, vec, AP2, y_pred])

            rx = SO3_GROUP.matrix_from_rotation_vector(_y_pred[0, :3])[0]
            tx = _y_pred[0, 3:] * 60.

            image_true = np.squeeze(_image)
            image_pred = resample_sitk(fixed_image_sitk, rx, tx)

            imageio.imsave('imgdump/image_{}_true.png'.format(i),
                           np.uint8(_image[0, ...]))
            imageio.imsave('imgdump/image_{}_pred.png'.format(i),
                           np.uint8(image_pred))

            cc.append(calc_correlation(image_pred, image_true))
            mse.append(calc_mse(image_pred, image_true))
            psnr.append(calc_psnr(image_pred, image_true))
            ssim.append(calc_ssim(image_pred, image_true))

            _y_true = np.concatenate((_rvec, _tvec), axis=-1)
            _se3_err_i.append(
                SE3_GROUP.compose(SE3_GROUP.inverse(_y_true), _y_pred))

        err_vec = np.vstack(_se3_err_i)
        err_weights = np.diag(np.linalg.inv(np.cov(err_vec.T)))
        err_weights = err_weights / np.linalg.norm(err_weights)
        print(err_weights)

    else:
        print('Invalid Option:', FLAGS.loss)
        raise SystemExit

    cc = np.stack(cc)
    mse = np.stack(mse)
    psnr = np.stack(psnr)
    ssim = np.stack(ssim)

    print('CC:', np.median(cc))
    print('MSE:', np.median(mse))
    print('PSNR:', np.median(psnr))
    print('SSIM:', np.median(ssim))
Exemplo n.º 5
0
class TestBackendNumpy(unittest.TestCase):
    _multiprocess_can_split_ = True

    @classmethod
    def setUpClass(cls):
        cls.initial_backend = os.environ['GEOMSTATS_BACKEND']
        os.environ['GEOMSTATS_BACKEND'] = 'numpy'
        importlib.reload(gs)

    @classmethod
    def tearDownClass(cls):
        os.environ['GEOMSTATS_BACKEND'] = cls.initial_backend
        importlib.reload(gs)

    def setUp(self):
        warnings.simplefilter('ignore', category=ImportWarning)

        self.so3_group = SpecialOrthogonalGroup(n=3)
        self.n_samples = 2

    def test_logm(self):
        point = gs.array([[2., 0., 0.], [0., 3., 0.], [0., 0., 4.]])
        result = gs.linalg.logm(point)
        expected = gs.array([[0.693147180, 0., 0.], [0., 1.098612288, 0.],
                             [0., 0., 1.38629436]])

        self.assertTrue(gs.allclose(result, expected))

    def test_expm_and_logm(self):
        point = gs.array([[2., 0., 0.], [0., 3., 0.], [0., 0., 4.]])
        result = gs.linalg.expm(gs.linalg.logm(point))
        expected = point

        self.assertTrue(gs.allclose(result, expected))

    def test_expm_vectorization(self):
        point = gs.array([[[2., 0., 0.], [0., 3., 0.], [0., 0., 4.]],
                          [[1., 0., 0.], [0., 5., 0.], [0., 0., 6.]]])

        expected = gs.array([[[7.38905609, 0., 0.], [0., 20.0855369, 0.],
                              [0., 0., 54.5981500]],
                             [[2.718281828, 0., 0.], [0., 148.413159, 0.],
                              [0., 0., 403.42879349]]])

        result = gs.linalg.expm(point)

        self.assertTrue(gs.allclose(result, expected))

    def test_logm_vectorization_diagonal(self):
        point = gs.array([[[2., 0., 0.], [0., 3., 0.], [0., 0., 4.]],
                          [[1., 0., 0.], [0., 5., 0.], [0., 0., 6.]]])

        expected = gs.array([[[0.693147180, 0., 0.], [0., 1.09861228866, 0.],
                              [0., 0., 1.38629436]],
                             [[0., 0., 0.], [0., 1.609437912, 0.],
                              [0., 0., 1.79175946]]])

        result = gs.linalg.logm(point)

        self.assertTrue(gs.allclose(result, expected))

    def test_expm_and_logm_vectorization_random_rotation(self):
        point = self.so3_group.random_uniform(self.n_samples)
        point = self.so3_group.matrix_from_rotation_vector(point)

        result = gs.linalg.expm(gs.linalg.logm(point))
        expected = point

        self.assertTrue(gs.allclose(result, expected))

    def test_expm_and_logm_vectorization(self):
        point = gs.array([[[2., 0., 0.], [0., 3., 0.], [0., 0., 4.]],
                          [[1., 0., 0.], [0., 5., 0.], [0., 0., 6.]]])
        result = gs.linalg.expm(gs.linalg.logm(point))
        expected = point

        self.assertTrue(gs.allclose(result, expected))
Exemplo n.º 6
0
class SpecialEuclideanGroup(LieGroup):
    """
    Class for the special euclidean group SE(n),
    i.e. the Lie group of rigid transformations.
    """
    def __init__(self, n, point_type=None, epsilon=0.):
        assert isinstance(n, int) and n > 1

        self.n = n
        self.dimension = int((n * (n - 1)) / 2 + n)

        self.epsilon = epsilon

        self.default_point_type = point_type
        if point_type is None:
            self.default_point_type = 'vector' if n == 3 else 'matrix'

        super(SpecialEuclideanGroup, self).__init__(dimension=self.dimension)

        self.rotations = SpecialOrthogonalGroup(n=n, epsilon=epsilon)
        self.translations = EuclideanSpace(dimension=n)

    def get_identity(self, point_type=None):
        """
        Get the identity of the group,
        as a vector if point_type == 'vector',
        as a matrix if point_type == 'matrix'.
        """
        if point_type is None:
            point_type = self.default_point_type

        identity = gs.zeros(self.dimension)
        if self.default_point_type == 'matrix':
            identity = gs.eye(self.n)
        return identity

    identity = property(get_identity)

    def belongs(self, point, point_type=None):
        """
        Evaluate if a point belongs to SE(n).
        """
        if point_type is None:
            point_type = self.default_point_type

        if point_type == 'vector':
            point = gs.to_ndarray(point, to_ndim=2)
            n_points, point_dim = point.shape
            belongs = point_dim == self.dimension
            belongs = gs.to_ndarray(belongs, to_ndim=1)
            belongs = gs.to_ndarray(belongs, to_ndim=2, axis=1)
            belongs = gs.tile(belongs, (n_points, 1))
        elif point_type == 'matrix':
            point = gs.to_ndarray(point, to_ndim=3)
            raise NotImplementedError()

        return belongs

    def regularize(self, point, point_type=None):
        """
        Regularize a point to the canonical representation
        chosen for SE(n).
        """
        if point_type is None:
            point_type = self.default_point_type

        if point_type == 'vector':
            point = gs.to_ndarray(point, to_ndim=2)

            rotations = self.rotations
            dim_rotations = rotations.dimension

            rot_vec = point[:, :dim_rotations]
            regularized_rot_vec = rotations.regularize(rot_vec,
                                                       point_type=point_type)

            translation = point[:, dim_rotations:]

            regularized_point = gs.concatenate(
                [regularized_rot_vec, translation], axis=1)

        elif point_type == 'matrix':
            point = gs.to_ndarray(point, to_ndim=3)
            regularized_point = gs.copy(point)

        return regularized_point

    def regularize_tangent_vec_at_identity(self,
                                           tangent_vec,
                                           metric=None,
                                           point_type=None):
        if point_type is None:
            point_type = self.default_point_type

        return self.regularize_tangent_vec(tangent_vec,
                                           self.identity,
                                           metric,
                                           point_type=point_type)

    def regularize_tangent_vec(self,
                               tangent_vec,
                               base_point,
                               metric=None,
                               point_type=None):
        if point_type is None:
            point_type = self.default_point_type

        if metric is None:
            metric = self.left_canonical_metric

        if point_type == 'vector':
            tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2)
            base_point = gs.to_ndarray(base_point, to_ndim=2)

            rotations = self.rotations
            dim_rotations = rotations.dimension

            rot_tangent_vec = tangent_vec[:, :dim_rotations]
            rot_base_point = base_point[:, :dim_rotations]

            metric_mat = metric.inner_product_mat_at_identity
            rot_metric_mat = metric_mat[:, :dim_rotations, :dim_rotations]
            rot_metric = InvariantMetric(
                group=rotations,
                inner_product_mat_at_identity=rot_metric_mat,
                left_or_right=metric.left_or_right)

            regularized_vec = gs.zeros_like(tangent_vec)
            rotations_vec = rotations.regularize_tangent_vec(
                tangent_vec=rot_tangent_vec,
                base_point=rot_base_point,
                metric=rot_metric,
                point_type=point_type)

            regularized_vec = gs.concatenate(
                [rotations_vec, tangent_vec[:, dim_rotations:]], axis=1)

        elif point_type == 'matrix':
            regularized_vec = tangent_vec

        return regularized_vec

    def compose(self, point_1, point_2, point_type=None):
        """
        Compose two elements of SE(n).

        Formula:
        point_1 . point_2 = [R1 * R2, (R1 * t2) + t1]
        where:
        R1, R2 are rotation matrices,
        t1, t2 are translation vectors.
        """
        if point_type is None:
            point_type = self.default_point_type

        rotations = self.rotations
        dim_rotations = rotations.dimension

        point_1 = self.regularize(point_1, point_type=point_type)
        point_2 = self.regularize(point_2, point_type=point_type)

        if point_type == 'vector':
            n_points_1, _ = point_1.shape
            n_points_2, _ = point_2.shape

            assert (point_1.shape == point_2.shape or n_points_1 == 1
                    or n_points_2 == 1)

            if n_points_1 == 1:
                point_1 = gs.stack([point_1[0]] * n_points_2)

            if n_points_2 == 1:
                point_2 = gs.stack([point_2[0]] * n_points_1)

            rot_vec_1 = point_1[:, :dim_rotations]
            rot_mat_1 = rotations.matrix_from_rotation_vector(rot_vec_1)

            rot_vec_2 = point_2[:, :dim_rotations]
            rot_mat_2 = rotations.matrix_from_rotation_vector(rot_vec_2)

            translation_1 = point_1[:, dim_rotations:]
            translation_2 = point_2[:, dim_rotations:]

            composition_rot_mat = gs.matmul(rot_mat_1, rot_mat_2)
            composition_rot_vec = rotations.rotation_vector_from_matrix(
                composition_rot_mat)

            composition_translation = gs.einsum('ij,ikj->ik', translation_2,
                                                rot_mat_1) + translation_1

            composition = gs.concatenate(
                (composition_rot_vec, composition_translation), axis=1)

        elif point_type == 'matrix':
            raise NotImplementedError()

        composition = self.regularize(composition, point_type=point_type)
        return composition

    def inverse(self, point, point_type=None):
        """
        Compute the group inverse in SE(n).

        Formula:
        (R, t)^{-1} = (R^{-1}, R^{-1}.(-t))
        """
        if point_type is None:
            point_type = self.default_point_type

        rotations = self.rotations
        dim_rotations = rotations.dimension

        point = self.regularize(point)

        if point_type == 'vector':
            n_points, _ = point.shape

            rot_vec = point[:, :dim_rotations]
            translation = point[:, dim_rotations:]

            inverse_point = gs.zeros_like(point)
            inverse_rotation = -rot_vec

            inv_rot_mat = rotations.matrix_from_rotation_vector(
                inverse_rotation)

            inverse_translation = gs.einsum(
                'ni,nij->nj', -translation,
                gs.transpose(inv_rot_mat, axes=(0, 2, 1)))

            inverse_point = gs.concatenate(
                [inverse_rotation, inverse_translation], axis=1)

        elif point_type == 'matrix':
            raise NotImplementedError()

        inverse_point = self.regularize(inverse_point, point_type=point_type)
        return inverse_point

    def jacobian_translation(self,
                             point,
                             left_or_right='left',
                             point_type=None):
        """
        Compute the jacobian matrix of the differential
        of the left/right translations from the identity to point in SE(n).
        """
        if point_type is None:
            point_type = self.default_point_type

        assert left_or_right in ('left', 'right')

        dim = self.dimension
        rotations = self.rotations
        translations = self.translations
        dim_rotations = rotations.dimension
        dim_translations = translations.dimension

        point = self.regularize(point, point_type=point_type)

        if point_type == 'vector':
            n_points, _ = point.shape

            rot_vec = point[:, :dim_rotations]

            jacobian = gs.zeros((n_points, ) + (dim, ) * 2)
            jacobian_rot = self.rotations.jacobian_translation(
                point=rot_vec,
                left_or_right=left_or_right,
                point_type=point_type)
            block_zeros_1 = gs.zeros(
                (n_points, dim_rotations, dim_translations))
            jacobian_block_line_1 = gs.concatenate(
                [jacobian_rot, block_zeros_1], axis=2)

            if left_or_right == 'left':
                rot_mat = self.rotations.matrix_from_rotation_vector(rot_vec)
                jacobian_trans = rot_mat
                block_zeros_2 = gs.zeros(
                    (n_points, dim_translations, dim_rotations))
                jacobian_block_line_2 = gs.concatenate(
                    [block_zeros_2, jacobian_trans], axis=2)

            else:
                inv_skew_mat = -self.rotations.skew_matrix_from_vector(rot_vec)
                eye = gs.to_ndarray(gs.eye(self.n), to_ndim=3)
                eye = gs.tile(eye, [n_points, 1, 1])
                jacobian_block_line_2 = gs.concatenate([inv_skew_mat, eye],
                                                       axis=2)

            jacobian = gs.concatenate(
                [jacobian_block_line_1, jacobian_block_line_2], axis=1)

            assert gs.ndim(jacobian) == 3

        elif point_type == 'matrix':
            raise NotImplementedError()

        return jacobian

    def group_exp_from_identity(self, tangent_vec, point_type=None):
        """
        Compute the group exponential of the tangent vector at the identity.
        """
        if point_type is None:
            point_type = self.default_point_type

        if point_type == 'vector':
            tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2)

            rotations = self.rotations
            dim_rotations = rotations.dimension

            rot_vec = tangent_vec[:, :dim_rotations]
            rot_vec = self.rotations.regularize(rot_vec, point_type=point_type)
            translation = tangent_vec[:, dim_rotations:]

            angle = gs.linalg.norm(rot_vec, axis=1)
            angle = gs.to_ndarray(angle, to_ndim=2, axis=1)

            skew_mat = self.rotations.skew_matrix_from_vector(rot_vec)
            sq_skew_mat = gs.matmul(skew_mat, skew_mat)

            mask_0 = gs.equal(angle, 0.)
            mask_close_0 = gs.isclose(angle, 0.) & ~mask_0
            mask_else = ~mask_0 & ~mask_close_0

            mask_0_float = gs.cast(mask_0, gs.float32)
            mask_close_0_float = gs.cast(mask_close_0, gs.float32)
            mask_else_float = gs.cast(mask_else, gs.float32)

            angle += mask_0_float * gs.ones_like(angle)

            coef_1 = gs.zeros_like(angle)
            coef_2 = gs.zeros_like(angle)

            coef_1 += mask_0_float * 1. / 2. * gs.ones_like(angle)
            coef_2 += mask_0_float * 1. / 6. * gs.ones_like(angle)

            coef_1 += mask_close_0_float * (
                TAYLOR_COEFFS_1_AT_0[0] + TAYLOR_COEFFS_1_AT_0[2] * angle**2 +
                TAYLOR_COEFFS_1_AT_0[4] * angle**4 +
                TAYLOR_COEFFS_1_AT_0[6] * angle**6)
            coef_2 += mask_close_0_float * (
                TAYLOR_COEFFS_2_AT_0[0] + TAYLOR_COEFFS_2_AT_0[2] * angle**2 +
                TAYLOR_COEFFS_2_AT_0[4] * angle**4 +
                TAYLOR_COEFFS_2_AT_0[6] * angle**6)

            coef_1 += mask_else_float * ((1. - gs.cos(angle)) / angle**2)
            coef_2 += mask_else_float * ((angle - gs.sin(angle)) / angle**3)

            n_tangent_vecs, _ = tangent_vec.shape
            group_exp_translation = gs.zeros((n_tangent_vecs, self.n))
            for i in range(n_tangent_vecs):
                translation_i = translation[i]
                term_1_i = coef_1[i] * gs.dot(translation_i,
                                              gs.transpose(skew_mat[i]))
                term_2_i = coef_2[i] * gs.dot(translation_i,
                                              gs.transpose(sq_skew_mat[i]))
                mask_i_float = gs.get_mask_i_float(i, n_tangent_vecs)
                group_exp_translation += mask_i_float * (translation_i +
                                                         term_1_i + term_2_i)

            group_exp = gs.concatenate([rot_vec, group_exp_translation],
                                       axis=1)

            group_exp = self.regularize(group_exp, point_type=point_type)
            return group_exp
        elif point_type == 'matrix':
            raise NotImplementedError()

    def group_log_from_identity(self, point, point_type=None):
        """
        Compute the group logarithm of the point at the identity.
        """
        if point_type is None:
            point_type = self.default_point_type

        point = self.regularize(point, point_type=point_type)

        rotations = self.rotations
        dim_rotations = rotations.dimension

        if point_type == 'vector':
            rot_vec = point[:, :dim_rotations]
            angle = gs.linalg.norm(rot_vec, axis=1)
            angle = gs.to_ndarray(angle, to_ndim=2, axis=1)

            translation = point[:, dim_rotations:]

            skew_rot_vec = rotations.skew_matrix_from_vector(rot_vec)
            sq_skew_rot_vec = gs.matmul(skew_rot_vec, skew_rot_vec)

            mask_close_0 = gs.isclose(angle, 0.)
            mask_close_pi = gs.isclose(angle, gs.pi)
            mask_else = ~mask_close_0 & ~mask_close_pi

            mask_close_0_float = gs.cast(mask_close_0, gs.float32)
            mask_close_pi_float = gs.cast(mask_close_pi, gs.float32)
            mask_else_float = gs.cast(mask_else, gs.float32)

            mask_0 = gs.isclose(angle, 0., atol=1e-6)
            mask_0_float = gs.cast(mask_0, gs.float32)
            angle += mask_0_float * gs.ones_like(angle)

            coef_1 = -0.5 * gs.ones_like(angle)
            coef_2 = gs.zeros_like(angle)

            coef_2 += mask_close_0_float * (1. / 12. + angle**2 / 720. +
                                            angle**4 / 30240. +
                                            angle**6 / 1209600.)

            delta_angle = angle - gs.pi
            coef_2 += mask_close_pi_float * (
                1. / PI2 + (PI2 - 8.) * delta_angle / (4. * PI3) -
                ((PI2 - 12.) * delta_angle**2 / (4. * PI4)) +
                ((-192. + 12. * PI2 + PI4) * delta_angle**3 / (48. * PI5)) -
                ((-240. + 12. * PI2 + PI4) * delta_angle**4 / (48. * PI6)) +
                ((-2880. + 120. * PI2 + 10. * PI4 + PI6) * delta_angle**5 /
                 (480. * PI7)) -
                ((-3360 + 120. * PI2 + 10. * PI4 + PI6) * delta_angle**6 /
                 (480. * PI8)))

            psi = 0.5 * angle * gs.sin(angle) / (1 - gs.cos(angle))
            coef_2 += mask_else_float * (1 - psi) / (angle**2)

            n_points, _ = point.shape
            group_log_translation = gs.zeros((n_points, self.n))
            for i in range(n_points):
                translation_i = translation[i]
                term_1_i = coef_1[i] * gs.dot(translation_i,
                                              gs.transpose(skew_rot_vec[i]))
                term_2_i = coef_2[i] * gs.dot(translation_i,
                                              gs.transpose(sq_skew_rot_vec[i]))
                mask_i_float = gs.get_mask_i_float(i, n_points)
                group_log_translation += mask_i_float * (translation_i +
                                                         term_1_i + term_2_i)

            group_log = gs.concatenate([rot_vec, group_log_translation],
                                       axis=1)

            assert gs.ndim(group_log) == 2

        elif point_type == 'matrix':
            raise NotImplementedError()

        return group_log

    def random_uniform(self, n_samples=1, point_type=None):
        """
        Sample in SE(n) with the uniform distribution.
        """
        if point_type is None:
            point_type = self.default_point_type

        random_rot_vec = self.rotations.random_uniform(n_samples,
                                                       point_type=point_type)
        random_translation = self.translations.random_uniform(n_samples)

        if point_type == 'vector':
            random_transfo = gs.concatenate(
                [random_rot_vec, random_translation], axis=1)

        elif point_type == 'matrix':
            raise NotImplementedError()

        random_transfo = self.regularize(random_transfo, point_type=point_type)
        return random_transfo

    def exponential_matrix(self, rot_vec):
        """
        Compute the exponential of the rotation matrix represented by rot_vec.
        """

        rot_vec = self.rotations.regularize(rot_vec)
        n_rot_vecs, _ = rot_vec.shape

        angle = gs.linalg.norm(rot_vec, axis=1)
        angle = gs.to_ndarray(angle, to_ndim=2, axis=1)

        skew_rot_vec = self.rotations.skew_matrix_from_vector(rot_vec)

        coef_1 = gs.empty_like(angle)
        coef_2 = gs.empty_like(coef_1)

        mask_0 = gs.equal(angle, 0)
        mask_0 = gs.squeeze(mask_0, axis=1)
        mask_close_to_0 = gs.isclose(angle, 0)
        mask_close_to_0 = gs.squeeze(mask_close_to_0, axis=1)
        mask_else = ~mask_0 & ~mask_close_to_0

        coef_1[mask_close_to_0] = (1. / 2. - angle[mask_close_to_0]**2 / 24.)
        coef_2[mask_close_to_0] = (1. / 6. - angle[mask_close_to_0]**3 / 120.)

        # TODO(nina): Check if the discountinuity at 0 is expected.
        coef_1[mask_0] = 0
        coef_2[mask_0] = 0

        coef_1[mask_else] = (angle[mask_else]**(-2) *
                             (1. - gs.cos(angle[mask_else])))
        coef_2[mask_else] = (angle[mask_else]**(-2) *
                             (1. -
                              (gs.sin(angle[mask_else]) / angle[mask_else])))

        term_1 = gs.zeros((n_rot_vecs, self.n, self.n))
        term_2 = gs.zeros_like(term_1)

        for i in range(n_rot_vecs):
            term_1[i] = gs.eye(self.n) + skew_rot_vec[i] * coef_1[i]
            term_2[i] = gs.matmul(skew_rot_vec[i], skew_rot_vec[i]) * coef_2[i]

        exponential_mat = term_1 + term_2
        assert exponential_mat.ndim == 3

        return exponential_mat

    def group_exponential_barycenter(self,
                                     points,
                                     weights=None,
                                     point_type=None):
        """
        Compute the group exponential barycenter in SE(n).
        """
        if point_type is None:
            point_type = self.default_point_type

        n_points = points.shape[0]
        assert n_points > 0

        if weights is None:
            weights = gs.ones((n_points, 1))

        weights = gs.to_ndarray(weights, to_ndim=2, axis=1)
        n_weights, _ = weights.shape
        assert n_points == n_weights

        dim = self.dimension
        rotations = self.rotations
        dim_rotations = rotations.dimension

        if point_type == 'vector':
            rotation_vectors = points[:, :dim_rotations]
            translations = points[:, dim_rotations:dim]
            assert rotation_vectors.shape == (n_points, dim_rotations)
            assert translations.shape == (n_points, self.n)

            mean_rotation = rotations.group_exponential_barycenter(
                points=rotation_vectors, weights=weights)
            mean_rotation_mat = rotations.matrix_from_rotation_vector(
                mean_rotation)

            matrix = gs.zeros((1, ) + (self.n, ) * 2)
            translation_aux = gs.zeros((1, self.n))

            inv_rot_mats = rotations.matrix_from_rotation_vector(
                -rotation_vectors)
            matrix_aux = gs.matmul(mean_rotation_mat, inv_rot_mats)
            assert matrix_aux.shape == (n_points, ) + (dim_rotations, ) * 2

            vec_aux = rotations.rotation_vector_from_matrix(matrix_aux)
            matrix_aux = self.exponential_matrix(vec_aux)
            matrix_aux = gs.linalg.inv(matrix_aux)

            for i in range(n_points):
                matrix += weights[i] * matrix_aux[i]
                translation_aux += weights[i] * gs.dot(
                    gs.matmul(matrix_aux[i], inv_rot_mats[i]), translations[i])

            mean_translation = gs.dot(
                translation_aux,
                gs.transpose(gs.linalg.inv(matrix), axes=(0, 2, 1)))

            exp_bar = gs.zeros((1, dim))
            exp_bar[0, :dim_rotations] = mean_rotation
            exp_bar[0, dim_rotations:dim] = mean_translation

        elif point_type == 'matrix':
            vector_points = self.rotation_vector_from_matrix(points)
            vector_exp_bar = self.group_exponential_barycenter(
                vector_points, weights, point_type='vector')
            exp_bar = self.matrix_from_rotation_vector(vector_exp_bar)
        return exp_bar
Exemplo n.º 7
0
class TestGeneralLinearGroupMethods(unittest.TestCase):
    _multiprocess_can_split_ = True

    def setUp(self):
        gs.random.seed(1234)
        n = 3
        self.group = GeneralLinearGroup(n=n)
        # We generate invertible matrices using so3_group
        self.so3_group = SpecialOrthogonalGroup(n=n)

    def test_belongs(self):
        """
        A rotation matrix belongs to the matrix Lie group
        of invertible matrices.
        """
        rot_vec = self.so3_group.random_uniform()
        rot_mat = self.so3_group.matrix_from_rotation_vector(rot_vec)

        self.assertTrue(self.group.belongs(rot_mat))

    def test_compose(self):
        # 1. Composition by identity, on the right
        # Expect the original transformation
        rot_vec_1 = self.so3_group.random_uniform()
        mat_1 = self.so3_group.matrix_from_rotation_vector(rot_vec_1)

        result_1 = self.group.compose(mat_1, self.group.identity)
        expected_1 = mat_1

        self.assertTrue(gs.allclose(result_1, expected_1))

        # 2. Composition by identity, on the left
        # Expect the original transformation
        rot_vec_2 = self.so3_group.random_uniform()
        mat_2 = self.so3_group.matrix_from_rotation_vector(rot_vec_2)

        result_2 = self.group.compose(self.group.identity, mat_2)
        expected_2 = mat_2

        norm = gs.linalg.norm(expected_2)
        atol = RTOL
        if norm != 0:
            atol = RTOL * norm
        self.assertTrue(
            gs.allclose(result_2, expected_2, atol=atol), '\nresult:\n{}'
            '\nexpected:\n{}'.format(result_2, expected_2))

    def test_compose_and_inverse(self):
        # 1. Compose transformation by its inverse on the right
        # Expect the group identity
        rot_vec_1 = self.so3_group.random_uniform()
        mat_1 = self.so3_group.matrix_from_rotation_vector(rot_vec_1)
        inv_mat_1 = self.group.inverse(mat_1)

        result_1 = self.group.compose(mat_1, inv_mat_1)
        expected_1 = self.group.identity

        norm = gs.linalg.norm(expected_1)
        atol = RTOL
        if norm != 0:
            atol = RTOL * norm

        self.assertTrue(
            gs.allclose(result_1, expected_1, atol=atol), '\nresult:\n{}'
            '\nexpected:\n{}'.format(result_1, expected_1))

        # 2. Compose transformation by its inverse on the left
        # Expect the group identity
        rot_vec_2 = self.so3_group.random_uniform()
        mat_2 = self.so3_group.matrix_from_rotation_vector(rot_vec_2)
        inv_mat_2 = self.group.inverse(mat_2)

        result_2 = self.group.compose(inv_mat_2, mat_2)
        expected_2 = self.group.identity

        norm = gs.linalg.norm(expected_2)
        atol = RTOL
        if norm != 0:
            atol = RTOL * norm

        self.assertTrue(gs.allclose(result_2, expected_2, atol=atol))
Exemplo n.º 8
0
class TestGeneralLinearGroupTensorFlow(tf.test.TestCase):
    _multiprocess_can_split_ = True

    def setUp(self):
        gs.random.seed(1234)
        n = 3
        self.group = GeneralLinearGroup(n=n)
        # We generate invertible matrices using so3_group
        self.so3_group = SpecialOrthogonalGroup(n=n)

    @classmethod
    def setUpClass(cls):
        os.environ['GEOMSTATS_BACKEND'] = 'tensorflow'
        importlib.reload(gs)

    @classmethod
    def tearDownClass(cls):
        os.environ['GEOMSTATS_BACKEND'] = 'numpy'
        importlib.reload(gs)

    def test_belongs(self):
        """
        A rotation matrix belongs to the matrix Lie group
        of invertible matrices.
        """
        rot_vec = tf.convert_to_tensor([0.2, -0.1, 0.1])
        rot_mat = self.so3_group.matrix_from_rotation_vector(rot_vec)
        result = self.group.belongs(rot_mat)
        expected = tf.convert_to_tensor([True])

        with self.test_session():
            self.assertAllClose(gs.eval(result), gs.eval(expected))

    def test_compose(self):
        # 1. Composition by identity, on the right
        # Expect the original transformation
        rot_vec = tf.convert_to_tensor([0.2, -0.1, 0.1])
        mat = self.so3_group.matrix_from_rotation_vector(rot_vec)

        result = self.group.compose(mat, self.group.identity)
        expected = mat
        expected = helper.to_matrix(mat)

        with self.test_session():
            self.assertAllClose(gs.eval(result), gs.eval(expected))

        # 2. Composition by identity, on the left
        # Expect the original transformation
        rot_vec = tf.convert_to_tensor([0.2, 0.1, -0.1])
        mat = self.so3_group.matrix_from_rotation_vector(rot_vec)

        result = self.group.compose(self.group.identity, mat)
        expected = mat

        with self.test_session():
            self.assertAllClose(gs.eval(result), gs.eval(expected))

    def test_inverse(self):
        mat = tf.convert_to_tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 10.]])
        result = self.group.inverse(mat)
        expected = 1. / 3. * tf.convert_to_tensor(
            [[-2., -4., 3.], [-2., 11., -6.], [3., -6., 3.]])
        expected = helper.to_matrix(expected)

        with self.test_session():
            self.assertAllClose(gs.eval(result), gs.eval(expected))

    def test_compose_and_inverse(self):
        # 1. Compose transformation by its inverse on the right
        # Expect the group identity
        rot_vec = tf.convert_to_tensor([0.2, 0.1, 0.1])
        mat = self.so3_group.matrix_from_rotation_vector(rot_vec)
        inv_mat = self.group.inverse(mat)

        result = self.group.compose(mat, inv_mat)
        expected = self.group.identity
        expected = helper.to_matrix(expected)

        with self.test_session():
            self.assertAllClose(gs.eval(result), gs.eval(expected))

        # 2. Compose transformation by its inverse on the left
        # Expect the group identity
        rot_vec = tf.convert_to_tensor([0.7, 0.1, 0.1])
        mat = self.so3_group.matrix_from_rotation_vector(rot_vec)
        inv_mat = self.group.inverse(mat)

        result = self.group.compose(inv_mat, mat)
        expected = self.group.identity
        expected = helper.to_matrix(expected)

        with self.test_session():
            self.assertAllClose(gs.eval(result), gs.eval(expected))
class SpecialEuclideanGroup(LieGroup):

    def __init__(self, n):
        assert n > 1

        if n is not 3:
            raise NotImplementedError('Only SE(3) is implemented.')

        self.n = n
        self.dimension = int((n * (n - 1)) / 2 + n)
        super(SpecialEuclideanGroup, self).__init__(
                          dimension=self.dimension,
                          identity=np.zeros(self.dimension))
        # TODO(nina): keep the names rotations and translations here?
        self.rotations = SpecialOrthogonalGroup(n=n)
        self.translations = EuclideanSpace(dimension=n)

    def belongs(self, point):
        """
        Check that the transformation belongs to
        the special euclidean group.
        """
        point = vectorization.expand_dims(point, to_ndim=2)
        _, point_dim = point.shape
        return point_dim == self.dimension

    def regularize(self, point):
        """
        Regularize an element of the group SE(3),
        by extracting the rotation vector r from the input [r t]
        and using self.rotations.regularize.

        :param point: 6d vector, element in SE(3) represented as [r t].
        :returns self.regularized_point: 6d vector, element in SE(3)
        with self.regularized rotation.
        """
        point = vectorization.expand_dims(point, to_ndim=2)
        assert self.belongs(point)

        rotations = self.rotations
        dim_rotations = rotations.dimension

        regularized_point = np.zeros_like(point)
        rot_vec = point[:, :dim_rotations]
        regularized_point[:, :dim_rotations] = rotations.regularize(rot_vec)
        regularized_point[:, dim_rotations:] = point[:, dim_rotations:]

        return regularized_point

    def regularize_tangent_vec_at_identity(self, tangent_vec, metric=None):
        return self.regularize_tangent_vec(tangent_vec, self.identity, metric)

    def regularize_tangent_vec(self, tangent_vec, base_point, metric=None):
        """
        Regularize an element of the group SE(3),
        by extracting the rotation vector r from the input [r t]
        and using self.rotations.regularize.

        :param point: 6d vector, element in SE(3) represented as [r t].
        :returns self.regularized_point: 6d vector, element in SE(3)
        with self.regularized rotation.
        """
        if metric is None:
            metric = self.left_canonical_metric

        tangent_vec = vectorization.expand_dims(tangent_vec, to_ndim=2)
        base_point = vectorization.expand_dims(base_point, to_ndim=2)

        rotations = self.rotations
        dim_rotations = rotations.dimension

        rot_tangent_vec = tangent_vec[:, :dim_rotations]
        rot_base_point = base_point[:, :dim_rotations]

        metric_mat = metric.inner_product_mat_at_identity
        rot_metric_mat = metric_mat[:dim_rotations, :dim_rotations]
        rot_metric = InvariantMetric(
                               group=rotations,
                               inner_product_mat_at_identity=rot_metric_mat,
                               left_or_right=metric.left_or_right)

        regularized_vec = np.zeros_like(tangent_vec)
        regularized_vec[:, :dim_rotations] = rotations.regularize_tangent_vec(
                                                 tangent_vec=rot_tangent_vec,
                                                 base_point=rot_base_point,
                                                 metric=rot_metric)
        regularized_vec[:, dim_rotations:] = tangent_vec[:, dim_rotations:]

        return regularized_vec

    def compose(self, point_1, point_2):
        """
        Compose two elements of group SE(3).

        Formula:
        point_1 . point_2 = [R1 * R2, (R1 * t2) + t1]
        where:
        R1, R2 are rotation matrices,
        t1, t2 are translation vectors.

        :param point_1, point_2: 6d vectors elements of SE(3)
        :return composition: composition of point_1 and point_2
        """
        rotations = self.rotations
        dim_rotations = rotations.dimension

        point_1 = self.regularize(point_1)
        point_2 = self.regularize(point_2)

        n_points_1, _ = point_1.shape
        n_points_2, _ = point_2.shape

        assert (point_1.shape == point_2.shape
                or n_points_1 == 1
                or n_points_2 == 1)

        rot_vec_1 = point_1[:, :dim_rotations]
        rot_mat_1 = rotations.matrix_from_rotation_vector(rot_vec_1)
        rot_mat_1 = so_group.closest_rotation_matrix(rot_mat_1)

        rot_vec_2 = point_2[:, :dim_rotations]
        rot_mat_2 = rotations.matrix_from_rotation_vector(rot_vec_2)
        rot_mat_2 = so_group.closest_rotation_matrix(rot_mat_2)

        translation_1 = point_1[:, dim_rotations:]
        translation_2 = point_2[:, dim_rotations:]

        n_compositions = np.maximum(n_points_1, n_points_2)
        composition_rot_mat = np.matmul(rot_mat_1, rot_mat_2)
        composition_rot_vec = rotations.rotation_vector_from_matrix(
                                                          composition_rot_mat)
        composition_translation = np.zeros((n_compositions, self.n))
        for i in range(n_compositions):
            translation_1_i = (translation_1[0] if n_points_1 == 1
                               else translation_1[i])
            rot_mat_1_i = (rot_mat_1[0] if n_points_1 == 1
                           else rot_mat_1[i])
            translation_2_i = (translation_2[0] if n_points_2 == 1
                               else translation_2[i])
            composition_translation[i] = (np.dot(translation_2_i,
                                                 np.transpose(rot_mat_1_i))
                                          + translation_1_i)

        composition = np.zeros((n_compositions, self.dimension))
        composition[:, :dim_rotations] = composition_rot_vec
        composition[:, dim_rotations:] = composition_translation

        composition = self.regularize(composition)
        return composition

    def inverse(self, point):
        """
        Compute the group inverse in SE(3).

        Formula:
        (R, t)^{-1} = (R^{-1}, R^{-1}.(-t))

        :param point: 6d vector element in SE(3)
        :returns inverse_point: 6d vector inverse of point
        """
        rotations = self.rotations
        dim_rotations = rotations.dimension

        point = self.regularize(point)
        n_points, _ = point.shape

        rot_vec = point[:, :dim_rotations]
        translation = point[:, dim_rotations:]

        inverse_point = np.zeros_like(point)
        inverse_rotation = -rot_vec

        inv_rot_mat = rotations.matrix_from_rotation_vector(inverse_rotation)

        inverse_translation = np.zeros((n_points, self.n))
        for i in range(n_points):
            inverse_translation[i] = np.dot(-translation[i],
                                            np.transpose(inv_rot_mat[i]))

        inverse_point[:, :dim_rotations] = inverse_rotation
        inverse_point[:, dim_rotations:] = inverse_translation

        inverse_point = self.regularize(inverse_point)
        return inverse_point

    def jacobian_translation(self, point, left_or_right='left'):
        """
        Compute the jacobian matrix of the differential
        of the left/right translations
        from the identity to point in the Lie group SE(3).

        :param point: 6D vector element of SE(3)
        :returns jacobian: 6x6 matrix
        """
        assert self.belongs(point)
        assert left_or_right in ('left', 'right')

        dim = self.dimension
        rotations = self.rotations
        dim_rotations = rotations.dimension

        point = self.regularize(point)
        n_points, _ = point.shape

        rot_vec = point[:, :dim_rotations]

        jacobian = np.zeros((n_points,) + (dim,) * 2)

        if left_or_right == 'left':
            jacobian_rot = self.rotations.jacobian_translation(
                                                      point=rot_vec,
                                                      left_or_right='left')
            jacobian_trans = self.rotations.matrix_from_rotation_vector(
                    rot_vec)

            jacobian[:, :dim_rotations, :dim_rotations] = jacobian_rot
            jacobian[:, dim_rotations:, dim_rotations:] = jacobian_trans

        else:
            jacobian_rot = self.rotations.jacobian_translation(
                                                      point=rot_vec,
                                                      left_or_right='right')

            inv_skew_mat = - so_group.skew_matrix_from_vector(rot_vec)
            jacobian[:, :dim_rotations, :dim_rotations] = jacobian_rot
            jacobian[:, dim_rotations:, :dim_rotations] = inv_skew_mat
            jacobian[:, dim_rotations:, dim_rotations:] = np.eye(self.n)

        assert jacobian.ndim == 3
        return jacobian

    def group_exp_from_identity(self,
                                tangent_vec):
        """
        Compute the group exponential of vector tangent_vector,
        at point base_point.

        :param tangent_vector: tangent vector of SE(3) at base_point.
        :param base_point: 6d vector element of SE(3).
        :returns group_exp: 6d vector element of SE(3).
        """
        tangent_vec = vectorization.expand_dims(tangent_vec, to_ndim=2)

        rotations = self.rotations
        dim_rotations = rotations.dimension

        rot_vec = tangent_vec[:, :dim_rotations]
        rot_vec = self.rotations.regularize(rot_vec)
        translation = tangent_vec[:, dim_rotations:]

        angle = np.linalg.norm(rot_vec, axis=1)
        angle = vectorization.expand_dims(angle, to_ndim=2, axis=1)

        mask_close_pi = np.isclose(angle, np.pi)
        mask_close_pi = np.squeeze(mask_close_pi, axis=1)
        rot_vec[mask_close_pi] = rotations.regularize(
                                       rot_vec[mask_close_pi])

        skew_mat = so_group.skew_matrix_from_vector(rot_vec)
        sq_skew_mat = np.matmul(skew_mat, skew_mat)

        mask_0 = np.equal(angle, 0)
        mask_close_0 = np.isclose(angle, 0) & ~mask_0

        mask_0 = np.squeeze(mask_0, axis=1)
        mask_close_0 = np.squeeze(mask_close_0, axis=1)

        mask_else = ~mask_0 & ~mask_close_0

        coef_1 = np.zeros_like(angle)
        coef_2 = np.zeros_like(angle)

        coef_1[mask_0] = 1. / 2.
        coef_2[mask_0] = 1. / 6.

        coef_1[mask_close_0] = (1. / 2. - angle[mask_close_0] ** 2 / 24.
                                + angle[mask_close_0] ** 4 / 720.
                                - angle[mask_close_0] ** 6 / 40320.)
        coef_2[mask_close_0] = (1. / 6. - angle[mask_close_0] ** 2 / 120.
                                + angle[mask_close_0] ** 4 / 5040.
                                - angle[mask_close_0] ** 6 / 362880.)

        coef_1[mask_else] = ((1. - np.cos(angle[mask_else]))
                             / angle[mask_else] ** 2)
        coef_2[mask_else] = ((angle[mask_else] - np.sin(angle[mask_else]))
                             / angle[mask_else] ** 3)

        n_tangent_vecs, _ = tangent_vec.shape
        group_exp_translation = np.zeros((n_tangent_vecs, self.n))
        for i in range(n_tangent_vecs):
            translation_i = translation[i]
            term_1_i = coef_1[i] * np.dot(translation_i,
                                          np.transpose(skew_mat[i]))
            term_2_i = coef_2[i] * np.dot(translation_i,
                                          np.transpose(sq_skew_mat[i]))

            group_exp_translation[i] = translation_i + term_1_i + term_2_i

        group_exp = np.zeros_like(tangent_vec)
        group_exp[:, :dim_rotations] = rot_vec
        group_exp[:, dim_rotations:] = group_exp_translation

        group_exp = self.regularize(group_exp)
        return group_exp

    def group_log_from_identity(self,
                                point):
        """
        Compute the group logarithm of point point,
        from the identity.
        """
        assert self.belongs(point)
        point = self.regularize(point)

        rotations = self.rotations
        dim_rotations = rotations.dimension

        rot_vec = point[:, :dim_rotations]
        angle = np.linalg.norm(rot_vec, axis=1)
        angle = vectorization.expand_dims(angle, to_ndim=2, axis=1)

        translation = point[:, dim_rotations:]

        group_log = np.zeros_like(point)
        group_log[:, :dim_rotations] = rot_vec
        skew_rot_vec = so_group.skew_matrix_from_vector(rot_vec)
        sq_skew_rot_vec = np.matmul(skew_rot_vec, skew_rot_vec)

        mask_close_0 = np.isclose(angle, 0)
        mask_close_0 = np.squeeze(mask_close_0, axis=1)

        mask_close_pi = np.isclose(angle, np.pi)
        mask_close_pi = np.squeeze(mask_close_pi, axis=1)

        mask_else = ~mask_close_0 & ~mask_close_pi

        coef_1 = - 0.5 * np.ones_like(angle)
        coef_2 = np.zeros_like(angle)

        coef_2[mask_close_0] = (1. / 12. + angle[mask_close_0] ** 2 / 720.
                                + angle[mask_close_0] ** 4 / 30240.
                                + angle[mask_close_0] ** 6 / 1209600.)

        delta_angle = angle[mask_close_pi] - np.pi
        coef_2[mask_close_pi] = (1. / PI2
                                 + (PI2 - 8.) * delta_angle / (4. * PI3)
                                 - ((PI2 - 12.)
                                    * delta_angle ** 2 / (4. * PI4))
                                 + ((-192. + 12. * PI2 + PI4)
                                    * delta_angle ** 3 / (48. * PI5))
                                 - ((-240. + 12. * PI2 + PI4)
                                    * delta_angle ** 4 / (48. * PI6))
                                 + ((-2880. + 120. * PI2 + 10. * PI4 + PI6)
                                    * delta_angle ** 5 / (480. * PI7))
                                 - ((-3360 + 120. * PI2 + 10. * PI4 + PI6)
                                    * delta_angle ** 6 / (480. * PI8)))

        psi = (0.5 * angle[mask_else]
               * np.sin(angle[mask_else]) / (1 - np.cos(angle[mask_else])))
        coef_2[mask_else] = (1 - psi) / (angle[mask_else] ** 2)

        n_points, _ = point.shape
        group_log_translation = np.zeros((n_points, self.n))
        for i in range(n_points):
            translation_i = translation[i]
            term_1_i = coef_1[i] * np.dot(translation_i,
                                          np.transpose(skew_rot_vec[i]))
            term_2_i = coef_2[i] * np.dot(translation_i,
                                          np.transpose(sq_skew_rot_vec[i]))
            group_log_translation[i] = translation_i + term_1_i + term_2_i

        group_log[:, dim_rotations:] = group_log_translation

        assert group_log.ndim == 2
        return group_log

    def random_uniform(self, n_samples=1):
        """
        Generate an 6d vector element of SE(3) uniformly,
        by generating separately a rotation vector uniformly
        on the hypercube of sides [-1, 1] in the tangent space,
        and a translation in the hypercube of side [-1, 1] in
        the euclidean space.
        """
        random_rot_vec = self.rotations.random_uniform(n_samples)
        random_translation = self.translations.random_uniform(n_samples)

        random_transfo = np.concatenate([random_rot_vec, random_translation],
                                        axis=1)
        random_transfo = self.regularize(random_transfo)
        return random_transfo

    def exponential_matrix(self, rot_vec):
        """
        Compute the exponential of the rotation matrix
        represented by rot_vec.

        :param rot_vec: 3D rotation vector
        :returns exponential_mat: 3x3 matrix
        """

        rot_vec = self.rotations.regularize(rot_vec)
        n_rot_vecs, _ = rot_vec.shape

        angle = np.linalg.norm(rot_vec, axis=1)
        angle = vectorization.expand_dims(angle, to_ndim=2, axis=1)

        skew_rot_vec = so_group.skew_matrix_from_vector(rot_vec)

        coef_1 = np.empty_like(angle)
        coef_2 = np.empty_like(coef_1)

        mask_0 = np.equal(angle, 0)
        mask_0 = np.squeeze(mask_0, axis=1)
        mask_close_to_0 = np.isclose(angle, 0)
        mask_close_to_0 = np.squeeze(mask_close_to_0, axis=1)
        mask_else = ~mask_0 & ~mask_close_to_0

        coef_1[mask_close_to_0] = (1. / 2.
                                   - angle[mask_close_to_0] ** 2 / 24.)
        coef_2[mask_close_to_0] = (1. / 6.
                                   - angle[mask_close_to_0] ** 3 / 120.)

        # TODO(nina): check if the discountinuity as 0 is expected.
        coef_1[mask_0] = 0
        coef_2[mask_0] = 0

        coef_1[mask_else] = (angle[mask_else] ** (-2)
                             * (1. - np.cos(angle[mask_else])))
        coef_2[mask_else] = (angle[mask_else] ** (-2)
                             * (1. - (np.sin(angle[mask_else])
                                      / angle[mask_else])))

        term_1 = np.zeros((n_rot_vecs, self.n, self.n))
        term_2 = np.zeros_like(term_1)

        for i in range(n_rot_vecs):
            term_1[i] = np.eye(self.n) + skew_rot_vec[i] * coef_1[i]
            term_2[i] = np.matmul(skew_rot_vec[i], skew_rot_vec[i]) * coef_2[i]

        exponential_mat = term_1 + term_2
        assert exponential_mat.ndim == 3
        return exponential_mat

    def group_exponential_barycenter(self, points, weights=None):
        """
        Compute the group exponential barycenter.

        :param points: SE3 data points, Nx6 array
        :param weights: data point weights, Nx1 array
        """

        n_points = points.shape[0]
        assert n_points > 0

        if weights is None:
            weights = np.ones((n_points, 1))

        weights = vectorization.expand_dims(weights, to_ndim=2, axis=1)
        n_weights, _ = weights.shape
        assert n_points == n_weights

        dim = self.dimension
        rotations = self.rotations
        dim_rotations = rotations.dimension

        rotation_vectors = points[:, :dim_rotations]
        translations = points[:, dim_rotations:dim]
        assert rotation_vectors.shape == (n_points, dim_rotations)
        assert translations.shape == (n_points, self.n)

        mean_rotation = rotations.group_exponential_barycenter(
                                                points=rotation_vectors,
                                                weights=weights)
        mean_rotation_mat = rotations.matrix_from_rotation_vector(
                    mean_rotation)

        matrix = np.zeros((1,) + (self.n,) * 2)
        translation_aux = np.zeros((1, self.n))

        inv_rot_mats = rotations.matrix_from_rotation_vector(
                -rotation_vectors)
        # TODO(nina): this is the same mat multiplied several times
        matrix_aux = np.matmul(mean_rotation_mat, inv_rot_mats)
        assert matrix_aux.shape == (n_points,) + (dim_rotations,) * 2

        vec_aux = rotations.rotation_vector_from_matrix(matrix_aux)
        matrix_aux = self.exponential_matrix(vec_aux)
        matrix_aux = np.linalg.inv(matrix_aux)

        for i in range(n_points):
            matrix += weights[i] * matrix_aux[i]
            translation_aux += weights[i] * np.dot(np.matmul(
                                                        matrix_aux[i],
                                                        inv_rot_mats[i]),
                                                   translations[i])

        mean_translation = np.dot(translation_aux,
                                  np.transpose(np.linalg.inv(matrix),
                                               axes=(0, 2, 1)))

        exp_bar = np.zeros((1, dim))
        exp_bar[0, :dim_rotations] = mean_rotation
        exp_bar[0, dim_rotations:dim] = mean_translation

        return exp_bar