def run(args): # set random seed seed = set_seed(args.seed) if seed is not None: print(f"Set random seed as {seed}") conf, vocab_dict = load_am_conf(args.conf, args.dict) print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True) data_conf = conf["data_conf"] load_conf = { "fmt": data_conf["fmt"], "vocab_dict": vocab_dict, "num_workers": args.num_workers, "max_batch_size": args.batch_size } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, **data_conf["train"], **load_conf) dev_loader = aps_dataloader(train=False, **data_conf["valid"], **load_conf) asr_cls = aps_asr_nnet(conf["nnet"]) asr_transform = None enh_transform = None if "asr_transform" in conf: asr_transform = aps_transform("asr")(**conf["asr_transform"]) if "enh_transform" in conf: enh_transform = aps_transform("enh")(**conf["enh_transform"]) if enh_transform: nnet = asr_cls(enh_transform=enh_transform, asr_transform=asr_transform, **conf["nnet_conf"]) elif asr_transform: nnet = asr_cls(asr_transform=asr_transform, **conf["nnet_conf"]) else: nnet = asr_cls(**conf["nnet_conf"]) task = aps_task(conf["task"], nnet, **conf["task_conf"]) Trainer = aps_trainer(args.trainer, distributed=False) trainer = Trainer(task, device_ids=args.device_id, checkpoint=args.checkpoint, resume=args.resume, init=args.init, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, reduction_tag="#tok", **conf["trainer_conf"]) # dump yaml configurations conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) # dump dict dump_dict(f"{args.checkpoint}/dict", vocab_dict, reverse=False) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def train_worker(task, conf, vocab_dict, args): """ Initalize training workers """ # init torch/horovod backend distributed.init(args.distributed) rank = distributed.rank() Trainer = aps_trainer(args.trainer, distributed=True) # construct trainer # torch.distributed.launch will provide # environment variables, and requires that you use init_method="env://". trainer = Trainer(task, rank=rank, device_ids=args.device_ids, checkpoint=args.checkpoint, resume=args.resume, init=args.init, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, reduction_tag="#tok", **conf["trainer_conf"]) # dump configurations if rank == 0: conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) dump_dict(f"{args.checkpoint}/dict", vocab_dict, reverse=False) num_process = len(args.device_ids.split(",")) if num_process != distributed.world_size(): raise RuntimeError(f"Number of process != world size: {num_process} " + f"vs {distributed.world_size()}") data_conf = conf["data_conf"] load_conf = { "fmt": data_conf["fmt"], "vocab_dict": vocab_dict, "num_workers": args.num_workers // num_process } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, distributed=True, max_batch_size=args.batch_size // num_process, **load_conf, **data_conf["train"]) dev_loader = aps_dataloader(train=False, distributed=False, max_batch_size=args.batch_size // args.dev_batch_factor, **load_conf, **data_conf["valid"]) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def train_worker(task, conf, args): """ Initalize training workers """ # init torch/horovod backend distributed.init(args.distributed) rank = distributed.rank() Trainer = aps_trainer(args.trainer, distributed=True) trainer = Trainer(task, rank=distributed.rank(), device_ids=args.device_ids, checkpoint=args.checkpoint, resume=args.resume, init=args.init, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, **conf["trainer_conf"]) # dump configurations if rank == 0: conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) num_process = len(args.device_ids.split(",")) if num_process != distributed.world_size(): raise RuntimeError(f"Number of process != world size: {num_process} " + f"vs {distributed.world_size()}") data_conf = conf["data_conf"] load_conf = { "fmt": data_conf["fmt"], "num_workers": args.num_workers // num_process } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, distributed=True, batch_size=args.batch_size // num_process, **load_conf, **data_conf["train"]) dev_loader = aps_dataloader(train=False, distributed=False, batch_size=args.batch_size // args.dev_batch_factor, **load_conf, **data_conf["valid"]) if args.eval_interval <= 0: raise RuntimeError("For distributed training of SE/SS model, " "--eval-interval must be larger than 0") trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def train_worker(task, conf, vocab_dict, args): """ Initalize training workers """ # init torch/horovod backend distributed.init(args.distributed) rank = distributed.rank() Trainer = aps_trainer(args.trainer, distributed=True) trainer = Trainer(task, device_ids=args.device_ids, checkpoint=args.checkpoint, resume=args.resume, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, reduction_tag="#tok", **conf["trainer_conf"]) # dump configurations if rank == 0: conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) num_process = len(args.device_ids.split(",")) if num_process != distributed.world_size(): raise RuntimeError(f"Number of process != world size: {num_process} " + f"vs {distributed.world_size()}") data_conf = conf["data_conf"] load_conf = { "vocab_dict": vocab_dict, "num_workers": args.num_workers // num_process, "sos": vocab_dict["<sos>"], "eos": vocab_dict["<eos>"], "fmt": data_conf["fmt"] } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, distributed=True, batch_size=args.batch_size // num_process, **data_conf["train"], **load_conf) dev_loader = aps_dataloader(train=False, distributed=False, batch_size=args.batch_size // args.dev_batch_factor, **data_conf["valid"], **load_conf) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def run(args): # set random seed seed = set_seed(args.seed) if seed is not None: print(f"Set random seed as {seed}") conf = load_ss_conf(args.conf) print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True) data_conf = conf["data_conf"] load_conf = { "fmt": data_conf["fmt"], "batch_size": args.batch_size, "num_workers": args.num_workers, } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, **load_conf, **data_conf["train"]) dev_loader = aps_dataloader(train=False, **load_conf, **data_conf["valid"]) sse_cls = aps_sse_nnet(conf["nnet"]) if "enh_transform" in conf: enh_transform = aps_transform("enh")(**conf["enh_transform"]) nnet = sse_cls(enh_transform=enh_transform, **conf["nnet_conf"]) else: nnet = sse_cls(**conf["nnet_conf"]) task = aps_task(conf["task"], nnet, **conf["task_conf"]) Trainer = aps_trainer(args.trainer, distributed=False) trainer = Trainer(task, device_ids=args.device_id, checkpoint=args.checkpoint, resume=args.resume, init=args.init, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, **conf["trainer_conf"]) # dump configurations conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def run(args): # set random seed seed = set_seed(args.seed) if seed is not None: print(f"Set random seed as {seed}") conf, vocab = load_lm_conf(args.conf, args.dict) print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True) data_conf = conf["data_conf"] load_conf = { "vocab_dict": vocab, "num_workers": args.num_workers, "sos": vocab["<sos>"], "eos": vocab["<eos>"], "fmt": data_conf["fmt"], "max_batch_size": args.batch_size } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, **data_conf["train"], **load_conf) dev_loader = aps_dataloader(train=False, **data_conf["valid"], **load_conf) nnet = aps_asr_nnet(conf["nnet"])(**conf["nnet_conf"]) task = aps_task(conf["task"], nnet, **conf["task_conf"]) Trainer = aps_trainer(args.trainer, distributed=False) trainer = Trainer(task, device_ids=args.device_id, checkpoint=args.checkpoint, resume=args.resume, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, reduction_tag="#tok", **conf["trainer_conf"]) # dump configurations conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def test_lm_utt_loader(batch_size, obj): egs_dir = "data/dataloader/lm" loader = aps_dataloader(fmt="lm@utt", sos=1, eos=2, text=f"{egs_dir}/{obj}", vocab_dict=load_dict(f"{egs_dir}/dict"), max_batch_size=batch_size, min_batch_size=batch_size) for egs in loader: assert egs["src"].shape == egs["tgt"].shape assert egs["src"].shape[0] == batch_size
def test_ss_online_loader(batch_size, chunk_size, num_workers): egs_dir = "data/dataloader/se" loader = aps_dataloader(fmt="se@online", simu_cfg=f"{egs_dir}/online.opts", sr=16000, max_batch_size=batch_size, chunk_size=chunk_size, num_workers=num_workers) for egs in loader: assert egs["mix"].shape == th.Size([batch_size, chunk_size]) assert len(egs["ref"]) == 2 assert egs["ref"][0].shape == th.Size([batch_size, chunk_size])
def test_lm_bptt_loader(batch_size, obj): egs_dir = "data/dataloader/lm" loader = aps_dataloader(fmt="lm@bptt", sos=1, eos=2, text=f"{egs_dir}/{obj}", vocab_dict=load_dict(f"{egs_dir}/dict"), bptt_size=10, max_batch_size=batch_size) for egs in loader: print(egs) assert egs["src"].shape == egs["tgt"].shape assert egs["src"].shape == th.Size([batch_size, 10])
def test_ss_chunk_loader(batch_size, chunk_size, num_workers): egs_dir = "data/dataloader/se" loader = aps_dataloader(fmt="se@chunk", mix_scp=f"{egs_dir}/wav.1.scp", ref_scp=f"{egs_dir}/wav.1.scp", sr=16000, max_batch_size=batch_size, chunk_size=chunk_size, num_workers=num_workers) for egs in loader: assert egs["mix"].shape == th.Size([batch_size, chunk_size]) assert egs["ref"].shape == th.Size([batch_size, chunk_size]) loader = aps_dataloader(fmt="se@chunk", mix_scp=f"{egs_dir}/wav.1.scp", ref_scp=f"{egs_dir}/wav.1.scp,{egs_dir}/wav.1.scp", sr=16000, max_batch_size=batch_size, chunk_size=chunk_size, num_workers=num_workers) for egs in loader: assert egs["mix"].shape == th.Size([batch_size, chunk_size]) assert len(egs["ref"]) == 2 assert egs["ref"][0].shape == th.Size([batch_size, chunk_size])
def test_am_raw_loader_const(batch_size, num_workers): egs_dir = "data/dataloader/am" loader = aps_dataloader(fmt="am@raw", wav_scp=f"{egs_dir}/egs.wav.scp", text=f"{egs_dir}/egs.fake.text", utt2dur=f"{egs_dir}/egs.utt2dur", vocab_dict=load_dict(f"{egs_dir}/dict"), train=False, sr=16000, max_batch_size=batch_size, batch_mode="constraint", num_workers=num_workers) for egs in loader: for key in ["src_pad", "tgt_pad", "tgt_len", "src_len"]: assert key in egs assert egs["tgt_pad"].shape[-1] == egs["tgt_len"].max().item()
def test_am_kaldi_loader(batch_size, num_workers): egs_dir = "data/dataloader/am" loader = aps_dataloader(fmt="am@kaldi", feats_scp=f"{egs_dir}/egs.fbank.scp", text=f"{egs_dir}/egs.fake.text", vocab_dict=load_dict(f"{egs_dir}/dict"), utt2num_frames=f"{egs_dir}/egs.fbank.num_frames", train=False, adapt_dur=900, num_workers=num_workers, min_batch_size=1, max_batch_size=batch_size) for egs in loader: for key in ["src_pad", "tgt_pad", "tgt_len", "src_len"]: assert key in egs assert egs["src_pad"].shape == th.Size( [batch_size, egs["src_len"][0].item(), 80]) assert egs["tgt_pad"].shape == th.Size( [batch_size, egs["tgt_len"].max().item()])
def test_am_raw_loader(batch_size, num_workers): egs_dir = "data/dataloader/am" loader = aps_dataloader(fmt="am@raw", wav_scp=f"{egs_dir}/egs.wav.scp", text=f"{egs_dir}/egs.fake.text", utt2dur=f"{egs_dir}/egs.utt2dur", vocab_dict=load_dict(f"{egs_dir}/dict"), train=False, sr=16000, adapt_dur=10, num_workers=num_workers, max_batch_size=batch_size, min_batch_size=1) for egs in loader: for key in ["src_pad", "tgt_pad", "tgt_len", "src_len"]: assert key in egs assert egs["src_pad"].shape == th.Size( [batch_size, egs["src_len"][0].item()]) assert egs["tgt_pad"].shape == th.Size( [batch_size, egs["tgt_len"].max().item()])