Exemplo n.º 1
0
    def __call__(self, sample: Dict[str, Any]):
        """

        Parameters
        ----------
        sample: dict

        Returns
        -------
        data dictionary
        """
        kspace = sample["kspace"]
        sensitivity_map = sample.get("sensitivity_map", None)
        filename = sample["filename"]

        if "sampling_mask" in sample:
            if self.mask_func is not None:
                warnings.warn(
                    f"`sampling_mask` is passed by the Dataset class, yet `mask_func` is also set. "
                    f"This will be ignored and the `sampling_mask` will be used instead. "
                    f"Be aware of this as it can lead to unexpected results. "
                    f"This warning will be issued only once."
                )
            mask_func = sample["sampling_mask"]
        else:
            mask_func = self.mask_func

        seed = None if not self.use_seed else tuple(map(ord, str(filename)))

        if np.random.random() >= self.kspace_crop_probability:
            kspace, backprojected_kspace, sensitivity_map = self.__random_image_crop(
                kspace, sensitivity_map
            )
            masked_kspace, sampling_mask = T.apply_mask(kspace, mask_func, seed)

        else:
            masked_kspace, sampling_mask = T.apply_mask(kspace, mask_func, seed)
            (
                kspace,
                masked_kspace,
                sampling_mask,
                backprojected_kspace,
                sensitivity_map,
            ) = self.__central_kspace_crop(
                kspace, masked_kspace, sampling_mask, sensitivity_map
            )

        sample["target"] = T.root_sum_of_squares(backprojected_kspace, dim="coil")
        del sample["kspace"]
        sample["masked_kspace"] = masked_kspace
        sample["sampling_mask"] = sampling_mask

        if sensitivity_map is not None:
            sample["sensitivity_map"] = sensitivity_map

        return sample
Exemplo n.º 2
0
    def estimate_sensitivity_map(self, sample):
        kspace_data = sample[self.kspace_key]

        if kspace_data.shape[0] == 1:
            warnings.warn(
                f"`Single-coil data, skipping estimation of sensitivity map. "
                f"This warning will be displayed only once."
            )
            return sample

        if "sensitivity_map" in sample:
            warnings.warn(
                f"`sensitivity_map` is given, but will be overwritten. "
                f"This warning will be displayed only once."
            )

        kspace_acs = T.apply_mask(kspace_data, sample["acs_mask"], return_mask=False)

        # Get complex-valued data solution
        image = self.backward_operator(kspace_acs)
        rss_image = T.root_sum_of_squares(image, dim="coil").align_as(image)

        # TODO(jt): Safe divide.
        sensitivity_mask = torch.where(
            rss_image.rename(None) == 0,
            torch.tensor([0.0], dtype=rss_image.dtype).to(rss_image.device),
            (image / rss_image).rename(None),
        ).refine_names(*image.names)
        return sensitivity_mask
Exemplo n.º 3
0
    def __call__(self, sample: Dict[str, Any]):
        """

        Parameters
        ----------
        sample: dict

        Returns
        -------
        data dictionary
        """
        kspace = sample['kspace']
        sensitivity_map = sample.get('sensitivity_map', None)
        filename = sample['filename']

        if 'sampling_mask' in sample and self.mask_func is not None:
            warnings.warn(f'`sampling_mask` is passed by the Dataset class, yet `mask_func` is also set. '
                          f'This will be ignored and the `sampling_mask` will be used instead. '
                          f'Be aware of this as it can lead to unexpected results. '
                          f'This warning will be issued only once.')
            raise NotImplementedError('This is required when a mask is present,'
                                      ' but in this case this should be applied differently!')

        seed = None if not self.use_seed else tuple(map(ord, str(filename)))

        if np.random.random() >= self.kspace_crop_probability:
            kspace, backprojected_kspace, sensitivity_map = self.__random_image_crop(kspace, sensitivity_map)
            masked_kspace, sampling_mask = transforms.apply_mask(kspace, self.mask_func, seed)

        else:
            masked_kspace, sampling_mask = transforms.apply_mask(kspace, self.mask_func, seed)
            kspace, masked_kspace, sampling_mask, backprojected_kspace, sensitivity_map = self.__central_kspace_crop(
                kspace, masked_kspace, sampling_mask, sensitivity_map)

        sample['target'] = transforms.root_sum_of_squares(backprojected_kspace, dim='coil')
        del sample['kspace']
        sample['masked_kspace'] = masked_kspace
        sample['sampling_mask'] = sampling_mask

        if sensitivity_map is not None:
            sample['sensitivity_map'] = sensitivity_map

        return sample
Exemplo n.º 4
0
def test_apply_mask_fastmri(shape, center_fractions, accelerations):
    mask_func = FastMRIMaskFunc(center_fractions=center_fractions,
                                accelerations=accelerations,
                                uniform_range=False)
    expected_mask = mask_func(shape[1:], seed=123)
    data = create_input(shape, named=True)

    output, mask = transforms.apply_mask(data, mask_func, seed=123)
    assert output.shape == data.shape
    assert mask.shape == expected_mask.shape
    assert np.all(expected_mask.numpy() == mask.numpy())
    assert np.all(
        np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy())
Exemplo n.º 5
0
    def __call__(self, sample: Dict[str, Any]):
        """

        Parameters
        ----------
        sample: dict

        Returns
        -------
        data dictionary
        """
        kspace = sample["kspace"]

        # Image-space croppable objects
        croppable_images = ["sensitivity_map", "input_image"]
        sensitivity_map = sample.get("sensitivity_map", None)
        sampling_mask = sample["sampling_mask"]
        backprojected_kspace = self.backward_operator(kspace)

        # TODO: Also create a kspace-like crop function
        if self.crop:
            cropped_output = self.crop_func(
                [
                    backprojected_kspace,
                    *[sample[_] for _ in croppable_images if _ in sample],
                ],
                self.crop,
                contiguous=True,
            )
            backprojected_kspace = cropped_output[0]
            for idx, key in enumerate(croppable_images):
                sample[key] = cropped_output[1 + idx]

            # Compute new k-space for the cropped input_image
            kspace = self.forward_operator(backprojected_kspace)

        masked_kspace, sampling_mask = T.apply_mask(kspace, sampling_mask)

        sample["target"] = T.root_sum_of_squares(backprojected_kspace,
                                                 dim="coil")
        sample["masked_kspace"] = masked_kspace
        sample["sampling_mask"] = sampling_mask
        sample["kspace"] = kspace  # The cropped kspace

        if sensitivity_map is not None:
            sample["sensitivity_map"] = sensitivity_map

        return sample