random.shuffle(r) ############################################################################### # Create Database print 'Creating PoseNet Dataset.' SO3_GROUP = SpecialOrthogonalGroup(3) writer = tf.python_io.TFRecordWriter(out_file) for i in tqdm(r): pose_q = np.array(poses[i][3:7]) pose_x = np.array(poses[i][0:3]) rot_vec = SO3_GROUP.rotation_vector_from_quaternion(pose_q)[0] pose = np.concatenate((rot_vec, pose_x), axis=0) X = imageio.imread(images[i]) X = X[::4, ::4, :] #X = exposure.equalize_hist(X) img_raw = X.tostring() #.astype('float32').tostring() pose_raw = pose.astype('float32').tostring() pose_q_raw = pose_q.astype('float32').tostring() pose_x_raw = pose_x.astype('float32').tostring() example = tf.train.Example(features=tf.train.Features( feature={ 'height': _int64_feature(X.shape[0]), 'width': _int64_feature(X.shape[1]),
def main(args): poses = [] images = [] # Processing Image Lables logger.info('Processing Image Lables') with open(FLAGS.root_dir + '/' + FLAGS.dataset) as f: next(f) # skip the 3 header lines next(f) next(f) for line in f: fname, p0, p1, p2, p3, p4, p5, p6 = line.split() p0 = float(p0) p1 = float(p1) p2 = float(p2) p3 = float(p3) p4 = float(p4) p5 = float(p5) p6 = float(p6) poses.append((p0, p1, p2, p3, p4, p5, p6)) images.append(FLAGS.root_dir + '/' + fname) r = list(range(len(images))) random.shuffle(r) random.shuffle(r) random.shuffle(r) # Writing TFRecords logger.info('Writing TFRecords') SO3_GROUP = SpecialOrthogonalGroup(3) writer = tf.python_io.TFRecordWriter(FLAGS.out_file) for i in tqdm(r): pose_q = np.array(poses[i][3:7]) pose_x = np.array(poses[i][0:3]) rot_vec = SO3_GROUP.rotation_vector_from_quaternion(pose_q)[0] pose = np.concatenate((rot_vec, pose_x), axis=0) logger.info('Processing Image: ' + images[i]) X = imageio.imread(images[i]) X = X[::4, ::4, :] if FLAGS.hist_norm: X = exposure.equalize_hist(X) img_raw = X.tostring() pose_raw = pose.astype('float32').tostring() pose_q_raw = pose_q.astype('float32').tostring() pose_x_raw = pose_x.astype('float32').tostring() example = tf.train.Example(features=tf.train.Features(feature={ 'height': _int64_feature(X.shape[0]), 'width': _int64_feature(X.shape[1]), 'channel': _int64_feature(X.shape[2]), 'image': _bytes_feature(img_raw), 'pose': _bytes_feature(pose_raw), 'pose_q': _bytes_feature(pose_q_raw), 'pose_x': _bytes_feature(pose_x_raw)})) writer.write(example.SerializeToString()) writer.close() logger.info('\n', 'Creating Dataset Success.')
def main(argv): # TF Record datafiles = FLAGS.data_dir + '/test/' + FLAGS.subject_id + '.tfrecord' dataset = tf.data.TFRecordDataset(datafiles) dataset = dataset.map(_parse_function_ifind) # dataset = dataset.repeat() # dataset = dataset.shuffle(FLAGS.queue_buffer) dataset = dataset.batch(1) image, vec, qt, AP1, AP2, AP3 = dataset.make_one_shot_iterator().get_next() # Nifti Volume subject_path = FLAGS.scan_dir + '/test/' + FLAGS.subject_id + '.nii.gz' fixed_image_sitk_tmp = sitk.ReadImage(subject_path, sitk.sitkFloat32) fixed_image_sitk = sitk.GetImageFromArray( sitk.GetArrayFromImage(fixed_image_sitk_tmp)) fixed_image_sitk = sitk.RescaleIntensity(fixed_image_sitk, 0, 1) # * 255. # Network Definition image_input = tf.placeholder(shape=[1, 224, 224, 1], dtype=tf.float32) image_resized = tf.image.resize_images(image, size=[224, 224]) if FLAGS.loss == 'PoseNet': y_pred, _ = inception.inception_v3(image_input, num_classes=7) quaternion_pred, translation_pred = tf.split(y_pred, [4, 3], axis=1) sess = tf.Session() ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir) tf.train.Saver().restore(sess, ckpt_file) print('restoring parameters from', ckpt_file) SO3_GROUP = SpecialOrthogonalGroup(3) for i in range(FLAGS.n_iter): _image, _image_resized, _quaternion_true, _translation_true = \ sess.run([image, image_resized, qt, AP2], ) _quaternion_pred_sample = [] _translation_pred_sample = [] for j in range(FLAGS.n_samples): _quaternion_pred_i, _translation_pred_i = \ sess.run([quaternion_pred, translation_pred], feed_dict={image_input: _image_resized}) _quaternion_pred_sample.append(_quaternion_pred_i) _translation_pred_sample.append(_translation_pred_i) print(_quaternion_pred_i, _translation_pred_i) _quaternion_pred_sample = np.vstack(_quaternion_pred_sample) _rotvec_pred_sample = SO3_GROUP.rotation_vector_from_quaternion( _quaternion_pred_sample) _rotvec_pred = SO3_GROUP.left_canonical_metric.mean( _rotvec_pred_sample) _quaternion_pred = SO3_GROUP.quaternion_from_rotation_vector( _rotvec_pred) _translation_pred = np.mean(np.vstack(_translation_pred_sample), axis=0) # _quaternion_pred_variance = SO3_GROUP.left_canonical_metric.variance(_rotvec_pred_sample) _translation_pred_variance = np.var( np.vstack(_translation_pred_sample), axis=0) rx = SO3_GROUP.matrix_from_quaternion(_quaternion_pred)[0] tx = _translation_pred[0] * 60. image_true = np.squeeze(_image) image_pred = resample_sitk(fixed_image_sitk, rx, tx) imageio.imsave('imgdump/image_{}_true.png'.format(i), _image[0, ...]) imageio.imsave('imgdump/image_{}_pred.png'.format(i), image_pred) calc_psnr(image_pred, image_true) calc_mse(image_pred, image_true) calc_ssim(image_pred, image_true) calc_correlation(image_pred, image_true) elif FLAGS.loss == 'AP': y_pred, _ = inception.inception_v3(image_input, num_classes=9) AP1_pred, AP2_pred, AP3_pred = tf.split(y_pred, 3, axis=1) sess = tf.Session() ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir) tf.train.Saver().restore(sess, ckpt_file) print('restoring parameters from', ckpt_file) for i in range(FLAGS.n_iter): _image, _image_resized, _AP1, _AP2, _AP3 = \ sess.run([image, image_resized, AP1, AP2, AP3]) _AP1_sample = [] _AP2_sample = [] _AP3_sample = [] for j in range(FLAGS.n_samples): _AP1_pred_i, _AP2_pred_i, _AP3_pred_i = \ sess.run([AP1_pred, AP2_pred, AP3_pred], feed_dict={image_input: _image_resized}) _AP1_sample.append(_AP1_pred_i) _AP2_sample.append(_AP2_pred_i) _AP3_sample.append(_AP3_pred_i) _AP1_pred = np.mean(np.vstack(_AP1_sample), axis=0) _AP2_pred = np.mean(np.vstack(_AP2_sample), axis=0) _AP3_pred = np.mean(np.vstack(_AP3_sample), axis=0) _AP1_pred_variance = np.var(np.vstack(_AP1_sample), axis=0) _AP2_pred_variance = np.var(np.vstack(_AP2_sample), axis=0) _AP3_pred_variance = np.var(np.vstack(_AP3_sample), axis=0) dist_ap1 = np.linalg.norm(_AP1 - _AP1_pred) dist_ap2 = np.linalg.norm(_AP2 - _AP2_pred) dist_ap3 = np.linalg.norm(_AP3 - _AP3_pred) rx = matrix_from_anchor_points(_AP1_pred[0], _AP2_pred[0], _AP3_pred[0]) tx = _AP2_pred[0] * 60. image_true = np.squeeze(_image) image_pred = resample_sitk(fixed_image_sitk, rx, tx) imageio.imsave('imgdump/image_{}_true.png'.format(i), _image[0, ...]) imageio.imsave('imgdump/image_{}_pred.png'.format(i), image_pred) calc_psnr(image_pred, image_true) calc_mse(image_pred, image_true) calc_ssim(image_pred, image_true) calc_correlation(image_pred, image_true) elif FLAGS.loss == 'SE3': y_pred, _ = inception.inception_v3(image_input, num_classes=6) sess = tf.Session() ckpt_file = tf.train.latest_checkpoint(FLAGS.model_dir) tf.train.Saver().restore(sess, ckpt_file) print('restoring parameters from', ckpt_file) SO3_GROUP = SpecialOrthogonalGroup(3) SE3_GROUP = SpecialEuclideanGroup(3) for i in range(FLAGS.n_iter): print(i) _image, _image_resized, _rvec, _tvec = \ sess.run([image, image_resized, vec, AP2]) _y_pred_sample = [] for j in range(FLAGS.n_samples): _y_pred_i = sess.run([y_pred], feed_dict={image_input: _image_resized}) _y_pred_sample.append(_y_pred_i[0]) _y_pred_sample = np.vstack(_y_pred_sample) _y_pred = SE3_GROUP.left_canonical_metric.mean(_y_pred_sample) _y_pred_variance = SE3_GROUP.left_canonical_metric.variance( _y_pred_sample) rx = SO3_GROUP.matrix_from_rotation_vector(_y_pred[0, :3])[0] tx = _y_pred[0, 3:] * 60. image_true = np.squeeze(_image) image_pred = resample_sitk(fixed_image_sitk, rx, tx) imageio.imsave('imgdump/image_{}_true.png'.format(i), _image[0, ...]) imageio.imsave('imgdump/image_{}_pred.png'.format(i), image_pred) calc_psnr(image_pred, image_true) calc_mse(image_pred, image_true) calc_ssim(image_pred, image_true) calc_correlation(image_pred, image_true) else: print('Invalid Option:', FLAGS.loss) raise SystemExit