def root_sum_of_squares(data: torch.Tensor, dim: Union[int, str] = "coil") -> torch.Tensor: """ Compute the root sum of squares (RSS) transform along a given (perhaps named) dimension of the input tensor. $$x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2}$$ Parameters ---------- data : torch.Tensor Input tensor dim : Union[int, str] Coil dimension. Returns ------- torch.Tensor : RSS of the input tensor. """ assert_named(data) if "complex" in data.names: assert_complex(data, complex_last=True) return torch.sqrt((data**2).sum("complex").sum(dim)) else: return torch.sqrt((data**2).sum(dim))
def modulus_if_complex(data: torch.Tensor) -> torch.Tensor: """ Compute modulus if complex-valued. Parameters ---------- data : torch.Tensor Returns ------- torch.Tensor """ # TODO: This can be merged with modulus if the tensor is real. assert_named(data) if "complex" in data.names: return modulus(data) return data
def safe_divide(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ Divide a and b safely, set the output to zero where the divisor b is zero. Parameters ---------- a : torch.Tensor b : torch.Tensor Returns ------- torch.Tensor: the division. """ assert_named(a) assert_named(b) b = b.align_as(a) data = torch.where( b.rename(None) == 0, torch.tensor([0.0], dtype=a.dtype).to(a.device), (a / b).rename(None), ).refine_names(*a.names) return data