def get_fastmri_slices_from_dir(image_dir, batch_size, fs=False, corruption_frac=1.0): """Load and normalize MRI dataset. Args: image_dir(str): Directory containing 3D MRI volumes of shape (?, n, n); each volume is stored as a '.h5' file in the FastMRI format batch_size(int): Number of input-output pairs in each batch fs(Boolean): Whether to read images with fat suppression (True) or without (False) corruption_frac(float): Probability with which to zero a line in k-space Returns: float: A numpy array of size (num_images, n, n) containing all image slices """ image_names = os.listdir(image_dir) found_vol = False while not found_vol: img_i = np.random.randint(0, len(image_names)) img = image_names[img_i] with h5py.File(os.path.join(image_dir, img), "r") as f: if (('CORPDFS' in f.attrs['acquisition']) == fs): n_slices = f['kspace'].shape[0] kspace_fulls = np.empty( (batch_size, f['kspace'].shape[1], f['kspace'].shape[2]), dtype=complex) masks = np.empty( (batch_size, f['kspace'].shape[1], f['kspace'].shape[2])) kspace_masks = kspace_fulls.copy() kspace_full_crops = np.empty((batch_size, n_crop, n_crop), dtype=complex) for i in range(batch_size): slice_i = np.random.randint(0, n_slices) kspace_full = f['kspace'][slice_i, :, :] mask = get_undersampling_mask(kspace_full.shape, corruption_frac) kspace_mask = mask * f['kspace'][slice_i, :, :] kspace_full_crop = models.crop_320( utils.split_reim(np.expand_dims(kspace_full, 0))[0, :, :, :]) kspace_full_crop = utils.join_reim( np.expand_dims(kspace_full_crop, 0))[0, :, :] kspace_fulls[i, ...] = kspace_full kspace_full_crops[i, ...] = kspace_full_crop kspace_masks[i, ...] = kspace_mask masks[i, ...] = mask found_vol = True return kspace_fulls, kspace_full_crops, kspace_masks, masks
def generate_noisy_data(images, input_domain, output_domain, corruption_frac, batch_size=10): """Generator that yields batches of noisy input and correct output data. For corrupted inputs, add complex-valued noise with standard deviation corruption_frac at each pixel in k-space. Args: images(float): Numpy array of input images, of shape (num_images,n,n) input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE' output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE' corruption_frac(float): Variance of complex-valued noise to be added batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10) Returns: inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2). """ num_batches = np.ceil(len(images) / batch_size) img_shape = images.shape[1] images = utils.split_reim(images) spectra = utils.convert_to_frequency_domain(images) while True: n = images.shape[1] batch_inds = np.random.randint(0, images.shape[0], batch_size) inputs = np.empty((0, n, n, 2)) outputs = np.empty((0, n, n, 2)) masks = np.empty((0, n, n, 2)) for j in batch_inds: true_img = np.expand_dims(images[j, :, :, :], 0) true_k = np.expand_dims(spectra[j, :, :, :], 0) mask = np.ones(true_k.shape) img_size = images.shape[1] noise = np.random.normal(loc=0.0, scale=corruption_frac, size=true_k.shape) corrupt_k = true_k.copy() + noise corrupt_img = utils.convert_to_image_domain(corrupt_k) if (input_domain == 'FREQ'): inputs = np.append(inputs, corrupt_k, axis=0) masks = np.append(masks, mask, axis=0) elif (input_domain == 'IMAGE'): inputs = np.append(inputs, corrupt_img, axis=0) if (output_domain == 'FREQ'): outputs = np.append(outputs, true_k, axis=0) elif (output_domain == 'IMAGE'): outputs = np.append(outputs, true_img, axis=0) yield (inputs, outputs)
def test_split_reim(self): compl_array = np.full((1, 3, 3), 1 + 1j) real_array = np.full((1, 3, 3), 1.0) imag_array = np.full((1, 3, 3), 1.0) split_array = np.stack([real_array, imag_array], 3) self.assertIsNone( np.testing.assert_array_equal(utils.split_reim(compl_array), split_array))
def get_mri_spectra_stats(images): """Compute mean and stddev of MRI spectra. Args: images(float): Numpy array of shape (num_images, n, n) containing input images. Returns: float: Numpy array of shape (1, n, n, 2) containing pixelwise mean of the real and imaginary parts of the Fourier spectra of the input images float: Numpy array of shape (1, n, n, 2) containing pixelwise standard deviation of the real and imaginary parts of the Fourier spectra of the input images """ images = utils.split_reim(images) spectra = utils.convert_to_frequency_domain(images) spectra_mean = np.mean(spectra, axis=0, keepdims=True) spectra_std = np.clip(np.std(spectra, axis=0, keepdims=True), a_min=1, a_max=None) return spectra_mean, spectra_std
def generate_undersampled_data(images, input_domain, output_domain, corruption_frac, enforce_dc, batch_size=10): """Generator that yields batches of undersampled input and correct output data. For corrupted inputs, select each line in k-space with probability corruption_frac and set it to zero. Args: images(float): Numpy array of input images, of shape (num_images, n, n) input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE' output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE' corruption_frac(float): Probability with which to zero a line in k-space batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10) Returns: inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2). """ num_batches = np.ceil(len(images) / batch_size) img_shape = images.shape[1] images = utils.split_reim(images) spectra = utils.convert_to_frequency_domain(images) while True: n = images.shape[1] batch_inds = np.random.randint(0, images.shape[0], batch_size) inputs = np.empty((0, n, n, 2)) outputs = np.empty((0, n, n, 2)) masks = np.empty((0, n, n, 2)) for j in batch_inds: true_img = np.expand_dims(images[j, :, :, :], 0) true_k = np.expand_dims(spectra[j, :, :, :], 0) mask = np.ones(true_k.shape) img_size = images.shape[1] num_points = int(img_size * corruption_frac) coord_list = np.random.choice(range(img_size), num_points, replace=False) corrupt_k = true_k.copy() for k in range(len(coord_list)): corrupt_k[0, coord_list[k], :, :] = 0 mask[0, coord_list[k], :, :] = 0 corrupt_img = utils.convert_to_image_domain(corrupt_k) nf = np.max(corrupt_img) if (input_domain == 'FREQ'): inputs = np.append(inputs, corrupt_k / nf, axis=0) masks = np.append(masks, mask, axis=0) elif (input_domain == 'IMAGE'): inputs = np.append(inputs, corrupt_img / nf, axis=0) if (output_domain == 'FREQ'): outputs = np.append(outputs, true_k / nf, axis=0) elif (output_domain == 'IMAGE'): outputs = np.append(outputs, true_img / nf, axis=0) if (enforce_dc): yield ((inputs, masks), outputs) else: yield (inputs, outputs)
def generate_undersampled_motion_data(images, input_domain, output_domain, us_frac, mot_frac, max_htrans, max_vtrans, max_rot, batch_size=10): """Generator that yields batches of motion-corrupted, undersampled input and correct output data. For corrupted inputs, select some lines at which motion occurs; randomly generate and apply translation/rotations at those lines. Args: images(float): Numpy array of input images, of shape (num_images,n,n) input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE' output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE' us_frac(float): Fraction of lines at which motion occurs. mot_frac(float): Fraction of lines at which motion occurs. max_htrans(float): Maximum fraction of image width for a translation. max_vtrans(float): Maximum fraction of image height for a translation. max_rot(float): Maximum fraction of 360 degrees for a rotation. batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10) Returns: inputs: Tuple of corrupted input data and correct output data, both numpy arrays of shape (batch_size,n,n,2). """ def get_us_motion_mask(arr_shape, us_frac): """ Based on https://github.com/facebookresearch/fastMRI/blob/master/common/subsample.py. """ num_cols = arr_shape[1] if (us_frac != 1): acceleration = int(1 / (1 - us_frac)) center_fraction = (1 - us_frac) * 0.08 / 0.25 # Create the mask num_low_freqs = int(round(num_cols * center_fraction)) prob = (num_cols / acceleration - num_low_freqs) / \ (num_cols - num_low_freqs) mask_inds = np.random.uniform(size=num_cols) < prob pad = (num_cols - num_low_freqs + 1) // 2 mask_inds[pad:pad + num_low_freqs] = True mask = np.zeros(arr_shape) mask[:, mask_inds] = 1 return np.fft.ifftshift(mask).T else: return (np.ones(arr_shape)).T num_batches = np.ceil(len(images) / batch_size) img_shape = images.shape[1] reim_images = images.copy() images = utils.split_reim(images) spectra = utils.convert_to_frequency_domain(images) while True: n = images.shape[1] batch_inds = np.random.randint(0, images.shape[0], batch_size) inputs = np.empty((0, n, n, 2)) outputs = np.empty((0, n, n, 2)) masks = np.empty((0, n, n, 2)) for j in batch_inds: true_img = np.expand_dims(images[j, :, :, :], 0) img_size = images.shape[1] num_points = int(np.random.random() * mot_frac * n) coord_list = np.sort( np.random.choice(img_size, size=num_points, replace=False)) num_pix = np.zeros((num_points, 2)) angle = np.zeros(num_points) max_htrans_pix = n * max_htrans max_vtrans_pix = n * max_vtrans max_rot_deg = 360 * max_rot num_pix[:, 0] = np.random.random(num_points) * ( 2 * max_htrans_pix) - max_htrans_pix num_pix[:, 1] = np.random.random(num_points) * ( 2 * max_vtrans_pix) - max_vtrans_pix angle = np.random.random(num_points) * \ (2 * max_rot_deg) - max_rot_deg corrupt_k, true_k = motion.add_rotation_and_translations( reim_images[j, :, :], coord_list, angle, num_pix) true_k = utils.split_reim(np.expand_dims(true_k, 0)) true_img = utils.convert_to_image_domain(true_k) corrupt_k = utils.split_reim(np.expand_dims(corrupt_k, 0)) mask = get_us_motion_mask(true_img.shape[1:3], us_frac) r_mask = np.expand_dims( np.repeat(mask[:, :, np.newaxis], 2, axis=-1), 0) corrupt_k *= r_mask corrupt_img = utils.convert_to_image_domain(corrupt_k) nf = np.max(corrupt_img) if (input_domain == 'FREQ'): inputs = np.append(inputs, corrupt_k / nf, axis=0) elif (input_domain == 'IMAGE'): inputs = np.append(inputs, corrupt_img / nf, axis=0) if (output_domain == 'FREQ'): outputs = np.append(outputs, true_k / nf, axis=0) elif (output_domain == 'IMAGE'): outputs = np.append(outputs, true_img / nf, axis=0) yield (inputs, outputs)
def generate_motion_data(images, input_domain, output_domain, mot_frac, max_htrans, max_vtrans, max_rot, batch_size=10): """Generator that yields batches of motion-corrupted input and correct output data. For corrupted inputs, select some lines at which motion occurs; randomly generate and apply translation/rotations at those lines. Args: images(float): Numpy array of input images, of shape (num_images,n,n) input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE' output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE' mot_frac(float): Fraction of lines at which motion occurs. max_htrans(float): Maximum fraction of image width for a translation. max_vtrans(float): Maximum fraction of image height for a translation. max_rot(float): Maximum fraction of 360 degrees for a rotation. batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10) Returns: inputs: Tuple of corrupted input data and correct output data, both numpy arrays of shape (batch_size,n,n,2). """ num_batches = np.ceil(len(images) / batch_size) img_shape = images.shape[1] reim_images = images.copy() images = utils.split_reim(images) spectra = utils.convert_to_frequency_domain(images) while True: n = images.shape[1] batch_inds = np.random.randint(0, images.shape[0], batch_size) inputs = np.empty((0, n, n, 2)) outputs = np.empty((0, n, n, 2)) masks = np.empty((0, n, n, 2)) for j in batch_inds: true_img = np.expand_dims(images[j, :, :, :], 0) img_size = images.shape[1] num_points = int(mot_frac * n) coord_list = np.sort( np.random.choice(img_size, size=num_points, replace=False)) num_pix = np.zeros((num_points, 2)) angle = np.zeros(num_points) max_htrans_pix = n * max_htrans max_vtrans_pix = n * max_vtrans max_rot_deg = 360 * max_rot num_pix[:, 0] = np.random.random(num_points) * ( 2 * max_htrans_pix) - max_htrans_pix num_pix[:, 1] = np.random.random(num_points) * ( 2 * max_vtrans_pix) - max_vtrans_pix angle = np.random.random(num_points) * \ (2 * max_rot_deg) - max_rot_deg corrupt_k, true_k = motion.add_rotation_and_translations( reim_images[j, :, :], coord_list, angle, num_pix) corrupt_k = utils.split_reim(np.expand_dims(corrupt_k, 0)) true_k = utils.split_reim(np.expand_dims(true_k, 0)) corrupt_img = utils.convert_to_image_domain(corrupt_k) nf = np.max(corrupt_img) if (input_domain == 'FREQ'): inputs = np.append(inputs, corrupt_k / nf, axis=0) elif (input_domain == 'IMAGE'): inputs = np.append(inputs, corrupt_img / nf, axis=0) if (output_domain == 'FREQ'): outputs = np.append(outputs, true_k / nf, axis=0) elif (output_domain == 'IMAGE'): outputs = np.append(outputs, true_img / nf, axis=0) yield (inputs, outputs)
def generate_undersampled_data(image_dir, input_domain, output_domain, corruption_frac, enforce_dc, fs=False, batch_size=16): """Generator that yields batches of undersampled input and correct output data. For corrupted inputs, select each line in k-space with probability corruption_frac and set it to zero. Args: image_dir(str): Directory containing 3D MRI volumes input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE' output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE' corruption_frac(float): Probability with which to zero a line in k-space fs(Bool, optional): Whether to read images with fat suppression (True) or without (False) batch_size(int, optional): Number of input-output pairs in each batch Returns: inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2). """ while True: images, kspace = get_fastmri_slices_from_dir(image_dir, batch_size, fs) images = utils.split_reim(images) spectra = utils.convert_to_frequency_domain(images) n = images.shape[1] inputs = np.empty((0, n, n, 2)) outputs = np.empty((0, n, n, 2)) masks = np.empty((0, n, n, 2)) for j in range(batch_size): true_img = np.expand_dims(images[j, :, :, :], 0) true_k = np.expand_dims(spectra[j, :, :, :], 0) mask = get_undersampling_mask(kspace[j, :, :].shape, corruption_frac) r_mask = np.expand_dims( np.repeat(mask[:, :, np.newaxis], 2, axis=-1), 0) num_points = int(n * corruption_frac) coord_list = np.random.choice(n, num_points, replace=False) corrupt_k = kspace[j, :, :] * mask # Bring majority of values to 0-1 range. corrupt_k = utils.split_reim(np.expand_dims(corrupt_k, 0)) * 500 corrupt_img = utils.convert_to_image_domain(corrupt_k) nf = np.max(np.abs(corrupt_img)) if (input_domain == 'FREQ'): inputs = np.append(inputs, corrupt_k / nf, axis=0) masks = np.append(masks, r_mask, axis=0) elif (input_domain == 'IMAGE'): inputs = np.append(inputs, corrupt_img / nf, axis=0) if (output_domain == 'FREQ'): outputs = np.append(outputs, true_k / nf, axis=0) elif (output_domain == 'IMAGE'): outputs = np.append(outputs, true_img / nf, axis=0) if (enforce_dc): yield ((inputs, masks), outputs) else: yield (inputs, outputs)
def generate_undersampled_data(image_dir, input_domain, output_domain, corruption_frac, enforce_dc, batch_size, fs=False): """Generator that yields batches of undersampled input and correct output data. For corrupted inputs, select each line in k-space with probability corruption_frac and set it to zero. Args: image_dir(str): Directory containing 3D MRI volumes input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE' output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE' corruption_frac(float): Probability with which to zero a line in k-space fs(Bool, optional): Whether to read images with fat suppression (True) or without (False) batch_size(int, optional): Number of input-output pairs in each batch Returns: inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2). """ while True: kspace_full, kspace_full_crop, kspace_mask, mask = get_fastmri_slices_from_dir( image_dir, batch_size, fs, corruption_frac=corruption_frac) mask = np.expand_dims(mask, -1) mask = np.repeat(mask, 2, axis=-1) kspace_full = utils.split_reim(kspace_full) kspace_full_crop = utils.split_reim(kspace_full_crop) kspace_mask = utils.split_reim(kspace_mask) # Bring majority of values to 0-1 range. corrupt_img = utils.convert_to_image_domain(kspace_mask) nf = np.percentile(np.abs(corrupt_img), 95, axis=(1, 2, 3)) nf = nf[:, np.newaxis, np.newaxis, np.newaxis] if (input_domain == 'FREQ'): inp = kspace_mask / nf elif (input_domain == 'IMAGE'): inp = utils.convert_to_image_domain(kspace_mask) / nf if (output_domain == 'FREQ'): output = kspace_full / nf output_crop = kspace_full_crop / nf elif (output_domain == 'IMAGE'): output = utils.convert_to_image_domain(kspace_mask) / nf output_crop = utils.convert_to_image_domain(kspace_full_crop) / nf if (enforce_dc): yield ({ 'input': inp, 'mask': mask }, { 'output': output, 'output_crop': output_crop }) else: yield (inp, {'output': output, 'output_crop': output_crop})
def generate_uniform_undersampled_data(images, input_domain, output_domain, corruption_frac, enforce_dc, batch_size=10): """Generator that yields batches of undersampled input and correct output data. For corrupted inputs, select each line in k-space with probability corruption_frac and set it to zero. Args: images(float): Numpy array of input images, of shape (num_images, n, n) input_domain(str): The domain of the network input; 'FREQ' or 'IMAGE' output_domain(str): The domain of the network output; 'FREQ' or 'IMAGE' corruption_frac(float): Probability with which to zero a line in k-space batch_size(int, optional): Number of input-output pairs in each batch (Default value = 10) Returns: inputs: Tuple of corrupted input data and ground truth output data, both numpy arrays of shape (batch_size,n,n,2). """ num_batches = np.ceil(len(images) / batch_size) img_shape = images.shape[1] images = utils.split_reim(images) spectra = utils.convert_to_frequency_domain(images) while True: n = images.shape[1] batch_inds = np.random.randint(0, images.shape[0], batch_size) if (input_domain == 'MAG' or ('COMPLEX' in input_domain)): n_ch_in = 1 else: n_ch_in = 2 if (output_domain == 'MAG' or ('COMPLEX' in output_domain)): n_ch_out = 1 else: n_ch_out = 2 inputs = np.empty((0, n, n, n_ch_in)) outputs = np.empty((0, n, n, n_ch_out)) masks = np.empty((0, n, n, n_ch_in)) if ('COMPLEX' in input_domain): masks = np.empty((0, n, n)) for j in batch_inds: true_img = np.expand_dims(images[j, :, :, :], 0) true_k = np.expand_dims(spectra[j, :, :, :], 0) mask = np.ones(true_k.shape) img_size = images.shape[1] num_points = int(img_size * corruption_frac) s = int(1 / (1 - corruption_frac)) arc_lines = int(32 / s) arc_low = int((50 - int(arc_lines / 2)) * n / 100) arc_high = int((50 + int(arc_lines / 2)) * n / 100) coord_list_low = np.concatenate([[i + j for j in range(s - 1)] for i in range(0, arc_low, s)], axis=0) coord_list_high = np.concatenate( [[i + j for j in range(s - 1)] for i in range(arc_high, n - (s - 1), s)], axis=0) coord_list = np.concatenate([coord_list_low, coord_list_high], axis=0) corrupt_k = true_k.copy() for k in range(len(coord_list)): corrupt_k[0, coord_list[k], :, :] = 0 mask[0, coord_list[k], :, :] = 0 corrupt_img = utils.convert_to_image_domain(corrupt_k) nf = np.max(corrupt_img) if (input_domain == 'FREQ'): inputs = np.append(inputs, corrupt_k / nf, axis=0) masks = np.append(masks, mask, axis=0) elif (input_domain == 'IMAGE'): inputs = np.append(inputs, corrupt_img / nf, axis=0) elif (input_domain == 'MAG'): corrupt_img = np.expand_dims( np.abs(utils.join_reim(corrupt_img)), -1) inputs = np.append(inputs, corrupt_img / nf, axis=0) elif (input_domain == 'COMPLEX_K'): corrupt_k = np.expand_dims(utils.join_reim(corrupt_k), -1) inputs = np.append(inputs, corrupt_k / nf, axis=0) elif (input_domain == 'COMPLEX_I'): corrupt_img = np.expand_dims(utils.join_reim(corrupt_img), -1) inputs = np.append(inputs, corrupt_img / nf, axis=0) if (output_domain == 'FREQ'): outputs = np.append(outputs, true_k / nf, axis=0) elif (output_domain == 'IMAGE'): outputs = np.append(outputs, true_img / nf, axis=0) elif (output_domain == 'MAG'): true_img = np.expand_dims(np.abs(utils.join_reim(true_img)), -1) outputs = np.append(outputs, true_img / nf, axis=0) elif (output_domain == 'COMPLEX_K'): true_k = np.expand_dims(utils.join_reim(true_k), -1) outputs = np.append(inputs, true_k / nf, axis=0) elif (output_domain == 'COMPLEX_I'): true_img = np.expand_dims(utils.join_reim(true_img), -1) outputs = np.append(inputs, true_img / nf, axis=0) if ('COMPLEX' in input_domain): mask = mask[:, :, :, 0] masks = np.append(masks, mask, axis=0) if (enforce_dc): yield ((inputs, masks), outputs) else: yield (inputs, outputs)