def root_sum_of_squares(data: torch.Tensor, dim: Union[int, str] = "coil") -> torch.Tensor: """ Compute the root sum of squares (RSS) transform along a given (perhaps named) dimension of the input tensor. $$x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2}$$ Parameters ---------- data : torch.Tensor Input tensor dim : Union[int, str] Coil dimension. Returns ------- torch.Tensor : RSS of the input tensor. """ assert_named(data) if "complex" in data.names: assert_complex(data, complex_last=True) return torch.sqrt((data**2).sum("complex").sum(dim)) else: return torch.sqrt((data**2).sum(dim))
def fft2_new( data: torch.Tensor, dim: Tuple[str, ...] = ("height", "width"), centered: bool = True, normalized: bool = True, ) -> torch.Tensor: """ Apply centered two-dimensional Inverse Fast Fourier Transform. Can be performed in half precision when input shapes are powers of two. Version for PyTorch >= 1.7.0. Parameters ---------- data : torch.Tensor Complex-valued input tensor. dim : tuple, list or int Dimensions over which to compute. centered : bool Whether to apply a centered fft (center of kspace is in the center versus in the corners). For FastMRI dataset this has to be true and for the Calgary-Campinas dataset false. normalized : bool Whether to normalize the ifft. For the FastMRI this has to be true and for the Calgary-Campinas dataset false. Returns ------- torch.Tensor: the fft of the output. """ assert_complex(data) names = data.names data = view_as_complex(data) if centered: data = ifftshift(data, dim=dim) # Verify whether half precision and if fft is possible in this shape. Else do a typecast. if verify_fft_dtype_possible(data, dim): data = torch.fft.fftn( data.rename(None), dim=_dims_to_index(dim, data.names), norm="ortho" if normalized else None, ) else: raise ValueError(f"Currently half precision FFT is not supported.") if any(names): data = data.refine_names(*names[:-1]) # typing: ignore if centered: data = fftshift(data, dim=dim) data = view_as_real(data) return data
def apply_mask( kspace: torch.Tensor, mask_func: Union[Callable, torch.Tensor], seed: Optional[int] = None, return_mask: bool = True, ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: """ Subsample kspace by setting kspace to zero as given by a binary mask. Parameters ---------- kspace : torch.Tensor k-space as a complex-valued tensor. mask_func : callable or torch.tensor Masking function, taking a shape and returning a mask with this shape or can be broadcasted as such Can also be a sampling mask. seed : int Seed for the random number generator return_mask : bool If true, mask will be returned Returns ------- masked data (torch.Tensor), mask (torch.Tensor) """ # TODO: Split the function to apply_mask_func and apply_mask assert_complex(kspace, enforce_named=True) names = kspace.names kspace = kspace.rename(None) if not isinstance(mask_func, torch.Tensor): shape = np.array(kspace.shape)[ 1: ] # The first dimension is always the coil dimension. mask = mask_func(shape, seed) else: mask = mask_func masked_kspace = torch.where( mask == 0, torch.tensor([0.0], dtype=kspace.dtype), kspace ) mask = mask.refine_names(*names) masked_kspace = masked_kspace.refine_names(*names) if not return_mask: return masked_kspace return masked_kspace, mask
def modulus(data: torch.Tensor) -> torch.Tensor: """ Compute modulus of complex input data. Assumes there is a dimension called "complex" in the data. Parameters ---------- data : torch.Tensor Returns ------- torch.Tensor: modulus of data. """ assert_complex(data, enforce_named=True, complex_last=False) # TODO: Named tensors typing not yet fully supported in pytorch. return (data**2).sum("complex").sqrt() # noqa
def tensor_to_complex_numpy(data: torch.Tensor) -> np.ndarray: """ Converts a complex pytorch tensor to a complex numpy array. The last axis denote the real and imaginary parts respectively. Parameters ---------- data : torch.Tensor Input data Returns ------- Complex valued np.ndarray """ assert_complex(data) data = data.detach().cpu().numpy() return data[..., 0] + 1j * data[..., 1]
def ifft2_old( data: torch.Tensor, dim: Tuple[str, ...] = ("height", "width"), centered: bool = True, normalized: bool = True, ) -> torch.Tensor: """ Apply centered two-dimensional Inverse Fast Fourier Transform. Can be performed in half precision when input shapes are powers of two. Parameters ---------- data : torch.Tensor Complex-valued input tensor. dim : tuple, list or int Dimensions over which to compute. centered : bool Whether to apply a centered ifft (center of kspace is in the center versus in the corners). For FastMRI dataset this has to be true and for the Calgary-Campinas dataset false. normalized : bool Whether to normalize the ifft. For the FastMRI this has to be true and for the Calgary-Campinas dataset false. Returns ------- torch.Tensor: the ifft of the output. """ assert_complex(data) if centered: data = ifftshift(data, dim=dim) names = data.names # TODO: Fix when ifft supports named tensors # Verify whether half precision and if ifft is possible in this shape. Else do a typecast. if verify_fft_dtype_possible(data, dim): data = torch.ifft(data.rename(None), 2, normalized=normalized) else: data = torch.ifft(data.rename(None).float(), 2, normalized=normalized).type( data.type() ) if any(names): data = data.refine_names(*names) if centered: data = fftshift(data, dim=dim) return data
def conjugate(data: torch.Tensor) -> torch.Tensor: """ Compute the complex conjugate of a torch tensor where the last axis denotes the real and complex part. Parameters ---------- data : torch.Tensor Returns ------- torch.Tensor """ assert_complex(data, enforce_named=True) names = data.names data = data.rename(None).clone( ) # Clone is required as the data in the next line is changed in-place. data[..., 1] = data[..., 1] * -1.0 data = data.refine_names(*names) return data
def complex_multiplication(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Multiplies two complex-valued tensors. Assumes the tensor has a named dimension "complex". Parameters ---------- x : torch.Tensor Input data y : torch.Tensor Input data Returns ------- torch.Tensor """ assert_complex(x, enforce_named=True, complex_last=True) assert_complex(y, enforce_named=True, complex_last=True) # multiplication = torch.view_as_complex(x.rename(None)) * torch.view_as_complex( # y.rename(None) # ) # return torch.view_as_real(multiplication).refine_names(*x.names) # TODO: Unsqueezing is not yet supported for named tensors, fix when it is. complex_index = x.names.index("complex") real_part = x.select("complex", 0) * y.select("complex", 0) - x.select( "complex", 1 ) * y.select("complex", 1) imaginary_part = x.select("complex", 0) * y.select("complex", 1) + x.select( "complex", 1 ) * y.select("complex", 0) real_part = real_part.rename(None) imaginary_part = imaginary_part.rename(None) multiplication = torch.cat( [ real_part.unsqueeze(dim=complex_index), imaginary_part.unsqueeze(dim=complex_index), ], dim=complex_index, ) return multiplication.refine_names(*x.names)
def view_as_complex(data): """Named version of `torch.view_as_complex()`""" assert_complex(data) return torch.view_as_complex( data.rename(None)).refine_names(*data.names[:-1])