def main(argv): # TF Record datafiles = FLAGS.data_dir + '/test/' + FLAGS.subject_id + '.tfrecord' dataset = tf.data.TFRecordDataset(datafiles) dataset = dataset.map(_parse_function_ifind) # dataset = dataset.repeat() # dataset = dataset.shuffle(FLAGS.queue_buffer) dataset = dataset.batch(1) image, vec, qt, AP1, AP2, AP3 = dataset.make_one_shot_iterator().get_next() # Nifti Volume subject_path = FLAGS.scan_dir + '/test/' + FLAGS.subject_id + '.nii.gz' fixed_image_sitk_tmp = sitk.ReadImage(subject_path, sitk.sitkFloat32) fixed_image_sitk = sitk.GetImageFromArray( sitk.GetArrayFromImage(fixed_image_sitk_tmp)) fixed_image_sitk = sitk.RescaleIntensity(fixed_image_sitk, 0, 1) * 255. # Network Definition image_resized = tf.image.resize_images(image, size=[224, 224]) # Measurements cc = [] mse = [] psnr = [] ssim = [] if FLAGS.loss == 'PoseNet': y_pred, _ = inception.inception_v3(image_resized, num_classes=7, is_training=False) quaternion_pred, translation_pred = tf.split(y_pred, [4, 3], axis=1) sess = tf.Session() ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir) tf.train.Saver().restore(sess, ckpt_file) print('restoring parameters from', ckpt_file) SO3_GROUP = SpecialOrthogonalGroup(3) for i in tqdm.tqdm(range(FLAGS.n_iter)): _image, _quaternion_true, _translation_true, _quaternion_pred, _translation_pred = \ sess.run([image, qt, AP2, quaternion_pred, translation_pred]) rx = SO3_GROUP.matrix_from_quaternion(_quaternion_pred)[0] tx = _translation_pred[0] * 60. image_true = np.squeeze(_image) image_pred = resample_sitk(fixed_image_sitk, rx, tx) imageio.imsave('imgdump/image_{}_true.png'.format(i), np.uint8(_image[0, ...])) imageio.imsave('imgdump/image_{}_pred.png'.format(i), np.uint8(image_pred)) cc.append(calc_correlation(image_pred, image_true)) mse.append(calc_mse(image_pred, image_true)) psnr.append(calc_psnr(image_pred, image_true)) ssim.append(calc_ssim(image_pred, image_true)) elif FLAGS.loss == 'AP': y_pred, _ = inception.inception_v3(image_resized, num_classes=9, is_training=False) AP1_pred, AP2_pred, AP3_pred = tf.split(y_pred, 3, axis=1) sess = tf.Session() ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir) tf.train.Saver().restore(sess, ckpt_file) print('restoring parameters from', ckpt_file) for i in tqdm.tqdm(range(FLAGS.n_iter)): _image, _AP1, _AP2, _AP3, _AP1_pred, _AP2_pred, _AP3_pred = \ sess.run([image, AP1, AP2, AP3, AP1_pred, AP2_pred, AP3_pred]) dist_ap1 = np.linalg.norm(_AP1 - _AP1_pred) dist_ap2 = np.linalg.norm(_AP2 - _AP2_pred) dist_ap3 = np.linalg.norm(_AP3 - _AP3_pred) rx = matrix_from_anchor_points(_AP1_pred[0], _AP2_pred[0], _AP3_pred[0]) tx = _AP2_pred[0] * 60. image_true = np.squeeze(_image) image_pred = resample_sitk(fixed_image_sitk, rx, tx) imageio.imsave('imgdump/image_{}_true.png'.format(i), np.uint8(_image[0, ...])) imageio.imsave('imgdump/image_{}_pred.png'.format(i), np.uint8(image_pred)) cc.append(calc_correlation(image_pred, image_true)) mse.append(calc_mse(image_pred, image_true)) psnr.append(calc_psnr(image_pred, image_true)) ssim.append(calc_ssim(image_pred, image_true)) elif FLAGS.loss == 'SE3': y_pred, _ = inception.inception_v3(image_resized, num_classes=6, is_training=False) sess = tf.Session() ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir) tf.train.Saver().restore(sess, ckpt_file) print('restoring parameters from', ckpt_file) SO3_GROUP = SpecialOrthogonalGroup(3) SE3_GROUP = SpecialEuclideanGroup(3) _se3_err_i = [] for i in tqdm.tqdm(range(FLAGS.n_iter)): _image, _rvec, _tvec, _y_pred = \ sess.run([image, vec, AP2, y_pred]) rx = SO3_GROUP.matrix_from_rotation_vector(_y_pred[0, :3])[0] tx = _y_pred[0, 3:] * 60. image_true = np.squeeze(_image) image_pred = resample_sitk(fixed_image_sitk, rx, tx) imageio.imsave('imgdump/image_{}_true.png'.format(i), np.uint8(_image[0, ...])) imageio.imsave('imgdump/image_{}_pred.png'.format(i), np.uint8(image_pred)) cc.append(calc_correlation(image_pred, image_true)) mse.append(calc_mse(image_pred, image_true)) psnr.append(calc_psnr(image_pred, image_true)) ssim.append(calc_ssim(image_pred, image_true)) _y_true = np.concatenate((_rvec, _tvec), axis=-1) _se3_err_i.append( SE3_GROUP.compose(SE3_GROUP.inverse(_y_true), _y_pred)) err_vec = np.vstack(_se3_err_i) err_weights = np.diag(np.linalg.inv(np.cov(err_vec.T))) err_weights = err_weights / np.linalg.norm(err_weights) print(err_weights) else: print('Invalid Option:', FLAGS.loss) raise SystemExit cc = np.stack(cc) mse = np.stack(mse) psnr = np.stack(psnr) ssim = np.stack(ssim) print('CC:', np.median(cc)) print('MSE:', np.median(mse)) print('PSNR:', np.median(psnr)) print('SSIM:', np.median(ssim))
def main(argv): # TF Record dataset = tf.data.TFRecordDataset(FLAGS.data_dir + '/dataset_test.tfrecords') dataset = dataset.map(_parse_function_kingscollege) # dataset = dataset.repeat() # dataset = dataset.shuffle(FLAGS.queue_buffer) dataset = dataset.batch(1) image, vec, pose_q, pose_x = dataset.make_one_shot_iterator().get_next() # Network Definition image_resized = tf.image.resize_images(image, size=[224, 224]) if FLAGS.loss == 'PoseNet': y_pred, _ = inception.inception_v3(image_resized, num_classes=7, is_training=False) quaternion_pred, translation_pred = tf.split(y_pred, [4, 3], axis=1) sess = tf.Session() ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir) tf.train.Saver().restore(sess, ckpt_file) print('restoring parameters from', ckpt_file) i = 0 results = [] try: while True: _image, _quaternion_true, _translation_true, _quaternion_pred, _translation_pred = \ sess.run([image, pose_q, pose_x, quaternion_pred, translation_pred]) # Compute Individual Sample Error q1 = _quaternion_true / np.linalg.norm(_quaternion_true) q2 = _quaternion_pred / np.linalg.norm(_quaternion_pred) d = abs(np.sum(np.multiply(q1, q2))) theta = 2. * np.arccos(d) * 180. / np.pi error_x = np.linalg.norm(_translation_true - _translation_pred) results.append([error_x, theta]) print('Iteration:', i, 'Error XYZ (m):', error_x, 'Error Q (degrees):', theta) i = i + 1 except tf.errors.OutOfRangeError: print('End of Test Data') results = np.stack(results) results = np.median(results, axis=0) print('Error XYZ (m):', results[0], 'Error Q (degrees):', results[1]) elif FLAGS.loss == 'SE3': y_pred, _ = inception.inception_v3(image_resized, num_classes=6, is_training=False) sess = tf.Session() ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir) tf.train.Saver().restore(sess, ckpt_file) print('restoring parameters from', ckpt_file) SO3_GROUP = SpecialOrthogonalGroup(3) SE3_GROUP = SpecialEuclideanGroup(3) metric = InvariantMetric(group=SE3_GROUP, inner_product_mat_at_identity=np.eye(6), left_or_right='left') i = 0 results = [] _y_pred_i = [] _y_true_i = [] _se3_err_i = [] try: while True: _image, _rvec, _qvec, _tvec, _y_pred = \ sess.run([image, vec, pose_q, pose_x, y_pred]) _quaternion_true = _qvec _quaternion_pred = SO3_GROUP.quaternion_from_rotation_vector( _y_pred[0, :3])[0] # Compute Individual Sample Error q1 = _quaternion_true / np.linalg.norm(_quaternion_true) q2 = _quaternion_pred / np.linalg.norm(_quaternion_pred) d = abs(np.sum(np.multiply(q1, q2))) theta = 2. * np.arccos(d) * 180. / np.pi error_x = np.linalg.norm(_tvec - _y_pred[0, 3:]) results.append([error_x, theta]) # SE3 compute _y_true = np.concatenate((_rvec, _tvec), axis=-1) se3_dist = metric.squared_dist(_y_pred, _y_true)[0] _y_pred_i.append(_y_pred) _y_true_i.append(_y_true) _se3_err_i.append( SE3_GROUP.compose(SE3_GROUP.inverse(_y_true), _y_pred)) print('Iteration:', i, 'Error XYZ (m):', error_x, 'Error Q (degrees):', theta, 'SE3 dist:', se3_dist) i = i + 1 except tf.errors.OutOfRangeError: print('End of Test Data') # Calculate SE3 Error Weights err_vec = np.vstack(_se3_err_i) err_weights = np.diag(np.linalg.inv(np.cov(err_vec.T))) err_weights = err_weights / np.linalg.norm(err_weights) print(err_weights) results = np.stack(results) results = np.median(results, axis=0) print('Error XYZ (m):', results[0], 'Error Q (degrees):', results[1]) else: print('Invalid Option:', FLAGS.loss) raise SystemExit