예제 #1
0
def root_sum_of_squares(data: torch.Tensor,
                        dim: Union[int, str] = "coil") -> torch.Tensor:
    """
    Compute the root sum of squares (RSS) transform along a given (perhaps named) dimension of the input tensor.

    $$x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2}$$

    Parameters
    ----------
    data : torch.Tensor
        Input tensor

    dim : Union[int, str]
        Coil dimension.

    Returns
    -------
    torch.Tensor : RSS of the input tensor.
    """
    assert_named(data)
    if "complex" in data.names:
        assert_complex(data, complex_last=True)
        return torch.sqrt((data**2).sum("complex").sum(dim))
    else:
        return torch.sqrt((data**2).sum(dim))
예제 #2
0
def modulus_if_complex(data: torch.Tensor) -> torch.Tensor:
    """
    Compute modulus if complex-valued.

    Parameters
    ----------
    data : torch.Tensor

    Returns
    -------
    torch.Tensor
    """
    # TODO: This can be merged with modulus if the tensor is real.
    assert_named(data)
    if "complex" in data.names:
        return modulus(data)
    return data
예제 #3
0
def safe_divide(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """
    Divide a and b safely, set the output to zero where the divisor b is zero.

    Parameters
    ----------
    a : torch.Tensor
    b : torch.Tensor

    Returns
    -------
    torch.Tensor: the division.

    """
    assert_named(a)
    assert_named(b)

    b = b.align_as(a)
    data = torch.where(
        b.rename(None) == 0,
        torch.tensor([0.0], dtype=a.dtype).to(a.device),
        (a / b).rename(None),
    ).refine_names(*a.names)
    return data