def get_everything(self, sess, features): ks_input = features["ks_input"] sensemap = features["sensemap"] mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input window = 1 im_lowres = tf_util.model_transpose(ks_input * window, sensemap) # im_lowres = tf_util.ifft2c(ks_input) im_lowres = tf.identity(im_lowres, name="low_res_image") # complex to channels im_lowres = tf_util.complex_to_channels(im_lowres) ks_truth = features["ks_truth"] mask_recon = features["mask_recon"] im_truth = tf_util.model_transpose(ks_truth * mask_recon, sensemap) im_truth = tf.identity(im_truth, name="truth_image") im_truth = tf_util.complex_to_channels(im_truth) im_lowres, im_truth, sensemap, mask = sess.run( [im_lowres, im_truth, sensemap, mask]) return im_lowres, im_truth, sensemap, mask
def _create_summary(sense_place, ks_place, im_out_place, im_truth_place): sensemap = sense_place ks_input = ks_place image_output = im_out_place image_truth = im_truth_place image_input = tf_util.model_transpose(ks_input, sensemap) mask_input = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_output = tf_util.model_forward(image_output, sensemap) ks_truth = tf_util.model_forward(image_truth, sensemap) with tf.name_scope("input-output-truth"): summary_input = tf_util.sumofsq(ks_input, keep_dims=True) summary_output = tf_util.sumofsq(ks_output, keep_dims=True) summary_truth = tf_util.sumofsq(ks_truth, keep_dims=True) summary_fft = tf.log( tf.concat((summary_input, summary_output, summary_truth), axis=2) + 1e-6) tf.summary.image("kspace", summary_fft, max_outputs=FLAGS.num_summary_image) summary_input = tf_util.sumofsq(image_input, keep_dims=True) summary_output = tf_util.sumofsq(image_output, keep_dims=True) summary_truth = tf_util.sumofsq(image_truth, keep_dims=True) summary_image = tf.concat( (summary_input, summary_output, summary_truth), axis=2) tf.summary.image("image", summary_image, max_outputs=FLAGS.num_summary_image) with tf.name_scope("truth"): summary_truth_real = tf.reduce_sum(image_truth, axis=-1, keep_dims=True) summary_truth_real = tf.real(summary_truth_real) tf.summary.image("image_real", summary_truth_real, max_outputs=FLAGS.num_summary_image) with tf.name_scope("mask"): summary_mask = tf_util.sumofsq(mask_input, keep_dims=True) tf.summary.image("mask", summary_mask, max_outputs=FLAGS.num_summary_image) with tf.name_scope("sensemap"): summary_map = tf.slice(tf.abs(sensemap), [0, 0, 0, 0, 0], [-1, -1, -1, 1, -1]) summary_map = tf.transpose(summary_map, [0, 1, 4, 2, 3]) summary_map = tf.reshape( summary_map, [tf.shape(summary_map)[0], tf.shape(summary_map)[1], -1]) summary_map = tf.expand_dims(summary_map, axis=-1) tf.summary.image("image", summary_map, max_outputs=FLAGS.num_summary_image)
def get_images(self, features): ks_input = features["ks_input"] sensemap = features["sensemap"] mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input im_lowres = tf_util.model_transpose(ks_input, sensemap) im_lowres = tf.identity(im_lowres, name="low_res_image") # complex to channels im_lowres = tf_util.complex_to_channels(im_lowres) ks_truth = features["ks_truth"] mask_recon = features["mask_recon"] im_truth = tf_util.model_transpose(ks_truth * mask_recon, sensemap) im_truth = tf.identity(im_truth, name="truth_image") im_truth = tf_util.complex_to_channels(im_truth) return im_lowres, im_truth
def create_summary(self): # note that ks is based on the input data not on the rotated data output_ks = tf_util.model_forward(self.output_image, self.sensemap) # Input image to generator self.input_image = tf_util.model_transpose(self.ks, self.sensemap) if self.data_type is "knee": truth_image = tf_util.channels_to_complex(self.z_truth) sum_input = tf.image.flip_up_down( tf.image.rot90(tf.abs(self.input_image))) sum_output = tf.image.flip_up_down( tf.image.rot90(tf.abs(self.output_image))) sum_truth = tf.image.flip_up_down( tf.image.rot90(tf.abs(truth_image))) train_out = tf.concat((sum_input, sum_output, sum_truth), axis=2) tf.summary.image("input-output-truth", train_out) mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64) loss_l1 = tf.reduce_mean(tf.abs(self.X_gen - self.z_truth)) loss_l2 = tf.reduce_mean( tf.square(tf.abs(self.X_gen - self.z_truth))) tf.summary.scalar("l1", loss_l1) tf.summary.scalar("l2", loss_l2) # to check supervised/unsupervised y_real = tf_util.channels_to_complex(self.Y_real) y_real = tf.image.flip_up_down(tf.image.rot90(tf.abs(y_real))) tf.summary.image("y_real/mag", y_real) # Plot losses self.d_loss_sum = tf.summary.scalar("Discriminator_loss", self.d_loss) self.g_loss_sum = tf.summary.scalar("Generator_loss", self.g_loss) self.gp_sum = tf.summary.scalar("Gradient_penalty", self.gradient_penalty) self.d_fake = tf.summary.scalar("subloss/D_fake", tf.reduce_mean(self.d_logits_fake)) self.d_real = tf.summary.scalar("subloss/D_real", tf.reduce_mean(self.d_logits_real)) self.z_sum = tf.summary.histogram( "z", tf_util.complex_to_channels(self.input_image)) self.d_sum = tf.summary.merge( [self.z_sum, self.d_loss_sum, self.d_fake, self.d_real]) self.g_sum = tf.summary.merge([self.z_sum, self.g_loss_sum]) self.train_sum = tf.summary.merge_all()
def read_real(self): # Read in "real" undersampled images real_iterator = mri_utils.Iterator(self.batch_size, self.mask_path, self.data_type, "validate", self.out_shape, verbose=self.verbose, data_dir=self.data_dir) data_num = real_iterator.num_files real_dataset = real_iterator.iterator.get_next() img_real = real_iterator.get_truth_image(real_dataset) self.masks = real_iterator.masks ks = real_dataset["ks_input"] real_mask = tf_util.kspace_mask(ks, dtype=tf.complex64) return img_real, data_num, real_mask
def unroll_fista( ks_input, sensemap, num_grad_steps=5, resblock_num_features=128, resblock_num_blocks=2, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, num_summary_image=0, mask=None, verbose=False, conv="real", do_conjugate=False, activation="relu", ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ if window is None: window = 1 summary_iter = None global type_conv type_conv = conv global conjugate conjugate = do_conjugate if verbose: print( "%s> Building FISTA unrolled network (%d steps)...." % (scope, num_grad_steps) ) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step with tf.variable_scope(iter_name): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) with tf.variable_scope("update"): im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): num_channels_out = im_k.shape[-1] im_k = prior_grad_res_net( im_k, training=is_training, num_features=resblock_num_features, num_blocks=resblock_num_blocks, num_features_out=num_channels_out, data_format="channels_last", activation=activation ) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") if num_summary_image > 0: with tf.name_scope("summary"): tmp = tf_util.sumofsq(im_k, keep_dims=True) if summary_iter is None: summary_iter = tmp else: summary_iter = tf.concat( (summary_iter, tmp), axis=2) tf.summary.scalar("max/" + iter_name, tf.reduce_max(tmp)) ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) # Final data projection ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") if summary_iter is not None: tf.summary.image("iter/image", summary_iter, max_outputs=num_summary_image) return im_k
def unroll_fista( ks_input, sensemap, scope="MRI", num_grad_steps=5, num_resblocks=4, num_features=64, kernel_size=[3, 3, 5], is_training=True, mask_output=1, mask=None, window=None, do_hardproj=False, do_dense=False, do_separable=False, do_rnn=False, do_circular=True, batchnorm=False, leaky=False, fix_update=False, data_format="channels_first", verbose=False, ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ # get list of GPU devices local_device_protos = device_lib.list_local_devices() device_list = [x.name for x in local_device_protos if x.device_type == "GPU"] if window is None: window = 1 summary_iter = {} if verbose: print("%s> Building FISTA unrolled network...." % scope) print("%s> Num of gradient steps: %d" % (scope, num_grad_steps)) print( "%s> Prior: %d ResBlocks, %d features" % (scope, num_resblocks, num_features) ) print("%s> Kernel size: [%d x %d x %d]" % ((scope,) + tuple(kernel_size))) if do_rnn: print("%s> Sharing weights across iterations..." % scope) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) if do_dense: print("%s> Inserting dense connections..." % scope) if do_circular: print("%s> Using circular padding..." % scope) if do_separable: print("%s> Using depth-wise separable convolutions..." % scope) if not batchnorm: print("%s> Turning off batch normalization..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 im_dense = None for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step if do_rnn: scope_name = "iter" else: scope_name = iter_name # figure out which GPU to use for this step # i_device = int(len(device_list) * i_step / num_grad_steps) # cur_device = device_list[i_device] # with tf.device(cur_device): with tf.variable_scope( scope_name, reuse=(tf.AUTO_REUSE if do_rnn else False) ): with tf.variable_scope("update"): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step if fix_update: t_update = -2.0 else: t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): # default is channels_last num_channels_out = im_k.shape[-1] if data_format == "channels_first": im_k = tf.transpose(im_k, [0, 4, 1, 2, 3]) if im_dense is not None: im_k = tf.concat([im_k, im_dense], axis=1) im_k, im_dense_k = prior_grad_res_net( im_k, training=is_training, num_features=num_features, num_blocks=num_resblocks, num_features_out=num_channels_out, kernel_size=kernel_size, data_format=data_format, circular=do_circular, separable=do_separable, batchnorm=batchnorm, leaky=leaky, ) if do_dense: if im_dense is not None: im_dense = tf.concat([im_dense, im_dense_k], axis=1) else: im_dense = im_dense_k if data_format == "channels_first": im_k = tf.transpose(im_k, [0, 2, 3, 4, 1]) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") with tf.name_scope("summary"): # tmp = tf_util.sumofsq(im_k, keep_dims=True) summary_iter[iter_name] = im_k ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") # return im_k, ks_k, summary_iter return im_k
def generator(self, ks_input, sensemap, reuse=False): mask_example = tf_util.kspace_mask(ks_input, dtype=tf.complex64) with tf.variable_scope("generator") as scope: if reuse: scope.reuse_variables() # 2D data # batch, height, width, channels if self.dims == 4: if self.arch == "unrolled": c_out = unrolled.unroll_fista( ks_input, sensemap, num_grad_steps=self.iterations, resblock_num_features=self.g_dim, resblock_num_blocks=self.res_blocks, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, num_summary_image=6, mask=mask_example, verbose=False, ) c_out = tf_util.complex_to_channels(c_out) else: z = tf_util.model_transpose(ks_input, sensemap) z = tf_util.complex_to_channels(z) res_size = self.g_dim kernel_size = 3 num_channels = 2 # could try tf.nn.tanh instead # act = tf.nn.sigmoid num_blocks = 5 c = tf.layers.conv2d( z, res_size, kernel_size, padding="same", activation=tf.nn.relu, use_bias=True, ) for i in range(num_blocks): c = tf.layers.conv2d( c, res_size, kernel_size, padding="same", activation=tf.nn.relu, use_bias=True, ) c1 = tf.layers.conv2d( c, res_size, kernel_size, padding="same", activation=tf.nn.relu, use_bias=True, ) c = tf.add(c, c1) c8 = tf.layers.conv2d( c, num_channels, kernel_size, padding="same", activation=None, use_bias=True, ) c_out = tf.add(c8, z) c_out = tf.nn.tanh(c_out) # 3D data # batch, height, width, channels, time frames else: if self.arch == "unrolled": c_out = unrolled_3d.unroll_fista( ks_input, sensemap, num_grad_steps=self.iterations, num_features=self.g_dim, num_resblocks=self.res_blocks, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, mask=mask_example, verbose=False, data_format="channels_last", do_separable=self.do_separable, ) c_out = tf_util.complex_to_channels(c_out) return c_out
def test(self): print("testing") # read in a correct test dicom file to change it later # dicom_filename = get_testdata_files("MR_small.dcm")[0] # self.ds = pydicom.dcmread(dicom_filename) # 18 time frames in each DICOM max_frame = self.max_frames frame = 1 case = 1 gif = [] print("number of test cases", self.input_files) total_acc = [] mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64) numel = tf.cast(tf.size(mask_input), tf.float32) acc = numel / tf.reduce_sum(tf.abs(mask_input)) output_psnr = [] output_nrmse = [] output_ssim = [] cs_psnr = [] cs_nrmse = [] cs_ssim = [] for step in range(self.input_files): # whenever you do a lot of TF operations or sess.run(-c) in a loop, cpu memory builds up print("test file #", step) acc_run = self.sess.run(acc) # acc = np.round(acc, decimals=2) # print("acceleration", acc_run) total_acc.append(acc_run) print( "total test acc:", np.round(np.mean(total_acc), decimals=2), np.round(np.std(total_acc), decimals=2), ) # if (step < select*max_frame) or (step > (select+1)*max_frame): # continue if self.data_type is "knee": # l1 = 0.015 l1 = 0.02 if self.data_type is "DCE": l1 = 0.05 if self.data_type is "DCE_2D": l1 = 0.07 # bart_test = self.bart_cs(sample_ks, sample_sensemap, l1=l1) # im_in = tf_util.model_transpose(self.ks, self.sensemap) # output_image, input_image, complex_truth = self.sess.run([self.im_out, im_in, self.complex_truth]) output_image, complex_truth = self.sess.run( [self.im_out, self.complex_truth]) if self.data_type is "knee": # input_image = np.squeeze(input_image) output_image = np.squeeze(output_image) truth_image = np.squeeze(complex_truth) psnr, nrmse, ssim = metrics.compute_all(truth_image, output_image, sos_axis=-1) # psnr, nrmse, ssim = metrics.compute_all( # truth_image, input_image, sos_axis=-1 # ) # psnr = np.round(psnr, decimals=2) # nrmse = np.round(nrmse, decimals=2) # ssim = np.round(ssim, decimals=2) output_psnr.append(psnr) output_nrmse.append(nrmse) output_ssim.append(ssim) print("output psnr, nrmse, ssim") print( np.round(np.mean(output_psnr), decimals=2), np.round(np.mean(output_nrmse), decimals=2), np.round(np.mean(output_ssim), decimals=2), ) # psnr, nrmse, ssim = metrics.compute_all( # complex_truth, bart_test, sos_axis=-1 # ) # psnr = np.round(psnr, decimals=2) # nrmse = np.round(nrmse, decimals=2) # ssim = np.round(ssim, decimals=2) # cs_psnr.append(psnr) # cs_nrmse.append(nrmse) # cs_ssim.append(ssim) # print("cs psnr, nrmse, ssim") # print( # np.mean(psnr), np.mean(nrmse), np.mean(ssim), # ) # mag_input = np.squeeze(np.absolute(input_image)) # mag_cs = np.squeeze(np.absolute(bart_test)) # mag_output = np.squeeze(np.absolute(output_image)) # if self.data_type is "knee": # # save as png # mag_input = np.rot90(mag_input, k=3) # # norm_max = np.amax(mag_input) # mag_cs = np.rot90(mag_cs, k=3) # mag_output = np.rot90(mag_output, k=3) # mag_images = np.concatenate((mag_input, mag_cs, mag_output), axis=1) # filename = self.image_dir + "/mag_" + str(step) + ".png" # scipy.misc.imsave(filename, mag_images) if self.data_type is "DCE": gif = [] for f in range(max_frame): frame_input = mag_input[:, :, f] frame_output = mag_output[:, :, f] frame_cs = mag_cs[:, :, f] rotated_input = np.rot90(frame_input, 1) rotated_output = np.rot90(frame_output, 1) rotated_cs = np.rot90(frame_cs, 1) # normalize to help the CS brightness newMax = np.max(rotated_input) newMin = np.min(rotated_input) oldMax = np.max(rotated_cs) oldMin = np.min(rotated_cs) rotated_cs = (rotated_cs - oldMin) * (newMax - newMin) / ( oldMax - oldMin) + newMin full_images = np.concatenate( (rotated_input, rotated_output, rotated_cs), axis=1) # filename = self.log_dir + '/images/case' + str(step) + '_f' + str(f) + '.png' # scipy.misc.imsave(filename, full_images) new_filename = (self.log_dir + "/dicoms/" + str(step) + "_f" + str(f) + ".dcm") self.write_dicom(full_images, new_filename, step, f) # add each frame to a gif gif.append(full_images) print("Saving gif") gif_path = self.log_dir + "/gifs/slice_" + str(step) + ".gif" imageio.mimsave(gif_path, gif, duration=0.2) # Save as PNG # filename = self.log_dir + '/images/case' + str(case) + '_f' + str(frame) + '.png' # scipy.misc.imsave(filename, saved) # Save as DICOM if self.data_type is "DCE_2D": rotated_input = np.rot90(mag_input, 1) rotated_output = np.rot90(mag_output, 1) rotated_cs = np.rot90(mag_cs, 1) full_images = np.concatenate( (rotated_input, rotated_output, rotated_cs), axis=1) full_images = full_images * 10.0 new_filename = (self.log_dir + "/dicoms/" + str(case) + "_f" + str(frame) + ".dcm") dicom_output = np.squeeze(np.abs(full_images)) self.write_dicom(dicom_output, new_filename, case, frame) # Save as PNG filename = (self.log_dir + "/images/case" + str(case) + "_f" + str(frame) + ".png") scipy.misc.imsave(filename, full_images) if frame <= max_frame: gif.append(full_images) frame = frame + 1 if frame > max_frame: print("Max frame") gif.append(full_images) # gif = gif+100 # gif = gif.astype('uint8') # timeMax = np.max(gif, axis=-1) # gif = gif/np.max(gif) # if self.shuffle == "False": print("Saving gif") gif_path = self.log_dir + \ "/gifs/new_gif" + str(case) + ".gif" imageio.mimsave(gif_path, gif, "GIF", duration=0.2) # return back to next case frame = 1 case = case + 1 gif = [] # Create a gif if frame <= max_frame: mypath = self.log_dir + "/images/case" + str(case) search_str = "*.png" filenames = sorted(glob.glob(mypath + search_str)) # make and save gif gif_path = self.log_dir + \ "/gifs/case_" + str(case) + ".gif" images = [] for f in filenames: image = scipy.misc.imread(f) images.append(image) filename_gif = mypath + "/case.gif" imageio.mimsave(gif_path, images, duration=0.3) txt_path = os.path.join(self.log_dir, "output_metrics.txt") f = open(txt_path, "w") f.write("output psnr = " + str(np.mean(output_psnr)) + " +\- " + str(np.std(output_psnr)) + "\n" + "output nrmse = " + str(np.mean(output_nrmse)) + " +\- " + str(np.std(output_nrmse)) + "\n" + "output ssim = " + str(np.mean(output_ssim)) + " +\- " + str(np.std(output_ssim)) + "\n" + "test acc = " + str(np.mean(total_acc)) + " +\-" + str(np.std(total_acc))) f.close()
def unroll_fista( ks_input, sensemap, num_grad_steps=5, resblock_num_features=128, resblock_num_blocks=2, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, num_summary_image=0, mask=None, verbose=False, ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ # get list of GPU devices local_device_protos = device_lib.list_local_devices() device_list = [x.name for x in local_device_protos if x.device_type == "GPU"] if window is None: window = 1 summary_iter = None if verbose: print( "%s> Building FISTA unrolled network (%d steps)...." % (scope, num_grad_steps) ) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step i_device = int(len(device_list) * i_step / num_grad_steps) # cur_device = device_list[i_device] cur_device = device_list[-1] with tf.device(cur_device), tf.variable_scope(iter_name): # with tf.variable_scope(iter_name): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) with tf.variable_scope("update"): im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): num_channels_out = im_k.shape[-1] # Default is channels last # im_k = prior_grad_res_net(im_k, training=is_training, # cout=num_channels_out) # Transpose channels_last to channels_first im_k = tf.transpose(im_k, [0, 3, 1, 2]) im_k = prior_grad_res_net( im_k, training=is_training, num_features=resblock_num_features, num_blocks=resblock_num_blocks, num_features_out=num_channels_out, data_format="channels_first", ) im_k = tf.transpose(im_k, [0, 2, 3, 1]) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") if num_summary_image > 0: with tf.name_scope("summary"): tmp = tf_util.sumofsq(im_k, keep_dims=True) if summary_iter is None: summary_iter = tmp else: summary_iter = tf.concat((summary_iter, tmp), axis=2) # tf.summary.scalar("max/" + iter_name, tf.reduce_max(tmp)) ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) # Final data projection ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") # if summary_iter is not None: # tf.summary.image("iter/image", summary_iter, max_outputs=num_summary_image) return im_k
def test(self): print("testing") # read in a correct test dicom file to change it later dicom_filename = pydicom.data.get_testdata_files("MR_small.dcm")[0] self.ds = pydicom.dcmread(dicom_filename) # 18 time frames in each DICOM max_frame = self.max_frames frame = 0 x_slice = 0 case = 1 gif = [] print("number of test cases", self.input_files) total_acc = [] mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64) numel = tf.cast(tf.size(mask_input), tf.float32) acc = numel / tf.reduce_sum(tf.abs(mask_input)) input_psnr = [] input_nrmse = [] input_ssim = [] output_psnr = [] output_nrmse = [] output_ssim = [] cs_psnr = [] cs_nrmse = [] cs_ssim = [] input_volume = np.zeros((self.max_frames, 192, 80, 180)) output_volume = np.zeros((self.max_frames, 192, 80, 180)) cs_volume = np.zeros((self.max_frames, 192, 80, 180)) model_time = [] cs_time = [] for step in range(self.input_files // 2): # for step in range(20): # DCE_2D: iterator will see each slice followed by 18 time frames # then the next slice print("test file #", step) acc_run = self.sess.run(acc) total_acc.append(acc_run) print( "total test acc:", np.round(np.mean(total_acc), decimals=2), np.round(np.std(total_acc), decimals=2), ) if self.data_type is "knee": # l1 = 0.015 l1 = 0.0035 if self.data_type is "DCE": # l1 = 0.05 l1 = 0.01 if self.data_type is "DCE_2D": l1 = 0.05 model_start_time = time.time() ( input_image, output_image, complex_truth, ks_run, sensemap_run, ) = self.sess.run([ self.im_in, self.output_image, self.complex_truth, self.ks, self.sensemap, ]) runtime = time.time() - model_start_time if step is not 1: model_time.append(runtime) print("GAN: %s seconds" % np.mean(model_time), "+/- %s" % np.std(model_time)) # bart_test = np.zeros_like(output_image) cs_start_time = time.time() bart_test = self.bart_cs(ks_run, sensemap_run, l1=l1) runtime = time.time() - cs_start_time if step is not 1: cs_time.append(runtime) print("CS: %s seconds" % np.mean(cs_time), "+/- %s" % np.std(cs_time)) if self.data_type is "knee": input_image = np.squeeze(input_image) output_image = np.squeeze(output_image) truth_image = np.squeeze(complex_truth) cs_image = np.squeeze(bart_test) psnr, nrmse, ssim = metrics.compute_all(truth_image, cs_image, sos_axis=-1) cs_psnr.append(psnr) cs_nrmse.append(nrmse) cs_ssim.append(ssim) print("cs psnr, nrmse, ssim") print( np.round(np.mean(cs_psnr), decimals=2), np.round(np.mean(cs_nrmse), decimals=2), np.round(np.mean(cs_ssim), decimals=2), ) psnr, nrmse, ssim = metrics.compute_all(truth_image, output_image, sos_axis=-1) output_psnr.append(psnr) output_nrmse.append(nrmse) output_ssim.append(ssim) print("output psnr, nrmse, ssim") print( np.round(np.mean(output_psnr), decimals=2), np.round(np.mean(output_nrmse), decimals=2), np.round(np.mean(output_ssim), decimals=2), ) psnr, nrmse, ssim = metrics.compute_all(truth_image, input_image, sos_axis=-1) input_psnr.append(psnr) input_nrmse.append(nrmse) input_ssim.append(ssim) print("input psnr, nrmse, ssim") print( np.round(np.mean(input_psnr), decimals=2), np.round(np.mean(input_nrmse), decimals=2), np.round(np.mean(input_ssim), decimals=2), ) def rotate_image(img): img = np.squeeze(np.absolute(img)) if self.data_type is "DCE": img = np.transpose(img, axes=(1, 0, 2)) img = np.flip(img, axis=2) # flip the time if self.data_type is "DCE_2D": img = np.transpose(img, axes=(1, 0)) return img mag_input = rotate_image(input_image) mag_output = rotate_image(output_image) mag_cs = rotate_image(bart_test) # x, y, z, time if self.data_type is "DCE": input_volume[step, :, :, :] = mag_input output_volume[step, :, :, :] = mag_output cs_volume[step, :, :, :] = mag_cs if self.data_type is "DCE_2D": input_volume[frame, x_slice, :, :] = mag_input output_volume[frame, x_slice, :, :] = mag_output cs_volume[frame, x_slice, :, :] = mag_cs new_filename = (self.log_dir + "/dicoms/" + "output_slice_" + str(x_slice) + "_f" + str(frame) + ".dcm") self.write_dicom(mag_input, new_filename, x_slice, frame) # increment frame # if frame is 17, go back to next slice if frame == self.max_frames - 1: frame = 0 x_slice += 1 else: frame += 1 print("slice", x_slice, "time frame", frame) in_sl = np.abs(input_volume[2, 0, :, :]) filename = os.path.join(self.log_dir, os.path.basename(self.search_str[:-11])) input_dir = filename + "_input" + ".npy" output_dir = filename + "_output" + ".npy" cs_dir = filename + "_cs" + ".npy" print("saving numpy volumes") np.save(input_dir, input_volume) np.save(output_dir, output_volume) np.save(cs_dir, cs_volume) print(output_dir) print("saving cfl volumes") cfl.write(input_dir, input_volume, "R") cfl.write(output_dir, output_volume, "R") cfl.write(cs_dir, cs_volume, "R") if self.data_type is "knee": print("output psnr = " + str(np.mean(output_psnr)) + " +\- " + str(np.std(output_psnr)) + "\n" + "output nrmse = " + str(np.mean(output_nrmse)) + " +\- " + str(np.std(output_nrmse)) + "\n" + "output ssim = " + str(np.mean(output_ssim)) + " +\- " + str(np.std(output_ssim)) + "\n" + "test acc = " + str(np.mean(total_acc)) + " +\-" + str(np.std(total_acc))) txt_path = os.path.join(self.log_dir, "output_metrics.txt") f = open(txt_path, "w") f.write("output psnr = " + str(np.mean(output_psnr)) + " +\- " + str(np.std(output_psnr)) + "\n" + "output nrmse = " + str(np.mean(output_nrmse)) + " +\- " + str(np.std(output_nrmse)) + "\n" + "output ssim = " + str(np.mean(output_ssim)) + " +\- " + str(np.std(output_ssim)) + "\n" + "input psnr = " + str(np.mean(input_psnr)) + " +\- " + str(np.std(input_psnr)) + "\n" + "input nrmse = " + str(np.mean(input_nrmse)) + " +\- " + str(np.std(input_nrmse)) + "\n" + "input ssim = " + str(np.mean(input_ssim)) + " +\- " + str(np.std(input_ssim)) + "\n" + "test acc = " + str(np.mean(total_acc)) + " +\-" + str(np.std(total_acc))) f.close() txt_path = os.path.join(self.log_dir, "cs_metrics.txt") f = open(txt_path, "w") f.write("cs psnr = " + str(np.mean(cs_psnr)) + " +\- " + str(np.std(cs_psnr)) + "\n" + "output nrmse = " + str(np.mean(cs_nrmse)) + " +\- " + str(np.std(cs_nrmse)) + "\n" + "output ssim = " + str(np.mean(cs_ssim)) + " +\- " + str(np.std(cs_ssim))) f.close()