def __getitem__(self, i): fname, slice_id = self.examples[i] with h5py.File(fname, 'r') as data: kspace = data["kspace"][slice_id] kspace = np.stack([kspace.real, kspace.imag], axis=-1) if self.random_rotate: kspace = ndimage.rotate(kspace, self.random_angles[i], reshape=False, mode='nearest') kspace = torch.from_numpy(kspace).permute(2, 0, 1) kspace = self.center_crop(kspace, self.image_shape).permute(1, 2, 0) kspace = fastmri.ifftshift(kspace, dim=(0, 1)) target = torch.ifft(kspace, 2, normalized=False) target = fastmri.ifftshift(target, dim=(0, 1)) # Normalize using mean of k-space in training data target /= 7.072103529760345e-07 kspace /= 7.072103529760345e-07 kspace = kspace.numpy() target = target.numpy() return self.transform(kspace, torch.zeros(kspace.shape[1]), target, dict(data.attrs), fname.name, slice_id)
def __getitem__(self, i): fname, slice_id = self.examples[i] with h5py.File(fname, "r") as data: kspace = data["kspace"][slice_id] kspace = torch.from_numpy(np.stack([kspace.real, kspace.imag], axis=-1)) kspace = fastmri.ifftshift(kspace, dim=(0, 1)) target = torch.fft.ifft(kspace, 2, norm=None) target = fastmri.ifftshift(target, dim=(0, 1)) # Normalize using mean of k-space in training data target /= 7.072103529760345e-07 kspace /= 7.072103529760345e-07 # Environment expects numpy arrays. The code above was used with an older # version of the environment to generate the results of the MICCAI'20 paper. # So, to keep this consistent with the version in the paper, we convert # the tensors back to numpy rather than changing the original code. kspace = kspace.numpy() target = target.numpy() return self.transform( kspace, torch.zeros(kspace.shape[1]), target, dict(data.attrs), fname.name, slice_id, )
def ifft_permute_maybe_shift(x: torch.Tensor, normalized: bool = False, ifft_shift: bool = False) -> torch.Tensor: x = x.permute(0, 2, 3, 1) y = torch.ifft(x, 2, normalized=normalized) if ifft_shift: y = fastmri.ifftshift(y, dim=(1, 2)) return y.permute(0, 3, 1, 2)
def forward(self, target, full_kspace, idx=None): """ Args: input (torch.Tensor): Input tensor of shape NHWC Returns: (torch.Tensor): Output tensor of shape NCHW """ old_mask = self._init_mask(full_kspace) pred_kspace = self._init_kspace(full_kspace, old_mask) old_recon = torch.zeros_like(target) uncertainty_map = torch.zeros_like(old_mask) reconstructions = [] zero_filled = [] uncertainty_maps = [] for i in range(self.num_step): pred_dict = self.step_forward(full_kspace, pred_kspace, old_mask, old_recon, target, i, uncertainty_map) new_img, old_recon, old_mask, uncertainty_map = ( pred_dict['output'], pred_dict['output'], pred_dict['mask'].detach(), pred_dict['uncertainty_map']) reconstructions.append(new_img) zero_filled.append(pred_dict['zero_recon']) uncertainty_maps.append(uncertainty_map) # transform back to kspace # NCHW -> NHWC image = torch.cat([new_img, torch.zeros_like(new_img)], dim=1).permute(0, 2, 3, 1) image_for_kspace = fastmri.ifftshift(image, dim=(1, 2)) pred_kspace = image_for_kspace.fft(2, normalized=False) pred_dict = { 'output': reconstructions, 'mask': old_mask, 'zero_filled_recon': zero_filled, 'uncertainty_maps': uncertainty_maps } return pred_dict
def _init_kspace(self, data, mask): """ data: NHWC (input image in kspace) mask: NCHW kspace: NHW2 """ kspace = data * mask.permute(0, 2, 3, 1) init_img = transforms.fftshift(transforms.ifft2(kspace), dim=(1, 2)).permute(0, -1, 1, 2) recon = self.reconstructor(init_img, None, mask.detach()) image = torch.cat([recon, torch.zeros_like(recon)], dim=1).permute(0, 2, 3, 1) image_for_kspace = fastmri.ifftshift(image, dim=(1, 2)) pred_kspace = image_for_kspace.fft(2, normalized=False) return pred_kspace
def test_ifftshift(shape): x = np.arange(np.product(shape)).reshape(shape) out_torch = fastmri.ifftshift(torch.from_numpy(x)).numpy() out_numpy = np.fft.ifftshift(x) assert np.allclose(out_torch, out_numpy)
def sample_low_frequency_mask( mask_args: Dict[str, Any], kspace_shapes: List[Tuple[int, ...]], rng: np.random.RandomState, attrs: Optional[List[Dict[str, Any]]] = None, ) -> torch.Tensor: """Samples low frequency masks. Returns masks that contain some number of the lowest k-space frequencies active. The number of frequencies doesn't have to be the same for all masks in the batch, and it can also be a random number, depending on the given ``mask_args``. Active columns will be represented as 1s in the mask, and inactive columns as 0s. The distribution and shape of the masks can be controlled by ``mask_args``. This is a dictionary with the following keys: - *"max_width"(int)*: The maximum width of the masks. - *"min_cols"(int)*: The minimum number of low frequencies columns to activate per side. - *"max_cols"(int)*: The maximum number of low frequencies columns to activate per side (inclusive). - *"width_dim"(int)*: Indicates which of the dimensions in ``kspace_shapes`` corresponds to the k-space width. - *"centered"(bool)*: Specifies if the low frequencies are in the center of the k-space (``True``) or on the edges (``False``). - *"apply_attrs_padding"(optional(bool))*: If ``True``, the function will read keys ``"padding_left"`` and ``"padding_right"`` from ``attrs`` and set all corresponding high-frequency columns to 1. The number of 1s in the effective region of the mask (see next paragraph) is sampled between ``mask_args["min_cols"]`` and ``mask_args["max_cols"]`` (inclusive). The number of dimensions for the mask tensor will be ``mask_args["width_dim"] + 2``. The size will be ``[batch_size, 1, ..., 1, mask_args["max_width"]]``. For example, with ``mask_args["width_dim"] = 1`` and ``mask_args["max_width"] = 368``, output tensor has shape ``[batch_size, 1, 368]``. This function supports simultaneously sampling masks for k-space of different number of columns. This is controlled by argument ``kspace_shapes``. From this list, the function will obtain 1) ``batch_size = len(kspace_shapes``), and 2) the width of the k-spaces for each element in the batch. The i-th mask will have ``kspace_shapes[item][mask_args["width_dim"]]`` *effective* columns. Note: The mask tensor returned will always have ``mask_args["max_width"]`` columns. However, for any element ``i`` s.t. ``kspace_shapes[i][mask_args["width_dim"]] < mask_args["max_width"]``, the function will then pad the extra k-space columns with 1s. The rest of the columns will be filled out as if the mask has the same width as that indicated by ``kspace_shape[i]``. Args: mask_args(dict(str,any)): Specifies configuration options for the masks, as explained above. kspace_shapes(list(tuple(int,...))): Specifies the shapes of the k-space data on which this mask will be applied, as explained above. rng(``np.random.RandomState``): A random number generator to sample the masks. attrs(dict(str,int)): Used to determine any high-frequency padding. It must contain keys ``"padding_left"`` and ``"padding_right"``. Returns: ``torch.Tensor``: The generated low frequency masks. """ batch_size = len(kspace_shapes) num_cols = [shape[mask_args["width_dim"]] for shape in kspace_shapes] mask = torch.zeros(batch_size, mask_args["max_width"]) num_low_freqs = rng.randint(mask_args["min_cols"], mask_args["max_cols"] + 1, size=batch_size) for i in range(batch_size): # If padding needs to be accounted for, only add low frequency lines # beyond the padding if attrs and mask_args.get("apply_attrs_padding", False): padding_left = attrs[i]["padding_left"] padding_right = attrs[i]["padding_right"] else: padding_left, padding_right = 0, num_cols[i] pad = (num_cols[i] - 2 * num_low_freqs[i] + 1) // 2 mask[i, pad:pad + 2 * num_low_freqs[i]] = 1 mask[i, :padding_left] = 1 mask[i, padding_right:num_cols[i]] = 1 if not mask_args["centered"]: mask[i, :num_cols[i]] = fastmri.ifftshift(mask[i, :num_cols[i]]) mask[i, num_cols[i]:mask_args["max_width"]] = 1 mask_shape = [batch_size] + [1] * (mask_args["width_dim"] + 1) mask_shape[mask_args["width_dim"] + 1] = mask_args["max_width"] return mask.view(*mask_shape)
def loss(self, pred_dict, target_dict, meta, loss_type): """ Args: pred_dict: output: reconstructed image from downsampled kspace measurement NCHW energy: negative entropy of the probability mask mask: the binazried sampling mask (used for visualization) target_dict: target: original fully sampled image NCHW meta: recon_weight: weight of reconstruction loss entropy_weight: weight of the entropy loss (to encourage exploration) """ target = target_dict['target'] label = target_dict['label'] pred = pred_dict['output'][-1] zero_filled = pred_dict['zero_filled_recon'][-1] gt_kspace = target_dict['kspace'] nll_loss = 0 if self.with_uncertainty: for i in range(len(pred_dict['output'])): pred = pred_dict['output'][i] uncertainty_map = pred_dict['uncertainty_maps'][i] nll_loss += torch.mean( compute_gaussian_nll_loss(pred, target, uncertainty_map)) if loss_type == 'l1': reconstruction_loss = F.l1_loss(pred, target, size_average=True) zero_loss = F.l1_loss(zero_filled, target, size_average=True) elif loss_type == 'ssim': reconstruction_loss = -torch.mean(compute_ssim_torch(pred, target)) zero_loss = -torch.mean(compute_ssim_torch(zero_filled, target)) elif loss_type == 'psnr': reconstruction_loss = -torch.mean(compute_psnr_torch(pred, target)) zero_loss = -torch.mean(compute_psnr_torch(zero_filled, target)) elif loss_type == 'xentropy': criterion = nn.CrossEntropyLoss() reconstruction_loss = criterion(pred, label) zero_loss = torch.from_numpy(np.array([0])) else: raise NotImplementedError # k-space loss image = torch.cat([pred, torch.zeros_like(pred)], dim=1).permute(0, 2, 3, 1) image_for_kspace = fastmri.ifftshift(image, dim=(1, 2)) pred_kspace = image_for_kspace.fft(2, normalized=False) pred_kspace = torch.norm(pred_kspace, dim=-1, keepdim=True) gt_kspace = torch.norm(gt_kspace, dim=-1, keepdim=True) pred_kspace = pred_kspace.permute(0, 3, 1, 2) gt_kspace = gt_kspace.permute(0, 3, 1, 2) kspace_loss = -torch.mean(compute_ssim_torch(pred_kspace, gt_kspace)) loss = reconstruction_loss * meta['recon_weight'] + kspace_loss * meta[ 'kspace_weight'] # + 10*zero_loss log_dict = { 'Total Loss': loss.item(), 'Zero Filled Loss': zero_loss.item(), 'K-space Loss': kspace_loss.item(), 'Recon loss': reconstruction_loss.item() } if self.with_uncertainty: loss += nll_loss * meta['uncertainty_weight'] log_dict.update({'Uncertainty Loss': nll_loss.item()}) return loss, log_dict