Esempio n. 1
0
def main(_):
    if FLAGS.batch_size is not 1:
        print("Error: to test images, batch size must be 1")
        exit()

    model_dir = os.path.join(FLAGS.log_root, FLAGS.train_dir)
    if not os.path.exists(FLAGS.log_root):
        os.makedirs(FLAGS.log_root)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    bart_dir = os.path.join(model_dir, "bart_recon")
    if not os.path.exists(bart_dir):
        os.makedirs(bart_dir)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        """Execute main function."""
        os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.device

        if not FLAGS.dataset_dir:
            raise ValueError("You must supply the dataset directory with " +
                             "--dataset_dir")

        if FLAGS.random_seed >= 0:
            random.seed(FLAGS.random_seed)
            np.random.seed(FLAGS.random_seed)

        tf.logging.set_verbosity(tf.logging.INFO)

        print("Preparing dataset...")
        out_shape = [FLAGS.shape_z, FLAGS.shape_y]

        test_dataset, num_files = mri_data.create_dataset(
            os.path.join(FLAGS.dataset_dir, "test"),
            FLAGS.mask_path,
            num_channels=FLAGS.num_channels,
            num_emaps=FLAGS.num_emaps,
            batch_size=FLAGS.batch_size,
            out_shape=out_shape,
        )
        # channels first: (batch, channels, z, y)
        # placeholders
        ks_shape = [None, FLAGS.shape_z, FLAGS.shape_y, FLAGS.num_channels]
        ks_place = tf.placeholder(tf.complex64, ks_shape)
        sense_shape = [
            None, FLAGS.shape_z, FLAGS.shape_y, 1, FLAGS.num_channels
        ]
        sense_place = tf.placeholder(tf.complex64, sense_shape)
        im_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1]
        im_truth_place = tf.placeholder(tf.complex64, im_shape)
        # run through unrolled
        im_out_place = mri_model.unroll_fista(
            ks_place,
            sense_place,
            is_training=True,
            verbose=True,
            do_hardproj=FLAGS.do_hard_proj,
            num_summary_image=FLAGS.num_summary_image,
            resblock_num_features=FLAGS.feat_map,
            num_grad_steps=FLAGS.num_grad_steps,
            conv=FLAGS.conv,
            do_conjugate=FLAGS.do_conjugate,
        )

        saver = tf.train.Saver()
        summary_writer = tf.summary.FileWriter(model_dir, sess.graph)

        # initialize model
        print("[*] initializing network...")
        if not load(model_dir, saver, sess):
            sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)

        # See how many parameters are in model
        total_parameters = 0
        for variable in tf.trainable_variables():
            variable_parameters = 1
            for dim in variable.get_shape():
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print("Total number of trainable parameters: %d" % total_parameters)

        test_iterator = test_dataset.make_one_shot_iterator()
        features, labels = test_iterator.get_next()

        ks_truth = labels
        ks_in = features["ks_input"]
        sense_in = features["sensemap"]
        mask_recon = features["mask_recon"]
        im_truth = tf_util.model_transpose(ks_truth * mask_recon, sense_in)

        total_summary = tf.summary.merge_all()

        output_psnr = []
        output_nrmse = []
        output_ssim = []
        cs_psnr = []
        cs_nrmse = []
        cs_ssim = []

        for test_file in range(num_files):
            ks_in_run, sense_in_run, im_truth_run = sess.run(
                [ks_in, sense_in, im_truth])
            im_out, total_summary_run = sess.run(
                [im_out_place, total_summary],
                feed_dict={
                    ks_place: ks_in_run,
                    sense_place: sense_in_run,
                    im_truth_place: im_truth_run,
                },
            )

            # CS recon
            bart_test = bart_cs(bart_dir, ks_in_run, sense_in_run, l1=0.007)
            # bart_test = None

            # handle batch dimension
            for b in range(FLAGS.batch_size):
                truth = im_truth_run[b, :, :, :]
                out = im_out[b, :, :, :]
                psnr, nrmse, ssim = metrics.compute_all(truth,
                                                        out,
                                                        sos_axis=-1)
                output_psnr.append(psnr)
                output_nrmse.append(nrmse)
                output_ssim.append(ssim)

            print("output mean +/ standard deviation psnr, nrmse, ssim")
            print(
                np.mean(output_psnr),
                np.std(output_psnr),
                np.mean(output_nrmse),
                np.std(output_nrmse),
                np.mean(output_ssim),
                np.std(output_ssim),
            )

            psnr, nrmse, ssim = metrics.compute_all(im_truth_run,
                                                    bart_test,
                                                    sos_axis=-1)
            cs_psnr.append(psnr)
            cs_nrmse.append(nrmse)
            cs_ssim.append(ssim)

            print("cs mean +/ standard deviation psnr, nrmse, ssim")
            print(
                np.mean(cs_psnr),
                np.std(cs_psnr),
                np.mean(cs_nrmse),
                np.std(cs_nrmse),
                np.mean(cs_ssim),
                np.std(cs_ssim),
            )
        print("End of testing loop")
        txt_path = os.path.join(model_dir, "metrics.txt")
        f = open(txt_path, "w")
        f.write("parameters = " + str(total_parameters) + "\n" +
                "output psnr = " + str(np.mean(output_psnr)) + " +\- " +
                str(np.std(output_psnr)) + "\n" + "output nrmse = " +
                str(np.mean(output_nrmse)) + " +\- " +
                str(np.std(output_nrmse)) + "\n" + "output ssim = " +
                str(np.mean(output_ssim)) + " +\- " +
                str(np.std(output_ssim)) + "\n"
                "cs psnr = " + str(np.mean(cs_psnr)) + " +\- " +
                str(np.std(cs_psnr)) + "\n" + "output nrmse = " +
                str(np.mean(cs_nrmse)) + " +\- " + str(np.std(cs_nrmse)) +
                "\n" + "output ssim = " + str(np.mean(cs_ssim)) + " +\- " +
                str(np.std(cs_ssim)))
        f.close()
def main(_):
    # path where model checkpoints and summaries will be saved
    model_dir = os.path.join(FLAGS.log_root, FLAGS.train_dir)
    if not os.path.exists(FLAGS.log_root):
        os.makedirs(FLAGS.log_root)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(
        config=tf.ConfigProto(allow_soft_placement=True,
                              log_device_placement=True)
    ) as sess:
        """Execute main function."""
        os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.device

        if not FLAGS.dataset_dir:
            raise ValueError(
                "You must supply the dataset directory with " + "--dataset_dir"
            )

        if FLAGS.random_seed >= 0:
            random.seed(FLAGS.random_seed)
            np.random.seed(FLAGS.random_seed)

        tf.logging.set_verbosity(tf.logging.INFO)

        print("Preparing dataset...")
        out_shape = [FLAGS.shape_z, FLAGS.shape_y]
        train_dataset, num_files = mri_data.create_dataset(
            os.path.join(FLAGS.dataset_dir, "train"),
            FLAGS.mask_path,
            num_channels=FLAGS.num_channels,
            num_emaps=FLAGS.num_emaps,
            batch_size=FLAGS.batch_size,
            out_shape=out_shape,
        )

        # channels last format: batch, z, y, channels
        # placeholders
        ks_shape = [None, FLAGS.shape_z, FLAGS.shape_y, FLAGS.num_channels]
        ks_place = tf.placeholder(tf.complex64, ks_shape)
        sense_shape = [None, FLAGS.shape_z,
                       FLAGS.shape_y, 1, FLAGS.num_channels]
        sense_place = tf.placeholder(tf.complex64, sense_shape)
        im_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1]
        im_truth_place = tf.placeholder(tf.complex64, im_shape)

        # run through unrolled model
        im_out_place = mri_model.unroll_fista(
            ks_place,
            sense_place,
            is_training=True,
            verbose=True,
            do_hardproj=FLAGS.do_hard_proj,
            num_summary_image=FLAGS.num_summary_image,
            resblock_num_features=FLAGS.feat_map,
            num_grad_steps=FLAGS.num_grad_steps,
            conv=FLAGS.conv,
            do_conjugate=FLAGS.do_conjugate,
            activation=FLAGS.activation
        )

        # tensorboard summary function
        _create_summary(sense_place, ks_place, im_out_place, im_truth_place)

        # define L1 loss between output and ground truth
        loss = tf.reduce_mean(tf.abs(im_out_place - im_truth_place), name="l1")
        loss_sum = tf.summary.scalar("loss/l1", loss)

        # optimize using Adam
        optimizer = tf.train.AdamOptimizer(
            learning_rate=FLAGS.learning_rate,
            name="opt",
            beta1=FLAGS.adam_beta1,
            beta2=FLAGS.adam_beta2,
        ).minimize(loss)

        # counter for saving checkpoints
        with tf.variable_scope("counter"):
            counter = tf.get_variable(
                "counter",
                shape=[1],
                initializer=tf.constant_initializer([0]),
                dtype=tf.int32,
            )
            update_counter = tf.assign(counter, tf.add(counter, 1))

        saver = tf.train.Saver()
        summary_writer = tf.summary.FileWriter(model_dir, sess.graph)

        # initialize model
        print("[*] initializing network...")
        if not load(model_dir, saver, sess):
            sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)

        # calculate number of parameters in model
        total_parameters = 0
        for variable in tf.trainable_variables():
            variable_parameters = 1
            for dim in variable.get_shape():
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print("Total number of trainable parameters: %d" % total_parameters)
        tf.summary.scalar("parameters/parameters", total_parameters)

        # use iterator to go through TFrecord dataset
        train_iterator = train_dataset.make_one_shot_iterator()
        features, labels = train_iterator.get_next()

        ks_truth = labels  # ground truth kspace
        ks_in = features["ks_input"]  # input kspace
        sense_in = features["sensemap"]  # sensitivity maps
        mask_recon = features["mask_recon"]  # reconstruction mask

        # ground truth kspace to image domain
        im_truth = tf_util.model_transpose(ks_truth * mask_recon, sense_in)

        # gather summaries for tensorboard
        total_summary = tf.summary.merge_all()

        print("Start from step %d." % (sess.run(counter)))
        for step in range(int(sess.run(counter)), FLAGS.max_steps):
            # evaluate input kspace, sensitivity maps, ground truth image
            ks_in_run, sense_in_run, im_truth_run = sess.run(
                [ks_in, sense_in, im_truth]
            )
            # run optimizer and collect output image from model and tensorboard summary
            im_out, total_summary_run, _ = sess.run(
                [im_out_place, total_summary, optimizer],
                feed_dict={
                    ks_place: ks_in_run,
                    sense_place: sense_in_run,
                    im_truth_place: im_truth_run,
                },
            )
            print("step", step)
            # add summary to tensorboard
            summary_writer.add_summary(total_summary_run, step)

            # save checkpoint every 500 steps
            if step % 500 == 0:
                print("saving checkpoint")
                saver.save(sess, model_dir + "/model.ckpt")

            # update recorded step training is at
            sess.run(update_counter)
        print("End of training loop")
Esempio n. 3
0
def create_model(sess, features, labels, masks, MY, s, architecture='resnet'):
    # sess: TF sesson
    # features: input, for SR/CS it is the input image
    # labels: output, for SR/CS it is the groundtruth image
    # architecture: aec for encode-decoder, resnet for upside down
    # Generator
    rows = int(features.get_shape()[1])
    cols = int(features.get_shape()[2])
    channels = int(features.get_shape()[3])

    gene_minput = tf.placeholder(
        tf.float32, shape=[FLAGS.batch_size, rows, cols, channels])
    gene_mMY = tf.placeholder(tf.complex64,
                              shape=[FLAGS.batch_size, 8, rows, cols])
    gene_ms = tf.placeholder(tf.complex64,
                             shape=[FLAGS.batch_size, 8, rows, cols])

    if (FLAGS.sampling_pattern != "nomask"):
        function_generator = lambda x, y, z, w: _generator_model_with_scale(
            sess, x, y, z, w, num_dc_layers=0, layer_output_skip=7)
    else:  # with unmasked input, remove dc
        function_generator = lambda x, y, z, w: _generator_model_with_scale(
            sess, x, y, z, w, num_dc_layers=-1, layer_output_skip=7)

    rbs = 2
    with tf.variable_scope('gene') as scope:
        if FLAGS.unrolled > 0:
            gene_output_1 = mri_model.unroll_fista(
                MY,
                s,
                num_grad_steps=FLAGS.unrolled,
                resblock_num_features=64,
                resblock_num_blocks=rbs,
                is_training=True,
                scope="MRI",
                mask_output=1,
                window=None,
                do_hardproj=True,
                num_summary_image=0,
                mask=masks,
                verbose=False)
            gene_var_list, gene_layers_1 = None, ['empty']
        else:
            gene_output_1, gene_var_list, gene_layers_1 = function_generator(
                features, masks, MY, s)
        scope.reuse_variables()

        gene_output_real = gene_output_1
        gene_output = tf.reshape(gene_output_real,
                                 [FLAGS.batch_size, rows, cols, 2])
        gene_layers = gene_layers_1
        #tf.summary.image('gene_train_last',abs(gene_layers),2)
        #print('gene_output_train', gene_output.get_shape())

        # for testing input
        if FLAGS.unrolled > 0:
            gene_moutput_1 = mri_model.unroll_fista(
                gene_mMY,
                gene_ms,
                num_grad_steps=FLAGS.unrolled,
                resblock_num_features=64,
                resblock_num_blocks=rbs,
                is_training=False,
                scope="MRI",
                mask_output=1,
                window=None,
                do_hardproj=True,
                num_summary_image=0,
                mask=masks,
                verbose=False)
            gene_mlayers_1 = None
        else:
            gene_moutput_1, _, gene_mlayers_1 = function_generator(
                gene_minput, masks, gene_mMY, gene_ms)
        scope.reuse_variables()

        gene_moutput_real = gene_moutput_1
        #gene_moutput_complex = tf.complex(gene_moutput_real[:,:,:,0], gene_moutput_real[:,:,:,1])
        #gene_moutput = tf.abs(gene_moutput_complex)
        #print('gene_moutput_test', gene_moutput.get_shape())i
        gene_moutput = tf.reshape(gene_moutput_real,
                                  [FLAGS.batch_size, rows, cols, 2])
        gene_mlayers = gene_mlayers_1

    # Discriminator with real data
    if FLAGS.use_phase == True:
        disc_real_input = tf.identity(labels, name='disc_real_input')
    else:
        disc_real_input = tf.sqrt(labels[:, :, :, 0]**2 +
                                  labels[:, :, :, 1]**2)
        disc_real_input = tf.expand_dims(disc_real_input, -1)

    # TBD: Is there a better way to instance the discriminator?
    with tf.variable_scope('disc', reuse=tf.AUTO_REUSE) as scope:
        disc_real_output, disc_var_list, disc_layers_X = \
                _discriminator_model(sess, features, disc_real_input, hybrid_disc=FLAGS.hybrid_disc)

        scope.reuse_variables()
        if FLAGS.use_phase == True:
            gene_output_abs = gene_output
        else:
            gene_output_abs = tf.sqrt(gene_output[:, :, :, 0]**2 +
                                      gene_output[:, :, :, 1]**2)
            gene_output_abs = tf.expand_dims(gene_output_abs, -1)

        disc_fake_output, _, disc_layers_Z = _discriminator_model(
            sess, features, gene_output_abs, hybrid_disc=FLAGS.hybrid_disc)

    return [
        gene_minput, gene_mMY, gene_ms, gene_moutput, gene_output_abs,
        gene_var_list, gene_layers, gene_mlayers, disc_real_output,
        disc_fake_output, disc_var_list, disc_real_input
    ]