示例#1
0
    def get_everything(self, sess, features):
        ks_input = features["ks_input"]
        sensemap = features["sensemap"]
        mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)

        if mask is None:
            mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)

        ks_input = mask * ks_input

        window = 1
        im_lowres = tf_util.model_transpose(ks_input * window, sensemap)
        # im_lowres = tf_util.ifft2c(ks_input)
        im_lowres = tf.identity(im_lowres, name="low_res_image")

        # complex to channels
        im_lowres = tf_util.complex_to_channels(im_lowres)

        ks_truth = features["ks_truth"]
        mask_recon = features["mask_recon"]

        im_truth = tf_util.model_transpose(ks_truth * mask_recon, sensemap)

        im_truth = tf.identity(im_truth, name="truth_image")
        im_truth = tf_util.complex_to_channels(im_truth)

        im_lowres, im_truth, sensemap, mask = sess.run(
            [im_lowres, im_truth, sensemap, mask])

        return im_lowres, im_truth, sensemap, mask
示例#2
0
def _create_summary(sense_place, ks_place, im_out_place, im_truth_place):
    sensemap = sense_place
    ks_input = ks_place
    image_output = im_out_place
    image_truth = im_truth_place

    image_input = tf_util.model_transpose(ks_input, sensemap)
    mask_input = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
    ks_output = tf_util.model_forward(image_output, sensemap)
    ks_truth = tf_util.model_forward(image_truth, sensemap)

    with tf.name_scope("input-output-truth"):
        summary_input = tf_util.sumofsq(ks_input, keep_dims=True)
        summary_output = tf_util.sumofsq(ks_output, keep_dims=True)
        summary_truth = tf_util.sumofsq(ks_truth, keep_dims=True)
        summary_fft = tf.log(
            tf.concat((summary_input, summary_output, summary_truth), axis=2) +
            1e-6)
        tf.summary.image("kspace",
                         summary_fft,
                         max_outputs=FLAGS.num_summary_image)
        summary_input = tf_util.sumofsq(image_input, keep_dims=True)
        summary_output = tf_util.sumofsq(image_output, keep_dims=True)
        summary_truth = tf_util.sumofsq(image_truth, keep_dims=True)
        summary_image = tf.concat(
            (summary_input, summary_output, summary_truth), axis=2)
        tf.summary.image("image",
                         summary_image,
                         max_outputs=FLAGS.num_summary_image)

    with tf.name_scope("truth"):
        summary_truth_real = tf.reduce_sum(image_truth,
                                           axis=-1,
                                           keep_dims=True)
        summary_truth_real = tf.real(summary_truth_real)
        tf.summary.image("image_real",
                         summary_truth_real,
                         max_outputs=FLAGS.num_summary_image)

    with tf.name_scope("mask"):
        summary_mask = tf_util.sumofsq(mask_input, keep_dims=True)
        tf.summary.image("mask",
                         summary_mask,
                         max_outputs=FLAGS.num_summary_image)

    with tf.name_scope("sensemap"):
        summary_map = tf.slice(tf.abs(sensemap), [0, 0, 0, 0, 0],
                               [-1, -1, -1, 1, -1])
        summary_map = tf.transpose(summary_map, [0, 1, 4, 2, 3])
        summary_map = tf.reshape(
            summary_map,
            [tf.shape(summary_map)[0],
             tf.shape(summary_map)[1], -1])
        summary_map = tf.expand_dims(summary_map, axis=-1)
        tf.summary.image("image",
                         summary_map,
                         max_outputs=FLAGS.num_summary_image)
示例#3
0
    def get_images(self, features):
        ks_input = features["ks_input"]
        sensemap = features["sensemap"]
        mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)

        if mask is None:
            mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)

        ks_input = mask * ks_input

        im_lowres = tf_util.model_transpose(ks_input, sensemap)
        im_lowres = tf.identity(im_lowres, name="low_res_image")

        # complex to channels
        im_lowres = tf_util.complex_to_channels(im_lowres)

        ks_truth = features["ks_truth"]
        mask_recon = features["mask_recon"]

        im_truth = tf_util.model_transpose(ks_truth * mask_recon, sensemap)
        im_truth = tf.identity(im_truth, name="truth_image")
        im_truth = tf_util.complex_to_channels(im_truth)

        return im_lowres, im_truth
示例#4
0
    def create_summary(self):
        # note that ks is based on the input data not on the rotated data
        output_ks = tf_util.model_forward(self.output_image, self.sensemap)

        # Input image to generator
        self.input_image = tf_util.model_transpose(self.ks, self.sensemap)

        if self.data_type is "knee":
            truth_image = tf_util.channels_to_complex(self.z_truth)

            sum_input = tf.image.flip_up_down(
                tf.image.rot90(tf.abs(self.input_image)))
            sum_output = tf.image.flip_up_down(
                tf.image.rot90(tf.abs(self.output_image)))
            sum_truth = tf.image.flip_up_down(
                tf.image.rot90(tf.abs(truth_image)))

            train_out = tf.concat((sum_input, sum_output, sum_truth), axis=2)
            tf.summary.image("input-output-truth", train_out)

            mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64)

            loss_l1 = tf.reduce_mean(tf.abs(self.X_gen - self.z_truth))
            loss_l2 = tf.reduce_mean(
                tf.square(tf.abs(self.X_gen - self.z_truth)))
            tf.summary.scalar("l1", loss_l1)
            tf.summary.scalar("l2", loss_l2)

            # to check supervised/unsupervised
            y_real = tf_util.channels_to_complex(self.Y_real)
            y_real = tf.image.flip_up_down(tf.image.rot90(tf.abs(y_real)))
            tf.summary.image("y_real/mag", y_real)

        # Plot losses
        self.d_loss_sum = tf.summary.scalar("Discriminator_loss", self.d_loss)
        self.g_loss_sum = tf.summary.scalar("Generator_loss", self.g_loss)
        self.gp_sum = tf.summary.scalar("Gradient_penalty",
                                        self.gradient_penalty)
        self.d_fake = tf.summary.scalar("subloss/D_fake",
                                        tf.reduce_mean(self.d_logits_fake))
        self.d_real = tf.summary.scalar("subloss/D_real",
                                        tf.reduce_mean(self.d_logits_real))
        self.z_sum = tf.summary.histogram(
            "z", tf_util.complex_to_channels(self.input_image))
        self.d_sum = tf.summary.merge(
            [self.z_sum, self.d_loss_sum, self.d_fake, self.d_real])
        self.g_sum = tf.summary.merge([self.z_sum, self.g_loss_sum])
        self.train_sum = tf.summary.merge_all()
示例#5
0
    def read_real(self):
        # Read in "real" undersampled images
        real_iterator = mri_utils.Iterator(self.batch_size,
                                           self.mask_path,
                                           self.data_type,
                                           "validate",
                                           self.out_shape,
                                           verbose=self.verbose,
                                           data_dir=self.data_dir)
        data_num = real_iterator.num_files
        real_dataset = real_iterator.iterator.get_next()

        img_real = real_iterator.get_truth_image(real_dataset)
        self.masks = real_iterator.masks

        ks = real_dataset["ks_input"]
        real_mask = tf_util.kspace_mask(ks, dtype=tf.complex64)

        return img_real, data_num, real_mask
def unroll_fista(
    ks_input,
    sensemap,
    num_grad_steps=5,
    resblock_num_features=128,
    resblock_num_blocks=2,
    is_training=True,
    scope="MRI",
    mask_output=1,
    window=None,
    do_hardproj=True,
    num_summary_image=0,
    mask=None,
    verbose=False,
    conv="real",
    do_conjugate=False,
    activation="relu",
):
    """Create general unrolled network for MRI.
    x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) )
            = S( x_k - 2 * t * (A^T W A x - A^T W b))
    """
    if window is None:
        window = 1
    summary_iter = None

    global type_conv
    type_conv = conv
    global conjugate
    conjugate = do_conjugate

    if verbose:
        print(
            "%s> Building FISTA unrolled network (%d steps)...."
            % (scope, num_grad_steps)
        )
        if sensemap is not None:
            print("%s>   Using sensitivity maps..." % scope)
    with tf.variable_scope(scope):
        if mask is None:
            mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
        ks_input = mask * ks_input
        ks_0 = ks_input
        # x0 = A^T W b
        im_0 = tf_util.model_transpose(ks_0 * window, sensemap)
        im_0 = tf.identity(im_0, name="input_image")
        # To be updated
        ks_k = ks_0
        im_k = im_0

        for i_step in range(num_grad_steps):
            iter_name = "iter_%02d" % i_step
            with tf.variable_scope(iter_name):
                # = S( x_k - 2 * t * (A^T W A x_k - A^T W b))
                # = S( x_k - 2 * t * (A^T W A x_k - x0))
                with tf.variable_scope("update"):
                    im_k_orig = im_k
                    # xk = A^T A x_k
                    ks_k = tf_util.model_forward(im_k, sensemap)
                    ks_k = mask * ks_k
                    im_k = tf_util.model_transpose(ks_k * window, sensemap)
                    # xk = A^T A x_k - A^T b
                    im_k = tf_util.complex_to_channels(im_k - im_0)
                    im_k_orig = tf_util.complex_to_channels(im_k_orig)
                    # Update step
                    t_update = tf.get_variable(
                        "t", dtype=tf.float32, initializer=tf.constant([-2.0])
                    )
                    im_k = im_k_orig + t_update * im_k

                with tf.variable_scope("prox"):
                    num_channels_out = im_k.shape[-1]
                    im_k = prior_grad_res_net(
                        im_k,
                        training=is_training,
                        num_features=resblock_num_features,
                        num_blocks=resblock_num_blocks,
                        num_features_out=num_channels_out,
                        data_format="channels_last",
                        activation=activation
                    )
                    im_k = tf_util.channels_to_complex(im_k)

                im_k = tf.identity(im_k, name="image")
                if num_summary_image > 0:
                    with tf.name_scope("summary"):
                        tmp = tf_util.sumofsq(im_k, keep_dims=True)
                        if summary_iter is None:
                            summary_iter = tmp
                        else:
                            summary_iter = tf.concat(
                                (summary_iter, tmp), axis=2)
                        tf.summary.scalar("max/" + iter_name,
                                          tf.reduce_max(tmp))

        ks_k = tf_util.model_forward(im_k, sensemap)
        if do_hardproj:
            if verbose:
                print("%s>   Final hard data projection..." % scope)
            # Final data projection
            ks_k = mask * ks_0 + (1 - mask) * ks_k
            if mask_output is not None:
                ks_k = ks_k * mask_output
            im_k = tf_util.model_transpose(ks_k * window, sensemap)

        ks_k = tf.identity(ks_k, name="output_kspace")
        im_k = tf.identity(im_k, name="output_image")

    if summary_iter is not None:
        tf.summary.image("iter/image", summary_iter,
                         max_outputs=num_summary_image)

    return im_k
示例#7
0
def unroll_fista(
    ks_input,
    sensemap,
    scope="MRI",
    num_grad_steps=5,
    num_resblocks=4,
    num_features=64,
    kernel_size=[3, 3, 5],
    is_training=True,
    mask_output=1,
    mask=None,
    window=None,
    do_hardproj=False,
    do_dense=False,
    do_separable=False,
    do_rnn=False,
    do_circular=True,
    batchnorm=False,
    leaky=False,
    fix_update=False,
    data_format="channels_first",
    verbose=False,
):
    """Create general unrolled network for MRI.
    x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) )
            = S( x_k - 2 * t * (A^T W A x - A^T W b))
    """
    # get list of GPU devices
    local_device_protos = device_lib.list_local_devices()
    device_list = [x.name for x in local_device_protos if x.device_type == "GPU"]

    if window is None:
        window = 1
    summary_iter = {}

    if verbose:
        print("%s> Building FISTA unrolled network...." % scope)
        print("%s>   Num of gradient steps: %d" % (scope, num_grad_steps))
        print(
            "%s>   Prior: %d ResBlocks, %d features"
            % (scope, num_resblocks, num_features)
        )
        print("%s>   Kernel size: [%d x %d x %d]" % ((scope,) + tuple(kernel_size)))
        if do_rnn:
            print("%s>   Sharing weights across iterations..." % scope)
        if sensemap is not None:
            print("%s>   Using sensitivity maps..." % scope)
        if do_dense:
            print("%s>   Inserting dense connections..." % scope)
        if do_circular:
            print("%s>   Using circular padding..." % scope)
        if do_separable:
            print("%s>   Using depth-wise separable convolutions..." % scope)
        if not batchnorm:
            print("%s>   Turning off batch normalization..." % scope)

    with tf.variable_scope(scope):
        if mask is None:
            mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
        ks_input = mask * ks_input
        ks_0 = ks_input
        # x0 = A^T W b
        im_0 = tf_util.model_transpose(ks_0 * window, sensemap)
        im_0 = tf.identity(im_0, name="input_image")
        # To be updated
        ks_k = ks_0
        im_k = im_0
        im_dense = None

        for i_step in range(num_grad_steps):
            iter_name = "iter_%02d" % i_step
            if do_rnn:
                scope_name = "iter"
            else:
                scope_name = iter_name

            # figure out which GPU to use for this step
            # i_device = int(len(device_list) * i_step / num_grad_steps)
            # cur_device = device_list[i_device]

            # with tf.device(cur_device):
            with tf.variable_scope(
                scope_name, reuse=(tf.AUTO_REUSE if do_rnn else False)
            ):
                with tf.variable_scope("update"):
                    # = S( x_k - 2 * t * (A^T W A x_k - A^T W b))
                    # = S( x_k - 2 * t * (A^T W A x_k - x0))
                    im_k_orig = im_k
                    # xk = A^T A x_k
                    ks_k = tf_util.model_forward(im_k, sensemap)
                    ks_k = mask * ks_k
                    im_k = tf_util.model_transpose(ks_k * window, sensemap)
                    # xk = A^T A x_k - A^T b
                    im_k = tf_util.complex_to_channels(im_k - im_0)
                    im_k_orig = tf_util.complex_to_channels(im_k_orig)
                    # Update step
                    if fix_update:
                        t_update = -2.0
                    else:
                        t_update = tf.get_variable(
                            "t", dtype=tf.float32, initializer=tf.constant([-2.0])
                        )
                    im_k = im_k_orig + t_update * im_k

                with tf.variable_scope("prox"):
                    # default is channels_last
                    num_channels_out = im_k.shape[-1]
                    if data_format == "channels_first":
                        im_k = tf.transpose(im_k, [0, 4, 1, 2, 3])

                    if im_dense is not None:
                        im_k = tf.concat([im_k, im_dense], axis=1)

                    im_k, im_dense_k = prior_grad_res_net(
                        im_k,
                        training=is_training,
                        num_features=num_features,
                        num_blocks=num_resblocks,
                        num_features_out=num_channels_out,
                        kernel_size=kernel_size,
                        data_format=data_format,
                        circular=do_circular,
                        separable=do_separable,
                        batchnorm=batchnorm,
                        leaky=leaky,
                    )

                    if do_dense:
                        if im_dense is not None:
                            im_dense = tf.concat([im_dense, im_dense_k], axis=1)
                        else:
                            im_dense = im_dense_k

                    if data_format == "channels_first":
                        im_k = tf.transpose(im_k, [0, 2, 3, 4, 1])

                    im_k = tf_util.channels_to_complex(im_k)

                im_k = tf.identity(im_k, name="image")

                with tf.name_scope("summary"):
                    # tmp = tf_util.sumofsq(im_k, keep_dims=True)
                    summary_iter[iter_name] = im_k

        ks_k = tf_util.model_forward(im_k, sensemap)
        if do_hardproj:
            if verbose:
                print("%s>   Final hard data projection..." % scope)
            ks_k = mask * ks_0 + (1 - mask) * ks_k
            if mask_output is not None:
                ks_k = ks_k * mask_output
            im_k = tf_util.model_transpose(ks_k * window, sensemap)

        ks_k = tf.identity(ks_k, name="output_kspace")
        im_k = tf.identity(im_k, name="output_image")

    #     return im_k, ks_k, summary_iter
    return im_k
示例#8
0
    def generator(self, ks_input, sensemap, reuse=False):
        mask_example = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
        with tf.variable_scope("generator") as scope:
            if reuse:
                scope.reuse_variables()
            # 2D data
            # batch, height, width, channels
            if self.dims == 4:
                if self.arch == "unrolled":
                    c_out = unrolled.unroll_fista(
                        ks_input,
                        sensemap,
                        num_grad_steps=self.iterations,
                        resblock_num_features=self.g_dim,
                        resblock_num_blocks=self.res_blocks,
                        is_training=True,
                        scope="MRI",
                        mask_output=1,
                        window=None,
                        do_hardproj=True,
                        num_summary_image=6,
                        mask=mask_example,
                        verbose=False,
                    )
                    c_out = tf_util.complex_to_channels(c_out)
                else:
                    z = tf_util.model_transpose(ks_input, sensemap)
                    z = tf_util.complex_to_channels(z)
                    res_size = self.g_dim
                    kernel_size = 3
                    num_channels = 2
                    # could try tf.nn.tanh instead
                    # act = tf.nn.sigmoid
                    num_blocks = 5
                    c = tf.layers.conv2d(
                        z,
                        res_size,
                        kernel_size,
                        padding="same",
                        activation=tf.nn.relu,
                        use_bias=True,
                    )
                    for i in range(num_blocks):
                        c = tf.layers.conv2d(
                            c,
                            res_size,
                            kernel_size,
                            padding="same",
                            activation=tf.nn.relu,
                            use_bias=True,
                        )
                        c1 = tf.layers.conv2d(
                            c,
                            res_size,
                            kernel_size,
                            padding="same",
                            activation=tf.nn.relu,
                            use_bias=True,
                        )
                        c = tf.add(c, c1)
                    c8 = tf.layers.conv2d(
                        c,
                        num_channels,
                        kernel_size,
                        padding="same",
                        activation=None,
                        use_bias=True,
                    )
                    c_out = tf.add(c8, z)
                    c_out = tf.nn.tanh(c_out)

            # 3D data
            # batch, height, width, channels, time frames
            else:
                if self.arch == "unrolled":
                    c_out = unrolled_3d.unroll_fista(
                        ks_input,
                        sensemap,
                        num_grad_steps=self.iterations,
                        num_features=self.g_dim,
                        num_resblocks=self.res_blocks,
                        is_training=True,
                        scope="MRI",
                        mask_output=1,
                        window=None,
                        do_hardproj=True,
                        mask=mask_example,
                        verbose=False,
                        data_format="channels_last",
                        do_separable=self.do_separable,
                    )
                    c_out = tf_util.complex_to_channels(c_out)
            return c_out
示例#9
0
    def test(self):
        print("testing")
        # read in a correct test dicom file to change it later
        # dicom_filename = get_testdata_files("MR_small.dcm")[0]
        # self.ds = pydicom.dcmread(dicom_filename)

        # 18 time frames in each DICOM
        max_frame = self.max_frames
        frame = 1
        case = 1
        gif = []
        print("number of test cases", self.input_files)

        total_acc = []
        mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64)
        numel = tf.cast(tf.size(mask_input), tf.float32)
        acc = numel / tf.reduce_sum(tf.abs(mask_input))

        output_psnr = []
        output_nrmse = []
        output_ssim = []
        cs_psnr = []
        cs_nrmse = []
        cs_ssim = []

        for step in range(self.input_files):
            # whenever you do a lot of TF operations or sess.run(-c) in a loop, cpu memory builds up

            print("test file #", step)

            acc_run = self.sess.run(acc)
            # acc = np.round(acc, decimals=2)
            # print("acceleration", acc_run)
            total_acc.append(acc_run)
            print(
                "total test acc:",
                np.round(np.mean(total_acc), decimals=2),
                np.round(np.std(total_acc), decimals=2),
            )

            # if (step < select*max_frame) or (step > (select+1)*max_frame):
            #     continue
            if self.data_type is "knee":
                # l1 = 0.015
                l1 = 0.02
            if self.data_type is "DCE":
                l1 = 0.05
            if self.data_type is "DCE_2D":
                l1 = 0.07

            # bart_test = self.bart_cs(sample_ks, sample_sensemap, l1=l1)
            # im_in = tf_util.model_transpose(self.ks, self.sensemap)
            # output_image, input_image, complex_truth = self.sess.run([self.im_out, im_in, self.complex_truth])
            output_image, complex_truth = self.sess.run(
                [self.im_out, self.complex_truth])
            if self.data_type is "knee":
                # input_image = np.squeeze(input_image)
                output_image = np.squeeze(output_image)
                truth_image = np.squeeze(complex_truth)

                psnr, nrmse, ssim = metrics.compute_all(truth_image,
                                                        output_image,
                                                        sos_axis=-1)
                # psnr, nrmse, ssim = metrics.compute_all(
                #     truth_image, input_image, sos_axis=-1
                # )

                # psnr = np.round(psnr, decimals=2)
                # nrmse = np.round(nrmse, decimals=2)
                # ssim = np.round(ssim, decimals=2)

                output_psnr.append(psnr)
                output_nrmse.append(nrmse)
                output_ssim.append(ssim)

                print("output psnr, nrmse, ssim")
                print(
                    np.round(np.mean(output_psnr), decimals=2),
                    np.round(np.mean(output_nrmse), decimals=2),
                    np.round(np.mean(output_ssim), decimals=2),
                )

                # psnr, nrmse, ssim = metrics.compute_all(
                #     complex_truth, bart_test, sos_axis=-1
                # )

                # psnr = np.round(psnr, decimals=2)
                # nrmse = np.round(nrmse, decimals=2)
                # ssim = np.round(ssim, decimals=2)

                # cs_psnr.append(psnr)
                # cs_nrmse.append(nrmse)
                # cs_ssim.append(ssim)
                # print("cs psnr, nrmse, ssim")
                # print(
                #     np.mean(psnr), np.mean(nrmse), np.mean(ssim),
                # )

            # mag_input = np.squeeze(np.absolute(input_image))
            # mag_cs = np.squeeze(np.absolute(bart_test))
            # mag_output = np.squeeze(np.absolute(output_image))

            # if self.data_type is "knee":
            #     # save as png
            #     mag_input = np.rot90(mag_input, k=3)
            #     # norm_max = np.amax(mag_input)
            #     mag_cs = np.rot90(mag_cs, k=3)
            #     mag_output = np.rot90(mag_output, k=3)

            #     mag_images = np.concatenate((mag_input, mag_cs, mag_output), axis=1)
            #     filename = self.image_dir + "/mag_" + str(step) + ".png"
            #     scipy.misc.imsave(filename, mag_images)
            if self.data_type is "DCE":
                gif = []
                for f in range(max_frame):
                    frame_input = mag_input[:, :, f]
                    frame_output = mag_output[:, :, f]
                    frame_cs = mag_cs[:, :, f]
                    rotated_input = np.rot90(frame_input, 1)
                    rotated_output = np.rot90(frame_output, 1)
                    rotated_cs = np.rot90(frame_cs, 1)

                    # normalize to help the CS brightness
                    newMax = np.max(rotated_input)
                    newMin = np.min(rotated_input)
                    oldMax = np.max(rotated_cs)
                    oldMin = np.min(rotated_cs)
                    rotated_cs = (rotated_cs - oldMin) * (newMax - newMin) / (
                        oldMax - oldMin) + newMin
                    full_images = np.concatenate(
                        (rotated_input, rotated_output, rotated_cs), axis=1)
                    #                 filename = self.log_dir + '/images/case' + str(step) + '_f' + str(f) + '.png'
                    #                 scipy.misc.imsave(filename, full_images)
                    new_filename = (self.log_dir + "/dicoms/" + str(step) +
                                    "_f" + str(f) + ".dcm")
                    self.write_dicom(full_images, new_filename, step, f)

                    # add each frame to a gif
                    gif.append(full_images)

                print("Saving gif")
                gif_path = self.log_dir + "/gifs/slice_" + str(step) + ".gif"
                imageio.mimsave(gif_path, gif, duration=0.2)

            # Save as PNG
            # filename = self.log_dir + '/images/case' + str(case) + '_f' + str(frame) + '.png'
            # scipy.misc.imsave(filename, saved)
            # Save as DICOM
            if self.data_type is "DCE_2D":
                rotated_input = np.rot90(mag_input, 1)
                rotated_output = np.rot90(mag_output, 1)
                rotated_cs = np.rot90(mag_cs, 1)

                full_images = np.concatenate(
                    (rotated_input, rotated_output, rotated_cs), axis=1)
                full_images = full_images * 10.0
                new_filename = (self.log_dir + "/dicoms/" + str(case) + "_f" +
                                str(frame) + ".dcm")
                dicom_output = np.squeeze(np.abs(full_images))
                self.write_dicom(dicom_output, new_filename, case, frame)

                # Save as PNG
                filename = (self.log_dir + "/images/case" + str(case) + "_f" +
                            str(frame) + ".png")
                scipy.misc.imsave(filename, full_images)

                if frame <= max_frame:
                    gif.append(full_images)
                    frame = frame + 1
                if frame > max_frame:
                    print("Max frame")
                    gif.append(full_images)
                    # gif = gif+100
                    # gif = gif.astype('uint8')
                    # timeMax = np.max(gif, axis=-1)
                    # gif = gif/np.max(gif)
                    # if self.shuffle == "False":
                    print("Saving gif")
                    gif_path = self.log_dir + \
                        "/gifs/new_gif" + str(case) + ".gif"
                    imageio.mimsave(gif_path, gif, "GIF", duration=0.2)

                    # return back to next case
                    frame = 1
                    case = case + 1
                    gif = []

                    # Create a gif
                    if frame <= max_frame:
                        mypath = self.log_dir + "/images/case" + str(case)
                        search_str = "*.png"
                        filenames = sorted(glob.glob(mypath + search_str))
                        # make and save gif
                        gif_path = self.log_dir + \
                            "/gifs/case_" + str(case) + ".gif"

                        images = []
                        for f in filenames:
                            image = scipy.misc.imread(f)
                            images.append(image)
                        filename_gif = mypath + "/case.gif"
                        imageio.mimsave(gif_path, images, duration=0.3)

        txt_path = os.path.join(self.log_dir, "output_metrics.txt")
        f = open(txt_path, "w")
        f.write("output psnr = " + str(np.mean(output_psnr)) + " +\- " +
                str(np.std(output_psnr)) + "\n" + "output nrmse = " +
                str(np.mean(output_nrmse)) + " +\- " +
                str(np.std(output_nrmse)) + "\n" + "output ssim = " +
                str(np.mean(output_ssim)) + " +\- " +
                str(np.std(output_ssim)) + "\n" + "test acc = " +
                str(np.mean(total_acc)) + " +\-" + str(np.std(total_acc)))
        f.close()
示例#10
0
def unroll_fista(
    ks_input,
    sensemap,
    num_grad_steps=5,
    resblock_num_features=128,
    resblock_num_blocks=2,
    is_training=True,
    scope="MRI",
    mask_output=1,
    window=None,
    do_hardproj=True,
    num_summary_image=0,
    mask=None,
    verbose=False,
):
    """Create general unrolled network for MRI.
    x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) )
            = S( x_k - 2 * t * (A^T W A x - A^T W b))
    """
    # get list of GPU devices
    local_device_protos = device_lib.list_local_devices()
    device_list = [x.name for x in local_device_protos if x.device_type == "GPU"]

    if window is None:
        window = 1
    summary_iter = None

    if verbose:
        print(
            "%s> Building FISTA unrolled network (%d steps)...."
            % (scope, num_grad_steps)
        )
        if sensemap is not None:
            print("%s>   Using sensitivity maps..." % scope)

    with tf.variable_scope(scope):
        if mask is None:
            mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
        ks_input = mask * ks_input
        ks_0 = ks_input
        # x0 = A^T W b
        im_0 = tf_util.model_transpose(ks_0 * window, sensemap)
        im_0 = tf.identity(im_0, name="input_image")
        # To be updated
        ks_k = ks_0
        im_k = im_0

        for i_step in range(num_grad_steps):
            iter_name = "iter_%02d" % i_step
            i_device = int(len(device_list) * i_step / num_grad_steps)
            # cur_device = device_list[i_device]
            cur_device = device_list[-1]

            with tf.device(cur_device), tf.variable_scope(iter_name):
            # with tf.variable_scope(iter_name):
                # = S( x_k - 2 * t * (A^T W A x_k - A^T W b))
                # = S( x_k - 2 * t * (A^T W A x_k - x0))
                with tf.variable_scope("update"):
                    im_k_orig = im_k

                    # xk = A^T A x_k
                    ks_k = tf_util.model_forward(im_k, sensemap)
                    ks_k = mask * ks_k
                    im_k = tf_util.model_transpose(ks_k * window, sensemap)
                    # xk = A^T A x_k - A^T b
                    im_k = tf_util.complex_to_channels(im_k - im_0)
                    im_k_orig = tf_util.complex_to_channels(im_k_orig)
                    # Update step
                    t_update = tf.get_variable(
                        "t", dtype=tf.float32, initializer=tf.constant([-2.0])
                    )
                    im_k = im_k_orig + t_update * im_k

                with tf.variable_scope("prox"):
                    num_channels_out = im_k.shape[-1]
                    # Default is channels last
                    # im_k = prior_grad_res_net(im_k, training=is_training,
                    #                           cout=num_channels_out)

                    # Transpose channels_last to channels_first
                    im_k = tf.transpose(im_k, [0, 3, 1, 2])
                    im_k = prior_grad_res_net(
                        im_k,
                        training=is_training,
                        num_features=resblock_num_features,
                        num_blocks=resblock_num_blocks,
                        num_features_out=num_channels_out,
                        data_format="channels_first",
                    )
                    im_k = tf.transpose(im_k, [0, 2, 3, 1])

                    im_k = tf_util.channels_to_complex(im_k)

                im_k = tf.identity(im_k, name="image")
                if num_summary_image > 0:
                    with tf.name_scope("summary"):
                        tmp = tf_util.sumofsq(im_k, keep_dims=True)
                        if summary_iter is None:
                            summary_iter = tmp
                        else:
                            summary_iter = tf.concat((summary_iter, tmp), axis=2)
                        # tf.summary.scalar("max/" + iter_name, tf.reduce_max(tmp))

        ks_k = tf_util.model_forward(im_k, sensemap)
        if do_hardproj:
            if verbose:
                print("%s>   Final hard data projection..." % scope)
            # Final data projection
            ks_k = mask * ks_0 + (1 - mask) * ks_k
            if mask_output is not None:
                ks_k = ks_k * mask_output
            im_k = tf_util.model_transpose(ks_k * window, sensemap)

        ks_k = tf.identity(ks_k, name="output_kspace")
        im_k = tf.identity(im_k, name="output_image")

    # if summary_iter is not None:
    #     tf.summary.image("iter/image", summary_iter, max_outputs=num_summary_image)

    return im_k
    def test(self):
        print("testing")
        # read in a correct test dicom file to change it later
        dicom_filename = pydicom.data.get_testdata_files("MR_small.dcm")[0]
        self.ds = pydicom.dcmread(dicom_filename)

        # 18 time frames in each DICOM
        max_frame = self.max_frames
        frame = 0
        x_slice = 0
        case = 1
        gif = []
        print("number of test cases", self.input_files)

        total_acc = []
        mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64)
        numel = tf.cast(tf.size(mask_input), tf.float32)
        acc = numel / tf.reduce_sum(tf.abs(mask_input))

        input_psnr = []
        input_nrmse = []
        input_ssim = []
        output_psnr = []
        output_nrmse = []
        output_ssim = []
        cs_psnr = []
        cs_nrmse = []
        cs_ssim = []

        input_volume = np.zeros((self.max_frames, 192, 80, 180))
        output_volume = np.zeros((self.max_frames, 192, 80, 180))
        cs_volume = np.zeros((self.max_frames, 192, 80, 180))

        model_time = []
        cs_time = []
        for step in range(self.input_files // 2):
            # for step in range(20):
            # DCE_2D: iterator will see each slice followed by 18 time frames
            # then the next slice
            print("test file #", step)
            acc_run = self.sess.run(acc)
            total_acc.append(acc_run)
            print(
                "total test acc:",
                np.round(np.mean(total_acc), decimals=2),
                np.round(np.std(total_acc), decimals=2),
            )
            if self.data_type is "knee":
                # l1 = 0.015
                l1 = 0.0035
            if self.data_type is "DCE":
                # l1 = 0.05
                l1 = 0.01
            if self.data_type is "DCE_2D":
                l1 = 0.05

            model_start_time = time.time()
            (
                input_image,
                output_image,
                complex_truth,
                ks_run,
                sensemap_run,
            ) = self.sess.run([
                self.im_in,
                self.output_image,
                self.complex_truth,
                self.ks,
                self.sensemap,
            ])
            runtime = time.time() - model_start_time
            if step is not 1:
                model_time.append(runtime)
            print("GAN: %s seconds" % np.mean(model_time),
                  "+/- %s" % np.std(model_time))

            # bart_test = np.zeros_like(output_image)
            cs_start_time = time.time()
            bart_test = self.bart_cs(ks_run, sensemap_run, l1=l1)
            runtime = time.time() - cs_start_time
            if step is not 1:
                cs_time.append(runtime)
            print("CS: %s seconds" % np.mean(cs_time),
                  "+/- %s" % np.std(cs_time))

            if self.data_type is "knee":
                input_image = np.squeeze(input_image)
                output_image = np.squeeze(output_image)
                truth_image = np.squeeze(complex_truth)
                cs_image = np.squeeze(bart_test)

                psnr, nrmse, ssim = metrics.compute_all(truth_image,
                                                        cs_image,
                                                        sos_axis=-1)
                cs_psnr.append(psnr)
                cs_nrmse.append(nrmse)
                cs_ssim.append(ssim)

                print("cs psnr, nrmse, ssim")
                print(
                    np.round(np.mean(cs_psnr), decimals=2),
                    np.round(np.mean(cs_nrmse), decimals=2),
                    np.round(np.mean(cs_ssim), decimals=2),
                )

                psnr, nrmse, ssim = metrics.compute_all(truth_image,
                                                        output_image,
                                                        sos_axis=-1)

                output_psnr.append(psnr)
                output_nrmse.append(nrmse)
                output_ssim.append(ssim)

                print("output psnr, nrmse, ssim")
                print(
                    np.round(np.mean(output_psnr), decimals=2),
                    np.round(np.mean(output_nrmse), decimals=2),
                    np.round(np.mean(output_ssim), decimals=2),
                )

                psnr, nrmse, ssim = metrics.compute_all(truth_image,
                                                        input_image,
                                                        sos_axis=-1)
                input_psnr.append(psnr)
                input_nrmse.append(nrmse)
                input_ssim.append(ssim)

                print("input psnr, nrmse, ssim")
                print(
                    np.round(np.mean(input_psnr), decimals=2),
                    np.round(np.mean(input_nrmse), decimals=2),
                    np.round(np.mean(input_ssim), decimals=2),
                )

            def rotate_image(img):
                img = np.squeeze(np.absolute(img))
                if self.data_type is "DCE":
                    img = np.transpose(img, axes=(1, 0, 2))
                    img = np.flip(img, axis=2)  # flip the time
                if self.data_type is "DCE_2D":
                    img = np.transpose(img, axes=(1, 0))
                return img

            mag_input = rotate_image(input_image)
            mag_output = rotate_image(output_image)
            mag_cs = rotate_image(bart_test)

            # x, y, z, time
            if self.data_type is "DCE":
                input_volume[step, :, :, :] = mag_input
                output_volume[step, :, :, :] = mag_output
                cs_volume[step, :, :, :] = mag_cs
            if self.data_type is "DCE_2D":
                input_volume[frame, x_slice, :, :] = mag_input
                output_volume[frame, x_slice, :, :] = mag_output
                cs_volume[frame, x_slice, :, :] = mag_cs

                new_filename = (self.log_dir + "/dicoms/" + "output_slice_" +
                                str(x_slice) + "_f" + str(frame) + ".dcm")
                self.write_dicom(mag_input, new_filename, x_slice, frame)

                # increment frame
                # if frame is 17, go back to next slice
                if frame == self.max_frames - 1:
                    frame = 0
                    x_slice += 1
                else:
                    frame += 1
                print("slice", x_slice, "time frame", frame)

        in_sl = np.abs(input_volume[2, 0, :, :])

        filename = os.path.join(self.log_dir,
                                os.path.basename(self.search_str[:-11]))
        input_dir = filename + "_input" + ".npy"
        output_dir = filename + "_output" + ".npy"
        cs_dir = filename + "_cs" + ".npy"
        print("saving numpy volumes")
        np.save(input_dir, input_volume)
        np.save(output_dir, output_volume)
        np.save(cs_dir, cs_volume)
        print(output_dir)
        print("saving cfl volumes")
        cfl.write(input_dir, input_volume, "R")
        cfl.write(output_dir, output_volume, "R")
        cfl.write(cs_dir, cs_volume, "R")

        if self.data_type is "knee":
            print("output psnr = " + str(np.mean(output_psnr)) + " +\- " +
                  str(np.std(output_psnr)) + "\n" + "output nrmse = " +
                  str(np.mean(output_nrmse)) + " +\- " +
                  str(np.std(output_nrmse)) + "\n" + "output ssim = " +
                  str(np.mean(output_ssim)) + " +\- " +
                  str(np.std(output_ssim)) + "\n" + "test acc = " +
                  str(np.mean(total_acc)) + " +\-" + str(np.std(total_acc)))
            txt_path = os.path.join(self.log_dir, "output_metrics.txt")
            f = open(txt_path, "w")
            f.write("output psnr = " + str(np.mean(output_psnr)) + " +\- " +
                    str(np.std(output_psnr)) + "\n" + "output nrmse = " +
                    str(np.mean(output_nrmse)) + " +\- " +
                    str(np.std(output_nrmse)) + "\n" + "output ssim = " +
                    str(np.mean(output_ssim)) + " +\- " +
                    str(np.std(output_ssim)) + "\n" + "input psnr = " +
                    str(np.mean(input_psnr)) + " +\- " +
                    str(np.std(input_psnr)) + "\n" + "input nrmse = " +
                    str(np.mean(input_nrmse)) + " +\- " +
                    str(np.std(input_nrmse)) + "\n" + "input ssim = " +
                    str(np.mean(input_ssim)) + " +\- " +
                    str(np.std(input_ssim)) + "\n" + "test acc = " +
                    str(np.mean(total_acc)) + " +\-" + str(np.std(total_acc)))
            f.close()
            txt_path = os.path.join(self.log_dir, "cs_metrics.txt")
            f = open(txt_path, "w")
            f.write("cs psnr = " + str(np.mean(cs_psnr)) + " +\- " +
                    str(np.std(cs_psnr)) + "\n" + "output nrmse = " +
                    str(np.mean(cs_nrmse)) + " +\- " + str(np.std(cs_nrmse)) +
                    "\n" + "output ssim = " + str(np.mean(cs_ssim)) + " +\- " +
                    str(np.std(cs_ssim)))
            f.close()