def test_xfmr_transducer(enc_type, enc_kwargs): nnet_cls = aps_asr_nnet("asr@xfmr_transducer") vocab_size = 100 batch_size = 4 dec_kwargs = { "jot_dim": 512, "att_dim": 512, "nhead": 8, "feedforward_dim": 2048, "pos_dropout": 0.1, "att_dropout": 0.1, "num_layers": 2 } asr_transform = AsrTransform(feats="fbank-log-cmvn", frame_len=400, frame_hop=160, window="hamm") xfmr_rnnt = nnet_cls(input_size=80, vocab_size=vocab_size, blank=vocab_size - 1, asr_transform=asr_transform, enc_type=enc_type, enc_proj=512, enc_kwargs=enc_kwargs, dec_kwargs=dec_kwargs) x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size) z, _ = xfmr_rnnt(x, x_len, y, y_len) assert z.shape[2:] == th.Size([u + 1, vocab_size])
def test_beam_att(enh_type, enh_kwargs): nnet_cls = aps_asr_nnet("asr@enh_att") vocab_size = 100 batch_size = 4 num_channels = 4 enh_transform = EnhTransform(feats="", frame_len=512, frame_hop=256, window="sqrthann") beam_att_asr = nnet_cls( vocab_size=vocab_size, asr_input_size=640 if enh_type != "time_invar_att" else 128, sos=0, eos=1, ctc=True, enh_type=enh_type, enh_kwargs=enh_kwargs, asr_transform=None, enh_transform=enh_transform, att_type="dot", att_kwargs={"att_dim": 512}, enc_type="pytorch_rnn", enc_proj=256, enc_kwargs=default_rnn_enc_kwargs, dec_dim=512, dec_kwargs=default_rnn_dec_kwargs) x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size, num_channels=num_channels) z, _, _, _ = beam_att_asr(x, x_len, y, y_len) assert z.shape == th.Size([4, u + 1, vocab_size - 1])
def test_common_transducer(enc_type, enc_kwargs): nnet_cls = aps_asr_nnet("asr@transducer") vocab_size = 100 batch_size = 4 dec_kwargs = { "embed_size": 512, "enc_dim": 512, "jot_dim": 512, "dec_rnn": "lstm", "dec_layers": 2, "dec_hidden": 512, "dec_dropout": 0.1 } asr_transform = AsrTransform(feats="fbank-log-cmvn", frame_len=400, frame_hop=160, window="hamm") xfmr_encoders = ["xfmr_abs", "xfmr_rel", "xfmr_xl", "cfmr_xl"] rnnt = nnet_cls(input_size=80, vocab_size=vocab_size, blank=vocab_size - 1, asr_transform=asr_transform, enc_type=enc_type, enc_proj=None if enc_type in xfmr_encoders else 512, enc_kwargs=enc_kwargs, dec_kwargs=dec_kwargs) x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size) z, _ = rnnt(x, x_len, y, y_len) assert z.shape[2:] == th.Size([u + 1, vocab_size])
def test_att_encoder(enc_type, enc_kwargs): nnet_cls = aps_asr_nnet("asr@att") vocab_size = 100 batch_size = 4 asr_transform = AsrTransform(feats="fbank-log-cmvn", frame_len=400, frame_hop=160, window="hamm") att_asr = nnet_cls(input_size=80, vocab_size=vocab_size, sos=0, eos=1, ctc=True, asr_transform=asr_transform, att_type="ctx", att_kwargs={"att_dim": 512}, enc_type=enc_type, enc_proj=256, enc_kwargs=enc_kwargs, dec_type="rnn", dec_dim=512, dec_kwargs=default_rnn_dec_kwargs) x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size) z, _, _, _ = att_asr(x, x_len, y, y_len) assert z.shape == th.Size([4, u + 1, vocab_size - 1])
def run(args): # set random seed seed = set_seed(args.seed) if seed is not None: print(f"Set random seed as {seed}") conf, vocab = load_am_conf(args.conf, args.dict) print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True) asr_cls = aps_asr_nnet(conf["nnet"]) asr_transform = None enh_transform = None if "asr_transform" in conf: asr_transform = aps_transform("asr")(**conf["asr_transform"]) if "enh_transform" in conf: enh_transform = aps_transform("enh")(**conf["enh_transform"]) if enh_transform: nnet = asr_cls(enh_transform=enh_transform, asr_transform=asr_transform, **conf["nnet_conf"]) elif asr_transform: nnet = asr_cls(asr_transform=asr_transform, **conf["nnet_conf"]) else: nnet = asr_cls(**conf["nnet_conf"]) task = aps_task(conf["task"], nnet, **conf["task_conf"]) train_worker(task, conf, vocab, args)
def run(args): _ = set_seed(args.seed) conf, vocab = load_am_conf(args.conf, args.dict) print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True) asr_cls = aps_asr_nnet(conf["nnet"]) asr_transform = None enh_transform = None if "asr_transform" in conf: asr_transform = aps_transform("asr")(**conf["asr_transform"]) if "enh_transform" in conf: enh_transform = aps_transform("enh")(**conf["enh_transform"]) if enh_transform: nnet = asr_cls(enh_transform=enh_transform, asr_transform=asr_transform, **conf["nnet_conf"]) elif asr_transform: nnet = asr_cls(asr_transform=asr_transform, **conf["nnet_conf"]) else: nnet = asr_cls(**conf["nnet_conf"]) other_loader_conf = { "vocab_dict": vocab, } start_trainer(args.trainer, conf, nnet, args, reduction_tag="#tok", other_loader_conf=other_loader_conf) dump_dict(f"{args.checkpoint}/dict", vocab, reverse=False)
def run(args): # set random seed seed = set_seed(args.seed) if seed is not None: print(f"Set random seed as {seed}") conf, vocab_dict = load_am_conf(args.conf, args.dict) print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True) data_conf = conf["data_conf"] load_conf = { "fmt": data_conf["fmt"], "vocab_dict": vocab_dict, "num_workers": args.num_workers, "max_batch_size": args.batch_size } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, **data_conf["train"], **load_conf) dev_loader = aps_dataloader(train=False, **data_conf["valid"], **load_conf) asr_cls = aps_asr_nnet(conf["nnet"]) asr_transform = None enh_transform = None if "asr_transform" in conf: asr_transform = aps_transform("asr")(**conf["asr_transform"]) if "enh_transform" in conf: enh_transform = aps_transform("enh")(**conf["enh_transform"]) if enh_transform: nnet = asr_cls(enh_transform=enh_transform, asr_transform=asr_transform, **conf["nnet_conf"]) elif asr_transform: nnet = asr_cls(asr_transform=asr_transform, **conf["nnet_conf"]) else: nnet = asr_cls(**conf["nnet_conf"]) task = aps_task(conf["task"], nnet, **conf["task_conf"]) Trainer = aps_trainer(args.trainer, distributed=False) trainer = Trainer(task, device_ids=args.device_id, checkpoint=args.checkpoint, resume=args.resume, init=args.init, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, reduction_tag="#tok", **conf["trainer_conf"]) # dump yaml configurations conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) # dump dict dump_dict(f"{args.checkpoint}/dict", vocab_dict, reverse=False) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def run(args): # set random seed seed = set_seed(args.seed) if seed is not None: print(f"Set random seed as {seed}") conf, vocab = load_lm_conf(args.conf, args.dict) print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True) nnet = aps_asr_nnet(conf["nnet"])(**conf["nnet_conf"]) task = aps_task(conf["task"], nnet, **conf["task_conf"]) train_worker(task, conf, vocab, args)
def test_mvdr_att(att_type, att_kwargs): nnet_cls = aps_asr_nnet("asr@enh_att") vocab_size = 100 batch_size = 4 num_channels = 4 enh_kwargs = { "rnn": "lstm", "num_layers": 2, "rnn_inp_proj": 512, "hidden_size": 512, "dropout": 0.2, "bidirectional": False, "mvdr_att_dim": 512, "mask_norm": True, "num_bins": 257 } asr_transform = AsrTransform(feats="abs-mel-log-cmvn", frame_len=400, frame_hop=160, window="hamm") enh_transform = EnhTransform(feats="spectrogram-log-cmvn-ipd", frame_len=400, frame_hop=160, window="hamm", ipd_index="0,1;0,2;0,3", cos_ipd=True) mvdr_att_asr = nnet_cls(enh_input_size=257 * 4, vocab_size=vocab_size, sos=0, eos=1, ctc=True, enh_type="rnn_mask_mvdr", enh_kwargs=enh_kwargs, asr_transform=asr_transform, enh_transform=enh_transform, att_type=att_type, att_kwargs=att_kwargs, enc_type="pytorch_rnn", enc_proj=256, enc_kwargs=default_rnn_enc_kwargs, dec_dim=512, dec_kwargs=default_rnn_dec_kwargs) x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size, num_channels=num_channels) z, _, _, _ = mvdr_att_asr(x, x_len, y, y_len) assert z.shape == th.Size([4, u + 1, vocab_size - 1])
def run(args): # set random seed seed = set_seed(args.seed) if seed is not None: print(f"Set random seed as {seed}") conf, vocab = load_lm_conf(args.conf, args.dict) print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True) data_conf = conf["data_conf"] load_conf = { "vocab_dict": vocab, "num_workers": args.num_workers, "sos": vocab["<sos>"], "eos": vocab["<eos>"], "fmt": data_conf["fmt"], "max_batch_size": args.batch_size } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, **data_conf["train"], **load_conf) dev_loader = aps_dataloader(train=False, **data_conf["valid"], **load_conf) nnet = aps_asr_nnet(conf["nnet"])(**conf["nnet_conf"]) task = aps_task(conf["task"], nnet, **conf["task_conf"]) Trainer = aps_trainer(args.trainer, distributed=False) trainer = Trainer(task, device_ids=args.device_id, checkpoint=args.checkpoint, resume=args.resume, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, reduction_tag="#tok", **conf["trainer_conf"]) # dump configurations conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def run(args): _ = set_seed(args.seed) conf, vocab = load_lm_conf(args.conf, args.dict) print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True) nnet = aps_asr_nnet(conf["nnet"])(**conf["nnet_conf"]) other_loader_conf = { "vocab_dict": vocab, "sos": vocab["<sos>"], "eos": vocab["<eos>"], } start_trainer(args.trainer, conf, nnet, args, reduction_tag="#tok", other_loader_conf=other_loader_conf) dump_dict(f"{args.checkpoint}/dict", vocab, reverse=False)
def _load(self, cpt_dir: str, cpt_tag: str = "best", task: str = "asr") -> Tuple[int, nn.Module, Dict]: if task not in ["asr", "sse"]: raise ValueError(f"Unknown task name: {task}") cpt_dir = pathlib.Path(cpt_dir) # load checkpoint cpt = th.load(cpt_dir / f"{cpt_tag}.pt.tar", map_location="cpu") with open(cpt_dir / "train.yaml", "r") as f: conf = yaml.full_load(f) if task == "asr": net_cls = aps_asr_nnet(conf["nnet"]) else: net_cls = aps_sse_nnet(conf["nnet"]) asr_transform = None enh_transform = None self.accept_raw = False if "asr_transform" in conf: asr_transform = aps_transform("asr")(**conf["asr_transform"]) # if no STFT layer self.accept_raw = asr_transform.spectra_index != -1 if "enh_transform" in conf: enh_transform = aps_transform("enh")(**conf["enh_transform"]) self.accept_raw = True if enh_transform and asr_transform: nnet = net_cls(enh_transform=enh_transform, asr_transform=asr_transform, **conf["nnet_conf"]) elif asr_transform: nnet = net_cls(asr_transform=asr_transform, **conf["nnet_conf"]) elif enh_transform: nnet = net_cls(enh_transform=enh_transform, **conf["nnet_conf"]) else: nnet = net_cls(**conf["nnet_conf"]) nnet.load_state_dict(cpt["model_state"]) return cpt["epoch"], nnet, conf
def test_xfmr_encoder(enc_type, enc_kwargs): nnet_cls = aps_asr_nnet("asr@xfmr") vocab_size = 100 batch_size = 4 asr_transform = AsrTransform(feats="fbank-log-cmvn", frame_len=400, frame_hop=160, window="hamm") xfmr_encoders = ["xfmr_abs", "xfmr_rel", "xfmr_xl", "cfmr_xl"] xfmr_asr = nnet_cls( input_size=80, vocab_size=vocab_size, sos=0, eos=1, ctc=True, asr_transform=asr_transform, enc_type=enc_type, enc_proj=512 if enc_type not in xfmr_encoders else None, enc_kwargs=enc_kwargs, dec_type="xfmr_abs", dec_kwargs=default_xfmr_dec_kwargs) x, x_len, y, y_len, u = gen_egs(vocab_size, batch_size) z, _, _, _ = xfmr_asr(x, x_len, y, y_len) assert z.shape == th.Size([4, u + 1, vocab_size - 1])