Example #1
0
    def test_smoke(self, device, dtype):
        a = torch.randn(5, 3, 3, device=device, dtype=dtype)
        u, s, v = _torch_svd_cast(a)

        tol_val: float = 1e-1 if dtype == torch.float16 else 1e-3
        assert_allclose(a,
                        u @ torch.diag_embed(s) @ v.transpose(-2, -1),
                        atol=tol_val,
                        rtol=tol_val)
Example #2
0
def zca_mean(
    inp: torch.Tensor,
    dim: int = 0,
    unbiased: bool = True,
    eps: float = 1e-6,
    return_inverse: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    r"""

    Computes the ZCA whitening matrix and mean vector. The output can be used with
    :py:meth:`~kornia.color.linear_transform`

    See :class:`~kornia.color.ZCAWhitening` for details.


    args:
        inp (torch.Tensor) : input data tensor
        dim (int): Specifies the dimension that serves as the samples dimension. Default = 0
        unbiased (bool): Whether to use the unbiased estimate of the covariance matrix. Default = True
        eps (float) : a small number used for numerical stability. Default = 0
        return_inverse (bool): Whether to return the inverse ZCA transform.

    shapes:
        - inp: :math:`(D_0,...,D_{\text{dim}},...,D_N)` is a batch of N-D tensors.
        - transform_matrix: :math:`(\Pi_{d=0,d\neq \text{dim}}^N D_d, \Pi_{d=0,d\neq \text{dim}}^N D_d)`
        - mean_vector: :math:`(1, \Pi_{d=0,d\neq \text{dim}}^N D_d)`
        - inv_transform: same shape as the transform matrix

    returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        A tuple containing the ZCA matrix and the mean vector. If return_inverse is set to True,
        then it returns the inverse ZCA matrix, otherwise it returns None.

    Examples:
        >>> x = torch.tensor([[0,1],[1,0],[-1,0],[0,-1]], dtype = torch.float32)
        >>> transform_matrix, mean_vector,_ = zca_mean(x) # Returns transformation matrix and data mean
        >>> x = torch.rand(3,20,2,2)
        >>> transform_matrix, mean_vector, inv_transform = zca_mean(x, dim = 1, return_inverse = True)
        >>> # transform_matrix.size() equals (12,12) and the mean vector.size equal (1,12)

    """

    if not isinstance(inp, torch.Tensor):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(
            type(inp)))

    if not isinstance(eps, float):
        raise TypeError(f"eps type is not a float. Got{type(eps)}")

    if not isinstance(unbiased, bool):
        raise TypeError(f"unbiased type is not bool. Got{type(unbiased)}")

    if not isinstance(dim, int):
        raise TypeError("Argument 'dim' must be of type int. Got {}".format(
            type(dim)))

    if not isinstance(return_inverse, bool):
        raise TypeError(
            "Argument return_inverse must be of type bool {}".format(
                type(return_inverse)))

    inp_size = inp.size()

    if dim >= len(inp_size) or dim < -len(inp_size):
        raise IndexError(
            "Dimension out of range (expected to be in range of [{},{}], but got {}"
            .format(-len(inp_size),
                    len(inp_size) - 1, dim))

    if dim < 0:
        dim = len(inp_size) + dim

    feat_dims = torch.cat(
        [torch.arange(0, dim),
         torch.arange(dim + 1, len(inp_size))])

    new_order: List[int] = torch.cat([torch.tensor([dim]), feat_dims]).tolist()

    inp_permute = inp.permute(new_order)

    N = inp_size[dim]
    feature_sizes = torch.tensor(inp_size[0:dim] + inp_size[dim + 1::])
    num_features: int = int(torch.prod(feature_sizes).item())

    mean: torch.Tensor = torch.mean(inp_permute, dim=0, keepdim=True)

    mean = mean.reshape((1, num_features))

    inp_center_flat: torch.Tensor = inp_permute.reshape(
        (N, num_features)) - mean

    cov = inp_center_flat.t().mm(inp_center_flat)

    if unbiased:
        cov = cov / float(N - 1)
    else:
        cov = cov / float(N)

    U, S, _ = _torch_svd_cast(cov)

    S = S.reshape(-1, 1)
    S_inv_root: torch.Tensor = torch.rsqrt(S + eps)
    T: torch.Tensor = (U).mm(S_inv_root * U.t())

    T_inv: Optional[torch.Tensor] = None
    if return_inverse:
        T_inv = (U).mm(torch.sqrt(S + eps) * U.t())

    return T, mean, T_inv