Esempio n. 1
0
def run(args):
    # set random seed
    seed = set_seed(args.seed)
    if seed is not None:
        print(f"Set random seed as {seed}")

    conf, vocab_dict = load_am_conf(args.conf, args.dict)
    print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True)
    print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True)
    data_conf = conf["data_conf"]
    load_conf = {
        "fmt": data_conf["fmt"],
        "vocab_dict": vocab_dict,
        "num_workers": args.num_workers,
        "max_batch_size": args.batch_size
    }
    load_conf.update(data_conf["loader"])
    trn_loader = aps_dataloader(train=True, **data_conf["train"], **load_conf)
    dev_loader = aps_dataloader(train=False, **data_conf["valid"], **load_conf)

    asr_cls = aps_asr_nnet(conf["nnet"])
    asr_transform = None
    enh_transform = None
    if "asr_transform" in conf:
        asr_transform = aps_transform("asr")(**conf["asr_transform"])
    if "enh_transform" in conf:
        enh_transform = aps_transform("enh")(**conf["enh_transform"])

    if enh_transform:
        nnet = asr_cls(enh_transform=enh_transform,
                       asr_transform=asr_transform,
                       **conf["nnet_conf"])
    elif asr_transform:
        nnet = asr_cls(asr_transform=asr_transform, **conf["nnet_conf"])
    else:
        nnet = asr_cls(**conf["nnet_conf"])

    task = aps_task(conf["task"], nnet, **conf["task_conf"])
    Trainer = aps_trainer(args.trainer, distributed=False)
    trainer = Trainer(task,
                      device_ids=args.device_id,
                      checkpoint=args.checkpoint,
                      resume=args.resume,
                      init=args.init,
                      save_interval=args.save_interval,
                      prog_interval=args.prog_interval,
                      tensorboard=args.tensorboard,
                      reduction_tag="#tok",
                      **conf["trainer_conf"])
    # dump yaml configurations
    conf["cmd_args"] = vars(args)
    with open(f"{args.checkpoint}/train.yaml", "w") as f:
        yaml.dump(conf, f)
    # dump dict
    dump_dict(f"{args.checkpoint}/dict", vocab_dict, reverse=False)

    trainer.run(trn_loader,
                dev_loader,
                num_epochs=args.epochs,
                eval_interval=args.eval_interval)
Esempio n. 2
0
def train_worker(task, conf, vocab_dict, args):
    """
    Initalize training workers
    """
    # init torch/horovod backend
    distributed.init(args.distributed)
    rank = distributed.rank()

    Trainer = aps_trainer(args.trainer, distributed=True)
    # construct trainer
    # torch.distributed.launch will provide
    # environment variables, and requires that you use init_method="env://".
    trainer = Trainer(task,
                      rank=rank,
                      device_ids=args.device_ids,
                      checkpoint=args.checkpoint,
                      resume=args.resume,
                      init=args.init,
                      save_interval=args.save_interval,
                      prog_interval=args.prog_interval,
                      tensorboard=args.tensorboard,
                      reduction_tag="#tok",
                      **conf["trainer_conf"])

    # dump configurations
    if rank == 0:
        conf["cmd_args"] = vars(args)
        with open(f"{args.checkpoint}/train.yaml", "w") as f:
            yaml.dump(conf, f)
        dump_dict(f"{args.checkpoint}/dict", vocab_dict, reverse=False)

    num_process = len(args.device_ids.split(","))
    if num_process != distributed.world_size():
        raise RuntimeError(f"Number of process != world size: {num_process} " +
                           f"vs {distributed.world_size()}")
    data_conf = conf["data_conf"]
    load_conf = {
        "fmt": data_conf["fmt"],
        "vocab_dict": vocab_dict,
        "num_workers": args.num_workers // num_process
    }
    load_conf.update(data_conf["loader"])
    trn_loader = aps_dataloader(train=True,
                                distributed=True,
                                max_batch_size=args.batch_size // num_process,
                                **load_conf,
                                **data_conf["train"])
    dev_loader = aps_dataloader(train=False,
                                distributed=False,
                                max_batch_size=args.batch_size //
                                args.dev_batch_factor,
                                **load_conf,
                                **data_conf["valid"])
    trainer.run(trn_loader,
                dev_loader,
                num_epochs=args.epochs,
                eval_interval=args.eval_interval)
Esempio n. 3
0
def train_worker(task, conf, args):
    """
    Initalize training workers
    """
    # init torch/horovod backend
    distributed.init(args.distributed)
    rank = distributed.rank()

    Trainer = aps_trainer(args.trainer, distributed=True)
    trainer = Trainer(task,
                      rank=distributed.rank(),
                      device_ids=args.device_ids,
                      checkpoint=args.checkpoint,
                      resume=args.resume,
                      init=args.init,
                      save_interval=args.save_interval,
                      prog_interval=args.prog_interval,
                      tensorboard=args.tensorboard,
                      **conf["trainer_conf"])

    # dump configurations
    if rank == 0:
        conf["cmd_args"] = vars(args)
        with open(f"{args.checkpoint}/train.yaml", "w") as f:
            yaml.dump(conf, f)

    num_process = len(args.device_ids.split(","))
    if num_process != distributed.world_size():
        raise RuntimeError(f"Number of process != world size: {num_process} " +
                           f"vs {distributed.world_size()}")
    data_conf = conf["data_conf"]
    load_conf = {
        "fmt": data_conf["fmt"],
        "num_workers": args.num_workers // num_process
    }
    load_conf.update(data_conf["loader"])
    trn_loader = aps_dataloader(train=True,
                                distributed=True,
                                batch_size=args.batch_size // num_process,
                                **load_conf,
                                **data_conf["train"])
    dev_loader = aps_dataloader(train=False,
                                distributed=False,
                                batch_size=args.batch_size //
                                args.dev_batch_factor,
                                **load_conf,
                                **data_conf["valid"])
    if args.eval_interval <= 0:
        raise RuntimeError("For distributed training of SE/SS model, "
                           "--eval-interval must be larger than 0")
    trainer.run(trn_loader,
                dev_loader,
                num_epochs=args.epochs,
                eval_interval=args.eval_interval)
Esempio n. 4
0
def train_worker(task, conf, vocab_dict, args):
    """
    Initalize training workers
    """
    # init torch/horovod backend
    distributed.init(args.distributed)
    rank = distributed.rank()

    Trainer = aps_trainer(args.trainer, distributed=True)
    trainer = Trainer(task,
                      device_ids=args.device_ids,
                      checkpoint=args.checkpoint,
                      resume=args.resume,
                      save_interval=args.save_interval,
                      prog_interval=args.prog_interval,
                      tensorboard=args.tensorboard,
                      reduction_tag="#tok",
                      **conf["trainer_conf"])
    # dump configurations
    if rank == 0:
        conf["cmd_args"] = vars(args)
        with open(f"{args.checkpoint}/train.yaml", "w") as f:
            yaml.dump(conf, f)

    num_process = len(args.device_ids.split(","))
    if num_process != distributed.world_size():
        raise RuntimeError(f"Number of process != world size: {num_process} " +
                           f"vs {distributed.world_size()}")

    data_conf = conf["data_conf"]
    load_conf = {
        "vocab_dict": vocab_dict,
        "num_workers": args.num_workers // num_process,
        "sos": vocab_dict["<sos>"],
        "eos": vocab_dict["<eos>"],
        "fmt": data_conf["fmt"]
    }
    load_conf.update(data_conf["loader"])
    trn_loader = aps_dataloader(train=True,
                                distributed=True,
                                batch_size=args.batch_size // num_process,
                                **data_conf["train"],
                                **load_conf)
    dev_loader = aps_dataloader(train=False,
                                distributed=False,
                                batch_size=args.batch_size //
                                args.dev_batch_factor,
                                **data_conf["valid"],
                                **load_conf)
    trainer.run(trn_loader,
                dev_loader,
                num_epochs=args.epochs,
                eval_interval=args.eval_interval)
Esempio n. 5
0
File: train_ss.py Progetto: xmpx/aps
def run(args):
    # set random seed
    seed = set_seed(args.seed)
    if seed is not None:
        print(f"Set random seed as {seed}")

    conf = load_ss_conf(args.conf)
    print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True)
    print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True)

    data_conf = conf["data_conf"]
    load_conf = {
        "fmt": data_conf["fmt"],
        "batch_size": args.batch_size,
        "num_workers": args.num_workers,
    }
    load_conf.update(data_conf["loader"])
    trn_loader = aps_dataloader(train=True, **load_conf, **data_conf["train"])
    dev_loader = aps_dataloader(train=False, **load_conf, **data_conf["valid"])

    sse_cls = aps_sse_nnet(conf["nnet"])
    if "enh_transform" in conf:
        enh_transform = aps_transform("enh")(**conf["enh_transform"])
        nnet = sse_cls(enh_transform=enh_transform, **conf["nnet_conf"])
    else:
        nnet = sse_cls(**conf["nnet_conf"])

    task = aps_task(conf["task"], nnet, **conf["task_conf"])

    Trainer = aps_trainer(args.trainer, distributed=False)
    trainer = Trainer(task,
                      device_ids=args.device_id,
                      checkpoint=args.checkpoint,
                      resume=args.resume,
                      init=args.init,
                      save_interval=args.save_interval,
                      prog_interval=args.prog_interval,
                      tensorboard=args.tensorboard,
                      **conf["trainer_conf"])

    # dump configurations
    conf["cmd_args"] = vars(args)
    with open(f"{args.checkpoint}/train.yaml", "w") as f:
        yaml.dump(conf, f)

    trainer.run(trn_loader,
                dev_loader,
                num_epochs=args.epochs,
                eval_interval=args.eval_interval)
Esempio n. 6
0
File: train_lm.py Progetto: xmpx/aps
def run(args):
    # set random seed
    seed = set_seed(args.seed)
    if seed is not None:
        print(f"Set random seed as {seed}")

    conf, vocab = load_lm_conf(args.conf, args.dict)
    print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True)
    print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True)

    data_conf = conf["data_conf"]
    load_conf = {
        "vocab_dict": vocab,
        "num_workers": args.num_workers,
        "sos": vocab["<sos>"],
        "eos": vocab["<eos>"],
        "fmt": data_conf["fmt"],
        "max_batch_size": args.batch_size
    }
    load_conf.update(data_conf["loader"])
    trn_loader = aps_dataloader(train=True, **data_conf["train"], **load_conf)
    dev_loader = aps_dataloader(train=False, **data_conf["valid"], **load_conf)

    nnet = aps_asr_nnet(conf["nnet"])(**conf["nnet_conf"])
    task = aps_task(conf["task"], nnet, **conf["task_conf"])

    Trainer = aps_trainer(args.trainer, distributed=False)
    trainer = Trainer(task,
                      device_ids=args.device_id,
                      checkpoint=args.checkpoint,
                      resume=args.resume,
                      save_interval=args.save_interval,
                      prog_interval=args.prog_interval,
                      tensorboard=args.tensorboard,
                      reduction_tag="#tok",
                      **conf["trainer_conf"])
    # dump configurations
    conf["cmd_args"] = vars(args)
    with open(f"{args.checkpoint}/train.yaml", "w") as f:
        yaml.dump(conf, f)

    trainer.run(trn_loader,
                dev_loader,
                num_epochs=args.epochs,
                eval_interval=args.eval_interval)
Esempio n. 7
0
def test_lm_utt_loader(batch_size, obj):
    egs_dir = "data/dataloader/lm"
    loader = aps_dataloader(fmt="lm@utt",
                            sos=1,
                            eos=2,
                            text=f"{egs_dir}/{obj}",
                            vocab_dict=load_dict(f"{egs_dir}/dict"),
                            max_batch_size=batch_size,
                            min_batch_size=batch_size)
    for egs in loader:
        assert egs["src"].shape == egs["tgt"].shape
        assert egs["src"].shape[0] == batch_size
Esempio n. 8
0
def test_ss_online_loader(batch_size, chunk_size, num_workers):
    egs_dir = "data/dataloader/se"
    loader = aps_dataloader(fmt="se@online",
                            simu_cfg=f"{egs_dir}/online.opts",
                            sr=16000,
                            max_batch_size=batch_size,
                            chunk_size=chunk_size,
                            num_workers=num_workers)
    for egs in loader:
        assert egs["mix"].shape == th.Size([batch_size, chunk_size])
        assert len(egs["ref"]) == 2
        assert egs["ref"][0].shape == th.Size([batch_size, chunk_size])
Esempio n. 9
0
def test_lm_bptt_loader(batch_size, obj):
    egs_dir = "data/dataloader/lm"
    loader = aps_dataloader(fmt="lm@bptt",
                            sos=1,
                            eos=2,
                            text=f"{egs_dir}/{obj}",
                            vocab_dict=load_dict(f"{egs_dir}/dict"),
                            bptt_size=10,
                            max_batch_size=batch_size)
    for egs in loader:
        print(egs)
        assert egs["src"].shape == egs["tgt"].shape
        assert egs["src"].shape == th.Size([batch_size, 10])
Esempio n. 10
0
def test_ss_chunk_loader(batch_size, chunk_size, num_workers):
    egs_dir = "data/dataloader/se"
    loader = aps_dataloader(fmt="se@chunk",
                            mix_scp=f"{egs_dir}/wav.1.scp",
                            ref_scp=f"{egs_dir}/wav.1.scp",
                            sr=16000,
                            max_batch_size=batch_size,
                            chunk_size=chunk_size,
                            num_workers=num_workers)
    for egs in loader:
        assert egs["mix"].shape == th.Size([batch_size, chunk_size])
        assert egs["ref"].shape == th.Size([batch_size, chunk_size])
    loader = aps_dataloader(fmt="se@chunk",
                            mix_scp=f"{egs_dir}/wav.1.scp",
                            ref_scp=f"{egs_dir}/wav.1.scp,{egs_dir}/wav.1.scp",
                            sr=16000,
                            max_batch_size=batch_size,
                            chunk_size=chunk_size,
                            num_workers=num_workers)
    for egs in loader:
        assert egs["mix"].shape == th.Size([batch_size, chunk_size])
        assert len(egs["ref"]) == 2
        assert egs["ref"][0].shape == th.Size([batch_size, chunk_size])
Esempio n. 11
0
def test_am_raw_loader_const(batch_size, num_workers):
    egs_dir = "data/dataloader/am"
    loader = aps_dataloader(fmt="am@raw",
                            wav_scp=f"{egs_dir}/egs.wav.scp",
                            text=f"{egs_dir}/egs.fake.text",
                            utt2dur=f"{egs_dir}/egs.utt2dur",
                            vocab_dict=load_dict(f"{egs_dir}/dict"),
                            train=False,
                            sr=16000,
                            max_batch_size=batch_size,
                            batch_mode="constraint",
                            num_workers=num_workers)
    for egs in loader:
        for key in ["src_pad", "tgt_pad", "tgt_len", "src_len"]:
            assert key in egs
        assert egs["tgt_pad"].shape[-1] == egs["tgt_len"].max().item()
Esempio n. 12
0
def test_am_kaldi_loader(batch_size, num_workers):
    egs_dir = "data/dataloader/am"
    loader = aps_dataloader(fmt="am@kaldi",
                            feats_scp=f"{egs_dir}/egs.fbank.scp",
                            text=f"{egs_dir}/egs.fake.text",
                            vocab_dict=load_dict(f"{egs_dir}/dict"),
                            utt2num_frames=f"{egs_dir}/egs.fbank.num_frames",
                            train=False,
                            adapt_dur=900,
                            num_workers=num_workers,
                            min_batch_size=1,
                            max_batch_size=batch_size)
    for egs in loader:
        for key in ["src_pad", "tgt_pad", "tgt_len", "src_len"]:
            assert key in egs

        assert egs["src_pad"].shape == th.Size(
            [batch_size, egs["src_len"][0].item(), 80])
        assert egs["tgt_pad"].shape == th.Size(
            [batch_size, egs["tgt_len"].max().item()])
Esempio n. 13
0
def test_am_raw_loader(batch_size, num_workers):
    egs_dir = "data/dataloader/am"
    loader = aps_dataloader(fmt="am@raw",
                            wav_scp=f"{egs_dir}/egs.wav.scp",
                            text=f"{egs_dir}/egs.fake.text",
                            utt2dur=f"{egs_dir}/egs.utt2dur",
                            vocab_dict=load_dict(f"{egs_dir}/dict"),
                            train=False,
                            sr=16000,
                            adapt_dur=10,
                            num_workers=num_workers,
                            max_batch_size=batch_size,
                            min_batch_size=1)
    for egs in loader:
        for key in ["src_pad", "tgt_pad", "tgt_len", "src_len"]:
            assert key in egs
        assert egs["src_pad"].shape == th.Size(
            [batch_size, egs["src_len"][0].item()])
        assert egs["tgt_pad"].shape == th.Size(
            [batch_size, egs["tgt_len"].max().item()])