def get_everything(self, sess, features): ks_input = features["ks_input"] sensemap = features["sensemap"] mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input window = 1 im_lowres = tf_util.model_transpose(ks_input * window, sensemap) # im_lowres = tf_util.ifft2c(ks_input) im_lowres = tf.identity(im_lowres, name="low_res_image") # complex to channels im_lowres = tf_util.complex_to_channels(im_lowres) ks_truth = features["ks_truth"] mask_recon = features["mask_recon"] im_truth = tf_util.model_transpose(ks_truth * mask_recon, sensemap) im_truth = tf.identity(im_truth, name="truth_image") im_truth = tf_util.complex_to_channels(im_truth) im_lowres, im_truth, sensemap, mask = sess.run( [im_lowres, im_truth, sensemap, mask]) return im_lowres, im_truth, sensemap, mask
def get_undersampled(self, features): ks_input = features["ks_input"] sensemap = features["sensemap"] im_lowres = tf_util.model_transpose(ks_input, sensemap) im_lowres = tf.identity(im_lowres, name="low_res_image") im_lowres = tf_util.complex_to_channels(im_lowres) return im_lowres
def _batch_norm_relu(tf_input, data_format="channels_last", training=False, activation="relu"): tf_output = _batch_norm(tf_input, data_format=data_format, training=training) input_shape = tf.shape(tf_output) if (activation == "relu" or "crelu"): tf_output = tf.nn.relu(tf_output) else: # convert two channels to complex-valued in preparation for complex-valued activation functions tf_output = tf_util.channels_to_complex(tf_output) if (activation == "zrelu"): tf_output = complex_utils.zrelu(tf_output) if (activation == "modrelu"): tf_output = complex_utils.modrelu(tf_output, data_format) if (activation == "cardioid"): tf_output = complex_utils.cardioid(tf_output) # convert complex back to two channels tf_output = tf_util.complex_to_channels(tf_output) return tf_output
def get_truth_image(self, features): ks_truth = features["ks_truth"] sensemap = features["sensemap"] mask_recon = features["mask_recon"] image_truth = tf_util.model_transpose(ks_truth * mask_recon, sensemap) image_truth = tf.identity(image_truth, name="truth_image") # complex to channels image_truth = tf_util.complex_to_channels(image_truth) return image_truth
def get_images(self, features): ks_input = features["ks_input"] sensemap = features["sensemap"] mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input im_lowres = tf_util.model_transpose(ks_input, sensemap) im_lowres = tf.identity(im_lowres, name="low_res_image") # complex to channels im_lowres = tf_util.complex_to_channels(im_lowres) ks_truth = features["ks_truth"] mask_recon = features["mask_recon"] im_truth = tf_util.model_transpose(ks_truth * mask_recon, sensemap) im_truth = tf.identity(im_truth, name="truth_image") im_truth = tf_util.complex_to_channels(im_truth) return im_lowres, im_truth
def create_summary(self): # note that ks is based on the input data not on the rotated data output_ks = tf_util.model_forward(self.output_image, self.sensemap) # Input image to generator self.input_image = tf_util.model_transpose(self.ks, self.sensemap) if self.data_type is "knee": truth_image = tf_util.channels_to_complex(self.z_truth) sum_input = tf.image.flip_up_down( tf.image.rot90(tf.abs(self.input_image))) sum_output = tf.image.flip_up_down( tf.image.rot90(tf.abs(self.output_image))) sum_truth = tf.image.flip_up_down( tf.image.rot90(tf.abs(truth_image))) train_out = tf.concat((sum_input, sum_output, sum_truth), axis=2) tf.summary.image("input-output-truth", train_out) mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64) loss_l1 = tf.reduce_mean(tf.abs(self.X_gen - self.z_truth)) loss_l2 = tf.reduce_mean( tf.square(tf.abs(self.X_gen - self.z_truth))) tf.summary.scalar("l1", loss_l1) tf.summary.scalar("l2", loss_l2) # to check supervised/unsupervised y_real = tf_util.channels_to_complex(self.Y_real) y_real = tf.image.flip_up_down(tf.image.rot90(tf.abs(y_real))) tf.summary.image("y_real/mag", y_real) # Plot losses self.d_loss_sum = tf.summary.scalar("Discriminator_loss", self.d_loss) self.g_loss_sum = tf.summary.scalar("Generator_loss", self.g_loss) self.gp_sum = tf.summary.scalar("Gradient_penalty", self.gradient_penalty) self.d_fake = tf.summary.scalar("subloss/D_fake", tf.reduce_mean(self.d_logits_fake)) self.d_real = tf.summary.scalar("subloss/D_real", tf.reduce_mean(self.d_logits_real)) self.z_sum = tf.summary.histogram( "z", tf_util.complex_to_channels(self.input_image)) self.d_sum = tf.summary.merge( [self.z_sum, self.d_loss_sum, self.d_fake, self.d_real]) self.g_sum = tf.summary.merge([self.z_sum, self.g_loss_sum]) self.train_sum = tf.summary.merge_all()
def _conv2d( tf_input, num_features=128, kernel_size=3, data_format="channels_last", circular=True, conjugate=False, ): """Conv2d with option for circular convolution.""" if data_format == "channels_last": # (batch, z, y, channels) axis_z = 1 axis_y = 2 axis_c = 3 else: # (batch, channels, z, y) axis_c = 1 axis_z = 2 axis_y = 3 pad = int((kernel_size - 0.5) / 2) tf_output = tf_input if circular: with tf.name_scope("circular_pad"): tf_output = _circular_pad(tf_output, pad, axis_z) tf_output = _circular_pad(tf_output, pad, axis_y) if type_conv == "real": print("real convolution") num_features = int(num_features) // np.sqrt(2) tf_output = tf.layers.conv2d( tf_output, num_features, kernel_size, padding="same", use_bias=False, data_format=data_format, ) if type_conv == "complex": print("complex convolution") # channels to complex tf_output = tf_util.channels_to_complex(tf_output) if num_features != 2: num_features = num_features // 2 tf_output = complex_utils.complex_conv( tf_output, num_features=num_features, kernel_size=kernel_size) if conjugate == True and num_features != 2: print("conjugation") # conjugate the output tf_real = tf_util.getReal(tf_output, data_format) imag_out = tf_util.getImag(tf_output, data_format) imag_conj = -1 * imag_out real_out = tf.concat([real_out, real_out], axis=-1) imag_out = tf.concat([imag_out, imag_conj], axis=-1) tf_output = tf.concat([real_out, imag_out], axis=-1) # complex to channels tf_output = tf_util.complex_to_channels(tf_output) if circular: shape_input = tf.shape(tf_input) shape_z = shape_input[axis_z] shape_y = shape_input[axis_y] with tf.name_scope("circular_crop"): if data_format == "channels_last": tf_output = tf_output[ :, pad: (shape_z + pad), pad: (shape_y + pad), : ] else: tf_output = tf_output[ :, :, pad: (shape_z + pad), pad: (shape_y + pad) ] # add all needed attributes to tensor else: with tf.name_scope("non_circular"): tf_output = tf_output[:, :, :, :] return tf_output
def unroll_fista( ks_input, sensemap, num_grad_steps=5, resblock_num_features=128, resblock_num_blocks=2, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, num_summary_image=0, mask=None, verbose=False, conv="real", do_conjugate=False, activation="relu", ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ if window is None: window = 1 summary_iter = None global type_conv type_conv = conv global conjugate conjugate = do_conjugate if verbose: print( "%s> Building FISTA unrolled network (%d steps)...." % (scope, num_grad_steps) ) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step with tf.variable_scope(iter_name): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) with tf.variable_scope("update"): im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): num_channels_out = im_k.shape[-1] im_k = prior_grad_res_net( im_k, training=is_training, num_features=resblock_num_features, num_blocks=resblock_num_blocks, num_features_out=num_channels_out, data_format="channels_last", activation=activation ) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") if num_summary_image > 0: with tf.name_scope("summary"): tmp = tf_util.sumofsq(im_k, keep_dims=True) if summary_iter is None: summary_iter = tmp else: summary_iter = tf.concat( (summary_iter, tmp), axis=2) tf.summary.scalar("max/" + iter_name, tf.reduce_max(tmp)) ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) # Final data projection ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") if summary_iter is not None: tf.summary.image("iter/image", summary_iter, max_outputs=num_summary_image) return im_k
def unroll_fista( ks_input, sensemap, scope="MRI", num_grad_steps=5, num_resblocks=4, num_features=64, kernel_size=[3, 3, 5], is_training=True, mask_output=1, mask=None, window=None, do_hardproj=False, do_dense=False, do_separable=False, do_rnn=False, do_circular=True, batchnorm=False, leaky=False, fix_update=False, data_format="channels_first", verbose=False, ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ # get list of GPU devices local_device_protos = device_lib.list_local_devices() device_list = [x.name for x in local_device_protos if x.device_type == "GPU"] if window is None: window = 1 summary_iter = {} if verbose: print("%s> Building FISTA unrolled network...." % scope) print("%s> Num of gradient steps: %d" % (scope, num_grad_steps)) print( "%s> Prior: %d ResBlocks, %d features" % (scope, num_resblocks, num_features) ) print("%s> Kernel size: [%d x %d x %d]" % ((scope,) + tuple(kernel_size))) if do_rnn: print("%s> Sharing weights across iterations..." % scope) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) if do_dense: print("%s> Inserting dense connections..." % scope) if do_circular: print("%s> Using circular padding..." % scope) if do_separable: print("%s> Using depth-wise separable convolutions..." % scope) if not batchnorm: print("%s> Turning off batch normalization..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 im_dense = None for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step if do_rnn: scope_name = "iter" else: scope_name = iter_name # figure out which GPU to use for this step # i_device = int(len(device_list) * i_step / num_grad_steps) # cur_device = device_list[i_device] # with tf.device(cur_device): with tf.variable_scope( scope_name, reuse=(tf.AUTO_REUSE if do_rnn else False) ): with tf.variable_scope("update"): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step if fix_update: t_update = -2.0 else: t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): # default is channels_last num_channels_out = im_k.shape[-1] if data_format == "channels_first": im_k = tf.transpose(im_k, [0, 4, 1, 2, 3]) if im_dense is not None: im_k = tf.concat([im_k, im_dense], axis=1) im_k, im_dense_k = prior_grad_res_net( im_k, training=is_training, num_features=num_features, num_blocks=num_resblocks, num_features_out=num_channels_out, kernel_size=kernel_size, data_format=data_format, circular=do_circular, separable=do_separable, batchnorm=batchnorm, leaky=leaky, ) if do_dense: if im_dense is not None: im_dense = tf.concat([im_dense, im_dense_k], axis=1) else: im_dense = im_dense_k if data_format == "channels_first": im_k = tf.transpose(im_k, [0, 2, 3, 4, 1]) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") with tf.name_scope("summary"): # tmp = tf_util.sumofsq(im_k, keep_dims=True) summary_iter[iter_name] = im_k ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") # return im_k, ks_k, summary_iter return im_k
def measure(self, X_gen, sensemap, real_mask): name = "measure" random_seed = 0 verbose = True image = tf_util.channels_to_complex(X_gen) kspace = tf_util.model_forward(image, sensemap) # input_shape = tf.shape(kspace) total_kspace = None if (self.data_type is "DCE" or self.data_type is "DCE_2D" or self.data_type is "mfast"): print("DCE measure") for i in range(self.batch_size): if self.dims == 4: ks_x = kspace[i, :, :] else: # 2D plus time ks_x = kspace[i, :, :, :] # lazy: use original applied mask and just apply it again # won't work because it isn't doing anything unless it gets flipped # mask_x = tf_util.kspace_mask(ks_x, dtype=tf.complex64) # # Augment sampling masks # # New mask - taken from image B # # mask = real_mask # mask_x = tf.image.flip_up_down(mask_x) # mask_x = tf.image.flip_left_right(mask_x) # if self.dims != 4: # mask_x = mask_x[:,:,:,0,:] # data dimensions shape_y = self.width shape_t = self.max_frames sim_partial_ky = 0.0 accs = [1, 6] rand_accel = (accs[1] - accs[0]) * \ tf.random_uniform([]) + accs[0] fn_inputs = [ shape_y, shape_t, rand_accel, 10, 2.0, sim_partial_ky, ] # ny, nt, accel, ncal, vd_degree mask_x = tf.py_func(mask.generate_perturbed2dvdkt, fn_inputs, tf.complex64) ks_x = ks_x * tf.reshape(mask_x, [1, shape_y, shape_t, 1]) if total_kspace is not None: total_kspace = tf.concat([total_kspace, ks_x], 0) else: total_kspace = ks_x if self.data_type is "knee": print("knee") for i in range(self.batch_size): ks_x = kspace[i, :, :] print("ks_x", ks_x) # Randomly select mask mask_x = tf.random_shuffle(self.masks) print("self masks", mask_x) mask_x = tf.slice(mask_x, [0, 0, 0], [1, -1, -1]) print("sliced masks", mask_x) # Augment sampling masks mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed) mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed) # Tranpose to store data as (kz, ky, channels) mask_x = tf.transpose(mask_x, [1, 2, 0]) print("transposed mask", mask_x) ks_x = tf.image.flip_up_down(ks_x) # Initially set image size to be all the same ks_x = tf.image.resize_image_with_crop_or_pad( ks_x, self.height, self.width) mask_x = tf.image.resize_image_with_crop_or_pad( mask_x, self.height, self.width) print("resized mask", mask_x) shape_cal = 20 if shape_cal > 0: with tf.name_scope("CalibRegion"): if self.verbose: print("%s> Including calib region (%d, %d)..." % (name, shape_cal, shape_cal)) mask_calib = tf.ones([shape_cal, shape_cal, 1], dtype=tf.complex64) mask_calib = tf.image.resize_image_with_crop_or_pad( mask_calib, self.height, self.width) mask_x = mask_x * (1 - mask_calib) + mask_calib mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x)) mask_recon = tf.cast(mask_recon > 1e-7, dtype=tf.complex64) mask_x = mask_x * mask_recon print("mask x", mask_x) # Assuming calibration region is fully sampled shape_sc = 5 scale = tf.image.resize_image_with_crop_or_pad( ks_x, shape_sc, shape_sc) scale = tf.reduce_mean(tf.square( tf.abs(scale))) * (shape_sc * shape_sc / 1e5) scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64) ks_x = ks_x * scale # Masked input ks_x = tf.multiply(ks_x, mask_x) if total_kspace is not None: total_kspace = tf.concat([total_kspace, ks_x], 0) else: total_kspace = ks_x x_measured = tf_util.model_transpose(total_kspace, sensemap, name="x_measured") x_measured = tf_util.complex_to_channels(x_measured) return x_measured
def generator(self, ks_input, sensemap, reuse=False): mask_example = tf_util.kspace_mask(ks_input, dtype=tf.complex64) with tf.variable_scope("generator") as scope: if reuse: scope.reuse_variables() # 2D data # batch, height, width, channels if self.dims == 4: if self.arch == "unrolled": c_out = unrolled.unroll_fista( ks_input, sensemap, num_grad_steps=self.iterations, resblock_num_features=self.g_dim, resblock_num_blocks=self.res_blocks, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, num_summary_image=6, mask=mask_example, verbose=False, ) c_out = tf_util.complex_to_channels(c_out) else: z = tf_util.model_transpose(ks_input, sensemap) z = tf_util.complex_to_channels(z) res_size = self.g_dim kernel_size = 3 num_channels = 2 # could try tf.nn.tanh instead # act = tf.nn.sigmoid num_blocks = 5 c = tf.layers.conv2d( z, res_size, kernel_size, padding="same", activation=tf.nn.relu, use_bias=True, ) for i in range(num_blocks): c = tf.layers.conv2d( c, res_size, kernel_size, padding="same", activation=tf.nn.relu, use_bias=True, ) c1 = tf.layers.conv2d( c, res_size, kernel_size, padding="same", activation=tf.nn.relu, use_bias=True, ) c = tf.add(c, c1) c8 = tf.layers.conv2d( c, num_channels, kernel_size, padding="same", activation=None, use_bias=True, ) c_out = tf.add(c8, z) c_out = tf.nn.tanh(c_out) # 3D data # batch, height, width, channels, time frames else: if self.arch == "unrolled": c_out = unrolled_3d.unroll_fista( ks_input, sensemap, num_grad_steps=self.iterations, num_features=self.g_dim, num_resblocks=self.res_blocks, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, mask=mask_example, verbose=False, data_format="channels_last", do_separable=self.do_separable, ) c_out = tf_util.complex_to_channels(c_out) return c_out
def build_model(self, mode): self.Y_real, self.data_num, real_mask = self.read_real() # Read in train or test input images to be reconstructed by generator train_iterator = mri_utils.Iterator( self.batch_size, self.mask_path, self.data_type, mode, # either train or test self.out_shape, verbose=self.verbose, train_acc=self.train_acc, data_dir=self.data_dir) self.input_files = train_iterator.num_files train_dataset = train_iterator.iterator.get_next() ks_truth = train_dataset["ks_truth"] ks_input = train_dataset["ks_input"] sensemap = train_dataset["sensemap"] image_truth = tf_util.model_transpose(ks_truth, sensemap) self.complex_truth = image_truth image_truth = tf_util.complex_to_channels(image_truth) self.z_truth = image_truth self.ks = ks_input self.sensemap = sensemap # generate image X_gen self.X_gen = self.generator(self.ks, self.sensemap) # no measurement for supervised self.Y_fake = self.X_gen # output of discriminator for fake and real images self.d_logits_fake = self.discriminator(self.Y_fake, reuse=False) self.d_logits_real = self.discriminator(self.Y_real, reuse=True) # discriminator loss self.d_loss = tf.reduce_mean(self.d_logits_fake) - tf.reduce_mean( self.d_logits_real) # add total variation loss # tv_loss = -tf.reduce_sum(tf.image.total_variation(self.Y_fake)) # generator loss self.g_loss = -tf.reduce_mean(self.d_logits_fake) # + tv_loss # Gradient Penalty self.epsilon = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0.0, maxval=1.0) Y_hat = self.Y_real + self.epsilon * (self.Y_fake - self.Y_real) D_Y_hat = self.discriminator(Y_hat, reuse=True) grad_D_Y_hat = tf.gradients(D_Y_hat, [Y_hat])[0] red_idx = range(1, Y_hat.shape.ndims) slopes = tf.sqrt( tf.reduce_sum(tf.square(grad_D_Y_hat), reduction_indices=list(red_idx))) self.gradient_penalty = tf.reduce_mean((slopes - 1.0)**2) # updated discriminator loss self.d_loss = self.d_loss + 10.0 * self.gradient_penalty train_vars = tf.trainable_variables() for v in train_vars: # v = tf.cast(v, tf.float32) tf.add_to_collection("reg_loss", tf.nn.l2_loss(v)) self.generator_vars = [v for v in train_vars if "generator" in v.name] self.discriminator_vars = [ v for v in train_vars if "discriminator" in v.name ] self.g_optimizer = tf.train.AdamOptimizer( learning_rate=self.lr, name="g_opt", beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=self.generator_vars) self.d_optimizer = tf.train.AdamOptimizer( learning_rate=self.lr, name="d_opt", beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=self.discriminator_vars) self.output_image = tf_util.channels_to_complex(self.X_gen) self.im_out = self.output_image self.mag_output = tf.abs(self.output_image) self.create_summary() with tf.variable_scope("counter"): self.counter = tf.get_variable( "counter", shape=[1], initializer=tf.constant_initializer([1]), dtype=tf.int32, ) self.update_counter = tf.assign(self.counter, tf.add(self.counter, 1)) self.saver = tf.train.Saver() self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) self.initialize_model()
def unroll_fista( ks_input, sensemap, num_grad_steps=5, resblock_num_features=128, resblock_num_blocks=2, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, num_summary_image=0, mask=None, verbose=False, ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ # get list of GPU devices local_device_protos = device_lib.list_local_devices() device_list = [x.name for x in local_device_protos if x.device_type == "GPU"] if window is None: window = 1 summary_iter = None if verbose: print( "%s> Building FISTA unrolled network (%d steps)...." % (scope, num_grad_steps) ) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step i_device = int(len(device_list) * i_step / num_grad_steps) # cur_device = device_list[i_device] cur_device = device_list[-1] with tf.device(cur_device), tf.variable_scope(iter_name): # with tf.variable_scope(iter_name): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) with tf.variable_scope("update"): im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): num_channels_out = im_k.shape[-1] # Default is channels last # im_k = prior_grad_res_net(im_k, training=is_training, # cout=num_channels_out) # Transpose channels_last to channels_first im_k = tf.transpose(im_k, [0, 3, 1, 2]) im_k = prior_grad_res_net( im_k, training=is_training, num_features=resblock_num_features, num_blocks=resblock_num_blocks, num_features_out=num_channels_out, data_format="channels_first", ) im_k = tf.transpose(im_k, [0, 2, 3, 1]) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") if num_summary_image > 0: with tf.name_scope("summary"): tmp = tf_util.sumofsq(im_k, keep_dims=True) if summary_iter is None: summary_iter = tmp else: summary_iter = tf.concat((summary_iter, tmp), axis=2) # tf.summary.scalar("max/" + iter_name, tf.reduce_max(tmp)) ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) # Final data projection ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") # if summary_iter is not None: # tf.summary.image("iter/image", summary_iter, max_outputs=num_summary_image) return im_k
def measure(self, X_gen, sensemap, real_mask): name = "measure" random_seed = 0 verbose = True image = tf_util.channels_to_complex(X_gen) kspace = tf_util.model_forward(image, sensemap) total_kspace = None if (self.data_type is "DCE" # or self.data_type is "DCE_2D" # or self.data_type is "mfast" ): # remove batch dimension # kspace = kspace[0, :, :, :, :] kspace = tf.squeeze(kspace, axis=0) # different mask for each frame for f in range(self.max_frames): ks_x = kspace[:, :, f, :] # Randomly select mask mask_x = tf.random_shuffle(self.masks) mask_x = mask_x[0, :, :] mask_x = tf.expand_dims(mask_x, axis=0) # Augment sampling masks mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed) mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed) # Tranpose to store data as (kz, ky, channels) mask_x = tf.transpose(mask_x, [1, 2, 0]) # self.applied_mask = tf.expand_dims(mask_x, axis=-1) ks_x = tf.image.flip_up_down(ks_x) # Initially set image size to be all the same ks_x = tf.image.resize_image_with_crop_or_pad( ks_x, self.height, self.width) mask_x = tf.image.resize_image_with_crop_or_pad( mask_x, self.height, self.width) shape_cal = 20 if shape_cal > 0: with tf.name_scope("CalibRegion"): if self.verbose: print("%s> Including calib region (%d, %d)..." % (name, shape_cal, shape_cal)) mask_calib = tf.ones([shape_cal, shape_cal, 1], dtype=tf.complex64) mask_calib = tf.image.resize_image_with_crop_or_pad( mask_calib, self.height, self.width) mask_x = mask_x * (1 - mask_calib) + mask_calib # mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x)) # mask_recon = tf.cast(mask_recon > 0.0, dtype=tf.complex64) # mask_x = mask_x * mask_recon # mask_x = tf.expand_dims(mask_x, axis=0) # Assuming calibration region is fully sampled shape_sc = 5 scale = tf.image.resize_image_with_crop_or_pad( ks_x, shape_sc, shape_sc) scale = tf.reduce_mean(tf.square( tf.abs(scale))) * (shape_sc * shape_sc / 1e5) scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64) ks_x = ks_x * scale ks_x = tf.multiply(ks_x, mask_x) ks_x = tf.expand_dims(ks_x, axis=-2) self.applied_mask = tf.expand_dims(mask_x, axis=0) if total_kspace is not None: total_kspace = tf.concat([total_kspace, ks_x], axis=-2) else: total_kspace = ks_x total_kspace = tf.expand_dims(total_kspace, axis=0) # for i in range(self.batch_size): # if self.dims == 4: # ks_x = kspace[i, :, :] # else: # # 2D plus time # ks_x = kspace[i, :, :, :] # # lazy: use original applied mask and just apply it again # # won't work because it isn't doing anything unless it gets flipped # # mask_x = tf_util.kspace_mask(ks_x, dtype=tf.complex64) # # # Augment sampling masks # # # New mask - taken from image B # # # mask = real_mask # # mask_x = tf.image.flip_up_down(mask_x) # # mask_x = tf.image.flip_left_right(mask_x) # # if self.dims != 4: # # mask_x = mask_x[:,:,:,0,:] # # data dimensions # shape_y = self.width # # shape_t = self.max_frames # shape_t = self.height # sim_partial_ky = 0.0 # # accs = [1, 6] # accs = [5, 6] # rand_accel = (accs[1] - accs[0]) * tf.random_uniform([]) + accs[0] # tf.summary.scalar("acc", rand_accel) # fn_inputs = [ # shape_y, # shape_t, # rand_accel, # 10, # 2.0, # sim_partial_ky, # ] # ny, nt, accel, ncal, vd_degree # mask_x = tf.py_func( # mask.generate_perturbed2dvdkt, fn_inputs, tf.complex64 # ) # print("mask x", mask_x) # self.reshaped_mask = mask_x # self.reshaped_mask = tf.reshape(mask_x, [shape_t, shape_y, 1, 1]) # ks_x = ks_x * self.reshaped_mask # if total_kspace is not None: # total_kspace = tf.concat([total_kspace, ks_x], 0) # else: # total_kspace = ks_x if self.data_type is "DCE_2D": # remove batch dimension kspace = tf.squeeze(kspace, axis=0) # ks_x = kspace[:, :, f, :] ks_x = kspace # Randomly select mask mask_x = tf.random_shuffle(self.masks) mask_x = mask_x[0, :, :] mask_x = tf.expand_dims(mask_x, axis=0) # Augment sampling masks mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed) mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed) # Tranpose to store data as (kz, ky, channels) mask_x = tf.transpose(mask_x, [1, 2, 0]) # self.applied_mask = tf.expand_dims(mask_x, axis=-1) ks_x = tf.image.flip_up_down(ks_x) # Initially set image size to be all the same ks_x = tf.image.resize_image_with_crop_or_pad( ks_x, self.height, self.width) mask_x = tf.image.resize_image_with_crop_or_pad( mask_x, self.height, self.width) shape_cal = 20 if shape_cal > 0: with tf.name_scope("CalibRegion"): if self.verbose: print("%s> Including calib region (%d, %d)..." % (name, shape_cal, shape_cal)) mask_calib = tf.ones([shape_cal, shape_cal, 1], dtype=tf.complex64) mask_calib = tf.image.resize_image_with_crop_or_pad( mask_calib, self.height, self.width) mask_x = mask_x * (1 - mask_calib) + mask_calib # mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x)) # mask_recon = tf.cast(mask_recon > 0.0, dtype=tf.complex64) # mask_x = mask_x * mask_recon # mask_x = tf.expand_dims(mask_x, axis=0) # Assuming calibration region is fully sampled shape_sc = 5 scale = tf.image.resize_image_with_crop_or_pad( ks_x, shape_sc, shape_sc) scale = tf.reduce_mean(tf.square( tf.abs(scale))) * (shape_sc * shape_sc / 1e5) scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64) ks_x = ks_x * scale ks_x = tf.multiply(ks_x, mask_x) self.applied_mask = tf.expand_dims(mask_x, axis=0) total_kspace = ks_x total_kspace = tf.expand_dims(total_kspace, axis=0) if self.data_type is "knee": for i in range(self.batch_size): ks_x = kspace[i, :, :] # Randomly select mask mask_x = tf.random_shuffle(self.masks) mask_x = tf.slice(mask_x, [0, 0, 0], [1, -1, -1]) # Augment sampling masks mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed) mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed) # Tranpose to store data as (kz, ky, channels) mask_x = tf.transpose(mask_x, [1, 2, 0]) ks_x = tf.image.flip_up_down(ks_x) # Initially set image size to be all the same ks_x = tf.image.resize_image_with_crop_or_pad( ks_x, self.height, self.width) mask_x = tf.image.resize_image_with_crop_or_pad( mask_x, self.height, self.width) shape_cal = 20 if shape_cal > 0: with tf.name_scope("CalibRegion"): if self.verbose: print("%s> Including calib region (%d, %d)..." % (name, shape_cal, shape_cal)) mask_calib = tf.ones([shape_cal, shape_cal, 1], dtype=tf.complex64) mask_calib = tf.image.resize_image_with_crop_or_pad( mask_calib, self.height, self.width) mask_x = mask_x * (1 - mask_calib) + mask_calib mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x)) # mask_recon = tf.abs(ks_x) mask_recon = tf.cast(mask_recon > 1e-7, dtype=tf.complex64) mask_x = mask_x * mask_recon # Assuming calibration region is fully sampled shape_sc = 5 scale = tf.image.resize_image_with_crop_or_pad( ks_x, shape_sc, shape_sc) scale = tf.reduce_mean(tf.square( tf.abs(scale))) * (shape_sc * shape_sc / 1e5) scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64) ks_x = ks_x * scale # Masked input ks_x = tf.multiply(ks_x, mask_x) if total_kspace is not None: total_kspace = tf.concat([total_kspace, ks_x], 0) else: total_kspace = ks_x x_measured = tf_util.model_transpose(total_kspace, sensemap, name="x_measured") x_measured = tf_util.complex_to_channels(x_measured) return x_measured
def build_model(self, mode): # Read in real undersampled image Yr self.Y_real, self.data_num, real_mask = self.read_real() if self.data_num == 0: print("Error: no training files found") exit() real_image = tf.abs(tf_util.channels_to_complex(self.Y_real)) if mode == "test": if self.data_type == "DCE_2D": # self.search_str = "/14Jul16_Ex19493_Ser5*.tfrecords" self.search_str = "/17Dec16_Ex21068_Ser13*.tfrecords" # self.search_str = "/*.tfrecords" if self.data_type == "knee": self.search_str = "/*.tfrecords" else: self.search_str = "/*.tfrecords" # Read in train or test input images to be reconstructed by generator train_iterator = mri_utils.Iterator( self.batch_size, self.mask_path, self.data_type, mode, # either train or test self.out_shape, verbose=self.verbose, train_acc=self.train_acc, search_str=self.search_str, data_dir=self.data_dir) self.input_files = train_iterator.num_files if self.input_files == 0: print("Error: no input files found") exit() train_dataset = train_iterator.iterator.get_next() ks_truth = train_dataset["ks_truth"] ks_input = train_dataset["ks_input"] sensemap = train_dataset["sensemap"] self.im_in = tf_util.model_transpose(ks_input, sensemap) self.complex_truth = tf_util.model_transpose(ks_truth, sensemap) self.z_truth = tf_util.complex_to_channels(self.complex_truth) self.ks = ks_input self.sensemap = sensemap # generate image X_gen self.X_gen = self.generator(self.ks, self.sensemap) # measure X_gen self.Y_fake = self.measure(self.X_gen, self.sensemap, real_mask) # output of discriminator for fake and real images self.d_logits_fake = self.discriminator(self.Y_fake, reuse=False) self.d_logits_real = self.discriminator(self.Y_real, reuse=True) # discriminator loss self.d_loss = tf.reduce_mean(self.d_logits_fake) - tf.reduce_mean( self.d_logits_real) # generator loss self.g_loss = -tf.reduce_mean(self.d_logits_fake) # Gradient Penalty self.epsilon = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0.0, maxval=1.0) Y_hat = self.Y_real + self.epsilon * (self.Y_fake - self.Y_real) D_Y_hat = self.discriminator(Y_hat, reuse=True) grad_D_Y_hat = tf.gradients(D_Y_hat, [Y_hat])[0] red_idx = range(1, Y_hat.shape.ndims) slopes = tf.sqrt( tf.reduce_sum(tf.square(grad_D_Y_hat), reduction_indices=list(red_idx))) self.gradient_penalty = tf.reduce_mean((slopes - 1.0)**2) # updated discriminator loss self.d_loss = self.d_loss + 10.0 * self.gradient_penalty train_vars = tf.trainable_variables() for v in train_vars: # v = tf.cast(v, tf.float32) tf.add_to_collection("reg_loss", tf.nn.l2_loss(v)) self.generator_vars = [v for v in train_vars if "generator" in v.name] self.discriminator_vars = [ v for v in train_vars if "discriminator" in v.name ] local_device_protos = device_lib.list_local_devices() device_list = [ x.name for x in local_device_protos if x.device_type == "GPU" ] cur_device = device_list[-1] # cur_device = device_list[0] with tf.device(cur_device): self.g_optimizer = tf.train.AdamOptimizer( learning_rate=self.lr, name="g_opt", beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=self.generator_vars) cur_device = device_list[0] with tf.device(cur_device): self.d_optimizer = tf.train.AdamOptimizer( learning_rate=self.lr, name="d_opt", beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=self.discriminator_vars) self.output_image = tf_util.channels_to_complex(self.X_gen) self.im_out = self.output_image self.mag_output = tf.abs(self.output_image) self.create_summary() with tf.variable_scope("counter"): self.counter = tf.get_variable( "counter", shape=[1], initializer=tf.constant_initializer([1]), dtype=tf.int32, ) self.update_counter = tf.assign(self.counter, tf.add(self.counter, 1)) self.saver = tf.train.Saver() self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) self.initialize_model()