Ejemplo n.º 1
0
 def __call__(self, arg):
     """ Apply filter on a (batched) signal. """
     N = arg.shape[-1]
     if not N in self._cached:
         # cache or print or error
         if self.strict:
             print(f"caching sparse operator for size {N}")
         self.cache(N)
     # read cache
     d, w = self._cached[N]
     if d.dim() == arg.dim(): 
         F_arg = rfft(w * arg)
         return irfft(d * F_arg)
     # batched
     F_arg = rfft(w * arg, dim=1)
     return irfft(d * F_arg, dim=1)
Ejemplo n.º 2
0
def convolve1d(signal: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
    """
    Computes the 1-d convolution of signal by kernel using FFTs.
    Both signal and kernel must be 1-dimensional.
    :param torch.Tensor signal: A signal to convolve.
    :param torch.Tensor kernel: A convolution kernel.
    :param str mode: One of: 'full', 'valid', 'same'.
    :return: torch.Tensor Convolution of signal with kernel. Returns the full convolution, i.e.,
        the output tensor will have size m + n - 1, where m is the length of the
        signal and n is the length of the kernel.
    """
    assert (signal.ndim == 1
            and kernel.ndim == 1), "signal and kernel must be 1-dimensional"
    m = signal.size(-1)
    n = kernel.size(-1)

    # Compute convolution using fft.
    padded_size = m + n - 1
    # Round up for cheaper fft.
    fast_ftt_size = next_fast_len(padded_size)
    f_signal = rfft(signal, n=fast_ftt_size)
    f_kernel = rfft(kernel, n=fast_ftt_size)
    f_result = f_signal * f_kernel
    result = irfft(f_result, n=fast_ftt_size)

    return result[:padded_size]
Ejemplo n.º 3
0
def _istft(x,
           n_fft,
           ola_weight,
           win_length,
           window,
           hop_length,
           center,
           normalized,
           onesided,
           pad_mode,
           return_complex,
           norm_envelope=None):
    """
    A helper function to do istft.
    """
    if onesided:
        x = fft.irfft(x,
                      n=n_fft,
                      dim=-2,
                      norm='ortho' if normalized else 'backward')
    else:
        x = fft.ifft(x,
                     n=n_fft,
                     dim=-2,
                     norm='ortho' if normalized else 'backward').real

    x, norm_envelope = _ola(x,
                            hop_length,
                            ola_weight,
                            padding=n_fft // 2 if center else 0,
                            norm_envelope=norm_envelope)
    return x, norm_envelope
Ejemplo n.º 4
0
def fft_convolve(signal, kernel):
    signal = nn.functional.pad(signal, (0, signal.shape[-1]))
    kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0))

    output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel))
    output = output[..., output.shape[-1] // 2:]

    return output
Ejemplo n.º 5
0
    def __call__(self, image: torch.Tensor,
                 context: ExpressionContext) -> torch.Tensor:
        std = 15. * torch.Tensor(context(self.std)).to(image.device).reshape(
            3, 1)

        space = fft.rfft(image.reshape(3, -1))
        space.real = space.real + torch.randn(space.shape).to(
            image.device) * std
        space.imag = space.imag + torch.randn(space.shape).to(
            image.device) * std

        return fft.irfft(space).reshape(*image.shape)
Ejemplo n.º 6
0
def amp_to_impulse_response(amp, target_size):
    amp = torch.stack([amp, torch.zeros_like(amp)], -1)
    amp = torch.view_as_complex(amp)
    amp = fft.irfft(amp)

    filter_size = amp.shape[-1]

    amp = torch.roll(amp, filter_size // 2, -1)
    win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device)

    amp = amp * win

    amp = nn.functional.pad(amp, (0, int(target_size) - int(filter_size)))
    amp = torch.roll(amp, -filter_size // 2, -1)

    return amp
Ejemplo n.º 7
0
def conv1d(signal, kernel, mode='fft_circular', cut=False, cut_lim=150):
    """
    signal M x N
    kernel N
    """
    kernel_size = int(kernel.shape[-1])

    if mode == 'direct':
        conved = F.conv1d(signal.unsqueeze(1),
                          kernel.flip(0).unsqueeze(0).unsqueeze(0),
                          padding=kernel_size - 1)[:, 0]

    elif mode == 'fft_circular':
        conved = irfft(rfft(signal) * rfft(kernel), signal.shape[-1])

    if cut:
        conved = conved[:, cut_lim:kernel_size + cut_lim]

    return conved
Ejemplo n.º 8
0
def idct(x, dim=-1):
    """
    Inverse discrete cosine transform of type II, scaled to be orthonormal.

    This is the inverse of :func:`dct_ii` , and is equivalent to
    :func:`scipy.fftpack.idct` with ``norm="ortho"``.

    :param Tensor x: The input signal.
    :param int dim: Dimension along which to compute DCT.
    :rtype: Tensor
    """
    if dim >= 0:
        dim -= x.dim()
    if dim != -1:
        y = x.reshape(x.shape[:dim + 1] + (-1, )).transpose(-1, -2)
        return idct(y).transpose(-1, -2).reshape(x.shape)

    N = x.size(-1)
    scale = torch.cat([
        x.new_tensor([math.sqrt(N)]),
        x.new_full((N - 1, ), math.sqrt(0.5 * N))
    ])
    x = x * scale
    # Step 1, solve X = cos(k) * Yr + sin(k) * Yi
    # We know that Y[1:] is conjugate to Y[:0:-1], hence
    # X[:0:-1] = sin(k) * Yr[1:] + cos(k) * Yi[1:]
    # So Yr[1:] = cos(k) * X[1:] + sin(k) * X[:0:-1]
    # and Yi[1:] = sin(k) * X[1:] - cos(k) * X[:0:-1]
    # In addition, Yi[0] = 0, Yr[0] = X[0]
    # In other words, Y = complex_mul(e^ik, X - i[0, X[:0:-1]])
    M = N // 2 + 1  # half size
    xi = torch.nn.functional.pad(-x[..., N - M + 1:], (0, 1)).flip(-1)
    X = torch.stack([x[..., :M], xi], dim=-1)
    coef_real = torch.cos(
        torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype,
                       device=x.device))
    coef = torch.stack([coef_real[:M], coef_real[-M:].flip(-1)], dim=-1)
    Y = as_complex(coef) * as_complex(X)
    # Step 2
    y = irfft(Y, n=N)
    # Step 3
    return torch.stack([y, y.flip(-1)],
                       axis=-1).reshape(x.shape[:-1] + (-1, ))[..., :N]
Ejemplo n.º 9
0
def autocorrelation(input, dim=0):
    """
    Computes the autocorrelation of samples at dimension ``dim``.

    Reference: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation

    :param torch.Tensor input: the input tensor.
    :param int dim: the dimension to calculate autocorrelation.
    :returns torch.Tensor: autocorrelation of ``input``.
    """
    if (not input.is_cuda) and (not torch.backends.mkl.is_available()):
        raise NotImplementedError(
            "For CPU tensor, this method is only supported "
            "with MKL installed.")

    # Adapted from Stan implementation
    # https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/autocorrelation.hpp
    N = input.size(dim)
    M = next_fast_len(N)
    M2 = 2 * M

    # transpose dim with -1 for Fourier transform
    input = input.transpose(dim, -1)

    # centering and padding x
    centered_signal = input - input.mean(dim=-1, keepdim=True)

    # Fourier transform
    freqvec = torch.view_as_real(rfft(centered_signal, n=M2))
    # take square of magnitude of freqvec (or freqvec x freqvec*)
    freqvec_gram = freqvec.pow(2).sum(-1)
    # inverse Fourier transform
    autocorr = irfft(freqvec_gram, n=M2)

    # truncate and normalize the result, then transpose back to original shape
    autocorr = autocorr[..., :N]
    autocorr = autocorr / torch.tensor(
        range(N, 0, -1), dtype=input.dtype, device=input.device)
    autocorr = autocorr / autocorr[..., :1]
    return autocorr.transpose(dim, -1)
Ejemplo n.º 10
0
def convolve(signal, kernel, mode='full'):
    """
    Computes the 1-d convolution of signal by kernel using FFTs.
    The two arguments should have the same rightmost dim, but may otherwise be
    arbitrarily broadcastable.

    :param torch.Tensor signal: A signal to convolve.
    :param torch.Tensor kernel: A convolution kernel.
    :param str mode: One of: 'full', 'valid', 'same'.
    :return: A tensor with broadcasted shape. Letting ``m = signal.size(-1)``
        and ``n = kernel.size(-1)``, the rightmost size of the result will be:
        ``m + n - 1`` if mode is 'full';
        ``max(m, n) - min(m, n) + 1`` if mode is 'valid'; or
        ``max(m, n)`` if mode is 'same'.
    :rtype torch.Tensor:
    """
    m = signal.size(-1)
    n = kernel.size(-1)
    if mode == 'full':
        truncate = m + n - 1
    elif mode == 'valid':
        truncate = max(m, n) - min(m, n) + 1
    elif mode == 'same':
        truncate = max(m, n)
    else:
        raise ValueError('Unknown mode: {}'.format(mode))

    # Compute convolution using fft.
    padded_size = m + n - 1
    # Round up for cheaper fft.
    fast_ftt_size = next_fast_len(padded_size)
    f_signal = rfft(signal, n=fast_ftt_size)
    f_kernel = rfft(kernel, n=fast_ftt_size)
    f_result = f_signal * f_kernel
    result = irfft(f_result, n=fast_ftt_size)

    start_idx = (padded_size - truncate) // 2
    return result[..., start_idx:start_idx + truncate]
Ejemplo n.º 11
0
def autocorrelation(input, dim=0):
    """
    Computes the autocorrelation of samples at dimension ``dim``.

    Reference: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation

    Implementation copied form `pyro <https://github.com/pyro-ppl/pyro/blob/dev/pyro/ops/stats.py>`_.

    :param torch.Tensor input: the input tensor.
    :param int dim: the dimension to calculate autocorrelation.
    :returns torch.Tensor: autocorrelation of ``input``.
    """
    # Adapted from Stan implementation
    # https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/autocorrelation.hpp
    N = input.size(dim)
    M = next_fast_len(N)
    M2 = 2 * M

    # transpose dim with -1 for Fourier transform
    input = input.transpose(dim, -1)

    # centering and padding x
    centered_signal = input - input.mean(dim=-1, keepdim=True)

    # Fourier transform
    freqvec = torch.view_as_real(rfft(centered_signal, n=M2))
    # take square of magnitude of freqvec (or freqvec x freqvec*)
    freqvec_gram = freqvec.pow(2).sum(-1)
    # inverse Fourier transform
    autocorr = irfft(freqvec_gram, n=M2)

    # truncate and normalize the result, then transpose back to original shape
    autocorr = autocorr[..., :N]
    autocorr = autocorr / torch.tensor(
        range(N, 0, -1), dtype=input.dtype, device=input.device)
    autocorr = autocorr / autocorr[..., :1]
    return autocorr.transpose(dim, -1)
Ejemplo n.º 12
0
                            onesided=True)
        if norm == 'forward':
            output /= float(n)

        # Make complex and move back dimension to its original position
        if _torch_has_complex:
            output = torch.view_as_complex(output)
            output = utils.movedim(output, -1, dim)
        else:
            output = utils.movedim(output, -2, dim if dim > 0 else dim - 1)

        return output


if _torch_has_fft_module:
    irfft = lambda *a, real=None, **k: fft_mod.irfft(*a, **k)
else:

    def irfft(input, n=None, dim=-1, norm='backward', real=None):
        """One dimensional complex-to-real inverse Fourier transform.

        The input is interpreted as a one-sided Hermitian signal in the
        Fourier domain, as produced by rfft(). By the Hermitian property,
        the output will be real-valued.

        Notes
        -----
        .. The correct interpretation of the Hermitian input depends on the
           length of the original data, as given by `n`. This is because each
           input shape could correspond to either an odd or even length signal.
           By default, the signal is assumed to be even length and odd signals
Ejemplo n.º 13
0
def _fft_convnd(input: Tensor, weight: Tensor, bias: Optional[Tensor],
                stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int],
                groups: int) -> Tensor:

    output_size = _conv_shape(input.shape[2:], weight.shape[2:], stride,
                              padding, dilation)
    reversed_padding_repeated_twice = _reverse_repeat_tuple(padding, 2)
    padded_input = F.pad(input, reversed_padding_repeated_twice)

    s: List[int] = []
    weight_s: List[int] = []
    for i, (x_size, w_size, d, st) in enumerate(
            zip(padded_input.shape[2:], weight.shape[2:], dilation, stride)):
        s_size = max(x_size, w_size * d)

        # find s size that can be divided by stride and dilation
        rfft_even = 2 if i == len(stride) - 1 else 1
        factor = _lcm(st * rfft_even, d * rfft_even)

        offset = s_size % factor
        if offset:
            s_size += factor - offset
        s.append(s_size)
        weight_s.append(s_size // d)

    X = rfftn(padded_input, s=s)

    W = rfft(weight, n=weight_s[-1])
    # handle dilation
    # handle dilation for last dim
    if dilation[-1] > 1:
        W_neg_freq = W.flip(-1)[..., 1:]
        W_neg_freq.imag.mul_(-1)

        tmp = [W]
        for i in range(1, dilation[-1]):
            if i % 2:
                tmp.append(W_neg_freq)
            else:
                tmp.append(W[..., 1:])

        W = torch.cat(tmp, -1)

    if len(weight_s) > 1:
        W = fftn(W, s=weight_s[:-1], dim=tuple(range(2, W.ndim - 1)))
        repeats = (1, 1) + dilation[:-1] + (1, )
        W.imag.mul_(-1)
        if sum(repeats) > W.ndim:
            W = W.repeat(*repeats)
    else:
        W.imag.mul_(-1)

    Y = _complex_matmul(X, W, groups)

    # handle stride
    if len(stride) > 1:
        for i, st in enumerate(stride[:-1]):
            if st > 1:
                Y = Y.reshape(*Y.shape[:i + 2], st, -1,
                              *Y.shape[i + 3:]).mean(i + 2)

            Y = ifft(Y, dim=i + 2)
            Y = Y.as_strided(
                Y.shape[:i + 2] + output_size[i:i + 1] + Y.shape[i + 3:],
                Y.stride())

    if stride[-1] > 1:
        n_fft = Y.size(-1) * 2 - 2
        new_n_fft = n_fft // stride[-1]
        step_size = new_n_fft // 2
        strided_Y_size = step_size + 1

        unfolded_Y_real = Y.real.unfold(-1, strided_Y_size, step_size)
        unfolded_Y_imag = Y.imag[..., 1:].unfold(-1, strided_Y_size - 2,
                                                 step_size)
        Y_pos_real, Y_pos_imag = unfolded_Y_real[..., ::2, :].sum(
            -2), unfolded_Y_imag[..., ::2, :].sum(-2)
        Y_neg_real, Y_neg_imag = unfolded_Y_real[..., 1::2, :].sum(-2).flip(
            -1), unfolded_Y_imag[..., 1::2, :].sum(-2).flip(-1)

        Y_real = Y_pos_real.add_(Y_neg_real)
        Y_imag = Y_pos_imag.add_(Y_neg_imag, alpha=-1)
        Y_imag = F.pad(Y_imag, [1, 1])

        Y = torch.view_as_complex(torch.stack((Y_real, Y_imag),
                                              -1)).div_(stride[-1])

    output = irfft(Y)

    # Remove extra padded values
    output = output[..., :output_size[-1]].contiguous()

    # Optionally, add a bias term before returning.
    if bias is not None:
        output += bias[(slice(None), ) + (None, ) * (output.ndim - 2)]

    return output
Ejemplo n.º 14
0
def _new_irfft(x: torch.Tensor, length: int):
    x = torch.view_as_complex(x)
    return new_fft.irfft(x, length, dim=-1)
Ejemplo n.º 15
0
def convolve1d(
    waveform,
    kernel,
    padding=0,
    pad_type="constant",
    stride=1,
    groups=1,
    use_fft=False,
    rotation_index=0,
):
    """Use torch.nn.functional to perform 1d padding and conv.

    Arguments
    ---------
    waveform : tensor
        The tensor to perform operations on.
    kernel : tensor
        The filter to apply during convolution
    padding : int or tuple
        The padding (pad_left, pad_right) to apply.
        If an integer is passed instead, this is passed
        to the conv1d function and pad_type is ignored.
    pad_type : str
        The type of padding to use. Passed directly to
        `torch.nn.functional.pad`, see PyTorch documentation
        for available options.
    stride : int
        The number of units to move each time convolution is applied.
        Passed to conv1d. Has no effect if `use_fft` is True.
    groups : int
        This option is passed to `conv1d` to split the input into groups for
        convolution. Input channels should be divisible by number of groups.
    use_fft : bool
        When `use_fft` is passed `True`, then compute the convolution in the
        spectral domain using complex multiply. This is more efficient on CPU
        when the size of the kernel is large (e.g. reverberation). WARNING:
        Without padding, circular convolution occurs. This makes little
        difference in the case of reverberation, but may make more difference
        with different kernels.
    rotation_index : int
        This option only applies if `use_fft` is true. If so, the kernel is
        rolled by this amount before convolution to shift the output location.

    Returns
    -------
    The convolved waveform.

    Example
    -------
    >>> from speechbrain.dataio.dataio import read_audio
    >>> signal = read_audio('samples/audio_samples/example1.wav')
    >>> signal = signal.unsqueeze(0).unsqueeze(2)
    >>> kernel = torch.rand(1, 10, 1)
    >>> signal = convolve1d(signal, kernel, padding=(9, 0))
    """
    if len(waveform.shape) != 3:
        raise ValueError("Convolve1D expects a 3-dimensional tensor")

    # Move time dimension last, which pad and fft and conv expect.
    waveform = waveform.transpose(2, 1)
    kernel = kernel.transpose(2, 1)

    # Padding can be a tuple (left_pad, right_pad) or an int
    if isinstance(padding, tuple):
        waveform = torch.nn.functional.pad(
            input=waveform,
            pad=padding,
            mode=pad_type,
        )

    # This approach uses FFT, which is more efficient if the kernel is large
    if use_fft:

        # Pad kernel to same length as signal, ensuring correct alignment
        zero_length = waveform.size(-1) - kernel.size(-1)

        # Handle case where signal is shorter
        if zero_length < 0:
            kernel = kernel[..., :zero_length]
            zero_length = 0

        # Perform rotation to ensure alignment
        zeros = torch.zeros(kernel.size(0),
                            kernel.size(1),
                            zero_length,
                            device=kernel.device)
        after_index = kernel[..., rotation_index:]
        before_index = kernel[..., :rotation_index]
        kernel = torch.cat((after_index, zeros, before_index), dim=-1)

        # Multiply in frequency domain to convolve in time domain
        if version.parse(torch.__version__) > version.parse("1.6.0"):
            import torch.fft as fft

            result = fft.rfft(waveform) * fft.rfft(kernel)
            convolved = fft.irfft(result, n=waveform.size(-1))
        else:
            f_signal = torch.rfft(waveform, 1)
            f_kernel = torch.rfft(kernel, 1)
            sig_real, sig_imag = f_signal.unbind(-1)
            ker_real, ker_imag = f_kernel.unbind(-1)
            f_result = torch.stack(
                [
                    sig_real * ker_real - sig_imag * ker_imag,
                    sig_real * ker_imag + sig_imag * ker_real,
                ],
                dim=-1,
            )
            convolved = torch.irfft(f_result,
                                    1,
                                    signal_sizes=[waveform.size(-1)])

    # Use the implementation given by torch, which should be efficient on GPU
    else:
        convolved = torch.nn.functional.conv1d(
            input=waveform,
            weight=kernel,
            stride=stride,
            groups=groups,
            padding=padding if not isinstance(padding, tuple) else 0,
        )

    # Return time dimension to the second dimension.
    return convolved.transpose(2, 1)
Ejemplo n.º 16
0
 def irfft(input, signal_ndim, normalized=False, signal_sizes=None):
     norm = "ortho" if normalized else "backward"
     n = None if signal_sizes is None else signal_sizes[0]
     return fft.irfft(input, n=n, dim=-1, norm=norm)