def test_stft_def(fb_config): """ Check consistency between two calls.""" fb = STFTFB(**fb_config) enc = Encoder(fb) dec = Decoder(fb) enc2, dec2 = make_enc_dec('stft', **fb_config) testing.assert_allclose(enc.filterbank.filters, enc2.filterbank.filters) testing.assert_allclose(dec.filterbank.filters, dec2.filterbank.filters)
def test_pinv_of(fb_class): fb = fb_class(n_filters=500, kernel_size=16, stride=8) encoder = Encoder(fb) # Pseudo inverse can be taken from an Encoder/Decoder class or Filterbank. decoder_e = Decoder.pinv_of(encoder) decoder_f = Decoder.pinv_of(fb) assert_allclose(decoder_e.filters, decoder_f.filters) # Check filter computing inp = torch.randn(1, 1, 32000) _ = decoder_e(encoder(inp)) decoder = Decoder(fb) # Pseudo inverse can be taken from an Encoder/Decoder class or Filterbank. encoder_e = Encoder.pinv_of(decoder) encoder_f = Encoder.pinv_of(fb) assert_allclose(encoder_e.filters, encoder_f.filters)
def test_fb_def_and_forward(fb_class, fb_config): """ Test filterbank defintion and encoder/decoder forward.""" # Definition enc = Encoder(fb_class(**fb_config)) dec = Decoder(fb_class(**fb_config)) # Forward inp = torch.randn(1, 1, 32000) tf_out = enc(inp) out = dec(tf_out) # 4d forward + unit test tf_out_4d = tf_out.repeat(1, 2, 1, 1) out_4d = dec(tf_out_4d) assert_allclose(out, out_4d[:, 0]) # Get config tests dec_config = dec.get_config() enc_config = enc.get_config() # N feats out test assert tf_out.shape[1] == enc.filterbank.n_feats_out
def test_griffinlim(fb_config, feed_istft, feed_angle): stft = Encoder(STFTFB(**fb_config)) istft = None if not feed_istft else Decoder(STFTFB(**fb_config)) wav = torch.randn(2, 1, 8000) spec = stft(wav) tf_mask = torch.sigmoid(torch.randn_like(spec)) masked_spec = spec * tf_mask mag = transforms.take_mag(masked_spec, -2) angles = None if not feed_angle else transforms.angle(masked_spec, -2) griffin_lim(mag, stft, angles=angles, istft_dec=istft, n_iter=3)
def test_fb_def_and_forward_all_dims(fb_class, fb_config): """ Test encoder/decoder on other shapes than 3D""" # Definition enc = Encoder(fb_class(**fb_config)) dec = Decoder(fb_class(**fb_config)) # 3D Forward with one channel inp = torch.randn(3, 1, 32000) tf_out = enc(inp) assert tf_out.shape[:2] == (3, enc.filterbank.n_feats_out) out = dec(tf_out) assert out.shape[:-1] == inp.shape[:-1] # Time axis can differ
def test_fb_forward_multichannel(fb_class, fb_config, ndim): """ Test encoder/decoder in multichannel setting""" # Definition enc = Encoder(fb_class(**fb_config)) dec = Decoder(fb_class(**fb_config)) # 3D Forward with several channels tensor_shape = tuple([random.randint(2, 4) for _ in range(ndim)]) + (4000, ) inp = torch.randn(tensor_shape) tf_out = enc(inp) assert tf_out.shape[:ndim + 1] == (tensor_shape[:-1] + (enc.filterbank.n_feats_out, )) out = dec(tf_out) assert out.shape[:-1] == inp.shape[:-1] # Time axis can differ
def test_perfect_resyn_window(fb_config, analysis_window_name): """ Unit test perfect reconstruction """ kernel_size = fb_config["kernel_size"] window = get_window(analysis_window_name, kernel_size) enc = Encoder(STFTFB(**fb_config, window=window)) # Compute window for perfect resynthesis synthesis_window = perfect_synthesis_window(enc.filterbank.window, enc.stride) dec = Decoder(STFTFB(**fb_config, window=synthesis_window)) inp_wav = torch.ones(1, 1, 32000) out_wav = dec(enc(inp_wav))[:, :, kernel_size:-kernel_size] inp_test = inp_wav[:, :, kernel_size:-kernel_size] testing.assert_allclose(inp_test, out_wav)
def test_misi(fb_config, feed_istft, feed_angle): stft = Encoder(STFTFB(**fb_config)) istft = None if not feed_istft else Decoder(STFTFB(**fb_config)) n_src = 3 # Create mixture wav = torch.randn(2, 1, 8000) spec = stft(wav).unsqueeze(1) # Create n_src masks on mixture spec and apply them shape = list(spec.shape) shape[1] *= n_src tf_mask = torch.sigmoid(torch.randn(*shape)) masked_specs = spec * tf_mask # Separate mag and angle. mag = transforms.take_mag(masked_specs, -2) angles = None if not feed_angle else transforms.angle(masked_specs, -2) est_wavs = misi(wav, mag, stft, angles=angles, istft_dec=istft, n_iter=2) # We actually don't know the last dim because ISTFT(STFT()) cuts the end assert est_wavs.shape[:-1] == (2, n_src)
def __init__(self, fb_conf, mask_conf): super().__init__() self.n_src = mask_conf['n_src'] self.n_filters = fb_conf['n_filters'] # Create TasNet encoders and decoders (could use nn.Conv1D as well) self.encoder_sig = Encoder(FreeFB(**fb_conf)) self.encoder_relu = Encoder(FreeFB(**fb_conf)) self.decoder = Decoder(FreeFB(**fb_conf)) self.bn_layer = GlobLN(fb_conf['n_filters']) # Create TasNet masker self.masker = nn.Sequential( SingleRNN('lstm', fb_conf['n_filters'], hidden_size=mask_conf['n_units'], n_layers=mask_conf['n_layers'], bidirectional=True, dropout=mask_conf['dropout']), nn.Linear(2 * mask_conf['n_units'], self.n_src * self.n_filters), nn.Sigmoid())
def test_fb_def_and_forward_lowdim(fb_class, fb_config): """ Test filterbank definition and encoder/decoder forward.""" # Definition enc = Encoder(fb_class(**fb_config)) dec = Decoder(fb_class(**fb_config)) # Forward inp = torch.randn(1, 1, 16000) tf_out = enc(inp) # Assert for 2D inputs with pytest.warns(UserWarning): # STFT(2D) gives 3D and iSTFT(3D) gives 3D. UserWarning about that. assert_allclose(enc(inp), enc(inp[0])) # Assert for 1D inputs assert_allclose(enc(inp)[0], enc(inp[0, 0])) out = dec(tf_out) # Assert for 4D inputs tf_out_4d = tf_out.repeat(1, 2, 1, 1) out_4d = dec(tf_out_4d) assert_allclose(out, out_4d[:, 0]) # Asser for 2D inputs assert_allclose(out[0, 0], dec(tf_out[0])) assert tf_out.shape[1] == enc.filterbank.n_feats_out