示例#1
0
    def __call__(self, sample):
        if self.type_of_map == "unit":
            kspace = sample["kspace"]
            #TODO(kp) Figure out a way to skip this class entirely if sensitivity map already in sample and est = false
            #(kp) added if statement to keep sensitivity map from being altered if already existing
            if 'sensitivity_map' in sample:
                return sample
            sensitivity_map = torch.zeros(kspace.shape).float()
            # TODO(jt): Named variant, this assumes the complex channel is last.
            if not kspace.names[-1] == "complex":
                raise NotImplementedError(f"Assuming last channel is complex.")
            sensitivity_map[..., 0] = 1.0
            sample["sensitivity_map"] = sensitivity_map.refine_names(*kspace.names).to(
                kspace.device
            )

        elif self.type_of_map == "rss_estimate":
            acs_image = self.estimate_acs_image(sample)
            acs_image_rss = T.root_sum_of_squares(acs_image, dim="coil").align_as(
                acs_image
            )
            sample["sensitivity_map"] = T.safe_divide(acs_image, acs_image_rss)
        else:
            raise ValueError(
                f"Expected type of map to be either `unit` or `rss_estimate`. Got {self.type_of_map}."
            )

        return sample
示例#2
0
    def estimate_sensitivity_map(self, sample):
        kspace_data = sample[self.kspace_key]

        if kspace_data.shape[0] == 1:
            warnings.warn(
                f"`Single-coil data, skipping estimation of sensitivity map. "
                f"This warning will be displayed only once."
            )
            return sample

        if "sensitivity_map" in sample:
            warnings.warn(
                f"`sensitivity_map` is given, but will be overwritten. "
                f"This warning will be displayed only once."
            )

        kspace_acs = T.apply_mask(kspace_data, sample["acs_mask"], return_mask=False)

        # Get complex-valued data solution
        image = self.backward_operator(kspace_acs)
        rss_image = T.root_sum_of_squares(image, dim="coil").align_as(image)

        # TODO(jt): Safe divide.
        sensitivity_mask = torch.where(
            rss_image.rename(None) == 0,
            torch.tensor([0.0], dtype=rss_image.dtype).to(rss_image.device),
            (image / rss_image).rename(None),
        ).refine_names(*image.names)
        return sensitivity_mask
示例#3
0
    def __call__(self, sample: Dict[str, Any]):
        """

        Parameters
        ----------
        sample: dict

        Returns
        -------
        data dictionary
        """
        kspace = sample["kspace"]
        sensitivity_map = sample.get("sensitivity_map", None)
        filename = sample["filename"]

        if "sampling_mask" in sample:
            if self.mask_func is not None:
                warnings.warn(
                    f"`sampling_mask` is passed by the Dataset class, yet `mask_func` is also set. "
                    f"This will be ignored and the `sampling_mask` will be used instead. "
                    f"Be aware of this as it can lead to unexpected results. "
                    f"This warning will be issued only once."
                )
            mask_func = sample["sampling_mask"]
        else:
            mask_func = self.mask_func

        seed = None if not self.use_seed else tuple(map(ord, str(filename)))

        if np.random.random() >= self.kspace_crop_probability:
            kspace, backprojected_kspace, sensitivity_map = self.__random_image_crop(
                kspace, sensitivity_map
            )
            masked_kspace, sampling_mask = T.apply_mask(kspace, mask_func, seed)

        else:
            masked_kspace, sampling_mask = T.apply_mask(kspace, mask_func, seed)
            (
                kspace,
                masked_kspace,
                sampling_mask,
                backprojected_kspace,
                sensitivity_map,
            ) = self.__central_kspace_crop(
                kspace, masked_kspace, sampling_mask, sensitivity_map
            )

        sample["target"] = T.root_sum_of_squares(backprojected_kspace, dim="coil")
        del sample["kspace"]
        sample["masked_kspace"] = masked_kspace
        sample["sampling_mask"] = sampling_mask

        if sensitivity_map is not None:
            sample["sensitivity_map"] = sensitivity_map

        return sample
示例#4
0
def test_root_sum_of_squares_complex(shape, dims):
    shape = shape + [
        2,
    ]
    data = create_input(shape, named=True)  # noqa
    out_torch = transforms.root_sum_of_squares(data, dims).numpy()
    input_numpy = tensor_to_complex_numpy(data)
    out_numpy = np.sqrt(
        np.sum(np.abs(input_numpy)**2, dims if not dims == "coils" else 0))
    assert np.allclose(out_torch, out_numpy)
示例#5
0
    def __call__(self, sample):
        kspace = sample["kspace"]
        # We need to create an ACS mask based on the shape of this kspace, as it can be cropped.

        seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"])))
        kspace_shape = sample["kspace"].shape[1:]
        acs_mask = self.mask_func(kspace_shape, seed, return_acs=True)

        kspace = acs_mask * kspace + 0.0
        acs_image = self.backward_operator(kspace)

        sample["body_coil_image"] = T.root_sum_of_squares(acs_image, dim="coil")
        return sample
示例#6
0
    def __call__(self, sample: Dict[str, Any]):
        """

        Parameters
        ----------
        sample: dict

        Returns
        -------
        data dictionary
        """
        kspace = sample["kspace"]

        # Image-space croppable objects
        croppable_images = ["sensitivity_map", "input_image"]
        sensitivity_map = sample.get("sensitivity_map", None)
        sampling_mask = sample["sampling_mask"]
        backprojected_kspace = self.backward_operator(kspace)

        # TODO: Also create a kspace-like crop function
        if self.crop:
            cropped_output = self.crop_func(
                [
                    backprojected_kspace,
                    *[sample[_] for _ in croppable_images if _ in sample],
                ],
                self.crop,
                contiguous=True,
            )
            backprojected_kspace = cropped_output[0]
            for idx, key in enumerate(croppable_images):
                sample[key] = cropped_output[1 + idx]

            # Compute new k-space for the cropped input_image
            kspace = self.forward_operator(backprojected_kspace)

        masked_kspace, sampling_mask = T.apply_mask(kspace, sampling_mask)

        sample["target"] = T.root_sum_of_squares(backprojected_kspace,
                                                 dim="coil")
        sample["masked_kspace"] = masked_kspace
        sample["sampling_mask"] = sampling_mask
        sample["kspace"] = kspace  # The cropped kspace

        if sensitivity_map is not None:
            sample["sensitivity_map"] = sensitivity_map

        return sample
示例#7
0
    def __call__(self, sample):
        kspace_data = sample[self.kspace_key]

        # Get complex-valued image solution
        image = self.backward_operator(kspace_data)

        if self.type_reconstruction == 'complex':
            sample[self.target_key] = image.sum('coil')
        elif self.type_reconstruction.lower() == 'rss':
            sample[self.target_key] = transforms.root_sum_of_squares(image, dim='coil')
        elif self.type_reconstruction == 'sense':
            if 'sensitivity_map' not in sample:
                raise ValueError('Sensitivity map is required for SENSE reconstruction.')
            raise NotImplementedError('SENSE is not implemented.')

        return sample
示例#8
0
    def __call__(self, sample):
        kspace_data = sample[self.kspace_key]

        # Get complex-valued data solution
        image = self.backward_operator(kspace_data)

        if self.type_reconstruction == "complex":
            sample[self.target_key] = image.sum("coil")
        elif self.type_reconstruction.lower() == "rss":
            sample[self.target_key] = T.root_sum_of_squares(image, dim="coil")
        elif self.type_reconstruction == "sense":
            if "sensitivity_map" not in sample:
                raise ValueError(
                    "Sensitivity map is required for SENSE reconstruction.")
            raise NotImplementedError("SENSE is not implemented.")

        return sample
示例#9
0
    def __call__(self, sample: Dict[str, Any]):
        """

        Parameters
        ----------
        sample: dict

        Returns
        -------
        data dictionary
        """
        kspace = sample['kspace']
        sensitivity_map = sample.get('sensitivity_map', None)
        filename = sample['filename']

        if 'sampling_mask' in sample and self.mask_func is not None:
            warnings.warn(f'`sampling_mask` is passed by the Dataset class, yet `mask_func` is also set. '
                          f'This will be ignored and the `sampling_mask` will be used instead. '
                          f'Be aware of this as it can lead to unexpected results. '
                          f'This warning will be issued only once.')
            raise NotImplementedError('This is required when a mask is present,'
                                      ' but in this case this should be applied differently!')

        seed = None if not self.use_seed else tuple(map(ord, str(filename)))

        if np.random.random() >= self.kspace_crop_probability:
            kspace, backprojected_kspace, sensitivity_map = self.__random_image_crop(kspace, sensitivity_map)
            masked_kspace, sampling_mask = transforms.apply_mask(kspace, self.mask_func, seed)

        else:
            masked_kspace, sampling_mask = transforms.apply_mask(kspace, self.mask_func, seed)
            kspace, masked_kspace, sampling_mask, backprojected_kspace, sensitivity_map = self.__central_kspace_crop(
                kspace, masked_kspace, sampling_mask, sensitivity_map)

        sample['target'] = transforms.root_sum_of_squares(backprojected_kspace, dim='coil')
        del sample['kspace']
        sample['masked_kspace'] = masked_kspace
        sample['sampling_mask'] = sampling_mask

        if sensitivity_map is not None:
            sample['sensitivity_map'] = sensitivity_map

        return sample
示例#10
0
    def __call__(self, sample):
        if self.type_of_map == "unit":
            kspace = sample["kspace"]
            sensitivity_map = torch.zeros(kspace.shape).float()
            # TODO(jt): Named variant, this assumes the complex channel is last.
            if not kspace.names[-1] == "complex":
                raise NotImplementedError(f"Assuming last channel is complex.")
            sensitivity_map[..., 0] = 1.0
            sample["sensitivity_map"] = sensitivity_map.refine_names(
                *kspace.names).to(kspace.device)

        elif self.type_of_map == "rss_estimate":
            acs_image = self.estimate_acs_image(sample)
            acs_image_rss = T.root_sum_of_squares(
                acs_image, dim="coil").align_as(acs_image)
            sample["sensitivity_map"] = T.safe_divide(acs_image, acs_image_rss)
        else:
            raise ValueError(
                f"Expected type of map to be either `unit` or `rss_estimate`. Got {self.type_of_map}."
            )

        return sample
示例#11
0
def test_root_sum_of_squares_real(shape, dims):
    data = create_input(shape, named=True)  # noqa
    out_torch = transforms.root_sum_of_squares(data, dims).numpy()
    out_numpy = np.sqrt(np.sum(data.numpy()**2, dims))
    assert np.allclose(out_torch, out_numpy)