def __call__( self, kspace: np.ndarray, mask: np.ndarray, target: np.ndarray, attrs: Dict, fname: str, slice_num: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, str, int, float, torch.Tensor]: """ Args: kspace: Input k-space of shape (num_coils, rows, cols) for multi-coil data. mask: Mask from the test dataset. target: Target image. attrs: Acquisition related information stored in the HDF5 object. fname: File name. slice_num: Serial number of the slice. Returns: tuple containing: masked_kspace: k-space after applying sampling mask. mask: The applied sampling mask target: The target image (if applicable). fname: File name. slice_num: The slice index. max_value: Maximum image value. crop_size: The size to crop the final image. """ if target is not None: target = to_tensor(target) max_value = attrs["max"] else: target = torch.tensor(0) max_value = 0.0 kspace = to_tensor(kspace) seed = None if not self.use_seed else tuple(map(ord, fname)) acq_start = attrs["padding_left"] acq_end = attrs["padding_right"] crop_size = torch.tensor( [attrs["recon_size"][0], attrs["recon_size"][1]]) if self.mask_func: masked_kspace, mask = apply_mask(kspace, self.mask_func, seed, (acq_start, acq_end)) else: masked_kspace = kspace shape = np.array(kspace.shape) num_cols = shape[-2] shape[:-3] = 1 mask_shape = [1] * len(shape) mask_shape[-2] = num_cols mask = torch.from_numpy( mask.reshape(*mask_shape).astype(np.float32)) mask = mask.reshape(*mask_shape) mask[:, :, :acq_start] = 0 mask[:, :, acq_end:] = 0 return ( masked_kspace, mask.byte(), target, fname, slice_num, max_value, crop_size, )
def __call__( self, kspace: np.ndarray, sensitivity_map: np.ndarray, mask: np.ndarray, eta: np.ndarray, target: np.ndarray, attrs: Dict, fname: str, slice_idx: int, ) -> Tuple[torch.Tensor, Union[Union[List, torch.Tensor], torch.Tensor], Union[Optional[torch.Tensor], Any], Union[List, Any], Union[ Optional[torch.Tensor], Any], Union[torch.Tensor, Any], str, int, Union[Union[List, torch.Tensor], Any], ]: """ Apply the data transform. Parameters ---------- kspace: The kspace. sensitivity_map: The sensitivity map. mask: The mask. eta: The initial estimation. target: The target. attrs: The attributes. fname: The file name. slice_idx: The slice number. Returns ------- The transformed data. """ kspace = to_tensor(kspace) # This condition is necessary in case of auto estimation of sense maps. if sensitivity_map is not None and sensitivity_map.size != 0: sensitivity_map = to_tensor(sensitivity_map) # Apply zero-filling on kspace if self.kspace_zero_filling_size is not None and self.kspace_zero_filling_size not in ( "", "None"): padding_top = np.floor_divide( abs(int(self.kspace_zero_filling_size[0]) - kspace.shape[1]), 2) padding_bottom = padding_top padding_left = np.floor_divide( abs(int(self.kspace_zero_filling_size[1]) - kspace.shape[2]), 2) padding_right = padding_left kspace = torch.view_as_complex(kspace) kspace = torch.nn.functional.pad(kspace, pad=(padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=0) kspace = torch.view_as_real(kspace) sensitivity_map = fft2( sensitivity_map, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) sensitivity_map = torch.view_as_complex(sensitivity_map) sensitivity_map = torch.nn.functional.pad( sensitivity_map, pad=(padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=0, ) sensitivity_map = torch.view_as_real(sensitivity_map) sensitivity_map = ifft2( sensitivity_map, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) # Initial estimation eta = to_tensor( eta) if eta is not None and eta.size != 0 else torch.tensor([]) # If the target is not given, we need to compute it. if self.coil_combination_method.upper() == "RSS": target = rss( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), dim=self.coil_dim, ) elif self.coil_combination_method.upper() == "SENSE": if sensitivity_map is not None and sensitivity_map.size != 0: target = sense( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), sensitivity_map, dim=self.coil_dim, ) elif target is not None and target.size != 0: target = to_tensor(target) elif "target" in attrs or "target_rss" in attrs: target = torch.tensor(attrs["target"]) else: raise ValueError("No target found") target = torch.view_as_complex(target) target = torch.abs(target / torch.max(torch.abs(target))) seed = tuple(map(ord, fname)) if self.use_seed else None acq_start = attrs["padding_left"] if "padding_left" in attrs else 0 acq_end = attrs["padding_right"] if "padding_left" in attrs else 0 # This should be outside the condition because it needs to be returned in the end, even if cropping is off. # crop_size = torch.tensor([attrs["recon_size"][0], attrs["recon_size"][1]]) crop_size = target.shape if self.crop_size is not None and self.crop_size not in ("", "None"): # Check for smallest size against the target shape. h = min(int(self.crop_size[0]), target.shape[0]) w = min(int(self.crop_size[1]), target.shape[1]) # Check for smallest size against the stored recon shape in metadata. if crop_size[0] != 0: h = h if h <= crop_size[0] else crop_size[0] if crop_size[1] != 0: w = w if w <= crop_size[1] else crop_size[1] self.crop_size = (int(h), int(w)) target = center_crop(target, self.crop_size) if sensitivity_map is not None and sensitivity_map.size != 0: sensitivity_map = (ifft2( complex_center_crop( fft2( sensitivity_map, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), self.crop_size, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if self.kspace_crop else complex_center_crop( sensitivity_map, self.crop_size)) if eta is not None and eta.ndim > 2: eta = (ifft2( complex_center_crop( fft2( eta, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), self.crop_size, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if self.kspace_crop else complex_center_crop( eta, self.crop_size)) # Cropping before masking will maintain the shape of original kspace intact for masking. if self.crop_size is not None and self.crop_size not in ( "", "None") and self.crop_before_masking: kspace = (complex_center_crop(kspace, self.crop_size) if self.kspace_crop else fft2( complex_center_crop( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), self.crop_size, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )) # Undersample kspace if undersampling is enabled. if self.mask_func is None: masked_kspace = kspace acc = torch.tensor([np.around(mask.size / mask.sum()) ]) if mask is not None else torch.tensor([1]) if mask is None: mask = torch.ones( [masked_kspace.shape[-3], masked_kspace.shape[-2]], dtype=torch.float32 # type: ignore ) else: mask = torch.from_numpy(mask) if mask.shape[0] == masked_kspace.shape[2]: # type: ignore mask = mask.permute(1, 0) elif mask.shape[0] != masked_kspace.shape[1]: # type: ignore mask = torch.ones( [masked_kspace.shape[-3], masked_kspace.shape[-2]], dtype=torch.float32 # type: ignore ) if mask.ndim == 1: mask = np.expand_dims(mask, axis=0) if mask.shape[-2] == 1: # 1D mask mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(-1) else: # 2D mask # Crop loaded mask. if self.crop_size is not None and self.crop_size not in ( "", "None"): mask = center_crop(mask, self.crop_size) mask = mask.unsqueeze(0).unsqueeze(-1) if self.shift_mask: mask = torch.fft.fftshift(mask, dim=[-3, -2]) masked_kspace = masked_kspace * mask mask = mask.byte() elif isinstance(self.mask_func, list): masked_kspaces = [] masks = [] accs = [] for m in self.mask_func: _masked_kspace, _mask, _acc = apply_mask( kspace, m, seed, (acq_start, acq_end), shift=self.shift_mask, half_scan_percentage=self.half_scan_percentage, center_scale=self.mask_center_scale, ) masked_kspaces.append(_masked_kspace) masks.append(_mask.byte()) accs.append(_acc) masked_kspace = masked_kspaces mask = masks acc = accs else: masked_kspace, mask, acc = apply_mask( kspace, self.mask_func[0], # type: ignore seed, (acq_start, acq_end), shift=self.shift_mask, half_scan_percentage=self.half_scan_percentage, center_scale=self.mask_center_scale, ) mask = mask.byte() # Cropping after masking. if self.crop_size is not None and self.crop_size not in ( "", "None") and not self.crop_before_masking: masked_kspace = (complex_center_crop(masked_kspace, self.crop_size) if self.kspace_crop else fft2( complex_center_crop( ifft2( masked_kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), self.crop_size, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )) mask = center_crop(mask.squeeze(-1), self.crop_size).unsqueeze(-1) # Normalize by the max value. if self.normalize_inputs: if isinstance(self.mask_func, list): masked_kspaces = [] for y in masked_kspace: if self.fft_normalization in ("orthogonal", "orthogonal_norm_only", "ortho"): imspace = ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) imspace = imspace / torch.max(torch.abs(imspace)) masked_kspaces.append( fft2( imspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )) elif self.fft_normalization == "fft_norm": imspace = ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) masked_kspaces.append( fft2( imspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )) elif self.fft_normalization == "backward": imspace = ifft2(y, centered=self.fft_centered, normalization="backward", spatial_dims=self.spatial_dims) masked_kspaces.append( fft2( imspace, centered=self.fft_centered, normalization="backward", spatial_dims=self.spatial_dims, )) else: imspace = torch.fft.ifftn(torch.view_as_complex(y), dim=[-2, -1], norm=None) imspace = imspace / torch.max(torch.abs(imspace)) masked_kspaces.append( torch.view_as_real( torch.fft.fftn(imspace, dim=[-2, -1], norm=None))) masked_kspace = masked_kspaces elif self.fft_normalization in ("orthogonal", "orthogonal_norm_only", "ortho"): imspace = ifft2( masked_kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) imspace = imspace / torch.max(torch.abs(imspace)) masked_kspace = fft2( imspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) elif self.fft_normalization == "fft_norm": masked_kspace = fft2( ifft2( masked_kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) elif self.fft_normalization == "backward": masked_kspace = fft2( ifft2( masked_kspace, centered=self.fft_centered, normalization="backward", spatial_dims=self.spatial_dims, ), centered=self.fft_centered, normalization="backward", spatial_dims=self.spatial_dims, ) else: imspace = torch.fft.ifftn(torch.view_as_complex(masked_kspace), dim=[-2, -1], norm=None) imspace = imspace / torch.max(torch.abs(imspace)) masked_kspace = torch.view_as_real( torch.fft.fftn(imspace, dim=[-2, -1], norm=None)) if sensitivity_map.size != 0: sensitivity_map = sensitivity_map / torch.max( torch.abs(sensitivity_map)) if eta.size != 0 and eta.ndim > 2: eta = eta / torch.max(torch.abs(eta)) target = target / torch.max(torch.abs(target)) return kspace, masked_kspace, sensitivity_map, mask, eta, target, fname, slice_idx, acc