Example #1
0
    def __getitem__(self, i):
        fname, slice_id = self.examples[i]
        with h5py.File(fname, 'r') as data:

            kspace = data["kspace"][slice_id]
            kspace = np.stack([kspace.real, kspace.imag], axis=-1)
            if self.random_rotate:
                kspace = ndimage.rotate(kspace,
                                        self.random_angles[i],
                                        reshape=False,
                                        mode='nearest')

            kspace = torch.from_numpy(kspace).permute(2, 0, 1)
            kspace = self.center_crop(kspace,
                                      self.image_shape).permute(1, 2, 0)

            kspace = fastmri.ifftshift(kspace, dim=(0, 1))

            target = torch.ifft(kspace, 2, normalized=False)
            target = fastmri.ifftshift(target, dim=(0, 1))

            # Normalize using mean of k-space in training data
            target /= 7.072103529760345e-07
            kspace /= 7.072103529760345e-07

            kspace = kspace.numpy()
            target = target.numpy()

            return self.transform(kspace, torch.zeros(kspace.shape[1]), target,
                                  dict(data.attrs), fname.name, slice_id)
Example #2
0
    def __getitem__(self, i):
        fname, slice_id = self.examples[i]
        with h5py.File(fname, "r") as data:
            kspace = data["kspace"][slice_id]
            kspace = torch.from_numpy(np.stack([kspace.real, kspace.imag], axis=-1))
            kspace = fastmri.ifftshift(kspace, dim=(0, 1))
            target = torch.fft.ifft(kspace, 2, norm=None)
            target = fastmri.ifftshift(target, dim=(0, 1))
            # Normalize using mean of k-space in training data
            target /= 7.072103529760345e-07
            kspace /= 7.072103529760345e-07

            # Environment expects numpy arrays. The code above was used with an older
            # version of the environment to generate the results of the MICCAI'20 paper.
            # So, to keep this consistent with the version in the paper, we convert
            # the tensors back to numpy rather than changing the original code.
            kspace = kspace.numpy()
            target = target.numpy()
            return self.transform(
                kspace,
                torch.zeros(kspace.shape[1]),
                target,
                dict(data.attrs),
                fname.name,
                slice_id,
            )
Example #3
0
def ifft_permute_maybe_shift(x: torch.Tensor,
                             normalized: bool = False,
                             ifft_shift: bool = False) -> torch.Tensor:
    x = x.permute(0, 2, 3, 1)
    y = torch.ifft(x, 2, normalized=normalized)
    if ifft_shift:
        y = fastmri.ifftshift(y, dim=(1, 2))
    return y.permute(0, 3, 1, 2)
Example #4
0
    def forward(self, target, full_kspace, idx=None):
        """
        Args:
            input (torch.Tensor): Input tensor of shape NHWC
        Returns:
            (torch.Tensor): Output tensor of shape NCHW
        """
        old_mask = self._init_mask(full_kspace)
        pred_kspace = self._init_kspace(full_kspace, old_mask)

        old_recon = torch.zeros_like(target)
        uncertainty_map = torch.zeros_like(old_mask)

        reconstructions = []
        zero_filled = []
        uncertainty_maps = []

        for i in range(self.num_step):
            pred_dict = self.step_forward(full_kspace, pred_kspace, old_mask,
                                          old_recon, target, i,
                                          uncertainty_map)

            new_img, old_recon, old_mask, uncertainty_map = (
                pred_dict['output'], pred_dict['output'],
                pred_dict['mask'].detach(), pred_dict['uncertainty_map'])

            reconstructions.append(new_img)
            zero_filled.append(pred_dict['zero_recon'])
            uncertainty_maps.append(uncertainty_map)

            # transform back to kspace
            # NCHW -> NHWC
            image = torch.cat([new_img, torch.zeros_like(new_img)],
                              dim=1).permute(0, 2, 3, 1)
            image_for_kspace = fastmri.ifftshift(image, dim=(1, 2))
            pred_kspace = image_for_kspace.fft(2, normalized=False)

        pred_dict = {
            'output': reconstructions,
            'mask': old_mask,
            'zero_filled_recon': zero_filled,
            'uncertainty_maps': uncertainty_maps
        }

        return pred_dict
Example #5
0
    def _init_kspace(self, data, mask):
        """
            data: NHWC (input image in kspace)
            mask: NCHW
            kspace: NHW2

        """
        kspace = data * mask.permute(0, 2, 3, 1)
        init_img = transforms.fftshift(transforms.ifft2(kspace),
                                       dim=(1, 2)).permute(0, -1, 1, 2)
        recon = self.reconstructor(init_img, None, mask.detach())

        image = torch.cat([recon, torch.zeros_like(recon)],
                          dim=1).permute(0, 2, 3, 1)
        image_for_kspace = fastmri.ifftshift(image, dim=(1, 2))
        pred_kspace = image_for_kspace.fft(2, normalized=False)

        return pred_kspace
Example #6
0
def test_ifftshift(shape):
    x = np.arange(np.product(shape)).reshape(shape)
    out_torch = fastmri.ifftshift(torch.from_numpy(x)).numpy()
    out_numpy = np.fft.ifftshift(x)

    assert np.allclose(out_torch, out_numpy)
Example #7
0
def sample_low_frequency_mask(
    mask_args: Dict[str, Any],
    kspace_shapes: List[Tuple[int, ...]],
    rng: np.random.RandomState,
    attrs: Optional[List[Dict[str, Any]]] = None,
) -> torch.Tensor:
    """Samples low frequency masks.

    Returns masks that contain some number of the lowest k-space frequencies active.
    The number of frequencies doesn't have to be the same for all masks in the batch, and
    it can also be a random number, depending on the given ``mask_args``. Active columns
    will be represented as 1s in the mask, and inactive columns as 0s.

    The distribution and shape of the masks can be controlled by ``mask_args``. This is a
    dictionary with the following keys:

        - *"max_width"(int)*: The maximum width of the masks.
        - *"min_cols"(int)*: The minimum number of low frequencies columns to activate per side.
        - *"max_cols"(int)*: The maximum number of low frequencies columns to activate
          per side (inclusive).
        - *"width_dim"(int)*: Indicates which of the dimensions in ``kspace_shapes``
          corresponds to the k-space width.
        - *"centered"(bool)*: Specifies if the low frequencies are in the center of the
          k-space (``True``) or on the edges (``False``).
        - *"apply_attrs_padding"(optional(bool))*: If ``True``, the function will read
          keys ``"padding_left"`` and ``"padding_right"`` from ``attrs`` and set all
          corresponding high-frequency columns to 1.

    The number of 1s in the effective region of the mask (see next paragraph) is sampled
    between ``mask_args["min_cols"]`` and ``mask_args["max_cols"]`` (inclusive).
    The number of dimensions for the mask tensor will be ``mask_args["width_dim"] + 2``.
    The size will be ``[batch_size, 1, ..., 1, mask_args["max_width"]]``. For example, with
    ``mask_args["width_dim"] = 1`` and ``mask_args["max_width"] = 368``, output tensor
    has shape ``[batch_size, 1, 368]``.

    This function supports simultaneously sampling masks for k-space of different number of
    columns. This is controlled by argument ``kspace_shapes``. From this list, the function will
    obtain 1) ``batch_size = len(kspace_shapes``), and 2) the width of the k-spaces for
    each element in the batch. The i-th mask will have
    ``kspace_shapes[item][mask_args["width_dim"]]``
    *effective* columns.


    Note:
        The mask tensor returned will always have
        ``mask_args["max_width"]`` columns. However, for any element ``i``
        s.t.  ``kspace_shapes[i][mask_args["width_dim"]] < mask_args["max_width"]``, the
        function will then pad the extra k-space columns with 1s. The rest of the columns
        will be filled out as if the mask has the same width as that indicated by
        ``kspace_shape[i]``.

    Args:
        mask_args(dict(str,any)): Specifies configuration options for the masks, as explained
            above.

        kspace_shapes(list(tuple(int,...))): Specifies the shapes of the k-space data on
            which this mask will be applied, as explained above.

        rng(``np.random.RandomState``): A random number generator to sample the masks.

        attrs(dict(str,int)): Used to determine any high-frequency padding. It must contain
            keys ``"padding_left"`` and ``"padding_right"``.

    Returns:
        ``torch.Tensor``: The generated low frequency masks.

    """
    batch_size = len(kspace_shapes)
    num_cols = [shape[mask_args["width_dim"]] for shape in kspace_shapes]
    mask = torch.zeros(batch_size, mask_args["max_width"])
    num_low_freqs = rng.randint(mask_args["min_cols"],
                                mask_args["max_cols"] + 1,
                                size=batch_size)
    for i in range(batch_size):
        # If padding needs to be accounted for, only add low frequency lines
        # beyond the padding
        if attrs and mask_args.get("apply_attrs_padding", False):
            padding_left = attrs[i]["padding_left"]
            padding_right = attrs[i]["padding_right"]
        else:
            padding_left, padding_right = 0, num_cols[i]

        pad = (num_cols[i] - 2 * num_low_freqs[i] + 1) // 2
        mask[i, pad:pad + 2 * num_low_freqs[i]] = 1
        mask[i, :padding_left] = 1
        mask[i, padding_right:num_cols[i]] = 1

        if not mask_args["centered"]:
            mask[i, :num_cols[i]] = fastmri.ifftshift(mask[i, :num_cols[i]])
        mask[i, num_cols[i]:mask_args["max_width"]] = 1

    mask_shape = [batch_size] + [1] * (mask_args["width_dim"] + 1)
    mask_shape[mask_args["width_dim"] + 1] = mask_args["max_width"]
    return mask.view(*mask_shape)
Example #8
0
    def loss(self, pred_dict, target_dict, meta, loss_type):
        """
        Args:
            pred_dict:
                output: reconstructed image from downsampled kspace measurement NCHW
                energy: negative entropy of the probability mask
                mask: the binazried sampling mask (used for visualization)

            target_dict:
                target: original fully sampled image NCHW

            meta:
                recon_weight: weight of reconstruction loss
                entropy_weight: weight of the entropy loss (to encourage exploration)
        """
        target = target_dict['target']
        label = target_dict['label']
        pred = pred_dict['output'][-1]
        zero_filled = pred_dict['zero_filled_recon'][-1]
        gt_kspace = target_dict['kspace']

        nll_loss = 0
        if self.with_uncertainty:
            for i in range(len(pred_dict['output'])):
                pred = pred_dict['output'][i]
                uncertainty_map = pred_dict['uncertainty_maps'][i]
                nll_loss += torch.mean(
                    compute_gaussian_nll_loss(pred, target, uncertainty_map))

        if loss_type == 'l1':
            reconstruction_loss = F.l1_loss(pred, target, size_average=True)
            zero_loss = F.l1_loss(zero_filled, target, size_average=True)
        elif loss_type == 'ssim':
            reconstruction_loss = -torch.mean(compute_ssim_torch(pred, target))
            zero_loss = -torch.mean(compute_ssim_torch(zero_filled, target))
        elif loss_type == 'psnr':
            reconstruction_loss = -torch.mean(compute_psnr_torch(pred, target))
            zero_loss = -torch.mean(compute_psnr_torch(zero_filled, target))
        elif loss_type == 'xentropy':
            criterion = nn.CrossEntropyLoss()
            reconstruction_loss = criterion(pred, label)
            zero_loss = torch.from_numpy(np.array([0]))
        else:
            raise NotImplementedError

        # k-space loss
        image = torch.cat([pred, torch.zeros_like(pred)],
                          dim=1).permute(0, 2, 3, 1)
        image_for_kspace = fastmri.ifftshift(image, dim=(1, 2))
        pred_kspace = image_for_kspace.fft(2, normalized=False)

        pred_kspace = torch.norm(pred_kspace, dim=-1, keepdim=True)
        gt_kspace = torch.norm(gt_kspace, dim=-1, keepdim=True)
        pred_kspace = pred_kspace.permute(0, 3, 1, 2)
        gt_kspace = gt_kspace.permute(0, 3, 1, 2)

        kspace_loss = -torch.mean(compute_ssim_torch(pred_kspace, gt_kspace))

        loss = reconstruction_loss * meta['recon_weight'] + kspace_loss * meta[
            'kspace_weight']  # + 10*zero_loss

        log_dict = {
            'Total Loss': loss.item(),
            'Zero Filled Loss': zero_loss.item(),
            'K-space Loss': kspace_loss.item(),
            'Recon loss': reconstruction_loss.item()
        }

        if self.with_uncertainty:
            loss += nll_loss * meta['uncertainty_weight']
            log_dict.update({'Uncertainty Loss': nll_loss.item()})

        return loss, log_dict