def save_zero_filled(data_dir, out_dir, which_challenge): reconstructions = {} for fname in tqdm(list(data_dir.glob("*.h5"))): with h5py.File(fname, "r") as hf: et_root = etree.fromstring(hf["ismrmrd_header"][()]) masked_kspace = transforms.to_tensor(hf["kspace"][()]) # extract target image width, height from ismrmrd header enc = ["encoding", "encodedSpace", "matrixSize"] crop_size = ( int(et_query(et_root, enc + ["x"])), int(et_query(et_root, enc + ["y"])), ) # inverse Fourier Transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) # crop input image image = transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if which_challenge == "multicoil": image = fastmri.rss(image, dim=1) reconstructions[fname.name] = image fastmri.save_reconstructions(reconstructions, out_dir)
def save_zero_filled(data_dir, out_dir, which_challenge): reconstructions = {} for f in data_dir.iterdir(): with h5py.File(f, "r") as hf: enc = ismrmrd.xsd.CreateFromDocument(hf["ismrmrd_header"][()]).encoding[0] masked_kspace = transforms.to_tensor(hf["kspace"][()]) # extract target image width, height from ismrmrd header crop_size = (enc.reconSpace.matrixSize.x, enc.reconSpace.matrixSize.y) # inverse Fourier Transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) # crop input image image = transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if which_challenge == "multicoil": image = fastmri.rss(image, dim=1) reconstructions[f.name] = image fastmri.save_reconstructions(reconstructions, out_dir)
def forward(self, masked_kspace, mask): sens_maps = self.sens_net(masked_kspace, mask) kspace_pred = masked_kspace.clone() for cascade in self.cascades: kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps) return fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace_pred)), dim=1)
def forward( self, masked_kspace: torch.Tensor, mask: torch.Tensor, num_low_frequencies: Optional[int] = None, ) -> torch.Tensor: sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) kspace_pred = masked_kspace.clone() for cascade in self.cascades: kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps) return fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace_pred)), dim=1)
def visualize_reconstruction(filepath, output_filepath=None): hf = h5py.File(filepath) recons = hf['reconstruction'][()].squeeze() recons_rss = fastmri.rss(T.to_tensor(recons), dim=0) img = np.abs(recons_rss.numpy()) if output_filepath is not None: if not output_filepath.parent.exists(): output_filepath.parent.mkdir(parents=True) plt.imshow(img, cmap='gray') plt.axis("off") plt.savefig(output_filepath, bbox_inches="tight", pad_inches=0) else: plt.imshow(img, cmap='gray') plt.show()
def _base_fastmri_unet_transform( kspace, mask, ground_truth, attrs, which_challenge="singlecoil", ): kspace = fastmri_transforms.to_tensor(kspace) mask = mask[..., :kspace.shape[-2]] # accounting for variable size masks masked_kspace = kspace * mask.unsqueeze(-1) + 0.0 # inverse Fourier transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # crop input to correct size if ground_truth is not None: crop_size = (ground_truth.shape[-2], ground_truth.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) # noinspection PyTypeChecker image = fastmri_transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if which_challenge == "multicoil": image = fastmri.rss(image) # normalize input image, mean, std = fastmri_transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image.unsqueeze(0), mean, std
def visualize_kspace(kspace, dim=None, crop=False, output_filepath=None): kspace = fastmri.ifft2c(kspace) if crop: crop_size = (kspace.shape[-2], kspace.shape[-2]) kspace = T.complex_center_crop(kspace, crop_size) kspace = fastmri.complex_abs(kspace) kspace, _, _ = T.normalize_instance(kspace, eps=1e-11) kspace = kspace.clamp(-6, 6) else: # Compute absolute value to get a real image kspace = fastmri.complex_abs(kspace) if dim is not None: kspace = fastmri.rss(kspace, dim=dim) img = np.abs(kspace.numpy()) if output_filepath is not None: if not output_filepath.parent.exists(): output_filepath.parent.mkdir(parents=True) plt.imshow(img, cmap='gray') plt.axis("off") plt.savefig(output_filepath, bbox_inches="tight", pad_inches=0) else: plt.imshow(img, cmap='gray') plt.show()
def __call__( self, kspace: np.ndarray, mask: np.ndarray, target: np.ndarray, attrs: Dict, fname: str, slice_num: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, int, float]: """ Args: kspace: Input k-space of shape (num_coils, rows, cols) for multi-coil data or (rows, cols) for single coil data. mask: Mask from the test dataset. target: Target image. attrs: Acquisition related information stored in the HDF5 object. fname: File name. slice_num: Serial number of the slice. Returns: tuple containing: image: Zero-filled input image. target: Target image converted to a torch.Tensor. mean: Mean value used for normalization. std: Standard deviation value used for normalization. fname: File name. slice_num: Serial number of the slice. """ kspace = to_tensor(kspace) # check for max value max_value = attrs["max"] if "max" in attrs.keys() else 0.0 # apply mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = apply_mask(kspace, self.mask_func, seed) else: masked_kspace = kspace # inverse Fourier transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) image = complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": image = fastmri.rss(image) # normalize input image, mean, std = normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # normalize target if target is not None: target = to_tensor(target) target = center_crop(target, crop_size) target = normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target, mean, std, fname, slice_num, max_value
# SSIM loss loss = fastmri.SSIMLoss() print(loss(slice_image_abs.unsqueeze(1), slice_image_abs.unsqueeze(1), data_range=slice_image_abs.max().reshape(-1))) # In[15]: show_coils(slice_image_abs, [0], cmap='gray') # As we can see, each coil in a multi-coil MRI scan focusses on a different region of the image. These coils can be combined into the full image using the Root-Sum-of-Squares (RSS) transform. # In[16]: slice_image_rss = fastmri.rss(slice_image_abs, dim=0) # In[17]: plt.imshow(np.abs(slice_image_rss.numpy()), cmap='gray') # So far, we have been looking at fully-sampled data. We can simulate under-sampled data by creating a mask and applying it to k-space. # In[18]: from fastmri.data.subsample import RandomMaskFunc mask_func = RandomMaskFunc(center_fractions=[0.08], accelerations=[4]) # Create the mask function object
def test_root_sum_of_squares(shape, dim): x = create_input(shape) out_torch = fastmri.rss(x, dim).numpy() out_numpy = np.sqrt(np.sum(x.numpy() ** 2, dim)) assert np.allclose(out_torch, out_numpy)
def __call__(self, kspace, mask, target, attrs, fname, slice_num): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset. target (numpy.array): Target image. attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name. slice_num (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. fname (str): File name. slice_num (int): Serial number of the slice. """ kspace = transforms.to_tensor(kspace) # apply mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = transforms.apply_mask( kspace, self.mask_func, seed) else: masked_kspace = kspace # inverse Fourier transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) image = transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": image = fastmri.rss(image) # normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target, mean, std, fname, slice_num
def __call__( self, kspace: np.ndarray, mask: np.ndarray, target: np.ndarray, attrs: Dict, fname: str, slice_num: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, int, float]: """ Args: kspace: Input k-space of shape (num_coils, rows, cols) for multi-coil data or (rows, cols) for single coil data. mask: Mask from the test dataset. target: Target image. attrs: Acquisition related information stored in the HDF5 object. fname: File name. slice_num: Serial number of the slice. Returns: tuple containing: image: Zero-filled input image. target: Target image converted to a torch.Tensor. mean: Mean value used for normalization. std: Standard deviation value used for normalization. fname: File name. slice_num: Serial number of the slice. """ kspace = to_tensor(kspace) # check for max value max_value = attrs["max"] if "max" in attrs.keys() else 0.0 # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) is_label = attrs["is_label"] if is_label: # Handling Label image if self.strong_mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = apply_mask(kspace, self.strong_mask_func, seed) else: masked_kspace = kspace image = fastmri.ifft2c(masked_kspace) # print("kspace shape:\n", kspace.shape) # print("labellel_image shape:\n", labelled_image.shape) # print("cropsize shape: 1\n", crop_size) # print("labelled_kspace shape:\n", labelled_kspace.shape) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) # print("cropsize shape: 2\n", crop_size) image = complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": image = fastmri.rss(image) image, label_mean, label_std = normalize_instance(image, eps=1e-11) lalbeled_image = image.clamp(-6, 6) # normalize target if target is not None: labeled_target = to_tensor(target) labeled_target = center_crop(labeled_target, crop_size) labeled_target = normalize(labeled_target, label_mean, label_std, eps=1e-11) labeled_target = labeled_target.clamp(-6, 6) else: labeled_target = torch.Tensor([0]) return lalbeled_image, lalbeled_image, labeled_target, label_mean, label_std, fname, slice_num, max_value # unlabel kspace image handling unlabelled_kspace = kspace if target is not None: unlabelled_target = target if self.weak_mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) weak_masked_kspace, weak_mask = apply_mask(kspace, self.weak_mask_func, seed) else: weak_masked_kspace = unlabelled_kspace # inverse Fourier transform to get zero filled solution weak_image = fastmri.ifft2c(weak_masked_kspace) # check for FLAIR 203 if weak_image.shape[-2] < crop_size[1]: crop_size = (weak_image.shape[-2], weak_image.shape[-2]) weak_image = complex_center_crop(weak_image, crop_size) # absolute value weak_image = fastmri.complex_abs(weak_image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": weak_image = fastmri.rss(weak_image) if self.strong_mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) strong_masked_kspace, strong_mask = apply_mask( unlabelled_kspace, self.strong_mask_func, seed) else: strong_masked_kspace = unlabelled_kspace # inverse Fourier transform to get zero filled solution strong_image = fastmri.ifft2c(strong_masked_kspace) # check for FLAIR 203 if strong_image.shape[-2] < crop_size[1]: crop_size = (strong_image.shape[-2], strong_image.shape[-2]) strong_image = complex_center_crop(strong_image, crop_size) # absolute value strong_image = fastmri.complex_abs(strong_image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": strong_image = fastmri.rss(strong_image) image_cat = torch.stack([weak_image, strong_image], dim=0) image_cat, unlabel_mean, unlabel_std = normalize_instance(image_cat, eps=1e-11) image_cat = image_cat.clamp(-6, 6) weak_image, strong_image = image_cat[0], image_cat[1] # normalize target if target is not None: unlabelled_target = to_tensor(unlabelled_target) unlabelled_target = center_crop(unlabelled_target, crop_size) unlabelled_target = normalize(unlabelled_target, unlabel_mean, unlabel_std, eps=1e-11) unlabelled_target = unlabelled_target.clamp(-6, 6) else: unlabelled_target = torch.Tensor([0]) return weak_image, strong_image, unlabelled_target, unlabel_mean, unlabel_std, fname, slice_num, max_value