def test_smoke(self, device, dtype): a = torch.randn(5, 3, 3, device=device, dtype=dtype) u, s, v = _torch_svd_cast(a) tol_val: float = 1e-1 if dtype == torch.float16 else 1e-3 assert_allclose(a, u @ torch.diag_embed(s) @ v.transpose(-2, -1), atol=tol_val, rtol=tol_val)
def zca_mean( inp: torch.Tensor, dim: int = 0, unbiased: bool = True, eps: float = 1e-6, return_inverse: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: r""" Computes the ZCA whitening matrix and mean vector. The output can be used with :py:meth:`~kornia.color.linear_transform` See :class:`~kornia.color.ZCAWhitening` for details. args: inp (torch.Tensor) : input data tensor dim (int): Specifies the dimension that serves as the samples dimension. Default = 0 unbiased (bool): Whether to use the unbiased estimate of the covariance matrix. Default = True eps (float) : a small number used for numerical stability. Default = 0 return_inverse (bool): Whether to return the inverse ZCA transform. shapes: - inp: :math:`(D_0,...,D_{\text{dim}},...,D_N)` is a batch of N-D tensors. - transform_matrix: :math:`(\Pi_{d=0,d\neq \text{dim}}^N D_d, \Pi_{d=0,d\neq \text{dim}}^N D_d)` - mean_vector: :math:`(1, \Pi_{d=0,d\neq \text{dim}}^N D_d)` - inv_transform: same shape as the transform matrix returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the ZCA matrix and the mean vector. If return_inverse is set to True, then it returns the inverse ZCA matrix, otherwise it returns None. Examples: >>> x = torch.tensor([[0,1],[1,0],[-1,0],[0,-1]], dtype = torch.float32) >>> transform_matrix, mean_vector,_ = zca_mean(x) # Returns transformation matrix and data mean >>> x = torch.rand(3,20,2,2) >>> transform_matrix, mean_vector, inv_transform = zca_mean(x, dim = 1, return_inverse = True) >>> # transform_matrix.size() equals (12,12) and the mean vector.size equal (1,12) """ if not isinstance(inp, torch.Tensor): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(inp))) if not isinstance(eps, float): raise TypeError(f"eps type is not a float. Got{type(eps)}") if not isinstance(unbiased, bool): raise TypeError(f"unbiased type is not bool. Got{type(unbiased)}") if not isinstance(dim, int): raise TypeError("Argument 'dim' must be of type int. Got {}".format( type(dim))) if not isinstance(return_inverse, bool): raise TypeError( "Argument return_inverse must be of type bool {}".format( type(return_inverse))) inp_size = inp.size() if dim >= len(inp_size) or dim < -len(inp_size): raise IndexError( "Dimension out of range (expected to be in range of [{},{}], but got {}" .format(-len(inp_size), len(inp_size) - 1, dim)) if dim < 0: dim = len(inp_size) + dim feat_dims = torch.cat( [torch.arange(0, dim), torch.arange(dim + 1, len(inp_size))]) new_order: List[int] = torch.cat([torch.tensor([dim]), feat_dims]).tolist() inp_permute = inp.permute(new_order) N = inp_size[dim] feature_sizes = torch.tensor(inp_size[0:dim] + inp_size[dim + 1::]) num_features: int = int(torch.prod(feature_sizes).item()) mean: torch.Tensor = torch.mean(inp_permute, dim=0, keepdim=True) mean = mean.reshape((1, num_features)) inp_center_flat: torch.Tensor = inp_permute.reshape( (N, num_features)) - mean cov = inp_center_flat.t().mm(inp_center_flat) if unbiased: cov = cov / float(N - 1) else: cov = cov / float(N) U, S, _ = _torch_svd_cast(cov) S = S.reshape(-1, 1) S_inv_root: torch.Tensor = torch.rsqrt(S + eps) T: torch.Tensor = (U).mm(S_inv_root * U.t()) T_inv: Optional[torch.Tensor] = None if return_inverse: T_inv = (U).mm(torch.sqrt(S + eps) * U.t()) return T, mean, T_inv