Beispiel #1
0
def root_sum_of_squares(data: torch.Tensor,
                        dim: Union[int, str] = "coil") -> torch.Tensor:
    """
    Compute the root sum of squares (RSS) transform along a given (perhaps named) dimension of the input tensor.

    $$x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2}$$

    Parameters
    ----------
    data : torch.Tensor
        Input tensor

    dim : Union[int, str]
        Coil dimension.

    Returns
    -------
    torch.Tensor : RSS of the input tensor.
    """
    assert_named(data)
    if "complex" in data.names:
        assert_complex(data, complex_last=True)
        return torch.sqrt((data**2).sum("complex").sum(dim))
    else:
        return torch.sqrt((data**2).sum(dim))
Beispiel #2
0
def fft2_new(
    data: torch.Tensor,
    dim: Tuple[str, ...] = ("height", "width"),
    centered: bool = True,
    normalized: bool = True,
) -> torch.Tensor:
    """
    Apply centered two-dimensional Inverse Fast Fourier Transform. Can be performed in half precision when
    input shapes are powers of two.

    Version for PyTorch >= 1.7.0.

    Parameters
    ----------
    data : torch.Tensor
        Complex-valued input tensor.
    dim : tuple, list or int
        Dimensions over which to compute.
    centered : bool
        Whether to apply a centered fft (center of kspace is in the center versus in the corners).
        For FastMRI dataset this has to be true and for the Calgary-Campinas dataset false.
    normalized : bool
        Whether to normalize the ifft. For the FastMRI this has to be true and for the Calgary-Campinas dataset false.
    Returns
    -------
    torch.Tensor: the fft of the output.
    """
    assert_complex(data)
    names = data.names
    data = view_as_complex(data)
    if centered:
        data = ifftshift(data, dim=dim)
    # Verify whether half precision and if fft is possible in this shape. Else do a typecast.
    if verify_fft_dtype_possible(data, dim):
        data = torch.fft.fftn(
            data.rename(None),
            dim=_dims_to_index(dim, data.names),
            norm="ortho" if normalized else None,
        )
    else:
        raise ValueError(f"Currently half precision FFT is not supported.")

    if any(names):
        data = data.refine_names(*names[:-1])  # typing: ignore

    if centered:
        data = fftshift(data, dim=dim)

    data = view_as_real(data)
    return data
Beispiel #3
0
def apply_mask(
    kspace: torch.Tensor,
    mask_func: Union[Callable, torch.Tensor],
    seed: Optional[int] = None,
    return_mask: bool = True,
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
    """
    Subsample kspace by setting kspace to zero as given by a binary mask.

    Parameters
    ----------
    kspace : torch.Tensor
        k-space as a complex-valued tensor.
    mask_func : callable or torch.tensor
        Masking function, taking a shape and returning a mask with this shape or can be broadcasted as such
        Can also be a sampling mask.
    seed : int
        Seed for the random number generator
    return_mask : bool
        If true, mask will be returned

    Returns
    -------
    masked data (torch.Tensor), mask (torch.Tensor)
    """
    # TODO: Split the function to apply_mask_func and apply_mask

    assert_complex(kspace, enforce_named=True)
    names = kspace.names
    kspace = kspace.rename(None)

    if not isinstance(mask_func, torch.Tensor):
        shape = np.array(kspace.shape)[
            1:
        ]  # The first dimension is always the coil dimension.
        mask = mask_func(shape, seed)
    else:
        mask = mask_func

    masked_kspace = torch.where(
        mask == 0, torch.tensor([0.0], dtype=kspace.dtype), kspace
    )

    mask = mask.refine_names(*names)
    masked_kspace = masked_kspace.refine_names(*names)
    if not return_mask:
        return masked_kspace

    return masked_kspace, mask
Beispiel #4
0
def modulus(data: torch.Tensor) -> torch.Tensor:
    """
    Compute modulus of complex input data. Assumes there is a dimension called "complex" in the data.

    Parameters
    ----------
    data : torch.Tensor

    Returns
    -------
    torch.Tensor: modulus of data.
    """
    assert_complex(data, enforce_named=True, complex_last=False)
    # TODO: Named tensors typing not yet fully supported in pytorch.
    return (data**2).sum("complex").sqrt()  # noqa
Beispiel #5
0
def tensor_to_complex_numpy(data: torch.Tensor) -> np.ndarray:
    """
    Converts a complex pytorch tensor to a complex numpy array.
    The last axis denote the real and imaginary parts respectively.

    Parameters
    ----------
    data : torch.Tensor
        Input data

    Returns
    -------
    Complex valued np.ndarray
    """
    assert_complex(data)
    data = data.detach().cpu().numpy()
    return data[..., 0] + 1j * data[..., 1]
Beispiel #6
0
def ifft2_old(
    data: torch.Tensor,
    dim: Tuple[str, ...] = ("height", "width"),
    centered: bool = True,
    normalized: bool = True,
) -> torch.Tensor:
    """
    Apply centered two-dimensional Inverse Fast Fourier Transform. Can be performed in half precision when
    input shapes are powers of two.

    Parameters
    ----------
    data : torch.Tensor
        Complex-valued input tensor.
    dim : tuple, list or int
        Dimensions over which to compute.
    centered : bool
        Whether to apply a centered ifft (center of kspace is in the center versus in the corners).
        For FastMRI dataset this has to be true and for the Calgary-Campinas dataset false.
    normalized : bool
        Whether to normalize the ifft. For the FastMRI this has to be true and for the Calgary-Campinas dataset false.

    Returns
    -------
    torch.Tensor: the ifft of the output.
    """
    assert_complex(data)

    if centered:
        data = ifftshift(data, dim=dim)
    names = data.names
    # TODO: Fix when ifft supports named tensors
    # Verify whether half precision and if ifft is possible in this shape. Else do a typecast.
    if verify_fft_dtype_possible(data, dim):
        data = torch.ifft(data.rename(None), 2, normalized=normalized)
    else:
        data = torch.ifft(data.rename(None).float(), 2, normalized=normalized).type(
            data.type()
        )

    if any(names):
        data = data.refine_names(*names)

    if centered:
        data = fftshift(data, dim=dim)
    return data
Beispiel #7
0
def conjugate(data: torch.Tensor) -> torch.Tensor:
    """
    Compute the complex conjugate of a torch tensor where the last axis denotes the real and complex part.

    Parameters
    ----------
    data : torch.Tensor

    Returns
    -------
    torch.Tensor
    """
    assert_complex(data, enforce_named=True)
    names = data.names
    data = data.rename(None).clone(
    )  # Clone is required as the data in the next line is changed in-place.
    data[..., 1] = data[..., 1] * -1.0
    data = data.refine_names(*names)
    return data
Beispiel #8
0
def complex_multiplication(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Multiplies two complex-valued tensors. Assumes the tensor has a named dimension "complex".

    Parameters
    ----------
    x : torch.Tensor
        Input data
    y : torch.Tensor
        Input data

    Returns
    -------
    torch.Tensor
    """
    assert_complex(x, enforce_named=True, complex_last=True)
    assert_complex(y, enforce_named=True, complex_last=True)
    # multiplication = torch.view_as_complex(x.rename(None)) * torch.view_as_complex(
    #     y.rename(None)
    # )
    # return torch.view_as_real(multiplication).refine_names(*x.names)
    # TODO: Unsqueezing is not yet supported for named tensors, fix when it is.
    complex_index = x.names.index("complex")

    real_part = x.select("complex", 0) * y.select("complex", 0) - x.select(
        "complex", 1
    ) * y.select("complex", 1)
    imaginary_part = x.select("complex", 0) * y.select("complex", 1) + x.select(
        "complex", 1
    ) * y.select("complex", 0)

    real_part = real_part.rename(None)
    imaginary_part = imaginary_part.rename(None)

    multiplication = torch.cat(
        [
            real_part.unsqueeze(dim=complex_index),
            imaginary_part.unsqueeze(dim=complex_index),
        ],
        dim=complex_index,
    )

    return multiplication.refine_names(*x.names)
Beispiel #9
0
def view_as_complex(data):
    """Named version of `torch.view_as_complex()`"""
    assert_complex(data)
    return torch.view_as_complex(
        data.rename(None)).refine_names(*data.names[:-1])