示例#1
0
def bart_cs(bart_dir, ks, sensemap, l1=0.01):
    cfl_ks = np.squeeze(ks)
    cfl_ks = np.expand_dims(cfl_ks, -2)
    cfl_sensemap = np.squeeze(sensemap)
    cfl_sensemap = np.expand_dims(cfl_sensemap, -2)

    ks_dir = os.path.join(bart_dir, "file_ks")
    sense_dir = os.path.join(bart_dir, "file_sensemap")
    img_dir = os.path.join(bart_dir, "file_img")

    cfl.write(ks_dir, cfl_ks, "R")
    cfl.write(sense_dir, cfl_sensemap, "R")

    # L1-wavelet regularized
    cmd_flags = "-S -e -R W:3:0:%f -i 100" % l1

    cmd = "%s pics %s %s %s %s" % (
        BIN_BART,
        cmd_flags,
        ks_dir,
        sense_dir,
        img_dir,
    )
    subprocess.check_call(["bash", "-c", cmd])
    bart_recon = load_recon(img_dir, sense_dir)
    return bart_recon
示例#2
0
def bart_pics(
    ks_input,
    verbose=False,
    sensemap=None,
    shape_e=2,
    do_cs=True,
    do_imag_reg=False,
    filename_ks_tmp="ks.tmp",
    filename_map_tmp="map.tmp",
    filename_im_tmp="im.tmp",
    filename_ks_out_tmp="ks_out.tmp",
):
    """BART PICS reconstruction."""
    if verbose:
        print("PICS (l1-ESPIRiT) reconstruction...")

    cfl.write(filename_ks_tmp, ks_input)
    if sensemap is None:
        cmd = "bart ecalib -m %d -c 1e-9 %s %s" % (
            shape_e,
            filename_ks_tmp,
            filename_map_tmp,
        )
        if verbose:
            print("  %s" % cmd)
        subprocess.check_output(["bash", "-c", cmd])
    else:
        cfl.write(filename_map_tmp, sensemap)
    if do_cs:
        flags = "-l1 -r 1e-2"
    else:
        flags = "-l2 -r 1e-2"
    if do_imag_reg:
        flags = flags + " -R R1:7:1e-1"

    cmd = "bart pics %s -S %s %s %s" % (
        flags,
        filename_ks_tmp,
        filename_map_tmp,
        filename_im_tmp,
    )
    if verbose:
        print("  %s" % cmd)
    subprocess.check_output(["bash", "-c", cmd])

    cmd = "bart fakeksp -r %s %s %s %s" % (
        filename_im_tmp,
        filename_ks_tmp,
        filename_map_tmp,
        filename_ks_out_tmp,
    )
    if verbose:
        print("  %s" % cmd)
    subprocess.check_output(["bash", "-c", cmd])
    ks_pics = np.squeeze(cfl.read(filename_ks_out_tmp))
    ks_pics = np.expand_dims(ks_pics, axis=0)

    return ks_pics
    def bart_cs(self, ks, sensemap, l1=0.01):
        if self.data_type is "knee" or "DCE_2D":
            cfl_ks = np.squeeze(ks)
            cfl_ks = np.expand_dims(cfl_ks, -2)
            cfl_sensemap = np.squeeze(sensemap)
            cfl_sensemap = np.expand_dims(cfl_sensemap, -2)

            ks_dir = os.path.join(self.bart_dir, "file_ks")
            sense_dir = os.path.join(self.bart_dir, "file_sensemap")
            img_dir = os.path.join(self.bart_dir, "file_img")

            cfl.write(ks_dir, cfl_ks, "R")
            cfl.write(sense_dir, cfl_sensemap, "R")

            # L1-wavelet regularized
            cmd_flags = "-S -e -R W:3:0:%f -i 100" % l1
            # cmd_flags = "-S -e -R W:0:0:%f -i 100" % l1
            cmd = "%s pics %s %s %s %s" % (
                BIN_BART,
                cmd_flags,
                ks_dir,
                sense_dir,
                img_dir,
            )
            subprocess.check_call(["bash", "-c", cmd])
            bart_recon = self.load_recon(img_dir, sense_dir)
        elif self.data_type is "DCE":
            cfl_ks = np.squeeze(ks)
            cfl_sensemap = np.squeeze(sensemap)

            cfl_ks = np.transpose(cfl_ks, [0, 1, 3, 2])
            cfl_ks = np.expand_dims(cfl_ks, -2)
            cfl_ks = np.expand_dims(cfl_ks, 2)

            cfl_sensemap = np.expand_dims(cfl_sensemap, axis=-2)

            ks_dir = os.path.join(self.bart_dir, "file_ks")
            sense_dir = os.path.join(self.bart_dir, "file_sensemap")
            img_dir = os.path.join(self.bart_dir, "file_img")

            cfl.write(ks_dir, cfl_ks, "R")
            cfl.write(sense_dir, cfl_sensemap, "R")
            # Low-rank
            # might be 3:3
            cmd_flags = "-S -e -R L:7:7:%f -i 100" % l1
            cmd = "%s pics %s %s %s %s" % (
                BIN_BART,
                cmd_flags,
                ks_dir,
                sense_dir,
                img_dir,
            )
            subprocess.check_call(["bash", "-c", cmd])
            bart_recon = self.load_recon(img_dir, sense_dir)
        else:
            print("implement bart for this data type")
            exit()
        return bart_recon
示例#4
0
def bart_espirit(
    ks_input,
    shape=None,
    verbose=False,
    shape_e=2,
    crop_value=None,
    cal_size=None,
    smooth=False,
    filename_ks_tmp="ks.tmp",
    filename_map_tmp="map.tmp",
):
    """Estimate sensitivity maps using BART ESPIRiT.
    ks_input dimensions: [emaps, channels, kz, ky, kx]
    """
    if verbose:
        print("Estimating sensitivity maps...")
    if shape is not None:
        ks_input = recon.crop(ks_input, [-1, -1, shape[0], shape[1], -1])
        ks_input = recon.zeropad(ks_input, [-1, -1, shape[0], shape[1], -1])

    flags = ""
    if crop_value is not None:
        flags = flags + "-c %f " % crop_value
    if cal_size is not None:
        flags = flags + "-r %d " % cal_size
    if smooth:
        flags = flags + "-S "

    cfl.write(filename_ks_tmp, ks_input)
    cmd = "bart ecalib -m %d %s %s %s" % (
        shape_e,
        flags,
        filename_ks_tmp,
        filename_map_tmp,
    )
    if verbose:
        print("  %s" % cmd)
    time_start = timer()
    subprocess.check_output(["bash", "-c", cmd])
    time_end = timer()
    sensemap = cfl.read(filename_map_tmp)
    return sensemap, time_end - time_start
示例#5
0
def setup_data_tfrecords_MFAST(
    dir_in_root,
    dir_out,
    data_divide=(0.75, 0.05, 0.2),
    min_shape=[80, 180],
    num_maps=1,
    crop_maps=False,
    verbose=False,
):
    """Setups training data as tfrecords.
    prep_data.setup_data('/mnt/raid3/data/Studies_DCE/recon-ccomp6/',
        '/mnt/raid3/jycheng/Project/deepspirit/data/train/', verbose=True)
    """
    if verbose:
        print("Directory names:")
        print("  Input root:  %s" % dir_in_root)
        print("  Output root: %s" % dir_out)

    file_kspace = "kspace"
    file_sensemap = "sensemap"

    case_list = os.listdir(dir_in_root)
    random.shuffle(case_list)
    num_cases = len(case_list)

    i_train_1 = np.round(data_divide[0] * num_cases).astype(int)
    i_validate_0 = i_train_1 + 1
    i_validate_1 = np.round(
        data_divide[1] * num_cases).astype(int) + i_validate_0

    if not os.path.exists(dir_out):
        os.mkdir(dir_out)
    if not os.path.exists(os.path.join(dir_out, "train")):
        os.mkdir(os.path.join(dir_out, "train"))
    if not os.path.exists(os.path.join(dir_out, "validate")):
        os.mkdir(os.path.join(dir_out, "validate"))
    if not os.path.exists(os.path.join(dir_out, "test")):
        os.mkdir(os.path.join(dir_out, "test"))

    i_case = 0
    for case_name in case_list:
        file_kspace_i = os.path.join(dir_in_root, case_name, file_kspace)
        file_sensemap_i = os.path.join(dir_in_root, case_name, file_sensemap)

        if not os.path.exists(file_kspace_i + ".hdr"):
            print("skipping due to kspace not existing in this folder")
            continue

        if i_case < i_train_1:
            dir_out_i = os.path.join(dir_out, "train")
        elif i_case < i_validate_1:
            dir_out_i = os.path.join(dir_out, "validate")
        else:
            dir_out_i = os.path.join(dir_out, "test")

        if verbose:
            print("Processing [%d] %s..." % (i_case, case_name))
        i_case = i_case + 1

        kspace = np.squeeze(cfl.read(file_kspace_i))

        num_coils = kspace.shape[0]
        if num_coils is not 32:
            print("skipping due to incorrect number of coils")
            continue

        if verbose:
            print("  Slice shape: (%d, %d)" %
                  (kspace.shape[1], kspace.shape[2]))
            print("  Num channels: %d" % kspace.shape[0])
        shape_x = kspace.shape[-1]

        # for n in range(shape_x):
        #     modulation = (-1) ** n
        #     kspace[:, :, :, n] = kspace[:, :, :, n] * [
        #         1,
        #         1,
        #         1,
        #         np.exp(-1j * modulation),
        #     ]
        # print("kspace shape after modulation")

        # should this be here?
        kspace = fftc.ifftc(kspace, axis=-1)
        print("original kspace shape")
        print(kspace.shape)

        # x,y,z,coils
        # crop or zero pad to the correct size
        #         if (kspace.shape[0] is not min_shape[0]) or (kspace_shape[1] is not min_shape[1]):
        #             print("resizing")
        #             image = fftc.ifft2c(kspace, do_orthonorm=True, order='C')
        #             new_shape = (kspace.shape[0], min_shape[0], min_shape[1], kspace.shape[-1])
        #             resized_im = image.copy()
        #             resized_im.resize(new_shape, refcheck=False)
        #             kspace = fftc.fft2c(resized_im, do_orthonorm=True, order='C')

        kspace = kspace.astype(np.complex64)

        print("new kspace shape")
        print(kspace.shape)
        file_kspace_i = file_kspace_i + "_resized"
        cfl.write(file_kspace_i, kspace)

        cmd_flags = ""
        if crop_maps:
            cmd_flags = cmd_flags + " -c 1e-9"
        cmd_flags = cmd_flags + (" -m %d" % num_maps)
        cmd = "%s ecalib %s %s %s" % (
            BIN_BART,
            cmd_flags,
            file_kspace_i,
            file_sensemap_i,
        )
        if verbose:
            print("  Estimating sensitivity maps (bart espirit)...")
            print("    %s" % cmd)
        subprocess.check_call(["bash", "-c", cmd])
        sensemap = np.squeeze(cfl.read(file_sensemap_i))
        sensemap = np.expand_dims(sensemap, axis=0)
        sensemap = sensemap.astype(np.complex64)

        if verbose:
            print("  Creating tfrecords (%d)..." % shape_x)
        for i_x in range(shape_x):
            file_out = os.path.join(dir_out_i,
                                    "%s_x%03d.tfrecords" % (case_name, i_x))
            kspace_x = kspace[:, :, :, i_x]
            sensemap_x = sensemap[:, :, :, :, i_x]

            example = tf.train.Example(features=tf.train.Features(
                feature={
                    "name": _bytes_feature(str.encode(case_name)),
                    "xslice": _int64_feature(i_x),
                    "ks_shape_x": _int64_feature(kspace.shape[3]),
                    "ks_shape_y": _int64_feature(kspace.shape[2]),
                    "ks_shape_z": _int64_feature(kspace.shape[1]),
                    "ks_shape_c": _int64_feature(kspace.shape[0]),
                    "map_shape_x": _int64_feature(sensemap.shape[4]),
                    "map_shape_y": _int64_feature(sensemap.shape[3]),
                    "map_shape_z": _int64_feature(sensemap.shape[2]),
                    "map_shape_c": _int64_feature(sensemap.shape[1]),
                    "map_shape_m": _int64_feature(sensemap.shape[0]),
                    "ks": _bytes_feature(kspace_x.tostring()),
                    "map": _bytes_feature(sensemap_x.tostring()),
                }))

            tf_writer = tf.python_io.TFRecordWriter(file_out)
            tf_writer.write(example.SerializeToString())
            tf_writer.close()
    def test(self):
        print("testing")
        # read in a correct test dicom file to change it later
        dicom_filename = pydicom.data.get_testdata_files("MR_small.dcm")[0]
        self.ds = pydicom.dcmread(dicom_filename)

        # 18 time frames in each DICOM
        max_frame = self.max_frames
        frame = 0
        x_slice = 0
        case = 1
        gif = []
        print("number of test cases", self.input_files)

        total_acc = []
        mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64)
        numel = tf.cast(tf.size(mask_input), tf.float32)
        acc = numel / tf.reduce_sum(tf.abs(mask_input))

        input_psnr = []
        input_nrmse = []
        input_ssim = []
        output_psnr = []
        output_nrmse = []
        output_ssim = []
        cs_psnr = []
        cs_nrmse = []
        cs_ssim = []

        input_volume = np.zeros((self.max_frames, 192, 80, 180))
        output_volume = np.zeros((self.max_frames, 192, 80, 180))
        cs_volume = np.zeros((self.max_frames, 192, 80, 180))

        model_time = []
        cs_time = []
        for step in range(self.input_files // 2):
            # for step in range(20):
            # DCE_2D: iterator will see each slice followed by 18 time frames
            # then the next slice
            print("test file #", step)
            acc_run = self.sess.run(acc)
            total_acc.append(acc_run)
            print(
                "total test acc:",
                np.round(np.mean(total_acc), decimals=2),
                np.round(np.std(total_acc), decimals=2),
            )
            if self.data_type is "knee":
                # l1 = 0.015
                l1 = 0.0035
            if self.data_type is "DCE":
                # l1 = 0.05
                l1 = 0.01
            if self.data_type is "DCE_2D":
                l1 = 0.05

            model_start_time = time.time()
            (
                input_image,
                output_image,
                complex_truth,
                ks_run,
                sensemap_run,
            ) = self.sess.run([
                self.im_in,
                self.output_image,
                self.complex_truth,
                self.ks,
                self.sensemap,
            ])
            runtime = time.time() - model_start_time
            if step is not 1:
                model_time.append(runtime)
            print("GAN: %s seconds" % np.mean(model_time),
                  "+/- %s" % np.std(model_time))

            # bart_test = np.zeros_like(output_image)
            cs_start_time = time.time()
            bart_test = self.bart_cs(ks_run, sensemap_run, l1=l1)
            runtime = time.time() - cs_start_time
            if step is not 1:
                cs_time.append(runtime)
            print("CS: %s seconds" % np.mean(cs_time),
                  "+/- %s" % np.std(cs_time))

            if self.data_type is "knee":
                input_image = np.squeeze(input_image)
                output_image = np.squeeze(output_image)
                truth_image = np.squeeze(complex_truth)
                cs_image = np.squeeze(bart_test)

                psnr, nrmse, ssim = metrics.compute_all(truth_image,
                                                        cs_image,
                                                        sos_axis=-1)
                cs_psnr.append(psnr)
                cs_nrmse.append(nrmse)
                cs_ssim.append(ssim)

                print("cs psnr, nrmse, ssim")
                print(
                    np.round(np.mean(cs_psnr), decimals=2),
                    np.round(np.mean(cs_nrmse), decimals=2),
                    np.round(np.mean(cs_ssim), decimals=2),
                )

                psnr, nrmse, ssim = metrics.compute_all(truth_image,
                                                        output_image,
                                                        sos_axis=-1)

                output_psnr.append(psnr)
                output_nrmse.append(nrmse)
                output_ssim.append(ssim)

                print("output psnr, nrmse, ssim")
                print(
                    np.round(np.mean(output_psnr), decimals=2),
                    np.round(np.mean(output_nrmse), decimals=2),
                    np.round(np.mean(output_ssim), decimals=2),
                )

                psnr, nrmse, ssim = metrics.compute_all(truth_image,
                                                        input_image,
                                                        sos_axis=-1)
                input_psnr.append(psnr)
                input_nrmse.append(nrmse)
                input_ssim.append(ssim)

                print("input psnr, nrmse, ssim")
                print(
                    np.round(np.mean(input_psnr), decimals=2),
                    np.round(np.mean(input_nrmse), decimals=2),
                    np.round(np.mean(input_ssim), decimals=2),
                )

            def rotate_image(img):
                img = np.squeeze(np.absolute(img))
                if self.data_type is "DCE":
                    img = np.transpose(img, axes=(1, 0, 2))
                    img = np.flip(img, axis=2)  # flip the time
                if self.data_type is "DCE_2D":
                    img = np.transpose(img, axes=(1, 0))
                return img

            mag_input = rotate_image(input_image)
            mag_output = rotate_image(output_image)
            mag_cs = rotate_image(bart_test)

            # x, y, z, time
            if self.data_type is "DCE":
                input_volume[step, :, :, :] = mag_input
                output_volume[step, :, :, :] = mag_output
                cs_volume[step, :, :, :] = mag_cs
            if self.data_type is "DCE_2D":
                input_volume[frame, x_slice, :, :] = mag_input
                output_volume[frame, x_slice, :, :] = mag_output
                cs_volume[frame, x_slice, :, :] = mag_cs

                new_filename = (self.log_dir + "/dicoms/" + "output_slice_" +
                                str(x_slice) + "_f" + str(frame) + ".dcm")
                self.write_dicom(mag_input, new_filename, x_slice, frame)

                # increment frame
                # if frame is 17, go back to next slice
                if frame == self.max_frames - 1:
                    frame = 0
                    x_slice += 1
                else:
                    frame += 1
                print("slice", x_slice, "time frame", frame)

        in_sl = np.abs(input_volume[2, 0, :, :])

        filename = os.path.join(self.log_dir,
                                os.path.basename(self.search_str[:-11]))
        input_dir = filename + "_input" + ".npy"
        output_dir = filename + "_output" + ".npy"
        cs_dir = filename + "_cs" + ".npy"
        print("saving numpy volumes")
        np.save(input_dir, input_volume)
        np.save(output_dir, output_volume)
        np.save(cs_dir, cs_volume)
        print(output_dir)
        print("saving cfl volumes")
        cfl.write(input_dir, input_volume, "R")
        cfl.write(output_dir, output_volume, "R")
        cfl.write(cs_dir, cs_volume, "R")

        if self.data_type is "knee":
            print("output psnr = " + str(np.mean(output_psnr)) + " +\- " +
                  str(np.std(output_psnr)) + "\n" + "output nrmse = " +
                  str(np.mean(output_nrmse)) + " +\- " +
                  str(np.std(output_nrmse)) + "\n" + "output ssim = " +
                  str(np.mean(output_ssim)) + " +\- " +
                  str(np.std(output_ssim)) + "\n" + "test acc = " +
                  str(np.mean(total_acc)) + " +\-" + str(np.std(total_acc)))
            txt_path = os.path.join(self.log_dir, "output_metrics.txt")
            f = open(txt_path, "w")
            f.write("output psnr = " + str(np.mean(output_psnr)) + " +\- " +
                    str(np.std(output_psnr)) + "\n" + "output nrmse = " +
                    str(np.mean(output_nrmse)) + " +\- " +
                    str(np.std(output_nrmse)) + "\n" + "output ssim = " +
                    str(np.mean(output_ssim)) + " +\- " +
                    str(np.std(output_ssim)) + "\n" + "input psnr = " +
                    str(np.mean(input_psnr)) + " +\- " +
                    str(np.std(input_psnr)) + "\n" + "input nrmse = " +
                    str(np.mean(input_nrmse)) + " +\- " +
                    str(np.std(input_nrmse)) + "\n" + "input ssim = " +
                    str(np.mean(input_ssim)) + " +\- " +
                    str(np.std(input_ssim)) + "\n" + "test acc = " +
                    str(np.mean(total_acc)) + " +\-" + str(np.std(total_acc)))
            f.close()
            txt_path = os.path.join(self.log_dir, "cs_metrics.txt")
            f = open(txt_path, "w")
            f.write("cs psnr = " + str(np.mean(cs_psnr)) + " +\- " +
                    str(np.std(cs_psnr)) + "\n" + "output nrmse = " +
                    str(np.mean(cs_nrmse)) + " +\- " + str(np.std(cs_nrmse)) +
                    "\n" + "output ssim = " + str(np.mean(cs_ssim)) + " +\- " +
                    str(np.std(cs_ssim)))
            f.close()