def training_step(self, batch, batch_idx): subsampled_kspace, _, _, _, _, _, mask1 = batch mask2 = -(mask1-1) subsampled_kspace1 = subsampled_kspace * mask1 + 0.0 subsampled_kspace2 = subsampled_kspace * mask2 + 0.0 image1 = fastmri.ifft2c(subsampled_kspace1) image1 = fastmri.complex_abs(image1) image2 = fastmri.ifft2c(subsampled_kspace2) image2 = fastmri.complex_abs(image2) image = torch.vstack((image1, image2)) output_image = self(image) output_image1 = output_image[0,:,:] output_image2 = output_image[1,:,:] output_kspace1 = torch.fft.fft2(output_image1) output_kspace1 = torch.stack((output_kspace1.real, output_kspace1.imag), axis=-1) output_kspace1 = output_kspace1 * mask2 + 0.0 output_kspace2 = torch.fft.fft2(output_image2) output_kspace2 = torch.stack((output_kspace2.real, output_kspace2.imag), axis=-1) output_kspace2 = output_kspace2 * mask1 + 0.0 loss = l1_l2_loss(output_kspace1, subsampled_kspace2) \ + l1_l2_loss(output_kspace2, subsampled_kspace1) self.log("loss", loss.detach()) return loss
def forward(self, kspace_pred: torch.Tensor, ref_kspace: torch.Tensor): """ Compute data consistency loss in kspace and total variation loss in image space. Inputs: - kspace_pred: PyTorch tensor of shape (N, H, W, 2) holding predicted kspace. - ref_kspace : input masked kspace Returns: - loss: PyTorch Variable holding a scalar giving the total variation loss for img. """ output = fastmri.complex_abs(fastmri.ifft2c(kspace_pred)) gt = fastmri.complex_abs(fastmri.ifft2c(ref_kspace)) # energy_loss = torch.abs(torch.sum(output) - torch.sum(gt)) hdiff_pred = (output[:, :, :-1] - output[:, :, 1:]).view(-1) hdiff_gt = (gt[:, :, :-1] - gt[:, :, 1:]).view(-1) vdiff_pred = (output[:, :-1, :] - output[:, 1:, :]).view(-1) vdiff_gt = (gt[:, :-1, :] - gt[:, 1:, :]).view(-1) hdiff_var_loss = torch.sqrt( torch.var(hdiff_pred)) - 1.25 * torch.sqrt(torch.var(hdiff_gt)) vdiff_var_loss = torch.sqrt( torch.var(vdiff_pred)) - 1.25 * torch.sqrt(torch.var(vdiff_gt)) tv_loss = (torch.abs(hdiff_var_loss) + torch.abs(vdiff_var_loss)) return self.tv_weight * tv_loss
def hist_loss(current_kspace: torch.Tensor, masked_kspace: torch.Tensor, bins: int = 5): """ Inputs: - kspace_pred: PyTorch tensor of shape (N, H, W, 2) holding predicted kspace. - ref_kspace : input masked kspace - mask: the subsampling mask Returns: - loss: PyTorch Variable holding a scalar giving the total variation loss for img. """ output = fastmri.complex_abs(fastmri.ifft2c(current_kspace)) gt = fastmri.complex_abs(fastmri.ifft2c(masked_kspace)) hdiff_pred = (output[:, :, :-1] - output[:, :, 1:]).view(-1) hdiff_gt = (gt[:, :, :-1] - gt[:, :, 1:]).view(-1) hmin_pred, hmax_pred = hdiff_pred.min().item(), hdiff_pred.max().item() hmin_gt, hmax_gt = hdiff_gt.min().item(), hdiff_gt.max().item() hist_x = differentiable_histogram(hdiff_pred, bins=bins, min=hmin_pred, max=hmax_pred) hist_y = differentiable_histogram(hdiff_gt, bins=bins, min=hmin_gt, max=hmax_gt) hdiff_hist_loss = (hist_x - hist_y) / len(hdiff_pred) hdiff_hist_loss = torch.norm(hdiff_hist_loss) vdiff_pred = (output[:, :-1, :] - output[:, 1:, :]).view(-1) vdiff_gt = (gt[:, :-1, :] - gt[:, 1:, :]).view(-1) vmin_pred, vmax_pred = vdiff_pred.min().item(), vdiff_pred.max().item() vmin_gt, vmax_gt = vdiff_gt.min().item(), vdiff_gt.max().item() hist_x = differentiable_histogram(vdiff_pred, bins=bins, min=vmin_pred, max=vmax_pred) hist_y = differentiable_histogram(vdiff_gt, bins=bins, min=vmin_gt, max=vmax_gt) vdiff_hist_loss = (hist_x - hist_y) / len(vdiff_pred) vdiff_hist_loss = torch.norm(vdiff_hist_loss) output = output.view(-1) gt = gt.view(-1) gt_min, gt_max = gt.min().item(), gt.max().item() hist_x = differentiable_histogram(output, bins=bins, min=gt_min, max=gt_max) hist_y = differentiable_histogram(gt, bins=bins, min=gt_min, max=gt_max) gt_hist_loss = (hist_x - hist_y) / len(output) gt_hist_loss = torch.norm(gt_hist_loss) return (hdiff_hist_loss + vdiff_hist_loss + gt_hist_loss)
def load_data(file_dir_path): file_path = get_files(file_dir_path) file_num = len(file_path) total_image_list = [] total_sampled_image_list = [] for h5_num in range(file_num): total_kspace, slices_num = load_dataset(file_path[0]) image_list = [] slice_kspace_tensor_list = [] for i in range(slices_num): slice_kspace = total_kspace[i] slice_kspace_tensor = T.to_tensor( slice_kspace) # convert numpy to tensor slice_image = fastmri.ifft2c( slice_kspace_tensor) # inverse fast FT slice_image_abs = fastmri.complex_abs( slice_image) # compute the absolute value to get a real image image_list.append(slice_image_abs) slice_kspace_tensor_list.append( slice_kspace_tensor) # 35* torch[640, 368]) image_list_tensor = torch.stack(image_list, dim=0) # torch.Size([35, 640, 368]) total_image_list.append(image_list_tensor) mask_func = RandomMaskFunc( center_fractions=[0.08], accelerations=[4]) # create the mask function object sampled_image_list = [] for i in range(slices_num): slice_kspace_tensor = slice_kspace_tensor_list[i] masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func) Ny, Nx, _ = slice_kspace_tensor.shape mask = mask.repeat(Ny, 1, 1).squeeze() # functions.show_slice(mask, cmap='gray') # functions.show_slice(image_list[10], cmap='gray') sampled_image = fastmri.ifft2c( masked_kspace) # inverse fast FT to get the complex image sampled_image_abs = fastmri.complex_abs(sampled_image) sampled_image_list.append(sampled_image_abs) sampled_image_list_tensor = torch.stack( sampled_image_list, dim=0) # torch.Size([35, 640, 368]) total_sampled_image_list.append(sampled_image_list_tensor) # total_image_tensor = torch.cat(total_image_list, dim=0) # torch.Size([6965, 640, 368]) # total_sampled_image_tensor = torch.cat(total_sampled_image_list, dim=0) # torch.Size([6965, 640, 368]) total_image_tensor = torch.stack(total_image_list, dim=0) # torch.Size([199, 35, 640, 368]) total_sampled_image_tensor = torch.stack( total_sampled_image_list, dim=0) # torch.Size([199, 35, 640, 368]) print(total_image_tensor.shape) print(total_sampled_image_tensor.shape) return total_image_tensor, total_sampled_image_tensor
def to_cropped_image(masked_kspace, target, attrs): # inverse Fourier transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) image = T.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # normalize input image, mean, std = T.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # normalize target if target is not None: if isinstance(target, np.ndarray): target = T.to_tensor(target) target = T.center_crop(target, crop_size) target = T.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target
def save_zero_filled(data_dir, out_dir, which_challenge): reconstructions = {} for f in data_dir.iterdir(): with h5py.File(f, "r") as hf: enc = ismrmrd.xsd.CreateFromDocument(hf["ismrmrd_header"][()]).encoding[0] masked_kspace = transforms.to_tensor(hf["kspace"][()]) # extract target image width, height from ismrmrd header crop_size = (enc.reconSpace.matrixSize.x, enc.reconSpace.matrixSize.y) # inverse Fourier Transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) # crop input image image = transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if which_challenge == "multicoil": image = fastmri.rss(image, dim=1) reconstructions[f.name] = image fastmri.save_reconstructions(reconstructions, out_dir)
def save_zero_filled(data_dir, out_dir, which_challenge): reconstructions = {} for fname in tqdm(list(data_dir.glob("*.h5"))): with h5py.File(fname, "r") as hf: et_root = etree.fromstring(hf["ismrmrd_header"][()]) masked_kspace = transforms.to_tensor(hf["kspace"][()]) # extract target image width, height from ismrmrd header enc = ["encoding", "encodedSpace", "matrixSize"] crop_size = ( int(et_query(et_root, enc + ["x"])), int(et_query(et_root, enc + ["y"])), ) # inverse Fourier Transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) # crop input image image = transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if which_challenge == "multicoil": image = fastmri.rss(image, dim=1) reconstructions[fname.name] = image fastmri.save_reconstructions(reconstructions, out_dir)
def __call__(self, target, fname, slice_num=0, attrs=None, seed=None): # Preprocess the data here # target shape: [H, W, 1] or [H, W, 3] img = target if target.shape[2] != 2: img = np.concatenate((target, np.zeros_like(target)), axis=2) assert img.shape[-1] == 2 img = to_tensor(img) kspace = fastmri.fft2c(img) center_kspace, _ = apply_mask(kspace, self.mask_func, hamming=True, seed=seed) img_LF = fastmri.complex_abs(fastmri.ifft2c(center_kspace)) img_LF = img_LF.unsqueeze(0) image, mean, std = normalize_instance(img_LF, eps=1e-11) image = image.clamp(-6, 6) # img_LF tensor should have shape [H, W, ?] target = to_tensor(np.transpose(target, (2, 0, 1))) # target shape [1, H, W] target = normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) target = target.squeeze(0) # check for max value max_value = 0.0 # print('traget shape', target.shape) # print('image shape', image.shape) return image, target, mean, std, fname, slice_num, max_value
def load_data_from_pathlist(path): file_num = len(path) use_num = file_num // 3 total_target_list = [] total_sampled_image_list = [] for h5_num in range(use_num): total_kspace, slices_num, target = load_dataset(path[h5_num]) image_list = [] slice_kspace_tensor_list = [] target_image_list = [] for i in range(slices_num): slice_kspace = total_kspace[i] #target_image = target[i] slice_kspace_tensor = T.to_tensor( slice_kspace) # convert numpy to tensor slice_kspace_tensor = slice_kspace_tensor.float() #print(slice_kspace_tensor.shape) slice_kspace_tensor_list.append( slice_kspace_tensor) # 35* torch[640, 368]) #target = target_image_list.append(target_image) #image_list_tensor = torch.stack(image_list, dim=0) # torch.Size([35, 640, 368]) #total_image_list.append(image_list_tensor) mask_func = RandomMaskFunc( center_fractions=[0.08], accelerations=[4]) # create the mask function object sampled_image_list = [] target_list = [] for i in range(slices_num): slice_kspace_tensor = slice_kspace_tensor_list[i] masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func) Ny, Nx, _ = slice_kspace_tensor.shape mask = mask.repeat(Ny, 1, 1).squeeze() # functions.show_slice(mask, cmap='gray') # functions.show_slice(image_list[10], cmap='gray') sampled_image = fastmri.ifft2c( masked_kspace) # inverse fast FT to get the complex image sampled_image = T.complex_center_crop(sampled_image, (320, 320)) sampled_image_abs = fastmri.complex_abs(sampled_image) sampled_image_list.append(sampled_image_abs) sampled_image_list_tensor = torch.stack( sampled_image_list, dim=0) # torch.Size([35, 640, 368]) total_sampled_image_list.append(sampled_image_list_tensor) target = T.to_tensor(target) total_target_list.append(target) #target_image_tensor = torch.cat(target_image_list, dim=0) # torch.Size([6965, 640, 368]) total_target = torch.cat(total_target_list, dim=0) total_sampled_image_tensor = torch.cat( total_sampled_image_list, dim=0) # torch.Size([6965, 640, 368]) total_sampled_image_tensor, mean, std = T.normalize_instance( total_sampled_image_tensor, eps=1e-11) total_sampled_image_tensor = total_sampled_image_tensor.clamp(-6, 6) target_image_tensor = T.normalize(total_target, mean, std, eps=1e-11) target_image_tensor = target_image_tensor.clamp(-6, 6) # total_image_tensor = torch.stack(total_image_list, dim=0) # torch.Size([199, 35, 640, 368]) # total_sampled_image_tensor = torch.stack(total_sampled_image_list, dim=0) # torch.Size([199, 35, 640, 368]) #print(target_image_tensor.shape) #print(total_sampled_image_tensor.shape) return target_image_tensor, total_sampled_image_tensor
def test_complex_abs(shape): shape = shape + [2] x = create_input(shape) out_torch = fastmri.complex_abs(x).numpy() input_numpy = transforms.tensor_to_complex_np(x) out_numpy = np.abs(input_numpy) assert np.allclose(out_torch, out_numpy)
def forward(self, masked_kspace, mask): sens_maps = self.sens_net(masked_kspace, mask) kspace_pred = masked_kspace.clone() for cascade in self.cascades: kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps) return fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace_pred)), dim=1)
def visualize_kspace(kspace, dim=None, crop=False, output_filepath=None): kspace = fastmri.ifft2c(kspace) if crop: crop_size = (kspace.shape[-2], kspace.shape[-2]) kspace = T.complex_center_crop(kspace, crop_size) kspace = fastmri.complex_abs(kspace) kspace, _, _ = T.normalize_instance(kspace, eps=1e-11) kspace = kspace.clamp(-6, 6) else: # Compute absolute value to get a real image kspace = fastmri.complex_abs(kspace) if dim is not None: kspace = fastmri.rss(kspace, dim=dim) img = np.abs(kspace.numpy()) if output_filepath is not None: if not output_filepath.parent.exists(): output_filepath.parent.mkdir(parents=True) plt.imshow(img, cmap='gray') plt.axis("off") plt.savefig(output_filepath, bbox_inches="tight", pad_inches=0) else: plt.imshow(img, cmap='gray') plt.show()
def forward( self, masked_kspace: torch.Tensor, mask: torch.Tensor, num_low_frequencies: Optional[int] = None, ) -> torch.Tensor: sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) kspace_pred = masked_kspace.clone() for cascade in self.cascades: kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps) return fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace_pred)), dim=1)
def print_slice(complex_slice, name=None): """ This function should be used only to print the output of a CRNN-i unit, in that it considers a complex input image in 4 dimensions, with shape (1, 2, 320, 320) """ if name: print("***********" + name + "**********") slice = complex_slice.clone() slice = slice.detach() slice = slice.cpu() slice = slice.permute(0, 2, 3, 1) slice = slice.squeeze() slice_image_abs = fastmri.complex_abs(slice) plt.imshow(slice_image_abs, cmap='gray') plt.show()
def test_step(self, batch, batch_idx): masked_kspace, _, fname, slice_num, _, attrs, _ = batch = batch # image, _ = to_cropped_image(masked_kspace, None, attrs) image = fastmri.ifft2c(masked_kspace) image = fastmri.complex_abs(image) output = self.forward(image) return { "fname": fname, "slice": slice_num, "output": output.cpu().numpy(), }
def validation_step(self, batch, batch_idx): masked_kspace, mask, target, fname, slice_num, max_value, _ = batch kspace_pred = self(masked_kspace, mask) output = fastmri.complex_abs(fastmri.ifft2c(kspace_pred)) target, output = transforms.center_crop_to_smallest(target, output) return { "batch_idx": batch_idx, "fname": fname, "slice_num": slice_num, "max_value": max_value, "output": output, "target": target, "val_loss": self.loss(kspace_pred, masked_kspace), }
def test_step(self, batch, batch_idx): masked_kspace, mask, _, fname, slice_num, _, crop_size = batch crop_size = crop_size[0] # always have a batch size of 1 for varnet kspace_pred = self(masked_kspace, mask) output = fastmri.complex_abs(fastmri.ifft2c(kspace_pred)) # check for FLAIR 203 if output.shape[-1] < crop_size[1]: crop_size = (output.shape[-1], output.shape[-1]) output = transforms.center_crop(output, crop_size) return { "fname": fname, "slice": slice_num, "output": output.cpu().numpy(), }
def _base_fastmri_unet_transform( kspace, mask, ground_truth, attrs, which_challenge="singlecoil", ): kspace = fastmri_transforms.to_tensor(kspace) mask = mask[..., :kspace.shape[-2]] # accounting for variable size masks masked_kspace = kspace * mask.unsqueeze(-1) + 0.0 # inverse Fourier transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # crop input to correct size if ground_truth is not None: crop_size = (ground_truth.shape[-2], ground_truth.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) # noinspection PyTypeChecker image = fastmri_transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if which_challenge == "multicoil": image = fastmri.rss(image) # normalize input image, mean, std = fastmri_transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image.unsqueeze(0), mean, std
def training_step(self, batch, batch_idx): subsampled_kspace, _, _, _, _, _, mask_loss = batch mask_train = -(mask_loss-1) subsampled_kspace_train = subsampled_kspace * mask_train + 0.0 subsampled_kspace_loss = subsampled_kspace * mask_loss + 0.0 image_train = fastmri.ifft2c(subsampled_kspace_train) image_train = fastmri.complex_abs(image_train) output_image = self(image_train) output_kspace = torch.fft.fft2(output_image) output_kspace = torch.stack((output_kspace.real, output_kspace.imag), axis=-1) output_kspace_loss = output_kspace * mask_loss + 0.0 loss = l1_l2_loss(output_kspace_loss, subsampled_kspace_loss) self.log("loss", loss.detach()) return loss
def __call__( self, kspace: np.ndarray, mask: np.ndarray, target: np.ndarray, attrs: Dict, fname: str, slice_num: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, int, float]: """ Args: kspace: Input k-space of shape (num_coils, rows, cols) for multi-coil data or (rows, cols) for single coil data. mask: Mask from the test dataset. target: Target image. attrs: Acquisition related information stored in the HDF5 object. fname: File name. slice_num: Serial number of the slice. Returns: tuple containing: image: Zero-filled input image. target: Target image converted to a torch.Tensor. mean: Mean value used for normalization. std: Standard deviation value used for normalization. fname: File name. slice_num: Serial number of the slice. """ kspace = to_tensor(kspace) # check for max value max_value = attrs["max"] if "max" in attrs.keys() else 0.0 # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) is_label = attrs["is_label"] if is_label: # Handling Label image if self.strong_mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = apply_mask(kspace, self.strong_mask_func, seed) else: masked_kspace = kspace image = fastmri.ifft2c(masked_kspace) # print("kspace shape:\n", kspace.shape) # print("labellel_image shape:\n", labelled_image.shape) # print("cropsize shape: 1\n", crop_size) # print("labelled_kspace shape:\n", labelled_kspace.shape) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) # print("cropsize shape: 2\n", crop_size) image = complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": image = fastmri.rss(image) image, label_mean, label_std = normalize_instance(image, eps=1e-11) lalbeled_image = image.clamp(-6, 6) # normalize target if target is not None: labeled_target = to_tensor(target) labeled_target = center_crop(labeled_target, crop_size) labeled_target = normalize(labeled_target, label_mean, label_std, eps=1e-11) labeled_target = labeled_target.clamp(-6, 6) else: labeled_target = torch.Tensor([0]) return lalbeled_image, lalbeled_image, labeled_target, label_mean, label_std, fname, slice_num, max_value # unlabel kspace image handling unlabelled_kspace = kspace if target is not None: unlabelled_target = target if self.weak_mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) weak_masked_kspace, weak_mask = apply_mask(kspace, self.weak_mask_func, seed) else: weak_masked_kspace = unlabelled_kspace # inverse Fourier transform to get zero filled solution weak_image = fastmri.ifft2c(weak_masked_kspace) # check for FLAIR 203 if weak_image.shape[-2] < crop_size[1]: crop_size = (weak_image.shape[-2], weak_image.shape[-2]) weak_image = complex_center_crop(weak_image, crop_size) # absolute value weak_image = fastmri.complex_abs(weak_image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": weak_image = fastmri.rss(weak_image) if self.strong_mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) strong_masked_kspace, strong_mask = apply_mask( unlabelled_kspace, self.strong_mask_func, seed) else: strong_masked_kspace = unlabelled_kspace # inverse Fourier transform to get zero filled solution strong_image = fastmri.ifft2c(strong_masked_kspace) # check for FLAIR 203 if strong_image.shape[-2] < crop_size[1]: crop_size = (strong_image.shape[-2], strong_image.shape[-2]) strong_image = complex_center_crop(strong_image, crop_size) # absolute value strong_image = fastmri.complex_abs(strong_image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": strong_image = fastmri.rss(strong_image) image_cat = torch.stack([weak_image, strong_image], dim=0) image_cat, unlabel_mean, unlabel_std = normalize_instance(image_cat, eps=1e-11) image_cat = image_cat.clamp(-6, 6) weak_image, strong_image = image_cat[0], image_cat[1] # normalize target if target is not None: unlabelled_target = to_tensor(unlabelled_target) unlabelled_target = center_crop(unlabelled_target, crop_size) unlabelled_target = normalize(unlabelled_target, unlabel_mean, unlabel_std, eps=1e-11) unlabelled_target = unlabelled_target.clamp(-6, 6) else: unlabelled_target = torch.Tensor([0]) return weak_image, strong_image, unlabelled_target, unlabel_mean, unlabel_std, fname, slice_num, max_value
def forward(self, kspace_pred: torch.Tensor, ref_kspace: torch.Tensor): """ Compute data consistency loss in kspace and total variation loss in image space. Inputs: - kspace_pred: PyTorch tensor of shape (N, H, W, 2) holding predicted kspace. - ref_kspace : input masked kspace - mask: the subsampling mask Returns: - loss: PyTorch Variable holding a scalar giving the total variation loss for img. """ # zero = torch.zeros(1, 1, 1, 1).to(kspace_pred) # dc_loss = torch.where(mask, kspace_pred - ref_kspace, zero) # dc_loss = torch.norm(dc_loss) # dc_loss = torch.max(torch.abs(dc_loss)) # dc_loss = dc_loss / torch.max(torch.abs(ref_kspace)) output = fastmri.complex_abs(fastmri.ifft2c(kspace_pred)) gt = fastmri.complex_abs(fastmri.ifft2c(ref_kspace)) # energy_loss = torch.abs(torch.sum(output) - torch.sum(gt)) hdiff_pred = (output[:, :, :-1] - output[:, :, 1:]).view(-1) hdiff_gt = (gt[:, :, :-1] - gt[:, :, 1:]).view(-1) hmin_pred, hmax_pred = hdiff_pred.min().item(), hdiff_pred.max().item() hmin_gt, hmax_gt = hdiff_gt.min().item(), hdiff_gt.max().item() hist_x = self.differentiable_histogram(hdiff_pred, bins=self.bins, min=hmin_pred, max=hmax_pred) hist_y = self.differentiable_histogram(hdiff_gt, bins=self.bins, min=hmin_gt, max=hmax_gt) hdiff_hist_loss = (hist_x - hist_y) / len(hdiff_pred) hdiff_hist_loss = torch.norm(hdiff_hist_loss) vdiff_pred = (output[:, :-1, :] - output[:, 1:, :]).view(-1) vdiff_gt = (gt[:, :-1, :] - gt[:, 1:, :]).view(-1) vmin_pred, vmax_pred = vdiff_pred.min().item(), vdiff_pred.max().item() vmin_gt, vmax_gt = vdiff_gt.min().item(), vdiff_gt.max().item() hist_x = self.differentiable_histogram(vdiff_pred, bins=self.bins, min=vmin_pred, max=vmax_pred) hist_y = self.differentiable_histogram(vdiff_gt, bins=self.bins, min=vmin_gt, max=vmax_gt) vdiff_hist_loss = (hist_x - hist_y) / len(vdiff_pred) vdiff_hist_loss = torch.norm(vdiff_hist_loss) output = output.view(-1) gt = gt.view(-1) gt_min, gt_max = gt.min().item(), gt.max().item() hist_x = self.differentiable_histogram(output, bins=self.bins, min=gt_min, max=gt_max) hist_y = self.differentiable_histogram(gt, bins=self.bins, min=gt_min, max=gt_max) gt_hist_loss = (hist_x - hist_y) / len(output) gt_hist_loss = torch.norm(gt_hist_loss) hdiff_var_loss = torch.sqrt( torch.var(hdiff_pred)) - 1.5 * torch.sqrt(torch.var(hdiff_gt)) vdiff_var_loss = torch.sqrt( torch.var(vdiff_pred)) - 1.5 * torch.sqrt(torch.var(vdiff_gt)) # intensity_var_loss = torch.abs(torch.sqrt(torch.var(output.view(-1))) - # torch.sqrt(torch.var(gt.view(-1)))) tv_loss = (torch.abs(hdiff_var_loss) + torch.abs(vdiff_var_loss)) hist_loss = hdiff_hist_loss + vdiff_hist_loss return (self.intensity_weight * gt_hist_loss + self.hist_weight * hist_loss + self.tv_weight * tv_loss)
# The fastMRI repo contains some utlity functions to convert k-space into image space. These functions work on PyTorch Tensors. The to_tensor function can convert Numpy arrays to PyTorch Tensors. # In[9]: import fastmri from fastmri.data import transforms as T # In[10]: slice_kspace2 = T.to_tensor(slice_kspace) # Convert from numpy array to pytorch tensor slice_image = fastmri.ifft2c(slice_kspace2) # Apply Inverse Fourier Transform to get the complex image slice_image_abs = fastmri.complex_abs(slice_image) # Compute absolute value to get a real image # SSIM loss loss = fastmri.SSIMLoss() print(loss(slice_image_abs.unsqueeze(1), slice_image_abs.unsqueeze(1), data_range=slice_image_abs.max().reshape(-1))) # In[15]: show_coils(slice_image_abs, [0], cmap='gray') # As we can see, each coil in a multi-coil MRI scan focusses on a different region of the image. These coils can be combined into the full image using the Root-Sum-of-Squares (RSS) transform. # In[16]:
def __call__(self, kspace, mask, target, attrs, fname, slice_num): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset. target (numpy.array): Target image. attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name. slice_num (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. fname (str): File name. slice_num (int): Serial number of the slice. """ kspace = transforms.to_tensor(kspace) image = fastmri.ifft2c(kspace) # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for sFLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) image = transforms.complex_center_crop(image, crop_size) #getLR imgfft = fastmri.fft2c(image) imgfft = transforms.complex_center_crop(imgfft,(160,160)) LR_image = fastmri.ifft2c(imgfft) # absolute value LR_image = fastmri.complex_abs(LR_image) # normalize input LR_image, mean, std = transforms.normalize_instance(LR_image, eps=1e-11) LR_image = LR_image.clamp(-6, 6) # normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return LR_image, target, mean, std, fname, slice_num
def __call__(self, kspace, mask, target, attrs, fname, slice_num): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset. target (numpy.array): Target image. attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name. slice_num (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. fname (str): File name. slice_num (int): Serial number of the slice. """ kspace = transforms.to_tensor(kspace) # apply mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = transforms.apply_mask( kspace, self.mask_func, seed) else: masked_kspace = kspace # inverse Fourier transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) image = transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": image = fastmri.rss(image) # normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target, mean, std, fname, slice_num
def __call__( self, kspace: np.ndarray, mask: np.ndarray, target: np.ndarray, attrs: Dict, fname: str, slice_num: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, int, float]: """ Args: kspace: Input k-space of shape (num_coils, rows, cols) for multi-coil data or (rows, cols) for single coil data. mask: Mask from the test dataset. target: Target image. attrs: Acquisition related information stored in the HDF5 object. fname: File name. slice_num: Serial number of the slice. Returns: tuple containing: image: Zero-filled input image. target: Target image converted to a torch.Tensor. mean: Mean value used for normalization. std: Standard deviation value used for normalization. fname: File name. slice_num: Serial number of the slice. """ kspace = to_tensor(kspace) # check for max value max_value = attrs["max"] if "max" in attrs.keys() else 0.0 # apply mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = apply_mask(kspace, self.mask_func, seed) else: masked_kspace = kspace # inverse Fourier transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) image = complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": image = fastmri.rss(image) # normalize input image, mean, std = normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # normalize target if target is not None: target = to_tensor(target) target = center_crop(target, crop_size) target = normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target, mean, std, fname, slice_num, max_value
dest = h5py.File(dest_path / fname.name, 'a') volume_kspace = orig['kspace'][()] kspace_list = [] reconstruction_list = [] total_slices = volume_kspace.shape[0] for i in range(total_slices // 2 - num_middle_slices // 2, total_slices // 2 + num_middle_slices // 2 + 1): slice_kspace = volume_kspace[i] slice_kspace2 = to_tensor( slice_kspace) # Convert from numpy array to pytorch tensor kspace_crop = complex_center_crop(slice_kspace2, (size, size)) ift = fastmri.ifft2c( kspace_crop ) # Apply Inverse Fourier Transform to get the complex image reconstruction = fastmri.complex_abs(ift) kspace_list.append(tensor_to_complex_np(kspace_crop)) reconstruction_list.append(reconstruction) stacked_kspace = np.stack(kspace_list) stacked_reconstruction_esc = np.stack(reconstruction_list) dest['kspace'] = stacked_kspace dest['reconstruction_esc'] = stacked_reconstruction_esc dest['ismrmrd_header'] = orig['ismrmrd_header'][()] dest.attrs['norm'] = np.linalg.norm(stacked_reconstruction_esc) dest.attrs['max'] = np.max(stacked_reconstruction_esc) dest.close() orig.close()