Beispiel #1
0
def test_pmsqe_pit(n_src, sample_rate):
    # Define supported STFT
    if sample_rate == 16000:
        stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
    else:
        stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128))
    # Usage by itself
    ref, est = torch.randn(2, n_src, 16000), torch.randn(2, n_src, 16000)
    ref_spec = transforms.mag(stft(ref))
    est_spec = transforms.mag(stft(est))
    loss_func = PITLossWrapper(SingleSrcPMSQE(sample_rate=sample_rate),
                               pit_from="pw_pt")
    # Assert forward ok.
    loss_func(est_spec, ref_spec)
def test_pinv_of(fb_class):
    fb = fb_class(n_filters=500, kernel_size=16, stride=8)
    encoder = Encoder(fb)
    # Pseudo inverse can be taken from an Encoder/Decoder class or Filterbank.
    decoder_e = Decoder.pinv_of(encoder)
    decoder_f = Decoder.pinv_of(fb)
    assert_allclose(decoder_e.filters(), decoder_f.filters())

    # Check filter computing
    inp = torch.randn(1, 1, 32000)
    _ = decoder_e(encoder(inp))

    decoder = Decoder(fb)
    # Pseudo inverse can be taken from an Encoder/Decoder class or Filterbank.
    encoder_e = Encoder.pinv_of(decoder)
    encoder_f = Encoder.pinv_of(fb)
    assert_allclose(encoder_e.filters(), encoder_f.filters())
Beispiel #3
0
def test_pmsqe(sample_rate):
    # Define supported STFT
    if sample_rate == 16000:
        stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
    else:
        stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128))
    # Usage by itself
    ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000)
    ref_spec = transforms.mag(stft(ref))
    est_spec = transforms.mag(stft(est))
    loss_func = SingleSrcPMSQE(sample_rate=sample_rate)
    loss_value = loss_func(est_spec, ref_spec)
    # Assert output has shape (batch,)
    assert loss_value.shape[0] == ref.shape[0]
    # Assert support for transposed inputs.
    tr_loss_value = loss_func(est_spec.transpose(1, 2),
                              ref_spec.transpose(1, 2))
    assert_allclose(loss_value, tr_loss_value)
Beispiel #4
0
def test_stft_def(fb_config):
    """ Check consistency between two calls."""
    fb = STFTFB(**fb_config)
    enc = Encoder(fb)
    dec = Decoder(fb)
    enc2, dec2 = make_enc_dec("stft", **fb_config)
    testing.assert_allclose(enc.filterbank.filters(),
                            enc2.filterbank.filters())
    testing.assert_allclose(dec.filterbank.filters(),
                            dec2.filterbank.filters())
Beispiel #5
0
def test_griffinlim(fb_config, feed_istft, feed_angle):
    stft = Encoder(STFTFB(**fb_config))
    istft = None if not feed_istft else Decoder(STFTFB(**fb_config))
    wav = torch.randn(2, 1, 8000)
    spec = stft(wav)
    tf_mask = torch.sigmoid(torch.randn_like(spec))
    masked_spec = spec * tf_mask
    mag = transforms.mag(masked_spec, -2)
    angles = None if not feed_angle else transforms.angle(masked_spec, -2)
    griffin_lim(mag, stft, angles=angles, istft_dec=istft, n_iter=3)
def test_fb_def_and_forward_all_dims(fb_class, fb_config):
    """ Test encoder/decoder on other shapes than 3D"""
    # Definition
    enc = Encoder(fb_class(**fb_config))
    dec = Decoder(fb_class(**fb_config))

    # 3D Forward with one channel
    inp = torch.randn(3, 1, 32000)
    tf_out = enc(inp)
    assert tf_out.shape[:2] == (3, enc.filterbank.n_feats_out)
    out = dec(tf_out)
    assert out.shape[:-1] == inp.shape[:-1]  # Time axis can differ
Beispiel #7
0
def test_melgram_encoder(n_filters, n_mels, ndim):
    n_mels = n_mels if n_mels is not None else n_filters // 2 + 1
    melgram_fb = MelGramFB(n_filters=n_filters,
                           kernel_size=n_filters,
                           n_mels=n_mels)
    enc = Encoder(melgram_fb)
    tensor_shape = tuple([random.randint(2, 3)
                          for _ in range(ndim - 1)]) + (4000, )
    wav = torch.randn(tensor_shape)
    mel_spec = enc(wav)
    assert wav.shape[:-1] == mel_spec.shape[:-2]
    assert mel_spec.shape[-2] == n_mels
    conf = melgram_fb.get_config()
Beispiel #8
0
def test_pcen_forward(n_channels, batch_size):
    audio = torch.randn(batch_size, n_channels, 16000 * 10)

    fb = STFTFB(kernel_size=256, n_filters=256, stride=128)
    enc = Encoder(fb)
    tf_rep = enc(audio)
    mag_spec = transforms.mag(tf_rep)

    pcen = PCEN(n_channels=n_channels)
    energy = pcen(mag_spec)

    expected_shape = mag_spec.shape
    assert energy.shape == expected_shape
def test_fb_forward_multichannel(fb_class, fb_config, ndim):
    """ Test encoder/decoder in multichannel setting"""
    # Definition
    enc = Encoder(fb_class(**fb_config))
    dec = Decoder(fb_class(**fb_config))
    # 3D Forward with several channels
    tensor_shape = tuple([random.randint(2, 4)
                          for _ in range(ndim)]) + (4000, )
    inp = torch.randn(tensor_shape)
    tf_out = enc(inp)
    assert tf_out.shape[:ndim + 1] == (tensor_shape[:-1] +
                                       (enc.filterbank.n_feats_out, ))
    out = dec(tf_out)
    assert out.shape[:-1] == inp.shape[:-1]  # Time axis can differ
Beispiel #10
0
    def __init__(self, fb_conf, mask_conf):
        super().__init__()
        self.n_src = mask_conf["n_src"]
        self.n_filters = fb_conf["n_filters"]
        # Create TasNet encoders and decoders (could use nn.Conv1D as well)
        self.encoder_sig = Encoder(FreeFB(**fb_conf))
        self.encoder_relu = Encoder(FreeFB(**fb_conf))
        self.decoder = Decoder(FreeFB(**fb_conf))
        self.bn_layer = GlobLN(fb_conf["n_filters"])

        # Create TasNet masker
        self.masker = nn.Sequential(
            SingleRNN(
                "lstm",
                fb_conf["n_filters"],
                hidden_size=mask_conf["n_units"],
                n_layers=mask_conf["n_layers"],
                bidirectional=True,
                dropout=mask_conf["dropout"],
            ),
            nn.Linear(2 * mask_conf["n_units"], self.n_src * self.n_filters),
            nn.Sigmoid(),
        )
Beispiel #11
0
def test_perfect_resyn_window(fb_config, analysis_window_name,
                              use_torch_window):
    """ Unit test perfect reconstruction """
    kernel_size = fb_config["kernel_size"]
    window = get_window(analysis_window_name, kernel_size)
    if use_torch_window:
        window = torch.Tensor(window)

    enc = Encoder(STFTFB(**fb_config, window=window))
    # Compute window for perfect resynthesis
    synthesis_window = perfect_synthesis_window(enc.filterbank.window,
                                                enc.stride)
    dec = Decoder(STFTFB(**fb_config, window=synthesis_window))
    inp_wav = torch.ones(1, 1, 32000)
    out_wav = dec(enc(inp_wav))[:, :, kernel_size:-kernel_size]
    inp_test = inp_wav[:, :, kernel_size:-kernel_size]
    testing.assert_allclose(inp_test, out_wav)
Beispiel #12
0
def test_misi(fb_config, feed_istft, feed_angle):
    stft = Encoder(STFTFB(**fb_config))
    istft = None if not feed_istft else Decoder(STFTFB(**fb_config))
    n_src = 3
    # Create mixture
    wav = torch.randn(2, 1, 8000)
    spec = stft(wav).unsqueeze(1)
    # Create n_src masks on mixture spec and apply them
    shape = list(spec.shape)
    shape[1] *= n_src
    tf_mask = torch.sigmoid(torch.randn(*shape))
    masked_specs = spec * tf_mask
    # Separate mag and angle.
    mag = transforms.mag(masked_specs, -2)
    angles = None if not feed_angle else transforms.angle(masked_specs, -2)
    est_wavs = misi(wav, mag, stft, angles=angles, istft_dec=istft, n_iter=2)
    # We actually don't know the last dim because ISTFT(STFT()) cuts the end
    assert est_wavs.shape[:-1] == (2, n_src)
Beispiel #13
0
    def __init__(self,
                 n_filters=None,
                 windows_size=None,
                 hops_size=None,
                 alpha=1.0):
        super().__init__()

        if windows_size is None:
            windows_size = [2048, 1024, 512, 256, 128, 64, 32]
        if n_filters is None:
            n_filters = [2048, 1024, 512, 256, 128, 64, 32]
        if hops_size is None:
            hops_size = [1024, 512, 256, 128, 64, 32, 16]

        self.windows_size = windows_size
        self.n_filters = n_filters
        self.hops_size = hops_size
        self.alpha = alpha

        self.encoders = nn.ModuleList(
            Encoder(STFTFB(n_filters[i], windows_size[i], hops_size[i]))
            for i in range(len(self.n_filters)))
def test_fb_def_and_forward_lowdim(fb_class, fb_config):
    """ Test filterbank definition and encoder/decoder forward."""
    # Definition
    enc = Encoder(fb_class(**fb_config))
    dec = Decoder(fb_class(**fb_config))
    # Forward
    inp = torch.randn(1, 1, 16000)
    tf_out = enc(inp)
    # Assert for 2D inputs
    with pytest.warns(UserWarning):
        # STFT(2D) gives 3D and iSTFT(3D) gives 3D. UserWarning about that.
        assert_allclose(enc(inp), enc(inp[0]))
    # Assert for 1D inputs
    assert_allclose(enc(inp)[0], enc(inp[0, 0]))

    out = dec(tf_out)
    # Assert for 4D inputs
    tf_out_4d = tf_out.repeat(1, 2, 1, 1)
    out_4d = dec(tf_out_4d)
    assert_allclose(out, out_4d[:, 0:1])
    # Asser for 2D inputs
    assert_allclose(out[0, 0], dec(tf_out[0]))
    assert tf_out.shape[1] == enc.filterbank.n_feats_out
def test_torch_stft(
    n_fft_next_pow,
    hop_ratio,
    win_length,
    window,
    center,
    pad_mode,
    normalized,
    sample_rate,
    pass_length,
    wav_shape,
):
    # Accept 0.1 less tolerance for larger windows.
    RTOL = 1e-3 if win_length > 256 else 1e-4
    ATOL = 1e-4 if win_length > 256 else 1e-5
    wav = torch.randn(wav_shape, dtype=torch.float32)
    output_len = wav.shape[-1] if pass_length else None
    n_fft = win_length if not n_fft_next_pow else next_power_of_2(win_length)
    hop_length = win_length // hop_ratio

    window = None if window is None else get_window(
        window, win_length, fftbins=True)
    if window is not None:
        # Cannot restore the signal without overlap and near to zero window.
        if hop_ratio == 1 and (window**2 < 1e-11).any():
            pass

    fb = TorchSTFTFB.from_torch_args(
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
        center=center,
        pad_mode=pad_mode,
        normalized=normalized,
        onesided=True,
        sample_rate=sample_rate,
    )

    stft = Encoder(fb)
    istft = Decoder(fb)

    spec = torch.stft(
        wav,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=fb.torch_window,
        center=center,
        pad_mode=pad_mode,
        normalized=normalized,
        onesided=True,
    )

    spec_asteroid = stft(wav)
    torch_spec = to_asteroid(spec.float())
    assert_allclose(spec_asteroid, torch_spec, rtol=RTOL, atol=ATOL)

    try:
        wav_back = torch.istft(
            spec,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=fb.torch_window,
            center=center,
            normalized=normalized,
            onesided=True,
            length=output_len,
        )
    except RuntimeError:
        # If there was a RuntimeError, the OLA had zeros. So we cannot unit test
        # But we can make sure that istft raises a warning about it.
        with pytest.warns(RuntimeWarning):
            _ = istft(spec_asteroid, length=output_len)
    else:
        # If there was no RuntimeError, we unit-test against the results.
        wav_back_asteroid = istft(spec_asteroid, length=output_len)
        # Asteroid always returns a longer signal.
        assert wav_back_asteroid.shape[-1] >= wav_back.shape[-1]
        # The unit test is done on the left part of the signal.
        assert_allclose(wav_back_asteroid[:wav_back.shape[-1]],
                        wav_back.float(),
                        rtol=RTOL,
                        atol=ATOL)