def test_freq_xfmr_rel(num_spks): nnet_cls = aps_sse_nnet("sse@freq_xfmr_rel") transform = EnhTransform(feats="spectrogram-log-cmvn", frame_len=512, frame_hop=256) xfmr = nnet_cls(input_size=257, enh_transform=transform, num_spks=num_spks, num_bins=257, att_dim=256, nhead=4, radius=256, feedforward_dim=512, att_dropout=0.1, proj_dropout=0.1, post_norm=True, num_layers=3, non_linear="sigmoid", training_mode="time") inp = th.rand(4, 64000) x = xfmr(inp) if num_spks > 1: assert len(x) == num_spks assert x[0].shape == th.Size([4, 64000]) else: assert x.shape == th.Size([4, 64000]) y = xfmr.infer(inp[1]) if num_spks > 1: y = y[0] assert y.shape == th.Size([64000])
def test_beam_att(enh_type, enh_kwargs): nnet_cls = aps_asr_nnet("asr@enh_att") vocab_size = 100 batch_size = 4 num_channels = 4 enh_transform = EnhTransform(feats="", frame_len=512, frame_hop=256, window="sqrthann") beam_att_asr = nnet_cls( vocab_size=vocab_size, asr_input_size=640 if enh_type != "time_invar_att" else 128, sos=0, eos=1, ctc=True, enh_type=enh_type, enh_kwargs=enh_kwargs, asr_transform=None, enh_transform=enh_transform, att_type="dot", att_kwargs={"att_dim": 512}, enc_type="pytorch_rnn", enc_proj=256, enc_kwargs=default_rnn_enc_kwargs, dec_dim=512, dec_kwargs=default_rnn_dec_kwargs) x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size, num_channels=num_channels) z, _, _, _ = beam_att_asr(x, x_len, y, y_len) assert z.shape == th.Size([4, u + 1, vocab_size - 1])
def test_dense_unet(num_spks, non_linear): nnet_cls = aps_sse_nnet("sse@dense_unet") transform = EnhTransform(feats="spectrogram-log-cmvn", frame_len=512, frame_hop=256) dense_unet = nnet_cls(K="3,3;3,3;3,3;3,3;3,3;3,3;3,3;3,3", S="1,1;2,1;2,1;2,1;2,1;2,1;2,1;2,1", P="0,1;0,1;0,1;0,1;0,1;0,1;0,1;0,1;0,1", O="0,0,0,0,0,0,0,0", enc_channel="16,32,32,32,32,64,128,384", dec_channel="32,16,32,32,32,32,64,128", conv_dropout=0.3, num_spks=num_spks, rnn_hidden=512, rnn_layers=2, rnn_resize=384, rnn_bidir=False, rnn_dropout=0.2, num_dense_blocks=5, enh_transform=transform, non_linear=non_linear, inp_cplx=True, out_cplx=True, training_mode="time") inp = th.rand(4, 64000) x = dense_unet(inp) if num_spks > 1: x = x[0] assert x.shape == th.Size([4, 64000]) y = dense_unet.infer(inp[1]) if num_spks > 1: y = y[0] assert y.shape == th.Size([64000])
def test_enh_transform(wav, feats, shape): transform = EnhTransform(feats=feats, frame_len=512, frame_hop=256, ipd_index="0,1;0,2;0,3;0,4", aug_prob=0.2) feats, stft, _ = transform(th.from_numpy(wav[None, ...]), None) assert feats.shape == th.Size(shape) assert th.sum(th.isnan(feats)) == 0 assert stft.shape == th.Size([1, 5, 257, 366]) assert transform.feats_dim == shape[-1]
def toy_rnn(mode, num_spks): transform = EnhTransform(feats="spectrogram-log-cmvn", frame_len=512, frame_hop=256) base_rnn_cls = aps_sse_nnet("sse@base_rnn") return base_rnn_cls(enh_transform=transform, num_bins=257, input_size=257, num_layers=2, num_spks=num_spks, hidden=256, training_mode=mode)
def test_enh_ml(num_channels): nnet_cls = aps_sse_nnet("sse@rnn_enh_ml") transform = EnhTransform(feats="spectrogram-log-cmvn-ipd", frame_len=512, frame_hop=256, ipd_index="0,1;0,2") rnn_ml = nnet_cls(enh_transform=transform, num_bins=257, input_size=257 * 3, input_proj=512, num_layers=2, hidden=512) task = aps_task("sse@enh_ml", rnn_ml) egs = {"mix": th.rand(4, num_channels, 64000)} run_epochs(task, egs, 3)
def test_crn(): nnet_cls = aps_sse_nnet("sse@crn") transform = EnhTransform(feats="spectrogram-log-cmvn", frame_len=320, frame_hop=160, round_pow_of_two=False) crn = nnet_cls(161, enh_transform=transform, mode="masking", training_mode="freq") inp = th.rand(4, 64000) x = crn(inp) assert x.shape == th.Size([4, 161, 399]) z = crn.infer(inp[1]) assert z.shape == th.Size([64000])
def test_mvdr_att(att_type, att_kwargs): nnet_cls = aps_asr_nnet("asr@enh_att") vocab_size = 100 batch_size = 4 num_channels = 4 enh_kwargs = { "rnn": "lstm", "num_layers": 2, "rnn_inp_proj": 512, "hidden_size": 512, "dropout": 0.2, "bidirectional": False, "mvdr_att_dim": 512, "mask_norm": True, "num_bins": 257 } asr_transform = AsrTransform(feats="abs-mel-log-cmvn", frame_len=400, frame_hop=160, window="hamm") enh_transform = EnhTransform(feats="spectrogram-log-cmvn-ipd", frame_len=400, frame_hop=160, window="hamm", ipd_index="0,1;0,2;0,3", cos_ipd=True) mvdr_att_asr = nnet_cls(enh_input_size=257 * 4, vocab_size=vocab_size, sos=0, eos=1, ctc=True, enh_type="rnn_mask_mvdr", enh_kwargs=enh_kwargs, asr_transform=asr_transform, enh_transform=enh_transform, att_type=att_type, att_kwargs=att_kwargs, enc_type="pytorch_rnn", enc_proj=256, enc_kwargs=default_rnn_enc_kwargs, dec_dim=512, dec_kwargs=default_rnn_dec_kwargs) x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size, num_channels=num_channels) z, _, _, _ = mvdr_att_asr(x, x_len, y, y_len) assert z.shape == th.Size([4, u + 1, vocab_size - 1])
def test_phasen(): nnet_cls = aps_sse_nnet("sse@phasen") transform = EnhTransform(feats="", frame_len=512, frame_hop=256) phasen = nnet_cls(12, 4, enh_transform=transform, num_tsbs=1, num_bins=257, channel_r=5, conv1d_kernel=9, lstm_hidden=256, linear_size=512) inp = th.rand(4, 64000) x, y = phasen(inp) assert x.shape == th.Size([4, 257, 249]) assert y.shape == th.Size([4, 257, 249]) z = phasen.infer(inp[1]) assert z.shape == th.Size([64000])
def test_rnn_enh_ml(num_bins): nnet_cls = aps_sse_nnet("sse@rnn_enh_ml") transform = EnhTransform(feats="spectrogram-log-cmvn-ipd", frame_len=512, frame_hop=256, ipd_index="0,1;0,2;0,3") rnn_enh_ml = nnet_cls(enh_transform=transform, num_bins=num_bins, input_size=num_bins * 4, input_proj=512, num_layers=2, hidden=512) inp = th.rand(2, 5, 64000) x, y = rnn_enh_ml(inp) assert x.shape == th.Size([2, 5, num_bins, 249]) assert th.isnan(x.real).sum() + th.isnan(x.imag).sum() == 0 assert y.shape == th.Size([2, 249, num_bins]) z = rnn_enh_ml.infer(inp[0]) assert z.shape == th.Size([249, num_bins])
def test_base_rnn(num_spks, nonlinear): nnet_cls = aps_sse_nnet("sse@base_rnn") transform = EnhTransform(feats="spectrogram-log-cmvn", frame_len=512, frame_hop=256) base_rnn = nnet_cls(enh_transform=transform, num_bins=257, input_size=257, input_proj=512, num_layers=2, hidden=512, num_spks=num_spks, output_nonlinear=nonlinear) inp = th.rand(2, 64000) x = base_rnn(inp) if num_spks > 1: x = x[0] assert x.shape == th.Size([2, 257, 249]) z = base_rnn.infer(inp[0]) if num_spks > 1: z = z[0] assert z.shape == th.Size([64000])
def test_dcunet(num_branch, cplx): nnet_cls = aps_sse_nnet("sse@dcunet") transform = EnhTransform(feats="", frame_len=512, frame_hop=256) dcunet = nnet_cls(enh_transform=transform, K="7,5;7,5;5,3;5,3;3,3;3,3", S="2,1;2,1;2,1;2,1;2,1;2,1", C="32,32,64,64,64,128", P="1,1,1,1,1,0", O="0,0,1,1,1,0", num_branch=num_branch, cplx=cplx, causal_conv=False, freq_padding=True, connection="cat") inp = th.rand(4, 64000) x = dcunet(inp) if num_branch > 1: x = x[0] assert x.shape == th.Size([4, 64000]) y = dcunet.infer(inp[1]) if num_branch > 1: y = y[0] assert y.shape == th.Size([64000])
def test_dccrn(num_spks, cplx): nnet_cls = aps_sse_nnet("sse@dccrn") transform = EnhTransform(feats="spectrogram", frame_len=512, frame_hop=256) dccrn = nnet_cls(enh_transform=transform, cplx=cplx, K="3,3;3,3;3,3;3,3;3,3;3,3;3,3", S="2,1;2,1;2,1;2,1;2,1;2,1;2,1", P="1,1,1,1,1,0,0", O="0,0,0,0,0,0,1", C="16,32,64,64,128,128,256", num_spks=num_spks, rnn_resize=512 if cplx else 256, non_linear="sigmoid", connection="cat") inp = th.rand(4, 64000) x = dccrn(inp) if num_spks > 1: x = x[0] assert x.shape == th.Size([4, 64000]) y = dccrn.infer(inp[1]) if num_spks > 1: y = y[0] assert y.shape == th.Size([64000])