Beispiel #1
0
 def __call__(self, arg):
     """ Apply filter on a (batched) signal. """
     N = arg.shape[-1]
     if not N in self._cached:
         # cache or print or error
         if self.strict:
             print(f"caching sparse operator for size {N}")
         self.cache(N)
     # read cache
     d, w = self._cached[N]
     if d.dim() == arg.dim(): 
         F_arg = rfft(w * arg)
         return irfft(d * F_arg)
     # batched
     F_arg = rfft(w * arg, dim=1)
     return irfft(d * F_arg, dim=1)
Beispiel #2
0
def convolve1d(signal: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
    """
    Computes the 1-d convolution of signal by kernel using FFTs.
    Both signal and kernel must be 1-dimensional.
    :param torch.Tensor signal: A signal to convolve.
    :param torch.Tensor kernel: A convolution kernel.
    :param str mode: One of: 'full', 'valid', 'same'.
    :return: torch.Tensor Convolution of signal with kernel. Returns the full convolution, i.e.,
        the output tensor will have size m + n - 1, where m is the length of the
        signal and n is the length of the kernel.
    """
    assert (signal.ndim == 1
            and kernel.ndim == 1), "signal and kernel must be 1-dimensional"
    m = signal.size(-1)
    n = kernel.size(-1)

    # Compute convolution using fft.
    padded_size = m + n - 1
    # Round up for cheaper fft.
    fast_ftt_size = next_fast_len(padded_size)
    f_signal = rfft(signal, n=fast_ftt_size)
    f_kernel = rfft(kernel, n=fast_ftt_size)
    f_result = f_signal * f_kernel
    result = irfft(f_result, n=fast_ftt_size)

    return result[:padded_size]
def _istft(x,
           n_fft,
           ola_weight,
           win_length,
           window,
           hop_length,
           center,
           normalized,
           onesided,
           pad_mode,
           return_complex,
           norm_envelope=None):
    """
    A helper function to do istft.
    """
    if onesided:
        x = fft.irfft(x,
                      n=n_fft,
                      dim=-2,
                      norm='ortho' if normalized else 'backward')
    else:
        x = fft.ifft(x,
                     n=n_fft,
                     dim=-2,
                     norm='ortho' if normalized else 'backward').real

    x, norm_envelope = _ola(x,
                            hop_length,
                            ola_weight,
                            padding=n_fft // 2 if center else 0,
                            norm_envelope=norm_envelope)
    return x, norm_envelope
Beispiel #4
0
def fft_convolve(signal, kernel):
    signal = nn.functional.pad(signal, (0, signal.shape[-1]))
    kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0))

    output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel))
    output = output[..., output.shape[-1] // 2:]

    return output
Beispiel #5
0
    def __call__(self, image: torch.Tensor,
                 context: ExpressionContext) -> torch.Tensor:
        std = 15. * torch.Tensor(context(self.std)).to(image.device).reshape(
            3, 1)

        space = fft.rfft(image.reshape(3, -1))
        space.real = space.real + torch.randn(space.shape).to(
            image.device) * std
        space.imag = space.imag + torch.randn(space.shape).to(
            image.device) * std

        return fft.irfft(space).reshape(*image.shape)
Beispiel #6
0
def amp_to_impulse_response(amp, target_size):
    amp = torch.stack([amp, torch.zeros_like(amp)], -1)
    amp = torch.view_as_complex(amp)
    amp = fft.irfft(amp)

    filter_size = amp.shape[-1]

    amp = torch.roll(amp, filter_size // 2, -1)
    win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device)

    amp = amp * win

    amp = nn.functional.pad(amp, (0, int(target_size) - int(filter_size)))
    amp = torch.roll(amp, -filter_size // 2, -1)

    return amp
Beispiel #7
0
def conv1d(signal, kernel, mode='fft_circular', cut=False, cut_lim=150):
    """
    signal M x N
    kernel N
    """
    kernel_size = int(kernel.shape[-1])

    if mode == 'direct':
        conved = F.conv1d(signal.unsqueeze(1),
                          kernel.flip(0).unsqueeze(0).unsqueeze(0),
                          padding=kernel_size - 1)[:, 0]

    elif mode == 'fft_circular':
        conved = irfft(rfft(signal) * rfft(kernel), signal.shape[-1])

    if cut:
        conved = conved[:, cut_lim:kernel_size + cut_lim]

    return conved
Beispiel #8
0
def idct(x, dim=-1):
    """
    Inverse discrete cosine transform of type II, scaled to be orthonormal.

    This is the inverse of :func:`dct_ii` , and is equivalent to
    :func:`scipy.fftpack.idct` with ``norm="ortho"``.

    :param Tensor x: The input signal.
    :param int dim: Dimension along which to compute DCT.
    :rtype: Tensor
    """
    if dim >= 0:
        dim -= x.dim()
    if dim != -1:
        y = x.reshape(x.shape[:dim + 1] + (-1, )).transpose(-1, -2)
        return idct(y).transpose(-1, -2).reshape(x.shape)

    N = x.size(-1)
    scale = torch.cat([
        x.new_tensor([math.sqrt(N)]),
        x.new_full((N - 1, ), math.sqrt(0.5 * N))
    ])
    x = x * scale
    # Step 1, solve X = cos(k) * Yr + sin(k) * Yi
    # We know that Y[1:] is conjugate to Y[:0:-1], hence
    # X[:0:-1] = sin(k) * Yr[1:] + cos(k) * Yi[1:]
    # So Yr[1:] = cos(k) * X[1:] + sin(k) * X[:0:-1]
    # and Yi[1:] = sin(k) * X[1:] - cos(k) * X[:0:-1]
    # In addition, Yi[0] = 0, Yr[0] = X[0]
    # In other words, Y = complex_mul(e^ik, X - i[0, X[:0:-1]])
    M = N // 2 + 1  # half size
    xi = torch.nn.functional.pad(-x[..., N - M + 1:], (0, 1)).flip(-1)
    X = torch.stack([x[..., :M], xi], dim=-1)
    coef_real = torch.cos(
        torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype,
                       device=x.device))
    coef = torch.stack([coef_real[:M], coef_real[-M:].flip(-1)], dim=-1)
    Y = as_complex(coef) * as_complex(X)
    # Step 2
    y = irfft(Y, n=N)
    # Step 3
    return torch.stack([y, y.flip(-1)],
                       axis=-1).reshape(x.shape[:-1] + (-1, ))[..., :N]
Beispiel #9
0
def autocorrelation(input, dim=0):
    """
    Computes the autocorrelation of samples at dimension ``dim``.

    Reference: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation

    :param torch.Tensor input: the input tensor.
    :param int dim: the dimension to calculate autocorrelation.
    :returns torch.Tensor: autocorrelation of ``input``.
    """
    if (not input.is_cuda) and (not torch.backends.mkl.is_available()):
        raise NotImplementedError(
            "For CPU tensor, this method is only supported "
            "with MKL installed.")

    # Adapted from Stan implementation
    # https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/autocorrelation.hpp
    N = input.size(dim)
    M = next_fast_len(N)
    M2 = 2 * M

    # transpose dim with -1 for Fourier transform
    input = input.transpose(dim, -1)

    # centering and padding x
    centered_signal = input - input.mean(dim=-1, keepdim=True)

    # Fourier transform
    freqvec = torch.view_as_real(rfft(centered_signal, n=M2))
    # take square of magnitude of freqvec (or freqvec x freqvec*)
    freqvec_gram = freqvec.pow(2).sum(-1)
    # inverse Fourier transform
    autocorr = irfft(freqvec_gram, n=M2)

    # truncate and normalize the result, then transpose back to original shape
    autocorr = autocorr[..., :N]
    autocorr = autocorr / torch.tensor(
        range(N, 0, -1), dtype=input.dtype, device=input.device)
    autocorr = autocorr / autocorr[..., :1]
    return autocorr.transpose(dim, -1)
Beispiel #10
0
def convolve(signal, kernel, mode='full'):
    """
    Computes the 1-d convolution of signal by kernel using FFTs.
    The two arguments should have the same rightmost dim, but may otherwise be
    arbitrarily broadcastable.

    :param torch.Tensor signal: A signal to convolve.
    :param torch.Tensor kernel: A convolution kernel.
    :param str mode: One of: 'full', 'valid', 'same'.
    :return: A tensor with broadcasted shape. Letting ``m = signal.size(-1)``
        and ``n = kernel.size(-1)``, the rightmost size of the result will be:
        ``m + n - 1`` if mode is 'full';
        ``max(m, n) - min(m, n) + 1`` if mode is 'valid'; or
        ``max(m, n)`` if mode is 'same'.
    :rtype torch.Tensor:
    """
    m = signal.size(-1)
    n = kernel.size(-1)
    if mode == 'full':
        truncate = m + n - 1
    elif mode == 'valid':
        truncate = max(m, n) - min(m, n) + 1
    elif mode == 'same':
        truncate = max(m, n)
    else:
        raise ValueError('Unknown mode: {}'.format(mode))

    # Compute convolution using fft.
    padded_size = m + n - 1
    # Round up for cheaper fft.
    fast_ftt_size = next_fast_len(padded_size)
    f_signal = rfft(signal, n=fast_ftt_size)
    f_kernel = rfft(kernel, n=fast_ftt_size)
    f_result = f_signal * f_kernel
    result = irfft(f_result, n=fast_ftt_size)

    start_idx = (padded_size - truncate) // 2
    return result[..., start_idx:start_idx + truncate]
def autocorrelation(input, dim=0):
    """
    Computes the autocorrelation of samples at dimension ``dim``.

    Reference: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation

    Implementation copied form `pyro <https://github.com/pyro-ppl/pyro/blob/dev/pyro/ops/stats.py>`_.

    :param torch.Tensor input: the input tensor.
    :param int dim: the dimension to calculate autocorrelation.
    :returns torch.Tensor: autocorrelation of ``input``.
    """
    # Adapted from Stan implementation
    # https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/autocorrelation.hpp
    N = input.size(dim)
    M = next_fast_len(N)
    M2 = 2 * M

    # transpose dim with -1 for Fourier transform
    input = input.transpose(dim, -1)

    # centering and padding x
    centered_signal = input - input.mean(dim=-1, keepdim=True)

    # Fourier transform
    freqvec = torch.view_as_real(rfft(centered_signal, n=M2))
    # take square of magnitude of freqvec (or freqvec x freqvec*)
    freqvec_gram = freqvec.pow(2).sum(-1)
    # inverse Fourier transform
    autocorr = irfft(freqvec_gram, n=M2)

    # truncate and normalize the result, then transpose back to original shape
    autocorr = autocorr[..., :N]
    autocorr = autocorr / torch.tensor(
        range(N, 0, -1), dtype=input.dtype, device=input.device)
    autocorr = autocorr / autocorr[..., :1]
    return autocorr.transpose(dim, -1)
Beispiel #12
0
                            onesided=True)
        if norm == 'forward':
            output /= float(n)

        # Make complex and move back dimension to its original position
        if _torch_has_complex:
            output = torch.view_as_complex(output)
            output = utils.movedim(output, -1, dim)
        else:
            output = utils.movedim(output, -2, dim if dim > 0 else dim - 1)

        return output


if _torch_has_fft_module:
    irfft = lambda *a, real=None, **k: fft_mod.irfft(*a, **k)
else:

    def irfft(input, n=None, dim=-1, norm='backward', real=None):
        """One dimensional complex-to-real inverse Fourier transform.

        The input is interpreted as a one-sided Hermitian signal in the
        Fourier domain, as produced by rfft(). By the Hermitian property,
        the output will be real-valued.

        Notes
        -----
        .. The correct interpretation of the Hermitian input depends on the
           length of the original data, as given by `n`. This is because each
           input shape could correspond to either an odd or even length signal.
           By default, the signal is assumed to be even length and odd signals
def _fft_convnd(input: Tensor, weight: Tensor, bias: Optional[Tensor],
                stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int],
                groups: int) -> Tensor:

    output_size = _conv_shape(input.shape[2:], weight.shape[2:], stride,
                              padding, dilation)
    reversed_padding_repeated_twice = _reverse_repeat_tuple(padding, 2)
    padded_input = F.pad(input, reversed_padding_repeated_twice)

    s: List[int] = []
    weight_s: List[int] = []
    for i, (x_size, w_size, d, st) in enumerate(
            zip(padded_input.shape[2:], weight.shape[2:], dilation, stride)):
        s_size = max(x_size, w_size * d)

        # find s size that can be divided by stride and dilation
        rfft_even = 2 if i == len(stride) - 1 else 1
        factor = _lcm(st * rfft_even, d * rfft_even)

        offset = s_size % factor
        if offset:
            s_size += factor - offset
        s.append(s_size)
        weight_s.append(s_size // d)

    X = rfftn(padded_input, s=s)

    W = rfft(weight, n=weight_s[-1])
    # handle dilation
    # handle dilation for last dim
    if dilation[-1] > 1:
        W_neg_freq = W.flip(-1)[..., 1:]
        W_neg_freq.imag.mul_(-1)

        tmp = [W]
        for i in range(1, dilation[-1]):
            if i % 2:
                tmp.append(W_neg_freq)
            else:
                tmp.append(W[..., 1:])

        W = torch.cat(tmp, -1)

    if len(weight_s) > 1:
        W = fftn(W, s=weight_s[:-1], dim=tuple(range(2, W.ndim - 1)))
        repeats = (1, 1) + dilation[:-1] + (1, )
        W.imag.mul_(-1)
        if sum(repeats) > W.ndim:
            W = W.repeat(*repeats)
    else:
        W.imag.mul_(-1)

    Y = _complex_matmul(X, W, groups)

    # handle stride
    if len(stride) > 1:
        for i, st in enumerate(stride[:-1]):
            if st > 1:
                Y = Y.reshape(*Y.shape[:i + 2], st, -1,
                              *Y.shape[i + 3:]).mean(i + 2)

            Y = ifft(Y, dim=i + 2)
            Y = Y.as_strided(
                Y.shape[:i + 2] + output_size[i:i + 1] + Y.shape[i + 3:],
                Y.stride())

    if stride[-1] > 1:
        n_fft = Y.size(-1) * 2 - 2
        new_n_fft = n_fft // stride[-1]
        step_size = new_n_fft // 2
        strided_Y_size = step_size + 1

        unfolded_Y_real = Y.real.unfold(-1, strided_Y_size, step_size)
        unfolded_Y_imag = Y.imag[..., 1:].unfold(-1, strided_Y_size - 2,
                                                 step_size)
        Y_pos_real, Y_pos_imag = unfolded_Y_real[..., ::2, :].sum(
            -2), unfolded_Y_imag[..., ::2, :].sum(-2)
        Y_neg_real, Y_neg_imag = unfolded_Y_real[..., 1::2, :].sum(-2).flip(
            -1), unfolded_Y_imag[..., 1::2, :].sum(-2).flip(-1)

        Y_real = Y_pos_real.add_(Y_neg_real)
        Y_imag = Y_pos_imag.add_(Y_neg_imag, alpha=-1)
        Y_imag = F.pad(Y_imag, [1, 1])

        Y = torch.view_as_complex(torch.stack((Y_real, Y_imag),
                                              -1)).div_(stride[-1])

    output = irfft(Y)

    # Remove extra padded values
    output = output[..., :output_size[-1]].contiguous()

    # Optionally, add a bias term before returning.
    if bias is not None:
        output += bias[(slice(None), ) + (None, ) * (output.ndim - 2)]

    return output
Beispiel #14
0
def _new_irfft(x: torch.Tensor, length: int):
    x = torch.view_as_complex(x)
    return new_fft.irfft(x, length, dim=-1)
def convolve1d(
    waveform,
    kernel,
    padding=0,
    pad_type="constant",
    stride=1,
    groups=1,
    use_fft=False,
    rotation_index=0,
):
    """Use torch.nn.functional to perform 1d padding and conv.

    Arguments
    ---------
    waveform : tensor
        The tensor to perform operations on.
    kernel : tensor
        The filter to apply during convolution
    padding : int or tuple
        The padding (pad_left, pad_right) to apply.
        If an integer is passed instead, this is passed
        to the conv1d function and pad_type is ignored.
    pad_type : str
        The type of padding to use. Passed directly to
        `torch.nn.functional.pad`, see PyTorch documentation
        for available options.
    stride : int
        The number of units to move each time convolution is applied.
        Passed to conv1d. Has no effect if `use_fft` is True.
    groups : int
        This option is passed to `conv1d` to split the input into groups for
        convolution. Input channels should be divisible by number of groups.
    use_fft : bool
        When `use_fft` is passed `True`, then compute the convolution in the
        spectral domain using complex multiply. This is more efficient on CPU
        when the size of the kernel is large (e.g. reverberation). WARNING:
        Without padding, circular convolution occurs. This makes little
        difference in the case of reverberation, but may make more difference
        with different kernels.
    rotation_index : int
        This option only applies if `use_fft` is true. If so, the kernel is
        rolled by this amount before convolution to shift the output location.

    Returns
    -------
    The convolved waveform.

    Example
    -------
    >>> from speechbrain.dataio.dataio import read_audio
    >>> signal = read_audio('samples/audio_samples/example1.wav')
    >>> signal = signal.unsqueeze(0).unsqueeze(2)
    >>> kernel = torch.rand(1, 10, 1)
    >>> signal = convolve1d(signal, kernel, padding=(9, 0))
    """
    if len(waveform.shape) != 3:
        raise ValueError("Convolve1D expects a 3-dimensional tensor")

    # Move time dimension last, which pad and fft and conv expect.
    waveform = waveform.transpose(2, 1)
    kernel = kernel.transpose(2, 1)

    # Padding can be a tuple (left_pad, right_pad) or an int
    if isinstance(padding, tuple):
        waveform = torch.nn.functional.pad(
            input=waveform,
            pad=padding,
            mode=pad_type,
        )

    # This approach uses FFT, which is more efficient if the kernel is large
    if use_fft:

        # Pad kernel to same length as signal, ensuring correct alignment
        zero_length = waveform.size(-1) - kernel.size(-1)

        # Handle case where signal is shorter
        if zero_length < 0:
            kernel = kernel[..., :zero_length]
            zero_length = 0

        # Perform rotation to ensure alignment
        zeros = torch.zeros(kernel.size(0),
                            kernel.size(1),
                            zero_length,
                            device=kernel.device)
        after_index = kernel[..., rotation_index:]
        before_index = kernel[..., :rotation_index]
        kernel = torch.cat((after_index, zeros, before_index), dim=-1)

        # Multiply in frequency domain to convolve in time domain
        if version.parse(torch.__version__) > version.parse("1.6.0"):
            import torch.fft as fft

            result = fft.rfft(waveform) * fft.rfft(kernel)
            convolved = fft.irfft(result, n=waveform.size(-1))
        else:
            f_signal = torch.rfft(waveform, 1)
            f_kernel = torch.rfft(kernel, 1)
            sig_real, sig_imag = f_signal.unbind(-1)
            ker_real, ker_imag = f_kernel.unbind(-1)
            f_result = torch.stack(
                [
                    sig_real * ker_real - sig_imag * ker_imag,
                    sig_real * ker_imag + sig_imag * ker_real,
                ],
                dim=-1,
            )
            convolved = torch.irfft(f_result,
                                    1,
                                    signal_sizes=[waveform.size(-1)])

    # Use the implementation given by torch, which should be efficient on GPU
    else:
        convolved = torch.nn.functional.conv1d(
            input=waveform,
            weight=kernel,
            stride=stride,
            groups=groups,
            padding=padding if not isinstance(padding, tuple) else 0,
        )

    # Return time dimension to the second dimension.
    return convolved.transpose(2, 1)
Beispiel #16
0
 def irfft(input, signal_ndim, normalized=False, signal_sizes=None):
     norm = "ortho" if normalized else "backward"
     n = None if signal_sizes is None else signal_sizes[0]
     return fft.irfft(input, n=n, dim=-1, norm=norm)