Example #1
0
def _create_summary(sense_place, ks_place, im_out_place, im_truth_place):
    sensemap = sense_place
    ks_input = ks_place
    image_output = im_out_place
    image_truth = im_truth_place

    image_input = tf_util.model_transpose(ks_input, sensemap)
    mask_input = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
    ks_output = tf_util.model_forward(image_output, sensemap)
    ks_truth = tf_util.model_forward(image_truth, sensemap)

    with tf.name_scope("input-output-truth"):
        summary_input = tf_util.sumofsq(ks_input, keep_dims=True)
        summary_output = tf_util.sumofsq(ks_output, keep_dims=True)
        summary_truth = tf_util.sumofsq(ks_truth, keep_dims=True)
        summary_fft = tf.log(
            tf.concat((summary_input, summary_output, summary_truth), axis=2) +
            1e-6)
        tf.summary.image("kspace",
                         summary_fft,
                         max_outputs=FLAGS.num_summary_image)
        summary_input = tf_util.sumofsq(image_input, keep_dims=True)
        summary_output = tf_util.sumofsq(image_output, keep_dims=True)
        summary_truth = tf_util.sumofsq(image_truth, keep_dims=True)
        summary_image = tf.concat(
            (summary_input, summary_output, summary_truth), axis=2)
        tf.summary.image("image",
                         summary_image,
                         max_outputs=FLAGS.num_summary_image)

    with tf.name_scope("truth"):
        summary_truth_real = tf.reduce_sum(image_truth,
                                           axis=-1,
                                           keep_dims=True)
        summary_truth_real = tf.real(summary_truth_real)
        tf.summary.image("image_real",
                         summary_truth_real,
                         max_outputs=FLAGS.num_summary_image)

    with tf.name_scope("mask"):
        summary_mask = tf_util.sumofsq(mask_input, keep_dims=True)
        tf.summary.image("mask",
                         summary_mask,
                         max_outputs=FLAGS.num_summary_image)

    with tf.name_scope("sensemap"):
        summary_map = tf.slice(tf.abs(sensemap), [0, 0, 0, 0, 0],
                               [-1, -1, -1, 1, -1])
        summary_map = tf.transpose(summary_map, [0, 1, 4, 2, 3])
        summary_map = tf.reshape(
            summary_map,
            [tf.shape(summary_map)[0],
             tf.shape(summary_map)[1], -1])
        summary_map = tf.expand_dims(summary_map, axis=-1)
        tf.summary.image("image",
                         summary_map,
                         max_outputs=FLAGS.num_summary_image)
Example #2
0
    def create_summary(self):
        # note that ks is based on the input data not on the rotated data
        output_ks = tf_util.model_forward(self.output_image, self.sensemap)

        # Input image to generator
        self.input_image = tf_util.model_transpose(self.ks, self.sensemap)

        if self.data_type is "knee":
            truth_image = tf_util.channels_to_complex(self.z_truth)

            sum_input = tf.image.flip_up_down(
                tf.image.rot90(tf.abs(self.input_image)))
            sum_output = tf.image.flip_up_down(
                tf.image.rot90(tf.abs(self.output_image)))
            sum_truth = tf.image.flip_up_down(
                tf.image.rot90(tf.abs(truth_image)))

            train_out = tf.concat((sum_input, sum_output, sum_truth), axis=2)
            tf.summary.image("input-output-truth", train_out)

            mask_input = tf_util.kspace_mask(self.ks, dtype=tf.complex64)

            loss_l1 = tf.reduce_mean(tf.abs(self.X_gen - self.z_truth))
            loss_l2 = tf.reduce_mean(
                tf.square(tf.abs(self.X_gen - self.z_truth)))
            tf.summary.scalar("l1", loss_l1)
            tf.summary.scalar("l2", loss_l2)

            # to check supervised/unsupervised
            y_real = tf_util.channels_to_complex(self.Y_real)
            y_real = tf.image.flip_up_down(tf.image.rot90(tf.abs(y_real)))
            tf.summary.image("y_real/mag", y_real)

        # Plot losses
        self.d_loss_sum = tf.summary.scalar("Discriminator_loss", self.d_loss)
        self.g_loss_sum = tf.summary.scalar("Generator_loss", self.g_loss)
        self.gp_sum = tf.summary.scalar("Gradient_penalty",
                                        self.gradient_penalty)
        self.d_fake = tf.summary.scalar("subloss/D_fake",
                                        tf.reduce_mean(self.d_logits_fake))
        self.d_real = tf.summary.scalar("subloss/D_real",
                                        tf.reduce_mean(self.d_logits_real))
        self.z_sum = tf.summary.histogram(
            "z", tf_util.complex_to_channels(self.input_image))
        self.d_sum = tf.summary.merge(
            [self.z_sum, self.d_loss_sum, self.d_fake, self.d_real])
        self.g_sum = tf.summary.merge([self.z_sum, self.g_loss_sum])
        self.train_sum = tf.summary.merge_all()
def unroll_fista(
    ks_input,
    sensemap,
    num_grad_steps=5,
    resblock_num_features=128,
    resblock_num_blocks=2,
    is_training=True,
    scope="MRI",
    mask_output=1,
    window=None,
    do_hardproj=True,
    num_summary_image=0,
    mask=None,
    verbose=False,
    conv="real",
    do_conjugate=False,
    activation="relu",
):
    """Create general unrolled network for MRI.
    x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) )
            = S( x_k - 2 * t * (A^T W A x - A^T W b))
    """
    if window is None:
        window = 1
    summary_iter = None

    global type_conv
    type_conv = conv
    global conjugate
    conjugate = do_conjugate

    if verbose:
        print(
            "%s> Building FISTA unrolled network (%d steps)...."
            % (scope, num_grad_steps)
        )
        if sensemap is not None:
            print("%s>   Using sensitivity maps..." % scope)
    with tf.variable_scope(scope):
        if mask is None:
            mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
        ks_input = mask * ks_input
        ks_0 = ks_input
        # x0 = A^T W b
        im_0 = tf_util.model_transpose(ks_0 * window, sensemap)
        im_0 = tf.identity(im_0, name="input_image")
        # To be updated
        ks_k = ks_0
        im_k = im_0

        for i_step in range(num_grad_steps):
            iter_name = "iter_%02d" % i_step
            with tf.variable_scope(iter_name):
                # = S( x_k - 2 * t * (A^T W A x_k - A^T W b))
                # = S( x_k - 2 * t * (A^T W A x_k - x0))
                with tf.variable_scope("update"):
                    im_k_orig = im_k
                    # xk = A^T A x_k
                    ks_k = tf_util.model_forward(im_k, sensemap)
                    ks_k = mask * ks_k
                    im_k = tf_util.model_transpose(ks_k * window, sensemap)
                    # xk = A^T A x_k - A^T b
                    im_k = tf_util.complex_to_channels(im_k - im_0)
                    im_k_orig = tf_util.complex_to_channels(im_k_orig)
                    # Update step
                    t_update = tf.get_variable(
                        "t", dtype=tf.float32, initializer=tf.constant([-2.0])
                    )
                    im_k = im_k_orig + t_update * im_k

                with tf.variable_scope("prox"):
                    num_channels_out = im_k.shape[-1]
                    im_k = prior_grad_res_net(
                        im_k,
                        training=is_training,
                        num_features=resblock_num_features,
                        num_blocks=resblock_num_blocks,
                        num_features_out=num_channels_out,
                        data_format="channels_last",
                        activation=activation
                    )
                    im_k = tf_util.channels_to_complex(im_k)

                im_k = tf.identity(im_k, name="image")
                if num_summary_image > 0:
                    with tf.name_scope("summary"):
                        tmp = tf_util.sumofsq(im_k, keep_dims=True)
                        if summary_iter is None:
                            summary_iter = tmp
                        else:
                            summary_iter = tf.concat(
                                (summary_iter, tmp), axis=2)
                        tf.summary.scalar("max/" + iter_name,
                                          tf.reduce_max(tmp))

        ks_k = tf_util.model_forward(im_k, sensemap)
        if do_hardproj:
            if verbose:
                print("%s>   Final hard data projection..." % scope)
            # Final data projection
            ks_k = mask * ks_0 + (1 - mask) * ks_k
            if mask_output is not None:
                ks_k = ks_k * mask_output
            im_k = tf_util.model_transpose(ks_k * window, sensemap)

        ks_k = tf.identity(ks_k, name="output_kspace")
        im_k = tf.identity(im_k, name="output_image")

    if summary_iter is not None:
        tf.summary.image("iter/image", summary_iter,
                         max_outputs=num_summary_image)

    return im_k
Example #4
0
def unroll_fista(
    ks_input,
    sensemap,
    scope="MRI",
    num_grad_steps=5,
    num_resblocks=4,
    num_features=64,
    kernel_size=[3, 3, 5],
    is_training=True,
    mask_output=1,
    mask=None,
    window=None,
    do_hardproj=False,
    do_dense=False,
    do_separable=False,
    do_rnn=False,
    do_circular=True,
    batchnorm=False,
    leaky=False,
    fix_update=False,
    data_format="channels_first",
    verbose=False,
):
    """Create general unrolled network for MRI.
    x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) )
            = S( x_k - 2 * t * (A^T W A x - A^T W b))
    """
    # get list of GPU devices
    local_device_protos = device_lib.list_local_devices()
    device_list = [x.name for x in local_device_protos if x.device_type == "GPU"]

    if window is None:
        window = 1
    summary_iter = {}

    if verbose:
        print("%s> Building FISTA unrolled network...." % scope)
        print("%s>   Num of gradient steps: %d" % (scope, num_grad_steps))
        print(
            "%s>   Prior: %d ResBlocks, %d features"
            % (scope, num_resblocks, num_features)
        )
        print("%s>   Kernel size: [%d x %d x %d]" % ((scope,) + tuple(kernel_size)))
        if do_rnn:
            print("%s>   Sharing weights across iterations..." % scope)
        if sensemap is not None:
            print("%s>   Using sensitivity maps..." % scope)
        if do_dense:
            print("%s>   Inserting dense connections..." % scope)
        if do_circular:
            print("%s>   Using circular padding..." % scope)
        if do_separable:
            print("%s>   Using depth-wise separable convolutions..." % scope)
        if not batchnorm:
            print("%s>   Turning off batch normalization..." % scope)

    with tf.variable_scope(scope):
        if mask is None:
            mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
        ks_input = mask * ks_input
        ks_0 = ks_input
        # x0 = A^T W b
        im_0 = tf_util.model_transpose(ks_0 * window, sensemap)
        im_0 = tf.identity(im_0, name="input_image")
        # To be updated
        ks_k = ks_0
        im_k = im_0
        im_dense = None

        for i_step in range(num_grad_steps):
            iter_name = "iter_%02d" % i_step
            if do_rnn:
                scope_name = "iter"
            else:
                scope_name = iter_name

            # figure out which GPU to use for this step
            # i_device = int(len(device_list) * i_step / num_grad_steps)
            # cur_device = device_list[i_device]

            # with tf.device(cur_device):
            with tf.variable_scope(
                scope_name, reuse=(tf.AUTO_REUSE if do_rnn else False)
            ):
                with tf.variable_scope("update"):
                    # = S( x_k - 2 * t * (A^T W A x_k - A^T W b))
                    # = S( x_k - 2 * t * (A^T W A x_k - x0))
                    im_k_orig = im_k
                    # xk = A^T A x_k
                    ks_k = tf_util.model_forward(im_k, sensemap)
                    ks_k = mask * ks_k
                    im_k = tf_util.model_transpose(ks_k * window, sensemap)
                    # xk = A^T A x_k - A^T b
                    im_k = tf_util.complex_to_channels(im_k - im_0)
                    im_k_orig = tf_util.complex_to_channels(im_k_orig)
                    # Update step
                    if fix_update:
                        t_update = -2.0
                    else:
                        t_update = tf.get_variable(
                            "t", dtype=tf.float32, initializer=tf.constant([-2.0])
                        )
                    im_k = im_k_orig + t_update * im_k

                with tf.variable_scope("prox"):
                    # default is channels_last
                    num_channels_out = im_k.shape[-1]
                    if data_format == "channels_first":
                        im_k = tf.transpose(im_k, [0, 4, 1, 2, 3])

                    if im_dense is not None:
                        im_k = tf.concat([im_k, im_dense], axis=1)

                    im_k, im_dense_k = prior_grad_res_net(
                        im_k,
                        training=is_training,
                        num_features=num_features,
                        num_blocks=num_resblocks,
                        num_features_out=num_channels_out,
                        kernel_size=kernel_size,
                        data_format=data_format,
                        circular=do_circular,
                        separable=do_separable,
                        batchnorm=batchnorm,
                        leaky=leaky,
                    )

                    if do_dense:
                        if im_dense is not None:
                            im_dense = tf.concat([im_dense, im_dense_k], axis=1)
                        else:
                            im_dense = im_dense_k

                    if data_format == "channels_first":
                        im_k = tf.transpose(im_k, [0, 2, 3, 4, 1])

                    im_k = tf_util.channels_to_complex(im_k)

                im_k = tf.identity(im_k, name="image")

                with tf.name_scope("summary"):
                    # tmp = tf_util.sumofsq(im_k, keep_dims=True)
                    summary_iter[iter_name] = im_k

        ks_k = tf_util.model_forward(im_k, sensemap)
        if do_hardproj:
            if verbose:
                print("%s>   Final hard data projection..." % scope)
            ks_k = mask * ks_0 + (1 - mask) * ks_k
            if mask_output is not None:
                ks_k = ks_k * mask_output
            im_k = tf_util.model_transpose(ks_k * window, sensemap)

        ks_k = tf.identity(ks_k, name="output_kspace")
        im_k = tf.identity(im_k, name="output_image")

    #     return im_k, ks_k, summary_iter
    return im_k
Example #5
0
    def measure(self, X_gen, sensemap, real_mask):
        name = "measure"
        random_seed = 0
        verbose = True
        image = tf_util.channels_to_complex(X_gen)
        kspace = tf_util.model_forward(image, sensemap)
        # input_shape = tf.shape(kspace)

        total_kspace = None
        if (self.data_type is "DCE" or self.data_type is "DCE_2D"
                or self.data_type is "mfast"):
            print("DCE measure")
            for i in range(self.batch_size):
                if self.dims == 4:
                    ks_x = kspace[i, :, :]
                else:
                    # 2D plus time
                    ks_x = kspace[i, :, :, :]
                # lazy: use original applied mask and just apply it again
                # won't work because it isn't doing anything unless it gets flipped
                # mask_x = tf_util.kspace_mask(ks_x, dtype=tf.complex64)
                # # Augment sampling masks

                # # New mask - taken from image B
                # # mask = real_mask
                # mask_x = tf.image.flip_up_down(mask_x)
                # mask_x = tf.image.flip_left_right(mask_x)

                # if self.dims != 4:
                #     mask_x = mask_x[:,:,:,0,:]

                # data dimensions
                shape_y = self.width
                shape_t = self.max_frames
                sim_partial_ky = 0.0

                accs = [1, 6]
                rand_accel = (accs[1] - accs[0]) * \
                    tf.random_uniform([]) + accs[0]
                fn_inputs = [
                    shape_y,
                    shape_t,
                    rand_accel,
                    10,
                    2.0,
                    sim_partial_ky,
                ]  # ny, nt, accel, ncal, vd_degree
                mask_x = tf.py_func(mask.generate_perturbed2dvdkt, fn_inputs,
                                    tf.complex64)
                ks_x = ks_x * tf.reshape(mask_x, [1, shape_y, shape_t, 1])

                if total_kspace is not None:
                    total_kspace = tf.concat([total_kspace, ks_x], 0)
                else:
                    total_kspace = ks_x
        if self.data_type is "knee":
            print("knee")
            for i in range(self.batch_size):
                ks_x = kspace[i, :, :]
                print("ks_x", ks_x)
                # Randomly select mask
                mask_x = tf.random_shuffle(self.masks)
                print("self masks", mask_x)
                mask_x = tf.slice(mask_x, [0, 0, 0], [1, -1, -1])
                print("sliced masks", mask_x)
                # Augment sampling masks
                mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed)
                mask_x = tf.image.random_flip_left_right(mask_x,
                                                         seed=random_seed)
                # Tranpose to store data as (kz, ky, channels)
                mask_x = tf.transpose(mask_x, [1, 2, 0])
                print("transposed mask", mask_x)
                ks_x = tf.image.flip_up_down(ks_x)
                # Initially set image size to be all the same
                ks_x = tf.image.resize_image_with_crop_or_pad(
                    ks_x, self.height, self.width)
                mask_x = tf.image.resize_image_with_crop_or_pad(
                    mask_x, self.height, self.width)
                print("resized mask", mask_x)
                shape_cal = 20
                if shape_cal > 0:
                    with tf.name_scope("CalibRegion"):
                        if self.verbose:
                            print("%s>  Including calib region (%d, %d)..." %
                                  (name, shape_cal, shape_cal))
                        mask_calib = tf.ones([shape_cal, shape_cal, 1],
                                             dtype=tf.complex64)
                        mask_calib = tf.image.resize_image_with_crop_or_pad(
                            mask_calib, self.height, self.width)
                        mask_x = mask_x * (1 - mask_calib) + mask_calib

                    mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x))
                    mask_recon = tf.cast(mask_recon > 1e-7, dtype=tf.complex64)
                    mask_x = mask_x * mask_recon
                    print("mask x", mask_x)
                    # Assuming calibration region is fully sampled
                    shape_sc = 5
                    scale = tf.image.resize_image_with_crop_or_pad(
                        ks_x, shape_sc, shape_sc)
                    scale = tf.reduce_mean(tf.square(
                        tf.abs(scale))) * (shape_sc * shape_sc / 1e5)
                    scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64)
                    ks_x = ks_x * scale
                    # Masked input
                    ks_x = tf.multiply(ks_x, mask_x)
                if total_kspace is not None:
                    total_kspace = tf.concat([total_kspace, ks_x], 0)
                else:
                    total_kspace = ks_x

        x_measured = tf_util.model_transpose(total_kspace,
                                             sensemap,
                                             name="x_measured")
        x_measured = tf_util.complex_to_channels(x_measured)
        return x_measured
def unroll_fista(
    ks_input,
    sensemap,
    num_grad_steps=5,
    resblock_num_features=128,
    resblock_num_blocks=2,
    is_training=True,
    scope="MRI",
    mask_output=1,
    window=None,
    do_hardproj=True,
    num_summary_image=0,
    mask=None,
    verbose=False,
):
    """Create general unrolled network for MRI.
    x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) )
            = S( x_k - 2 * t * (A^T W A x - A^T W b))
    """
    # get list of GPU devices
    local_device_protos = device_lib.list_local_devices()
    device_list = [x.name for x in local_device_protos if x.device_type == "GPU"]

    if window is None:
        window = 1
    summary_iter = None

    if verbose:
        print(
            "%s> Building FISTA unrolled network (%d steps)...."
            % (scope, num_grad_steps)
        )
        if sensemap is not None:
            print("%s>   Using sensitivity maps..." % scope)

    with tf.variable_scope(scope):
        if mask is None:
            mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64)
        ks_input = mask * ks_input
        ks_0 = ks_input
        # x0 = A^T W b
        im_0 = tf_util.model_transpose(ks_0 * window, sensemap)
        im_0 = tf.identity(im_0, name="input_image")
        # To be updated
        ks_k = ks_0
        im_k = im_0

        for i_step in range(num_grad_steps):
            iter_name = "iter_%02d" % i_step
            i_device = int(len(device_list) * i_step / num_grad_steps)
            # cur_device = device_list[i_device]
            cur_device = device_list[-1]

            with tf.device(cur_device), tf.variable_scope(iter_name):
            # with tf.variable_scope(iter_name):
                # = S( x_k - 2 * t * (A^T W A x_k - A^T W b))
                # = S( x_k - 2 * t * (A^T W A x_k - x0))
                with tf.variable_scope("update"):
                    im_k_orig = im_k

                    # xk = A^T A x_k
                    ks_k = tf_util.model_forward(im_k, sensemap)
                    ks_k = mask * ks_k
                    im_k = tf_util.model_transpose(ks_k * window, sensemap)
                    # xk = A^T A x_k - A^T b
                    im_k = tf_util.complex_to_channels(im_k - im_0)
                    im_k_orig = tf_util.complex_to_channels(im_k_orig)
                    # Update step
                    t_update = tf.get_variable(
                        "t", dtype=tf.float32, initializer=tf.constant([-2.0])
                    )
                    im_k = im_k_orig + t_update * im_k

                with tf.variable_scope("prox"):
                    num_channels_out = im_k.shape[-1]
                    # Default is channels last
                    # im_k = prior_grad_res_net(im_k, training=is_training,
                    #                           cout=num_channels_out)

                    # Transpose channels_last to channels_first
                    im_k = tf.transpose(im_k, [0, 3, 1, 2])
                    im_k = prior_grad_res_net(
                        im_k,
                        training=is_training,
                        num_features=resblock_num_features,
                        num_blocks=resblock_num_blocks,
                        num_features_out=num_channels_out,
                        data_format="channels_first",
                    )
                    im_k = tf.transpose(im_k, [0, 2, 3, 1])

                    im_k = tf_util.channels_to_complex(im_k)

                im_k = tf.identity(im_k, name="image")
                if num_summary_image > 0:
                    with tf.name_scope("summary"):
                        tmp = tf_util.sumofsq(im_k, keep_dims=True)
                        if summary_iter is None:
                            summary_iter = tmp
                        else:
                            summary_iter = tf.concat((summary_iter, tmp), axis=2)
                        # tf.summary.scalar("max/" + iter_name, tf.reduce_max(tmp))

        ks_k = tf_util.model_forward(im_k, sensemap)
        if do_hardproj:
            if verbose:
                print("%s>   Final hard data projection..." % scope)
            # Final data projection
            ks_k = mask * ks_0 + (1 - mask) * ks_k
            if mask_output is not None:
                ks_k = ks_k * mask_output
            im_k = tf_util.model_transpose(ks_k * window, sensemap)

        ks_k = tf.identity(ks_k, name="output_kspace")
        im_k = tf.identity(im_k, name="output_image")

    # if summary_iter is not None:
    #     tf.summary.image("iter/image", summary_iter, max_outputs=num_summary_image)

    return im_k
    def measure(self, X_gen, sensemap, real_mask):
        name = "measure"
        random_seed = 0
        verbose = True
        image = tf_util.channels_to_complex(X_gen)
        kspace = tf_util.model_forward(image, sensemap)
        total_kspace = None
        if (self.data_type is "DCE"
                # or self.data_type is "DCE_2D"
                # or self.data_type is "mfast"
            ):
            # remove batch dimension
            # kspace = kspace[0, :, :, :, :]
            kspace = tf.squeeze(kspace, axis=0)
            # different mask for each frame
            for f in range(self.max_frames):
                ks_x = kspace[:, :, f, :]
                # Randomly select mask
                mask_x = tf.random_shuffle(self.masks)
                mask_x = mask_x[0, :, :]
                mask_x = tf.expand_dims(mask_x, axis=0)
                # Augment sampling masks
                mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed)
                mask_x = tf.image.random_flip_left_right(mask_x,
                                                         seed=random_seed)
                # Tranpose to store data as (kz, ky, channels)
                mask_x = tf.transpose(mask_x, [1, 2, 0])
                # self.applied_mask = tf.expand_dims(mask_x, axis=-1)
                ks_x = tf.image.flip_up_down(ks_x)
                # Initially set image size to be all the same
                ks_x = tf.image.resize_image_with_crop_or_pad(
                    ks_x, self.height, self.width)
                mask_x = tf.image.resize_image_with_crop_or_pad(
                    mask_x, self.height, self.width)
                shape_cal = 20
                if shape_cal > 0:
                    with tf.name_scope("CalibRegion"):
                        if self.verbose:
                            print("%s>  Including calib region (%d, %d)..." %
                                  (name, shape_cal, shape_cal))
                        mask_calib = tf.ones([shape_cal, shape_cal, 1],
                                             dtype=tf.complex64)
                        mask_calib = tf.image.resize_image_with_crop_or_pad(
                            mask_calib, self.height, self.width)
                        mask_x = mask_x * (1 - mask_calib) + mask_calib

                    # mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x))
                    # mask_recon = tf.cast(mask_recon > 0.0, dtype=tf.complex64)
                    # mask_x = mask_x * mask_recon
                    # mask_x = tf.expand_dims(mask_x, axis=0)

                    # Assuming calibration region is fully sampled
                    shape_sc = 5
                    scale = tf.image.resize_image_with_crop_or_pad(
                        ks_x, shape_sc, shape_sc)
                    scale = tf.reduce_mean(tf.square(
                        tf.abs(scale))) * (shape_sc * shape_sc / 1e5)
                    scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64)
                    ks_x = ks_x * scale
                    ks_x = tf.multiply(ks_x, mask_x)
                    ks_x = tf.expand_dims(ks_x, axis=-2)
                    self.applied_mask = tf.expand_dims(mask_x, axis=0)
                if total_kspace is not None:
                    total_kspace = tf.concat([total_kspace, ks_x], axis=-2)
                else:
                    total_kspace = ks_x
            total_kspace = tf.expand_dims(total_kspace, axis=0)
            # for i in range(self.batch_size):
            # if self.dims == 4:
            #     ks_x = kspace[i, :, :]
            # else:
            #     # 2D plus time
            #     ks_x = kspace[i, :, :, :]
            # # lazy: use original applied mask and just apply it again
            # # won't work because it isn't doing anything unless it gets flipped
            # # mask_x = tf_util.kspace_mask(ks_x, dtype=tf.complex64)
            # # # Augment sampling masks

            # # # New mask - taken from image B
            # # # mask = real_mask
            # # mask_x = tf.image.flip_up_down(mask_x)
            # # mask_x = tf.image.flip_left_right(mask_x)

            # # if self.dims != 4:
            # #     mask_x = mask_x[:,:,:,0,:]

            # # data dimensions
            # shape_y = self.width
            # # shape_t = self.max_frames
            # shape_t = self.height
            # sim_partial_ky = 0.0

            # # accs = [1, 6]
            # accs = [5, 6]
            # rand_accel = (accs[1] - accs[0]) * tf.random_uniform([]) + accs[0]
            # tf.summary.scalar("acc", rand_accel)
            # fn_inputs = [
            #     shape_y,
            #     shape_t,
            #     rand_accel,
            #     10,
            #     2.0,
            #     sim_partial_ky,
            # ]  # ny, nt, accel, ncal, vd_degree
            # mask_x = tf.py_func(
            #     mask.generate_perturbed2dvdkt, fn_inputs, tf.complex64
            # )
            # print("mask x", mask_x)
            # self.reshaped_mask = mask_x
            # self.reshaped_mask = tf.reshape(mask_x, [shape_t, shape_y, 1, 1])
            # ks_x = ks_x * self.reshaped_mask

            # if total_kspace is not None:
            #     total_kspace = tf.concat([total_kspace, ks_x], 0)
            # else:
            #     total_kspace = ks_x
        if self.data_type is "DCE_2D":
            # remove batch dimension
            kspace = tf.squeeze(kspace, axis=0)
            # ks_x = kspace[:, :, f, :]
            ks_x = kspace
            # Randomly select mask
            mask_x = tf.random_shuffle(self.masks)
            mask_x = mask_x[0, :, :]
            mask_x = tf.expand_dims(mask_x, axis=0)
            # Augment sampling masks
            mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed)
            mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed)
            # Tranpose to store data as (kz, ky, channels)
            mask_x = tf.transpose(mask_x, [1, 2, 0])
            # self.applied_mask = tf.expand_dims(mask_x, axis=-1)
            ks_x = tf.image.flip_up_down(ks_x)
            # Initially set image size to be all the same
            ks_x = tf.image.resize_image_with_crop_or_pad(
                ks_x, self.height, self.width)
            mask_x = tf.image.resize_image_with_crop_or_pad(
                mask_x, self.height, self.width)
            shape_cal = 20
            if shape_cal > 0:
                with tf.name_scope("CalibRegion"):
                    if self.verbose:
                        print("%s>  Including calib region (%d, %d)..." %
                              (name, shape_cal, shape_cal))
                    mask_calib = tf.ones([shape_cal, shape_cal, 1],
                                         dtype=tf.complex64)
                    mask_calib = tf.image.resize_image_with_crop_or_pad(
                        mask_calib, self.height, self.width)
                    mask_x = mask_x * (1 - mask_calib) + mask_calib

                # mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x))
                # mask_recon = tf.cast(mask_recon > 0.0, dtype=tf.complex64)
                # mask_x = mask_x * mask_recon
                # mask_x = tf.expand_dims(mask_x, axis=0)

                # Assuming calibration region is fully sampled
                shape_sc = 5
                scale = tf.image.resize_image_with_crop_or_pad(
                    ks_x, shape_sc, shape_sc)
                scale = tf.reduce_mean(tf.square(
                    tf.abs(scale))) * (shape_sc * shape_sc / 1e5)
                scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64)
                ks_x = ks_x * scale
                ks_x = tf.multiply(ks_x, mask_x)
                self.applied_mask = tf.expand_dims(mask_x, axis=0)
                total_kspace = ks_x
            total_kspace = tf.expand_dims(total_kspace, axis=0)

        if self.data_type is "knee":
            for i in range(self.batch_size):
                ks_x = kspace[i, :, :]
                # Randomly select mask
                mask_x = tf.random_shuffle(self.masks)
                mask_x = tf.slice(mask_x, [0, 0, 0], [1, -1, -1])
                # Augment sampling masks
                mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed)
                mask_x = tf.image.random_flip_left_right(mask_x,
                                                         seed=random_seed)
                # Tranpose to store data as (kz, ky, channels)
                mask_x = tf.transpose(mask_x, [1, 2, 0])
                ks_x = tf.image.flip_up_down(ks_x)
                # Initially set image size to be all the same
                ks_x = tf.image.resize_image_with_crop_or_pad(
                    ks_x, self.height, self.width)
                mask_x = tf.image.resize_image_with_crop_or_pad(
                    mask_x, self.height, self.width)

                shape_cal = 20
                if shape_cal > 0:
                    with tf.name_scope("CalibRegion"):
                        if self.verbose:
                            print("%s>  Including calib region (%d, %d)..." %
                                  (name, shape_cal, shape_cal))
                        mask_calib = tf.ones([shape_cal, shape_cal, 1],
                                             dtype=tf.complex64)
                        mask_calib = tf.image.resize_image_with_crop_or_pad(
                            mask_calib, self.height, self.width)
                        mask_x = mask_x * (1 - mask_calib) + mask_calib

                    mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x))
                    # mask_recon = tf.abs(ks_x)
                    mask_recon = tf.cast(mask_recon > 1e-7, dtype=tf.complex64)
                    mask_x = mask_x * mask_recon

                    # Assuming calibration region is fully sampled
                    shape_sc = 5
                    scale = tf.image.resize_image_with_crop_or_pad(
                        ks_x, shape_sc, shape_sc)
                    scale = tf.reduce_mean(tf.square(
                        tf.abs(scale))) * (shape_sc * shape_sc / 1e5)
                    scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64)
                    ks_x = ks_x * scale
                    # Masked input
                    ks_x = tf.multiply(ks_x, mask_x)

                if total_kspace is not None:
                    total_kspace = tf.concat([total_kspace, ks_x], 0)
                else:
                    total_kspace = ks_x

        x_measured = tf_util.model_transpose(total_kspace,
                                             sensemap,
                                             name="x_measured")
        x_measured = tf_util.complex_to_channels(x_measured)
        return x_measured