示例#1
0
def plot_img_from_k_diff(k1, k2, i, vmin=-1, vmax=1, ax=None):
    plot_img_diff(utils.convert_to_image_domain(k1),
                  utils.convert_to_image_domain(k2),
                  i,
                  vmin=vmin,
                  vmax=vmax,
                  ax=ax)
def generate_noisy_data(images,
                        input_domain,
                        output_domain,
                        corruption_frac,
                        batch_size=10):
    """Generator that yields batches of noisy input and correct output data.

    For corrupted inputs, add complex-valued noise with standard deviation corruption_frac at each pixel in k-space.

    Args:
      images(float): Numpy array of input images, of shape (num_images,n,n)
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      corruption_frac(float): Variance of complex-valued noise to be added
      batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10)

    Returns:
      inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2).

    """
    num_batches = np.ceil(len(images) / batch_size)
    img_shape = images.shape[1]

    images = utils.split_reim(images)
    spectra = utils.convert_to_frequency_domain(images)

    while True:
        n = images.shape[1]
        batch_inds = np.random.randint(0, images.shape[0], batch_size)

        inputs = np.empty((0, n, n, 2))
        outputs = np.empty((0, n, n, 2))
        masks = np.empty((0, n, n, 2))

        for j in batch_inds:
            true_img = np.expand_dims(images[j, :, :, :], 0)
            true_k = np.expand_dims(spectra[j, :, :, :], 0)
            mask = np.ones(true_k.shape)

            img_size = images.shape[1]
            noise = np.random.normal(loc=0.0,
                                     scale=corruption_frac,
                                     size=true_k.shape)

            corrupt_k = true_k.copy() + noise
            corrupt_img = utils.convert_to_image_domain(corrupt_k)

            if (input_domain == 'FREQ'):
                inputs = np.append(inputs, corrupt_k, axis=0)
                masks = np.append(masks, mask, axis=0)
            elif (input_domain == 'IMAGE'):
                inputs = np.append(inputs, corrupt_img, axis=0)

            if (output_domain == 'FREQ'):
                outputs = np.append(outputs, true_k, axis=0)
            elif (output_domain == 'IMAGE'):
                outputs = np.append(outputs, true_img, axis=0)

        yield (inputs, outputs)
示例#3
0
    def test_convert_to_image_domain(self):
        real_array = np.full((1, 3, 3), 1.0)
        imag_array = np.full((1, 3, 3), 1.0)
        split_array = np.stack([real_array, imag_array], 3)

        iffted_array = utils.convert_to_image_domain(split_array)
        check_array = utils.convert_to_frequency_domain(iffted_array)

        self.assertIsNone(
            np.testing.assert_array_equal(split_array, check_array))
parser = argparse.ArgumentParser(description='Run inference on test set.')
parser.add_argument('exp_str', help='Number of experiment directory')
parser.add_argument('test_file', help='Location of h5 containing the test set')

args = parser.parse_args()

exp_str = 'undersample_motion'
test_data_path = args.test_file
exp_dir = filepaths.TRAIN_DIR + args.exp_str

f = h5py.File(test_data_path, 'r')
freq_inputs = f['inputs'][()]
if (exp_str == 'undersample'):
    masks = f['masks'][()]
img_inputs = utils.convert_to_image_domain(freq_inputs)
mag_inputs = np.expand_dims(np.abs(utils.join_reim(img_inputs)), -1)

freq_label = f['outputs'][()]
img_label = utils.convert_to_image_domain(freq_label)
mag_label = np.expand_dims(np.abs(utils.join_reim(img_label)), -1)

for exp in os.listdir(exp_dir):
    model_dir = os.path.join(exp_dir, exp)
    config_path = os.path.join(model_dir,
                               visualization_lib.get_config_path(model_dir))
    config, model = visualization_lib.load_model(config_path)

    best_ckpt = visualization_lib.get_best_ckpt(model_dir)
    model.load_weights(
        os.path.join(model_dir, 'cp-' + str(best_ckpt).zfill(4) + '.ckpt'))
示例#5
0
def generate_undersampled_data(images,
                               input_domain,
                               output_domain,
                               corruption_frac,
                               enforce_dc,
                               batch_size=10):
    """Generator that yields batches of undersampled input and correct output data.

    For corrupted inputs, select each line in k-space with probability corruption_frac and set it to zero.

    Args:
      images(float): Numpy array of input images, of shape (num_images, n, n)
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      corruption_frac(float): Probability with which to zero a line in k-space
      batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10)

    Returns:
      inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2).

    """
    num_batches = np.ceil(len(images) / batch_size)
    img_shape = images.shape[1]

    images = utils.split_reim(images)
    spectra = utils.convert_to_frequency_domain(images)

    while True:
        n = images.shape[1]
        batch_inds = np.random.randint(0, images.shape[0], batch_size)

        inputs = np.empty((0, n, n, 2))
        outputs = np.empty((0, n, n, 2))
        masks = np.empty((0, n, n, 2))

        for j in batch_inds:
            true_img = np.expand_dims(images[j, :, :, :], 0)
            true_k = np.expand_dims(spectra[j, :, :, :], 0)
            mask = np.ones(true_k.shape)

            img_size = images.shape[1]
            num_points = int(img_size * corruption_frac)
            coord_list = np.random.choice(range(img_size),
                                          num_points,
                                          replace=False)

            corrupt_k = true_k.copy()
            for k in range(len(coord_list)):
                corrupt_k[0, coord_list[k], :, :] = 0
                mask[0, coord_list[k], :, :] = 0
            corrupt_img = utils.convert_to_image_domain(corrupt_k)

            nf = np.max(corrupt_img)

            if (input_domain == 'FREQ'):
                inputs = np.append(inputs, corrupt_k / nf, axis=0)
                masks = np.append(masks, mask, axis=0)
            elif (input_domain == 'IMAGE'):
                inputs = np.append(inputs, corrupt_img / nf, axis=0)

            if (output_domain == 'FREQ'):
                outputs = np.append(outputs, true_k / nf, axis=0)
            elif (output_domain == 'IMAGE'):
                outputs = np.append(outputs, true_img / nf, axis=0)

        if (enforce_dc):
            yield ((inputs, masks), outputs)
        else:
            yield (inputs, outputs)
示例#6
0
def generate_undersampled_motion_data(images,
                                      input_domain,
                                      output_domain,
                                      us_frac,
                                      mot_frac,
                                      max_htrans,
                                      max_vtrans,
                                      max_rot,
                                      batch_size=10):
    """Generator that yields batches of motion-corrupted, undersampled input and correct output data.

    For corrupted inputs, select some lines at which motion occurs; randomly generate and apply translation/rotations at those lines.

    Args:
      images(float): Numpy array of input images, of shape (num_images,n,n)
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      us_frac(float): Fraction of lines at which motion occurs.
      mot_frac(float): Fraction of lines at which motion occurs.
      max_htrans(float): Maximum fraction of image width for a translation.
      max_vtrans(float): Maximum fraction of image height for a translation.
      max_rot(float): Maximum fraction of 360 degrees for a rotation.
      batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10)

    Returns:
      inputs: Tuple of corrupted input data and correct output data, both numpy arrays of shape (batch_size,n,n,2).

    """
    def get_us_motion_mask(arr_shape, us_frac):
        """ Based on https://github.com/facebookresearch/fastMRI/blob/master/common/subsample.py. """
        num_cols = arr_shape[1]
        if (us_frac != 1):
            acceleration = int(1 / (1 - us_frac))
            center_fraction = (1 - us_frac) * 0.08 / 0.25

            # Create the mask
            num_low_freqs = int(round(num_cols * center_fraction))
            prob = (num_cols / acceleration - num_low_freqs) / \
                (num_cols - num_low_freqs)
            mask_inds = np.random.uniform(size=num_cols) < prob
            pad = (num_cols - num_low_freqs + 1) // 2
            mask_inds[pad:pad + num_low_freqs] = True

            mask = np.zeros(arr_shape)
            mask[:, mask_inds] = 1

            return np.fft.ifftshift(mask).T

        else:
            return (np.ones(arr_shape)).T

    num_batches = np.ceil(len(images) / batch_size)
    img_shape = images.shape[1]

    reim_images = images.copy()
    images = utils.split_reim(images)
    spectra = utils.convert_to_frequency_domain(images)

    while True:
        n = images.shape[1]
        batch_inds = np.random.randint(0, images.shape[0], batch_size)

        inputs = np.empty((0, n, n, 2))
        outputs = np.empty((0, n, n, 2))
        masks = np.empty((0, n, n, 2))

        for j in batch_inds:

            true_img = np.expand_dims(images[j, :, :, :], 0)

            img_size = images.shape[1]
            num_points = int(np.random.random() * mot_frac * n)
            coord_list = np.sort(
                np.random.choice(img_size, size=num_points, replace=False))
            num_pix = np.zeros((num_points, 2))
            angle = np.zeros(num_points)

            max_htrans_pix = n * max_htrans
            max_vtrans_pix = n * max_vtrans
            max_rot_deg = 360 * max_rot

            num_pix[:, 0] = np.random.random(num_points) * (
                2 * max_htrans_pix) - max_htrans_pix
            num_pix[:, 1] = np.random.random(num_points) * (
                2 * max_vtrans_pix) - max_vtrans_pix
            angle = np.random.random(num_points) * \
                (2 * max_rot_deg) - max_rot_deg

            corrupt_k, true_k = motion.add_rotation_and_translations(
                reim_images[j, :, :], coord_list, angle, num_pix)
            true_k = utils.split_reim(np.expand_dims(true_k, 0))
            true_img = utils.convert_to_image_domain(true_k)
            corrupt_k = utils.split_reim(np.expand_dims(corrupt_k, 0))

            mask = get_us_motion_mask(true_img.shape[1:3], us_frac)
            r_mask = np.expand_dims(
                np.repeat(mask[:, :, np.newaxis], 2, axis=-1), 0)

            corrupt_k *= r_mask
            corrupt_img = utils.convert_to_image_domain(corrupt_k)

            nf = np.max(corrupt_img)

            if (input_domain == 'FREQ'):
                inputs = np.append(inputs, corrupt_k / nf, axis=0)
            elif (input_domain == 'IMAGE'):
                inputs = np.append(inputs, corrupt_img / nf, axis=0)

            if (output_domain == 'FREQ'):
                outputs = np.append(outputs, true_k / nf, axis=0)
            elif (output_domain == 'IMAGE'):
                outputs = np.append(outputs, true_img / nf, axis=0)

        yield (inputs, outputs)
示例#7
0
def generate_motion_data(images,
                         input_domain,
                         output_domain,
                         mot_frac,
                         max_htrans,
                         max_vtrans,
                         max_rot,
                         batch_size=10):
    """Generator that yields batches of motion-corrupted input and correct output data.

    For corrupted inputs, select some lines at which motion occurs; randomly generate and apply translation/rotations at those lines.

    Args:
      images(float): Numpy array of input images, of shape (num_images,n,n)
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      mot_frac(float): Fraction of lines at which motion occurs.
      max_htrans(float): Maximum fraction of image width for a translation.
      max_vtrans(float): Maximum fraction of image height for a translation.
      max_rot(float): Maximum fraction of 360 degrees for a rotation.
      batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10)

    Returns:
      inputs: Tuple of corrupted input data and correct output data, both numpy arrays of shape (batch_size,n,n,2).

    """
    num_batches = np.ceil(len(images) / batch_size)
    img_shape = images.shape[1]

    reim_images = images.copy()
    images = utils.split_reim(images)
    spectra = utils.convert_to_frequency_domain(images)

    while True:
        n = images.shape[1]
        batch_inds = np.random.randint(0, images.shape[0], batch_size)

        inputs = np.empty((0, n, n, 2))
        outputs = np.empty((0, n, n, 2))
        masks = np.empty((0, n, n, 2))

        for j in batch_inds:
            true_img = np.expand_dims(images[j, :, :, :], 0)

            img_size = images.shape[1]
            num_points = int(mot_frac * n)
            coord_list = np.sort(
                np.random.choice(img_size, size=num_points, replace=False))
            num_pix = np.zeros((num_points, 2))
            angle = np.zeros(num_points)

            max_htrans_pix = n * max_htrans
            max_vtrans_pix = n * max_vtrans
            max_rot_deg = 360 * max_rot

            num_pix[:, 0] = np.random.random(num_points) * (
                2 * max_htrans_pix) - max_htrans_pix
            num_pix[:, 1] = np.random.random(num_points) * (
                2 * max_vtrans_pix) - max_vtrans_pix
            angle = np.random.random(num_points) * \
                (2 * max_rot_deg) - max_rot_deg

            corrupt_k, true_k = motion.add_rotation_and_translations(
                reim_images[j, :, :], coord_list, angle, num_pix)
            corrupt_k = utils.split_reim(np.expand_dims(corrupt_k, 0))
            true_k = utils.split_reim(np.expand_dims(true_k, 0))

            corrupt_img = utils.convert_to_image_domain(corrupt_k)

            nf = np.max(corrupt_img)

            if (input_domain == 'FREQ'):
                inputs = np.append(inputs, corrupt_k / nf, axis=0)
            elif (input_domain == 'IMAGE'):
                inputs = np.append(inputs, corrupt_img / nf, axis=0)

            if (output_domain == 'FREQ'):
                outputs = np.append(outputs, true_k / nf, axis=0)
            elif (output_domain == 'IMAGE'):
                outputs = np.append(outputs, true_img / nf, axis=0)

        yield (inputs, outputs)
示例#8
0
def generate_undersampled_data(image_dir,
                               input_domain,
                               output_domain,
                               corruption_frac,
                               enforce_dc,
                               fs=False,
                               batch_size=16):
    """Generator that yields batches of undersampled input and correct output data.

    For corrupted inputs, select each line in k-space with probability corruption_frac and set it to zero.

    Args:
      image_dir(str): Directory containing 3D MRI volumes
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      corruption_frac(float): Probability with which to zero a line in k-space
      fs(Bool, optional): Whether to read images with fat suppression (True) or without (False)
      batch_size(int, optional): Number of input-output pairs in each batch

    Returns:
      inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2).

    """

    while True:
        images, kspace = get_fastmri_slices_from_dir(image_dir, batch_size, fs)

        images = utils.split_reim(images)
        spectra = utils.convert_to_frequency_domain(images)

        n = images.shape[1]

        inputs = np.empty((0, n, n, 2))
        outputs = np.empty((0, n, n, 2))
        masks = np.empty((0, n, n, 2))

        for j in range(batch_size):
            true_img = np.expand_dims(images[j, :, :, :], 0)
            true_k = np.expand_dims(spectra[j, :, :, :], 0)
            mask = get_undersampling_mask(kspace[j, :, :].shape,
                                          corruption_frac)
            r_mask = np.expand_dims(
                np.repeat(mask[:, :, np.newaxis], 2, axis=-1), 0)

            num_points = int(n * corruption_frac)
            coord_list = np.random.choice(n, num_points, replace=False)

            corrupt_k = kspace[j, :, :] * mask

            # Bring majority of values to 0-1 range.
            corrupt_k = utils.split_reim(np.expand_dims(corrupt_k, 0)) * 500

            corrupt_img = utils.convert_to_image_domain(corrupt_k)

            nf = np.max(np.abs(corrupt_img))

            if (input_domain == 'FREQ'):
                inputs = np.append(inputs, corrupt_k / nf, axis=0)
                masks = np.append(masks, r_mask, axis=0)
            elif (input_domain == 'IMAGE'):
                inputs = np.append(inputs, corrupt_img / nf, axis=0)

            if (output_domain == 'FREQ'):
                outputs = np.append(outputs, true_k / nf, axis=0)
            elif (output_domain == 'IMAGE'):
                outputs = np.append(outputs, true_img / nf, axis=0)

        if (enforce_dc):
            yield ((inputs, masks), outputs)
        else:
            yield (inputs, outputs)
示例#9
0
def plot_img_from_k(k, i, vmin=None, vmax=None, ax=None, fftshift=False):
    to_plot = utils.convert_to_image_domain(k)
    if (fftshift):
        to_plot = np.fft.fftshift(to_plot, axes=(1, 2))
    plot_img(to_plot, i, vmin=vmin, vmax=vmax, ax=ax)
def generate_undersampled_data(image_dir,
                               input_domain,
                               output_domain,
                               corruption_frac,
                               enforce_dc,
                               batch_size,
                               fs=False):
    """Generator that yields batches of undersampled input and correct output data.

    For corrupted inputs, select each line in k-space with probability corruption_frac and set it to zero.

    Args:
      image_dir(str): Directory containing 3D MRI volumes
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      corruption_frac(float): Probability with which to zero a line in k-space
      fs(Bool, optional): Whether to read images with fat suppression (True) or without (False)
      batch_size(int, optional): Number of input-output pairs in each batch

    Returns:
      inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2).

    """

    while True:
        kspace_full, kspace_full_crop, kspace_mask, mask = get_fastmri_slices_from_dir(
            image_dir, batch_size, fs, corruption_frac=corruption_frac)

        mask = np.expand_dims(mask, -1)
        mask = np.repeat(mask, 2, axis=-1)

        kspace_full = utils.split_reim(kspace_full)
        kspace_full_crop = utils.split_reim(kspace_full_crop)
        kspace_mask = utils.split_reim(kspace_mask)

        # Bring majority of values to 0-1 range.
        corrupt_img = utils.convert_to_image_domain(kspace_mask)
        nf = np.percentile(np.abs(corrupt_img), 95, axis=(1, 2, 3))
        nf = nf[:, np.newaxis, np.newaxis, np.newaxis]

        if (input_domain == 'FREQ'):
            inp = kspace_mask / nf
        elif (input_domain == 'IMAGE'):
            inp = utils.convert_to_image_domain(kspace_mask) / nf

        if (output_domain == 'FREQ'):
            output = kspace_full / nf
            output_crop = kspace_full_crop / nf
        elif (output_domain == 'IMAGE'):
            output = utils.convert_to_image_domain(kspace_mask) / nf
            output_crop = utils.convert_to_image_domain(kspace_full_crop) / nf

        if (enforce_dc):
            yield ({
                'input': inp,
                'mask': mask
            }, {
                'output': output,
                'output_crop': output_crop
            })
        else:
            yield (inp, {'output': output, 'output_crop': output_crop})
示例#11
0
def generate_uniform_undersampled_data(images,
                                       input_domain,
                                       output_domain,
                                       corruption_frac,
                                       enforce_dc,
                                       batch_size=10):
    """Generator that yields batches of undersampled input and correct output data.

    For corrupted inputs, select each line in k-space with probability corruption_frac and set it to zero.

    Args:
      images(float): Numpy array of input images, of shape (num_images, n, n)
      input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE'
      output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE'
      corruption_frac(float): Probability with which to zero a line in k-space
      batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10)

    Returns:
      inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2).

    """
    num_batches = np.ceil(len(images) / batch_size)
    img_shape = images.shape[1]

    images = utils.split_reim(images)
    spectra = utils.convert_to_frequency_domain(images)

    while True:
        n = images.shape[1]
        batch_inds = np.random.randint(0, images.shape[0], batch_size)

        if (input_domain == 'MAG' or ('COMPLEX' in input_domain)):
            n_ch_in = 1
        else:
            n_ch_in = 2

        if (output_domain == 'MAG' or ('COMPLEX' in output_domain)):
            n_ch_out = 1
        else:
            n_ch_out = 2
        inputs = np.empty((0, n, n, n_ch_in))
        outputs = np.empty((0, n, n, n_ch_out))
        masks = np.empty((0, n, n, n_ch_in))

        if ('COMPLEX' in input_domain):
            masks = np.empty((0, n, n))

        for j in batch_inds:
            true_img = np.expand_dims(images[j, :, :, :], 0)
            true_k = np.expand_dims(spectra[j, :, :, :], 0)
            mask = np.ones(true_k.shape)

            img_size = images.shape[1]
            num_points = int(img_size * corruption_frac)

            s = int(1 / (1 - corruption_frac))
            arc_lines = int(32 / s)

            arc_low = int((50 - int(arc_lines / 2)) * n / 100)
            arc_high = int((50 + int(arc_lines / 2)) * n / 100)

            coord_list_low = np.concatenate([[i + j for j in range(s - 1)]
                                             for i in range(0, arc_low, s)],
                                            axis=0)
            coord_list_high = np.concatenate(
                [[i + j for j in range(s - 1)]
                 for i in range(arc_high, n - (s - 1), s)],
                axis=0)
            coord_list = np.concatenate([coord_list_low, coord_list_high],
                                        axis=0)

            corrupt_k = true_k.copy()
            for k in range(len(coord_list)):
                corrupt_k[0, coord_list[k], :, :] = 0
                mask[0, coord_list[k], :, :] = 0
            corrupt_img = utils.convert_to_image_domain(corrupt_k)

            nf = np.max(corrupt_img)

            if (input_domain == 'FREQ'):
                inputs = np.append(inputs, corrupt_k / nf, axis=0)
                masks = np.append(masks, mask, axis=0)
            elif (input_domain == 'IMAGE'):
                inputs = np.append(inputs, corrupt_img / nf, axis=0)
            elif (input_domain == 'MAG'):
                corrupt_img = np.expand_dims(
                    np.abs(utils.join_reim(corrupt_img)), -1)
                inputs = np.append(inputs, corrupt_img / nf, axis=0)
            elif (input_domain == 'COMPLEX_K'):
                corrupt_k = np.expand_dims(utils.join_reim(corrupt_k), -1)
                inputs = np.append(inputs, corrupt_k / nf, axis=0)
            elif (input_domain == 'COMPLEX_I'):
                corrupt_img = np.expand_dims(utils.join_reim(corrupt_img), -1)
                inputs = np.append(inputs, corrupt_img / nf, axis=0)

            if (output_domain == 'FREQ'):
                outputs = np.append(outputs, true_k / nf, axis=0)
            elif (output_domain == 'IMAGE'):
                outputs = np.append(outputs, true_img / nf, axis=0)
            elif (output_domain == 'MAG'):
                true_img = np.expand_dims(np.abs(utils.join_reim(true_img)),
                                          -1)
                outputs = np.append(outputs, true_img / nf, axis=0)
            elif (output_domain == 'COMPLEX_K'):
                true_k = np.expand_dims(utils.join_reim(true_k), -1)
                outputs = np.append(inputs, true_k / nf, axis=0)
            elif (output_domain == 'COMPLEX_I'):
                true_img = np.expand_dims(utils.join_reim(true_img), -1)
                outputs = np.append(inputs, true_img / nf, axis=0)

            if ('COMPLEX' in input_domain):
                mask = mask[:, :, :, 0]
                masks = np.append(masks, mask, axis=0)

        if (enforce_dc):
            yield ((inputs, masks), outputs)
        else:
            yield (inputs, outputs)