def test_conjugate(shape): data = np.arange(np.product(shape)).reshape( shape) + 1j * (np.arange(np.product(shape)).reshape(shape) + 1) torch_tensor = transforms.to_tensor(data) torch_tensor = add_names(torch_tensor, named=True) out_torch = tensor_to_complex_numpy(transforms.conjugate(torch_tensor)) out_numpy = np.conjugate(data) assert np.allclose(out_torch, out_numpy)
mr_forward = torch.where( sampling_mask.rename(None) == 0, torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), transforms.fft2(mul).rename(None), ) error = mr_forward - torch.where( sampling_mask.rename(None) == 0, torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), masked_kspace.rename(None), ) error = error.refine_names(*mul_names) mr_backward = transforms.ifft2(error) out = transforms.complex_multiplication(transforms.conjugate(sensitivity_map), mr_backward).sum("coil") # numpy # mul_numpy = sensitivity_map_numpy * input_image_numpy # mr_forward_numpy = sampling_mask_numpy * numpy_fft(mul_numpy) # error_numpy = mr_forward_numpy - sampling_mask_numpy * masked_kspace_numpy # mr_backward_numpy = numpy_ifft(error_numpy) # out_numpy = (sensitivity_map_numpy.conjugate() * mr_backward_numpy).sum(1) # np.allclose(tensor_to_complex_numpy(out), out_numpy) # numpy 2 mr_backward_numpy = numpy_ifft( sampling_mask_numpy * numpy_fft(sensitivity_map_numpy * input_image_numpy[:, np.newaxis, ...]) -
def compute_sense_init(self, kspace, sensitivity_map): input_image = T.complex_multiplication( T.conjugate(sensitivity_map), self.backward_operator(kspace), ).sum("coil") return input_image
def forward( self, input_image, masked_kspace, sensitivity_map, sampling_mask, loglikelihood_scaling=None, ): r""" Defines the MRI loglikelihood assuming one noise vector for the complex images for all coils. $$ \frac{1}{\sigma^2} \sum_{i}^{\text{num coils}} {S}_i^\{text{H}} \mathcal{F}^{-1} P^T (P \mathcal{F} S_i x_\tau - y_\tau)$$ for each time step $\tau$ Parameters ---------- input_image : torch.tensor Initial or previous iteration of image. masked_kspace : torch.tensor Masked k-space. sensitivity_map : torch.tensor sampling_mask : torch.tensor loglikelihood_scaling : float Multiplier for loglikelihood, for instance for the k-space noise. Returns ------- torch.Tensor """ if "slice" in input_image.names: self.ndim = 3 input_image = input_image.align_to(*self.names_image_complex_last) sensitivity_map = sensitivity_map.align_to(*self.names_data_complex_last) masked_kspace = masked_kspace.align_to(*self.names_data_complex_last) loglikelihood_scaling = loglikelihood_scaling.align_to(*self.names_data_complex_last) # We multiply by the loglikelihood_scaling here to prevent fp16 information loss, # as this value is typically <<1, and the operators are linear. # input_image is a named tensor with names ('batch', 'coil', 'height', 'width', 'complex') mul = loglikelihood_scaling.align_as(sensitivity_map) * T.complex_multiplication( sensitivity_map, input_image.align_as(sensitivity_map) ) # TODO: Named tensor: this needs a fix once this exists. mul_names = mul.names mr_forward = torch.where( sampling_mask.rename(None) == 0, torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), self.forward_operator(mul).rename(None), ) error = mr_forward - loglikelihood_scaling * torch.where( sampling_mask.rename(None) == 0, torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), masked_kspace.rename(None), ) error = error.refine_names(*mul_names) mr_backward = self.backward_operator(error) if sensitivity_map is not None: out = T.complex_multiplication(T.conjugate(sensitivity_map), mr_backward).sum("coil") else: out = mr_backward.sum("coil") return out.align_to(*self.names_image_complex_channel) # noqa
mul_names = mul.names mr_forward = torch.where( sampling_mask.rename(None) == 0, torch.tensor([0.], dtype=masked_kspace.dtype).to(masked_kspace.device), transforms.fft2(mul).rename(None)) error = mr_forward - torch.where( sampling_mask.rename(None) == 0, torch.tensor([0.], dtype=masked_kspace.dtype).to(masked_kspace.device), masked_kspace.rename(None)) error = error.refine_names(*mul_names) mr_backward = transforms.ifft2(error) out = transforms.complex_multiplication(transforms.conjugate(sensitivity_map), mr_backward).sum('coil') # numpy # mul_numpy = sensitivity_map_numpy * input_image_numpy # mr_forward_numpy = sampling_mask_numpy * numpy_fft(mul_numpy) # error_numpy = mr_forward_numpy - sampling_mask_numpy * masked_kspace_numpy # mr_backward_numpy = numpy_ifft(error_numpy) # out_numpy = (sensitivity_map_numpy.conjugate() * mr_backward_numpy).sum(1) # np.allclose(tensor_to_complex_numpy(out), out_numpy) # numpy 2 mr_backward_numpy = numpy_ifft(sampling_mask_numpy * numpy_fft( sensitivity_map_numpy * input_image_numpy[:, np.newaxis, ...]) - sampling_mask_numpy * masked_kspace_numpy) out_numpy = (sensitivity_map_numpy.conjugate() * mr_backward_numpy).sum(1)