Example #1
0
def test_common_transducer(enc_type, enc_kwargs):
    nnet_cls = aps_asr_nnet("asr@transducer")
    vocab_size = 100
    batch_size = 4
    dec_kwargs = {
        "embed_size": 512,
        "enc_dim": 512,
        "jot_dim": 512,
        "dec_rnn": "lstm",
        "dec_layers": 2,
        "dec_hidden": 512,
        "dec_dropout": 0.1
    }
    asr_transform = AsrTransform(feats="fbank-log-cmvn",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm")
    xfmr_encoders = ["xfmr_abs", "xfmr_rel", "xfmr_xl", "cfmr_xl"]
    rnnt = nnet_cls(input_size=80,
                    vocab_size=vocab_size,
                    blank=vocab_size - 1,
                    asr_transform=asr_transform,
                    enc_type=enc_type,
                    enc_proj=None if enc_type in xfmr_encoders else 512,
                    enc_kwargs=enc_kwargs,
                    dec_kwargs=dec_kwargs)
    x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size)
    z, _ = rnnt(x, x_len, y, y_len)
    assert z.shape[2:] == th.Size([u + 1, vocab_size])
Example #2
0
def test_xfmr_transducer(enc_type, enc_kwargs):
    nnet_cls = aps_asr_nnet("asr@xfmr_transducer")
    vocab_size = 100
    batch_size = 4
    dec_kwargs = {
        "jot_dim": 512,
        "att_dim": 512,
        "nhead": 8,
        "feedforward_dim": 2048,
        "pos_dropout": 0.1,
        "att_dropout": 0.1,
        "num_layers": 2
    }
    asr_transform = AsrTransform(feats="fbank-log-cmvn",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm")
    xfmr_rnnt = nnet_cls(input_size=80,
                         vocab_size=vocab_size,
                         blank=vocab_size - 1,
                         asr_transform=asr_transform,
                         enc_type=enc_type,
                         enc_proj=512,
                         enc_kwargs=enc_kwargs,
                         dec_kwargs=dec_kwargs)
    x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size)
    z, _ = xfmr_rnnt(x, x_len, y, y_len)
    assert z.shape[2:] == th.Size([u + 1, vocab_size])
Example #3
0
def test_att_encoder(enc_type, enc_kwargs):
    nnet_cls = aps_asr_nnet("asr@att")
    vocab_size = 100
    batch_size = 4
    asr_transform = AsrTransform(feats="fbank-log-cmvn",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm")
    att_asr = nnet_cls(input_size=80,
                       vocab_size=vocab_size,
                       sos=0,
                       eos=1,
                       ctc=True,
                       asr_transform=asr_transform,
                       att_type="ctx",
                       att_kwargs={"att_dim": 512},
                       enc_type=enc_type,
                       enc_proj=256,
                       enc_kwargs=enc_kwargs,
                       dec_type="rnn",
                       dec_dim=512,
                       dec_kwargs=default_rnn_dec_kwargs)
    x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size)
    z, _, _, _ = att_asr(x, x_len, y, y_len)
    assert z.shape == th.Size([4, u + 1, vocab_size - 1])
Example #4
0
def test_asr_transform(wav, feats, shape):
    transform = AsrTransform(feats=feats,
                             frame_len=400,
                             frame_hop=160,
                             use_power=True,
                             pre_emphasis=0.96,
                             aug_prob=0.5,
                             aug_mask_zero=False)
    feats, _ = transform(th.from_numpy(wav[None, ...]), None)
    assert feats.shape == th.Size(shape)
    assert th.sum(th.isnan(feats)) == 0
    assert transform.feats_dim == shape[-1]
Example #5
0
def test_mvdr_att(att_type, att_kwargs):
    nnet_cls = aps_asr_nnet("asr@enh_att")
    vocab_size = 100
    batch_size = 4
    num_channels = 4
    enh_kwargs = {
        "rnn": "lstm",
        "num_layers": 2,
        "rnn_inp_proj": 512,
        "hidden_size": 512,
        "dropout": 0.2,
        "bidirectional": False,
        "mvdr_att_dim": 512,
        "mask_norm": True,
        "num_bins": 257
    }
    asr_transform = AsrTransform(feats="abs-mel-log-cmvn",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm")
    enh_transform = EnhTransform(feats="spectrogram-log-cmvn-ipd",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm",
                                 ipd_index="0,1;0,2;0,3",
                                 cos_ipd=True)
    mvdr_att_asr = nnet_cls(enh_input_size=257 * 4,
                            vocab_size=vocab_size,
                            sos=0,
                            eos=1,
                            ctc=True,
                            enh_type="rnn_mask_mvdr",
                            enh_kwargs=enh_kwargs,
                            asr_transform=asr_transform,
                            enh_transform=enh_transform,
                            att_type=att_type,
                            att_kwargs=att_kwargs,
                            enc_type="pytorch_rnn",
                            enc_proj=256,
                            enc_kwargs=default_rnn_enc_kwargs,
                            dec_dim=512,
                            dec_kwargs=default_rnn_dec_kwargs)
    x, x_len, y, y_len, u = gen_egs(vocab_size,
                                    batch_size,
                                    num_channels=num_channels)
    z, _, _, _ = mvdr_att_asr(x, x_len, y, y_len)
    assert z.shape == th.Size([4, u + 1, vocab_size - 1])
Example #6
0
def debug_visualize_feature():
    transform = AsrTransform(feats="fbank-log-cmvn-delta",
                             frame_len=400,
                             frame_hop=160,
                             use_power=True,
                             pre_emphasis=0.97,
                             num_mels=80,
                             min_freq=20,
                             aug_prob=1,
                             norm_per_band=True,
                             aug_freq_args=(40, 1),
                             aug_time_args=(100, 1),
                             aug_mask_zero=False,
                             delta_as_channel=True)
    feats, _ = transform(th.from_numpy(egs1_wav[None, ...]), None)
    print(transform)
    from aps.plot import plot_feature
    plot_feature(feats[0, 1].numpy(), "egs")
Example #7
0
def test_xfmr_encoder(enc_type, enc_kwargs):
    nnet_cls = aps_asr_nnet("asr@xfmr")
    vocab_size = 100
    batch_size = 4
    asr_transform = AsrTransform(feats="fbank-log-cmvn",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm")
    xfmr_encoders = ["xfmr_abs", "xfmr_rel", "xfmr_xl", "cfmr_xl"]
    xfmr_asr = nnet_cls(
        input_size=80,
        vocab_size=vocab_size,
        sos=0,
        eos=1,
        ctc=True,
        asr_transform=asr_transform,
        enc_type=enc_type,
        enc_proj=512 if enc_type not in xfmr_encoders else None,
        enc_kwargs=enc_kwargs,
        dec_type="xfmr_abs",
        dec_kwargs=default_xfmr_dec_kwargs)
    x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size)
    z, _, _, _ = xfmr_asr(x, x_len, y, y_len)
    assert z.shape == th.Size([4, u + 1, vocab_size - 1])