Example #1
0
def test_freq_xfmr_rel(num_spks):
    nnet_cls = aps_sse_nnet("sse@freq_xfmr_rel")
    transform = EnhTransform(feats="spectrogram-log-cmvn",
                             frame_len=512,
                             frame_hop=256)
    xfmr = nnet_cls(input_size=257,
                    enh_transform=transform,
                    num_spks=num_spks,
                    num_bins=257,
                    att_dim=256,
                    nhead=4,
                    radius=256,
                    feedforward_dim=512,
                    att_dropout=0.1,
                    proj_dropout=0.1,
                    post_norm=True,
                    num_layers=3,
                    non_linear="sigmoid",
                    training_mode="time")
    inp = th.rand(4, 64000)
    x = xfmr(inp)
    if num_spks > 1:
        assert len(x) == num_spks
        assert x[0].shape == th.Size([4, 64000])
    else:
        assert x.shape == th.Size([4, 64000])
    y = xfmr.infer(inp[1])
    if num_spks > 1:
        y = y[0]
    assert y.shape == th.Size([64000])
Example #2
0
def test_beam_att(enh_type, enh_kwargs):
    nnet_cls = aps_asr_nnet("asr@enh_att")
    vocab_size = 100
    batch_size = 4
    num_channels = 4
    enh_transform = EnhTransform(feats="",
                                 frame_len=512,
                                 frame_hop=256,
                                 window="sqrthann")
    beam_att_asr = nnet_cls(
        vocab_size=vocab_size,
        asr_input_size=640 if enh_type != "time_invar_att" else 128,
        sos=0,
        eos=1,
        ctc=True,
        enh_type=enh_type,
        enh_kwargs=enh_kwargs,
        asr_transform=None,
        enh_transform=enh_transform,
        att_type="dot",
        att_kwargs={"att_dim": 512},
        enc_type="pytorch_rnn",
        enc_proj=256,
        enc_kwargs=default_rnn_enc_kwargs,
        dec_dim=512,
        dec_kwargs=default_rnn_dec_kwargs)
    x, x_len, y, y_len, u = gen_egs(vocab_size,
                                    batch_size,
                                    num_channels=num_channels)
    z, _, _, _ = beam_att_asr(x, x_len, y, y_len)
    assert z.shape == th.Size([4, u + 1, vocab_size - 1])
Example #3
0
def test_dense_unet(num_spks, non_linear):
    nnet_cls = aps_sse_nnet("sse@dense_unet")
    transform = EnhTransform(feats="spectrogram-log-cmvn",
                             frame_len=512,
                             frame_hop=256)
    dense_unet = nnet_cls(K="3,3;3,3;3,3;3,3;3,3;3,3;3,3;3,3",
                          S="1,1;2,1;2,1;2,1;2,1;2,1;2,1;2,1",
                          P="0,1;0,1;0,1;0,1;0,1;0,1;0,1;0,1;0,1",
                          O="0,0,0,0,0,0,0,0",
                          enc_channel="16,32,32,32,32,64,128,384",
                          dec_channel="32,16,32,32,32,32,64,128",
                          conv_dropout=0.3,
                          num_spks=num_spks,
                          rnn_hidden=512,
                          rnn_layers=2,
                          rnn_resize=384,
                          rnn_bidir=False,
                          rnn_dropout=0.2,
                          num_dense_blocks=5,
                          enh_transform=transform,
                          non_linear=non_linear,
                          inp_cplx=True,
                          out_cplx=True,
                          training_mode="time")
    inp = th.rand(4, 64000)
    x = dense_unet(inp)
    if num_spks > 1:
        x = x[0]
    assert x.shape == th.Size([4, 64000])
    y = dense_unet.infer(inp[1])
    if num_spks > 1:
        y = y[0]
    assert y.shape == th.Size([64000])
Example #4
0
def test_enh_transform(wav, feats, shape):
    transform = EnhTransform(feats=feats,
                             frame_len=512,
                             frame_hop=256,
                             ipd_index="0,1;0,2;0,3;0,4",
                             aug_prob=0.2)
    feats, stft, _ = transform(th.from_numpy(wav[None, ...]), None)
    assert feats.shape == th.Size(shape)
    assert th.sum(th.isnan(feats)) == 0
    assert stft.shape == th.Size([1, 5, 257, 366])
    assert transform.feats_dim == shape[-1]
Example #5
0
def toy_rnn(mode, num_spks):
    transform = EnhTransform(feats="spectrogram-log-cmvn",
                             frame_len=512,
                             frame_hop=256)
    base_rnn_cls = aps_sse_nnet("sse@base_rnn")
    return base_rnn_cls(enh_transform=transform,
                        num_bins=257,
                        input_size=257,
                        num_layers=2,
                        num_spks=num_spks,
                        hidden=256,
                        training_mode=mode)
Example #6
0
def test_enh_ml(num_channels):
    nnet_cls = aps_sse_nnet("sse@rnn_enh_ml")
    transform = EnhTransform(feats="spectrogram-log-cmvn-ipd",
                             frame_len=512,
                             frame_hop=256,
                             ipd_index="0,1;0,2")
    rnn_ml = nnet_cls(enh_transform=transform,
                      num_bins=257,
                      input_size=257 * 3,
                      input_proj=512,
                      num_layers=2,
                      hidden=512)
    task = aps_task("sse@enh_ml", rnn_ml)
    egs = {"mix": th.rand(4, num_channels, 64000)}
    run_epochs(task, egs, 3)
Example #7
0
def test_crn():
    nnet_cls = aps_sse_nnet("sse@crn")
    transform = EnhTransform(feats="spectrogram-log-cmvn",
                             frame_len=320,
                             frame_hop=160,
                             round_pow_of_two=False)
    crn = nnet_cls(161,
                   enh_transform=transform,
                   mode="masking",
                   training_mode="freq")
    inp = th.rand(4, 64000)
    x = crn(inp)
    assert x.shape == th.Size([4, 161, 399])
    z = crn.infer(inp[1])
    assert z.shape == th.Size([64000])
Example #8
0
def test_mvdr_att(att_type, att_kwargs):
    nnet_cls = aps_asr_nnet("asr@enh_att")
    vocab_size = 100
    batch_size = 4
    num_channels = 4
    enh_kwargs = {
        "rnn": "lstm",
        "num_layers": 2,
        "rnn_inp_proj": 512,
        "hidden_size": 512,
        "dropout": 0.2,
        "bidirectional": False,
        "mvdr_att_dim": 512,
        "mask_norm": True,
        "num_bins": 257
    }
    asr_transform = AsrTransform(feats="abs-mel-log-cmvn",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm")
    enh_transform = EnhTransform(feats="spectrogram-log-cmvn-ipd",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm",
                                 ipd_index="0,1;0,2;0,3",
                                 cos_ipd=True)
    mvdr_att_asr = nnet_cls(enh_input_size=257 * 4,
                            vocab_size=vocab_size,
                            sos=0,
                            eos=1,
                            ctc=True,
                            enh_type="rnn_mask_mvdr",
                            enh_kwargs=enh_kwargs,
                            asr_transform=asr_transform,
                            enh_transform=enh_transform,
                            att_type=att_type,
                            att_kwargs=att_kwargs,
                            enc_type="pytorch_rnn",
                            enc_proj=256,
                            enc_kwargs=default_rnn_enc_kwargs,
                            dec_dim=512,
                            dec_kwargs=default_rnn_dec_kwargs)
    x, x_len, y, y_len, u = gen_egs(vocab_size,
                                    batch_size,
                                    num_channels=num_channels)
    z, _, _, _ = mvdr_att_asr(x, x_len, y, y_len)
    assert z.shape == th.Size([4, u + 1, vocab_size - 1])
Example #9
0
def test_phasen():
    nnet_cls = aps_sse_nnet("sse@phasen")
    transform = EnhTransform(feats="", frame_len=512, frame_hop=256)
    phasen = nnet_cls(12,
                      4,
                      enh_transform=transform,
                      num_tsbs=1,
                      num_bins=257,
                      channel_r=5,
                      conv1d_kernel=9,
                      lstm_hidden=256,
                      linear_size=512)
    inp = th.rand(4, 64000)
    x, y = phasen(inp)
    assert x.shape == th.Size([4, 257, 249])
    assert y.shape == th.Size([4, 257, 249])
    z = phasen.infer(inp[1])
    assert z.shape == th.Size([64000])
Example #10
0
def test_rnn_enh_ml(num_bins):
    nnet_cls = aps_sse_nnet("sse@rnn_enh_ml")
    transform = EnhTransform(feats="spectrogram-log-cmvn-ipd",
                             frame_len=512,
                             frame_hop=256,
                             ipd_index="0,1;0,2;0,3")
    rnn_enh_ml = nnet_cls(enh_transform=transform,
                          num_bins=num_bins,
                          input_size=num_bins * 4,
                          input_proj=512,
                          num_layers=2,
                          hidden=512)
    inp = th.rand(2, 5, 64000)
    x, y = rnn_enh_ml(inp)
    assert x.shape == th.Size([2, 5, num_bins, 249])
    assert th.isnan(x.real).sum() + th.isnan(x.imag).sum() == 0
    assert y.shape == th.Size([2, 249, num_bins])
    z = rnn_enh_ml.infer(inp[0])
    assert z.shape == th.Size([249, num_bins])
Example #11
0
def test_base_rnn(num_spks, nonlinear):
    nnet_cls = aps_sse_nnet("sse@base_rnn")
    transform = EnhTransform(feats="spectrogram-log-cmvn",
                             frame_len=512,
                             frame_hop=256)
    base_rnn = nnet_cls(enh_transform=transform,
                        num_bins=257,
                        input_size=257,
                        input_proj=512,
                        num_layers=2,
                        hidden=512,
                        num_spks=num_spks,
                        output_nonlinear=nonlinear)
    inp = th.rand(2, 64000)
    x = base_rnn(inp)
    if num_spks > 1:
        x = x[0]
    assert x.shape == th.Size([2, 257, 249])
    z = base_rnn.infer(inp[0])
    if num_spks > 1:
        z = z[0]
    assert z.shape == th.Size([64000])
Example #12
0
def test_dcunet(num_branch, cplx):
    nnet_cls = aps_sse_nnet("sse@dcunet")
    transform = EnhTransform(feats="", frame_len=512, frame_hop=256)
    dcunet = nnet_cls(enh_transform=transform,
                      K="7,5;7,5;5,3;5,3;3,3;3,3",
                      S="2,1;2,1;2,1;2,1;2,1;2,1",
                      C="32,32,64,64,64,128",
                      P="1,1,1,1,1,0",
                      O="0,0,1,1,1,0",
                      num_branch=num_branch,
                      cplx=cplx,
                      causal_conv=False,
                      freq_padding=True,
                      connection="cat")
    inp = th.rand(4, 64000)
    x = dcunet(inp)
    if num_branch > 1:
        x = x[0]
    assert x.shape == th.Size([4, 64000])
    y = dcunet.infer(inp[1])
    if num_branch > 1:
        y = y[0]
    assert y.shape == th.Size([64000])
Example #13
0
def test_dccrn(num_spks, cplx):
    nnet_cls = aps_sse_nnet("sse@dccrn")
    transform = EnhTransform(feats="spectrogram", frame_len=512, frame_hop=256)
    dccrn = nnet_cls(enh_transform=transform,
                     cplx=cplx,
                     K="3,3;3,3;3,3;3,3;3,3;3,3;3,3",
                     S="2,1;2,1;2,1;2,1;2,1;2,1;2,1",
                     P="1,1,1,1,1,0,0",
                     O="0,0,0,0,0,0,1",
                     C="16,32,64,64,128,128,256",
                     num_spks=num_spks,
                     rnn_resize=512 if cplx else 256,
                     non_linear="sigmoid",
                     connection="cat")
    inp = th.rand(4, 64000)
    x = dccrn(inp)
    if num_spks > 1:
        x = x[0]
    assert x.shape == th.Size([4, 64000])
    y = dccrn.infer(inp[1])
    if num_spks > 1:
        y = y[0]
    assert y.shape == th.Size([64000])