Exemple #1
0
    def setUp(self):
        self.n_samples = 10
        self.SO3_GROUP = SpecialOrthogonalGroup(n=3)
        self.SE3_GROUP = SpecialEuclideanGroup(n=3)
        self.S1 = Hypersphere(dimension=1)
        self.S2 = Hypersphere(dimension=2)
        self.H2 = HyperbolicSpace(dimension=2)

        plt.figure()
Exemple #2
0
    def setUp(self):
        gs.random.seed(1234)
        self.n = 3
        self.n_samples = 2
        self.group = GeneralLinearGroup(n=self.n)
        # We generate invertible matrices using so3_group
        self.so3_group = SpecialOrthogonalGroup(n=self.n)

        warnings.simplefilter('ignore', category=ImportWarning)
Exemple #3
0
    def setUp(self):
        gs.random.seed(1234)

        n = 3
        group = SpecialOrthogonalGroup(n=n)

        # Diagonal left and right invariant metrics
        diag_mat_at_identity = gs.eye(group.dimension)

        left_diag_metric = InvariantMetric(
            group=group,
            inner_product_mat_at_identity=diag_mat_at_identity,
            left_or_right='left')
        right_diag_metric = InvariantMetric(
            group=group,
            inner_product_mat_at_identity=diag_mat_at_identity,
            left_or_right='right')

        # General left and right invariant metrics
        # TODO(nina): replace by general SPD matrix
        sym_mat_at_identity = gs.eye(group.dimension)

        left_metric = InvariantMetric(
            group=group,
            inner_product_mat_at_identity=sym_mat_at_identity,
            left_or_right='left')

        right_metric = InvariantMetric(
            group=group,
            inner_product_mat_at_identity=sym_mat_at_identity,
            left_or_right='right')

        metrics = {
            'left_diag': left_diag_metric,
            'right_diag_metric': right_diag_metric,
            'left': left_metric,
            'right': right_metric
        }

        # General case for the point
        point_1 = tf.convert_to_tensor([-0.2, 0.9, 0.5])
        point_2 = tf.convert_to_tensor([0., 2., -0.1])
        # Edge case for the point, angle < epsilon,
        point_small = tf.convert_to_tensor([[-1e-7, 0., -7 * 1e-8]])

        self.group = group
        self.metrics = metrics

        self.left_diag_metric = left_diag_metric
        self.right_diag_metric = right_diag_metric
        self.left_metric = left_metric
        self.right_metric = right_metric
        self.point_1 = point_1
        self.point_2 = point_2
        self.point_small = point_small
Exemple #4
0
    def __init__(self, n):
        assert isinstance(n, int) and n > 1

        self.n = n
        self.dimension = int((n * (n - 1)) / 2 + n)
        super(SpecialEuclideanGroup, self).__init__(
                          dimension=self.dimension,
                          identity=gs.zeros(self.dimension))
        # TODO(nina): keep the names rotations and translations here?
        self.rotations = SpecialOrthogonalGroup(n=n)
        self.translations = EuclideanSpace(dimension=n)
        self.point_representation = 'vector' if n == 3 else 'matrix'
    def __init__(self, n):
        assert n > 1

        if n is not 3:
            raise NotImplementedError('Only SE(3) is implemented.')

        self.n = n
        self.dimension = int((n * (n - 1)) / 2 + n)
        super(SpecialEuclideanGroup, self).__init__(
                          dimension=self.dimension,
                          identity=np.zeros(self.dimension))
        # TODO(nina): keep the names rotations and translations here?
        self.rotations = SpecialOrthogonalGroup(n=n)
        self.translations = EuclideanSpace(dimension=n)
Exemple #6
0
    def __init__(self, n, point_type=None):
        assert isinstance(n, int) and n > 1

        self.n = n
        self.dimension = int((n * (n - 1)) / 2 + n)

        self.default_point_type = point_type
        if point_type is None:
            self.default_point_type = 'vector' if n == 3 else 'matrix'

        super(SpecialEuclideanGroup, self).__init__(dimension=self.dimension)

        # TODO(nina): keep the names rotations and translations here?
        self.rotations = SpecialOrthogonalGroup(n=n)
        self.translations = EuclideanSpace(dimension=n)
Exemple #7
0
    def __init__(self, n, point_type=None, epsilon=0.):
        assert isinstance(n, int) and n > 1

        self.n = n
        self.dimension = int((n * (n - 1)) / 2 + n)

        self.epsilon = epsilon

        self.default_point_type = point_type
        if point_type is None:
            self.default_point_type = 'vector' if n == 3 else 'matrix'

        super(SpecialEuclideanGroup, self).__init__(dimension=self.dimension)

        self.rotations = SpecialOrthogonalGroup(n=n, epsilon=epsilon)
        self.translations = EuclideanSpace(dimension=n)
Exemple #8
0
def main(args):

    poses   = []
    images  = []

    # Processing Image Lables
    logger.info('Processing Image Lables')
    with open(FLAGS.root_dir + '/' + FLAGS.dataset) as f:
        next(f)  # skip the 3 header lines
        next(f)
        next(f)
        for line in f:
            fname, p0, p1, p2, p3, p4, p5, p6 = line.split()
            p0 = float(p0)
            p1 = float(p1)
            p2 = float(p2)
            p3 = float(p3)
            p4 = float(p4)
            p5 = float(p5)
            p6 = float(p6)
            poses.append((p0, p1, p2, p3, p4, p5, p6))
            images.append(FLAGS.root_dir + '/' + fname)

    r = list(range(len(images)))
    random.shuffle(r)
    random.shuffle(r)
    random.shuffle(r)

    # Writing TFRecords
    logger.info('Writing TFRecords')

    SO3_GROUP   = SpecialOrthogonalGroup(3)
    writer      = tf.python_io.TFRecordWriter(FLAGS.out_file)

    for i in tqdm(r):

        pose_q  = np.array(poses[i][3:7])
        pose_x  = np.array(poses[i][0:3])

        rot_vec = SO3_GROUP.rotation_vector_from_quaternion(pose_q)[0]
        pose    = np.concatenate((rot_vec, pose_x), axis=0)

        logger.info('Processing Image: ' + images[i])
        X = imageio.imread(images[i])
        X = X[::4, ::4, :]
        if FLAGS.hist_norm:
            X = exposure.equalize_hist(X)

        img_raw     = X.tostring()
        pose_raw    = pose.astype('float32').tostring()
        pose_q_raw  = pose_q.astype('float32').tostring()
        pose_x_raw  = pose_x.astype('float32').tostring()

        example = tf.train.Example(features=tf.train.Features(feature={
            'height':   _int64_feature(X.shape[0]),
            'width':    _int64_feature(X.shape[1]),
            'channel':  _int64_feature(X.shape[2]),
            'image':    _bytes_feature(img_raw),
            'pose':     _bytes_feature(pose_raw),
            'pose_q':   _bytes_feature(pose_q_raw),
            'pose_x':   _bytes_feature(pose_x_raw)}))

        writer.write(example.SerializeToString())

    writer.close()
    logger.info('\n', 'Creating Dataset Success.')
"""Unit tests for visualization module."""

import matplotlib
matplotlib.use('Agg')  # NOQA
import unittest

import geomstats.visualization as visualization
from geomstats.hypersphere import Hypersphere
from geomstats.special_euclidean_group import SpecialEuclideanGroup
from geomstats.special_orthogonal_group import SpecialOrthogonalGroup

SO3_GROUP = SpecialOrthogonalGroup(n=3)
SE3_GROUP = SpecialEuclideanGroup(n=3)
S2 = Hypersphere(dimension=2)

# TODO(nina): add tests for examples


class TestVisualizationMethods(unittest.TestCase):
    _multiprocess_can_split_ = True

    def setUp(self):
        self.n_samples = 10

    def test_plot_points_so3(self):
        points = SO3_GROUP.random_uniform(self.n_samples)
        visualization.plot(points, space='SO3_GROUP')

    def test_plot_points_se3(self):
        points = SE3_GROUP.random_uniform(self.n_samples)
        visualization.plot(points, space='SE3_GROUP')
"""
Predict on SE3: losses.
"""
import numpy as np

import geomstats.lie_group as lie_group
from geomstats.special_euclidean_group import SpecialEuclideanGroup
from geomstats.special_orthogonal_group import SpecialOrthogonalGroup

SE3 = SpecialEuclideanGroup(n=3)
SO3 = SpecialOrthogonalGroup(n=3)


def loss(y_pred,
         y_true,
         metric=SE3.left_canonical_metric,
         representation='vector'):
    """
    Loss function given by a riemannian metric on a Lie group,
    by default the left-invariant canonical metric.
    """
    if y_pred.ndim == 1:
        y_pred = np.expand_dims(y_pred, axis=0)
    if y_true.ndim == 1:
        y_true = np.expand_dims(y_true, axis=0)

    if representation == 'quaternion':
        y_pred_rot_vec = SO3.rotation_vector_from_quaternion(y_pred[:, :4])
        y_pred = np.hstack([y_pred_rot_vec, y_pred[:, 4:]])
        y_true_rot_vec = SO3.rotation_vector_from_quaternion(y_true[:, :4])
        y_true = np.hstack([y_true_rot_vec, y_true[:, 4:]])
Exemple #11
0
def main(argv):

    # TF Record
    datafiles = FLAGS.data_dir + '/test/' + FLAGS.subject_id + '.tfrecord'
    dataset = tf.data.TFRecordDataset(datafiles)
    dataset = dataset.map(_parse_function_ifind)
    # dataset = dataset.repeat()
    # dataset = dataset.shuffle(FLAGS.queue_buffer)
    dataset = dataset.batch(1)
    image, vec, qt, AP1, AP2, AP3 = dataset.make_one_shot_iterator().get_next()

    # Nifti Volume
    subject_path = FLAGS.scan_dir + '/test/' + FLAGS.subject_id + '.nii.gz'
    fixed_image_sitk_tmp = sitk.ReadImage(subject_path, sitk.sitkFloat32)
    fixed_image_sitk = sitk.GetImageFromArray(
        sitk.GetArrayFromImage(fixed_image_sitk_tmp))
    fixed_image_sitk = sitk.RescaleIntensity(fixed_image_sitk, 0, 1) * 255.

    # Network Definition
    image_resized = tf.image.resize_images(image, size=[224, 224])

    # Measurements
    cc = []
    mse = []
    psnr = []
    ssim = []

    if FLAGS.loss == 'PoseNet':

        y_pred, _ = inception.inception_v3(image_resized,
                                           num_classes=7,
                                           is_training=False)
        quaternion_pred, translation_pred = tf.split(y_pred, [4, 3], axis=1)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        SO3_GROUP = SpecialOrthogonalGroup(3)

        for i in tqdm.tqdm(range(FLAGS.n_iter)):

            _image, _quaternion_true, _translation_true, _quaternion_pred, _translation_pred = \
                sess.run([image, qt, AP2, quaternion_pred, translation_pred])

            rx = SO3_GROUP.matrix_from_quaternion(_quaternion_pred)[0]
            tx = _translation_pred[0] * 60.

            image_true = np.squeeze(_image)
            image_pred = resample_sitk(fixed_image_sitk, rx, tx)

            imageio.imsave('imgdump/image_{}_true.png'.format(i),
                           np.uint8(_image[0, ...]))
            imageio.imsave('imgdump/image_{}_pred.png'.format(i),
                           np.uint8(image_pred))

            cc.append(calc_correlation(image_pred, image_true))
            mse.append(calc_mse(image_pred, image_true))
            psnr.append(calc_psnr(image_pred, image_true))
            ssim.append(calc_ssim(image_pred, image_true))

    elif FLAGS.loss == 'AP':

        y_pred, _ = inception.inception_v3(image_resized,
                                           num_classes=9,
                                           is_training=False)
        AP1_pred, AP2_pred, AP3_pred = tf.split(y_pred, 3, axis=1)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        for i in tqdm.tqdm(range(FLAGS.n_iter)):

            _image, _AP1, _AP2, _AP3, _AP1_pred, _AP2_pred, _AP3_pred = \
                sess.run([image, AP1, AP2, AP3, AP1_pred, AP2_pred, AP3_pred])

            dist_ap1 = np.linalg.norm(_AP1 - _AP1_pred)
            dist_ap2 = np.linalg.norm(_AP2 - _AP2_pred)
            dist_ap3 = np.linalg.norm(_AP3 - _AP3_pred)

            rx = matrix_from_anchor_points(_AP1_pred[0], _AP2_pred[0],
                                           _AP3_pred[0])
            tx = _AP2_pred[0] * 60.

            image_true = np.squeeze(_image)
            image_pred = resample_sitk(fixed_image_sitk, rx, tx)

            imageio.imsave('imgdump/image_{}_true.png'.format(i),
                           np.uint8(_image[0, ...]))
            imageio.imsave('imgdump/image_{}_pred.png'.format(i),
                           np.uint8(image_pred))

            cc.append(calc_correlation(image_pred, image_true))
            mse.append(calc_mse(image_pred, image_true))
            psnr.append(calc_psnr(image_pred, image_true))
            ssim.append(calc_ssim(image_pred, image_true))

    elif FLAGS.loss == 'SE3':

        y_pred, _ = inception.inception_v3(image_resized,
                                           num_classes=6,
                                           is_training=False)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        SO3_GROUP = SpecialOrthogonalGroup(3)
        SE3_GROUP = SpecialEuclideanGroup(3)
        _se3_err_i = []

        for i in tqdm.tqdm(range(FLAGS.n_iter)):

            _image, _rvec, _tvec, _y_pred = \
                sess.run([image, vec, AP2, y_pred])

            rx = SO3_GROUP.matrix_from_rotation_vector(_y_pred[0, :3])[0]
            tx = _y_pred[0, 3:] * 60.

            image_true = np.squeeze(_image)
            image_pred = resample_sitk(fixed_image_sitk, rx, tx)

            imageio.imsave('imgdump/image_{}_true.png'.format(i),
                           np.uint8(_image[0, ...]))
            imageio.imsave('imgdump/image_{}_pred.png'.format(i),
                           np.uint8(image_pred))

            cc.append(calc_correlation(image_pred, image_true))
            mse.append(calc_mse(image_pred, image_true))
            psnr.append(calc_psnr(image_pred, image_true))
            ssim.append(calc_ssim(image_pred, image_true))

            _y_true = np.concatenate((_rvec, _tvec), axis=-1)
            _se3_err_i.append(
                SE3_GROUP.compose(SE3_GROUP.inverse(_y_true), _y_pred))

        err_vec = np.vstack(_se3_err_i)
        err_weights = np.diag(np.linalg.inv(np.cov(err_vec.T)))
        err_weights = err_weights / np.linalg.norm(err_weights)
        print(err_weights)

    else:
        print('Invalid Option:', FLAGS.loss)
        raise SystemExit

    cc = np.stack(cc)
    mse = np.stack(mse)
    psnr = np.stack(psnr)
    ssim = np.stack(ssim)

    print('CC:', np.median(cc))
    print('MSE:', np.median(mse))
    print('PSNR:', np.median(psnr))
    print('SSIM:', np.median(ssim))
Exemple #12
0
def main(argv):

    # TF Record
    datafiles = FLAGS.data_dir + '/test/' + FLAGS.subject_id + '.tfrecord'
    dataset = tf.data.TFRecordDataset(datafiles)
    dataset = dataset.map(_parse_function_ifind)
    # dataset = dataset.repeat()
    # dataset = dataset.shuffle(FLAGS.queue_buffer)
    dataset = dataset.batch(1)
    image, vec, qt, AP1, AP2, AP3 = dataset.make_one_shot_iterator().get_next()

    # Nifti Volume
    subject_path = FLAGS.scan_dir + '/test/' + FLAGS.subject_id + '.nii.gz'
    fixed_image_sitk_tmp = sitk.ReadImage(subject_path, sitk.sitkFloat32)
    fixed_image_sitk = sitk.GetImageFromArray(
        sitk.GetArrayFromImage(fixed_image_sitk_tmp))
    fixed_image_sitk = sitk.RescaleIntensity(fixed_image_sitk, 0, 1)  # * 255.

    # Network Definition
    image_input = tf.placeholder(shape=[1, 224, 224, 1], dtype=tf.float32)
    image_resized = tf.image.resize_images(image, size=[224, 224])

    if FLAGS.loss == 'PoseNet':

        y_pred, _ = inception.inception_v3(image_input, num_classes=7)
        quaternion_pred, translation_pred = tf.split(y_pred, [4, 3], axis=1)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        SO3_GROUP = SpecialOrthogonalGroup(3)

        for i in range(FLAGS.n_iter):

            _image, _image_resized, _quaternion_true, _translation_true = \
                sess.run([image, image_resized, qt, AP2], )

            _quaternion_pred_sample = []
            _translation_pred_sample = []
            for j in range(FLAGS.n_samples):
                _quaternion_pred_i, _translation_pred_i = \
                    sess.run([quaternion_pred, translation_pred],
                             feed_dict={image_input: _image_resized})
                _quaternion_pred_sample.append(_quaternion_pred_i)
                _translation_pred_sample.append(_translation_pred_i)
                print(_quaternion_pred_i, _translation_pred_i)

            _quaternion_pred_sample = np.vstack(_quaternion_pred_sample)
            _rotvec_pred_sample = SO3_GROUP.rotation_vector_from_quaternion(
                _quaternion_pred_sample)
            _rotvec_pred = SO3_GROUP.left_canonical_metric.mean(
                _rotvec_pred_sample)

            _quaternion_pred = SO3_GROUP.quaternion_from_rotation_vector(
                _rotvec_pred)
            _translation_pred = np.mean(np.vstack(_translation_pred_sample),
                                        axis=0)

            # _quaternion_pred_variance = SO3_GROUP.left_canonical_metric.variance(_rotvec_pred_sample)
            _translation_pred_variance = np.var(
                np.vstack(_translation_pred_sample), axis=0)

            rx = SO3_GROUP.matrix_from_quaternion(_quaternion_pred)[0]
            tx = _translation_pred[0] * 60.

            image_true = np.squeeze(_image)
            image_pred = resample_sitk(fixed_image_sitk, rx, tx)

            imageio.imsave('imgdump/image_{}_true.png'.format(i), _image[0,
                                                                         ...])
            imageio.imsave('imgdump/image_{}_pred.png'.format(i), image_pred)

            calc_psnr(image_pred, image_true)
            calc_mse(image_pred, image_true)
            calc_ssim(image_pred, image_true)
            calc_correlation(image_pred, image_true)

    elif FLAGS.loss == 'AP':

        y_pred, _ = inception.inception_v3(image_input, num_classes=9)
        AP1_pred, AP2_pred, AP3_pred = tf.split(y_pred, 3, axis=1)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        for i in range(FLAGS.n_iter):

            _image, _image_resized, _AP1, _AP2, _AP3 = \
                sess.run([image, image_resized, AP1, AP2, AP3])

            _AP1_sample = []
            _AP2_sample = []
            _AP3_sample = []
            for j in range(FLAGS.n_samples):
                _AP1_pred_i, _AP2_pred_i, _AP3_pred_i = \
                    sess.run([AP1_pred, AP2_pred, AP3_pred],
                             feed_dict={image_input: _image_resized})
                _AP1_sample.append(_AP1_pred_i)
                _AP2_sample.append(_AP2_pred_i)
                _AP3_sample.append(_AP3_pred_i)

            _AP1_pred = np.mean(np.vstack(_AP1_sample), axis=0)
            _AP2_pred = np.mean(np.vstack(_AP2_sample), axis=0)
            _AP3_pred = np.mean(np.vstack(_AP3_sample), axis=0)

            _AP1_pred_variance = np.var(np.vstack(_AP1_sample), axis=0)
            _AP2_pred_variance = np.var(np.vstack(_AP2_sample), axis=0)
            _AP3_pred_variance = np.var(np.vstack(_AP3_sample), axis=0)

            dist_ap1 = np.linalg.norm(_AP1 - _AP1_pred)
            dist_ap2 = np.linalg.norm(_AP2 - _AP2_pred)
            dist_ap3 = np.linalg.norm(_AP3 - _AP3_pred)

            rx = matrix_from_anchor_points(_AP1_pred[0], _AP2_pred[0],
                                           _AP3_pred[0])
            tx = _AP2_pred[0] * 60.

            image_true = np.squeeze(_image)
            image_pred = resample_sitk(fixed_image_sitk, rx, tx)

            imageio.imsave('imgdump/image_{}_true.png'.format(i), _image[0,
                                                                         ...])
            imageio.imsave('imgdump/image_{}_pred.png'.format(i), image_pred)

            calc_psnr(image_pred, image_true)
            calc_mse(image_pred, image_true)
            calc_ssim(image_pred, image_true)
            calc_correlation(image_pred, image_true)

    elif FLAGS.loss == 'SE3':

        y_pred, _ = inception.inception_v3(image_input, num_classes=6)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        SO3_GROUP = SpecialOrthogonalGroup(3)
        SE3_GROUP = SpecialEuclideanGroup(3)

        for i in range(FLAGS.n_iter):

            print(i)

            _image, _image_resized, _rvec, _tvec = \
                sess.run([image, image_resized, vec, AP2])

            _y_pred_sample = []
            for j in range(FLAGS.n_samples):
                _y_pred_i = sess.run([y_pred],
                                     feed_dict={image_input: _image_resized})
                _y_pred_sample.append(_y_pred_i[0])

            _y_pred_sample = np.vstack(_y_pred_sample)
            _y_pred = SE3_GROUP.left_canonical_metric.mean(_y_pred_sample)
            _y_pred_variance = SE3_GROUP.left_canonical_metric.variance(
                _y_pred_sample)

            rx = SO3_GROUP.matrix_from_rotation_vector(_y_pred[0, :3])[0]
            tx = _y_pred[0, 3:] * 60.

            image_true = np.squeeze(_image)
            image_pred = resample_sitk(fixed_image_sitk, rx, tx)

            imageio.imsave('imgdump/image_{}_true.png'.format(i), _image[0,
                                                                         ...])
            imageio.imsave('imgdump/image_{}_pred.png'.format(i), image_pred)

            calc_psnr(image_pred, image_true)
            calc_mse(image_pred, image_true)
            calc_ssim(image_pred, image_true)
            calc_correlation(image_pred, image_true)

    else:
        print('Invalid Option:', FLAGS.loss)
        raise SystemExit
Exemple #13
0
    def setUp(self):
        warnings.simplefilter('ignore', category=ImportWarning)

        self.so3_group = SpecialOrthogonalGroup(n=3)
        self.n_samples = 2
Exemple #14
0
 def setUp(self):
     gs.random.seed(1234)
     n = 3
     self.group = GeneralLinearGroup(n=n)
     # We generate invertible matrices using so3_group
     self.so3_group = SpecialOrthogonalGroup(n=n)
def main(argv):

    # TF Record
    dataset = tf.data.TFRecordDataset(FLAGS.data_dir +
                                      '/dataset_test.tfrecords')
    dataset = dataset.map(_parse_function_kingscollege)
    # dataset = dataset.repeat()
    # dataset = dataset.shuffle(FLAGS.queue_buffer)
    dataset = dataset.batch(1)
    image, vec, pose_q, pose_x = dataset.make_one_shot_iterator().get_next()

    # Network Definition
    image_resized = tf.image.resize_images(image, size=[224, 224])

    if FLAGS.loss == 'PoseNet':

        y_pred, _ = inception.inception_v3(image_resized,
                                           num_classes=7,
                                           is_training=False)
        quaternion_pred, translation_pred = tf.split(y_pred, [4, 3], axis=1)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        i = 0

        results = []

        try:

            while True:
                _image, _quaternion_true, _translation_true, _quaternion_pred, _translation_pred = \
                    sess.run([image, pose_q, pose_x, quaternion_pred, translation_pred])

                # Compute Individual Sample Error
                q1 = _quaternion_true / np.linalg.norm(_quaternion_true)
                q2 = _quaternion_pred / np.linalg.norm(_quaternion_pred)
                d = abs(np.sum(np.multiply(q1, q2)))
                theta = 2. * np.arccos(d) * 180. / np.pi
                error_x = np.linalg.norm(_translation_true - _translation_pred)

                results.append([error_x, theta])

                print('Iteration:', i, 'Error XYZ (m):', error_x,
                      'Error Q (degrees):', theta)
                i = i + 1

        except tf.errors.OutOfRangeError:
            print('End of Test Data')

        results = np.stack(results)
        results = np.median(results, axis=0)
        print('Error XYZ (m):', results[0], 'Error Q (degrees):', results[1])

    elif FLAGS.loss == 'SE3':

        y_pred, _ = inception.inception_v3(image_resized,
                                           num_classes=6,
                                           is_training=False)

        sess = tf.Session()

        ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        tf.train.Saver().restore(sess, ckpt_file)
        print('restoring parameters from', ckpt_file)

        SO3_GROUP = SpecialOrthogonalGroup(3)
        SE3_GROUP = SpecialEuclideanGroup(3)
        metric = InvariantMetric(group=SE3_GROUP,
                                 inner_product_mat_at_identity=np.eye(6),
                                 left_or_right='left')

        i = 0

        results = []
        _y_pred_i = []
        _y_true_i = []
        _se3_err_i = []

        try:

            while True:
                _image, _rvec, _qvec, _tvec, _y_pred = \
                    sess.run([image, vec, pose_q, pose_x, y_pred])

                _quaternion_true = _qvec
                _quaternion_pred = SO3_GROUP.quaternion_from_rotation_vector(
                    _y_pred[0, :3])[0]

                # Compute Individual Sample Error
                q1 = _quaternion_true / np.linalg.norm(_quaternion_true)
                q2 = _quaternion_pred / np.linalg.norm(_quaternion_pred)
                d = abs(np.sum(np.multiply(q1, q2)))
                theta = 2. * np.arccos(d) * 180. / np.pi
                error_x = np.linalg.norm(_tvec - _y_pred[0, 3:])
                results.append([error_x, theta])

                # SE3 compute
                _y_true = np.concatenate((_rvec, _tvec), axis=-1)
                se3_dist = metric.squared_dist(_y_pred, _y_true)[0]

                _y_pred_i.append(_y_pred)
                _y_true_i.append(_y_true)
                _se3_err_i.append(
                    SE3_GROUP.compose(SE3_GROUP.inverse(_y_true), _y_pred))

                print('Iteration:', i, 'Error XYZ (m):', error_x,
                      'Error Q (degrees):', theta, 'SE3 dist:', se3_dist)
                i = i + 1

        except tf.errors.OutOfRangeError:
            print('End of Test Data')

        # Calculate SE3 Error Weights
        err_vec = np.vstack(_se3_err_i)
        err_weights = np.diag(np.linalg.inv(np.cov(err_vec.T)))
        err_weights = err_weights / np.linalg.norm(err_weights)
        print(err_weights)
        results = np.stack(results)
        results = np.median(results, axis=0)
        print('Error XYZ (m):', results[0], 'Error Q (degrees):', results[1])

    else:
        print('Invalid Option:', FLAGS.loss)
        raise SystemExit