def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
    """Returns a notch filter constructed from a high-pass and low-pass filter.

    (from https://tomroelandts.com/articles/
    how-to-create-simple-band-pass-and-band-reject-filters)

    Arguments
    ---------
    notch_freq : float
        frequency to put notch as a fraction of the
        sampling rate / 2. The range of possible inputs is 0 to 1.
    filter_width : int
        Filter width in samples. Longer filters have
        smaller transition bands, but are more inefficient.
    notch_width : float
        Width of the notch, as a fraction of the sampling_rate / 2.

    Example
    -------
    >>> from speechbrain.dataio.dataio import read_audio
    >>> signal = read_audio('samples/audio_samples/example1.wav')
    >>> signal = signal.unsqueeze(0).unsqueeze(2)
    >>> kernel = notch_filter(0.25)
    >>> notched_signal = convolve1d(signal, kernel)
    """

    # Check inputs
    assert 0 < notch_freq <= 1
    assert filter_width % 2 != 0
    pad = filter_width // 2
    inputs = torch.arange(filter_width) - pad

    # Avoid frequencies that are too low
    notch_freq += notch_width

    # Define sinc function, avoiding division by zero
    def sinc(x):
        def _sinc(x):
            return torch.sin(x) / x

        # The zero is at the middle index
        return torch.cat([_sinc(x[:pad]), torch.ones(1), _sinc(x[pad + 1:])])

    # Compute a low-pass filter with cutoff frequency notch_freq.
    hlpf = sinc(3 * (notch_freq - notch_width) * inputs)
    hlpf *= torch.blackman_window(filter_width)
    hlpf /= torch.sum(hlpf)

    # Compute a high-pass filter with cutoff frequency notch_freq.
    hhpf = sinc(3 * (notch_freq + notch_width) * inputs)
    hhpf *= torch.blackman_window(filter_width)
    hhpf /= -torch.sum(hhpf)
    hhpf[pad] += 1

    # Adding filters creates notch filter
    return (hlpf + hhpf).view(1, -1, 1)
Example #2
0
def get_window(name, window_length, squared=False):
    """
    Returns a windowing function.
    
    Arguments:
    ----------
        window (str)                : name of the window, currently only 'hann' is available
        window_length (int)         : length of the window
        squared (bool)              : if true, square the window
        
    Returns:
    ----------
        torch.FloatTensor           : window of size `window_length`
    """
    if name == "hann":
        window = torch.hann_window(window_length)
    elif name == "hamming":
        window = torch.hamming_window(window_length)
    elif name == "blackman":
        window = torch.blackman_window(window_length)
    else:
        raise ValueError("Invalid window name {}".format(name))
    if squared:
        window *= window
    return window
Example #3
0
def get_window(window_type: str,
               window_length_in_samp: int,
               device: Optional[torch.device] = None) -> torch.Tensor:
    # Increase precision in order to achieve parity with scipy.signal.windows.get_window implementation
    if window_type == "bartlett":
        return torch.bartlett_window(window_length_in_samp,
                                     periodic=False,
                                     dtype=torch.float64,
                                     device=device).to(torch.float32)
    elif window_type == "blackman":
        return torch.blackman_window(window_length_in_samp,
                                     periodic=False,
                                     dtype=torch.float64,
                                     device=device).to(torch.float32)
    elif window_type == "hamming":
        return torch.hamming_window(window_length_in_samp,
                                    periodic=False,
                                    dtype=torch.float64,
                                    device=device).to(torch.float32)
    elif window_type == "hann":
        return torch.hann_window(window_length_in_samp,
                                 periodic=False,
                                 dtype=torch.float64,
                                 device=device).to(torch.float32)
    else:
        raise ValueError(f"Unknown window type: {window_type}")
Example #4
0
 def spectral_ops(self):
     a = torch.randn(10)
     b = torch.randn(10, 8, 4, 2)
     return (
         torch.stft(a, 8),
         torch.istft(b, 8),
         torch.bartlett_window(2, dtype=torch.float),
         torch.blackman_window(2, dtype=torch.float),
         torch.hamming_window(4, dtype=torch.float),
         torch.hann_window(4, dtype=torch.float),
         torch.kaiser_window(4, dtype=torch.float),
     )
Example #5
0
def dereverb_wpe_torch(
    audio: torch.Tensor,
    n_fft: int = 512,
    hop_length: int = 128,
    taps: int = 10,
    delay: int = 3,
    iterations: int = 3,
    statistics_mode: str = "full",
) -> torch.Tensor:
    if not is_module_available("nara_wpe"):
        raise ImportError(
            "Please install nara_wpe first using 'pip install git+https://github.com/fgnt/nara_wpe' "
            "(at the time of writing, only GitHub version has a PyTorch implementation)."
        )

    from nara_wpe.torch_wpe import wpe_v6

    assert audio.ndim == 2

    window = torch.blackman_window(n_fft)
    Y = torch.stft(
        audio,
        n_fft=n_fft,
        hop_length=hop_length,
        return_complex=True,
        window=window,
    )
    Y = Y.permute(1, 0, 2)
    Z = wpe_v6(
        Y,
        taps=taps,
        delay=delay,
        iterations=iterations,
        statistics_mode=statistics_mode,
    )
    z = torch.istft(Z.permute(1, 0, 2),
                    n_fft=n_fft,
                    hop_length=hop_length,
                    window=window)
    return z