def _create_summary(sense_place, ks_place, im_out_place, im_truth_place): sensemap = sense_place ks_input = ks_place image_output = im_out_place image_truth = im_truth_place image_input = tf_util.model_transpose(ks_input, sensemap) mask_input = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_output = tf_util.model_forward(image_output, sensemap) ks_truth = tf_util.model_forward(image_truth, sensemap) with tf.name_scope("input-output-truth"): summary_input = tf_util.sumofsq(ks_input, keep_dims=True) summary_output = tf_util.sumofsq(ks_output, keep_dims=True) summary_truth = tf_util.sumofsq(ks_truth, keep_dims=True) summary_fft = tf.log( tf.concat((summary_input, summary_output, summary_truth), axis=2) + 1e-6) tf.summary.image("kspace", summary_fft, max_outputs=FLAGS.num_summary_image) summary_input = tf_util.sumofsq(image_input, keep_dims=True) summary_output = tf_util.sumofsq(image_output, keep_dims=True) summary_truth = tf_util.sumofsq(image_truth, keep_dims=True) summary_image = tf.concat( (summary_input, summary_output, summary_truth), axis=2) tf.summary.image("image", summary_image, max_outputs=FLAGS.num_summary_image) with tf.name_scope("truth"): summary_truth_real = tf.reduce_sum(image_truth, axis=-1, keep_dims=True) summary_truth_real = tf.real(summary_truth_real) tf.summary.image("image_real", summary_truth_real, max_outputs=FLAGS.num_summary_image) with tf.name_scope("mask"): summary_mask = tf_util.sumofsq(mask_input, keep_dims=True) tf.summary.image("mask", summary_mask, max_outputs=FLAGS.num_summary_image) with tf.name_scope("sensemap"): summary_map = tf.slice(tf.abs(sensemap), [0, 0, 0, 0, 0], [-1, -1, -1, 1, -1]) summary_map = tf.transpose(summary_map, [0, 1, 4, 2, 3]) summary_map = tf.reshape( summary_map, [tf.shape(summary_map)[0], tf.shape(summary_map)[1], -1]) summary_map = tf.expand_dims(summary_map, axis=-1) tf.summary.image("image", summary_map, max_outputs=FLAGS.num_summary_image)
def create_summary(self): # note that ks is based on the input data not on the rotated data output_ks = tf_util.model_forward(self.output_image, self.sensemap) # Input image to generator self.input_image = tf_util.model_transpose(self.ks, self.sensemap) if self.data_type is "knee": truth_image = tf_util.channels_to_complex(self.z_truth) sum_input = tf.image.flip_up_down( tf.image.rot90(tf.abs(self.input_image))) sum_output = tf.image.flip_up_down( tf.image.rot90(tf.abs(self.output_image))) sum_truth = tf.image.flip_up_down( tf.image.rot90(tf.abs(truth_image))) train_out = tf.concat((sum_input, sum_output, sum_truth), axis=2) tf.summary.image("input-output-truth", train_out) mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64) loss_l1 = tf.reduce_mean(tf.abs(self.X_gen - self.z_truth)) loss_l2 = tf.reduce_mean( tf.square(tf.abs(self.X_gen - self.z_truth))) tf.summary.scalar("l1", loss_l1) tf.summary.scalar("l2", loss_l2) # to check supervised/unsupervised y_real = tf_util.channels_to_complex(self.Y_real) y_real = tf.image.flip_up_down(tf.image.rot90(tf.abs(y_real))) tf.summary.image("y_real/mag", y_real) # Plot losses self.d_loss_sum = tf.summary.scalar("Discriminator_loss", self.d_loss) self.g_loss_sum = tf.summary.scalar("Generator_loss", self.g_loss) self.gp_sum = tf.summary.scalar("Gradient_penalty", self.gradient_penalty) self.d_fake = tf.summary.scalar("subloss/D_fake", tf.reduce_mean(self.d_logits_fake)) self.d_real = tf.summary.scalar("subloss/D_real", tf.reduce_mean(self.d_logits_real)) self.z_sum = tf.summary.histogram( "z", tf_util.complex_to_channels(self.input_image)) self.d_sum = tf.summary.merge( [self.z_sum, self.d_loss_sum, self.d_fake, self.d_real]) self.g_sum = tf.summary.merge([self.z_sum, self.g_loss_sum]) self.train_sum = tf.summary.merge_all()
def unroll_fista( ks_input, sensemap, num_grad_steps=5, resblock_num_features=128, resblock_num_blocks=2, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, num_summary_image=0, mask=None, verbose=False, conv="real", do_conjugate=False, activation="relu", ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ if window is None: window = 1 summary_iter = None global type_conv type_conv = conv global conjugate conjugate = do_conjugate if verbose: print( "%s> Building FISTA unrolled network (%d steps)...." % (scope, num_grad_steps) ) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step with tf.variable_scope(iter_name): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) with tf.variable_scope("update"): im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): num_channels_out = im_k.shape[-1] im_k = prior_grad_res_net( im_k, training=is_training, num_features=resblock_num_features, num_blocks=resblock_num_blocks, num_features_out=num_channels_out, data_format="channels_last", activation=activation ) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") if num_summary_image > 0: with tf.name_scope("summary"): tmp = tf_util.sumofsq(im_k, keep_dims=True) if summary_iter is None: summary_iter = tmp else: summary_iter = tf.concat( (summary_iter, tmp), axis=2) tf.summary.scalar("max/" + iter_name, tf.reduce_max(tmp)) ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) # Final data projection ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") if summary_iter is not None: tf.summary.image("iter/image", summary_iter, max_outputs=num_summary_image) return im_k
def unroll_fista( ks_input, sensemap, scope="MRI", num_grad_steps=5, num_resblocks=4, num_features=64, kernel_size=[3, 3, 5], is_training=True, mask_output=1, mask=None, window=None, do_hardproj=False, do_dense=False, do_separable=False, do_rnn=False, do_circular=True, batchnorm=False, leaky=False, fix_update=False, data_format="channels_first", verbose=False, ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ # get list of GPU devices local_device_protos = device_lib.list_local_devices() device_list = [x.name for x in local_device_protos if x.device_type == "GPU"] if window is None: window = 1 summary_iter = {} if verbose: print("%s> Building FISTA unrolled network...." % scope) print("%s> Num of gradient steps: %d" % (scope, num_grad_steps)) print( "%s> Prior: %d ResBlocks, %d features" % (scope, num_resblocks, num_features) ) print("%s> Kernel size: [%d x %d x %d]" % ((scope,) + tuple(kernel_size))) if do_rnn: print("%s> Sharing weights across iterations..." % scope) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) if do_dense: print("%s> Inserting dense connections..." % scope) if do_circular: print("%s> Using circular padding..." % scope) if do_separable: print("%s> Using depth-wise separable convolutions..." % scope) if not batchnorm: print("%s> Turning off batch normalization..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 im_dense = None for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step if do_rnn: scope_name = "iter" else: scope_name = iter_name # figure out which GPU to use for this step # i_device = int(len(device_list) * i_step / num_grad_steps) # cur_device = device_list[i_device] # with tf.device(cur_device): with tf.variable_scope( scope_name, reuse=(tf.AUTO_REUSE if do_rnn else False) ): with tf.variable_scope("update"): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step if fix_update: t_update = -2.0 else: t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): # default is channels_last num_channels_out = im_k.shape[-1] if data_format == "channels_first": im_k = tf.transpose(im_k, [0, 4, 1, 2, 3]) if im_dense is not None: im_k = tf.concat([im_k, im_dense], axis=1) im_k, im_dense_k = prior_grad_res_net( im_k, training=is_training, num_features=num_features, num_blocks=num_resblocks, num_features_out=num_channels_out, kernel_size=kernel_size, data_format=data_format, circular=do_circular, separable=do_separable, batchnorm=batchnorm, leaky=leaky, ) if do_dense: if im_dense is not None: im_dense = tf.concat([im_dense, im_dense_k], axis=1) else: im_dense = im_dense_k if data_format == "channels_first": im_k = tf.transpose(im_k, [0, 2, 3, 4, 1]) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") with tf.name_scope("summary"): # tmp = tf_util.sumofsq(im_k, keep_dims=True) summary_iter[iter_name] = im_k ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") # return im_k, ks_k, summary_iter return im_k
def measure(self, X_gen, sensemap, real_mask): name = "measure" random_seed = 0 verbose = True image = tf_util.channels_to_complex(X_gen) kspace = tf_util.model_forward(image, sensemap) # input_shape = tf.shape(kspace) total_kspace = None if (self.data_type is "DCE" or self.data_type is "DCE_2D" or self.data_type is "mfast"): print("DCE measure") for i in range(self.batch_size): if self.dims == 4: ks_x = kspace[i, :, :] else: # 2D plus time ks_x = kspace[i, :, :, :] # lazy: use original applied mask and just apply it again # won't work because it isn't doing anything unless it gets flipped # mask_x = tf_util.kspace_mask(ks_x, dtype=tf.complex64) # # Augment sampling masks # # New mask - taken from image B # # mask = real_mask # mask_x = tf.image.flip_up_down(mask_x) # mask_x = tf.image.flip_left_right(mask_x) # if self.dims != 4: # mask_x = mask_x[:,:,:,0,:] # data dimensions shape_y = self.width shape_t = self.max_frames sim_partial_ky = 0.0 accs = [1, 6] rand_accel = (accs[1] - accs[0]) * \ tf.random_uniform([]) + accs[0] fn_inputs = [ shape_y, shape_t, rand_accel, 10, 2.0, sim_partial_ky, ] # ny, nt, accel, ncal, vd_degree mask_x = tf.py_func(mask.generate_perturbed2dvdkt, fn_inputs, tf.complex64) ks_x = ks_x * tf.reshape(mask_x, [1, shape_y, shape_t, 1]) if total_kspace is not None: total_kspace = tf.concat([total_kspace, ks_x], 0) else: total_kspace = ks_x if self.data_type is "knee": print("knee") for i in range(self.batch_size): ks_x = kspace[i, :, :] print("ks_x", ks_x) # Randomly select mask mask_x = tf.random_shuffle(self.masks) print("self masks", mask_x) mask_x = tf.slice(mask_x, [0, 0, 0], [1, -1, -1]) print("sliced masks", mask_x) # Augment sampling masks mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed) mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed) # Tranpose to store data as (kz, ky, channels) mask_x = tf.transpose(mask_x, [1, 2, 0]) print("transposed mask", mask_x) ks_x = tf.image.flip_up_down(ks_x) # Initially set image size to be all the same ks_x = tf.image.resize_image_with_crop_or_pad( ks_x, self.height, self.width) mask_x = tf.image.resize_image_with_crop_or_pad( mask_x, self.height, self.width) print("resized mask", mask_x) shape_cal = 20 if shape_cal > 0: with tf.name_scope("CalibRegion"): if self.verbose: print("%s> Including calib region (%d, %d)..." % (name, shape_cal, shape_cal)) mask_calib = tf.ones([shape_cal, shape_cal, 1], dtype=tf.complex64) mask_calib = tf.image.resize_image_with_crop_or_pad( mask_calib, self.height, self.width) mask_x = mask_x * (1 - mask_calib) + mask_calib mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x)) mask_recon = tf.cast(mask_recon > 1e-7, dtype=tf.complex64) mask_x = mask_x * mask_recon print("mask x", mask_x) # Assuming calibration region is fully sampled shape_sc = 5 scale = tf.image.resize_image_with_crop_or_pad( ks_x, shape_sc, shape_sc) scale = tf.reduce_mean(tf.square( tf.abs(scale))) * (shape_sc * shape_sc / 1e5) scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64) ks_x = ks_x * scale # Masked input ks_x = tf.multiply(ks_x, mask_x) if total_kspace is not None: total_kspace = tf.concat([total_kspace, ks_x], 0) else: total_kspace = ks_x x_measured = tf_util.model_transpose(total_kspace, sensemap, name="x_measured") x_measured = tf_util.complex_to_channels(x_measured) return x_measured
def unroll_fista( ks_input, sensemap, num_grad_steps=5, resblock_num_features=128, resblock_num_blocks=2, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, num_summary_image=0, mask=None, verbose=False, ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ # get list of GPU devices local_device_protos = device_lib.list_local_devices() device_list = [x.name for x in local_device_protos if x.device_type == "GPU"] if window is None: window = 1 summary_iter = None if verbose: print( "%s> Building FISTA unrolled network (%d steps)...." % (scope, num_grad_steps) ) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step i_device = int(len(device_list) * i_step / num_grad_steps) # cur_device = device_list[i_device] cur_device = device_list[-1] with tf.device(cur_device), tf.variable_scope(iter_name): # with tf.variable_scope(iter_name): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) with tf.variable_scope("update"): im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): num_channels_out = im_k.shape[-1] # Default is channels last # im_k = prior_grad_res_net(im_k, training=is_training, # cout=num_channels_out) # Transpose channels_last to channels_first im_k = tf.transpose(im_k, [0, 3, 1, 2]) im_k = prior_grad_res_net( im_k, training=is_training, num_features=resblock_num_features, num_blocks=resblock_num_blocks, num_features_out=num_channels_out, data_format="channels_first", ) im_k = tf.transpose(im_k, [0, 2, 3, 1]) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") if num_summary_image > 0: with tf.name_scope("summary"): tmp = tf_util.sumofsq(im_k, keep_dims=True) if summary_iter is None: summary_iter = tmp else: summary_iter = tf.concat((summary_iter, tmp), axis=2) # tf.summary.scalar("max/" + iter_name, tf.reduce_max(tmp)) ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) # Final data projection ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") # if summary_iter is not None: # tf.summary.image("iter/image", summary_iter, max_outputs=num_summary_image) return im_k
def measure(self, X_gen, sensemap, real_mask): name = "measure" random_seed = 0 verbose = True image = tf_util.channels_to_complex(X_gen) kspace = tf_util.model_forward(image, sensemap) total_kspace = None if (self.data_type is "DCE" # or self.data_type is "DCE_2D" # or self.data_type is "mfast" ): # remove batch dimension # kspace = kspace[0, :, :, :, :] kspace = tf.squeeze(kspace, axis=0) # different mask for each frame for f in range(self.max_frames): ks_x = kspace[:, :, f, :] # Randomly select mask mask_x = tf.random_shuffle(self.masks) mask_x = mask_x[0, :, :] mask_x = tf.expand_dims(mask_x, axis=0) # Augment sampling masks mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed) mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed) # Tranpose to store data as (kz, ky, channels) mask_x = tf.transpose(mask_x, [1, 2, 0]) # self.applied_mask = tf.expand_dims(mask_x, axis=-1) ks_x = tf.image.flip_up_down(ks_x) # Initially set image size to be all the same ks_x = tf.image.resize_image_with_crop_or_pad( ks_x, self.height, self.width) mask_x = tf.image.resize_image_with_crop_or_pad( mask_x, self.height, self.width) shape_cal = 20 if shape_cal > 0: with tf.name_scope("CalibRegion"): if self.verbose: print("%s> Including calib region (%d, %d)..." % (name, shape_cal, shape_cal)) mask_calib = tf.ones([shape_cal, shape_cal, 1], dtype=tf.complex64) mask_calib = tf.image.resize_image_with_crop_or_pad( mask_calib, self.height, self.width) mask_x = mask_x * (1 - mask_calib) + mask_calib # mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x)) # mask_recon = tf.cast(mask_recon > 0.0, dtype=tf.complex64) # mask_x = mask_x * mask_recon # mask_x = tf.expand_dims(mask_x, axis=0) # Assuming calibration region is fully sampled shape_sc = 5 scale = tf.image.resize_image_with_crop_or_pad( ks_x, shape_sc, shape_sc) scale = tf.reduce_mean(tf.square( tf.abs(scale))) * (shape_sc * shape_sc / 1e5) scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64) ks_x = ks_x * scale ks_x = tf.multiply(ks_x, mask_x) ks_x = tf.expand_dims(ks_x, axis=-2) self.applied_mask = tf.expand_dims(mask_x, axis=0) if total_kspace is not None: total_kspace = tf.concat([total_kspace, ks_x], axis=-2) else: total_kspace = ks_x total_kspace = tf.expand_dims(total_kspace, axis=0) # for i in range(self.batch_size): # if self.dims == 4: # ks_x = kspace[i, :, :] # else: # # 2D plus time # ks_x = kspace[i, :, :, :] # # lazy: use original applied mask and just apply it again # # won't work because it isn't doing anything unless it gets flipped # # mask_x = tf_util.kspace_mask(ks_x, dtype=tf.complex64) # # # Augment sampling masks # # # New mask - taken from image B # # # mask = real_mask # # mask_x = tf.image.flip_up_down(mask_x) # # mask_x = tf.image.flip_left_right(mask_x) # # if self.dims != 4: # # mask_x = mask_x[:,:,:,0,:] # # data dimensions # shape_y = self.width # # shape_t = self.max_frames # shape_t = self.height # sim_partial_ky = 0.0 # # accs = [1, 6] # accs = [5, 6] # rand_accel = (accs[1] - accs[0]) * tf.random_uniform([]) + accs[0] # tf.summary.scalar("acc", rand_accel) # fn_inputs = [ # shape_y, # shape_t, # rand_accel, # 10, # 2.0, # sim_partial_ky, # ] # ny, nt, accel, ncal, vd_degree # mask_x = tf.py_func( # mask.generate_perturbed2dvdkt, fn_inputs, tf.complex64 # ) # print("mask x", mask_x) # self.reshaped_mask = mask_x # self.reshaped_mask = tf.reshape(mask_x, [shape_t, shape_y, 1, 1]) # ks_x = ks_x * self.reshaped_mask # if total_kspace is not None: # total_kspace = tf.concat([total_kspace, ks_x], 0) # else: # total_kspace = ks_x if self.data_type is "DCE_2D": # remove batch dimension kspace = tf.squeeze(kspace, axis=0) # ks_x = kspace[:, :, f, :] ks_x = kspace # Randomly select mask mask_x = tf.random_shuffle(self.masks) mask_x = mask_x[0, :, :] mask_x = tf.expand_dims(mask_x, axis=0) # Augment sampling masks mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed) mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed) # Tranpose to store data as (kz, ky, channels) mask_x = tf.transpose(mask_x, [1, 2, 0]) # self.applied_mask = tf.expand_dims(mask_x, axis=-1) ks_x = tf.image.flip_up_down(ks_x) # Initially set image size to be all the same ks_x = tf.image.resize_image_with_crop_or_pad( ks_x, self.height, self.width) mask_x = tf.image.resize_image_with_crop_or_pad( mask_x, self.height, self.width) shape_cal = 20 if shape_cal > 0: with tf.name_scope("CalibRegion"): if self.verbose: print("%s> Including calib region (%d, %d)..." % (name, shape_cal, shape_cal)) mask_calib = tf.ones([shape_cal, shape_cal, 1], dtype=tf.complex64) mask_calib = tf.image.resize_image_with_crop_or_pad( mask_calib, self.height, self.width) mask_x = mask_x * (1 - mask_calib) + mask_calib # mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x)) # mask_recon = tf.cast(mask_recon > 0.0, dtype=tf.complex64) # mask_x = mask_x * mask_recon # mask_x = tf.expand_dims(mask_x, axis=0) # Assuming calibration region is fully sampled shape_sc = 5 scale = tf.image.resize_image_with_crop_or_pad( ks_x, shape_sc, shape_sc) scale = tf.reduce_mean(tf.square( tf.abs(scale))) * (shape_sc * shape_sc / 1e5) scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64) ks_x = ks_x * scale ks_x = tf.multiply(ks_x, mask_x) self.applied_mask = tf.expand_dims(mask_x, axis=0) total_kspace = ks_x total_kspace = tf.expand_dims(total_kspace, axis=0) if self.data_type is "knee": for i in range(self.batch_size): ks_x = kspace[i, :, :] # Randomly select mask mask_x = tf.random_shuffle(self.masks) mask_x = tf.slice(mask_x, [0, 0, 0], [1, -1, -1]) # Augment sampling masks mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed) mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed) # Tranpose to store data as (kz, ky, channels) mask_x = tf.transpose(mask_x, [1, 2, 0]) ks_x = tf.image.flip_up_down(ks_x) # Initially set image size to be all the same ks_x = tf.image.resize_image_with_crop_or_pad( ks_x, self.height, self.width) mask_x = tf.image.resize_image_with_crop_or_pad( mask_x, self.height, self.width) shape_cal = 20 if shape_cal > 0: with tf.name_scope("CalibRegion"): if self.verbose: print("%s> Including calib region (%d, %d)..." % (name, shape_cal, shape_cal)) mask_calib = tf.ones([shape_cal, shape_cal, 1], dtype=tf.complex64) mask_calib = tf.image.resize_image_with_crop_or_pad( mask_calib, self.height, self.width) mask_x = mask_x * (1 - mask_calib) + mask_calib mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x)) # mask_recon = tf.abs(ks_x) mask_recon = tf.cast(mask_recon > 1e-7, dtype=tf.complex64) mask_x = mask_x * mask_recon # Assuming calibration region is fully sampled shape_sc = 5 scale = tf.image.resize_image_with_crop_or_pad( ks_x, shape_sc, shape_sc) scale = tf.reduce_mean(tf.square( tf.abs(scale))) * (shape_sc * shape_sc / 1e5) scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64) ks_x = ks_x * scale # Masked input ks_x = tf.multiply(ks_x, mask_x) if total_kspace is not None: total_kspace = tf.concat([total_kspace, ks_x], 0) else: total_kspace = ks_x x_measured = tf_util.model_transpose(total_kspace, sensemap, name="x_measured") x_measured = tf_util.complex_to_channels(x_measured) return x_measured