Beispiel #1
0
    def calculate_metrics(self, output_image, bart_test, sample_truth):
        cs_psnr = []
        cs_nrmse = []
        cs_ssim = []
        output_psnr = []
        output_nrmse = []
        output_ssim = []

        complex_truth = tf_util.channels_to_complex(sample_truth)
        complex_truth = self.sess.run(complex_truth)

        psnr, nrmse, ssim = metrics.compute_all(complex_truth,
                                                bart_test,
                                                sos_axis=-1)
        cs_psnr.append(psnr)
        cs_nrmse.append(nrmse)
        cs_ssim.append(ssim)

        psnr, nrmse, ssim = metrics.compute_all(complex_truth,
                                                output_image,
                                                sos_axis=-1)
        output_psnr.append(psnr)
        output_nrmse.append(nrmse)
        output_ssim.append(ssim)
        return output_psnr, output_nrmse, output_ssim
def calculate_metrics(output, bart_test, truth):
    cs_psnr = []
    cs_nrmse = []
    cs_ssim = []
    output_psnr = []
    output_nrmse = []
    output_ssim = []

    psnr, nrmse, ssim = metrics.compute_all(truth, output, sos_axis=-1)
    output_psnr.append(psnr)
    output_nrmse.append(nrmse)
    output_ssim.append(ssim)

    print("cs psnr, nrmse, ssim")
    print(
        np.mean(cs_psnr),
        np.std(cs_psnr),
        np.mean(cs_nrmse),
        np.std(cs_nrmse),
        np.mean(cs_ssim),
        np.std(cs_ssim),
    )
    print(output_psnr)
    print("output psnr, nrmse, ssim")
    print(
        np.mean(output_psnr),
        np.std(output_psnr),
        np.mean(output_nrmse),
        np.std(output_nrmse),
        np.mean(output_ssim),
        np.std(output_ssim),
    )
Beispiel #3
0
def calculate_metrics(output, bart_test, truth):
    cs_psnr = []
    cs_nrmse = []
    cs_ssim = []
    output_psnr = []
    output_nrmse = []
    output_ssim = []

    psnr, nrmse, ssim = metrics.compute_all(truth, output, sos_axis=-1)
    output_psnr.append(psnr)
    output_nrmse.append(nrmse)
    output_ssim.append(ssim)
def calculate_metrics(output, bart_test, truth):
    cs_psnr = []
    cs_nrmse = []
    cs_ssim = []
    output_psnr = []
    output_nrmse = []
    output_ssim = []

    # complex_truth = tf_util.channels_to_complex(sample_truth)
    # complex_truth = sess.run(complex_truth)

    # psnr, nrmse, ssim = metrics.compute_all(complex_truth, bart_test, sos_axis=-1)
    # cs_psnr.append(psnr)
    # cs_nrmse.append(nrmse)
    # cs_ssim.append(ssim)
    # print("truth", truth.shape)
    # print("output", output.shape)

    psnr, nrmse, ssim = metrics.compute_all(truth, output, sos_axis=-1)
    output_psnr.append(psnr)
    output_nrmse.append(nrmse)
    output_ssim.append(ssim)

    print("cs psnr, nrmse, ssim")
    print(
        np.mean(cs_psnr),
        np.std(cs_psnr),
        np.mean(cs_nrmse),
        np.std(cs_nrmse),
        np.mean(cs_ssim),
        np.std(cs_ssim),
    )
    print(output_psnr)
    print("output psnr, nrmse, ssim")
    print(
        np.mean(output_psnr),
        np.std(output_psnr),
        np.mean(output_nrmse),
        np.std(output_nrmse),
        np.mean(output_ssim),
        np.std(output_ssim),
    )
Beispiel #5
0
def main(_):
    if FLAGS.batch_size is not 1:
        print("Error: to test images, batch size must be 1")
        exit()

    model_dir = os.path.join(FLAGS.log_root, FLAGS.train_dir)
    if not os.path.exists(FLAGS.log_root):
        os.makedirs(FLAGS.log_root)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    bart_dir = os.path.join(model_dir, "bart_recon")
    if not os.path.exists(bart_dir):
        os.makedirs(bart_dir)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        """Execute main function."""
        os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.device

        if not FLAGS.dataset_dir:
            raise ValueError("You must supply the dataset directory with " +
                             "--dataset_dir")

        if FLAGS.random_seed >= 0:
            random.seed(FLAGS.random_seed)
            np.random.seed(FLAGS.random_seed)

        tf.logging.set_verbosity(tf.logging.INFO)

        print("Preparing dataset...")
        out_shape = [FLAGS.shape_z, FLAGS.shape_y]

        test_dataset, num_files = mri_data.create_dataset(
            os.path.join(FLAGS.dataset_dir, "test"),
            FLAGS.mask_path,
            num_channels=FLAGS.num_channels,
            num_emaps=FLAGS.num_emaps,
            batch_size=FLAGS.batch_size,
            out_shape=out_shape,
        )
        # channels first: (batch, channels, z, y)
        # placeholders
        ks_shape = [None, FLAGS.shape_z, FLAGS.shape_y, FLAGS.num_channels]
        ks_place = tf.placeholder(tf.complex64, ks_shape)
        sense_shape = [
            None, FLAGS.shape_z, FLAGS.shape_y, 1, FLAGS.num_channels
        ]
        sense_place = tf.placeholder(tf.complex64, sense_shape)
        im_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1]
        im_truth_place = tf.placeholder(tf.complex64, im_shape)
        # run through unrolled
        im_out_place = mri_model.unroll_fista(
            ks_place,
            sense_place,
            is_training=True,
            verbose=True,
            do_hardproj=FLAGS.do_hard_proj,
            num_summary_image=FLAGS.num_summary_image,
            resblock_num_features=FLAGS.feat_map,
            num_grad_steps=FLAGS.num_grad_steps,
            conv=FLAGS.conv,
            do_conjugate=FLAGS.do_conjugate,
        )

        saver = tf.train.Saver()
        summary_writer = tf.summary.FileWriter(model_dir, sess.graph)

        # initialize model
        print("[*] initializing network...")
        if not load(model_dir, saver, sess):
            sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)

        # See how many parameters are in model
        total_parameters = 0
        for variable in tf.trainable_variables():
            variable_parameters = 1
            for dim in variable.get_shape():
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print("Total number of trainable parameters: %d" % total_parameters)

        test_iterator = test_dataset.make_one_shot_iterator()
        features, labels = test_iterator.get_next()

        ks_truth = labels
        ks_in = features["ks_input"]
        sense_in = features["sensemap"]
        mask_recon = features["mask_recon"]
        im_truth = tf_util.model_transpose(ks_truth * mask_recon, sense_in)

        total_summary = tf.summary.merge_all()

        output_psnr = []
        output_nrmse = []
        output_ssim = []
        cs_psnr = []
        cs_nrmse = []
        cs_ssim = []

        for test_file in range(num_files):
            ks_in_run, sense_in_run, im_truth_run = sess.run(
                [ks_in, sense_in, im_truth])
            im_out, total_summary_run = sess.run(
                [im_out_place, total_summary],
                feed_dict={
                    ks_place: ks_in_run,
                    sense_place: sense_in_run,
                    im_truth_place: im_truth_run,
                },
            )

            # CS recon
            bart_test = bart_cs(bart_dir, ks_in_run, sense_in_run, l1=0.007)
            # bart_test = None

            # handle batch dimension
            for b in range(FLAGS.batch_size):
                truth = im_truth_run[b, :, :, :]
                out = im_out[b, :, :, :]
                psnr, nrmse, ssim = metrics.compute_all(truth,
                                                        out,
                                                        sos_axis=-1)
                output_psnr.append(psnr)
                output_nrmse.append(nrmse)
                output_ssim.append(ssim)

            print("output mean +/ standard deviation psnr, nrmse, ssim")
            print(
                np.mean(output_psnr),
                np.std(output_psnr),
                np.mean(output_nrmse),
                np.std(output_nrmse),
                np.mean(output_ssim),
                np.std(output_ssim),
            )

            psnr, nrmse, ssim = metrics.compute_all(im_truth_run,
                                                    bart_test,
                                                    sos_axis=-1)
            cs_psnr.append(psnr)
            cs_nrmse.append(nrmse)
            cs_ssim.append(ssim)

            print("cs mean +/ standard deviation psnr, nrmse, ssim")
            print(
                np.mean(cs_psnr),
                np.std(cs_psnr),
                np.mean(cs_nrmse),
                np.std(cs_nrmse),
                np.mean(cs_ssim),
                np.std(cs_ssim),
            )
        print("End of testing loop")
        txt_path = os.path.join(model_dir, "metrics.txt")
        f = open(txt_path, "w")
        f.write("parameters = " + str(total_parameters) + "\n" +
                "output psnr = " + str(np.mean(output_psnr)) + " +\- " +
                str(np.std(output_psnr)) + "\n" + "output nrmse = " +
                str(np.mean(output_nrmse)) + " +\- " +
                str(np.std(output_nrmse)) + "\n" + "output ssim = " +
                str(np.mean(output_ssim)) + " +\- " +
                str(np.std(output_ssim)) + "\n"
                "cs psnr = " + str(np.mean(cs_psnr)) + " +\- " +
                str(np.std(cs_psnr)) + "\n" + "output nrmse = " +
                str(np.mean(cs_nrmse)) + " +\- " + str(np.std(cs_nrmse)) +
                "\n" + "output ssim = " + str(np.mean(cs_ssim)) + " +\- " +
                str(np.std(cs_ssim)))
        f.close()
def main(_):
    model_dir = os.path.join(FLAGS.log_root, FLAGS.train_dir)
    if not os.path.exists(FLAGS.log_root):
        os.makedirs(FLAGS.log_root)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    bart_dir = os.path.join(model_dir, "bart_recon")
    if not os.path.exists(bart_dir):
        os.makedirs(bart_dir)

    # im_head = "/home/ekcole/Workspace/mfast_combined/"
    # im_dir = os.path.join(im_head, FLAGS.train_dir)
    image_dir = os.path.join(model_dir, "images")
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        """Execute main function."""
        os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.device

        if not FLAGS.dataset_dir:
            raise ValueError("You must supply the dataset directory with " +
                             "--dataset_dir")

        if FLAGS.random_seed >= 0:
            random.seed(FLAGS.random_seed)
            np.random.seed(FLAGS.random_seed)

        tf.logging.set_verbosity(tf.logging.INFO)

        print("Preparing dataset...")
        out_shape = [FLAGS.shape_z, FLAGS.shape_y]

        test_dataset, num_files = mri_data.create_dataset(
            os.path.join(FLAGS.dataset_dir, "test_images"),
            FLAGS.mask_path,
            num_channels=FLAGS.num_channels,
            num_maps=FLAGS.num_emaps,
            batch_size=FLAGS.batch_size,
            out_shape=out_shape,
        )
        # channels first: (batch, channels, z, y)
        # placeholders
        ks_shape = [None, FLAGS.shape_z, FLAGS.shape_y, FLAGS.num_channels]
        ks_place = tf.placeholder(tf.complex64, ks_shape)
        sense_shape = [
            None, FLAGS.shape_z, FLAGS.shape_y, 1, FLAGS.num_channels
        ]
        sense_place = tf.placeholder(tf.complex64, sense_shape)
        im_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1]
        im_truth_place = tf.placeholder(tf.complex64, im_shape)
        # run through unrolled
        im_out_place = mri_model.unroll_fista(
            ks_place,
            sense_place,
            is_training=True,
            verbose=True,
            do_hardproj=FLAGS.do_hard_proj,
            num_summary_image=FLAGS.num_summary_image,
            resblock_num_features=FLAGS.feat_map,
            num_grad_steps=FLAGS.num_grad_steps,
            conv=FLAGS.conv,
        )
        saver = tf.train.Saver()
        summary_writer = tf.summary.FileWriter(model_dir, sess.graph)

        # initialize model
        print("[*] initializing network...")
        if not load(model_dir, saver, sess):
            sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)

        # See how many parameters are in model
        total_parameters = 0
        for variable in tf.trainable_variables():
            variable_parameters = 1
            for dim in variable.get_shape():
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print("Total number of trainable parameters: %d" % total_parameters)
        tf.summary.scalar("parameters/parameters", total_parameters)

        test_iterator = test_dataset.make_one_shot_iterator()
        features, labels = test_iterator.get_next()

        ks_truth = labels
        ks_in = features["ks_input"]
        sense_in = features["sensemap"]
        mask_recon = features["mask_recon"]
        im_in = tf_util.model_transpose(ks_in * mask_recon, sense_in)
        im_truth = tf_util.model_transpose(ks_truth * mask_recon, sense_in)

        total_summary = tf.summary.merge_all()

        output_psnr = []
        output_nrmse = []
        output_ssim = []
        cs_psnr = []
        cs_nrmse = []
        cs_ssim = []

        for test_file in range(num_files):
            ks_in_run, sense_in_run, im_truth_run, im_in_run = sess.run(
                [ks_in, sense_in, im_truth, im_in])
            im_out, total_summary_run = sess.run(
                [im_out_place, total_summary],
                feed_dict={
                    ks_place: ks_in_run,
                    sense_place: sense_in_run,
                    im_truth_place: im_truth_run,
                },
            )

            # CS recon
            bart_test = bart_cs(bart_dir, ks_in_run, sense_in_run, l1=FLAGS.l1)

            # print("rotating")
            # im_in_run = np.rot90(np.squeeze(im_in_run), k=3)
            # im_out = np.rot90(np.squeeze(im_out), k=3)
            # bart_test = np.rot90(np.squeeze(bart_test), k=3)
            # im_truth_run = np.rot90(np.squeeze(im_truth_run), k=3)

            # save magnitude input, output, cs, truth as .png
            # complex
            if FLAGS.conv == "complex":
                mag_images = np.squeeze(
                    np.absolute(
                        np.concatenate((im_out, bart_test, im_truth_run),
                                       axis=2)))
                phase_images = np.squeeze(
                    np.angle(
                        np.concatenate((im_out, bart_test, im_truth_run),
                                       axis=2)))
                diff_out = im_truth_run - im_out

                diff_cs = im_truth_run - bart_test

                diff_mag = np.squeeze(
                    np.absolute(np.concatenate((diff_out, diff_cs), axis=2)))
                diff_phase = np.squeeze(
                    np.angle(np.concatenate((diff_out, diff_cs), axis=2)))

            if FLAGS.conv == "real":
                mag_images = np.squeeze(
                    np.absolute(np.concatenate((im_in_run, im_out), axis=2)))
                phase_images = np.squeeze(
                    np.angle(np.concatenate((im_in_run, im_out), axis=2)))
                diff_in = im_truth_run - im_in_run
                diff_out = im_truth_run - im_out

                diff_mag = np.squeeze(
                    np.absolute(np.concatenate((diff_in, diff_out), axis=2)))
                diff_phase = np.squeeze(
                    np.angle(np.concatenate((diff_in, diff_out), axis=2)))

            filename = image_dir + "/mag_" + str(test_file) + ".png"
            scipy.misc.imsave(filename, mag_images)

            # filename = image_dir + "/diff_mag_" + str(test_file) + ".png"
            # scipy.misc.imsave(filename, diff_mag)

            filename = image_dir + "/phase_" + str(test_file) + ".png"
            scipy.misc.imsave(filename, phase_images)

            # filename = image_dir + "/diff_phase_" + str(test_file) + ".png"
            # scipy.misc.imsave(filename, diff_phase)

            filename = image_dir + "/diff_phase_" + str(test_file) + ".npy"
            np.save(filename, diff_phase)

            filename = image_dir + "/diff_mag_" + str(test_file) + ".npy"
            np.save(filename, diff_mag)

            psnr, nrmse, ssim = metrics.compute_all(im_truth_run,
                                                    im_out,
                                                    sos_axis=-1)
            output_psnr.append(psnr)
            output_nrmse.append(nrmse)
            output_ssim.append(ssim)

            print("output psnr, nrmse, ssim")
            print(
                np.mean(output_psnr),
                np.std(output_psnr),
                np.mean(output_nrmse),
                np.std(output_nrmse),
                np.mean(output_ssim),
                np.std(output_ssim),
            )

            psnr, nrmse, ssim = metrics.compute_all(im_truth_run,
                                                    bart_test,
                                                    sos_axis=-1)
            cs_psnr.append(psnr)
            cs_nrmse.append(nrmse)
            cs_ssim.append(ssim)

            print("cs psnr, nrmse, ssim")
            print(
                np.mean(cs_psnr),
                np.std(cs_psnr),
                np.mean(cs_nrmse),
                np.std(cs_nrmse),
                np.mean(cs_ssim),
                np.std(cs_ssim),
            )
        print("End of testing loop")
        txt_path = os.path.join(model_dir, "metrics.txt")
        f = open(txt_path, "w")
        f.write("parameters = " + str(total_parameters) + "\n"
                "output psnr = " + str(np.mean(output_psnr)) + " +\- " +
                str(np.std(output_psnr)) + "\n" + "output nrmse = " +
                str(np.mean(output_nrmse)) + " +\- " +
                str(np.std(output_nrmse)) + "\n" + "output ssim = " +
                str(np.mean(output_ssim)) + " +\- " +
                str(np.std(output_ssim)) + "\n"
                "cs psnr = " + str(np.mean(cs_psnr)) + " +\- " +
                str(np.std(cs_psnr)) + "\n" + "output nrmse = " +
                str(np.mean(cs_nrmse)) + " +\- " + str(np.std(cs_nrmse)) +
                "\n" + "output ssim = " + str(np.mean(cs_ssim)) + " +\- " +
                str(np.std(cs_ssim)))
        f.close()
Beispiel #7
0
    def test(self):
        print("testing")
        # read in a correct test dicom file to change it later
        # dicom_filename = get_testdata_files("MR_small.dcm")[0]
        # self.ds = pydicom.dcmread(dicom_filename)

        # 18 time frames in each DICOM
        max_frame = self.max_frames
        frame = 1
        case = 1
        gif = []
        print("number of test cases", self.input_files)

        total_acc = []
        mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64)
        numel = tf.cast(tf.size(mask_input), tf.float32)
        acc = numel / tf.reduce_sum(tf.abs(mask_input))

        output_psnr = []
        output_nrmse = []
        output_ssim = []
        cs_psnr = []
        cs_nrmse = []
        cs_ssim = []

        for step in range(self.input_files):
            # whenever you do a lot of TF operations or sess.run(-c) in a loop, cpu memory builds up

            print("test file #", step)

            acc_run = self.sess.run(acc)
            # acc = np.round(acc, decimals=2)
            # print("acceleration", acc_run)
            total_acc.append(acc_run)
            print(
                "total test acc:",
                np.round(np.mean(total_acc), decimals=2),
                np.round(np.std(total_acc), decimals=2),
            )

            # if (step < select*max_frame) or (step > (select+1)*max_frame):
            #     continue
            if self.data_type is "knee":
                # l1 = 0.015
                l1 = 0.02
            if self.data_type is "DCE":
                l1 = 0.05
            if self.data_type is "DCE_2D":
                l1 = 0.07

            # bart_test = self.bart_cs(sample_ks, sample_sensemap, l1=l1)
            # im_in = tf_util.model_transpose(self.ks, self.sensemap)
            # output_image, input_image, complex_truth = self.sess.run([self.im_out, im_in, self.complex_truth])
            output_image, complex_truth = self.sess.run(
                [self.im_out, self.complex_truth])
            if self.data_type is "knee":
                # input_image = np.squeeze(input_image)
                output_image = np.squeeze(output_image)
                truth_image = np.squeeze(complex_truth)

                psnr, nrmse, ssim = metrics.compute_all(truth_image,
                                                        output_image,
                                                        sos_axis=-1)
                # psnr, nrmse, ssim = metrics.compute_all(
                #     truth_image, input_image, sos_axis=-1
                # )

                # psnr = np.round(psnr, decimals=2)
                # nrmse = np.round(nrmse, decimals=2)
                # ssim = np.round(ssim, decimals=2)

                output_psnr.append(psnr)
                output_nrmse.append(nrmse)
                output_ssim.append(ssim)

                print("output psnr, nrmse, ssim")
                print(
                    np.round(np.mean(output_psnr), decimals=2),
                    np.round(np.mean(output_nrmse), decimals=2),
                    np.round(np.mean(output_ssim), decimals=2),
                )

                # psnr, nrmse, ssim = metrics.compute_all(
                #     complex_truth, bart_test, sos_axis=-1
                # )

                # psnr = np.round(psnr, decimals=2)
                # nrmse = np.round(nrmse, decimals=2)
                # ssim = np.round(ssim, decimals=2)

                # cs_psnr.append(psnr)
                # cs_nrmse.append(nrmse)
                # cs_ssim.append(ssim)
                # print("cs psnr, nrmse, ssim")
                # print(
                #     np.mean(psnr), np.mean(nrmse), np.mean(ssim),
                # )

            # mag_input = np.squeeze(np.absolute(input_image))
            # mag_cs = np.squeeze(np.absolute(bart_test))
            # mag_output = np.squeeze(np.absolute(output_image))

            # if self.data_type is "knee":
            #     # save as png
            #     mag_input = np.rot90(mag_input, k=3)
            #     # norm_max = np.amax(mag_input)
            #     mag_cs = np.rot90(mag_cs, k=3)
            #     mag_output = np.rot90(mag_output, k=3)

            #     mag_images = np.concatenate((mag_input, mag_cs, mag_output), axis=1)
            #     filename = self.image_dir + "/mag_" + str(step) + ".png"
            #     scipy.misc.imsave(filename, mag_images)
            if self.data_type is "DCE":
                gif = []
                for f in range(max_frame):
                    frame_input = mag_input[:, :, f]
                    frame_output = mag_output[:, :, f]
                    frame_cs = mag_cs[:, :, f]
                    rotated_input = np.rot90(frame_input, 1)
                    rotated_output = np.rot90(frame_output, 1)
                    rotated_cs = np.rot90(frame_cs, 1)

                    # normalize to help the CS brightness
                    newMax = np.max(rotated_input)
                    newMin = np.min(rotated_input)
                    oldMax = np.max(rotated_cs)
                    oldMin = np.min(rotated_cs)
                    rotated_cs = (rotated_cs - oldMin) * (newMax - newMin) / (
                        oldMax - oldMin) + newMin
                    full_images = np.concatenate(
                        (rotated_input, rotated_output, rotated_cs), axis=1)
                    #                 filename = self.log_dir + '/images/case' + str(step) + '_f' + str(f) + '.png'
                    #                 scipy.misc.imsave(filename, full_images)
                    new_filename = (self.log_dir + "/dicoms/" + str(step) +
                                    "_f" + str(f) + ".dcm")
                    self.write_dicom(full_images, new_filename, step, f)

                    # add each frame to a gif
                    gif.append(full_images)

                print("Saving gif")
                gif_path = self.log_dir + "/gifs/slice_" + str(step) + ".gif"
                imageio.mimsave(gif_path, gif, duration=0.2)

            # Save as PNG
            # filename = self.log_dir + '/images/case' + str(case) + '_f' + str(frame) + '.png'
            # scipy.misc.imsave(filename, saved)
            # Save as DICOM
            if self.data_type is "DCE_2D":
                rotated_input = np.rot90(mag_input, 1)
                rotated_output = np.rot90(mag_output, 1)
                rotated_cs = np.rot90(mag_cs, 1)

                full_images = np.concatenate(
                    (rotated_input, rotated_output, rotated_cs), axis=1)
                full_images = full_images * 10.0
                new_filename = (self.log_dir + "/dicoms/" + str(case) + "_f" +
                                str(frame) + ".dcm")
                dicom_output = np.squeeze(np.abs(full_images))
                self.write_dicom(dicom_output, new_filename, case, frame)

                # Save as PNG
                filename = (self.log_dir + "/images/case" + str(case) + "_f" +
                            str(frame) + ".png")
                scipy.misc.imsave(filename, full_images)

                if frame <= max_frame:
                    gif.append(full_images)
                    frame = frame + 1
                if frame > max_frame:
                    print("Max frame")
                    gif.append(full_images)
                    # gif = gif+100
                    # gif = gif.astype('uint8')
                    # timeMax = np.max(gif, axis=-1)
                    # gif = gif/np.max(gif)
                    # if self.shuffle == "False":
                    print("Saving gif")
                    gif_path = self.log_dir + \
                        "/gifs/new_gif" + str(case) + ".gif"
                    imageio.mimsave(gif_path, gif, "GIF", duration=0.2)

                    # return back to next case
                    frame = 1
                    case = case + 1
                    gif = []

                    # Create a gif
                    if frame <= max_frame:
                        mypath = self.log_dir + "/images/case" + str(case)
                        search_str = "*.png"
                        filenames = sorted(glob.glob(mypath + search_str))
                        # make and save gif
                        gif_path = self.log_dir + \
                            "/gifs/case_" + str(case) + ".gif"

                        images = []
                        for f in filenames:
                            image = scipy.misc.imread(f)
                            images.append(image)
                        filename_gif = mypath + "/case.gif"
                        imageio.mimsave(gif_path, images, duration=0.3)

        txt_path = os.path.join(self.log_dir, "output_metrics.txt")
        f = open(txt_path, "w")
        f.write("output psnr = " + str(np.mean(output_psnr)) + " +\- " +
                str(np.std(output_psnr)) + "\n" + "output nrmse = " +
                str(np.mean(output_nrmse)) + " +\- " +
                str(np.std(output_nrmse)) + "\n" + "output ssim = " +
                str(np.mean(output_ssim)) + " +\- " +
                str(np.std(output_ssim)) + "\n" + "test acc = " +
                str(np.mean(total_acc)) + " +\-" + str(np.std(total_acc)))
        f.close()
    def test(self):
        print("testing")
        # read in a correct test dicom file to change it later
        dicom_filename = pydicom.data.get_testdata_files("MR_small.dcm")[0]
        self.ds = pydicom.dcmread(dicom_filename)

        # 18 time frames in each DICOM
        max_frame = self.max_frames
        frame = 0
        x_slice = 0
        case = 1
        gif = []
        print("number of test cases", self.input_files)

        total_acc = []
        mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64)
        numel = tf.cast(tf.size(mask_input), tf.float32)
        acc = numel / tf.reduce_sum(tf.abs(mask_input))

        input_psnr = []
        input_nrmse = []
        input_ssim = []
        output_psnr = []
        output_nrmse = []
        output_ssim = []
        cs_psnr = []
        cs_nrmse = []
        cs_ssim = []

        input_volume = np.zeros((self.max_frames, 192, 80, 180))
        output_volume = np.zeros((self.max_frames, 192, 80, 180))
        cs_volume = np.zeros((self.max_frames, 192, 80, 180))

        model_time = []
        cs_time = []
        for step in range(self.input_files // 2):
            # for step in range(20):
            # DCE_2D: iterator will see each slice followed by 18 time frames
            # then the next slice
            print("test file #", step)
            acc_run = self.sess.run(acc)
            total_acc.append(acc_run)
            print(
                "total test acc:",
                np.round(np.mean(total_acc), decimals=2),
                np.round(np.std(total_acc), decimals=2),
            )
            if self.data_type is "knee":
                # l1 = 0.015
                l1 = 0.0035
            if self.data_type is "DCE":
                # l1 = 0.05
                l1 = 0.01
            if self.data_type is "DCE_2D":
                l1 = 0.05

            model_start_time = time.time()
            (
                input_image,
                output_image,
                complex_truth,
                ks_run,
                sensemap_run,
            ) = self.sess.run([
                self.im_in,
                self.output_image,
                self.complex_truth,
                self.ks,
                self.sensemap,
            ])
            runtime = time.time() - model_start_time
            if step is not 1:
                model_time.append(runtime)
            print("GAN: %s seconds" % np.mean(model_time),
                  "+/- %s" % np.std(model_time))

            # bart_test = np.zeros_like(output_image)
            cs_start_time = time.time()
            bart_test = self.bart_cs(ks_run, sensemap_run, l1=l1)
            runtime = time.time() - cs_start_time
            if step is not 1:
                cs_time.append(runtime)
            print("CS: %s seconds" % np.mean(cs_time),
                  "+/- %s" % np.std(cs_time))

            if self.data_type is "knee":
                input_image = np.squeeze(input_image)
                output_image = np.squeeze(output_image)
                truth_image = np.squeeze(complex_truth)
                cs_image = np.squeeze(bart_test)

                psnr, nrmse, ssim = metrics.compute_all(truth_image,
                                                        cs_image,
                                                        sos_axis=-1)
                cs_psnr.append(psnr)
                cs_nrmse.append(nrmse)
                cs_ssim.append(ssim)

                print("cs psnr, nrmse, ssim")
                print(
                    np.round(np.mean(cs_psnr), decimals=2),
                    np.round(np.mean(cs_nrmse), decimals=2),
                    np.round(np.mean(cs_ssim), decimals=2),
                )

                psnr, nrmse, ssim = metrics.compute_all(truth_image,
                                                        output_image,
                                                        sos_axis=-1)

                output_psnr.append(psnr)
                output_nrmse.append(nrmse)
                output_ssim.append(ssim)

                print("output psnr, nrmse, ssim")
                print(
                    np.round(np.mean(output_psnr), decimals=2),
                    np.round(np.mean(output_nrmse), decimals=2),
                    np.round(np.mean(output_ssim), decimals=2),
                )

                psnr, nrmse, ssim = metrics.compute_all(truth_image,
                                                        input_image,
                                                        sos_axis=-1)
                input_psnr.append(psnr)
                input_nrmse.append(nrmse)
                input_ssim.append(ssim)

                print("input psnr, nrmse, ssim")
                print(
                    np.round(np.mean(input_psnr), decimals=2),
                    np.round(np.mean(input_nrmse), decimals=2),
                    np.round(np.mean(input_ssim), decimals=2),
                )

            def rotate_image(img):
                img = np.squeeze(np.absolute(img))
                if self.data_type is "DCE":
                    img = np.transpose(img, axes=(1, 0, 2))
                    img = np.flip(img, axis=2)  # flip the time
                if self.data_type is "DCE_2D":
                    img = np.transpose(img, axes=(1, 0))
                return img

            mag_input = rotate_image(input_image)
            mag_output = rotate_image(output_image)
            mag_cs = rotate_image(bart_test)

            # x, y, z, time
            if self.data_type is "DCE":
                input_volume[step, :, :, :] = mag_input
                output_volume[step, :, :, :] = mag_output
                cs_volume[step, :, :, :] = mag_cs
            if self.data_type is "DCE_2D":
                input_volume[frame, x_slice, :, :] = mag_input
                output_volume[frame, x_slice, :, :] = mag_output
                cs_volume[frame, x_slice, :, :] = mag_cs

                new_filename = (self.log_dir + "/dicoms/" + "output_slice_" +
                                str(x_slice) + "_f" + str(frame) + ".dcm")
                self.write_dicom(mag_input, new_filename, x_slice, frame)

                # increment frame
                # if frame is 17, go back to next slice
                if frame == self.max_frames - 1:
                    frame = 0
                    x_slice += 1
                else:
                    frame += 1
                print("slice", x_slice, "time frame", frame)

        in_sl = np.abs(input_volume[2, 0, :, :])

        filename = os.path.join(self.log_dir,
                                os.path.basename(self.search_str[:-11]))
        input_dir = filename + "_input" + ".npy"
        output_dir = filename + "_output" + ".npy"
        cs_dir = filename + "_cs" + ".npy"
        print("saving numpy volumes")
        np.save(input_dir, input_volume)
        np.save(output_dir, output_volume)
        np.save(cs_dir, cs_volume)
        print(output_dir)
        print("saving cfl volumes")
        cfl.write(input_dir, input_volume, "R")
        cfl.write(output_dir, output_volume, "R")
        cfl.write(cs_dir, cs_volume, "R")

        if self.data_type is "knee":
            print("output psnr = " + str(np.mean(output_psnr)) + " +\- " +
                  str(np.std(output_psnr)) + "\n" + "output nrmse = " +
                  str(np.mean(output_nrmse)) + " +\- " +
                  str(np.std(output_nrmse)) + "\n" + "output ssim = " +
                  str(np.mean(output_ssim)) + " +\- " +
                  str(np.std(output_ssim)) + "\n" + "test acc = " +
                  str(np.mean(total_acc)) + " +\-" + str(np.std(total_acc)))
            txt_path = os.path.join(self.log_dir, "output_metrics.txt")
            f = open(txt_path, "w")
            f.write("output psnr = " + str(np.mean(output_psnr)) + " +\- " +
                    str(np.std(output_psnr)) + "\n" + "output nrmse = " +
                    str(np.mean(output_nrmse)) + " +\- " +
                    str(np.std(output_nrmse)) + "\n" + "output ssim = " +
                    str(np.mean(output_ssim)) + " +\- " +
                    str(np.std(output_ssim)) + "\n" + "input psnr = " +
                    str(np.mean(input_psnr)) + " +\- " +
                    str(np.std(input_psnr)) + "\n" + "input nrmse = " +
                    str(np.mean(input_nrmse)) + " +\- " +
                    str(np.std(input_nrmse)) + "\n" + "input ssim = " +
                    str(np.mean(input_ssim)) + " +\- " +
                    str(np.std(input_ssim)) + "\n" + "test acc = " +
                    str(np.mean(total_acc)) + " +\-" + str(np.std(total_acc)))
            f.close()
            txt_path = os.path.join(self.log_dir, "cs_metrics.txt")
            f = open(txt_path, "w")
            f.write("cs psnr = " + str(np.mean(cs_psnr)) + " +\- " +
                    str(np.std(cs_psnr)) + "\n" + "output nrmse = " +
                    str(np.mean(cs_nrmse)) + " +\- " + str(np.std(cs_nrmse)) +
                    "\n" + "output ssim = " + str(np.mean(cs_ssim)) + " +\- " +
                    str(np.std(cs_ssim)))
            f.close()