def __call__(self, target, fname, slice_num=0, attrs=None, seed=None): # Preprocess the data here # target shape: [H, W, 1] or [H, W, 3] img = target if target.shape[2] != 2: img = np.concatenate((target, np.zeros_like(target)), axis=2) assert img.shape[-1] == 2 img = to_tensor(img) kspace = fastmri.fft2c(img) center_kspace, _ = apply_mask(kspace, self.mask_func, hamming=True, seed=seed) img_LF = fastmri.complex_abs(fastmri.ifft2c(center_kspace)) img_LF = img_LF.unsqueeze(0) image, mean, std = normalize_instance(img_LF, eps=1e-11) image = image.clamp(-6, 6) # img_LF tensor should have shape [H, W, ?] target = to_tensor(np.transpose(target, (2, 0, 1))) # target shape [1, H, W] target = normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) target = target.squeeze(0) # check for max value max_value = 0.0 # print('traget shape', target.shape) # print('image shape', image.shape) return image, target, mean, std, fname, slice_num, max_value
def data_consistency_kspace(self, prediction, k_space_slice, mask): """ Args: - prediction: net (or block) predicted real image in complex domain - k_space_slice: initially sampled elements in k-space - mask: corresponding nonzero location in kspace Res: image in k space where: - masked entries of initial slice are replaced with entries predicted by output - non masked entries of initial slice stay the same """ prediction = prediction[:, :, 0:320, 0:320] prediction = prediction.permute( 0, 2, 3, 1) # prediction from 1 x 2 x h x w to 1 x h x w x 2 prediction = self.proper_padding( prediction, k_space_slice) # pad prediction to be 640 x 372 x 2 k_space_prediction = fastmri.fft2c( prediction) # transform prediction to kspace domain k_space_out = ( 1 - mask) * k_space_prediction + mask * k_space_slice # apply mask prediction = fastmri.ifft2c(k_space_out) # back to cplx image prediction = transforms.complex_center_crop( prediction, (320, 320)) # crop image to 320 x 320 prediction = prediction.permute(0, 3, 1, 2) # back to 1 x 2 x h x w return prediction
def add_ge_oversampling(masked_kspace): # call all of these "masked_kspace" to preserve memory readout_len = masked_kspace.shape[-3] pad_size = (0, 0, 0, 0, readout_len // 2, readout_len // 2) masked_kspace = fastmri.ifft2c(masked_kspace) masked_kspace = F.pad(masked_kspace, pad_size) masked_kspace = fastmri.fft2c(masked_kspace) return masked_kspace
def test_fft2(shape): shape = shape + [2] x = create_input(shape) out_torch = fastmri.fft2c(x).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] input_numpy = transforms.tensor_to_complex_np(x) input_numpy = np.fft.ifftshift(input_numpy, (-2, -1)) out_numpy = np.fft.fft2(input_numpy, norm="ortho") out_numpy = np.fft.fftshift(out_numpy, (-2, -1)) assert np.allclose(out_torch, out_numpy)
def forward( self, current_kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: zero = torch.zeros(1, 1, 1, 1).to(current_kspace) soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight model_term = fastmri.fft2c(self.model(fastmri.ifft2c(current_kspace))) return current_kspace - soft_dc - model_term
def forward( self, current_kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: zero = torch.zeros(1, 1, 1, 1).to(current_kspace) soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight with torch.enable_grad(): X = current_kspace.clone().detach().requires_grad_(True) histl = hist_loss(X, ref_kspace.clone().detach()) histl.backward() soft_histc = X.grad * self.hist_weight model_term = fastmri.fft2c(self.model(fastmri.ifft2c(current_kspace))) return current_kspace - soft_dc - soft_histc - model_term
def forward(self, mr_img, mk_space, mask): # Loop over num = cascades number for i in range(self.n_cascade): # The laplacian decomposition will produce different images composing the gaussian/laplacian pyramids # The gaussians are the downsampled images from the original one # The laplacian represents the 'errors' within the images related to the gaussian images gaussian_3, lap_1, lap_2 = self.lapl_dec(mr_img.cpu()) # Resize along with the input/output channels mr_img = self.shuffle_down_4(mr_img) # The convBlock1 is here performed to extract shallow features mr_img = self.convBlock1(mr_img) # ---- 3 branches ---- branch_1 = self.shuffle_up_4(mr_img) branch_1 = self.branch1(branch_1) branch_2 = self.shuffle_up_2(mr_img) branch_2 = self.branch2(branch_2) branch_3 = self.branch3(mr_img) branch_out1 = torch.stack((self.pixel_conv(branch_1[...,0], lap_1[...,0]), self.pixel_conv(branch_1[...,1], lap_1[...,1])), dim=-1) branch_out2 = torch.stack((self.pixel_conv(branch_2[...,0], lap_2[...,0]), self.pixel_conv(branch_2[...,1], lap_2[...,1])), dim=-1) branch_out3 = torch.stack((self.pixel_conv(branch_3[...,0], gaussian_3[...,0]), self.pixel_conv(branch_3[...,1], gaussian_3[...,1])), dim=-1) # Performing the 2x linear upsample in order to add the different branches correctly output = self.lapl_rec(branch_out3, branch_out2) output = self.lapl_rec(output, branch_out1) # The DataConsistency Layer step is done here mr_img = fastmri.ifft2c((1.0 - mask) * fastmri.fft2c(output) + mk_space) return mr_img
def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
def sens_expand(x): return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
def __call__(self, kspace, mask, target, attrs, fname, slice_num): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset. target (numpy.array): Target image. attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name. slice_num (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. fname (str): File name. slice_num (int): Serial number of the slice. """ kspace = transforms.to_tensor(kspace) image = fastmri.ifft2c(kspace) # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for sFLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) image = transforms.complex_center_crop(image, crop_size) #getLR imgfft = fastmri.fft2c(image) imgfft = transforms.complex_center_crop(imgfft,(160,160)) LR_image = fastmri.ifft2c(imgfft) # absolute value LR_image = fastmri.complex_abs(LR_image) # normalize input LR_image, mean, std = transforms.normalize_instance(LR_image, eps=1e-11) LR_image = LR_image.clamp(-6, 6) # normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return LR_image, target, mean, std, fname, slice_num