def test_complex_center_crop(shape, target_shape): shape = shape + [2] input = create_input(shape) out_torch = transforms.complex_center_crop(input, target_shape).numpy() assert list(out_torch.shape) == target_shape + [ 2, ]
def save_zero_filled(data_dir, out_dir, which_challenge): reconstructions = {} for fname in tqdm(list(data_dir.glob("*.h5"))): with h5py.File(fname, "r") as hf: et_root = etree.fromstring(hf["ismrmrd_header"][()]) masked_kspace = transforms.to_tensor(hf["kspace"][()]) # extract target image width, height from ismrmrd header enc = ["encoding", "encodedSpace", "matrixSize"] crop_size = ( int(et_query(et_root, enc + ["x"])), int(et_query(et_root, enc + ["y"])), ) # inverse Fourier Transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) # crop input image image = transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if which_challenge == "multicoil": image = fastmri.rss(image, dim=1) reconstructions[fname.name] = image fastmri.save_reconstructions(reconstructions, out_dir)
def to_cropped_image(masked_kspace, target, attrs): # inverse Fourier transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) image = T.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # normalize input image, mean, std = T.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # normalize target if target is not None: if isinstance(target, np.ndarray): target = T.to_tensor(target) target = T.center_crop(target, crop_size) target = T.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target
def save_zero_filled(data_dir, out_dir, which_challenge): reconstructions = {} for f in data_dir.iterdir(): with h5py.File(f, "r") as hf: enc = ismrmrd.xsd.CreateFromDocument(hf["ismrmrd_header"][()]).encoding[0] masked_kspace = transforms.to_tensor(hf["kspace"][()]) # extract target image width, height from ismrmrd header crop_size = (enc.reconSpace.matrixSize.x, enc.reconSpace.matrixSize.y) # inverse Fourier Transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) # crop input image image = transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if which_challenge == "multicoil": image = fastmri.rss(image, dim=1) reconstructions[f.name] = image fastmri.save_reconstructions(reconstructions, out_dir)
def load_data_from_pathlist(path): file_num = len(path) use_num = file_num // 3 total_target_list = [] total_sampled_image_list = [] for h5_num in range(use_num): total_kspace, slices_num, target = load_dataset(path[h5_num]) image_list = [] slice_kspace_tensor_list = [] target_image_list = [] for i in range(slices_num): slice_kspace = total_kspace[i] #target_image = target[i] slice_kspace_tensor = T.to_tensor( slice_kspace) # convert numpy to tensor slice_kspace_tensor = slice_kspace_tensor.float() #print(slice_kspace_tensor.shape) slice_kspace_tensor_list.append( slice_kspace_tensor) # 35* torch[640, 368]) #target = target_image_list.append(target_image) #image_list_tensor = torch.stack(image_list, dim=0) # torch.Size([35, 640, 368]) #total_image_list.append(image_list_tensor) mask_func = RandomMaskFunc( center_fractions=[0.08], accelerations=[4]) # create the mask function object sampled_image_list = [] target_list = [] for i in range(slices_num): slice_kspace_tensor = slice_kspace_tensor_list[i] masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func) Ny, Nx, _ = slice_kspace_tensor.shape mask = mask.repeat(Ny, 1, 1).squeeze() # functions.show_slice(mask, cmap='gray') # functions.show_slice(image_list[10], cmap='gray') sampled_image = fastmri.ifft2c( masked_kspace) # inverse fast FT to get the complex image sampled_image = T.complex_center_crop(sampled_image, (320, 320)) sampled_image_abs = fastmri.complex_abs(sampled_image) sampled_image_list.append(sampled_image_abs) sampled_image_list_tensor = torch.stack( sampled_image_list, dim=0) # torch.Size([35, 640, 368]) total_sampled_image_list.append(sampled_image_list_tensor) target = T.to_tensor(target) total_target_list.append(target) #target_image_tensor = torch.cat(target_image_list, dim=0) # torch.Size([6965, 640, 368]) total_target = torch.cat(total_target_list, dim=0) total_sampled_image_tensor = torch.cat( total_sampled_image_list, dim=0) # torch.Size([6965, 640, 368]) total_sampled_image_tensor, mean, std = T.normalize_instance( total_sampled_image_tensor, eps=1e-11) total_sampled_image_tensor = total_sampled_image_tensor.clamp(-6, 6) target_image_tensor = T.normalize(total_target, mean, std, eps=1e-11) target_image_tensor = target_image_tensor.clamp(-6, 6) # total_image_tensor = torch.stack(total_image_list, dim=0) # torch.Size([199, 35, 640, 368]) # total_sampled_image_tensor = torch.stack(total_sampled_image_list, dim=0) # torch.Size([199, 35, 640, 368]) #print(target_image_tensor.shape) #print(total_sampled_image_tensor.shape) return target_image_tensor, total_sampled_image_tensor
def _base_fastmri_unet_transform( kspace, mask, ground_truth, attrs, which_challenge="singlecoil", ): kspace = fastmri_transforms.to_tensor(kspace) mask = mask[..., :kspace.shape[-2]] # accounting for variable size masks masked_kspace = kspace * mask.unsqueeze(-1) + 0.0 # inverse Fourier transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # crop input to correct size if ground_truth is not None: crop_size = (ground_truth.shape[-2], ground_truth.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) # noinspection PyTypeChecker image = fastmri_transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if which_challenge == "multicoil": image = fastmri.rss(image) # normalize input image, mean, std = fastmri_transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image.unsqueeze(0), mean, std
def visualize_kspace(kspace, dim=None, crop=False, output_filepath=None): kspace = fastmri.ifft2c(kspace) if crop: crop_size = (kspace.shape[-2], kspace.shape[-2]) kspace = T.complex_center_crop(kspace, crop_size) kspace = fastmri.complex_abs(kspace) kspace, _, _ = T.normalize_instance(kspace, eps=1e-11) kspace = kspace.clamp(-6, 6) else: # Compute absolute value to get a real image kspace = fastmri.complex_abs(kspace) if dim is not None: kspace = fastmri.rss(kspace, dim=dim) img = np.abs(kspace.numpy()) if output_filepath is not None: if not output_filepath.parent.exists(): output_filepath.parent.mkdir(parents=True) plt.imshow(img, cmap='gray') plt.axis("off") plt.savefig(output_filepath, bbox_inches="tight", pad_inches=0) else: plt.imshow(img, cmap='gray') plt.show()
def __call__(self, kspace, mask, target, attrs, fname, slice_num): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset. target (numpy.array): Target image. attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name. slice_num (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. fname (str): File name. slice_num (int): Serial number of the slice. """ kspace = transforms.to_tensor(kspace) image = fastmri.ifft2c(kspace) # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for sFLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) image = transforms.complex_center_crop(image, crop_size) #getLR imgfft = fastmri.fft2c(image) imgfft = transforms.complex_center_crop(imgfft,(160,160)) LR_image = fastmri.ifft2c(imgfft) # absolute value LR_image = fastmri.complex_abs(LR_image) # normalize input LR_image, mean, std = transforms.normalize_instance(LR_image, eps=1e-11) LR_image = LR_image.clamp(-6, 6) # normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return LR_image, target, mean, std, fname, slice_num
dest_path = Path(dest_data_folder) for fname in sorted(files): print(fname) orig = h5py.File(fname) dest = h5py.File(dest_path / fname.name, 'a') volume_kspace = orig['kspace'][()] kspace_list = [] reconstruction_list = [] total_slices = volume_kspace.shape[0] for i in range(total_slices // 2 - num_middle_slices // 2, total_slices // 2 + num_middle_slices // 2 + 1): slice_kspace = volume_kspace[i] slice_kspace2 = to_tensor( slice_kspace) # Convert from numpy array to pytorch tensor kspace_crop = complex_center_crop(slice_kspace2, (size, size)) ift = fastmri.ifft2c( kspace_crop ) # Apply Inverse Fourier Transform to get the complex image reconstruction = fastmri.complex_abs(ift) kspace_list.append(tensor_to_complex_np(kspace_crop)) reconstruction_list.append(reconstruction) stacked_kspace = np.stack(kspace_list) stacked_reconstruction_esc = np.stack(reconstruction_list) dest['kspace'] = stacked_kspace dest['reconstruction_esc'] = stacked_reconstruction_esc dest['ismrmrd_header'] = orig['ismrmrd_header'][()] dest.attrs['norm'] = np.linalg.norm(stacked_reconstruction_esc)
def __call__(self, kspace, mask, target, attrs, fname, slice_num): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset. target (numpy.array): Target image. attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name. slice_num (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. fname (str): File name. slice_num (int): Serial number of the slice. """ kspace = transforms.to_tensor(kspace) # apply mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = transforms.apply_mask( kspace, self.mask_func, seed) else: masked_kspace = kspace # inverse Fourier transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) image = transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": image = fastmri.rss(image) # normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target, mean, std, fname, slice_num
def __call__( self, kspace: np.ndarray, mask: np.ndarray, target: np.ndarray, attrs: Dict, fname: str, slice_num: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, int, float]: """ Args: kspace: Input k-space of shape (num_coils, rows, cols) for multi-coil data or (rows, cols) for single coil data. mask: Mask from the test dataset. target: Target image. attrs: Acquisition related information stored in the HDF5 object. fname: File name. slice_num: Serial number of the slice. Returns: tuple containing: image: Zero-filled input image. target: Target image converted to a torch.Tensor. mean: Mean value used for normalization. std: Standard deviation value used for normalization. fname: File name. slice_num: Serial number of the slice. """ kspace = T.to_tensor(kspace) # check for max value max_value = attrs["max"] if "max" in attrs.keys() else 0.0 # apply mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = T.apply_mask(kspace, self.mask_func, seed) else: masked_kspace = kspace # inverse Fourier transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) if not self.test_mode: # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for FLAIR 203 if self.test_mode or image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) image = T.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": image = fastmri.rss(image) # normalize input image, mean, std = T.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # normalize target if not self.test_mode and target is not None: target = T.to_tensor(target) target = T.center_crop(target, crop_size) target = T.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target, mean, std, fname, slice_num, max_value