コード例 #1
0
def test_xfmr_transducer(enc_type, enc_kwargs):
    nnet_cls = aps_asr_nnet("asr@xfmr_transducer")
    vocab_size = 100
    batch_size = 4
    dec_kwargs = {
        "jot_dim": 512,
        "att_dim": 512,
        "nhead": 8,
        "feedforward_dim": 2048,
        "pos_dropout": 0.1,
        "att_dropout": 0.1,
        "num_layers": 2
    }
    asr_transform = AsrTransform(feats="fbank-log-cmvn",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm")
    xfmr_rnnt = nnet_cls(input_size=80,
                         vocab_size=vocab_size,
                         blank=vocab_size - 1,
                         asr_transform=asr_transform,
                         enc_type=enc_type,
                         enc_proj=512,
                         enc_kwargs=enc_kwargs,
                         dec_kwargs=dec_kwargs)
    x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size)
    z, _ = xfmr_rnnt(x, x_len, y, y_len)
    assert z.shape[2:] == th.Size([u + 1, vocab_size])
コード例 #2
0
def test_beam_att(enh_type, enh_kwargs):
    nnet_cls = aps_asr_nnet("asr@enh_att")
    vocab_size = 100
    batch_size = 4
    num_channels = 4
    enh_transform = EnhTransform(feats="",
                                 frame_len=512,
                                 frame_hop=256,
                                 window="sqrthann")
    beam_att_asr = nnet_cls(
        vocab_size=vocab_size,
        asr_input_size=640 if enh_type != "time_invar_att" else 128,
        sos=0,
        eos=1,
        ctc=True,
        enh_type=enh_type,
        enh_kwargs=enh_kwargs,
        asr_transform=None,
        enh_transform=enh_transform,
        att_type="dot",
        att_kwargs={"att_dim": 512},
        enc_type="pytorch_rnn",
        enc_proj=256,
        enc_kwargs=default_rnn_enc_kwargs,
        dec_dim=512,
        dec_kwargs=default_rnn_dec_kwargs)
    x, x_len, y, y_len, u = gen_egs(vocab_size,
                                    batch_size,
                                    num_channels=num_channels)
    z, _, _, _ = beam_att_asr(x, x_len, y, y_len)
    assert z.shape == th.Size([4, u + 1, vocab_size - 1])
コード例 #3
0
def test_common_transducer(enc_type, enc_kwargs):
    nnet_cls = aps_asr_nnet("asr@transducer")
    vocab_size = 100
    batch_size = 4
    dec_kwargs = {
        "embed_size": 512,
        "enc_dim": 512,
        "jot_dim": 512,
        "dec_rnn": "lstm",
        "dec_layers": 2,
        "dec_hidden": 512,
        "dec_dropout": 0.1
    }
    asr_transform = AsrTransform(feats="fbank-log-cmvn",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm")
    xfmr_encoders = ["xfmr_abs", "xfmr_rel", "xfmr_xl", "cfmr_xl"]
    rnnt = nnet_cls(input_size=80,
                    vocab_size=vocab_size,
                    blank=vocab_size - 1,
                    asr_transform=asr_transform,
                    enc_type=enc_type,
                    enc_proj=None if enc_type in xfmr_encoders else 512,
                    enc_kwargs=enc_kwargs,
                    dec_kwargs=dec_kwargs)
    x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size)
    z, _ = rnnt(x, x_len, y, y_len)
    assert z.shape[2:] == th.Size([u + 1, vocab_size])
コード例 #4
0
def test_att_encoder(enc_type, enc_kwargs):
    nnet_cls = aps_asr_nnet("asr@att")
    vocab_size = 100
    batch_size = 4
    asr_transform = AsrTransform(feats="fbank-log-cmvn",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm")
    att_asr = nnet_cls(input_size=80,
                       vocab_size=vocab_size,
                       sos=0,
                       eos=1,
                       ctc=True,
                       asr_transform=asr_transform,
                       att_type="ctx",
                       att_kwargs={"att_dim": 512},
                       enc_type=enc_type,
                       enc_proj=256,
                       enc_kwargs=enc_kwargs,
                       dec_type="rnn",
                       dec_dim=512,
                       dec_kwargs=default_rnn_dec_kwargs)
    x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size)
    z, _, _, _ = att_asr(x, x_len, y, y_len)
    assert z.shape == th.Size([4, u + 1, vocab_size - 1])
コード例 #5
0
def run(args):
    # set random seed
    seed = set_seed(args.seed)
    if seed is not None:
        print(f"Set random seed as {seed}")

    conf, vocab = load_am_conf(args.conf, args.dict)

    print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True)
    print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True)

    asr_cls = aps_asr_nnet(conf["nnet"])
    asr_transform = None
    enh_transform = None
    if "asr_transform" in conf:
        asr_transform = aps_transform("asr")(**conf["asr_transform"])
    if "enh_transform" in conf:
        enh_transform = aps_transform("enh")(**conf["enh_transform"])

    if enh_transform:
        nnet = asr_cls(enh_transform=enh_transform,
                       asr_transform=asr_transform,
                       **conf["nnet_conf"])
    elif asr_transform:
        nnet = asr_cls(asr_transform=asr_transform, **conf["nnet_conf"])
    else:
        nnet = asr_cls(**conf["nnet_conf"])

    task = aps_task(conf["task"], nnet, **conf["task_conf"])
    train_worker(task, conf, vocab, args)
コード例 #6
0
def run(args):
    _ = set_seed(args.seed)
    conf, vocab = load_am_conf(args.conf, args.dict)

    print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True)
    print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True)

    asr_cls = aps_asr_nnet(conf["nnet"])
    asr_transform = None
    enh_transform = None
    if "asr_transform" in conf:
        asr_transform = aps_transform("asr")(**conf["asr_transform"])
    if "enh_transform" in conf:
        enh_transform = aps_transform("enh")(**conf["enh_transform"])

    if enh_transform:
        nnet = asr_cls(enh_transform=enh_transform,
                       asr_transform=asr_transform,
                       **conf["nnet_conf"])
    elif asr_transform:
        nnet = asr_cls(asr_transform=asr_transform, **conf["nnet_conf"])
    else:
        nnet = asr_cls(**conf["nnet_conf"])

    other_loader_conf = {
        "vocab_dict": vocab,
    }
    start_trainer(args.trainer,
                  conf,
                  nnet,
                  args,
                  reduction_tag="#tok",
                  other_loader_conf=other_loader_conf)
    dump_dict(f"{args.checkpoint}/dict", vocab, reverse=False)
コード例 #7
0
def run(args):
    # set random seed
    seed = set_seed(args.seed)
    if seed is not None:
        print(f"Set random seed as {seed}")

    conf, vocab_dict = load_am_conf(args.conf, args.dict)
    print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True)
    print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True)
    data_conf = conf["data_conf"]
    load_conf = {
        "fmt": data_conf["fmt"],
        "vocab_dict": vocab_dict,
        "num_workers": args.num_workers,
        "max_batch_size": args.batch_size
    }
    load_conf.update(data_conf["loader"])
    trn_loader = aps_dataloader(train=True, **data_conf["train"], **load_conf)
    dev_loader = aps_dataloader(train=False, **data_conf["valid"], **load_conf)

    asr_cls = aps_asr_nnet(conf["nnet"])
    asr_transform = None
    enh_transform = None
    if "asr_transform" in conf:
        asr_transform = aps_transform("asr")(**conf["asr_transform"])
    if "enh_transform" in conf:
        enh_transform = aps_transform("enh")(**conf["enh_transform"])

    if enh_transform:
        nnet = asr_cls(enh_transform=enh_transform,
                       asr_transform=asr_transform,
                       **conf["nnet_conf"])
    elif asr_transform:
        nnet = asr_cls(asr_transform=asr_transform, **conf["nnet_conf"])
    else:
        nnet = asr_cls(**conf["nnet_conf"])

    task = aps_task(conf["task"], nnet, **conf["task_conf"])
    Trainer = aps_trainer(args.trainer, distributed=False)
    trainer = Trainer(task,
                      device_ids=args.device_id,
                      checkpoint=args.checkpoint,
                      resume=args.resume,
                      init=args.init,
                      save_interval=args.save_interval,
                      prog_interval=args.prog_interval,
                      tensorboard=args.tensorboard,
                      reduction_tag="#tok",
                      **conf["trainer_conf"])
    # dump yaml configurations
    conf["cmd_args"] = vars(args)
    with open(f"{args.checkpoint}/train.yaml", "w") as f:
        yaml.dump(conf, f)
    # dump dict
    dump_dict(f"{args.checkpoint}/dict", vocab_dict, reverse=False)

    trainer.run(trn_loader,
                dev_loader,
                num_epochs=args.epochs,
                eval_interval=args.eval_interval)
コード例 #8
0
ファイル: distributed_train_lm.py プロジェクト: xmpx/aps
def run(args):
    # set random seed
    seed = set_seed(args.seed)
    if seed is not None:
        print(f"Set random seed as {seed}")

    conf, vocab = load_lm_conf(args.conf, args.dict)
    print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True)
    print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True)

    nnet = aps_asr_nnet(conf["nnet"])(**conf["nnet_conf"])
    task = aps_task(conf["task"], nnet, **conf["task_conf"])

    train_worker(task, conf, vocab, args)
コード例 #9
0
def test_mvdr_att(att_type, att_kwargs):
    nnet_cls = aps_asr_nnet("asr@enh_att")
    vocab_size = 100
    batch_size = 4
    num_channels = 4
    enh_kwargs = {
        "rnn": "lstm",
        "num_layers": 2,
        "rnn_inp_proj": 512,
        "hidden_size": 512,
        "dropout": 0.2,
        "bidirectional": False,
        "mvdr_att_dim": 512,
        "mask_norm": True,
        "num_bins": 257
    }
    asr_transform = AsrTransform(feats="abs-mel-log-cmvn",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm")
    enh_transform = EnhTransform(feats="spectrogram-log-cmvn-ipd",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm",
                                 ipd_index="0,1;0,2;0,3",
                                 cos_ipd=True)
    mvdr_att_asr = nnet_cls(enh_input_size=257 * 4,
                            vocab_size=vocab_size,
                            sos=0,
                            eos=1,
                            ctc=True,
                            enh_type="rnn_mask_mvdr",
                            enh_kwargs=enh_kwargs,
                            asr_transform=asr_transform,
                            enh_transform=enh_transform,
                            att_type=att_type,
                            att_kwargs=att_kwargs,
                            enc_type="pytorch_rnn",
                            enc_proj=256,
                            enc_kwargs=default_rnn_enc_kwargs,
                            dec_dim=512,
                            dec_kwargs=default_rnn_dec_kwargs)
    x, x_len, y, y_len, u = gen_egs(vocab_size,
                                    batch_size,
                                    num_channels=num_channels)
    z, _, _, _ = mvdr_att_asr(x, x_len, y, y_len)
    assert z.shape == th.Size([4, u + 1, vocab_size - 1])
コード例 #10
0
ファイル: train_lm.py プロジェクト: xmpx/aps
def run(args):
    # set random seed
    seed = set_seed(args.seed)
    if seed is not None:
        print(f"Set random seed as {seed}")

    conf, vocab = load_lm_conf(args.conf, args.dict)
    print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True)
    print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True)

    data_conf = conf["data_conf"]
    load_conf = {
        "vocab_dict": vocab,
        "num_workers": args.num_workers,
        "sos": vocab["<sos>"],
        "eos": vocab["<eos>"],
        "fmt": data_conf["fmt"],
        "max_batch_size": args.batch_size
    }
    load_conf.update(data_conf["loader"])
    trn_loader = aps_dataloader(train=True, **data_conf["train"], **load_conf)
    dev_loader = aps_dataloader(train=False, **data_conf["valid"], **load_conf)

    nnet = aps_asr_nnet(conf["nnet"])(**conf["nnet_conf"])
    task = aps_task(conf["task"], nnet, **conf["task_conf"])

    Trainer = aps_trainer(args.trainer, distributed=False)
    trainer = Trainer(task,
                      device_ids=args.device_id,
                      checkpoint=args.checkpoint,
                      resume=args.resume,
                      save_interval=args.save_interval,
                      prog_interval=args.prog_interval,
                      tensorboard=args.tensorboard,
                      reduction_tag="#tok",
                      **conf["trainer_conf"])
    # dump configurations
    conf["cmd_args"] = vars(args)
    with open(f"{args.checkpoint}/train.yaml", "w") as f:
        yaml.dump(conf, f)

    trainer.run(trn_loader,
                dev_loader,
                num_epochs=args.epochs,
                eval_interval=args.eval_interval)
コード例 #11
0
ファイル: train_lm.py プロジェクト: yt752/aps
def run(args):
    _ = set_seed(args.seed)
    conf, vocab = load_lm_conf(args.conf, args.dict)

    print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True)
    print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True)

    nnet = aps_asr_nnet(conf["nnet"])(**conf["nnet_conf"])

    other_loader_conf = {
        "vocab_dict": vocab,
        "sos": vocab["<sos>"],
        "eos": vocab["<eos>"],
    }
    start_trainer(args.trainer,
                  conf,
                  nnet,
                  args,
                  reduction_tag="#tok",
                  other_loader_conf=other_loader_conf)
    dump_dict(f"{args.checkpoint}/dict", vocab, reverse=False)
コード例 #12
0
ファイル: eval.py プロジェクト: xmpx/aps
    def _load(self,
              cpt_dir: str,
              cpt_tag: str = "best",
              task: str = "asr") -> Tuple[int, nn.Module, Dict]:
        if task not in ["asr", "sse"]:
            raise ValueError(f"Unknown task name: {task}")
        cpt_dir = pathlib.Path(cpt_dir)
        # load checkpoint
        cpt = th.load(cpt_dir / f"{cpt_tag}.pt.tar", map_location="cpu")
        with open(cpt_dir / "train.yaml", "r") as f:
            conf = yaml.full_load(f)
            if task == "asr":
                net_cls = aps_asr_nnet(conf["nnet"])
            else:
                net_cls = aps_sse_nnet(conf["nnet"])
        asr_transform = None
        enh_transform = None
        self.accept_raw = False
        if "asr_transform" in conf:
            asr_transform = aps_transform("asr")(**conf["asr_transform"])
            # if no STFT layer
            self.accept_raw = asr_transform.spectra_index != -1
        if "enh_transform" in conf:
            enh_transform = aps_transform("enh")(**conf["enh_transform"])
            self.accept_raw = True
        if enh_transform and asr_transform:
            nnet = net_cls(enh_transform=enh_transform,
                           asr_transform=asr_transform,
                           **conf["nnet_conf"])
        elif asr_transform:
            nnet = net_cls(asr_transform=asr_transform, **conf["nnet_conf"])
        elif enh_transform:
            nnet = net_cls(enh_transform=enh_transform, **conf["nnet_conf"])
        else:
            nnet = net_cls(**conf["nnet_conf"])

        nnet.load_state_dict(cpt["model_state"])
        return cpt["epoch"], nnet, conf
コード例 #13
0
def test_xfmr_encoder(enc_type, enc_kwargs):
    nnet_cls = aps_asr_nnet("asr@xfmr")
    vocab_size = 100
    batch_size = 4
    asr_transform = AsrTransform(feats="fbank-log-cmvn",
                                 frame_len=400,
                                 frame_hop=160,
                                 window="hamm")
    xfmr_encoders = ["xfmr_abs", "xfmr_rel", "xfmr_xl", "cfmr_xl"]
    xfmr_asr = nnet_cls(
        input_size=80,
        vocab_size=vocab_size,
        sos=0,
        eos=1,
        ctc=True,
        asr_transform=asr_transform,
        enc_type=enc_type,
        enc_proj=512 if enc_type not in xfmr_encoders else None,
        enc_kwargs=enc_kwargs,
        dec_type="xfmr_abs",
        dec_kwargs=default_xfmr_dec_kwargs)
    x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size)
    z, _, _, _ = xfmr_asr(x, x_len, y, y_len)
    assert z.shape == th.Size([4, u + 1, vocab_size - 1])