def test_pmsqe_pit(n_src, sample_rate): # Define supported STFT if sample_rate == 16000: stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256)) else: stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128)) # Usage by itself ref, est = torch.randn(2, n_src, 16000), torch.randn(2, n_src, 16000) ref_spec = transforms.mag(stft(ref)) est_spec = transforms.mag(stft(est)) loss_func = PITLossWrapper(SingleSrcPMSQE(sample_rate=sample_rate), pit_from="pw_pt") # Assert forward ok. loss_func(est_spec, ref_spec)
def test_pinv_of(fb_class): fb = fb_class(n_filters=500, kernel_size=16, stride=8) encoder = Encoder(fb) # Pseudo inverse can be taken from an Encoder/Decoder class or Filterbank. decoder_e = Decoder.pinv_of(encoder) decoder_f = Decoder.pinv_of(fb) assert_allclose(decoder_e.filters(), decoder_f.filters()) # Check filter computing inp = torch.randn(1, 1, 32000) _ = decoder_e(encoder(inp)) decoder = Decoder(fb) # Pseudo inverse can be taken from an Encoder/Decoder class or Filterbank. encoder_e = Encoder.pinv_of(decoder) encoder_f = Encoder.pinv_of(fb) assert_allclose(encoder_e.filters(), encoder_f.filters())
def test_pmsqe(sample_rate): # Define supported STFT if sample_rate == 16000: stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256)) else: stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128)) # Usage by itself ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000) ref_spec = transforms.mag(stft(ref)) est_spec = transforms.mag(stft(est)) loss_func = SingleSrcPMSQE(sample_rate=sample_rate) loss_value = loss_func(est_spec, ref_spec) # Assert output has shape (batch,) assert loss_value.shape[0] == ref.shape[0] # Assert support for transposed inputs. tr_loss_value = loss_func(est_spec.transpose(1, 2), ref_spec.transpose(1, 2)) assert_allclose(loss_value, tr_loss_value)
def test_stft_def(fb_config): """ Check consistency between two calls.""" fb = STFTFB(**fb_config) enc = Encoder(fb) dec = Decoder(fb) enc2, dec2 = make_enc_dec("stft", **fb_config) testing.assert_allclose(enc.filterbank.filters(), enc2.filterbank.filters()) testing.assert_allclose(dec.filterbank.filters(), dec2.filterbank.filters())
def test_griffinlim(fb_config, feed_istft, feed_angle): stft = Encoder(STFTFB(**fb_config)) istft = None if not feed_istft else Decoder(STFTFB(**fb_config)) wav = torch.randn(2, 1, 8000) spec = stft(wav) tf_mask = torch.sigmoid(torch.randn_like(spec)) masked_spec = spec * tf_mask mag = transforms.mag(masked_spec, -2) angles = None if not feed_angle else transforms.angle(masked_spec, -2) griffin_lim(mag, stft, angles=angles, istft_dec=istft, n_iter=3)
def test_fb_def_and_forward_all_dims(fb_class, fb_config): """ Test encoder/decoder on other shapes than 3D""" # Definition enc = Encoder(fb_class(**fb_config)) dec = Decoder(fb_class(**fb_config)) # 3D Forward with one channel inp = torch.randn(3, 1, 32000) tf_out = enc(inp) assert tf_out.shape[:2] == (3, enc.filterbank.n_feats_out) out = dec(tf_out) assert out.shape[:-1] == inp.shape[:-1] # Time axis can differ
def test_melgram_encoder(n_filters, n_mels, ndim): n_mels = n_mels if n_mels is not None else n_filters // 2 + 1 melgram_fb = MelGramFB(n_filters=n_filters, kernel_size=n_filters, n_mels=n_mels) enc = Encoder(melgram_fb) tensor_shape = tuple([random.randint(2, 3) for _ in range(ndim - 1)]) + (4000, ) wav = torch.randn(tensor_shape) mel_spec = enc(wav) assert wav.shape[:-1] == mel_spec.shape[:-2] assert mel_spec.shape[-2] == n_mels conf = melgram_fb.get_config()
def test_pcen_forward(n_channels, batch_size): audio = torch.randn(batch_size, n_channels, 16000 * 10) fb = STFTFB(kernel_size=256, n_filters=256, stride=128) enc = Encoder(fb) tf_rep = enc(audio) mag_spec = transforms.mag(tf_rep) pcen = PCEN(n_channels=n_channels) energy = pcen(mag_spec) expected_shape = mag_spec.shape assert energy.shape == expected_shape
def test_fb_forward_multichannel(fb_class, fb_config, ndim): """ Test encoder/decoder in multichannel setting""" # Definition enc = Encoder(fb_class(**fb_config)) dec = Decoder(fb_class(**fb_config)) # 3D Forward with several channels tensor_shape = tuple([random.randint(2, 4) for _ in range(ndim)]) + (4000, ) inp = torch.randn(tensor_shape) tf_out = enc(inp) assert tf_out.shape[:ndim + 1] == (tensor_shape[:-1] + (enc.filterbank.n_feats_out, )) out = dec(tf_out) assert out.shape[:-1] == inp.shape[:-1] # Time axis can differ
def __init__(self, fb_conf, mask_conf): super().__init__() self.n_src = mask_conf["n_src"] self.n_filters = fb_conf["n_filters"] # Create TasNet encoders and decoders (could use nn.Conv1D as well) self.encoder_sig = Encoder(FreeFB(**fb_conf)) self.encoder_relu = Encoder(FreeFB(**fb_conf)) self.decoder = Decoder(FreeFB(**fb_conf)) self.bn_layer = GlobLN(fb_conf["n_filters"]) # Create TasNet masker self.masker = nn.Sequential( SingleRNN( "lstm", fb_conf["n_filters"], hidden_size=mask_conf["n_units"], n_layers=mask_conf["n_layers"], bidirectional=True, dropout=mask_conf["dropout"], ), nn.Linear(2 * mask_conf["n_units"], self.n_src * self.n_filters), nn.Sigmoid(), )
def test_perfect_resyn_window(fb_config, analysis_window_name, use_torch_window): """ Unit test perfect reconstruction """ kernel_size = fb_config["kernel_size"] window = get_window(analysis_window_name, kernel_size) if use_torch_window: window = torch.Tensor(window) enc = Encoder(STFTFB(**fb_config, window=window)) # Compute window for perfect resynthesis synthesis_window = perfect_synthesis_window(enc.filterbank.window, enc.stride) dec = Decoder(STFTFB(**fb_config, window=synthesis_window)) inp_wav = torch.ones(1, 1, 32000) out_wav = dec(enc(inp_wav))[:, :, kernel_size:-kernel_size] inp_test = inp_wav[:, :, kernel_size:-kernel_size] testing.assert_allclose(inp_test, out_wav)
def test_misi(fb_config, feed_istft, feed_angle): stft = Encoder(STFTFB(**fb_config)) istft = None if not feed_istft else Decoder(STFTFB(**fb_config)) n_src = 3 # Create mixture wav = torch.randn(2, 1, 8000) spec = stft(wav).unsqueeze(1) # Create n_src masks on mixture spec and apply them shape = list(spec.shape) shape[1] *= n_src tf_mask = torch.sigmoid(torch.randn(*shape)) masked_specs = spec * tf_mask # Separate mag and angle. mag = transforms.mag(masked_specs, -2) angles = None if not feed_angle else transforms.angle(masked_specs, -2) est_wavs = misi(wav, mag, stft, angles=angles, istft_dec=istft, n_iter=2) # We actually don't know the last dim because ISTFT(STFT()) cuts the end assert est_wavs.shape[:-1] == (2, n_src)
def __init__(self, n_filters=None, windows_size=None, hops_size=None, alpha=1.0): super().__init__() if windows_size is None: windows_size = [2048, 1024, 512, 256, 128, 64, 32] if n_filters is None: n_filters = [2048, 1024, 512, 256, 128, 64, 32] if hops_size is None: hops_size = [1024, 512, 256, 128, 64, 32, 16] self.windows_size = windows_size self.n_filters = n_filters self.hops_size = hops_size self.alpha = alpha self.encoders = nn.ModuleList( Encoder(STFTFB(n_filters[i], windows_size[i], hops_size[i])) for i in range(len(self.n_filters)))
def test_fb_def_and_forward_lowdim(fb_class, fb_config): """ Test filterbank definition and encoder/decoder forward.""" # Definition enc = Encoder(fb_class(**fb_config)) dec = Decoder(fb_class(**fb_config)) # Forward inp = torch.randn(1, 1, 16000) tf_out = enc(inp) # Assert for 2D inputs with pytest.warns(UserWarning): # STFT(2D) gives 3D and iSTFT(3D) gives 3D. UserWarning about that. assert_allclose(enc(inp), enc(inp[0])) # Assert for 1D inputs assert_allclose(enc(inp)[0], enc(inp[0, 0])) out = dec(tf_out) # Assert for 4D inputs tf_out_4d = tf_out.repeat(1, 2, 1, 1) out_4d = dec(tf_out_4d) assert_allclose(out, out_4d[:, 0:1]) # Asser for 2D inputs assert_allclose(out[0, 0], dec(tf_out[0])) assert tf_out.shape[1] == enc.filterbank.n_feats_out
def test_torch_stft( n_fft_next_pow, hop_ratio, win_length, window, center, pad_mode, normalized, sample_rate, pass_length, wav_shape, ): # Accept 0.1 less tolerance for larger windows. RTOL = 1e-3 if win_length > 256 else 1e-4 ATOL = 1e-4 if win_length > 256 else 1e-5 wav = torch.randn(wav_shape, dtype=torch.float32) output_len = wav.shape[-1] if pass_length else None n_fft = win_length if not n_fft_next_pow else next_power_of_2(win_length) hop_length = win_length // hop_ratio window = None if window is None else get_window( window, win_length, fftbins=True) if window is not None: # Cannot restore the signal without overlap and near to zero window. if hop_ratio == 1 and (window**2 < 1e-11).any(): pass fb = TorchSTFTFB.from_torch_args( n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized, onesided=True, sample_rate=sample_rate, ) stft = Encoder(fb) istft = Decoder(fb) spec = torch.stft( wav, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=fb.torch_window, center=center, pad_mode=pad_mode, normalized=normalized, onesided=True, ) spec_asteroid = stft(wav) torch_spec = to_asteroid(spec.float()) assert_allclose(spec_asteroid, torch_spec, rtol=RTOL, atol=ATOL) try: wav_back = torch.istft( spec, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=fb.torch_window, center=center, normalized=normalized, onesided=True, length=output_len, ) except RuntimeError: # If there was a RuntimeError, the OLA had zeros. So we cannot unit test # But we can make sure that istft raises a warning about it. with pytest.warns(RuntimeWarning): _ = istft(spec_asteroid, length=output_len) else: # If there was no RuntimeError, we unit-test against the results. wav_back_asteroid = istft(spec_asteroid, length=output_len) # Asteroid always returns a longer signal. assert wav_back_asteroid.shape[-1] >= wav_back.shape[-1] # The unit test is done on the left part of the signal. assert_allclose(wav_back_asteroid[:wav_back.shape[-1]], wav_back.float(), rtol=RTOL, atol=ATOL)