예제 #1
0
def _single_tensor_adadelta(params: List[Tensor],
                            grads: List[Tensor],
                            square_avgs: List[Tensor],
                            acc_deltas: List[Tensor],
                            *,
                            lr: float,
                            rho: float,
                            eps: float,
                            weight_decay: float,
                            maximize: bool):

    for (param, grad, square_avg, acc_delta) in zip(params, grads, square_avgs, acc_deltas):
        grad = grad if not maximize else -grad

        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        if torch.is_complex(param):
            square_avg = torch.view_as_real(square_avg)
            acc_delta = torch.view_as_real(acc_delta)
            grad = torch.view_as_real(grad)

        square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)
        std = square_avg.add(eps).sqrt_()
        delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad)
        acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)
        if torch.is_complex(param):
            delta = torch.view_as_complex(delta)
        param.add_(delta, alpha=-lr)
예제 #2
0
파일: mathops.py 프로젝트: antsfamily/thlib
def matmulcc(A, B):
    """Complex matrix multiplication

    like A * B in matlab

    Parameters
    ----------
    A : {torch array}
        any size torch array, both complex and real representation are supported.
        For real representation, the last dimension is 2 (the first --> real part, the second --> imaginary part).
    B : {torch array}
        any size torch array, both complex and real representation are supported.
        For real representation, the last dimension is 2 (the first --> real part, the second --> imaginary part).

    Returns
    -------
    torch array
        result of complex multiplication with the same repesentation as :attr:`A`.
    """

    if th.is_complex(A) and th.is_complex(B):
        return th.matmul(A.real, B.real) - th.matmul(A.imag, B.imag) + 1j * (
            th.matmul(A.real, B.imag) + th.matmul(A.imag, B.real))
    else:
        return th.stack(
            (th.matmul(A[..., 0], B[..., 0]) - th.matmul(A[..., 1], B[..., 1]),
             th.matmul(A[..., 0], B[..., 1]) +
             th.matmul(A[..., 1], B[..., 0])),
            dim=-1)
예제 #3
0
def einsum(equation, *operands):
    # NOTE: Do not mix ComplexTensor and torch.complex in the input!
    # NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support
    # mixed input with complex and real tensors.
    if len(operands) == 1:
        if isinstance(operands[0], (tuple, list)):
            operands = operands[0]
        complex_module = FC if isinstance(operands[0],
                                          ComplexTensor) else torch
        return complex_module.einsum(equation, *operands)
    elif len(operands) != 2:
        op0 = operands[0]
        same_type = all(op.dtype == op0.dtype for op in operands[1:])
        if same_type:
            _einsum = FC.einsum if isinstance(op0,
                                              ComplexTensor) else torch.einsum
            return _einsum(equation, *operands)
        else:
            raise ValueError("0 or More than 2 operands are not supported.")
    a, b = operands
    if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
        return FC.einsum(equation, a, b)
    elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
        if not torch.is_complex(a):
            o_real = torch.einsum(equation, a, b.real)
            o_imag = torch.einsum(equation, a, b.imag)
            return torch.complex(o_real, o_imag)
        elif not torch.is_complex(b):
            o_real = torch.einsum(equation, a.real, b)
            o_imag = torch.einsum(equation, a.imag, b)
            return torch.complex(o_real, o_imag)
        else:
            return torch.einsum(equation, a, b)
    else:
        return torch.einsum(equation, a, b)
예제 #4
0
def adadelta(params: List[Tensor], grads: List[Tensor],
             square_avgs: List[Tensor], acc_deltas: List[Tensor], *, lr: float,
             rho: float, eps: float, weight_decay: float):
    r"""Functional API that performs Adadelta algorithm computation.

    See :class:`~torch.optim.Adadelta` for details.
    """

    for (param, grad, square_avg, acc_delta) in zip(params, grads, square_avgs,
                                                    acc_deltas):
        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        if torch.is_complex(param):
            square_avg = torch.view_as_real(square_avg)
            acc_delta = torch.view_as_real(acc_delta)
            grad = torch.view_as_real(grad)

        square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)
        std = square_avg.add(eps).sqrt_()
        delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad)
        acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)
        if torch.is_complex(param):
            delta = torch.view_as_complex(delta)
        param.add_(delta, alpha=-lr)
예제 #5
0
def inner_product(val1: Tensor, val2: Tensor, dim: int = -1) -> Tensor:
    """Complex inner product.

    Args:
        val1: A tensor for the inner product.
        val2: A second tensor for the inner product.
        dim: An integer indicating the complex dimension (for real inputs
            only).

    Returns:
        The complex inner product of ``val1`` and ``val2``.
    """
    if not val1.dtype == val2.dtype:
        raise ValueError("val1 has different dtype than val2.")

    if not torch.is_complex(val1):
        if not val1.shape[dim] == val2.shape[dim] == 2:
            raise ValueError(
                "Real input does not have dimension size 2 at dim.")

    inprod = conj_complex_mult(val2, val1, dim=dim)

    if not torch.is_complex(val1):
        inprod = torch.cat((inprod.select(dim, 0).sum().view(1),
                            inprod.select(dim, 1).sum().view(1)))
    else:
        inprod = torch.sum(inprod)

    return inprod
예제 #6
0
파일: mathops.py 프로젝트: antsfamily/thlib
def ebemulcc(A, B):
    """Element-by-element complex multiplication

    like A .* B in matlab

    Parameters
    ----------
    A : {torch array}
        any size torch array, both complex and real representation are supported.
        For real representation, the last dimension is 2 (the first --> real part, the second --> imaginary part).
    B : {torch array}
        any size torch array, both complex and real representation are supported.
        For real representation, the last dimension is 2 (the first --> real part, the second --> imaginary part).
        :attr:`B` has the same size as :attr:`A`.

    Returns
    -------
    torch array
        result of element-by-element complex multiplication with the same repesentation as :attr:`A`.
    """

    if th.is_complex(A) and th.is_complex(B):
        return A.real * B.real - A.imag * B.imag + 1j * (A.real * B.imag +
                                                         A.imag * B.real)
    else:
        return th.stack((A[..., 0] * B[..., 0] - A[..., 1] * B[..., 1],
                         A[..., 0] * B[..., 1] + A[..., 1] * B[..., 0]),
                        dim=-1)
예제 #7
0
def dist_proj(X, Y):
    Px = torch.matmul(X, torch.matmul(torch.inverse(torch.matmul(X.conj().t(), X)), X.conj().t()))
    Py = torch.matmul(Y, torch.matmul(torch.inverse(torch.matmul(Y.conj().t(), Y)), Y.conj().t()))
    if torch.is_complex(X) or torch.is_complex(Y):
        P = Px - Py
        return torch.sqrt(torch.sum(torch.matmul(P,P.conj().t()))).real/np.sqrt(2)
    else:
        return torch.norm(Px - Py)/np.sqrt(2)
예제 #8
0
def adagrad(params: List[Tensor], grads: List[Tensor],
            state_sums: List[Tensor], state_steps: List[int],
            has_sparse_grad: bool, *, lr: float, weight_decay: float,
            lr_decay: float, eps: float):
    r"""Functional API that performs Adagrad algorithm computation.

    See :class:`~torch.optim.Adagrad` for details.
    """

    if weight_decay != 0:
        if has_sparse_grad:
            raise RuntimeError(
                "weight_decay option is not compatible with sparse gradients")
        torch._foreach_add_(grads, params, alpha=weight_decay)

    minus_clr = [-lr / (1 + (step - 1) * lr_decay) for step in state_steps]

    if has_sparse_grad:
        # sparse is not supported by multi_tensor. Fall back to optim.adagrad
        # implementation for sparse gradients
        for i, (param, grad, state_sum,
                step) in enumerate(zip(params, grads, state_sums,
                                       state_steps)):
            grad = grad.coalesce(
            )  # the update is non-linear so indices must be unique
            grad_indices = grad._indices()
            grad_values = grad._values()
            size = grad.size()

            state_sum.add_(_make_sparse(grad, grad_indices,
                                        grad_values.pow(2)))
            std_sparse = state_sum.sparse_mask(grad)
            std_sparse_values = std_sparse._values().sqrt_().add_(eps)
            param.add_(
                _make_sparse(grad, grad_indices,
                             grad_values / std_sparse_values),
                alpha=minus_clr[i],
            )
    else:
        grads = [
            torch.view_as_real(x) if torch.is_complex(x) else x for x in grads
        ]
        state_sums = [
            torch.view_as_real(x) if torch.is_complex(x) else x
            for x in state_sums
        ]
        torch._foreach_addcmul_(state_sums, grads, grads, value=1)
        std = torch._foreach_add(torch._foreach_sqrt(state_sums), eps)
        toAdd = torch._foreach_div(torch._foreach_mul(grads, minus_clr), std)
        toAdd = [
            torch.view_as_complex(x) if torch.is_complex(params[i]) else x
            for i, x in enumerate(toAdd)
        ]
        torch._foreach_add_(params, toAdd)
        state_sums = [
            torch.view_as_complex(x) if torch.is_complex(params[i]) else x
            for i, x in enumerate(state_sums)
        ]
예제 #9
0
def _multi_tensor_adagrad(params: List[Tensor], grads: List[Tensor],
                          state_sums: List[Tensor], state_steps: List[Tensor],
                          *, lr: float, weight_decay: float, lr_decay: float,
                          eps: float, has_sparse_grad: bool):

    # Foreach functions will throw errors if given empty lists
    if len(params) == 0:
        return

    if has_sparse_grad is None:
        has_sparse_grad = any([grad.is_sparse for grad in grads])

    if has_sparse_grad:
        return _single_tensor_adagrad(params,
                                      grads,
                                      state_sums,
                                      state_steps,
                                      lr=lr,
                                      weight_decay=weight_decay,
                                      lr_decay=lr_decay,
                                      eps=eps,
                                      has_sparse_grad=has_sparse_grad)

    # Update steps
    torch._foreach_add_(state_steps, 1)

    if weight_decay != 0:
        torch._foreach_add_(grads, params, alpha=weight_decay)

    minus_clr = [-lr / (1 + (step - 1) * lr_decay) for step in state_steps]

    grads = [
        torch.view_as_real(x) if torch.is_complex(x) else x for x in grads
    ]
    state_sums = [
        torch.view_as_real(x) if torch.is_complex(x) else x for x in state_sums
    ]
    torch._foreach_addcmul_(state_sums, grads, grads, value=1)
    std = torch._foreach_add(torch._foreach_sqrt(state_sums), eps)
    toAdd = torch._foreach_div(torch._foreach_mul(grads, minus_clr), std)
    toAdd = [
        torch.view_as_complex(x) if torch.is_complex(params[i]) else x
        for i, x in enumerate(toAdd)
    ]
    torch._foreach_add_(params, toAdd)
    state_sums = [
        torch.view_as_complex(x) if torch.is_complex(params[i]) else x
        for i, x in enumerate(state_sums)
    ]
예제 #10
0
 def detect_complex(x):
     if type(x) == list:
         return any(type(v) == complex for v in x)
     elif type(x) == torch.Tensor:
         return torch.is_complex(x)
     else:
         return type(x) == complex
예제 #11
0
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and
                have a device type of ``'cpu'``.
        Returns:
            torch.Tensor: ``x`` filtered by Savitzky-Golay kernel with window length ``self.window_length`` using
            polynomials of order ``self.order``, along axis specified in ``self.axis``.
        """

        # Make input a real tensor on the CPU
        x = torch.as_tensor(
            x, device=x.device if isinstance(x, torch.Tensor) else None)
        if torch.is_complex(x):
            raise ValueError("x must be real.")
        x = x.to(dtype=torch.float)

        if (self.axis < 0) or (self.axis > len(x.shape) - 1):
            raise ValueError("Invalid axis for shape of x.")

        # Create list of filter kernels (1 per spatial dimension). The kernel for self.axis will be the savgol coeffs,
        # while the other kernels will be set to [1].
        n_spatial_dims = len(x.shape) - 2
        spatial_processing_axis = self.axis - 2
        new_dims_before = spatial_processing_axis
        new_dims_after = n_spatial_dims - spatial_processing_axis - 1
        kernel_list = [self.coeffs.to(device=x.device, dtype=x.dtype)]
        for _ in range(new_dims_before):
            kernel_list.insert(0, torch.ones(1, device=x.device,
                                             dtype=x.dtype))
        for _ in range(new_dims_after):
            kernel_list.append(torch.ones(1, device=x.device, dtype=x.dtype))

        return separable_filtering(x, kernel_list, mode=self.mode)
예제 #12
0
    def forward(self, input: ComplexTensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (ComplexTensor): spectrum [Batch, T, (C,) F]
            ilens (torch.Tensor): input lengths [Batch]
        """
        if not isinstance(input, ComplexTensor) and (
                is_torch_1_9_plus and not torch.is_complex(input)):
            raise TypeError("Only support complex tensors for stft decoder")

        bs = input.size(0)
        if input.dim() == 4:
            multi_channel = True
            # input: (Batch, T, C, F) -> (Batch * C, T, F)
            input = input.transpose(1, 2).reshape(-1, input.size(1),
                                                  input.size(3))
        else:
            multi_channel = False

        wav, wav_lens = self.stft.inverse(input, ilens)

        if multi_channel:
            # wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C)
            wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2)

        return wav, wav_lens
예제 #13
0
def conj_complex_mult(val1: Tensor, val2: Tensor, dim: int = -1) -> Tensor:
    """Complex multiplication, conjugating second input.

    Args:
        val1: A tensor to be multiplied.
        val2: A second tensor to be conjugated then multiplied.
        dim: An integer indicating the complex dimension (for real inputs
            only).

    Returns:
        ``val3 = val1 * conj(val2)``, where * executes complex multiplication.
    """
    if not val1.dtype == val2.dtype:
        raise ValueError("val1 has different dtype than val2.")

    if torch.is_complex(val1):
        val3 = val1 * val2.conj()
    else:
        if not val1.shape[dim] == val2.shape[dim] == 2:
            raise ValueError(
                "Real input does not have dimension size 2 at dim.")

        real_a = val1.select(dim, 0)
        imag_a = val1.select(dim, 1)
        real_b = val2.select(dim, 0)
        imag_b = val2.select(dim, 1)

        val3 = torch.stack((real_a * real_b + imag_a * imag_b,
                            imag_a * real_b - real_a * imag_b), dim)

    return val3
예제 #14
0
    def forward(self, P, G):
        D = P.dim()
        if self.axis is None:
            axis = list(range(1, D)) if D > 2 else list(range(0, D))
        else:
            axis = self.axis
        axis = [a + D if a < 0 else a for a in axis]
        caxis = self.caxis
        if th.is_complex(P):
            caxis = None
        if caxis is not None:
            caxis = self.caxis + D if self.caxis < 0 else self.caxis
            if caxis != D - 1:
                newshape = list(range(0, caxis)) + list(range(caxis + 1,
                                                              D)) + [caxis]
                P = P.permute(newshape)
                G = G.permute(newshape)
                axis = [a if a < caxis else a - 1 for a in axis]

            P = P[..., 0] + 1j * P[..., 1]
            G = G[..., 0] + 1j * G[..., 1]

        if self.norm in ['max', 'MAX', 'Max']:
            maxv = G.abs().max() + 1e-16
            P = P / maxv
            G = G / maxv

        for a in axis:
            P = th.fft.fft(P, n=None, dim=a)
            G = th.fft.fft(G, n=None, dim=a)

        P, G = P.angle(), G.angle()

        return self.lossfn(P, G)
예제 #15
0
파일: norm.py 프로젝트: aisari/torchsar
    def forward(self, X):
        if th.is_complex(X):
            X = X.abs()
        elif (self.caxis is None) or X.shape[-1] == 2:
            X = X.pow(2).sum(axis=-1, keepdims=True).sqrt()
        else:
            X = X.pow(2).sum(axis=self.caxis, keepdims=True).sqrt()

        if X.dtype is not th.float32 or th.double:
            X = X.to(th.float32)

        if self.axis is None:
            D = X.dim()
            axis = list(range(1, D)) if D > 2 else list(range(0, D))
        else:
            axis = self.axis

        X = th.mean(X.pow(self.p), axis=axis).pow(1. / self.p)

        if self.reduction == 'mean':
            F = th.mean(X)
        if self.reduction == 'sum':
            F = th.sum(X)

        return th.log(F)
예제 #16
0
def perform_filter_operation(
    Y: Union[torch.Tensor, ComplexTensor],
    filter_matrix_conj: Union[torch.Tensor, ComplexTensor],
    taps,
    delay,
) -> Union[torch.Tensor, ComplexTensor]:
    """perform_filter_operation

    Args:
        Y : Complex-valued STFT signal of shape (F, C, T)
        filter Matrix (F, taps, C, C)
    """
    if isinstance(Y, ComplexTensor):
        complex_module = FC
        pad_func = FC.pad
    elif is_torch_1_9_plus and torch.is_complex(Y):
        complex_module = torch
        pad_func = F.pad
    else:
        raise ValueError(
            "Please update your PyTorch version to 1.9+ for complex support.")

    T = Y.size(-1)
    # Y_tilde: (taps, F, C, T)
    Y_tilde = complex_module.stack(
        [
            pad_func(Y[:, :, :T - delay - i], (delay + i, 0),
                     mode="constant",
                     value=0) for i in range(taps)
        ],
        dim=0,
    )
    reverb_tail = complex_module.einsum("fpde,pfdt->fet",
                                        (filter_matrix_conj, Y_tilde))
    return Y - reverb_tail
예제 #17
0
    def forward(self, X):

        if th.is_complex(X):
            X = (X * X.conj()).real
        elif X.size(-1) == 2:
            X = X.pow(2).sum(axis=-1)

        D = X.dim()
        axis = list(range(1, D))

        if X.dtype is not th.float32 or th.double:
            X = X.to(th.float32)

        if self.mode in ['way1', 'WAY1']:
            Xmean = X.mean(axis=axis, keepdims=True)
            C = (X - Xmean).pow(2).mean(axis=axis,
                                        keepdims=True).sqrt() / (Xmean + EPS)
        if self.mode in ['way2', 'WAY2']:
            C = X.mean(axis=axis, keepdims=True) / (
                (X.sqrt().mean(axis=axis, keepdims=True)).pow(2) + EPS)

        if self.reduction == 'mean':
            C = th.mean(C)
        if self.reduction == 'sum':
            C = th.sum(C)
        return -C
예제 #18
0
파일: entropy.py 프로젝트: aisari/torchsar
    def forward(self, X):
        if th.is_complex(X):
            X = (X * X.conj()).real
        elif (self.caxis is None) or X.shape[-1] == 2:
            X = th.sum(X.pow(2), axis=-1, keepdims=True)
        else:
            X = th.sum(X.pow(2), axis=self.caxis, keepdims=True)

        if self.axis is None:
            D = X.dim()
            axis = list(range(1, D)) if D > 2 else list(range(0, D))
        else:
            axis = self.axis

        P = th.sum(X, axis, keepdims=True)
        p = X / (P + EPS)
        if self.mode in ['Shannon', 'shannon', 'SHANNON']:
            S = -th.sum(p * th.log2(p + EPS), axis)
        if self.mode in ['Natural', 'natural', 'NATURAL']:
            S = -th.sum(p * th.log(p + EPS), axis)
        if self.reduction == 'mean':
            S = th.mean(S)
        if self.reduction == 'sum':
            S = th.sum(S)

        return S
예제 #19
0
    def forward(self, P, G):
        D = P.dim()
        caxis = self.caxis
        if th.is_complex(P):
            caxis = None
        if caxis is not None:
            caxis = self.caxis + D if self.caxis < 0 else self.caxis
            if caxis != D - 1:
                newshape = list(range(0, caxis)) + list(range(caxis + 1, D)) + [caxis]
                P = P.permute(newshape)
                G = G.permute(newshape)

            if P.shape[-1] == 2:
                P = P[..., 0] + 1j * P[..., 1]
            if G.shape[-1] == 2:
                G = G[..., 0] + 1j * G[..., 1]

        axis = list(range(1, P.dim()))

        if self.norm in ['max', 'MAX', 'Max']:
            maxv = G.abs().max() + 1e-16
            P = P / maxv
            G = G / maxv

        F = th.mean((P - G).abs(), axis=axis)

        if self.reduction == 'mean':
            F = th.mean(F)
        if self.reduction == 'sum':
            F = th.sum(F)

        return F
예제 #20
0
def frobenius(X, p=2, reduction='mean'):
    r"""frobenius norm

        .. math::
           \|\bm X\|_p^p = (\sum{x^p})^{1/p}

    """

    if th.is_complex(X):
        X = ((X * X.conj()).real).sqrt()
    elif X.size(-1) == 2:
        X = X.pow(2).sum(axis=-1).sqrt()

    if X.dtype is not th.float32 or th.double:
        X = X.to(th.float32)

    D = X.dim()
    dim = list(range(1, D))
    X = th.mean(X.pow(p), axis=dim).pow(1. / p)

    if reduction == 'mean':
        F = th.mean(X)
    if reduction == 'sum':
        F = th.sum(X)

    return F
예제 #21
0
파일: ffts.py 프로젝트: aisari/torchsar
def ifft(x, n=None, axis=0, norm="backward", shift=False):
    """IFFT in torchsar

    IFFT in torchsar, since ifft in torch only supports complex-complex transformation,
    for real ifft, we insert imaginary part with zeros (torch.stack((x,torch.zeros_like(x), dim=-1))),
    also you can use torch's rifft.

    Parameters
    ----------
    x : {torch array}
        both complex and real representation are supported. Since torch does not
        support complex array, when :attr:`x` is complex, we will change the representation
        in real formation(last dimension is 2, real, imag), after IFFT, it will be change back.
    n : int, optional
        number of ifft points (the default is None --> equals to signal dimension)
    axis : int, optional
        axis of ifft (the default is 0, which the first dimension)
    norm : bool, optional
        Normalization mode. For the backward transform (ifft()), these correspond to:
        - "forward" - no normalization
        - "backward" - normalize by ``1/n`` (default)
        - "ortho" - normalize by 1``/sqrt(n)`` (making the IFFT orthonormal)
    shift : bool, optional
        shift the zero frequency to center (the default is False)

    Returns
    -------
    y : {torch array}
        ifft results torch array with the same type as :attr:`x`

    Raises
    ------
    ValueError
        nfft is small than signal dimension.
    """

    if norm is None:
        norm = 'backward'

    if (x.size(-1) == 2) and (not th.is_complex(x)):
        realflag = True
        x = th.view_as_complex(x)
        if axis < 0:
            axis += 1
    else:
        realflag = False

    if shift:
        y = thfft.ifftshift(thfft.ifft(thfft.ifftshift(x, dim=axis),
                                       n=n,
                                       dim=axis,
                                       norm=norm),
                            dim=axis)
    else:
        y = thfft.ifft(x, n=n, dim=axis, norm=norm)

    if realflag:
        y = th.view_as_real(y)

    return y
예제 #22
0
    def forward(self, X):

        if th.is_complex(X):
            X = (X * X.conj()).real
        elif (self.caxis is None) or X.shape[-1] == 2:
            X = th.sum(X.pow(2), axis=-1, keepdims=True)
        else:
            X = th.sum(X.pow(2), axis=self.caxis, keepdims=True)

        if self.axis is None:
            D = X.dim()
            axis = list(range(1, D)) if D > 2 else list(range(0, D))
        else:
            axis = self.axis

        if X.dtype is not th.float32 or th.double:
            X = X.to(th.float32)

        if self.mode in ['way1', 'WAY1']:
            Xmean = X.mean(axis=axis, keepdims=True)
            C = (X - Xmean).pow(2).mean(axis=axis,
                                        keepdims=True).sqrt() / (Xmean + EPS)
        if self.mode in ['way2', 'WAY2']:
            C = X.mean(axis=axis, keepdims=True) / (
                (X.sqrt().mean(axis=axis, keepdims=True)).pow(2) + EPS)
        if self.reduction == 'mean':
            C = th.mean(C)
        if self.reduction == 'sum':
            C = th.sum(C)
        return -C
예제 #23
0
def getLamdaGaplist(lambdas: torch.Tensor):
    """
    Calculate the gaps between lambda values.
    """
    if torch.is_complex(lambdas):
        lambdas = torch.real(lambdas)
    return lambdas[1:] - lambdas[:-1]
예제 #24
0
파일: entropy.py 프로젝트: aisari/torchsar
    def forward(self, X):

        if th.is_complex(X):
            X = (X * X.conj()).real
        elif X.size(-1) == 2:
            X = th.sum(X.pow(2), axis=-1)

        if X.dim() == 2:
            axis = (0, 1)
        if X.dim() == 3:
            axis = (1, 2)
        if X.dim() == 4:
            axis = (1, 2, 3)

        P = th.sum(X, axis, keepdims=True)
        p = X / (P + EPS)
        if self.mode in ['Shannon', 'shannon', 'SHANNON']:
            S = -th.sum(p * th.log2(p + EPS), axis)
        if self.mode in ['Natural', 'natural', 'NATURAL']:
            S = -th.sum(p * th.log(p + EPS), axis)

        if self.reduction == 'mean':
            S = th.mean(S)
        if self.reduction == 'sum':
            S = th.sum(S)

        return S
예제 #25
0
파일: utils.py 프로젝트: cxdcxd/pykeen
def negative_norm(
    x: torch.FloatTensor,
    p: Union[str, int] = 2,
    power_norm: bool = False,
) -> torch.FloatTensor:
    """Evaluate negative norm of a vector.

    :param x: shape: (batch_size, num_heads, num_relations, num_tails, dim)
        The vectors.
    :param p:
        The p for the norm. cf. torch.norm.
    :param power_norm:
        Whether to return $|x-y|_p^p$, cf. https://github.com/pytorch/pytorch/issues/28119

    :return: shape: (batch_size, num_heads, num_relations, num_tails)
        The scores.
    """
    if power_norm:
        assert not isinstance(p, str)
        return -(x.abs()**p).sum(dim=-1)

    if torch.is_complex(x):
        assert not isinstance(p, str)
        # workaround for complex numbers: manually compute norm
        return -(x.abs()**p).sum(dim=-1)**(1 / p)

    return -x.norm(p=p, dim=-1)
예제 #26
0
    def forward(self, X):
        D = X.dim()
        if self.axis is None:
            axis = list(range(1, D)) if D > 2 else list(range(0, D))
        else:
            axis = self.axis
        axis = [a + D if a < 0 else a for a in axis]

        if th.is_complex(X):
            caxis = None
        else:
            caxis = self.caxis + D if self.caxis < 0 else self.caxis
            if caxis != D - 1:
                newshape = list(range(0, caxis)) + list(range(caxis + 1,
                                                              D)) + [caxis]
                X = X.permute(newshape)
                axis = [a if a < caxis else a - 1 for a in axis]

            X = X[..., 0] + 1j * X[..., 1]

        for a in axis:
            X = th.fft.fft(X, n=None, dim=a)

        X = X.abs()

        S = th.sum(th.log2(1 + X / self.p), axis)
        if self.reduction == 'mean':
            S = th.mean(S)
        if self.reduction == 'sum':
            S = th.sum(S)

        return S
예제 #27
0
def signal_framing(
    signal: Union[torch.Tensor, ComplexTensor],
    frame_length: int,
    frame_step: int,
    pad_value=0,
) -> Union[torch.Tensor, ComplexTensor]:
    """Expands signal into frames of frame_length.

    Args:
        signal : (B * F, D, T)
    Returns:
        torch.Tensor: (B * F, D, T, W)
    """
    if isinstance(signal, ComplexTensor):
        real = signal_framing(signal.real, frame_length, frame_step, pad_value)
        imag = signal_framing(signal.imag, frame_length, frame_step, pad_value)
        return ComplexTensor(real, imag)
    elif is_torch_1_9_plus and torch.is_complex(signal):
        real = signal_framing(signal.real, frame_length, frame_step, pad_value)
        imag = signal_framing(signal.imag, frame_length, frame_step, pad_value)
        return torch.complex(real, imag)

    signal = F.pad(signal, (0, frame_length - 1), "constant", pad_value)
    indices = sum(
        [
            list(range(i, i + frame_length))
            for i in range(0,
                           signal.size(-1) - frame_length + 1, frame_step)
        ],
        [],
    )

    signal = signal[..., indices].view(*signal.size()[:-1], -1, frame_length)
    return signal
예제 #28
0
def batch_fft(data, normalize=False):
    """
    Compute fourier transform of batch.
    Args:
        data: input tensor, (NxHxW)

    Returns:
    Batch fourier transform of input data.
    """

    dim = data.ndim - 1  # subtract one for batch dimension
    if dim != 2:
        raise AttributeError(f'Data must be 2d but it is {dim}d.')

    dims = tuple(range(1, dim + 1))  # add one for batch dimension
    if normalize:
        norm = 'ortho'
    else:
        norm = 'backward'

    if not torch.is_complex(data):
        data = torch.complex(data, torch.zeros_like(data))
    freq = fftn(data, dim=dims, norm=norm)

    return freq
예제 #29
0
    def forward(self, X, w=None):
        r"""[summary]

        [description]

        Parameters
        ----------
        X : {[type]}
            After fft in azimuth
        w : {[type]}, optional
            [description] (the default is None, which [default_description])

        Returns
        -------
        [type]
            [description]
        """

        if th.is_complex(X):
            X = X.abs()
        elif X.shape[-1] == 2:
            X = th.view_as_complex(X)
            X = X.abs()

        if w is None:
            wshape = [1] * (X.dim())
            wshape[-2] = X.size(-2)
            w = th.ones(wshape, device=X.device, dtype=X.dtype)
        fv = th.sum((th.sum(w * X, axis=-2)).pow(self.p), axis=-1)

        if self.reduction == 'mean':
            C = th.mean(fv)
        if self.reduction == 'sum':
            C = th.sum(fv)
        return C
예제 #30
0
파일: adamax.py 프로젝트: huaxz1986/pytorch
def _multi_tensor_adamax(params: List[Tensor],
                         grads: List[Tensor],
                         exp_avgs: List[Tensor],
                         exp_infs: List[Tensor],
                         state_steps: List[Tensor],
                         *,
                         beta1: float,
                         beta2: float,
                         lr: float,
                         weight_decay: float,
                         eps: float,
                         maximize: bool):

    if len(params) == 0:
        return

    if maximize:
        grads = torch._foreach_neg(grads)

    params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params]
    grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads]
    exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs]
    exp_infs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_infs]

    # Update steps
    torch._foreach_add_(state_steps, 1)

    if weight_decay != 0:
        torch._foreach_add_(grads, params, alpha=weight_decay)

    # Update biased first moment estimate.
    torch._foreach_mul_(exp_avgs, beta1)
    torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)

    # Update the exponentially weighted infinity norm.
    torch._foreach_mul_(exp_infs, beta2)

    for exp_inf, grad in zip(exp_infs, grads):
        norm_buf = torch.cat([
            exp_inf.unsqueeze(0),
            grad.abs().add_(eps).unsqueeze_(0)
        ], 0)
        torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long()))

    bias_corrections = [1 - beta1 ** step.item() for step in state_steps]
    clr = [-1 * (lr / bias_correction) for bias_correction in bias_corrections]
    torch._foreach_addcdiv_(params, exp_avgs, exp_infs, clr)