Пример #1
0
def test_stft_windows(fb_config):
    kernel_size = fb_config["kernel_size"]
    win = np.hanning(kernel_size)
    STFTFB(**fb_config, window=win)
    with pytest.raises(AssertionError):
        win = np.hanning(kernel_size + 1)
        STFTFB(**fb_config, window=win)
Пример #2
0
def test_griffinlim(fb_config, feed_istft, feed_angle):
    stft = Encoder(STFTFB(**fb_config))
    istft = None if not feed_istft else Decoder(STFTFB(**fb_config))
    wav = torch.randn(2, 1, 8000)
    spec = stft(wav)
    tf_mask = torch.sigmoid(torch.randn_like(spec))
    masked_spec = spec * tf_mask
    mag = transforms.mag(masked_spec, -2)
    angles = None if not feed_angle else transforms.angle(masked_spec, -2)
    griffin_lim(mag, stft, angles=angles, istft_dec=istft, n_iter=3)
Пример #3
0
def test_pmsqe_pit(n_src, sample_rate):
    # Define supported STFT
    if sample_rate == 16000:
        stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
    else:
        stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128))
    # Usage by itself
    ref, est = torch.randn(2, n_src, 16000), torch.randn(2, n_src, 16000)
    ref_spec = transforms.mag(stft(ref))
    est_spec = transforms.mag(stft(est))
    loss_func = PITLossWrapper(SingleSrcPMSQE(sample_rate=sample_rate),
                               pit_from="pw_pt")
    # Assert forward ok.
    loss_func(est_spec, ref_spec)
Пример #4
0
def test_perfect_resyn_window(fb_config, analysis_window_name,
                              use_torch_window):
    """ Unit test perfect reconstruction """
    kernel_size = fb_config["kernel_size"]
    window = get_window(analysis_window_name, kernel_size)
    if use_torch_window:
        window = torch.Tensor(window)

    enc = Encoder(STFTFB(**fb_config, window=window))
    # Compute window for perfect resynthesis
    synthesis_window = perfect_synthesis_window(enc.filterbank.window,
                                                enc.stride)
    dec = Decoder(STFTFB(**fb_config, window=synthesis_window))
    inp_wav = torch.ones(1, 1, 32000)
    out_wav = dec(enc(inp_wav))[:, :, kernel_size:-kernel_size]
    inp_test = inp_wav[:, :, kernel_size:-kernel_size]
    testing.assert_allclose(inp_test, out_wav)
Пример #5
0
def test_pmsqe(sample_rate):
    # Define supported STFT
    if sample_rate == 16000:
        stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
    else:
        stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128))
    # Usage by itself
    ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000)
    ref_spec = transforms.mag(stft(ref))
    est_spec = transforms.mag(stft(est))
    loss_func = SingleSrcPMSQE(sample_rate=sample_rate)
    loss_value = loss_func(est_spec, ref_spec)
    # Assert output has shape (batch,)
    assert loss_value.shape[0] == ref.shape[0]
    # Assert support for transposed inputs.
    tr_loss_value = loss_func(est_spec.transpose(1, 2),
                              ref_spec.transpose(1, 2))
    assert_allclose(loss_value, tr_loss_value)
Пример #6
0
def test_misi(fb_config, feed_istft, feed_angle):
    stft = Encoder(STFTFB(**fb_config))
    istft = None if not feed_istft else Decoder(STFTFB(**fb_config))
    n_src = 3
    # Create mixture
    wav = torch.randn(2, 1, 8000)
    spec = stft(wav).unsqueeze(1)
    # Create n_src masks on mixture spec and apply them
    shape = list(spec.shape)
    shape[1] *= n_src
    tf_mask = torch.sigmoid(torch.randn(*shape))
    masked_specs = spec * tf_mask
    # Separate mag and angle.
    mag = transforms.mag(masked_specs, -2)
    angles = None if not feed_angle else transforms.angle(masked_specs, -2)
    est_wavs = misi(wav, mag, stft, angles=angles, istft_dec=istft, n_iter=2)
    # We actually don't know the last dim because ISTFT(STFT()) cuts the end
    assert est_wavs.shape[:-1] == (2, n_src)
Пример #7
0
def test_stft_def(fb_config):
    """ Check consistency between two calls."""
    fb = STFTFB(**fb_config)
    enc = Encoder(fb)
    dec = Decoder(fb)
    enc2, dec2 = make_enc_dec("stft", **fb_config)
    testing.assert_allclose(enc.filterbank.filters(),
                            enc2.filterbank.filters())
    testing.assert_allclose(dec.filterbank.filters(),
                            dec2.filterbank.filters())
Пример #8
0
def test_pcen_forward(n_channels, batch_size):
    audio = torch.randn(batch_size, n_channels, 16000 * 10)

    fb = STFTFB(kernel_size=256, n_filters=256, stride=128)
    enc = Encoder(fb)
    tf_rep = enc(audio)
    mag_spec = transforms.mag(tf_rep)

    pcen = PCEN(n_channels=n_channels)
    energy = pcen(mag_spec)

    expected_shape = mag_spec.shape
    assert energy.shape == expected_shape
Пример #9
0
    def __init__(self,
                 n_filters=None,
                 windows_size=None,
                 hops_size=None,
                 alpha=1.0):
        super().__init__()

        if windows_size is None:
            windows_size = [2048, 1024, 512, 256, 128, 64, 32]
        if n_filters is None:
            n_filters = [2048, 1024, 512, 256, 128, 64, 32]
        if hops_size is None:
            hops_size = [1024, 512, 256, 128, 64, 32, 16]

        self.windows_size = windows_size
        self.n_filters = n_filters
        self.hops_size = hops_size
        self.alpha = alpha

        self.encoders = nn.ModuleList(
            Encoder(STFTFB(n_filters[i], windows_size[i], hops_size[i]))
            for i in range(len(self.n_filters)))
Пример #10
0
def test_filter_shape(fb_config):
    # Instantiate STFT
    fb = STFTFB(**fb_config)
    # Check filter shape.
    assert fb.filters().shape == (fb_config["n_filters"] + 2, 1,
                                  fb_config["kernel_size"])
Пример #11
0
def test_stft_def_error(n_filters):
    with pytest.raises(ValueError) as err:
        STFTFB(n_filters, n_filters)
    assert str(err.value) == f"n_filters must be even, got {n_filters}"