예제 #1
0
def test_conjugate(shape):
    data = np.arange(np.product(shape)).reshape(
        shape) + 1j * (np.arange(np.product(shape)).reshape(shape) + 1)
    torch_tensor = transforms.to_tensor(data)
    torch_tensor = add_names(torch_tensor, named=True)

    out_torch = tensor_to_complex_numpy(transforms.conjugate(torch_tensor))
    out_numpy = np.conjugate(data)
    assert np.allclose(out_torch, out_numpy)
예제 #2
0
mr_forward = torch.where(
    sampling_mask.rename(None) == 0,
    torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
    transforms.fft2(mul).rename(None),
)

error = mr_forward - torch.where(
    sampling_mask.rename(None) == 0,
    torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
    masked_kspace.rename(None),
)
error = error.refine_names(*mul_names)

mr_backward = transforms.ifft2(error)

out = transforms.complex_multiplication(transforms.conjugate(sensitivity_map),
                                        mr_backward).sum("coil")

# numpy
# mul_numpy = sensitivity_map_numpy * input_image_numpy
# mr_forward_numpy = sampling_mask_numpy * numpy_fft(mul_numpy)
# error_numpy = mr_forward_numpy - sampling_mask_numpy * masked_kspace_numpy
# mr_backward_numpy = numpy_ifft(error_numpy)
# out_numpy = (sensitivity_map_numpy.conjugate() * mr_backward_numpy).sum(1)

# np.allclose(tensor_to_complex_numpy(out), out_numpy)

# numpy 2
mr_backward_numpy = numpy_ifft(
    sampling_mask_numpy *
    numpy_fft(sensitivity_map_numpy * input_image_numpy[:, np.newaxis, ...]) -
예제 #3
0
파일: mri_models.py 프로젝트: wdika/direct
 def compute_sense_init(self, kspace, sensitivity_map):
     input_image = T.complex_multiplication(
         T.conjugate(sensitivity_map),
         self.backward_operator(kspace),
     ).sum("coil")
     return input_image
예제 #4
0
파일: mri_models.py 프로젝트: wdika/direct
    def forward(
        self,
        input_image,
        masked_kspace,
        sensitivity_map,
        sampling_mask,
        loglikelihood_scaling=None,
    ):
        r"""
        Defines the MRI loglikelihood assuming one noise vector for the complex images for all coils.
        $$ \frac{1}{\sigma^2} \sum_{i}^{\text{num coils}}
            {S}_i^\{text{H}} \mathcal{F}^{-1} P^T (P \mathcal{F} S_i x_\tau - y_\tau)$$
        for each time step $\tau$

        Parameters
        ----------
        input_image : torch.tensor
            Initial or previous iteration of image.
        masked_kspace : torch.tensor
            Masked k-space.
        sensitivity_map : torch.tensor
        sampling_mask : torch.tensor
        loglikelihood_scaling : float
            Multiplier for loglikelihood, for instance for the k-space noise.

        Returns
        -------
        torch.Tensor
        """
        if "slice" in input_image.names:
            self.ndim = 3

        input_image = input_image.align_to(*self.names_image_complex_last)
        sensitivity_map = sensitivity_map.align_to(*self.names_data_complex_last)
        masked_kspace = masked_kspace.align_to(*self.names_data_complex_last)

        loglikelihood_scaling = loglikelihood_scaling.align_to(*self.names_data_complex_last)

        # We multiply by the loglikelihood_scaling here to prevent fp16 information loss,
        # as this value is typically <<1, and the operators are linear.
        # input_image is a named tensor with names ('batch', 'coil', 'height', 'width', 'complex')
        mul = loglikelihood_scaling.align_as(sensitivity_map) * T.complex_multiplication(
            sensitivity_map, input_image.align_as(sensitivity_map)
        )

        # TODO: Named tensor: this needs a fix once this exists.
        mul_names = mul.names
        mr_forward = torch.where(
            sampling_mask.rename(None) == 0,
            torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
            self.forward_operator(mul).rename(None),
        )

        error = mr_forward - loglikelihood_scaling * torch.where(
            sampling_mask.rename(None) == 0,
            torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
            masked_kspace.rename(None),
        )

        error = error.refine_names(*mul_names)
        mr_backward = self.backward_operator(error)

        if sensitivity_map is not None:
            out = T.complex_multiplication(T.conjugate(sensitivity_map), mr_backward).sum("coil")
        else:
            out = mr_backward.sum("coil")

        return out.align_to(*self.names_image_complex_channel)  # noqa
예제 #5
0
mul_names = mul.names
mr_forward = torch.where(
    sampling_mask.rename(None) == 0,
    torch.tensor([0.], dtype=masked_kspace.dtype).to(masked_kspace.device),
    transforms.fft2(mul).rename(None))

error = mr_forward - torch.where(
    sampling_mask.rename(None) == 0,
    torch.tensor([0.], dtype=masked_kspace.dtype).to(masked_kspace.device),
    masked_kspace.rename(None))
error = error.refine_names(*mul_names)

mr_backward = transforms.ifft2(error)

out = transforms.complex_multiplication(transforms.conjugate(sensitivity_map), mr_backward).sum('coil')


# numpy
# mul_numpy = sensitivity_map_numpy * input_image_numpy
# mr_forward_numpy = sampling_mask_numpy * numpy_fft(mul_numpy)
# error_numpy = mr_forward_numpy - sampling_mask_numpy * masked_kspace_numpy
# mr_backward_numpy = numpy_ifft(error_numpy)
# out_numpy = (sensitivity_map_numpy.conjugate() * mr_backward_numpy).sum(1)

# np.allclose(tensor_to_complex_numpy(out), out_numpy)

# numpy 2
mr_backward_numpy = numpy_ifft(sampling_mask_numpy * numpy_fft(
    sensitivity_map_numpy * input_image_numpy[:, np.newaxis, ...]) - sampling_mask_numpy * masked_kspace_numpy)
out_numpy = (sensitivity_map_numpy.conjugate() * mr_backward_numpy).sum(1)