def _create_summary(sense_place, ks_place, im_out_place, im_truth_place): sensemap = sense_place ks_input = ks_place image_output = im_out_place image_truth = im_truth_place image_input = tf_util.model_transpose(ks_input, sensemap) mask_input = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_output = tf_util.model_forward(image_output, sensemap) ks_truth = tf_util.model_forward(image_truth, sensemap) with tf.name_scope("input-output-truth"): summary_input = tf_util.sumofsq(ks_input, keep_dims=True) summary_output = tf_util.sumofsq(ks_output, keep_dims=True) summary_truth = tf_util.sumofsq(ks_truth, keep_dims=True) summary_fft = tf.log( tf.concat((summary_input, summary_output, summary_truth), axis=2) + 1e-6) tf.summary.image("kspace", summary_fft, max_outputs=FLAGS.num_summary_image) summary_input = tf_util.sumofsq(image_input, keep_dims=True) summary_output = tf_util.sumofsq(image_output, keep_dims=True) summary_truth = tf_util.sumofsq(image_truth, keep_dims=True) summary_image = tf.concat( (summary_input, summary_output, summary_truth), axis=2) tf.summary.image("image", summary_image, max_outputs=FLAGS.num_summary_image) with tf.name_scope("truth"): summary_truth_real = tf.reduce_sum(image_truth, axis=-1, keep_dims=True) summary_truth_real = tf.real(summary_truth_real) tf.summary.image("image_real", summary_truth_real, max_outputs=FLAGS.num_summary_image) with tf.name_scope("mask"): summary_mask = tf_util.sumofsq(mask_input, keep_dims=True) tf.summary.image("mask", summary_mask, max_outputs=FLAGS.num_summary_image) with tf.name_scope("sensemap"): summary_map = tf.slice(tf.abs(sensemap), [0, 0, 0, 0, 0], [-1, -1, -1, 1, -1]) summary_map = tf.transpose(summary_map, [0, 1, 4, 2, 3]) summary_map = tf.reshape( summary_map, [tf.shape(summary_map)[0], tf.shape(summary_map)[1], -1]) summary_map = tf.expand_dims(summary_map, axis=-1) tf.summary.image("image", summary_map, max_outputs=FLAGS.num_summary_image)
def unroll_fista( ks_input, sensemap, num_grad_steps=5, resblock_num_features=128, resblock_num_blocks=2, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, num_summary_image=0, mask=None, verbose=False, conv="real", do_conjugate=False, activation="relu", ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ if window is None: window = 1 summary_iter = None global type_conv type_conv = conv global conjugate conjugate = do_conjugate if verbose: print( "%s> Building FISTA unrolled network (%d steps)...." % (scope, num_grad_steps) ) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step with tf.variable_scope(iter_name): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) with tf.variable_scope("update"): im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): num_channels_out = im_k.shape[-1] im_k = prior_grad_res_net( im_k, training=is_training, num_features=resblock_num_features, num_blocks=resblock_num_blocks, num_features_out=num_channels_out, data_format="channels_last", activation=activation ) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") if num_summary_image > 0: with tf.name_scope("summary"): tmp = tf_util.sumofsq(im_k, keep_dims=True) if summary_iter is None: summary_iter = tmp else: summary_iter = tf.concat( (summary_iter, tmp), axis=2) tf.summary.scalar("max/" + iter_name, tf.reduce_max(tmp)) ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) # Final data projection ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") if summary_iter is not None: tf.summary.image("iter/image", summary_iter, max_outputs=num_summary_image) return im_k
def unroll_fista( ks_input, sensemap, num_grad_steps=5, resblock_num_features=128, resblock_num_blocks=2, is_training=True, scope="MRI", mask_output=1, window=None, do_hardproj=True, num_summary_image=0, mask=None, verbose=False, ): """Create general unrolled network for MRI. x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) = S( x_k - 2 * t * (A^T W A x - A^T W b)) """ # get list of GPU devices local_device_protos = device_lib.list_local_devices() device_list = [x.name for x in local_device_protos if x.device_type == "GPU"] if window is None: window = 1 summary_iter = None if verbose: print( "%s> Building FISTA unrolled network (%d steps)...." % (scope, num_grad_steps) ) if sensemap is not None: print("%s> Using sensitivity maps..." % scope) with tf.variable_scope(scope): if mask is None: mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) ks_input = mask * ks_input ks_0 = ks_input # x0 = A^T W b im_0 = tf_util.model_transpose(ks_0 * window, sensemap) im_0 = tf.identity(im_0, name="input_image") # To be updated ks_k = ks_0 im_k = im_0 for i_step in range(num_grad_steps): iter_name = "iter_%02d" % i_step i_device = int(len(device_list) * i_step / num_grad_steps) # cur_device = device_list[i_device] cur_device = device_list[-1] with tf.device(cur_device), tf.variable_scope(iter_name): # with tf.variable_scope(iter_name): # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) # = S( x_k - 2 * t * (A^T W A x_k - x0)) with tf.variable_scope("update"): im_k_orig = im_k # xk = A^T A x_k ks_k = tf_util.model_forward(im_k, sensemap) ks_k = mask * ks_k im_k = tf_util.model_transpose(ks_k * window, sensemap) # xk = A^T A x_k - A^T b im_k = tf_util.complex_to_channels(im_k - im_0) im_k_orig = tf_util.complex_to_channels(im_k_orig) # Update step t_update = tf.get_variable( "t", dtype=tf.float32, initializer=tf.constant([-2.0]) ) im_k = im_k_orig + t_update * im_k with tf.variable_scope("prox"): num_channels_out = im_k.shape[-1] # Default is channels last # im_k = prior_grad_res_net(im_k, training=is_training, # cout=num_channels_out) # Transpose channels_last to channels_first im_k = tf.transpose(im_k, [0, 3, 1, 2]) im_k = prior_grad_res_net( im_k, training=is_training, num_features=resblock_num_features, num_blocks=resblock_num_blocks, num_features_out=num_channels_out, data_format="channels_first", ) im_k = tf.transpose(im_k, [0, 2, 3, 1]) im_k = tf_util.channels_to_complex(im_k) im_k = tf.identity(im_k, name="image") if num_summary_image > 0: with tf.name_scope("summary"): tmp = tf_util.sumofsq(im_k, keep_dims=True) if summary_iter is None: summary_iter = tmp else: summary_iter = tf.concat((summary_iter, tmp), axis=2) # tf.summary.scalar("max/" + iter_name, tf.reduce_max(tmp)) ks_k = tf_util.model_forward(im_k, sensemap) if do_hardproj: if verbose: print("%s> Final hard data projection..." % scope) # Final data projection ks_k = mask * ks_0 + (1 - mask) * ks_k if mask_output is not None: ks_k = ks_k * mask_output im_k = tf_util.model_transpose(ks_k * window, sensemap) ks_k = tf.identity(ks_k, name="output_kspace") im_k = tf.identity(im_k, name="output_image") # if summary_iter is not None: # tf.summary.image("iter/image", summary_iter, max_outputs=num_summary_image) return im_k