def get_fastmri_slices_from_dir(image_dir,
                                batch_size,
                                fs=False,
                                corruption_frac=1.0):
    """Load and normalize MRI dataset.

    Args:
      image_dir(str): Directory containing 3D MRI volumes of shape (?, n, n); each volume is stored as a '.h5' file in the FastMRI format
      batch_size(int): Number of input-output pairs in each batch
      fs(Boolean): Whether to read images with fat suppression (True) or without (False)
      corruption_frac(float): Probability with which to zero a line in k-space

    Returns:
      float: A numpy array of size (num_images, n, n) containing all image slices

    """
    image_names = os.listdir(image_dir)

    found_vol = False
    while not found_vol:
        img_i = np.random.randint(0, len(image_names))
        img = image_names[img_i]

        with h5py.File(os.path.join(image_dir, img), "r") as f:
            if (('CORPDFS' in f.attrs['acquisition']) == fs):
                n_slices = f['kspace'].shape[0]

                kspace_fulls = np.empty(
                    (batch_size, f['kspace'].shape[1], f['kspace'].shape[2]),
                    dtype=complex)
                masks = np.empty(
                    (batch_size, f['kspace'].shape[1], f['kspace'].shape[2]))
                kspace_masks = kspace_fulls.copy()
                kspace_full_crops = np.empty((batch_size, n_crop, n_crop),
                                             dtype=complex)

                for i in range(batch_size):
                    slice_i = np.random.randint(0, n_slices)

                    kspace_full = f['kspace'][slice_i, :, :]

                    mask = get_undersampling_mask(kspace_full.shape,
                                                  corruption_frac)

                    kspace_mask = mask * f['kspace'][slice_i, :, :]
                    kspace_full_crop = models.crop_320(
                        utils.split_reim(np.expand_dims(kspace_full,
                                                        0))[0, :, :, :])
                    kspace_full_crop = utils.join_reim(
                        np.expand_dims(kspace_full_crop, 0))[0, :, :]

                    kspace_fulls[i, ...] = kspace_full
                    kspace_full_crops[i, ...] = kspace_full_crop
                    kspace_masks[i, ...] = kspace_mask
                    masks[i, ...] = mask

                found_vol = True

    return kspace_fulls, kspace_full_crops, kspace_masks, masks
def generate_noisy_data(images,
                        input_domain,
                        output_domain,
                        corruption_frac,
                        batch_size=10):
    """Generator that yields batches of noisy input and correct output data.

    For corrupted inputs, add complex-valued noise with standard deviation corruption_frac at each pixel in k-space.

    Args:
      images(float): Numpy array of input images, of shape (num_images,n,n)
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      corruption_frac(float): Variance of complex-valued noise to be added
      batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10)

    Returns:
      inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2).

    """
    num_batches = np.ceil(len(images) / batch_size)
    img_shape = images.shape[1]

    images = utils.split_reim(images)
    spectra = utils.convert_to_frequency_domain(images)

    while True:
        n = images.shape[1]
        batch_inds = np.random.randint(0, images.shape[0], batch_size)

        inputs = np.empty((0, n, n, 2))
        outputs = np.empty((0, n, n, 2))
        masks = np.empty((0, n, n, 2))

        for j in batch_inds:
            true_img = np.expand_dims(images[j, :, :, :], 0)
            true_k = np.expand_dims(spectra[j, :, :, :], 0)
            mask = np.ones(true_k.shape)

            img_size = images.shape[1]
            noise = np.random.normal(loc=0.0,
                                     scale=corruption_frac,
                                     size=true_k.shape)

            corrupt_k = true_k.copy() + noise
            corrupt_img = utils.convert_to_image_domain(corrupt_k)

            if (input_domain == 'FREQ'):
                inputs = np.append(inputs, corrupt_k, axis=0)
                masks = np.append(masks, mask, axis=0)
            elif (input_domain == 'IMAGE'):
                inputs = np.append(inputs, corrupt_img, axis=0)

            if (output_domain == 'FREQ'):
                outputs = np.append(outputs, true_k, axis=0)
            elif (output_domain == 'IMAGE'):
                outputs = np.append(outputs, true_img, axis=0)

        yield (inputs, outputs)
Example #3
0
    def test_split_reim(self):
        compl_array = np.full((1, 3, 3), 1 + 1j)
        real_array = np.full((1, 3, 3), 1.0)
        imag_array = np.full((1, 3, 3), 1.0)
        split_array = np.stack([real_array, imag_array], 3)

        self.assertIsNone(
            np.testing.assert_array_equal(utils.split_reim(compl_array),
                                          split_array))
Example #4
0
def get_mri_spectra_stats(images):
    """Compute mean and stddev of MRI spectra.

    Args:
      images(float): Numpy array of shape (num_images, n, n) containing input images.

    Returns:
      float: Numpy array of shape (1, n, n, 2) containing pixelwise mean of the real and imaginary parts of the Fourier spectra of the input images
      float: Numpy array of shape (1, n, n, 2) containing pixelwise standard deviation of the real and imaginary parts of the Fourier spectra of the input images

    """
    images = utils.split_reim(images)
    spectra = utils.convert_to_frequency_domain(images)

    spectra_mean = np.mean(spectra, axis=0, keepdims=True)
    spectra_std = np.clip(np.std(spectra, axis=0, keepdims=True),
                          a_min=1,
                          a_max=None)

    return spectra_mean, spectra_std
Example #5
0
def generate_undersampled_data(images,
                               input_domain,
                               output_domain,
                               corruption_frac,
                               enforce_dc,
                               batch_size=10):
    """Generator that yields batches of undersampled input and correct output data.

    For corrupted inputs, select each line in k-space with probability corruption_frac and set it to zero.

    Args:
      images(float): Numpy array of input images, of shape (num_images, n, n)
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      corruption_frac(float): Probability with which to zero a line in k-space
      batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10)

    Returns:
      inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2).

    """
    num_batches = np.ceil(len(images) / batch_size)
    img_shape = images.shape[1]

    images = utils.split_reim(images)
    spectra = utils.convert_to_frequency_domain(images)

    while True:
        n = images.shape[1]
        batch_inds = np.random.randint(0, images.shape[0], batch_size)

        inputs = np.empty((0, n, n, 2))
        outputs = np.empty((0, n, n, 2))
        masks = np.empty((0, n, n, 2))

        for j in batch_inds:
            true_img = np.expand_dims(images[j, :, :, :], 0)
            true_k = np.expand_dims(spectra[j, :, :, :], 0)
            mask = np.ones(true_k.shape)

            img_size = images.shape[1]
            num_points = int(img_size * corruption_frac)
            coord_list = np.random.choice(range(img_size),
                                          num_points,
                                          replace=False)

            corrupt_k = true_k.copy()
            for k in range(len(coord_list)):
                corrupt_k[0, coord_list[k], :, :] = 0
                mask[0, coord_list[k], :, :] = 0
            corrupt_img = utils.convert_to_image_domain(corrupt_k)

            nf = np.max(corrupt_img)

            if (input_domain == 'FREQ'):
                inputs = np.append(inputs, corrupt_k / nf, axis=0)
                masks = np.append(masks, mask, axis=0)
            elif (input_domain == 'IMAGE'):
                inputs = np.append(inputs, corrupt_img / nf, axis=0)

            if (output_domain == 'FREQ'):
                outputs = np.append(outputs, true_k / nf, axis=0)
            elif (output_domain == 'IMAGE'):
                outputs = np.append(outputs, true_img / nf, axis=0)

        if (enforce_dc):
            yield ((inputs, masks), outputs)
        else:
            yield (inputs, outputs)
Example #6
0
def generate_undersampled_motion_data(images,
                                      input_domain,
                                      output_domain,
                                      us_frac,
                                      mot_frac,
                                      max_htrans,
                                      max_vtrans,
                                      max_rot,
                                      batch_size=10):
    """Generator that yields batches of motion-corrupted, undersampled input and correct output data.

    For corrupted inputs, select some lines at which motion occurs; randomly generate and apply translation/rotations at those lines.

    Args:
      images(float): Numpy array of input images, of shape (num_images,n,n)
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      us_frac(float): Fraction of lines at which motion occurs.
      mot_frac(float): Fraction of lines at which motion occurs.
      max_htrans(float): Maximum fraction of image width for a translation.
      max_vtrans(float): Maximum fraction of image height for a translation.
      max_rot(float): Maximum fraction of 360 degrees for a rotation.
      batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10)

    Returns:
      inputs: Tuple of corrupted input data and correct output data, both numpy arrays of shape (batch_size,n,n,2).

    """
    def get_us_motion_mask(arr_shape, us_frac):
        """ Based on https://github.com/facebookresearch/fastMRI/blob/master/common/subsample.py. """
        num_cols = arr_shape[1]
        if (us_frac != 1):
            acceleration = int(1 / (1 - us_frac))
            center_fraction = (1 - us_frac) * 0.08 / 0.25

            # Create the mask
            num_low_freqs = int(round(num_cols * center_fraction))
            prob = (num_cols / acceleration - num_low_freqs) / \
                (num_cols - num_low_freqs)
            mask_inds = np.random.uniform(size=num_cols) < prob
            pad = (num_cols - num_low_freqs + 1) // 2
            mask_inds[pad:pad + num_low_freqs] = True

            mask = np.zeros(arr_shape)
            mask[:, mask_inds] = 1

            return np.fft.ifftshift(mask).T

        else:
            return (np.ones(arr_shape)).T

    num_batches = np.ceil(len(images) / batch_size)
    img_shape = images.shape[1]

    reim_images = images.copy()
    images = utils.split_reim(images)
    spectra = utils.convert_to_frequency_domain(images)

    while True:
        n = images.shape[1]
        batch_inds = np.random.randint(0, images.shape[0], batch_size)

        inputs = np.empty((0, n, n, 2))
        outputs = np.empty((0, n, n, 2))
        masks = np.empty((0, n, n, 2))

        for j in batch_inds:

            true_img = np.expand_dims(images[j, :, :, :], 0)

            img_size = images.shape[1]
            num_points = int(np.random.random() * mot_frac * n)
            coord_list = np.sort(
                np.random.choice(img_size, size=num_points, replace=False))
            num_pix = np.zeros((num_points, 2))
            angle = np.zeros(num_points)

            max_htrans_pix = n * max_htrans
            max_vtrans_pix = n * max_vtrans
            max_rot_deg = 360 * max_rot

            num_pix[:, 0] = np.random.random(num_points) * (
                2 * max_htrans_pix) - max_htrans_pix
            num_pix[:, 1] = np.random.random(num_points) * (
                2 * max_vtrans_pix) - max_vtrans_pix
            angle = np.random.random(num_points) * \
                (2 * max_rot_deg) - max_rot_deg

            corrupt_k, true_k = motion.add_rotation_and_translations(
                reim_images[j, :, :], coord_list, angle, num_pix)
            true_k = utils.split_reim(np.expand_dims(true_k, 0))
            true_img = utils.convert_to_image_domain(true_k)
            corrupt_k = utils.split_reim(np.expand_dims(corrupt_k, 0))

            mask = get_us_motion_mask(true_img.shape[1:3], us_frac)
            r_mask = np.expand_dims(
                np.repeat(mask[:, :, np.newaxis], 2, axis=-1), 0)

            corrupt_k *= r_mask
            corrupt_img = utils.convert_to_image_domain(corrupt_k)

            nf = np.max(corrupt_img)

            if (input_domain == 'FREQ'):
                inputs = np.append(inputs, corrupt_k / nf, axis=0)
            elif (input_domain == 'IMAGE'):
                inputs = np.append(inputs, corrupt_img / nf, axis=0)

            if (output_domain == 'FREQ'):
                outputs = np.append(outputs, true_k / nf, axis=0)
            elif (output_domain == 'IMAGE'):
                outputs = np.append(outputs, true_img / nf, axis=0)

        yield (inputs, outputs)
Example #7
0
def generate_motion_data(images,
                         input_domain,
                         output_domain,
                         mot_frac,
                         max_htrans,
                         max_vtrans,
                         max_rot,
                         batch_size=10):
    """Generator that yields batches of motion-corrupted input and correct output data.

    For corrupted inputs, select some lines at which motion occurs; randomly generate and apply translation/rotations at those lines.

    Args:
      images(float): Numpy array of input images, of shape (num_images,n,n)
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      mot_frac(float): Fraction of lines at which motion occurs.
      max_htrans(float): Maximum fraction of image width for a translation.
      max_vtrans(float): Maximum fraction of image height for a translation.
      max_rot(float): Maximum fraction of 360 degrees for a rotation.
      batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10)

    Returns:
      inputs: Tuple of corrupted input data and correct output data, both numpy arrays of shape (batch_size,n,n,2).

    """
    num_batches = np.ceil(len(images) / batch_size)
    img_shape = images.shape[1]

    reim_images = images.copy()
    images = utils.split_reim(images)
    spectra = utils.convert_to_frequency_domain(images)

    while True:
        n = images.shape[1]
        batch_inds = np.random.randint(0, images.shape[0], batch_size)

        inputs = np.empty((0, n, n, 2))
        outputs = np.empty((0, n, n, 2))
        masks = np.empty((0, n, n, 2))

        for j in batch_inds:
            true_img = np.expand_dims(images[j, :, :, :], 0)

            img_size = images.shape[1]
            num_points = int(mot_frac * n)
            coord_list = np.sort(
                np.random.choice(img_size, size=num_points, replace=False))
            num_pix = np.zeros((num_points, 2))
            angle = np.zeros(num_points)

            max_htrans_pix = n * max_htrans
            max_vtrans_pix = n * max_vtrans
            max_rot_deg = 360 * max_rot

            num_pix[:, 0] = np.random.random(num_points) * (
                2 * max_htrans_pix) - max_htrans_pix
            num_pix[:, 1] = np.random.random(num_points) * (
                2 * max_vtrans_pix) - max_vtrans_pix
            angle = np.random.random(num_points) * \
                (2 * max_rot_deg) - max_rot_deg

            corrupt_k, true_k = motion.add_rotation_and_translations(
                reim_images[j, :, :], coord_list, angle, num_pix)
            corrupt_k = utils.split_reim(np.expand_dims(corrupt_k, 0))
            true_k = utils.split_reim(np.expand_dims(true_k, 0))

            corrupt_img = utils.convert_to_image_domain(corrupt_k)

            nf = np.max(corrupt_img)

            if (input_domain == 'FREQ'):
                inputs = np.append(inputs, corrupt_k / nf, axis=0)
            elif (input_domain == 'IMAGE'):
                inputs = np.append(inputs, corrupt_img / nf, axis=0)

            if (output_domain == 'FREQ'):
                outputs = np.append(outputs, true_k / nf, axis=0)
            elif (output_domain == 'IMAGE'):
                outputs = np.append(outputs, true_img / nf, axis=0)

        yield (inputs, outputs)
def generate_undersampled_data(image_dir,
                               input_domain,
                               output_domain,
                               corruption_frac,
                               enforce_dc,
                               fs=False,
                               batch_size=16):
    """Generator that yields batches of undersampled input and correct output data.

    For corrupted inputs, select each line in k-space with probability corruption_frac and set it to zero.

    Args:
      image_dir(str): Directory containing 3D MRI volumes
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      corruption_frac(float): Probability with which to zero a line in k-space
      fs(Bool, optional): Whether to read images with fat suppression (True) or without (False)
      batch_size(int, optional): Number of input-output pairs in each batch

    Returns:
      inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2).

    """

    while True:
        images, kspace = get_fastmri_slices_from_dir(image_dir, batch_size, fs)

        images = utils.split_reim(images)
        spectra = utils.convert_to_frequency_domain(images)

        n = images.shape[1]

        inputs = np.empty((0, n, n, 2))
        outputs = np.empty((0, n, n, 2))
        masks = np.empty((0, n, n, 2))

        for j in range(batch_size):
            true_img = np.expand_dims(images[j, :, :, :], 0)
            true_k = np.expand_dims(spectra[j, :, :, :], 0)
            mask = get_undersampling_mask(kspace[j, :, :].shape,
                                          corruption_frac)
            r_mask = np.expand_dims(
                np.repeat(mask[:, :, np.newaxis], 2, axis=-1), 0)

            num_points = int(n * corruption_frac)
            coord_list = np.random.choice(n, num_points, replace=False)

            corrupt_k = kspace[j, :, :] * mask

            # Bring majority of values to 0-1 range.
            corrupt_k = utils.split_reim(np.expand_dims(corrupt_k, 0)) * 500

            corrupt_img = utils.convert_to_image_domain(corrupt_k)

            nf = np.max(np.abs(corrupt_img))

            if (input_domain == 'FREQ'):
                inputs = np.append(inputs, corrupt_k / nf, axis=0)
                masks = np.append(masks, r_mask, axis=0)
            elif (input_domain == 'IMAGE'):
                inputs = np.append(inputs, corrupt_img / nf, axis=0)

            if (output_domain == 'FREQ'):
                outputs = np.append(outputs, true_k / nf, axis=0)
            elif (output_domain == 'IMAGE'):
                outputs = np.append(outputs, true_img / nf, axis=0)

        if (enforce_dc):
            yield ((inputs, masks), outputs)
        else:
            yield (inputs, outputs)
def generate_undersampled_data(image_dir,
                               input_domain,
                               output_domain,
                               corruption_frac,
                               enforce_dc,
                               batch_size,
                               fs=False):
    """Generator that yields batches of undersampled input and correct output data.

    For corrupted inputs, select each line in k-space with probability corruption_frac and set it to zero.

    Args:
      image_dir(str): Directory containing 3D MRI volumes
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      corruption_frac(float): Probability with which to zero a line in k-space
      fs(Bool, optional): Whether to read images with fat suppression (True) or without (False)
      batch_size(int, optional): Number of input-output pairs in each batch

    Returns:
      inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2).

    """

    while True:
        kspace_full, kspace_full_crop, kspace_mask, mask = get_fastmri_slices_from_dir(
            image_dir, batch_size, fs, corruption_frac=corruption_frac)

        mask = np.expand_dims(mask, -1)
        mask = np.repeat(mask, 2, axis=-1)

        kspace_full = utils.split_reim(kspace_full)
        kspace_full_crop = utils.split_reim(kspace_full_crop)
        kspace_mask = utils.split_reim(kspace_mask)

        # Bring majority of values to 0-1 range.
        corrupt_img = utils.convert_to_image_domain(kspace_mask)
        nf = np.percentile(np.abs(corrupt_img), 95, axis=(1, 2, 3))
        nf = nf[:, np.newaxis, np.newaxis, np.newaxis]

        if (input_domain == 'FREQ'):
            inp = kspace_mask / nf
        elif (input_domain == 'IMAGE'):
            inp = utils.convert_to_image_domain(kspace_mask) / nf

        if (output_domain == 'FREQ'):
            output = kspace_full / nf
            output_crop = kspace_full_crop / nf
        elif (output_domain == 'IMAGE'):
            output = utils.convert_to_image_domain(kspace_mask) / nf
            output_crop = utils.convert_to_image_domain(kspace_full_crop) / nf

        if (enforce_dc):
            yield ({
                'input': inp,
                'mask': mask
            }, {
                'output': output,
                'output_crop': output_crop
            })
        else:
            yield (inp, {'output': output, 'output_crop': output_crop})
Example #10
0
def generate_uniform_undersampled_data(images,
                                       input_domain,
                                       output_domain,
                                       corruption_frac,
                                       enforce_dc,
                                       batch_size=10):
    """Generator that yields batches of undersampled input and correct output data.

    For corrupted inputs, select each line in k-space with probability corruption_frac and set it to zero.

    Args:
      images(float): Numpy array of input images, of shape (num_images, n, n)
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      corruption_frac(float): Probability with which to zero a line in k-space
      batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10)

    Returns:
      inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2).

    """
    num_batches = np.ceil(len(images) / batch_size)
    img_shape = images.shape[1]

    images = utils.split_reim(images)
    spectra = utils.convert_to_frequency_domain(images)

    while True:
        n = images.shape[1]
        batch_inds = np.random.randint(0, images.shape[0], batch_size)

        if (input_domain == 'MAG' or ('COMPLEX' in input_domain)):
            n_ch_in = 1
        else:
            n_ch_in = 2

        if (output_domain == 'MAG' or ('COMPLEX' in output_domain)):
            n_ch_out = 1
        else:
            n_ch_out = 2
        inputs = np.empty((0, n, n, n_ch_in))
        outputs = np.empty((0, n, n, n_ch_out))
        masks = np.empty((0, n, n, n_ch_in))

        if ('COMPLEX' in input_domain):
            masks = np.empty((0, n, n))

        for j in batch_inds:
            true_img = np.expand_dims(images[j, :, :, :], 0)
            true_k = np.expand_dims(spectra[j, :, :, :], 0)
            mask = np.ones(true_k.shape)

            img_size = images.shape[1]
            num_points = int(img_size * corruption_frac)

            s = int(1 / (1 - corruption_frac))
            arc_lines = int(32 / s)

            arc_low = int((50 - int(arc_lines / 2)) * n / 100)
            arc_high = int((50 + int(arc_lines / 2)) * n / 100)

            coord_list_low = np.concatenate([[i + j for j in range(s - 1)]
                                             for i in range(0, arc_low, s)],
                                            axis=0)
            coord_list_high = np.concatenate(
                [[i + j for j in range(s - 1)]
                 for i in range(arc_high, n - (s - 1), s)],
                axis=0)
            coord_list = np.concatenate([coord_list_low, coord_list_high],
                                        axis=0)

            corrupt_k = true_k.copy()
            for k in range(len(coord_list)):
                corrupt_k[0, coord_list[k], :, :] = 0
                mask[0, coord_list[k], :, :] = 0
            corrupt_img = utils.convert_to_image_domain(corrupt_k)

            nf = np.max(corrupt_img)

            if (input_domain == 'FREQ'):
                inputs = np.append(inputs, corrupt_k / nf, axis=0)
                masks = np.append(masks, mask, axis=0)
            elif (input_domain == 'IMAGE'):
                inputs = np.append(inputs, corrupt_img / nf, axis=0)
            elif (input_domain == 'MAG'):
                corrupt_img = np.expand_dims(
                    np.abs(utils.join_reim(corrupt_img)), -1)
                inputs = np.append(inputs, corrupt_img / nf, axis=0)
            elif (input_domain == 'COMPLEX_K'):
                corrupt_k = np.expand_dims(utils.join_reim(corrupt_k), -1)
                inputs = np.append(inputs, corrupt_k / nf, axis=0)
            elif (input_domain == 'COMPLEX_I'):
                corrupt_img = np.expand_dims(utils.join_reim(corrupt_img), -1)
                inputs = np.append(inputs, corrupt_img / nf, axis=0)

            if (output_domain == 'FREQ'):
                outputs = np.append(outputs, true_k / nf, axis=0)
            elif (output_domain == 'IMAGE'):
                outputs = np.append(outputs, true_img / nf, axis=0)
            elif (output_domain == 'MAG'):
                true_img = np.expand_dims(np.abs(utils.join_reim(true_img)),
                                          -1)
                outputs = np.append(outputs, true_img / nf, axis=0)
            elif (output_domain == 'COMPLEX_K'):
                true_k = np.expand_dims(utils.join_reim(true_k), -1)
                outputs = np.append(inputs, true_k / nf, axis=0)
            elif (output_domain == 'COMPLEX_I'):
                true_img = np.expand_dims(utils.join_reim(true_img), -1)
                outputs = np.append(inputs, true_img / nf, axis=0)

            if ('COMPLEX' in input_domain):
                mask = mask[:, :, :, 0]
                masks = np.append(masks, mask, axis=0)

        if (enforce_dc):
            yield ((inputs, masks), outputs)
        else:
            yield (inputs, outputs)