Beispiel #1
0
    def get_everything(self, sess, features):
        ks_input = features["ks_input"]
        sensemap = features["sensemap"]
        mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)

        if mask is None:
            mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)

        ks_input = mask * ks_input

        window = 1
        im_lowres = tf_util.model_transpose(ks_input * window, sensemap)
        # im_lowres = tf_util.ifft2c(ks_input)
        im_lowres = tf.identity(im_lowres, name="low_res_image")

        # complex to channels
        im_lowres = tf_util.complex_to_channels(im_lowres)

        ks_truth = features["ks_truth"]
        mask_recon = features["mask_recon"]

        im_truth = tf_util.model_transpose(ks_truth * mask_recon, sensemap)

        im_truth = tf.identity(im_truth, name="truth_image")
        im_truth = tf_util.complex_to_channels(im_truth)

        im_lowres, im_truth, sensemap, mask = sess.run(
            [im_lowres, im_truth, sensemap, mask])

        return im_lowres, im_truth, sensemap, mask
Beispiel #2
0
 def get_undersampled(self, features):
     ks_input = features["ks_input"]
     sensemap = features["sensemap"]
     im_lowres = tf_util.model_transpose(ks_input, sensemap)
     im_lowres = tf.identity(im_lowres, name="low_res_image")
     im_lowres = tf_util.complex_to_channels(im_lowres)
     return im_lowres
Beispiel #3
0
def _create_summary(sense_place, ks_place, im_out_place, im_truth_place):
    sensemap = sense_place
    ks_input = ks_place
    image_output = im_out_place
    image_truth = im_truth_place

    image_input = tf_util.model_transpose(ks_input, sensemap)
    mask_input = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
    ks_output = tf_util.model_forward(image_output, sensemap)
    ks_truth = tf_util.model_forward(image_truth, sensemap)

    with tf.name_scope("input-output-truth"):
        summary_input = tf_util.sumofsq(ks_input, keep_dims=True)
        summary_output = tf_util.sumofsq(ks_output, keep_dims=True)
        summary_truth = tf_util.sumofsq(ks_truth, keep_dims=True)
        summary_fft = tf.log(
            tf.concat((summary_input, summary_output, summary_truth), axis=2) +
            1e-6)
        tf.summary.image("kspace",
                         summary_fft,
                         max_outputs=FLAGS.num_summary_image)
        summary_input = tf_util.sumofsq(image_input, keep_dims=True)
        summary_output = tf_util.sumofsq(image_output, keep_dims=True)
        summary_truth = tf_util.sumofsq(image_truth, keep_dims=True)
        summary_image = tf.concat(
            (summary_input, summary_output, summary_truth), axis=2)
        tf.summary.image("image",
                         summary_image,
                         max_outputs=FLAGS.num_summary_image)

    with tf.name_scope("truth"):
        summary_truth_real = tf.reduce_sum(image_truth,
                                           axis=-1,
                                           keep_dims=True)
        summary_truth_real = tf.real(summary_truth_real)
        tf.summary.image("image_real",
                         summary_truth_real,
                         max_outputs=FLAGS.num_summary_image)

    with tf.name_scope("mask"):
        summary_mask = tf_util.sumofsq(mask_input, keep_dims=True)
        tf.summary.image("mask",
                         summary_mask,
                         max_outputs=FLAGS.num_summary_image)

    with tf.name_scope("sensemap"):
        summary_map = tf.slice(tf.abs(sensemap), [0, 0, 0, 0, 0],
                               [-1, -1, -1, 1, -1])
        summary_map = tf.transpose(summary_map, [0, 1, 4, 2, 3])
        summary_map = tf.reshape(
            summary_map,
            [tf.shape(summary_map)[0],
             tf.shape(summary_map)[1], -1])
        summary_map = tf.expand_dims(summary_map, axis=-1)
        tf.summary.image("image",
                         summary_map,
                         max_outputs=FLAGS.num_summary_image)
Beispiel #4
0
    def get_truth_image(self, features):
        ks_truth = features["ks_truth"]
        sensemap = features["sensemap"]
        mask_recon = features["mask_recon"]

        image_truth = tf_util.model_transpose(ks_truth * mask_recon, sensemap)
        image_truth = tf.identity(image_truth, name="truth_image")

        # complex to channels
        image_truth = tf_util.complex_to_channels(image_truth)

        return image_truth
Beispiel #5
0
    def get_images(self, features):
        ks_input = features["ks_input"]
        sensemap = features["sensemap"]
        mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)

        if mask is None:
            mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)

        ks_input = mask * ks_input

        im_lowres = tf_util.model_transpose(ks_input, sensemap)
        im_lowres = tf.identity(im_lowres, name="low_res_image")

        # complex to channels
        im_lowres = tf_util.complex_to_channels(im_lowres)

        ks_truth = features["ks_truth"]
        mask_recon = features["mask_recon"]

        im_truth = tf_util.model_transpose(ks_truth * mask_recon, sensemap)
        im_truth = tf.identity(im_truth, name="truth_image")
        im_truth = tf_util.complex_to_channels(im_truth)

        return im_lowres, im_truth
Beispiel #6
0
    def create_summary(self):
        # note that ks is based on the input data not on the rotated data
        output_ks = tf_util.model_forward(self.output_image, self.sensemap)

        # Input image to generator
        self.input_image = tf_util.model_transpose(self.ks, self.sensemap)

        if self.data_type is "knee":
            truth_image = tf_util.channels_to_complex(self.z_truth)

            sum_input = tf.image.flip_up_down(
                tf.image.rot90(tf.abs(self.input_image)))
            sum_output = tf.image.flip_up_down(
                tf.image.rot90(tf.abs(self.output_image)))
            sum_truth = tf.image.flip_up_down(
                tf.image.rot90(tf.abs(truth_image)))

            train_out = tf.concat((sum_input, sum_output, sum_truth), axis=2)
            tf.summary.image("input-output-truth", train_out)

            mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64)

            loss_l1 = tf.reduce_mean(tf.abs(self.X_gen - self.z_truth))
            loss_l2 = tf.reduce_mean(
                tf.square(tf.abs(self.X_gen - self.z_truth)))
            tf.summary.scalar("l1", loss_l1)
            tf.summary.scalar("l2", loss_l2)

            # to check supervised/unsupervised
            y_real = tf_util.channels_to_complex(self.Y_real)
            y_real = tf.image.flip_up_down(tf.image.rot90(tf.abs(y_real)))
            tf.summary.image("y_real/mag", y_real)

        # Plot losses
        self.d_loss_sum = tf.summary.scalar("Discriminator_loss", self.d_loss)
        self.g_loss_sum = tf.summary.scalar("Generator_loss", self.g_loss)
        self.gp_sum = tf.summary.scalar("Gradient_penalty",
                                        self.gradient_penalty)
        self.d_fake = tf.summary.scalar("subloss/D_fake",
                                        tf.reduce_mean(self.d_logits_fake))
        self.d_real = tf.summary.scalar("subloss/D_real",
                                        tf.reduce_mean(self.d_logits_real))
        self.z_sum = tf.summary.histogram(
            "z", tf_util.complex_to_channels(self.input_image))
        self.d_sum = tf.summary.merge(
            [self.z_sum, self.d_loss_sum, self.d_fake, self.d_real])
        self.g_sum = tf.summary.merge([self.z_sum, self.g_loss_sum])
        self.train_sum = tf.summary.merge_all()
def unroll_fista(
    ks_input,
    sensemap,
    num_grad_steps=5,
    resblock_num_features=128,
    resblock_num_blocks=2,
    is_training=True,
    scope="MRI",
    mask_output=1,
    window=None,
    do_hardproj=True,
    num_summary_image=0,
    mask=None,
    verbose=False,
    conv="real",
    do_conjugate=False,
    activation="relu",
):
    """Create general unrolled network for MRI.
    x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) )
            = S( x_k - 2 * t * (A^T W A x - A^T W b))
    """
    if window is None:
        window = 1
    summary_iter = None

    global type_conv
    type_conv = conv
    global conjugate
    conjugate = do_conjugate

    if verbose:
        print(
            "%s> Building FISTA unrolled network (%d steps)...."
            % (scope, num_grad_steps)
        )
        if sensemap is not None:
            print("%s>   Using sensitivity maps..." % scope)
    with tf.variable_scope(scope):
        if mask is None:
            mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
        ks_input = mask * ks_input
        ks_0 = ks_input
        # x0 = A^T W b
        im_0 = tf_util.model_transpose(ks_0 * window, sensemap)
        im_0 = tf.identity(im_0, name="input_image")
        # To be updated
        ks_k = ks_0
        im_k = im_0

        for i_step in range(num_grad_steps):
            iter_name = "iter_%02d" % i_step
            with tf.variable_scope(iter_name):
                # = S( x_k - 2 * t * (A^T W A x_k - A^T W b))
                # = S( x_k - 2 * t * (A^T W A x_k - x0))
                with tf.variable_scope("update"):
                    im_k_orig = im_k
                    # xk = A^T A x_k
                    ks_k = tf_util.model_forward(im_k, sensemap)
                    ks_k = mask * ks_k
                    im_k = tf_util.model_transpose(ks_k * window, sensemap)
                    # xk = A^T A x_k - A^T b
                    im_k = tf_util.complex_to_channels(im_k - im_0)
                    im_k_orig = tf_util.complex_to_channels(im_k_orig)
                    # Update step
                    t_update = tf.get_variable(
                        "t", dtype=tf.float32, initializer=tf.constant([-2.0])
                    )
                    im_k = im_k_orig + t_update * im_k

                with tf.variable_scope("prox"):
                    num_channels_out = im_k.shape[-1]
                    im_k = prior_grad_res_net(
                        im_k,
                        training=is_training,
                        num_features=resblock_num_features,
                        num_blocks=resblock_num_blocks,
                        num_features_out=num_channels_out,
                        data_format="channels_last",
                        activation=activation
                    )
                    im_k = tf_util.channels_to_complex(im_k)

                im_k = tf.identity(im_k, name="image")
                if num_summary_image > 0:
                    with tf.name_scope("summary"):
                        tmp = tf_util.sumofsq(im_k, keep_dims=True)
                        if summary_iter is None:
                            summary_iter = tmp
                        else:
                            summary_iter = tf.concat(
                                (summary_iter, tmp), axis=2)
                        tf.summary.scalar("max/" + iter_name,
                                          tf.reduce_max(tmp))

        ks_k = tf_util.model_forward(im_k, sensemap)
        if do_hardproj:
            if verbose:
                print("%s>   Final hard data projection..." % scope)
            # Final data projection
            ks_k = mask * ks_0 + (1 - mask) * ks_k
            if mask_output is not None:
                ks_k = ks_k * mask_output
            im_k = tf_util.model_transpose(ks_k * window, sensemap)

        ks_k = tf.identity(ks_k, name="output_kspace")
        im_k = tf.identity(im_k, name="output_image")

    if summary_iter is not None:
        tf.summary.image("iter/image", summary_iter,
                         max_outputs=num_summary_image)

    return im_k
Beispiel #8
0
def main(_):
    if FLAGS.batch_size is not 1:
        print("Error: to test images, batch size must be 1")
        exit()

    model_dir = os.path.join(FLAGS.log_root, FLAGS.train_dir)
    if not os.path.exists(FLAGS.log_root):
        os.makedirs(FLAGS.log_root)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    bart_dir = os.path.join(model_dir, "bart_recon")
    if not os.path.exists(bart_dir):
        os.makedirs(bart_dir)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        """Execute main function."""
        os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.device

        if not FLAGS.dataset_dir:
            raise ValueError("You must supply the dataset directory with " +
                             "--dataset_dir")

        if FLAGS.random_seed >= 0:
            random.seed(FLAGS.random_seed)
            np.random.seed(FLAGS.random_seed)

        tf.logging.set_verbosity(tf.logging.INFO)

        print("Preparing dataset...")
        out_shape = [FLAGS.shape_z, FLAGS.shape_y]

        test_dataset, num_files = mri_data.create_dataset(
            os.path.join(FLAGS.dataset_dir, "test"),
            FLAGS.mask_path,
            num_channels=FLAGS.num_channels,
            num_emaps=FLAGS.num_emaps,
            batch_size=FLAGS.batch_size,
            out_shape=out_shape,
        )
        # channels first: (batch, channels, z, y)
        # placeholders
        ks_shape = [None, FLAGS.shape_z, FLAGS.shape_y, FLAGS.num_channels]
        ks_place = tf.placeholder(tf.complex64, ks_shape)
        sense_shape = [
            None, FLAGS.shape_z, FLAGS.shape_y, 1, FLAGS.num_channels
        ]
        sense_place = tf.placeholder(tf.complex64, sense_shape)
        im_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1]
        im_truth_place = tf.placeholder(tf.complex64, im_shape)
        # run through unrolled
        im_out_place = mri_model.unroll_fista(
            ks_place,
            sense_place,
            is_training=True,
            verbose=True,
            do_hardproj=FLAGS.do_hard_proj,
            num_summary_image=FLAGS.num_summary_image,
            resblock_num_features=FLAGS.feat_map,
            num_grad_steps=FLAGS.num_grad_steps,
            conv=FLAGS.conv,
            do_conjugate=FLAGS.do_conjugate,
        )

        saver = tf.train.Saver()
        summary_writer = tf.summary.FileWriter(model_dir, sess.graph)

        # initialize model
        print("[*] initializing network...")
        if not load(model_dir, saver, sess):
            sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)

        # See how many parameters are in model
        total_parameters = 0
        for variable in tf.trainable_variables():
            variable_parameters = 1
            for dim in variable.get_shape():
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print("Total number of trainable parameters: %d" % total_parameters)

        test_iterator = test_dataset.make_one_shot_iterator()
        features, labels = test_iterator.get_next()

        ks_truth = labels
        ks_in = features["ks_input"]
        sense_in = features["sensemap"]
        mask_recon = features["mask_recon"]
        im_truth = tf_util.model_transpose(ks_truth * mask_recon, sense_in)

        total_summary = tf.summary.merge_all()

        output_psnr = []
        output_nrmse = []
        output_ssim = []
        cs_psnr = []
        cs_nrmse = []
        cs_ssim = []

        for test_file in range(num_files):
            ks_in_run, sense_in_run, im_truth_run = sess.run(
                [ks_in, sense_in, im_truth])
            im_out, total_summary_run = sess.run(
                [im_out_place, total_summary],
                feed_dict={
                    ks_place: ks_in_run,
                    sense_place: sense_in_run,
                    im_truth_place: im_truth_run,
                },
            )

            # CS recon
            bart_test = bart_cs(bart_dir, ks_in_run, sense_in_run, l1=0.007)
            # bart_test = None

            # handle batch dimension
            for b in range(FLAGS.batch_size):
                truth = im_truth_run[b, :, :, :]
                out = im_out[b, :, :, :]
                psnr, nrmse, ssim = metrics.compute_all(truth,
                                                        out,
                                                        sos_axis=-1)
                output_psnr.append(psnr)
                output_nrmse.append(nrmse)
                output_ssim.append(ssim)

            print("output mean +/ standard deviation psnr, nrmse, ssim")
            print(
                np.mean(output_psnr),
                np.std(output_psnr),
                np.mean(output_nrmse),
                np.std(output_nrmse),
                np.mean(output_ssim),
                np.std(output_ssim),
            )

            psnr, nrmse, ssim = metrics.compute_all(im_truth_run,
                                                    bart_test,
                                                    sos_axis=-1)
            cs_psnr.append(psnr)
            cs_nrmse.append(nrmse)
            cs_ssim.append(ssim)

            print("cs mean +/ standard deviation psnr, nrmse, ssim")
            print(
                np.mean(cs_psnr),
                np.std(cs_psnr),
                np.mean(cs_nrmse),
                np.std(cs_nrmse),
                np.mean(cs_ssim),
                np.std(cs_ssim),
            )
        print("End of testing loop")
        txt_path = os.path.join(model_dir, "metrics.txt")
        f = open(txt_path, "w")
        f.write("parameters = " + str(total_parameters) + "\n" +
                "output psnr = " + str(np.mean(output_psnr)) + " +\- " +
                str(np.std(output_psnr)) + "\n" + "output nrmse = " +
                str(np.mean(output_nrmse)) + " +\- " +
                str(np.std(output_nrmse)) + "\n" + "output ssim = " +
                str(np.mean(output_ssim)) + " +\- " +
                str(np.std(output_ssim)) + "\n"
                "cs psnr = " + str(np.mean(cs_psnr)) + " +\- " +
                str(np.std(cs_psnr)) + "\n" + "output nrmse = " +
                str(np.mean(cs_nrmse)) + " +\- " + str(np.std(cs_nrmse)) +
                "\n" + "output ssim = " + str(np.mean(cs_ssim)) + " +\- " +
                str(np.std(cs_ssim)))
        f.close()
    def measure(self, X_gen, sensemap, real_mask):
        name = "measure"
        random_seed = 0
        verbose = True
        image = tf_util.channels_to_complex(X_gen)
        kspace = tf_util.model_forward(image, sensemap)
        total_kspace = None
        if (self.data_type is "DCE"
                # or self.data_type is "DCE_2D"
                # or self.data_type is "mfast"
            ):
            # remove batch dimension
            # kspace = kspace[0, :, :, :, :]
            kspace = tf.squeeze(kspace, axis=0)
            # different mask for each frame
            for f in range(self.max_frames):
                ks_x = kspace[:, :, f, :]
                # Randomly select mask
                mask_x = tf.random_shuffle(self.masks)
                mask_x = mask_x[0, :, :]
                mask_x = tf.expand_dims(mask_x, axis=0)
                # Augment sampling masks
                mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed)
                mask_x = tf.image.random_flip_left_right(mask_x,
                                                         seed=random_seed)
                # Tranpose to store data as (kz, ky, channels)
                mask_x = tf.transpose(mask_x, [1, 2, 0])
                # self.applied_mask = tf.expand_dims(mask_x, axis=-1)
                ks_x = tf.image.flip_up_down(ks_x)
                # Initially set image size to be all the same
                ks_x = tf.image.resize_image_with_crop_or_pad(
                    ks_x, self.height, self.width)
                mask_x = tf.image.resize_image_with_crop_or_pad(
                    mask_x, self.height, self.width)
                shape_cal = 20
                if shape_cal > 0:
                    with tf.name_scope("CalibRegion"):
                        if self.verbose:
                            print("%s>  Including calib region (%d, %d)..." %
                                  (name, shape_cal, shape_cal))
                        mask_calib = tf.ones([shape_cal, shape_cal, 1],
                                             dtype=tf.complex64)
                        mask_calib = tf.image.resize_image_with_crop_or_pad(
                            mask_calib, self.height, self.width)
                        mask_x = mask_x * (1 - mask_calib) + mask_calib

                    # mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x))
                    # mask_recon = tf.cast(mask_recon > 0.0, dtype=tf.complex64)
                    # mask_x = mask_x * mask_recon
                    # mask_x = tf.expand_dims(mask_x, axis=0)

                    # Assuming calibration region is fully sampled
                    shape_sc = 5
                    scale = tf.image.resize_image_with_crop_or_pad(
                        ks_x, shape_sc, shape_sc)
                    scale = tf.reduce_mean(tf.square(
                        tf.abs(scale))) * (shape_sc * shape_sc / 1e5)
                    scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64)
                    ks_x = ks_x * scale
                    ks_x = tf.multiply(ks_x, mask_x)
                    ks_x = tf.expand_dims(ks_x, axis=-2)
                    self.applied_mask = tf.expand_dims(mask_x, axis=0)
                if total_kspace is not None:
                    total_kspace = tf.concat([total_kspace, ks_x], axis=-2)
                else:
                    total_kspace = ks_x
            total_kspace = tf.expand_dims(total_kspace, axis=0)
            # for i in range(self.batch_size):
            # if self.dims == 4:
            #     ks_x = kspace[i, :, :]
            # else:
            #     # 2D plus time
            #     ks_x = kspace[i, :, :, :]
            # # lazy: use original applied mask and just apply it again
            # # won't work because it isn't doing anything unless it gets flipped
            # # mask_x = tf_util.kspace_mask(ks_x, dtype=tf.complex64)
            # # # Augment sampling masks

            # # # New mask - taken from image B
            # # # mask = real_mask
            # # mask_x = tf.image.flip_up_down(mask_x)
            # # mask_x = tf.image.flip_left_right(mask_x)

            # # if self.dims != 4:
            # #     mask_x = mask_x[:,:,:,0,:]

            # # data dimensions
            # shape_y = self.width
            # # shape_t = self.max_frames
            # shape_t = self.height
            # sim_partial_ky = 0.0

            # # accs = [1, 6]
            # accs = [5, 6]
            # rand_accel = (accs[1] - accs[0]) * tf.random_uniform([]) + accs[0]
            # tf.summary.scalar("acc", rand_accel)
            # fn_inputs = [
            #     shape_y,
            #     shape_t,
            #     rand_accel,
            #     10,
            #     2.0,
            #     sim_partial_ky,
            # ]  # ny, nt, accel, ncal, vd_degree
            # mask_x = tf.py_func(
            #     mask.generate_perturbed2dvdkt, fn_inputs, tf.complex64
            # )
            # print("mask x", mask_x)
            # self.reshaped_mask = mask_x
            # self.reshaped_mask = tf.reshape(mask_x, [shape_t, shape_y, 1, 1])
            # ks_x = ks_x * self.reshaped_mask

            # if total_kspace is not None:
            #     total_kspace = tf.concat([total_kspace, ks_x], 0)
            # else:
            #     total_kspace = ks_x
        if self.data_type is "DCE_2D":
            # remove batch dimension
            kspace = tf.squeeze(kspace, axis=0)
            # ks_x = kspace[:, :, f, :]
            ks_x = kspace
            # Randomly select mask
            mask_x = tf.random_shuffle(self.masks)
            mask_x = mask_x[0, :, :]
            mask_x = tf.expand_dims(mask_x, axis=0)
            # Augment sampling masks
            mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed)
            mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed)
            # Tranpose to store data as (kz, ky, channels)
            mask_x = tf.transpose(mask_x, [1, 2, 0])
            # self.applied_mask = tf.expand_dims(mask_x, axis=-1)
            ks_x = tf.image.flip_up_down(ks_x)
            # Initially set image size to be all the same
            ks_x = tf.image.resize_image_with_crop_or_pad(
                ks_x, self.height, self.width)
            mask_x = tf.image.resize_image_with_crop_or_pad(
                mask_x, self.height, self.width)
            shape_cal = 20
            if shape_cal > 0:
                with tf.name_scope("CalibRegion"):
                    if self.verbose:
                        print("%s>  Including calib region (%d, %d)..." %
                              (name, shape_cal, shape_cal))
                    mask_calib = tf.ones([shape_cal, shape_cal, 1],
                                         dtype=tf.complex64)
                    mask_calib = tf.image.resize_image_with_crop_or_pad(
                        mask_calib, self.height, self.width)
                    mask_x = mask_x * (1 - mask_calib) + mask_calib

                # mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x))
                # mask_recon = tf.cast(mask_recon > 0.0, dtype=tf.complex64)
                # mask_x = mask_x * mask_recon
                # mask_x = tf.expand_dims(mask_x, axis=0)

                # Assuming calibration region is fully sampled
                shape_sc = 5
                scale = tf.image.resize_image_with_crop_or_pad(
                    ks_x, shape_sc, shape_sc)
                scale = tf.reduce_mean(tf.square(
                    tf.abs(scale))) * (shape_sc * shape_sc / 1e5)
                scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64)
                ks_x = ks_x * scale
                ks_x = tf.multiply(ks_x, mask_x)
                self.applied_mask = tf.expand_dims(mask_x, axis=0)
                total_kspace = ks_x
            total_kspace = tf.expand_dims(total_kspace, axis=0)

        if self.data_type is "knee":
            for i in range(self.batch_size):
                ks_x = kspace[i, :, :]
                # Randomly select mask
                mask_x = tf.random_shuffle(self.masks)
                mask_x = tf.slice(mask_x, [0, 0, 0], [1, -1, -1])
                # Augment sampling masks
                mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed)
                mask_x = tf.image.random_flip_left_right(mask_x,
                                                         seed=random_seed)
                # Tranpose to store data as (kz, ky, channels)
                mask_x = tf.transpose(mask_x, [1, 2, 0])
                ks_x = tf.image.flip_up_down(ks_x)
                # Initially set image size to be all the same
                ks_x = tf.image.resize_image_with_crop_or_pad(
                    ks_x, self.height, self.width)
                mask_x = tf.image.resize_image_with_crop_or_pad(
                    mask_x, self.height, self.width)

                shape_cal = 20
                if shape_cal > 0:
                    with tf.name_scope("CalibRegion"):
                        if self.verbose:
                            print("%s>  Including calib region (%d, %d)..." %
                                  (name, shape_cal, shape_cal))
                        mask_calib = tf.ones([shape_cal, shape_cal, 1],
                                             dtype=tf.complex64)
                        mask_calib = tf.image.resize_image_with_crop_or_pad(
                            mask_calib, self.height, self.width)
                        mask_x = mask_x * (1 - mask_calib) + mask_calib

                    mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x))
                    # mask_recon = tf.abs(ks_x)
                    mask_recon = tf.cast(mask_recon > 1e-7, dtype=tf.complex64)
                    mask_x = mask_x * mask_recon

                    # Assuming calibration region is fully sampled
                    shape_sc = 5
                    scale = tf.image.resize_image_with_crop_or_pad(
                        ks_x, shape_sc, shape_sc)
                    scale = tf.reduce_mean(tf.square(
                        tf.abs(scale))) * (shape_sc * shape_sc / 1e5)
                    scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64)
                    ks_x = ks_x * scale
                    # Masked input
                    ks_x = tf.multiply(ks_x, mask_x)

                if total_kspace is not None:
                    total_kspace = tf.concat([total_kspace, ks_x], 0)
                else:
                    total_kspace = ks_x

        x_measured = tf_util.model_transpose(total_kspace,
                                             sensemap,
                                             name="x_measured")
        x_measured = tf_util.complex_to_channels(x_measured)
        return x_measured
def main(_):
    model_dir = os.path.join(FLAGS.log_root, FLAGS.train_dir)
    if not os.path.exists(FLAGS.log_root):
        os.makedirs(FLAGS.log_root)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    bart_dir = os.path.join(model_dir, "bart_recon")
    if not os.path.exists(bart_dir):
        os.makedirs(bart_dir)

    # im_head = "/home/ekcole/Workspace/mfast_combined/"
    # im_dir = os.path.join(im_head, FLAGS.train_dir)
    image_dir = os.path.join(model_dir, "images")
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        """Execute main function."""
        os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.device

        if not FLAGS.dataset_dir:
            raise ValueError("You must supply the dataset directory with " +
                             "--dataset_dir")

        if FLAGS.random_seed >= 0:
            random.seed(FLAGS.random_seed)
            np.random.seed(FLAGS.random_seed)

        tf.logging.set_verbosity(tf.logging.INFO)

        print("Preparing dataset...")
        out_shape = [FLAGS.shape_z, FLAGS.shape_y]

        test_dataset, num_files = mri_data.create_dataset(
            os.path.join(FLAGS.dataset_dir, "test_images"),
            FLAGS.mask_path,
            num_channels=FLAGS.num_channels,
            num_maps=FLAGS.num_emaps,
            batch_size=FLAGS.batch_size,
            out_shape=out_shape,
        )
        # channels first: (batch, channels, z, y)
        # placeholders
        ks_shape = [None, FLAGS.shape_z, FLAGS.shape_y, FLAGS.num_channels]
        ks_place = tf.placeholder(tf.complex64, ks_shape)
        sense_shape = [
            None, FLAGS.shape_z, FLAGS.shape_y, 1, FLAGS.num_channels
        ]
        sense_place = tf.placeholder(tf.complex64, sense_shape)
        im_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1]
        im_truth_place = tf.placeholder(tf.complex64, im_shape)
        # run through unrolled
        im_out_place = mri_model.unroll_fista(
            ks_place,
            sense_place,
            is_training=True,
            verbose=True,
            do_hardproj=FLAGS.do_hard_proj,
            num_summary_image=FLAGS.num_summary_image,
            resblock_num_features=FLAGS.feat_map,
            num_grad_steps=FLAGS.num_grad_steps,
            conv=FLAGS.conv,
        )
        saver = tf.train.Saver()
        summary_writer = tf.summary.FileWriter(model_dir, sess.graph)

        # initialize model
        print("[*] initializing network...")
        if not load(model_dir, saver, sess):
            sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)

        # See how many parameters are in model
        total_parameters = 0
        for variable in tf.trainable_variables():
            variable_parameters = 1
            for dim in variable.get_shape():
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print("Total number of trainable parameters: %d" % total_parameters)
        tf.summary.scalar("parameters/parameters", total_parameters)

        test_iterator = test_dataset.make_one_shot_iterator()
        features, labels = test_iterator.get_next()

        ks_truth = labels
        ks_in = features["ks_input"]
        sense_in = features["sensemap"]
        mask_recon = features["mask_recon"]
        im_in = tf_util.model_transpose(ks_in * mask_recon, sense_in)
        im_truth = tf_util.model_transpose(ks_truth * mask_recon, sense_in)

        total_summary = tf.summary.merge_all()

        output_psnr = []
        output_nrmse = []
        output_ssim = []
        cs_psnr = []
        cs_nrmse = []
        cs_ssim = []

        for test_file in range(num_files):
            ks_in_run, sense_in_run, im_truth_run, im_in_run = sess.run(
                [ks_in, sense_in, im_truth, im_in])
            im_out, total_summary_run = sess.run(
                [im_out_place, total_summary],
                feed_dict={
                    ks_place: ks_in_run,
                    sense_place: sense_in_run,
                    im_truth_place: im_truth_run,
                },
            )

            # CS recon
            bart_test = bart_cs(bart_dir, ks_in_run, sense_in_run, l1=FLAGS.l1)

            # print("rotating")
            # im_in_run = np.rot90(np.squeeze(im_in_run), k=3)
            # im_out = np.rot90(np.squeeze(im_out), k=3)
            # bart_test = np.rot90(np.squeeze(bart_test), k=3)
            # im_truth_run = np.rot90(np.squeeze(im_truth_run), k=3)

            # save magnitude input, output, cs, truth as .png
            # complex
            if FLAGS.conv == "complex":
                mag_images = np.squeeze(
                    np.absolute(
                        np.concatenate((im_out, bart_test, im_truth_run),
                                       axis=2)))
                phase_images = np.squeeze(
                    np.angle(
                        np.concatenate((im_out, bart_test, im_truth_run),
                                       axis=2)))
                diff_out = im_truth_run - im_out

                diff_cs = im_truth_run - bart_test

                diff_mag = np.squeeze(
                    np.absolute(np.concatenate((diff_out, diff_cs), axis=2)))
                diff_phase = np.squeeze(
                    np.angle(np.concatenate((diff_out, diff_cs), axis=2)))

            if FLAGS.conv == "real":
                mag_images = np.squeeze(
                    np.absolute(np.concatenate((im_in_run, im_out), axis=2)))
                phase_images = np.squeeze(
                    np.angle(np.concatenate((im_in_run, im_out), axis=2)))
                diff_in = im_truth_run - im_in_run
                diff_out = im_truth_run - im_out

                diff_mag = np.squeeze(
                    np.absolute(np.concatenate((diff_in, diff_out), axis=2)))
                diff_phase = np.squeeze(
                    np.angle(np.concatenate((diff_in, diff_out), axis=2)))

            filename = image_dir + "/mag_" + str(test_file) + ".png"
            scipy.misc.imsave(filename, mag_images)

            # filename = image_dir + "/diff_mag_" + str(test_file) + ".png"
            # scipy.misc.imsave(filename, diff_mag)

            filename = image_dir + "/phase_" + str(test_file) + ".png"
            scipy.misc.imsave(filename, phase_images)

            # filename = image_dir + "/diff_phase_" + str(test_file) + ".png"
            # scipy.misc.imsave(filename, diff_phase)

            filename = image_dir + "/diff_phase_" + str(test_file) + ".npy"
            np.save(filename, diff_phase)

            filename = image_dir + "/diff_mag_" + str(test_file) + ".npy"
            np.save(filename, diff_mag)

            psnr, nrmse, ssim = metrics.compute_all(im_truth_run,
                                                    im_out,
                                                    sos_axis=-1)
            output_psnr.append(psnr)
            output_nrmse.append(nrmse)
            output_ssim.append(ssim)

            print("output psnr, nrmse, ssim")
            print(
                np.mean(output_psnr),
                np.std(output_psnr),
                np.mean(output_nrmse),
                np.std(output_nrmse),
                np.mean(output_ssim),
                np.std(output_ssim),
            )

            psnr, nrmse, ssim = metrics.compute_all(im_truth_run,
                                                    bart_test,
                                                    sos_axis=-1)
            cs_psnr.append(psnr)
            cs_nrmse.append(nrmse)
            cs_ssim.append(ssim)

            print("cs psnr, nrmse, ssim")
            print(
                np.mean(cs_psnr),
                np.std(cs_psnr),
                np.mean(cs_nrmse),
                np.std(cs_nrmse),
                np.mean(cs_ssim),
                np.std(cs_ssim),
            )
        print("End of testing loop")
        txt_path = os.path.join(model_dir, "metrics.txt")
        f = open(txt_path, "w")
        f.write("parameters = " + str(total_parameters) + "\n"
                "output psnr = " + str(np.mean(output_psnr)) + " +\- " +
                str(np.std(output_psnr)) + "\n" + "output nrmse = " +
                str(np.mean(output_nrmse)) + " +\- " +
                str(np.std(output_nrmse)) + "\n" + "output ssim = " +
                str(np.mean(output_ssim)) + " +\- " +
                str(np.std(output_ssim)) + "\n"
                "cs psnr = " + str(np.mean(cs_psnr)) + " +\- " +
                str(np.std(cs_psnr)) + "\n" + "output nrmse = " +
                str(np.mean(cs_nrmse)) + " +\- " + str(np.std(cs_nrmse)) +
                "\n" + "output ssim = " + str(np.mean(cs_ssim)) + " +\- " +
                str(np.std(cs_ssim)))
        f.close()
Beispiel #11
0
    def measure(self, X_gen, sensemap, real_mask):
        name = "measure"
        random_seed = 0
        verbose = True
        image = tf_util.channels_to_complex(X_gen)
        kspace = tf_util.model_forward(image, sensemap)
        # input_shape = tf.shape(kspace)

        total_kspace = None
        if (self.data_type is "DCE" or self.data_type is "DCE_2D"
                or self.data_type is "mfast"):
            print("DCE measure")
            for i in range(self.batch_size):
                if self.dims == 4:
                    ks_x = kspace[i, :, :]
                else:
                    # 2D plus time
                    ks_x = kspace[i, :, :, :]
                # lazy: use original applied mask and just apply it again
                # won't work because it isn't doing anything unless it gets flipped
                # mask_x = tf_util.kspace_mask(ks_x, dtype=tf.complex64)
                # # Augment sampling masks

                # # New mask - taken from image B
                # # mask = real_mask
                # mask_x = tf.image.flip_up_down(mask_x)
                # mask_x = tf.image.flip_left_right(mask_x)

                # if self.dims != 4:
                #     mask_x = mask_x[:,:,:,0,:]

                # data dimensions
                shape_y = self.width
                shape_t = self.max_frames
                sim_partial_ky = 0.0

                accs = [1, 6]
                rand_accel = (accs[1] - accs[0]) * \
                    tf.random_uniform([]) + accs[0]
                fn_inputs = [
                    shape_y,
                    shape_t,
                    rand_accel,
                    10,
                    2.0,
                    sim_partial_ky,
                ]  # ny, nt, accel, ncal, vd_degree
                mask_x = tf.py_func(mask.generate_perturbed2dvdkt, fn_inputs,
                                    tf.complex64)
                ks_x = ks_x * tf.reshape(mask_x, [1, shape_y, shape_t, 1])

                if total_kspace is not None:
                    total_kspace = tf.concat([total_kspace, ks_x], 0)
                else:
                    total_kspace = ks_x
        if self.data_type is "knee":
            print("knee")
            for i in range(self.batch_size):
                ks_x = kspace[i, :, :]
                print("ks_x", ks_x)
                # Randomly select mask
                mask_x = tf.random_shuffle(self.masks)
                print("self masks", mask_x)
                mask_x = tf.slice(mask_x, [0, 0, 0], [1, -1, -1])
                print("sliced masks", mask_x)
                # Augment sampling masks
                mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed)
                mask_x = tf.image.random_flip_left_right(mask_x,
                                                         seed=random_seed)
                # Tranpose to store data as (kz, ky, channels)
                mask_x = tf.transpose(mask_x, [1, 2, 0])
                print("transposed mask", mask_x)
                ks_x = tf.image.flip_up_down(ks_x)
                # Initially set image size to be all the same
                ks_x = tf.image.resize_image_with_crop_or_pad(
                    ks_x, self.height, self.width)
                mask_x = tf.image.resize_image_with_crop_or_pad(
                    mask_x, self.height, self.width)
                print("resized mask", mask_x)
                shape_cal = 20
                if shape_cal > 0:
                    with tf.name_scope("CalibRegion"):
                        if self.verbose:
                            print("%s>  Including calib region (%d, %d)..." %
                                  (name, shape_cal, shape_cal))
                        mask_calib = tf.ones([shape_cal, shape_cal, 1],
                                             dtype=tf.complex64)
                        mask_calib = tf.image.resize_image_with_crop_or_pad(
                            mask_calib, self.height, self.width)
                        mask_x = mask_x * (1 - mask_calib) + mask_calib

                    mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x))
                    mask_recon = tf.cast(mask_recon > 1e-7, dtype=tf.complex64)
                    mask_x = mask_x * mask_recon
                    print("mask x", mask_x)
                    # Assuming calibration region is fully sampled
                    shape_sc = 5
                    scale = tf.image.resize_image_with_crop_or_pad(
                        ks_x, shape_sc, shape_sc)
                    scale = tf.reduce_mean(tf.square(
                        tf.abs(scale))) * (shape_sc * shape_sc / 1e5)
                    scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64)
                    ks_x = ks_x * scale
                    # Masked input
                    ks_x = tf.multiply(ks_x, mask_x)
                if total_kspace is not None:
                    total_kspace = tf.concat([total_kspace, ks_x], 0)
                else:
                    total_kspace = ks_x

        x_measured = tf_util.model_transpose(total_kspace,
                                             sensemap,
                                             name="x_measured")
        x_measured = tf_util.complex_to_channels(x_measured)
        return x_measured
Beispiel #12
0
    def generator(self, ks_input, sensemap, reuse=False):
        mask_example = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
        with tf.variable_scope("generator") as scope:
            if reuse:
                scope.reuse_variables()
            # 2D data
            # batch, height, width, channels
            if self.dims == 4:
                if self.arch == "unrolled":
                    c_out = unrolled.unroll_fista(
                        ks_input,
                        sensemap,
                        num_grad_steps=self.iterations,
                        resblock_num_features=self.g_dim,
                        resblock_num_blocks=self.res_blocks,
                        is_training=True,
                        scope="MRI",
                        mask_output=1,
                        window=None,
                        do_hardproj=True,
                        num_summary_image=6,
                        mask=mask_example,
                        verbose=False,
                    )
                    c_out = tf_util.complex_to_channels(c_out)
                else:
                    z = tf_util.model_transpose(ks_input, sensemap)
                    z = tf_util.complex_to_channels(z)
                    res_size = self.g_dim
                    kernel_size = 3
                    num_channels = 2
                    # could try tf.nn.tanh instead
                    # act = tf.nn.sigmoid
                    num_blocks = 5
                    c = tf.layers.conv2d(
                        z,
                        res_size,
                        kernel_size,
                        padding="same",
                        activation=tf.nn.relu,
                        use_bias=True,
                    )
                    for i in range(num_blocks):
                        c = tf.layers.conv2d(
                            c,
                            res_size,
                            kernel_size,
                            padding="same",
                            activation=tf.nn.relu,
                            use_bias=True,
                        )
                        c1 = tf.layers.conv2d(
                            c,
                            res_size,
                            kernel_size,
                            padding="same",
                            activation=tf.nn.relu,
                            use_bias=True,
                        )
                        c = tf.add(c, c1)
                    c8 = tf.layers.conv2d(
                        c,
                        num_channels,
                        kernel_size,
                        padding="same",
                        activation=None,
                        use_bias=True,
                    )
                    c_out = tf.add(c8, z)
                    c_out = tf.nn.tanh(c_out)

            # 3D data
            # batch, height, width, channels, time frames
            else:
                if self.arch == "unrolled":
                    c_out = unrolled_3d.unroll_fista(
                        ks_input,
                        sensemap,
                        num_grad_steps=self.iterations,
                        num_features=self.g_dim,
                        num_resblocks=self.res_blocks,
                        is_training=True,
                        scope="MRI",
                        mask_output=1,
                        window=None,
                        do_hardproj=True,
                        mask=mask_example,
                        verbose=False,
                        data_format="channels_last",
                        do_separable=self.do_separable,
                    )
                    c_out = tf_util.complex_to_channels(c_out)
            return c_out
Beispiel #13
0
    def build_model(self, mode):
        self.Y_real, self.data_num, real_mask = self.read_real()

        # Read in train or test input images to be reconstructed by generator
        train_iterator = mri_utils.Iterator(
            self.batch_size,
            self.mask_path,
            self.data_type,
            mode,  # either train or test
            self.out_shape,
            verbose=self.verbose,
            train_acc=self.train_acc,
            data_dir=self.data_dir)
        self.input_files = train_iterator.num_files
        train_dataset = train_iterator.iterator.get_next()
        ks_truth = train_dataset["ks_truth"]
        ks_input = train_dataset["ks_input"]
        sensemap = train_dataset["sensemap"]
        image_truth = tf_util.model_transpose(ks_truth, sensemap)
        self.complex_truth = image_truth

        image_truth = tf_util.complex_to_channels(image_truth)
        self.z_truth = image_truth
        self.ks = ks_input
        self.sensemap = sensemap

        # generate image X_gen
        self.X_gen = self.generator(self.ks, self.sensemap)
        # no measurement for supervised
        self.Y_fake = self.X_gen

        # output of discriminator for fake and real images
        self.d_logits_fake = self.discriminator(self.Y_fake, reuse=False)
        self.d_logits_real = self.discriminator(self.Y_real, reuse=True)

        # discriminator loss
        self.d_loss = tf.reduce_mean(self.d_logits_fake) - tf.reduce_mean(
            self.d_logits_real)
        # add total variation loss
        # tv_loss = -tf.reduce_sum(tf.image.total_variation(self.Y_fake))
        # generator loss
        self.g_loss = -tf.reduce_mean(self.d_logits_fake)  # + tv_loss

        # Gradient Penalty
        self.epsilon = tf.random_uniform(shape=[self.batch_size, 1, 1, 1],
                                         minval=0.0,
                                         maxval=1.0)
        Y_hat = self.Y_real + self.epsilon * (self.Y_fake - self.Y_real)
        D_Y_hat = self.discriminator(Y_hat, reuse=True)
        grad_D_Y_hat = tf.gradients(D_Y_hat, [Y_hat])[0]
        red_idx = range(1, Y_hat.shape.ndims)
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(grad_D_Y_hat),
                          reduction_indices=list(red_idx)))
        self.gradient_penalty = tf.reduce_mean((slopes - 1.0)**2)
        # updated discriminator loss
        self.d_loss = self.d_loss + 10.0 * self.gradient_penalty

        train_vars = tf.trainable_variables()
        for v in train_vars:
            # v = tf.cast(v, tf.float32)
            tf.add_to_collection("reg_loss", tf.nn.l2_loss(v))
        self.generator_vars = [v for v in train_vars if "generator" in v.name]
        self.discriminator_vars = [
            v for v in train_vars if "discriminator" in v.name
        ]

        self.g_optimizer = tf.train.AdamOptimizer(
            learning_rate=self.lr,
            name="g_opt",
            beta1=self.beta1,
            beta2=self.beta2).minimize(self.g_loss,
                                       var_list=self.generator_vars)
        self.d_optimizer = tf.train.AdamOptimizer(
            learning_rate=self.lr,
            name="d_opt",
            beta1=self.beta1,
            beta2=self.beta2).minimize(self.d_loss,
                                       var_list=self.discriminator_vars)

        self.output_image = tf_util.channels_to_complex(self.X_gen)
        self.im_out = self.output_image
        self.mag_output = tf.abs(self.output_image)
        self.create_summary()

        with tf.variable_scope("counter"):
            self.counter = tf.get_variable(
                "counter",
                shape=[1],
                initializer=tf.constant_initializer([1]),
                dtype=tf.int32,
            )
            self.update_counter = tf.assign(self.counter,
                                            tf.add(self.counter, 1))
        self.saver = tf.train.Saver()
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        self.initialize_model()
def unroll_fista(
    ks_input,
    sensemap,
    num_grad_steps=5,
    resblock_num_features=128,
    resblock_num_blocks=2,
    is_training=True,
    scope="MRI",
    mask_output=1,
    window=None,
    do_hardproj=True,
    num_summary_image=0,
    mask=None,
    verbose=False,
):
    """Create general unrolled network for MRI.
    x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) )
            = S( x_k - 2 * t * (A^T W A x - A^T W b))
    """
    # get list of GPU devices
    local_device_protos = device_lib.list_local_devices()
    device_list = [x.name for x in local_device_protos if x.device_type == "GPU"]

    if window is None:
        window = 1
    summary_iter = None

    if verbose:
        print(
            "%s> Building FISTA unrolled network (%d steps)...."
            % (scope, num_grad_steps)
        )
        if sensemap is not None:
            print("%s>   Using sensitivity maps..." % scope)

    with tf.variable_scope(scope):
        if mask is None:
            mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
        ks_input = mask * ks_input
        ks_0 = ks_input
        # x0 = A^T W b
        im_0 = tf_util.model_transpose(ks_0 * window, sensemap)
        im_0 = tf.identity(im_0, name="input_image")
        # To be updated
        ks_k = ks_0
        im_k = im_0

        for i_step in range(num_grad_steps):
            iter_name = "iter_%02d" % i_step
            i_device = int(len(device_list) * i_step / num_grad_steps)
            # cur_device = device_list[i_device]
            cur_device = device_list[-1]

            with tf.device(cur_device), tf.variable_scope(iter_name):
            # with tf.variable_scope(iter_name):
                # = S( x_k - 2 * t * (A^T W A x_k - A^T W b))
                # = S( x_k - 2 * t * (A^T W A x_k - x0))
                with tf.variable_scope("update"):
                    im_k_orig = im_k

                    # xk = A^T A x_k
                    ks_k = tf_util.model_forward(im_k, sensemap)
                    ks_k = mask * ks_k
                    im_k = tf_util.model_transpose(ks_k * window, sensemap)
                    # xk = A^T A x_k - A^T b
                    im_k = tf_util.complex_to_channels(im_k - im_0)
                    im_k_orig = tf_util.complex_to_channels(im_k_orig)
                    # Update step
                    t_update = tf.get_variable(
                        "t", dtype=tf.float32, initializer=tf.constant([-2.0])
                    )
                    im_k = im_k_orig + t_update * im_k

                with tf.variable_scope("prox"):
                    num_channels_out = im_k.shape[-1]
                    # Default is channels last
                    # im_k = prior_grad_res_net(im_k, training=is_training,
                    #                           cout=num_channels_out)

                    # Transpose channels_last to channels_first
                    im_k = tf.transpose(im_k, [0, 3, 1, 2])
                    im_k = prior_grad_res_net(
                        im_k,
                        training=is_training,
                        num_features=resblock_num_features,
                        num_blocks=resblock_num_blocks,
                        num_features_out=num_channels_out,
                        data_format="channels_first",
                    )
                    im_k = tf.transpose(im_k, [0, 2, 3, 1])

                    im_k = tf_util.channels_to_complex(im_k)

                im_k = tf.identity(im_k, name="image")
                if num_summary_image > 0:
                    with tf.name_scope("summary"):
                        tmp = tf_util.sumofsq(im_k, keep_dims=True)
                        if summary_iter is None:
                            summary_iter = tmp
                        else:
                            summary_iter = tf.concat((summary_iter, tmp), axis=2)
                        # tf.summary.scalar("max/" + iter_name, tf.reduce_max(tmp))

        ks_k = tf_util.model_forward(im_k, sensemap)
        if do_hardproj:
            if verbose:
                print("%s>   Final hard data projection..." % scope)
            # Final data projection
            ks_k = mask * ks_0 + (1 - mask) * ks_k
            if mask_output is not None:
                ks_k = ks_k * mask_output
            im_k = tf_util.model_transpose(ks_k * window, sensemap)

        ks_k = tf.identity(ks_k, name="output_kspace")
        im_k = tf.identity(im_k, name="output_image")

    # if summary_iter is not None:
    #     tf.summary.image("iter/image", summary_iter, max_outputs=num_summary_image)

    return im_k
Beispiel #15
0
def unroll_fista(
    ks_input,
    sensemap,
    scope="MRI",
    num_grad_steps=5,
    num_resblocks=4,
    num_features=64,
    kernel_size=[3, 3, 5],
    is_training=True,
    mask_output=1,
    mask=None,
    window=None,
    do_hardproj=False,
    do_dense=False,
    do_separable=False,
    do_rnn=False,
    do_circular=True,
    batchnorm=False,
    leaky=False,
    fix_update=False,
    data_format="channels_first",
    verbose=False,
):
    """Create general unrolled network for MRI.
    x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) )
            = S( x_k - 2 * t * (A^T W A x - A^T W b))
    """
    # get list of GPU devices
    local_device_protos = device_lib.list_local_devices()
    device_list = [x.name for x in local_device_protos if x.device_type == "GPU"]

    if window is None:
        window = 1
    summary_iter = {}

    if verbose:
        print("%s> Building FISTA unrolled network...." % scope)
        print("%s>   Num of gradient steps: %d" % (scope, num_grad_steps))
        print(
            "%s>   Prior: %d ResBlocks, %d features"
            % (scope, num_resblocks, num_features)
        )
        print("%s>   Kernel size: [%d x %d x %d]" % ((scope,) + tuple(kernel_size)))
        if do_rnn:
            print("%s>   Sharing weights across iterations..." % scope)
        if sensemap is not None:
            print("%s>   Using sensitivity maps..." % scope)
        if do_dense:
            print("%s>   Inserting dense connections..." % scope)
        if do_circular:
            print("%s>   Using circular padding..." % scope)
        if do_separable:
            print("%s>   Using depth-wise separable convolutions..." % scope)
        if not batchnorm:
            print("%s>   Turning off batch normalization..." % scope)

    with tf.variable_scope(scope):
        if mask is None:
            mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
        ks_input = mask * ks_input
        ks_0 = ks_input
        # x0 = A^T W b
        im_0 = tf_util.model_transpose(ks_0 * window, sensemap)
        im_0 = tf.identity(im_0, name="input_image")
        # To be updated
        ks_k = ks_0
        im_k = im_0
        im_dense = None

        for i_step in range(num_grad_steps):
            iter_name = "iter_%02d" % i_step
            if do_rnn:
                scope_name = "iter"
            else:
                scope_name = iter_name

            # figure out which GPU to use for this step
            # i_device = int(len(device_list) * i_step / num_grad_steps)
            # cur_device = device_list[i_device]

            # with tf.device(cur_device):
            with tf.variable_scope(
                scope_name, reuse=(tf.AUTO_REUSE if do_rnn else False)
            ):
                with tf.variable_scope("update"):
                    # = S( x_k - 2 * t * (A^T W A x_k - A^T W b))
                    # = S( x_k - 2 * t * (A^T W A x_k - x0))
                    im_k_orig = im_k
                    # xk = A^T A x_k
                    ks_k = tf_util.model_forward(im_k, sensemap)
                    ks_k = mask * ks_k
                    im_k = tf_util.model_transpose(ks_k * window, sensemap)
                    # xk = A^T A x_k - A^T b
                    im_k = tf_util.complex_to_channels(im_k - im_0)
                    im_k_orig = tf_util.complex_to_channels(im_k_orig)
                    # Update step
                    if fix_update:
                        t_update = -2.0
                    else:
                        t_update = tf.get_variable(
                            "t", dtype=tf.float32, initializer=tf.constant([-2.0])
                        )
                    im_k = im_k_orig + t_update * im_k

                with tf.variable_scope("prox"):
                    # default is channels_last
                    num_channels_out = im_k.shape[-1]
                    if data_format == "channels_first":
                        im_k = tf.transpose(im_k, [0, 4, 1, 2, 3])

                    if im_dense is not None:
                        im_k = tf.concat([im_k, im_dense], axis=1)

                    im_k, im_dense_k = prior_grad_res_net(
                        im_k,
                        training=is_training,
                        num_features=num_features,
                        num_blocks=num_resblocks,
                        num_features_out=num_channels_out,
                        kernel_size=kernel_size,
                        data_format=data_format,
                        circular=do_circular,
                        separable=do_separable,
                        batchnorm=batchnorm,
                        leaky=leaky,
                    )

                    if do_dense:
                        if im_dense is not None:
                            im_dense = tf.concat([im_dense, im_dense_k], axis=1)
                        else:
                            im_dense = im_dense_k

                    if data_format == "channels_first":
                        im_k = tf.transpose(im_k, [0, 2, 3, 4, 1])

                    im_k = tf_util.channels_to_complex(im_k)

                im_k = tf.identity(im_k, name="image")

                with tf.name_scope("summary"):
                    # tmp = tf_util.sumofsq(im_k, keep_dims=True)
                    summary_iter[iter_name] = im_k

        ks_k = tf_util.model_forward(im_k, sensemap)
        if do_hardproj:
            if verbose:
                print("%s>   Final hard data projection..." % scope)
            ks_k = mask * ks_0 + (1 - mask) * ks_k
            if mask_output is not None:
                ks_k = ks_k * mask_output
            im_k = tf_util.model_transpose(ks_k * window, sensemap)

        ks_k = tf.identity(ks_k, name="output_kspace")
        im_k = tf.identity(im_k, name="output_image")

    #     return im_k, ks_k, summary_iter
    return im_k
def main(_):
    # path where model checkpoints and summaries will be saved
    model_dir = os.path.join(FLAGS.log_root, FLAGS.train_dir)
    if not os.path.exists(FLAGS.log_root):
        os.makedirs(FLAGS.log_root)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(
        config=tf.ConfigProto(allow_soft_placement=True,
                              log_device_placement=True)
    ) as sess:
        """Execute main function."""
        os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.device

        if not FLAGS.dataset_dir:
            raise ValueError(
                "You must supply the dataset directory with " + "--dataset_dir"
            )

        if FLAGS.random_seed >= 0:
            random.seed(FLAGS.random_seed)
            np.random.seed(FLAGS.random_seed)

        tf.logging.set_verbosity(tf.logging.INFO)

        print("Preparing dataset...")
        out_shape = [FLAGS.shape_z, FLAGS.shape_y]
        train_dataset, num_files = mri_data.create_dataset(
            os.path.join(FLAGS.dataset_dir, "train"),
            FLAGS.mask_path,
            num_channels=FLAGS.num_channels,
            num_emaps=FLAGS.num_emaps,
            batch_size=FLAGS.batch_size,
            out_shape=out_shape,
        )

        # channels last format: batch, z, y, channels
        # placeholders
        ks_shape = [None, FLAGS.shape_z, FLAGS.shape_y, FLAGS.num_channels]
        ks_place = tf.placeholder(tf.complex64, ks_shape)
        sense_shape = [None, FLAGS.shape_z,
                       FLAGS.shape_y, 1, FLAGS.num_channels]
        sense_place = tf.placeholder(tf.complex64, sense_shape)
        im_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1]
        im_truth_place = tf.placeholder(tf.complex64, im_shape)

        # run through unrolled model
        im_out_place = mri_model.unroll_fista(
            ks_place,
            sense_place,
            is_training=True,
            verbose=True,
            do_hardproj=FLAGS.do_hard_proj,
            num_summary_image=FLAGS.num_summary_image,
            resblock_num_features=FLAGS.feat_map,
            num_grad_steps=FLAGS.num_grad_steps,
            conv=FLAGS.conv,
            do_conjugate=FLAGS.do_conjugate,
            activation=FLAGS.activation
        )

        # tensorboard summary function
        _create_summary(sense_place, ks_place, im_out_place, im_truth_place)

        # define L1 loss between output and ground truth
        loss = tf.reduce_mean(tf.abs(im_out_place - im_truth_place), name="l1")
        loss_sum = tf.summary.scalar("loss/l1", loss)

        # optimize using Adam
        optimizer = tf.train.AdamOptimizer(
            learning_rate=FLAGS.learning_rate,
            name="opt",
            beta1=FLAGS.adam_beta1,
            beta2=FLAGS.adam_beta2,
        ).minimize(loss)

        # counter for saving checkpoints
        with tf.variable_scope("counter"):
            counter = tf.get_variable(
                "counter",
                shape=[1],
                initializer=tf.constant_initializer([0]),
                dtype=tf.int32,
            )
            update_counter = tf.assign(counter, tf.add(counter, 1))

        saver = tf.train.Saver()
        summary_writer = tf.summary.FileWriter(model_dir, sess.graph)

        # initialize model
        print("[*] initializing network...")
        if not load(model_dir, saver, sess):
            sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)

        # calculate number of parameters in model
        total_parameters = 0
        for variable in tf.trainable_variables():
            variable_parameters = 1
            for dim in variable.get_shape():
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print("Total number of trainable parameters: %d" % total_parameters)
        tf.summary.scalar("parameters/parameters", total_parameters)

        # use iterator to go through TFrecord dataset
        train_iterator = train_dataset.make_one_shot_iterator()
        features, labels = train_iterator.get_next()

        ks_truth = labels  # ground truth kspace
        ks_in = features["ks_input"]  # input kspace
        sense_in = features["sensemap"]  # sensitivity maps
        mask_recon = features["mask_recon"]  # reconstruction mask

        # ground truth kspace to image domain
        im_truth = tf_util.model_transpose(ks_truth * mask_recon, sense_in)

        # gather summaries for tensorboard
        total_summary = tf.summary.merge_all()

        print("Start from step %d." % (sess.run(counter)))
        for step in range(int(sess.run(counter)), FLAGS.max_steps):
            # evaluate input kspace, sensitivity maps, ground truth image
            ks_in_run, sense_in_run, im_truth_run = sess.run(
                [ks_in, sense_in, im_truth]
            )
            # run optimizer and collect output image from model and tensorboard summary
            im_out, total_summary_run, _ = sess.run(
                [im_out_place, total_summary, optimizer],
                feed_dict={
                    ks_place: ks_in_run,
                    sense_place: sense_in_run,
                    im_truth_place: im_truth_run,
                },
            )
            print("step", step)
            # add summary to tensorboard
            summary_writer.add_summary(total_summary_run, step)

            # save checkpoint every 500 steps
            if step % 500 == 0:
                print("saving checkpoint")
                saver.save(sess, model_dir + "/model.ckpt")

            # update recorded step training is at
            sess.run(update_counter)
        print("End of training loop")
    def build_model(self, mode):
        # Read in real undersampled image Yr
        self.Y_real, self.data_num, real_mask = self.read_real()
        if self.data_num == 0:
            print("Error: no training files found")
            exit()

        real_image = tf.abs(tf_util.channels_to_complex(self.Y_real))

        if mode == "test":
            if self.data_type == "DCE_2D":
                # self.search_str = "/14Jul16_Ex19493_Ser5*.tfrecords"
                self.search_str = "/17Dec16_Ex21068_Ser13*.tfrecords"
                # self.search_str = "/*.tfrecords"
            if self.data_type == "knee":
                self.search_str = "/*.tfrecords"
        else:
            self.search_str = "/*.tfrecords"

        # Read in train or test input images to be reconstructed by generator
        train_iterator = mri_utils.Iterator(
            self.batch_size,
            self.mask_path,
            self.data_type,
            mode,  # either train or test
            self.out_shape,
            verbose=self.verbose,
            train_acc=self.train_acc,
            search_str=self.search_str,
            data_dir=self.data_dir)
        self.input_files = train_iterator.num_files

        if self.input_files == 0:
            print("Error: no input files found")
            exit()

        train_dataset = train_iterator.iterator.get_next()
        ks_truth = train_dataset["ks_truth"]
        ks_input = train_dataset["ks_input"]
        sensemap = train_dataset["sensemap"]

        self.im_in = tf_util.model_transpose(ks_input, sensemap)
        self.complex_truth = tf_util.model_transpose(ks_truth, sensemap)

        self.z_truth = tf_util.complex_to_channels(self.complex_truth)
        self.ks = ks_input
        self.sensemap = sensemap

        # generate image X_gen
        self.X_gen = self.generator(self.ks, self.sensemap)

        # measure X_gen
        self.Y_fake = self.measure(self.X_gen, self.sensemap, real_mask)

        # output of discriminator for fake and real images
        self.d_logits_fake = self.discriminator(self.Y_fake, reuse=False)
        self.d_logits_real = self.discriminator(self.Y_real, reuse=True)

        # discriminator loss
        self.d_loss = tf.reduce_mean(self.d_logits_fake) - tf.reduce_mean(
            self.d_logits_real)

        # generator loss
        self.g_loss = -tf.reduce_mean(self.d_logits_fake)

        # Gradient Penalty
        self.epsilon = tf.random_uniform(shape=[self.batch_size, 1, 1, 1],
                                         minval=0.0,
                                         maxval=1.0)
        Y_hat = self.Y_real + self.epsilon * (self.Y_fake - self.Y_real)
        D_Y_hat = self.discriminator(Y_hat, reuse=True)
        grad_D_Y_hat = tf.gradients(D_Y_hat, [Y_hat])[0]
        red_idx = range(1, Y_hat.shape.ndims)
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(grad_D_Y_hat),
                          reduction_indices=list(red_idx)))
        self.gradient_penalty = tf.reduce_mean((slopes - 1.0)**2)
        # updated discriminator loss
        self.d_loss = self.d_loss + 10.0 * self.gradient_penalty

        train_vars = tf.trainable_variables()
        for v in train_vars:
            # v = tf.cast(v, tf.float32)
            tf.add_to_collection("reg_loss", tf.nn.l2_loss(v))
        self.generator_vars = [v for v in train_vars if "generator" in v.name]
        self.discriminator_vars = [
            v for v in train_vars if "discriminator" in v.name
        ]

        local_device_protos = device_lib.list_local_devices()
        device_list = [
            x.name for x in local_device_protos if x.device_type == "GPU"
        ]

        cur_device = device_list[-1]
        # cur_device = device_list[0]
        with tf.device(cur_device):
            self.g_optimizer = tf.train.AdamOptimizer(
                learning_rate=self.lr,
                name="g_opt",
                beta1=self.beta1,
                beta2=self.beta2).minimize(self.g_loss,
                                           var_list=self.generator_vars)

        cur_device = device_list[0]
        with tf.device(cur_device):
            self.d_optimizer = tf.train.AdamOptimizer(
                learning_rate=self.lr,
                name="d_opt",
                beta1=self.beta1,
                beta2=self.beta2).minimize(self.d_loss,
                                           var_list=self.discriminator_vars)

        self.output_image = tf_util.channels_to_complex(self.X_gen)
        self.im_out = self.output_image
        self.mag_output = tf.abs(self.output_image)
        self.create_summary()

        with tf.variable_scope("counter"):
            self.counter = tf.get_variable(
                "counter",
                shape=[1],
                initializer=tf.constant_initializer([1]),
                dtype=tf.int32,
            )
            self.update_counter = tf.assign(self.counter,
                                            tf.add(self.counter, 1))
        self.saver = tf.train.Saver()
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        self.initialize_model()