def get_everything(self, sess, features): ks_input = features["ks_input"] sensemap = features["sensemap"] mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input window = 1 im_lowres = tf_util.model_transpose(ks_input * window, sensemap) # im_lowres = tf_util.ifft2c(ks_input) im_lowres = tf.identity(im_lowres, name="low_res_image") # complex to channels im_lowres = tf_util.complex_to_channels(im_lowres) ks_truth = features["ks_truth"] mask_recon = features["mask_recon"] im_truth = tf_util.model_transpose(ks_truth * mask_recon, sensemap) im_truth = tf.identity(im_truth, name="truth_image") im_truth = tf_util.complex_to_channels(im_truth) im_lowres, im_truth, sensemap, mask = sess.run( [im_lowres, im_truth, sensemap, mask]) return im_lowres, im_truth, sensemap, mask
def get_undersampled(self, features): ks_input = features["ks_input"] sensemap = features["sensemap"] im_lowres = tf_util.model_transpose(ks_input, sensemap) im_lowres = tf.identity(im_lowres, name="low_res_image") im_lowres = tf_util.complex_to_channels(im_lowres) return im_lowres
def _create_summary(sense_place, ks_place, im_out_place, im_truth_place): sensemap = sense_place ks_input = ks_place image_output = im_out_place image_truth = im_truth_place image_input = tf_util.model_transpose(ks_input, sensemap) mask_input = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_output = tf_util.model_forward(image_output, sensemap) ks_truth = tf_util.model_forward(image_truth, sensemap) with tf.name_scope("input-output-truth"): summary_input = tf_util.sumofsq(ks_input, keep_dims=True) summary_output = tf_util.sumofsq(ks_output, keep_dims=True) summary_truth = tf_util.sumofsq(ks_truth, keep_dims=True) summary_fft = tf.log( tf.concat((summary_input, summary_output, summary_truth), axis=2) + 1e-6) tf.summary.image("kspace", summary_fft, max_outputs=FLAGS.num_summary_image) summary_input = tf_util.sumofsq(image_input, keep_dims=True) summary_output = tf_util.sumofsq(image_output, keep_dims=True) summary_truth = tf_util.sumofsq(image_truth, keep_dims=True) summary_image = tf.concat( (summary_input, summary_output, summary_truth), axis=2) tf.summary.image("image", summary_image, max_outputs=FLAGS.num_summary_image) with tf.name_scope("truth"): summary_truth_real = tf.reduce_sum(image_truth, axis=-1, keep_dims=True) summary_truth_real = tf.real(summary_truth_real) tf.summary.image("image_real", summary_truth_real, max_outputs=FLAGS.num_summary_image) with tf.name_scope("mask"): summary_mask = tf_util.sumofsq(mask_input, keep_dims=True) tf.summary.image("mask", summary_mask, max_outputs=FLAGS.num_summary_image) with tf.name_scope("sensemap"): summary_map = tf.slice(tf.abs(sensemap), [0, 0, 0, 0, 0], [-1, -1, -1, 1, -1]) summary_map = tf.transpose(summary_map, [0, 1, 4, 2, 3]) summary_map = tf.reshape( summary_map, [tf.shape(summary_map)[0], tf.shape(summary_map)[1], -1]) summary_map = tf.expand_dims(summary_map, axis=-1) tf.summary.image("image", summary_map, max_outputs=FLAGS.num_summary_image)
def get_truth_image(self, features): ks_truth = features["ks_truth"] sensemap = features["sensemap"] mask_recon = features["mask_recon"] image_truth = tf_util.model_transpose(ks_truth * mask_recon, sensemap) image_truth = tf.identity(image_truth, name="truth_image") # complex to channels image_truth = tf_util.complex_to_channels(image_truth) return image_truth
def get_images(self, features): ks_input = features["ks_input"] sensemap = features["sensemap"] mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input im_lowres = tf_util.model_transpose(ks_input, sensemap) im_lowres = tf.identity(im_lowres, name="low_res_image") # complex to channels im_lowres = tf_util.complex_to_channels(im_lowres) ks_truth = features["ks_truth"] mask_recon = features["mask_recon"] im_truth = tf_util.model_transpose(ks_truth * mask_recon, sensemap) im_truth = tf.identity(im_truth, name="truth_image") im_truth = tf_util.complex_to_channels(im_truth) return im_lowres, im_truth
def create_summary(self): # note that ks is based on the input data not on the rotated data output_ks = tf_util.model_forward(self.output_image, self.sensemap) # Input image to generator self.input_image = tf_util.model_transpose(self.ks, self.sensemap) if self.data_type is "knee": truth_image = tf_util.channels_to_complex(self.z_truth) sum_input = tf.image.flip_up_down( tf.image.rot90(tf.abs(self.input_image))) sum_output = tf.image.flip_up_down( tf.image.rot90(tf.abs(self.output_image))) sum_truth = tf.image.flip_up_down( tf.image.rot90(tf.abs(truth_image))) train_out = tf.concat((sum_input, sum_output, sum_truth), axis=2) tf.summary.image("input-output-truth", train_out) mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64) loss_l1 = tf.reduce_mean(tf.abs(self.X_gen - self.z_truth)) loss_l2 = tf.reduce_mean( tf.square(tf.abs(self.X_gen - self.z_truth))) tf.summary.scalar("l1", loss_l1) tf.summary.scalar("l2", loss_l2) # to check supervised/unsupervised y_real = tf_util.channels_to_complex(self.Y_real) y_real = tf.image.flip_up_down(tf.image.rot90(tf.abs(y_real))) tf.summary.image("y_real/mag", y_real) # Plot losses self.d_loss_sum = tf.summary.scalar("Discriminator_loss", self.d_loss) self.g_loss_sum = tf.summary.scalar("Generator_loss", self.g_loss) self.gp_sum = tf.summary.scalar("Gradient_penalty", self.gradient_penalty) self.d_fake = tf.summary.scalar("subloss/D_fake", tf.reduce_mean(self.d_logits_fake)) self.d_real = tf.summary.scalar("subloss/D_real", tf.reduce_mean(self.d_logits_real)) self.z_sum = tf.summary.histogram( "z", tf_util.complex_to_channels(self.input_image)) self.d_sum = tf.summary.merge( [self.z_sum, self.d_loss_sum, self.d_fake, self.d_real]) self.g_sum = tf.summary.merge([self.z_sum, self.g_loss_sum]) self.train_sum = tf.summary.merge_all()
def unroll_fista( ks_input, sensemap, num_grad_steps=5, resblock_num_features=128, resblock_num_blocks=2, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, num_summary_image=0, mask=None, verbose=False, conv="real", do_conjugate=False, activation="relu", ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ if window is None: window = 1 summary_iter = None global type_conv type_conv = conv global conjugate conjugate = do_conjugate if verbose: print( "%s> Building FISTA unrolled network (%d steps)...." % (scope, num_grad_steps) ) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step with tf.variable_scope(iter_name): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) with tf.variable_scope("update"): im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): num_channels_out = im_k.shape[-1] im_k = prior_grad_res_net( im_k, training=is_training, num_features=resblock_num_features, num_blocks=resblock_num_blocks, num_features_out=num_channels_out, data_format="channels_last", activation=activation ) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") if num_summary_image > 0: with tf.name_scope("summary"): tmp = tf_util.sumofsq(im_k, keep_dims=True) if summary_iter is None: summary_iter = tmp else: summary_iter = tf.concat( (summary_iter, tmp), axis=2) tf.summary.scalar("max/" + iter_name, tf.reduce_max(tmp)) ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) # Final data projection ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") if summary_iter is not None: tf.summary.image("iter/image", summary_iter, max_outputs=num_summary_image) return im_k
def main(_): if FLAGS.batch_size is not 1: print("Error: to test images, batch size must be 1") exit() model_dir = os.path.join(FLAGS.log_root, FLAGS.train_dir) if not os.path.exists(FLAGS.log_root): os.makedirs(FLAGS.log_root) if not os.path.exists(model_dir): os.makedirs(model_dir) bart_dir = os.path.join(model_dir, "bart_recon") if not os.path.exists(bart_dir): os.makedirs(bart_dir) run_config = tf.ConfigProto() run_config.gpu_options.allow_growth = True with tf.Session(config=run_config) as sess: """Execute main function.""" os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.device if not FLAGS.dataset_dir: raise ValueError("You must supply the dataset directory with " + "--dataset_dir") if FLAGS.random_seed >= 0: random.seed(FLAGS.random_seed) np.random.seed(FLAGS.random_seed) tf.logging.set_verbosity(tf.logging.INFO) print("Preparing dataset...") out_shape = [FLAGS.shape_z, FLAGS.shape_y] test_dataset, num_files = mri_data.create_dataset( os.path.join(FLAGS.dataset_dir, "test"), FLAGS.mask_path, num_channels=FLAGS.num_channels, num_emaps=FLAGS.num_emaps, batch_size=FLAGS.batch_size, out_shape=out_shape, ) # channels first: (batch, channels, z, y) # placeholders ks_shape = [None, FLAGS.shape_z, FLAGS.shape_y, FLAGS.num_channels] ks_place = tf.placeholder(tf.complex64, ks_shape) sense_shape = [ None, FLAGS.shape_z, FLAGS.shape_y, 1, FLAGS.num_channels ] sense_place = tf.placeholder(tf.complex64, sense_shape) im_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1] im_truth_place = tf.placeholder(tf.complex64, im_shape) # run through unrolled im_out_place = mri_model.unroll_fista( ks_place, sense_place, is_training=True, verbose=True, do_hardproj=FLAGS.do_hard_proj, num_summary_image=FLAGS.num_summary_image, resblock_num_features=FLAGS.feat_map, num_grad_steps=FLAGS.num_grad_steps, conv=FLAGS.conv, do_conjugate=FLAGS.do_conjugate, ) saver = tf.train.Saver() summary_writer = tf.summary.FileWriter(model_dir, sess.graph) # initialize model print("[*] initializing network...") if not load(model_dir, saver, sess): sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord) # See how many parameters are in model total_parameters = 0 for variable in tf.trainable_variables(): variable_parameters = 1 for dim in variable.get_shape(): variable_parameters *= dim.value total_parameters += variable_parameters print("Total number of trainable parameters: %d" % total_parameters) test_iterator = test_dataset.make_one_shot_iterator() features, labels = test_iterator.get_next() ks_truth = labels ks_in = features["ks_input"] sense_in = features["sensemap"] mask_recon = features["mask_recon"] im_truth = tf_util.model_transpose(ks_truth * mask_recon, sense_in) total_summary = tf.summary.merge_all() output_psnr = [] output_nrmse = [] output_ssim = [] cs_psnr = [] cs_nrmse = [] cs_ssim = [] for test_file in range(num_files): ks_in_run, sense_in_run, im_truth_run = sess.run( [ks_in, sense_in, im_truth]) im_out, total_summary_run = sess.run( [im_out_place, total_summary], feed_dict={ ks_place: ks_in_run, sense_place: sense_in_run, im_truth_place: im_truth_run, }, ) # CS recon bart_test = bart_cs(bart_dir, ks_in_run, sense_in_run, l1=0.007) # bart_test = None # handle batch dimension for b in range(FLAGS.batch_size): truth = im_truth_run[b, :, :, :] out = im_out[b, :, :, :] psnr, nrmse, ssim = metrics.compute_all(truth, out, sos_axis=-1) output_psnr.append(psnr) output_nrmse.append(nrmse) output_ssim.append(ssim) print("output mean +/ standard deviation psnr, nrmse, ssim") print( np.mean(output_psnr), np.std(output_psnr), np.mean(output_nrmse), np.std(output_nrmse), np.mean(output_ssim), np.std(output_ssim), ) psnr, nrmse, ssim = metrics.compute_all(im_truth_run, bart_test, sos_axis=-1) cs_psnr.append(psnr) cs_nrmse.append(nrmse) cs_ssim.append(ssim) print("cs mean +/ standard deviation psnr, nrmse, ssim") print( np.mean(cs_psnr), np.std(cs_psnr), np.mean(cs_nrmse), np.std(cs_nrmse), np.mean(cs_ssim), np.std(cs_ssim), ) print("End of testing loop") txt_path = os.path.join(model_dir, "metrics.txt") f = open(txt_path, "w") f.write("parameters = " + str(total_parameters) + "\n" + "output psnr = " + str(np.mean(output_psnr)) + " +\- " + str(np.std(output_psnr)) + "\n" + "output nrmse = " + str(np.mean(output_nrmse)) + " +\- " + str(np.std(output_nrmse)) + "\n" + "output ssim = " + str(np.mean(output_ssim)) + " +\- " + str(np.std(output_ssim)) + "\n" "cs psnr = " + str(np.mean(cs_psnr)) + " +\- " + str(np.std(cs_psnr)) + "\n" + "output nrmse = " + str(np.mean(cs_nrmse)) + " +\- " + str(np.std(cs_nrmse)) + "\n" + "output ssim = " + str(np.mean(cs_ssim)) + " +\- " + str(np.std(cs_ssim))) f.close()
def measure(self, X_gen, sensemap, real_mask): name = "measure" random_seed = 0 verbose = True image = tf_util.channels_to_complex(X_gen) kspace = tf_util.model_forward(image, sensemap) total_kspace = None if (self.data_type is "DCE" # or self.data_type is "DCE_2D" # or self.data_type is "mfast" ): # remove batch dimension # kspace = kspace[0, :, :, :, :] kspace = tf.squeeze(kspace, axis=0) # different mask for each frame for f in range(self.max_frames): ks_x = kspace[:, :, f, :] # Randomly select mask mask_x = tf.random_shuffle(self.masks) mask_x = mask_x[0, :, :] mask_x = tf.expand_dims(mask_x, axis=0) # Augment sampling masks mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed) mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed) # Tranpose to store data as (kz, ky, channels) mask_x = tf.transpose(mask_x, [1, 2, 0]) # self.applied_mask = tf.expand_dims(mask_x, axis=-1) ks_x = tf.image.flip_up_down(ks_x) # Initially set image size to be all the same ks_x = tf.image.resize_image_with_crop_or_pad( ks_x, self.height, self.width) mask_x = tf.image.resize_image_with_crop_or_pad( mask_x, self.height, self.width) shape_cal = 20 if shape_cal > 0: with tf.name_scope("CalibRegion"): if self.verbose: print("%s> Including calib region (%d, %d)..." % (name, shape_cal, shape_cal)) mask_calib = tf.ones([shape_cal, shape_cal, 1], dtype=tf.complex64) mask_calib = tf.image.resize_image_with_crop_or_pad( mask_calib, self.height, self.width) mask_x = mask_x * (1 - mask_calib) + mask_calib # mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x)) # mask_recon = tf.cast(mask_recon > 0.0, dtype=tf.complex64) # mask_x = mask_x * mask_recon # mask_x = tf.expand_dims(mask_x, axis=0) # Assuming calibration region is fully sampled shape_sc = 5 scale = tf.image.resize_image_with_crop_or_pad( ks_x, shape_sc, shape_sc) scale = tf.reduce_mean(tf.square( tf.abs(scale))) * (shape_sc * shape_sc / 1e5) scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64) ks_x = ks_x * scale ks_x = tf.multiply(ks_x, mask_x) ks_x = tf.expand_dims(ks_x, axis=-2) self.applied_mask = tf.expand_dims(mask_x, axis=0) if total_kspace is not None: total_kspace = tf.concat([total_kspace, ks_x], axis=-2) else: total_kspace = ks_x total_kspace = tf.expand_dims(total_kspace, axis=0) # for i in range(self.batch_size): # if self.dims == 4: # ks_x = kspace[i, :, :] # else: # # 2D plus time # ks_x = kspace[i, :, :, :] # # lazy: use original applied mask and just apply it again # # won't work because it isn't doing anything unless it gets flipped # # mask_x = tf_util.kspace_mask(ks_x, dtype=tf.complex64) # # # Augment sampling masks # # # New mask - taken from image B # # # mask = real_mask # # mask_x = tf.image.flip_up_down(mask_x) # # mask_x = tf.image.flip_left_right(mask_x) # # if self.dims != 4: # # mask_x = mask_x[:,:,:,0,:] # # data dimensions # shape_y = self.width # # shape_t = self.max_frames # shape_t = self.height # sim_partial_ky = 0.0 # # accs = [1, 6] # accs = [5, 6] # rand_accel = (accs[1] - accs[0]) * tf.random_uniform([]) + accs[0] # tf.summary.scalar("acc", rand_accel) # fn_inputs = [ # shape_y, # shape_t, # rand_accel, # 10, # 2.0, # sim_partial_ky, # ] # ny, nt, accel, ncal, vd_degree # mask_x = tf.py_func( # mask.generate_perturbed2dvdkt, fn_inputs, tf.complex64 # ) # print("mask x", mask_x) # self.reshaped_mask = mask_x # self.reshaped_mask = tf.reshape(mask_x, [shape_t, shape_y, 1, 1]) # ks_x = ks_x * self.reshaped_mask # if total_kspace is not None: # total_kspace = tf.concat([total_kspace, ks_x], 0) # else: # total_kspace = ks_x if self.data_type is "DCE_2D": # remove batch dimension kspace = tf.squeeze(kspace, axis=0) # ks_x = kspace[:, :, f, :] ks_x = kspace # Randomly select mask mask_x = tf.random_shuffle(self.masks) mask_x = mask_x[0, :, :] mask_x = tf.expand_dims(mask_x, axis=0) # Augment sampling masks mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed) mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed) # Tranpose to store data as (kz, ky, channels) mask_x = tf.transpose(mask_x, [1, 2, 0]) # self.applied_mask = tf.expand_dims(mask_x, axis=-1) ks_x = tf.image.flip_up_down(ks_x) # Initially set image size to be all the same ks_x = tf.image.resize_image_with_crop_or_pad( ks_x, self.height, self.width) mask_x = tf.image.resize_image_with_crop_or_pad( mask_x, self.height, self.width) shape_cal = 20 if shape_cal > 0: with tf.name_scope("CalibRegion"): if self.verbose: print("%s> Including calib region (%d, %d)..." % (name, shape_cal, shape_cal)) mask_calib = tf.ones([shape_cal, shape_cal, 1], dtype=tf.complex64) mask_calib = tf.image.resize_image_with_crop_or_pad( mask_calib, self.height, self.width) mask_x = mask_x * (1 - mask_calib) + mask_calib # mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x)) # mask_recon = tf.cast(mask_recon > 0.0, dtype=tf.complex64) # mask_x = mask_x * mask_recon # mask_x = tf.expand_dims(mask_x, axis=0) # Assuming calibration region is fully sampled shape_sc = 5 scale = tf.image.resize_image_with_crop_or_pad( ks_x, shape_sc, shape_sc) scale = tf.reduce_mean(tf.square( tf.abs(scale))) * (shape_sc * shape_sc / 1e5) scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64) ks_x = ks_x * scale ks_x = tf.multiply(ks_x, mask_x) self.applied_mask = tf.expand_dims(mask_x, axis=0) total_kspace = ks_x total_kspace = tf.expand_dims(total_kspace, axis=0) if self.data_type is "knee": for i in range(self.batch_size): ks_x = kspace[i, :, :] # Randomly select mask mask_x = tf.random_shuffle(self.masks) mask_x = tf.slice(mask_x, [0, 0, 0], [1, -1, -1]) # Augment sampling masks mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed) mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed) # Tranpose to store data as (kz, ky, channels) mask_x = tf.transpose(mask_x, [1, 2, 0]) ks_x = tf.image.flip_up_down(ks_x) # Initially set image size to be all the same ks_x = tf.image.resize_image_with_crop_or_pad( ks_x, self.height, self.width) mask_x = tf.image.resize_image_with_crop_or_pad( mask_x, self.height, self.width) shape_cal = 20 if shape_cal > 0: with tf.name_scope("CalibRegion"): if self.verbose: print("%s> Including calib region (%d, %d)..." % (name, shape_cal, shape_cal)) mask_calib = tf.ones([shape_cal, shape_cal, 1], dtype=tf.complex64) mask_calib = tf.image.resize_image_with_crop_or_pad( mask_calib, self.height, self.width) mask_x = mask_x * (1 - mask_calib) + mask_calib mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x)) # mask_recon = tf.abs(ks_x) mask_recon = tf.cast(mask_recon > 1e-7, dtype=tf.complex64) mask_x = mask_x * mask_recon # Assuming calibration region is fully sampled shape_sc = 5 scale = tf.image.resize_image_with_crop_or_pad( ks_x, shape_sc, shape_sc) scale = tf.reduce_mean(tf.square( tf.abs(scale))) * (shape_sc * shape_sc / 1e5) scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64) ks_x = ks_x * scale # Masked input ks_x = tf.multiply(ks_x, mask_x) if total_kspace is not None: total_kspace = tf.concat([total_kspace, ks_x], 0) else: total_kspace = ks_x x_measured = tf_util.model_transpose(total_kspace, sensemap, name="x_measured") x_measured = tf_util.complex_to_channels(x_measured) return x_measured
def main(_): model_dir = os.path.join(FLAGS.log_root, FLAGS.train_dir) if not os.path.exists(FLAGS.log_root): os.makedirs(FLAGS.log_root) if not os.path.exists(model_dir): os.makedirs(model_dir) bart_dir = os.path.join(model_dir, "bart_recon") if not os.path.exists(bart_dir): os.makedirs(bart_dir) # im_head = "/home/ekcole/Workspace/mfast_combined/" # im_dir = os.path.join(im_head, FLAGS.train_dir) image_dir = os.path.join(model_dir, "images") if not os.path.exists(image_dir): os.makedirs(image_dir) run_config = tf.ConfigProto() run_config.gpu_options.allow_growth = True with tf.Session(config=run_config) as sess: """Execute main function.""" os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.device if not FLAGS.dataset_dir: raise ValueError("You must supply the dataset directory with " + "--dataset_dir") if FLAGS.random_seed >= 0: random.seed(FLAGS.random_seed) np.random.seed(FLAGS.random_seed) tf.logging.set_verbosity(tf.logging.INFO) print("Preparing dataset...") out_shape = [FLAGS.shape_z, FLAGS.shape_y] test_dataset, num_files = mri_data.create_dataset( os.path.join(FLAGS.dataset_dir, "test_images"), FLAGS.mask_path, num_channels=FLAGS.num_channels, num_maps=FLAGS.num_emaps, batch_size=FLAGS.batch_size, out_shape=out_shape, ) # channels first: (batch, channels, z, y) # placeholders ks_shape = [None, FLAGS.shape_z, FLAGS.shape_y, FLAGS.num_channels] ks_place = tf.placeholder(tf.complex64, ks_shape) sense_shape = [ None, FLAGS.shape_z, FLAGS.shape_y, 1, FLAGS.num_channels ] sense_place = tf.placeholder(tf.complex64, sense_shape) im_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1] im_truth_place = tf.placeholder(tf.complex64, im_shape) # run through unrolled im_out_place = mri_model.unroll_fista( ks_place, sense_place, is_training=True, verbose=True, do_hardproj=FLAGS.do_hard_proj, num_summary_image=FLAGS.num_summary_image, resblock_num_features=FLAGS.feat_map, num_grad_steps=FLAGS.num_grad_steps, conv=FLAGS.conv, ) saver = tf.train.Saver() summary_writer = tf.summary.FileWriter(model_dir, sess.graph) # initialize model print("[*] initializing network...") if not load(model_dir, saver, sess): sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord) # See how many parameters are in model total_parameters = 0 for variable in tf.trainable_variables(): variable_parameters = 1 for dim in variable.get_shape(): variable_parameters *= dim.value total_parameters += variable_parameters print("Total number of trainable parameters: %d" % total_parameters) tf.summary.scalar("parameters/parameters", total_parameters) test_iterator = test_dataset.make_one_shot_iterator() features, labels = test_iterator.get_next() ks_truth = labels ks_in = features["ks_input"] sense_in = features["sensemap"] mask_recon = features["mask_recon"] im_in = tf_util.model_transpose(ks_in * mask_recon, sense_in) im_truth = tf_util.model_transpose(ks_truth * mask_recon, sense_in) total_summary = tf.summary.merge_all() output_psnr = [] output_nrmse = [] output_ssim = [] cs_psnr = [] cs_nrmse = [] cs_ssim = [] for test_file in range(num_files): ks_in_run, sense_in_run, im_truth_run, im_in_run = sess.run( [ks_in, sense_in, im_truth, im_in]) im_out, total_summary_run = sess.run( [im_out_place, total_summary], feed_dict={ ks_place: ks_in_run, sense_place: sense_in_run, im_truth_place: im_truth_run, }, ) # CS recon bart_test = bart_cs(bart_dir, ks_in_run, sense_in_run, l1=FLAGS.l1) # print("rotating") # im_in_run = np.rot90(np.squeeze(im_in_run), k=3) # im_out = np.rot90(np.squeeze(im_out), k=3) # bart_test = np.rot90(np.squeeze(bart_test), k=3) # im_truth_run = np.rot90(np.squeeze(im_truth_run), k=3) # save magnitude input, output, cs, truth as .png # complex if FLAGS.conv == "complex": mag_images = np.squeeze( np.absolute( np.concatenate((im_out, bart_test, im_truth_run), axis=2))) phase_images = np.squeeze( np.angle( np.concatenate((im_out, bart_test, im_truth_run), axis=2))) diff_out = im_truth_run - im_out diff_cs = im_truth_run - bart_test diff_mag = np.squeeze( np.absolute(np.concatenate((diff_out, diff_cs), axis=2))) diff_phase = np.squeeze( np.angle(np.concatenate((diff_out, diff_cs), axis=2))) if FLAGS.conv == "real": mag_images = np.squeeze( np.absolute(np.concatenate((im_in_run, im_out), axis=2))) phase_images = np.squeeze( np.angle(np.concatenate((im_in_run, im_out), axis=2))) diff_in = im_truth_run - im_in_run diff_out = im_truth_run - im_out diff_mag = np.squeeze( np.absolute(np.concatenate((diff_in, diff_out), axis=2))) diff_phase = np.squeeze( np.angle(np.concatenate((diff_in, diff_out), axis=2))) filename = image_dir + "/mag_" + str(test_file) + ".png" scipy.misc.imsave(filename, mag_images) # filename = image_dir + "/diff_mag_" + str(test_file) + ".png" # scipy.misc.imsave(filename, diff_mag) filename = image_dir + "/phase_" + str(test_file) + ".png" scipy.misc.imsave(filename, phase_images) # filename = image_dir + "/diff_phase_" + str(test_file) + ".png" # scipy.misc.imsave(filename, diff_phase) filename = image_dir + "/diff_phase_" + str(test_file) + ".npy" np.save(filename, diff_phase) filename = image_dir + "/diff_mag_" + str(test_file) + ".npy" np.save(filename, diff_mag) psnr, nrmse, ssim = metrics.compute_all(im_truth_run, im_out, sos_axis=-1) output_psnr.append(psnr) output_nrmse.append(nrmse) output_ssim.append(ssim) print("output psnr, nrmse, ssim") print( np.mean(output_psnr), np.std(output_psnr), np.mean(output_nrmse), np.std(output_nrmse), np.mean(output_ssim), np.std(output_ssim), ) psnr, nrmse, ssim = metrics.compute_all(im_truth_run, bart_test, sos_axis=-1) cs_psnr.append(psnr) cs_nrmse.append(nrmse) cs_ssim.append(ssim) print("cs psnr, nrmse, ssim") print( np.mean(cs_psnr), np.std(cs_psnr), np.mean(cs_nrmse), np.std(cs_nrmse), np.mean(cs_ssim), np.std(cs_ssim), ) print("End of testing loop") txt_path = os.path.join(model_dir, "metrics.txt") f = open(txt_path, "w") f.write("parameters = " + str(total_parameters) + "\n" "output psnr = " + str(np.mean(output_psnr)) + " +\- " + str(np.std(output_psnr)) + "\n" + "output nrmse = " + str(np.mean(output_nrmse)) + " +\- " + str(np.std(output_nrmse)) + "\n" + "output ssim = " + str(np.mean(output_ssim)) + " +\- " + str(np.std(output_ssim)) + "\n" "cs psnr = " + str(np.mean(cs_psnr)) + " +\- " + str(np.std(cs_psnr)) + "\n" + "output nrmse = " + str(np.mean(cs_nrmse)) + " +\- " + str(np.std(cs_nrmse)) + "\n" + "output ssim = " + str(np.mean(cs_ssim)) + " +\- " + str(np.std(cs_ssim))) f.close()
def measure(self, X_gen, sensemap, real_mask): name = "measure" random_seed = 0 verbose = True image = tf_util.channels_to_complex(X_gen) kspace = tf_util.model_forward(image, sensemap) # input_shape = tf.shape(kspace) total_kspace = None if (self.data_type is "DCE" or self.data_type is "DCE_2D" or self.data_type is "mfast"): print("DCE measure") for i in range(self.batch_size): if self.dims == 4: ks_x = kspace[i, :, :] else: # 2D plus time ks_x = kspace[i, :, :, :] # lazy: use original applied mask and just apply it again # won't work because it isn't doing anything unless it gets flipped # mask_x = tf_util.kspace_mask(ks_x, dtype=tf.complex64) # # Augment sampling masks # # New mask - taken from image B # # mask = real_mask # mask_x = tf.image.flip_up_down(mask_x) # mask_x = tf.image.flip_left_right(mask_x) # if self.dims != 4: # mask_x = mask_x[:,:,:,0,:] # data dimensions shape_y = self.width shape_t = self.max_frames sim_partial_ky = 0.0 accs = [1, 6] rand_accel = (accs[1] - accs[0]) * \ tf.random_uniform([]) + accs[0] fn_inputs = [ shape_y, shape_t, rand_accel, 10, 2.0, sim_partial_ky, ] # ny, nt, accel, ncal, vd_degree mask_x = tf.py_func(mask.generate_perturbed2dvdkt, fn_inputs, tf.complex64) ks_x = ks_x * tf.reshape(mask_x, [1, shape_y, shape_t, 1]) if total_kspace is not None: total_kspace = tf.concat([total_kspace, ks_x], 0) else: total_kspace = ks_x if self.data_type is "knee": print("knee") for i in range(self.batch_size): ks_x = kspace[i, :, :] print("ks_x", ks_x) # Randomly select mask mask_x = tf.random_shuffle(self.masks) print("self masks", mask_x) mask_x = tf.slice(mask_x, [0, 0, 0], [1, -1, -1]) print("sliced masks", mask_x) # Augment sampling masks mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed) mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed) # Tranpose to store data as (kz, ky, channels) mask_x = tf.transpose(mask_x, [1, 2, 0]) print("transposed mask", mask_x) ks_x = tf.image.flip_up_down(ks_x) # Initially set image size to be all the same ks_x = tf.image.resize_image_with_crop_or_pad( ks_x, self.height, self.width) mask_x = tf.image.resize_image_with_crop_or_pad( mask_x, self.height, self.width) print("resized mask", mask_x) shape_cal = 20 if shape_cal > 0: with tf.name_scope("CalibRegion"): if self.verbose: print("%s> Including calib region (%d, %d)..." % (name, shape_cal, shape_cal)) mask_calib = tf.ones([shape_cal, shape_cal, 1], dtype=tf.complex64) mask_calib = tf.image.resize_image_with_crop_or_pad( mask_calib, self.height, self.width) mask_x = mask_x * (1 - mask_calib) + mask_calib mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x)) mask_recon = tf.cast(mask_recon > 1e-7, dtype=tf.complex64) mask_x = mask_x * mask_recon print("mask x", mask_x) # Assuming calibration region is fully sampled shape_sc = 5 scale = tf.image.resize_image_with_crop_or_pad( ks_x, shape_sc, shape_sc) scale = tf.reduce_mean(tf.square( tf.abs(scale))) * (shape_sc * shape_sc / 1e5) scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64) ks_x = ks_x * scale # Masked input ks_x = tf.multiply(ks_x, mask_x) if total_kspace is not None: total_kspace = tf.concat([total_kspace, ks_x], 0) else: total_kspace = ks_x x_measured = tf_util.model_transpose(total_kspace, sensemap, name="x_measured") x_measured = tf_util.complex_to_channels(x_measured) return x_measured
def generator(self, ks_input, sensemap, reuse=False): mask_example = tf_util.kspace_mask(ks_input, dtype=tf.complex64) with tf.variable_scope("generator") as scope: if reuse: scope.reuse_variables() # 2D data # batch, height, width, channels if self.dims == 4: if self.arch == "unrolled": c_out = unrolled.unroll_fista( ks_input, sensemap, num_grad_steps=self.iterations, resblock_num_features=self.g_dim, resblock_num_blocks=self.res_blocks, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, num_summary_image=6, mask=mask_example, verbose=False, ) c_out = tf_util.complex_to_channels(c_out) else: z = tf_util.model_transpose(ks_input, sensemap) z = tf_util.complex_to_channels(z) res_size = self.g_dim kernel_size = 3 num_channels = 2 # could try tf.nn.tanh instead # act = tf.nn.sigmoid num_blocks = 5 c = tf.layers.conv2d( z, res_size, kernel_size, padding="same", activation=tf.nn.relu, use_bias=True, ) for i in range(num_blocks): c = tf.layers.conv2d( c, res_size, kernel_size, padding="same", activation=tf.nn.relu, use_bias=True, ) c1 = tf.layers.conv2d( c, res_size, kernel_size, padding="same", activation=tf.nn.relu, use_bias=True, ) c = tf.add(c, c1) c8 = tf.layers.conv2d( c, num_channels, kernel_size, padding="same", activation=None, use_bias=True, ) c_out = tf.add(c8, z) c_out = tf.nn.tanh(c_out) # 3D data # batch, height, width, channels, time frames else: if self.arch == "unrolled": c_out = unrolled_3d.unroll_fista( ks_input, sensemap, num_grad_steps=self.iterations, num_features=self.g_dim, num_resblocks=self.res_blocks, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, mask=mask_example, verbose=False, data_format="channels_last", do_separable=self.do_separable, ) c_out = tf_util.complex_to_channels(c_out) return c_out
def build_model(self, mode): self.Y_real, self.data_num, real_mask = self.read_real() # Read in train or test input images to be reconstructed by generator train_iterator = mri_utils.Iterator( self.batch_size, self.mask_path, self.data_type, mode, # either train or test self.out_shape, verbose=self.verbose, train_acc=self.train_acc, data_dir=self.data_dir) self.input_files = train_iterator.num_files train_dataset = train_iterator.iterator.get_next() ks_truth = train_dataset["ks_truth"] ks_input = train_dataset["ks_input"] sensemap = train_dataset["sensemap"] image_truth = tf_util.model_transpose(ks_truth, sensemap) self.complex_truth = image_truth image_truth = tf_util.complex_to_channels(image_truth) self.z_truth = image_truth self.ks = ks_input self.sensemap = sensemap # generate image X_gen self.X_gen = self.generator(self.ks, self.sensemap) # no measurement for supervised self.Y_fake = self.X_gen # output of discriminator for fake and real images self.d_logits_fake = self.discriminator(self.Y_fake, reuse=False) self.d_logits_real = self.discriminator(self.Y_real, reuse=True) # discriminator loss self.d_loss = tf.reduce_mean(self.d_logits_fake) - tf.reduce_mean( self.d_logits_real) # add total variation loss # tv_loss = -tf.reduce_sum(tf.image.total_variation(self.Y_fake)) # generator loss self.g_loss = -tf.reduce_mean(self.d_logits_fake) # + tv_loss # Gradient Penalty self.epsilon = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0.0, maxval=1.0) Y_hat = self.Y_real + self.epsilon * (self.Y_fake - self.Y_real) D_Y_hat = self.discriminator(Y_hat, reuse=True) grad_D_Y_hat = tf.gradients(D_Y_hat, [Y_hat])[0] red_idx = range(1, Y_hat.shape.ndims) slopes = tf.sqrt( tf.reduce_sum(tf.square(grad_D_Y_hat), reduction_indices=list(red_idx))) self.gradient_penalty = tf.reduce_mean((slopes - 1.0)**2) # updated discriminator loss self.d_loss = self.d_loss + 10.0 * self.gradient_penalty train_vars = tf.trainable_variables() for v in train_vars: # v = tf.cast(v, tf.float32) tf.add_to_collection("reg_loss", tf.nn.l2_loss(v)) self.generator_vars = [v for v in train_vars if "generator" in v.name] self.discriminator_vars = [ v for v in train_vars if "discriminator" in v.name ] self.g_optimizer = tf.train.AdamOptimizer( learning_rate=self.lr, name="g_opt", beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=self.generator_vars) self.d_optimizer = tf.train.AdamOptimizer( learning_rate=self.lr, name="d_opt", beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=self.discriminator_vars) self.output_image = tf_util.channels_to_complex(self.X_gen) self.im_out = self.output_image self.mag_output = tf.abs(self.output_image) self.create_summary() with tf.variable_scope("counter"): self.counter = tf.get_variable( "counter", shape=[1], initializer=tf.constant_initializer([1]), dtype=tf.int32, ) self.update_counter = tf.assign(self.counter, tf.add(self.counter, 1)) self.saver = tf.train.Saver() self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) self.initialize_model()
def unroll_fista( ks_input, sensemap, num_grad_steps=5, resblock_num_features=128, resblock_num_blocks=2, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, num_summary_image=0, mask=None, verbose=False, ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ # get list of GPU devices local_device_protos = device_lib.list_local_devices() device_list = [x.name for x in local_device_protos if x.device_type == "GPU"] if window is None: window = 1 summary_iter = None if verbose: print( "%s> Building FISTA unrolled network (%d steps)...." % (scope, num_grad_steps) ) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step i_device = int(len(device_list) * i_step / num_grad_steps) # cur_device = device_list[i_device] cur_device = device_list[-1] with tf.device(cur_device), tf.variable_scope(iter_name): # with tf.variable_scope(iter_name): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) with tf.variable_scope("update"): im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): num_channels_out = im_k.shape[-1] # Default is channels last # im_k = prior_grad_res_net(im_k, training=is_training, # cout=num_channels_out) # Transpose channels_last to channels_first im_k = tf.transpose(im_k, [0, 3, 1, 2]) im_k = prior_grad_res_net( im_k, training=is_training, num_features=resblock_num_features, num_blocks=resblock_num_blocks, num_features_out=num_channels_out, data_format="channels_first", ) im_k = tf.transpose(im_k, [0, 2, 3, 1]) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") if num_summary_image > 0: with tf.name_scope("summary"): tmp = tf_util.sumofsq(im_k, keep_dims=True) if summary_iter is None: summary_iter = tmp else: summary_iter = tf.concat((summary_iter, tmp), axis=2) # tf.summary.scalar("max/" + iter_name, tf.reduce_max(tmp)) ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) # Final data projection ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") # if summary_iter is not None: # tf.summary.image("iter/image", summary_iter, max_outputs=num_summary_image) return im_k
def unroll_fista( ks_input, sensemap, scope="MRI", num_grad_steps=5, num_resblocks=4, num_features=64, kernel_size=[3, 3, 5], is_training=True, mask_output=1, mask=None, window=None, do_hardproj=False, do_dense=False, do_separable=False, do_rnn=False, do_circular=True, batchnorm=False, leaky=False, fix_update=False, data_format="channels_first", verbose=False, ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ # get list of GPU devices local_device_protos = device_lib.list_local_devices() device_list = [x.name for x in local_device_protos if x.device_type == "GPU"] if window is None: window = 1 summary_iter = {} if verbose: print("%s> Building FISTA unrolled network...." % scope) print("%s> Num of gradient steps: %d" % (scope, num_grad_steps)) print( "%s> Prior: %d ResBlocks, %d features" % (scope, num_resblocks, num_features) ) print("%s> Kernel size: [%d x %d x %d]" % ((scope,) + tuple(kernel_size))) if do_rnn: print("%s> Sharing weights across iterations..." % scope) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) if do_dense: print("%s> Inserting dense connections..." % scope) if do_circular: print("%s> Using circular padding..." % scope) if do_separable: print("%s> Using depth-wise separable convolutions..." % scope) if not batchnorm: print("%s> Turning off batch normalization..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 im_dense = None for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step if do_rnn: scope_name = "iter" else: scope_name = iter_name # figure out which GPU to use for this step # i_device = int(len(device_list) * i_step / num_grad_steps) # cur_device = device_list[i_device] # with tf.device(cur_device): with tf.variable_scope( scope_name, reuse=(tf.AUTO_REUSE if do_rnn else False) ): with tf.variable_scope("update"): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step if fix_update: t_update = -2.0 else: t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): # default is channels_last num_channels_out = im_k.shape[-1] if data_format == "channels_first": im_k = tf.transpose(im_k, [0, 4, 1, 2, 3]) if im_dense is not None: im_k = tf.concat([im_k, im_dense], axis=1) im_k, im_dense_k = prior_grad_res_net( im_k, training=is_training, num_features=num_features, num_blocks=num_resblocks, num_features_out=num_channels_out, kernel_size=kernel_size, data_format=data_format, circular=do_circular, separable=do_separable, batchnorm=batchnorm, leaky=leaky, ) if do_dense: if im_dense is not None: im_dense = tf.concat([im_dense, im_dense_k], axis=1) else: im_dense = im_dense_k if data_format == "channels_first": im_k = tf.transpose(im_k, [0, 2, 3, 4, 1]) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") with tf.name_scope("summary"): # tmp = tf_util.sumofsq(im_k, keep_dims=True) summary_iter[iter_name] = im_k ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") # return im_k, ks_k, summary_iter return im_k
def main(_): # path where model checkpoints and summaries will be saved model_dir = os.path.join(FLAGS.log_root, FLAGS.train_dir) if not os.path.exists(FLAGS.log_root): os.makedirs(FLAGS.log_root) if not os.path.exists(model_dir): os.makedirs(model_dir) run_config = tf.ConfigProto() run_config.gpu_options.allow_growth = True with tf.Session( config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True) ) as sess: """Execute main function.""" os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.device if not FLAGS.dataset_dir: raise ValueError( "You must supply the dataset directory with " + "--dataset_dir" ) if FLAGS.random_seed >= 0: random.seed(FLAGS.random_seed) np.random.seed(FLAGS.random_seed) tf.logging.set_verbosity(tf.logging.INFO) print("Preparing dataset...") out_shape = [FLAGS.shape_z, FLAGS.shape_y] train_dataset, num_files = mri_data.create_dataset( os.path.join(FLAGS.dataset_dir, "train"), FLAGS.mask_path, num_channels=FLAGS.num_channels, num_emaps=FLAGS.num_emaps, batch_size=FLAGS.batch_size, out_shape=out_shape, ) # channels last format: batch, z, y, channels # placeholders ks_shape = [None, FLAGS.shape_z, FLAGS.shape_y, FLAGS.num_channels] ks_place = tf.placeholder(tf.complex64, ks_shape) sense_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1, FLAGS.num_channels] sense_place = tf.placeholder(tf.complex64, sense_shape) im_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1] im_truth_place = tf.placeholder(tf.complex64, im_shape) # run through unrolled model im_out_place = mri_model.unroll_fista( ks_place, sense_place, is_training=True, verbose=True, do_hardproj=FLAGS.do_hard_proj, num_summary_image=FLAGS.num_summary_image, resblock_num_features=FLAGS.feat_map, num_grad_steps=FLAGS.num_grad_steps, conv=FLAGS.conv, do_conjugate=FLAGS.do_conjugate, activation=FLAGS.activation ) # tensorboard summary function _create_summary(sense_place, ks_place, im_out_place, im_truth_place) # define L1 loss between output and ground truth loss = tf.reduce_mean(tf.abs(im_out_place - im_truth_place), name="l1") loss_sum = tf.summary.scalar("loss/l1", loss) # optimize using Adam optimizer = tf.train.AdamOptimizer( learning_rate=FLAGS.learning_rate, name="opt", beta1=FLAGS.adam_beta1, beta2=FLAGS.adam_beta2, ).minimize(loss) # counter for saving checkpoints with tf.variable_scope("counter"): counter = tf.get_variable( "counter", shape=[1], initializer=tf.constant_initializer([0]), dtype=tf.int32, ) update_counter = tf.assign(counter, tf.add(counter, 1)) saver = tf.train.Saver() summary_writer = tf.summary.FileWriter(model_dir, sess.graph) # initialize model print("[*] initializing network...") if not load(model_dir, saver, sess): sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord) # calculate number of parameters in model total_parameters = 0 for variable in tf.trainable_variables(): variable_parameters = 1 for dim in variable.get_shape(): variable_parameters *= dim.value total_parameters += variable_parameters print("Total number of trainable parameters: %d" % total_parameters) tf.summary.scalar("parameters/parameters", total_parameters) # use iterator to go through TFrecord dataset train_iterator = train_dataset.make_one_shot_iterator() features, labels = train_iterator.get_next() ks_truth = labels # ground truth kspace ks_in = features["ks_input"] # input kspace sense_in = features["sensemap"] # sensitivity maps mask_recon = features["mask_recon"] # reconstruction mask # ground truth kspace to image domain im_truth = tf_util.model_transpose(ks_truth * mask_recon, sense_in) # gather summaries for tensorboard total_summary = tf.summary.merge_all() print("Start from step %d." % (sess.run(counter))) for step in range(int(sess.run(counter)), FLAGS.max_steps): # evaluate input kspace, sensitivity maps, ground truth image ks_in_run, sense_in_run, im_truth_run = sess.run( [ks_in, sense_in, im_truth] ) # run optimizer and collect output image from model and tensorboard summary im_out, total_summary_run, _ = sess.run( [im_out_place, total_summary, optimizer], feed_dict={ ks_place: ks_in_run, sense_place: sense_in_run, im_truth_place: im_truth_run, }, ) print("step", step) # add summary to tensorboard summary_writer.add_summary(total_summary_run, step) # save checkpoint every 500 steps if step % 500 == 0: print("saving checkpoint") saver.save(sess, model_dir + "/model.ckpt") # update recorded step training is at sess.run(update_counter) print("End of training loop")
def build_model(self, mode): # Read in real undersampled image Yr self.Y_real, self.data_num, real_mask = self.read_real() if self.data_num == 0: print("Error: no training files found") exit() real_image = tf.abs(tf_util.channels_to_complex(self.Y_real)) if mode == "test": if self.data_type == "DCE_2D": # self.search_str = "/14Jul16_Ex19493_Ser5*.tfrecords" self.search_str = "/17Dec16_Ex21068_Ser13*.tfrecords" # self.search_str = "/*.tfrecords" if self.data_type == "knee": self.search_str = "/*.tfrecords" else: self.search_str = "/*.tfrecords" # Read in train or test input images to be reconstructed by generator train_iterator = mri_utils.Iterator( self.batch_size, self.mask_path, self.data_type, mode, # either train or test self.out_shape, verbose=self.verbose, train_acc=self.train_acc, search_str=self.search_str, data_dir=self.data_dir) self.input_files = train_iterator.num_files if self.input_files == 0: print("Error: no input files found") exit() train_dataset = train_iterator.iterator.get_next() ks_truth = train_dataset["ks_truth"] ks_input = train_dataset["ks_input"] sensemap = train_dataset["sensemap"] self.im_in = tf_util.model_transpose(ks_input, sensemap) self.complex_truth = tf_util.model_transpose(ks_truth, sensemap) self.z_truth = tf_util.complex_to_channels(self.complex_truth) self.ks = ks_input self.sensemap = sensemap # generate image X_gen self.X_gen = self.generator(self.ks, self.sensemap) # measure X_gen self.Y_fake = self.measure(self.X_gen, self.sensemap, real_mask) # output of discriminator for fake and real images self.d_logits_fake = self.discriminator(self.Y_fake, reuse=False) self.d_logits_real = self.discriminator(self.Y_real, reuse=True) # discriminator loss self.d_loss = tf.reduce_mean(self.d_logits_fake) - tf.reduce_mean( self.d_logits_real) # generator loss self.g_loss = -tf.reduce_mean(self.d_logits_fake) # Gradient Penalty self.epsilon = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0.0, maxval=1.0) Y_hat = self.Y_real + self.epsilon * (self.Y_fake - self.Y_real) D_Y_hat = self.discriminator(Y_hat, reuse=True) grad_D_Y_hat = tf.gradients(D_Y_hat, [Y_hat])[0] red_idx = range(1, Y_hat.shape.ndims) slopes = tf.sqrt( tf.reduce_sum(tf.square(grad_D_Y_hat), reduction_indices=list(red_idx))) self.gradient_penalty = tf.reduce_mean((slopes - 1.0)**2) # updated discriminator loss self.d_loss = self.d_loss + 10.0 * self.gradient_penalty train_vars = tf.trainable_variables() for v in train_vars: # v = tf.cast(v, tf.float32) tf.add_to_collection("reg_loss", tf.nn.l2_loss(v)) self.generator_vars = [v for v in train_vars if "generator" in v.name] self.discriminator_vars = [ v for v in train_vars if "discriminator" in v.name ] local_device_protos = device_lib.list_local_devices() device_list = [ x.name for x in local_device_protos if x.device_type == "GPU" ] cur_device = device_list[-1] # cur_device = device_list[0] with tf.device(cur_device): self.g_optimizer = tf.train.AdamOptimizer( learning_rate=self.lr, name="g_opt", beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=self.generator_vars) cur_device = device_list[0] with tf.device(cur_device): self.d_optimizer = tf.train.AdamOptimizer( learning_rate=self.lr, name="d_opt", beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=self.discriminator_vars) self.output_image = tf_util.channels_to_complex(self.X_gen) self.im_out = self.output_image self.mag_output = tf.abs(self.output_image) self.create_summary() with tf.variable_scope("counter"): self.counter = tf.get_variable( "counter", shape=[1], initializer=tf.constant_initializer([1]), dtype=tf.int32, ) self.update_counter = tf.assign(self.counter, tf.add(self.counter, 1)) self.saver = tf.train.Saver() self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) self.initialize_model()