def save_zero_filled(data_dir, out_dir, which_challenge):
    reconstructions = {}

    for fname in tqdm(list(data_dir.glob("*.h5"))):
        with h5py.File(fname, "r") as hf:
            et_root = etree.fromstring(hf["ismrmrd_header"][()])
            masked_kspace = transforms.to_tensor(hf["kspace"][()])

            # extract target image width, height from ismrmrd header
            enc = ["encoding", "encodedSpace", "matrixSize"]
            crop_size = (
                int(et_query(et_root, enc + ["x"])),
                int(et_query(et_root, enc + ["y"])),
            )

            # inverse Fourier Transform to get zero filled solution
            image = fastmri.ifft2c(masked_kspace)

            # check for FLAIR 203
            if image.shape[-2] < crop_size[1]:
                crop_size = (image.shape[-2], image.shape[-2])

            # crop input image
            image = transforms.complex_center_crop(image, crop_size)

            # absolute value
            image = fastmri.complex_abs(image)

            # apply Root-Sum-of-Squares if multicoil data
            if which_challenge == "multicoil":
                image = fastmri.rss(image, dim=1)

            reconstructions[fname.name] = image

    fastmri.save_reconstructions(reconstructions, out_dir)
Exemple #2
0
def save_zero_filled(data_dir, out_dir, which_challenge):
    reconstructions = {}

    for f in data_dir.iterdir():
        with h5py.File(f, "r") as hf:
            enc = ismrmrd.xsd.CreateFromDocument(hf["ismrmrd_header"][()]).encoding[0]
            masked_kspace = transforms.to_tensor(hf["kspace"][()])

            # extract target image width, height from ismrmrd header
            crop_size = (enc.reconSpace.matrixSize.x, enc.reconSpace.matrixSize.y)

            # inverse Fourier Transform to get zero filled solution
            image = fastmri.ifft2c(masked_kspace)

            # check for FLAIR 203
            if image.shape[-2] < crop_size[1]:
                crop_size = (image.shape[-2], image.shape[-2])

            # crop input image
            image = transforms.complex_center_crop(image, crop_size)

            # absolute value
            image = fastmri.complex_abs(image)

            # apply Root-Sum-of-Squares if multicoil data
            if which_challenge == "multicoil":
                image = fastmri.rss(image, dim=1)

            reconstructions[f.name] = image

    fastmri.save_reconstructions(reconstructions, out_dir)
Exemple #3
0
    def forward(self, masked_kspace, mask):
        sens_maps = self.sens_net(masked_kspace, mask)
        kspace_pred = masked_kspace.clone()

        for cascade in self.cascades:
            kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps)

        return fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace_pred)),
                           dim=1)
Exemple #4
0
    def forward(
        self,
        masked_kspace: torch.Tensor,
        mask: torch.Tensor,
        num_low_frequencies: Optional[int] = None,
    ) -> torch.Tensor:
        sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
        kspace_pred = masked_kspace.clone()

        for cascade in self.cascades:
            kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps)

        return fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace_pred)),
                           dim=1)
def visualize_reconstruction(filepath, output_filepath=None):
    hf = h5py.File(filepath)
    recons = hf['reconstruction'][()].squeeze()
    recons_rss = fastmri.rss(T.to_tensor(recons), dim=0)
    img = np.abs(recons_rss.numpy())
    if output_filepath is not None:
        if not output_filepath.parent.exists():
            output_filepath.parent.mkdir(parents=True)
        plt.imshow(img, cmap='gray')
        plt.axis("off")
        plt.savefig(output_filepath, bbox_inches="tight", pad_inches=0)
    else:
        plt.imshow(img, cmap='gray')
        plt.show()
Exemple #6
0
def _base_fastmri_unet_transform(
    kspace,
    mask,
    ground_truth,
    attrs,
    which_challenge="singlecoil",
):
    kspace = fastmri_transforms.to_tensor(kspace)

    mask = mask[..., :kspace.shape[-2]]  # accounting for variable size masks
    masked_kspace = kspace * mask.unsqueeze(-1) + 0.0

    # inverse Fourier transform to get zero filled solution
    image = fastmri.ifft2c(masked_kspace)

    # crop input to correct size
    if ground_truth is not None:
        crop_size = (ground_truth.shape[-2], ground_truth.shape[-1])
    else:
        crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

    # check for FLAIR 203
    if image.shape[-2] < crop_size[1]:
        crop_size = (image.shape[-2], image.shape[-2])

    # noinspection PyTypeChecker
    image = fastmri_transforms.complex_center_crop(image, crop_size)

    # absolute value
    image = fastmri.complex_abs(image)

    # apply Root-Sum-of-Squares if multicoil data
    if which_challenge == "multicoil":
        image = fastmri.rss(image)

    # normalize input
    image, mean, std = fastmri_transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)

    return image.unsqueeze(0), mean, std
def visualize_kspace(kspace, dim=None, crop=False, output_filepath=None):
    kspace = fastmri.ifft2c(kspace)
    if crop:
        crop_size = (kspace.shape[-2], kspace.shape[-2])
        kspace = T.complex_center_crop(kspace, crop_size)
        kspace = fastmri.complex_abs(kspace)
        kspace, _, _ = T.normalize_instance(kspace, eps=1e-11)
        kspace = kspace.clamp(-6, 6)
    else:
        # Compute absolute value to get a real image
        kspace = fastmri.complex_abs(kspace)
    if dim is not None:
        kspace = fastmri.rss(kspace, dim=dim)
    img = np.abs(kspace.numpy())
    if output_filepath is not None:
        if not output_filepath.parent.exists():
            output_filepath.parent.mkdir(parents=True)
        plt.imshow(img, cmap='gray')
        plt.axis("off")
        plt.savefig(output_filepath, bbox_inches="tight", pad_inches=0)
    else:
        plt.imshow(img, cmap='gray')
        plt.show()
Exemple #8
0
    def __call__(
        self,
        kspace: np.ndarray,
        mask: np.ndarray,
        target: np.ndarray,
        attrs: Dict,
        fname: str,
        slice_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str,
               int, float]:
        """
        Args:
            kspace: Input k-space of shape (num_coils, rows, cols) for
                multi-coil data or (rows, cols) for single coil data.
            mask: Mask from the test dataset.
            target: Target image.
            attrs: Acquisition related information stored in the HDF5 object.
            fname: File name.
            slice_num: Serial number of the slice.

        Returns:
            tuple containing:
                image: Zero-filled input image.
                target: Target image converted to a torch.Tensor.
                mean: Mean value used for normalization.
                std: Standard deviation value used for normalization.
                fname: File name.
                slice_num: Serial number of the slice.
        """
        kspace = to_tensor(kspace)

        # check for max value
        max_value = attrs["max"] if "max" in attrs.keys() else 0.0

        # apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = apply_mask(kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # inverse Fourier transform to get zero filled solution
        image = fastmri.ifft2c(masked_kspace)

        # crop input to correct size
        if target is not None:
            crop_size = (target.shape[-2], target.shape[-1])
        else:
            crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

        # check for FLAIR 203
        if image.shape[-2] < crop_size[1]:
            crop_size = (image.shape[-2], image.shape[-2])

        image = complex_center_crop(image, crop_size)

        # absolute value
        image = fastmri.complex_abs(image)

        # apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == "multicoil":
            image = fastmri.rss(image)

        # normalize input
        image, mean, std = normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        # normalize target
        if target is not None:
            target = to_tensor(target)
            target = center_crop(target, crop_size)
            target = normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])

        return image, target, mean, std, fname, slice_num, max_value
Exemple #9
0
# SSIM loss
loss = fastmri.SSIMLoss()
print(loss(slice_image_abs.unsqueeze(1), slice_image_abs.unsqueeze(1), data_range=slice_image_abs.max().reshape(-1)))
# In[15]:


show_coils(slice_image_abs, [0], cmap='gray')


# As we can see, each coil in a multi-coil MRI scan focusses on a different region of the image. These coils can be combined into the full image using the Root-Sum-of-Squares (RSS) transform.

# In[16]:


slice_image_rss = fastmri.rss(slice_image_abs, dim=0)


# In[17]:


plt.imshow(np.abs(slice_image_rss.numpy()), cmap='gray')


# So far, we have been looking at fully-sampled data. We can simulate under-sampled data by creating a mask and applying it to k-space.

# In[18]:


from fastmri.data.subsample import RandomMaskFunc
mask_func = RandomMaskFunc(center_fractions=[0.08], accelerations=[4])  # Create the mask function object
Exemple #10
0
def test_root_sum_of_squares(shape, dim):
    x = create_input(shape)
    out_torch = fastmri.rss(x, dim).numpy()
    out_numpy = np.sqrt(np.sum(x.numpy() ** 2, dim))

    assert np.allclose(out_torch, out_numpy)
Exemple #11
0
    def __call__(self, kspace, mask, target, attrs, fname, slice_num):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows,
                cols, 2) for multi-coil data or (rows, cols, 2) for single coil
                data.
            mask (numpy.array): Mask from the test dataset.
            target (numpy.array): Target image.
            attrs (dict): Acquisition related information stored in the HDF5
                object.
            fname (str): File name.
            slice_num (int): Serial number of the slice.

        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch
                    Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                fname (str): File name.
                slice_num (int): Serial number of the slice.
        """
        kspace = transforms.to_tensor(kspace)

        # apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = transforms.apply_mask(
                kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # inverse Fourier transform to get zero filled solution
        image = fastmri.ifft2c(masked_kspace)

        # crop input to correct size
        if target is not None:
            crop_size = (target.shape[-2], target.shape[-1])
        else:
            crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

        # check for FLAIR 203
        if image.shape[-2] < crop_size[1]:
            crop_size = (image.shape[-2], image.shape[-2])

        image = transforms.complex_center_crop(image, crop_size)

        # absolute value
        image = fastmri.complex_abs(image)

        # apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == "multicoil":
            image = fastmri.rss(image)

        # normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        # normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target, crop_size)
            target = transforms.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])

        return image, target, mean, std, fname, slice_num
Exemple #12
0
    def __call__(
        self,
        kspace: np.ndarray,
        mask: np.ndarray,
        target: np.ndarray,
        attrs: Dict,
        fname: str,
        slice_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str,
               int, float]:
        """
        Args:
            kspace: Input k-space of shape (num_coils, rows, cols) for
                multi-coil data or (rows, cols) for single coil data.
            mask: Mask from the test dataset.
            target: Target image.
            attrs: Acquisition related information stored in the HDF5 object.
            fname: File name.
            slice_num: Serial number of the slice.

        Returns:
            tuple containing:
                image: Zero-filled input image.
                target: Target image converted to a torch.Tensor.
                mean: Mean value used for normalization.
                std: Standard deviation value used for normalization.
                fname: File name.
                slice_num: Serial number of the slice.
        """
        kspace = to_tensor(kspace)

        # check for max value
        max_value = attrs["max"] if "max" in attrs.keys() else 0.0

        # crop input to correct size
        if target is not None:
            crop_size = (target.shape[-2], target.shape[-1])
        else:
            crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])
        is_label = attrs["is_label"]
        if is_label:

            # Handling Label image
            if self.strong_mask_func:
                seed = None if not self.use_seed else tuple(map(ord, fname))
                masked_kspace, mask = apply_mask(kspace, self.strong_mask_func,
                                                 seed)
            else:
                masked_kspace = kspace
            image = fastmri.ifft2c(masked_kspace)
            # print("kspace shape:\n", kspace.shape)
            # print("labellel_image shape:\n", labelled_image.shape)
            # print("cropsize shape: 1\n", crop_size)
            # print("labelled_kspace shape:\n", labelled_kspace.shape)
            # check for FLAIR 203
            if image.shape[-2] < crop_size[1]:
                crop_size = (image.shape[-2], image.shape[-2])
            # print("cropsize shape: 2\n", crop_size)
            image = complex_center_crop(image, crop_size)

            # absolute value
            image = fastmri.complex_abs(image)

            # apply Root-Sum-of-Squares if multicoil data
            if self.which_challenge == "multicoil":
                image = fastmri.rss(image)

            image, label_mean, label_std = normalize_instance(image, eps=1e-11)
            lalbeled_image = image.clamp(-6, 6)

            # normalize target
            if target is not None:
                labeled_target = to_tensor(target)
                labeled_target = center_crop(labeled_target, crop_size)
                labeled_target = normalize(labeled_target,
                                           label_mean,
                                           label_std,
                                           eps=1e-11)
                labeled_target = labeled_target.clamp(-6, 6)
            else:
                labeled_target = torch.Tensor([0])
            return lalbeled_image, lalbeled_image, labeled_target, label_mean, label_std, fname, slice_num, max_value

        # unlabel kspace image handling
        unlabelled_kspace = kspace
        if target is not None:
            unlabelled_target = target

        if self.weak_mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            weak_masked_kspace, weak_mask = apply_mask(kspace,
                                                       self.weak_mask_func,
                                                       seed)
        else:
            weak_masked_kspace = unlabelled_kspace

        # inverse Fourier transform to get zero filled solution
        weak_image = fastmri.ifft2c(weak_masked_kspace)

        # check for FLAIR 203
        if weak_image.shape[-2] < crop_size[1]:
            crop_size = (weak_image.shape[-2], weak_image.shape[-2])

        weak_image = complex_center_crop(weak_image, crop_size)

        # absolute value
        weak_image = fastmri.complex_abs(weak_image)

        # apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == "multicoil":
            weak_image = fastmri.rss(weak_image)

        if self.strong_mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            strong_masked_kspace, strong_mask = apply_mask(
                unlabelled_kspace, self.strong_mask_func, seed)
        else:
            strong_masked_kspace = unlabelled_kspace

        # inverse Fourier transform to get zero filled solution
        strong_image = fastmri.ifft2c(strong_masked_kspace)

        # check for FLAIR 203
        if strong_image.shape[-2] < crop_size[1]:
            crop_size = (strong_image.shape[-2], strong_image.shape[-2])

        strong_image = complex_center_crop(strong_image, crop_size)

        # absolute value
        strong_image = fastmri.complex_abs(strong_image)

        # apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == "multicoil":
            strong_image = fastmri.rss(strong_image)

        image_cat = torch.stack([weak_image, strong_image], dim=0)
        image_cat, unlabel_mean, unlabel_std = normalize_instance(image_cat,
                                                                  eps=1e-11)
        image_cat = image_cat.clamp(-6, 6)
        weak_image, strong_image = image_cat[0], image_cat[1]

        # normalize target
        if target is not None:
            unlabelled_target = to_tensor(unlabelled_target)
            unlabelled_target = center_crop(unlabelled_target, crop_size)
            unlabelled_target = normalize(unlabelled_target,
                                          unlabel_mean,
                                          unlabel_std,
                                          eps=1e-11)
            unlabelled_target = unlabelled_target.clamp(-6, 6)
        else:
            unlabelled_target = torch.Tensor([0])

        return weak_image, strong_image, unlabelled_target, unlabel_mean, unlabel_std, fname, slice_num, max_value