示例#1
0
def train_worker(task, conf, args):
    """
    Initalize training workers
    """
    # init torch/horovod backend
    distributed.init(args.distributed)
    rank = distributed.rank()

    Trainer = aps_trainer(args.trainer, distributed=True)
    trainer = Trainer(task,
                      rank=distributed.rank(),
                      device_ids=args.device_ids,
                      checkpoint=args.checkpoint,
                      resume=args.resume,
                      init=args.init,
                      save_interval=args.save_interval,
                      prog_interval=args.prog_interval,
                      tensorboard=args.tensorboard,
                      **conf["trainer_conf"])

    # dump configurations
    if rank == 0:
        conf["cmd_args"] = vars(args)
        with open(f"{args.checkpoint}/train.yaml", "w") as f:
            yaml.dump(conf, f)

    num_process = len(args.device_ids.split(","))
    if num_process != distributed.world_size():
        raise RuntimeError(f"Number of process != world size: {num_process} " +
                           f"vs {distributed.world_size()}")
    data_conf = conf["data_conf"]
    load_conf = {
        "fmt": data_conf["fmt"],
        "num_workers": args.num_workers // num_process
    }
    load_conf.update(data_conf["loader"])
    trn_loader = aps_dataloader(train=True,
                                distributed=True,
                                batch_size=args.batch_size // num_process,
                                **load_conf,
                                **data_conf["train"])
    dev_loader = aps_dataloader(train=False,
                                distributed=False,
                                batch_size=args.batch_size //
                                args.dev_batch_factor,
                                **load_conf,
                                **data_conf["valid"])
    if args.eval_interval <= 0:
        raise RuntimeError("For distributed training of SE/SS model, "
                           "--eval-interval must be larger than 0")
    trainer.run(trn_loader,
                dev_loader,
                num_epochs=args.epochs,
                eval_interval=args.eval_interval)
示例#2
0
 def __init__(self,
              dataset: dat.Dataset,
              num_workers: int = 4,
              chunk_size: int = 64000,
              batch_size: int = 16,
              distributed: bool = False,
              train: bool = True) -> None:
     self.dataset = dataset
     self.train = train
     self.batch_size = batch_size
     self.splitter = ChunkSplitter(chunk_size,
                                   train=train,
                                   hop=chunk_size // 2)
     if distributed:
         self.sampler = dat.DistributedSampler(
             dataset,
             shuffle=train,
             num_replicas=dist.world_size(),
             rank=dist.rank())
     else:
         self.sampler = None
     # just return batch of egs, support multiple workers
     # NOTE: batch_size is not the batch_size of the audio chunk
     self.eg_loader = dat.DataLoader(self.dataset,
                                     batch_size=min(batch_size, 64),
                                     num_workers=num_workers,
                                     sampler=self.sampler,
                                     shuffle=(train
                                              and self.sampler is None),
                                     collate_fn=self._collate)
示例#3
0
def train_worker(task, conf, vocab_dict, args):
    """
    Initalize training workers
    """
    # init torch/horovod backend
    distributed.init(args.distributed)
    rank = distributed.rank()

    Trainer = aps_trainer(args.trainer, distributed=True)
    # construct trainer
    # torch.distributed.launch will provide
    # environment variables, and requires that you use init_method="env://".
    trainer = Trainer(task,
                      rank=rank,
                      device_ids=args.device_ids,
                      checkpoint=args.checkpoint,
                      resume=args.resume,
                      init=args.init,
                      save_interval=args.save_interval,
                      prog_interval=args.prog_interval,
                      tensorboard=args.tensorboard,
                      reduction_tag="#tok",
                      **conf["trainer_conf"])

    # dump configurations
    if rank == 0:
        conf["cmd_args"] = vars(args)
        with open(f"{args.checkpoint}/train.yaml", "w") as f:
            yaml.dump(conf, f)
        dump_dict(f"{args.checkpoint}/dict", vocab_dict, reverse=False)

    num_process = len(args.device_ids.split(","))
    if num_process != distributed.world_size():
        raise RuntimeError(f"Number of process != world size: {num_process} " +
                           f"vs {distributed.world_size()}")
    data_conf = conf["data_conf"]
    load_conf = {
        "fmt": data_conf["fmt"],
        "vocab_dict": vocab_dict,
        "num_workers": args.num_workers // num_process
    }
    load_conf.update(data_conf["loader"])
    trn_loader = aps_dataloader(train=True,
                                distributed=True,
                                max_batch_size=args.batch_size // num_process,
                                **load_conf,
                                **data_conf["train"])
    dev_loader = aps_dataloader(train=False,
                                distributed=False,
                                max_batch_size=args.batch_size //
                                args.dev_batch_factor,
                                **load_conf,
                                **data_conf["valid"])
    trainer.run(trn_loader,
                dev_loader,
                num_epochs=args.epochs,
                eval_interval=args.eval_interval)
示例#4
0
def train_worker(task, conf, vocab_dict, args):
    """
    Initalize training workers
    """
    # init torch/horovod backend
    distributed.init(args.distributed)
    rank = distributed.rank()

    Trainer = aps_trainer(args.trainer, distributed=True)
    trainer = Trainer(task,
                      device_ids=args.device_ids,
                      checkpoint=args.checkpoint,
                      resume=args.resume,
                      save_interval=args.save_interval,
                      prog_interval=args.prog_interval,
                      tensorboard=args.tensorboard,
                      reduction_tag="#tok",
                      **conf["trainer_conf"])
    # dump configurations
    if rank == 0:
        conf["cmd_args"] = vars(args)
        with open(f"{args.checkpoint}/train.yaml", "w") as f:
            yaml.dump(conf, f)

    num_process = len(args.device_ids.split(","))
    if num_process != distributed.world_size():
        raise RuntimeError(f"Number of process != world size: {num_process} " +
                           f"vs {distributed.world_size()}")

    data_conf = conf["data_conf"]
    load_conf = {
        "vocab_dict": vocab_dict,
        "num_workers": args.num_workers // num_process,
        "sos": vocab_dict["<sos>"],
        "eos": vocab_dict["<eos>"],
        "fmt": data_conf["fmt"]
    }
    load_conf.update(data_conf["loader"])
    trn_loader = aps_dataloader(train=True,
                                distributed=True,
                                batch_size=args.batch_size // num_process,
                                **data_conf["train"],
                                **load_conf)
    dev_loader = aps_dataloader(train=False,
                                distributed=False,
                                batch_size=args.batch_size //
                                args.dev_batch_factor,
                                **data_conf["valid"],
                                **load_conf)
    trainer.run(trn_loader,
                dev_loader,
                num_epochs=args.epochs,
                eval_interval=args.eval_interval)
示例#5
0
def derive_indices(num_batches: int,
                   seed: int = 0,
                   shuffle: bool = True,
                   distributed: bool = False) -> List[int]:
    """
    Return indices for BatchSampler
    """
    if distributed:
        rank = dist.rank()
        world_size = dist.world_size()
        num_batches = num_batches * world_size
    if shuffle:
        g = th.Generator()
        g.manual_seed(seed)
        indices = th.randperm(num_batches, generator=g).tolist()
    else:
        indices = th.arange(num_batches).tolist()
    if distributed:
        return indices[rank:num_batches:world_size]
    else:
        return indices
示例#6
0
def start_trainer(trainer: str,
                  conf: Dict,
                  nnet: nn.Module,
                  args: Namespace,
                  reduction_tag: str = "none",
                  other_loader_conf: Dict = None) -> None:
    """
    Run the instance of the aps Trainer
    """
    is_distributed = args.distributed != "none"
    if is_distributed:
        # init torch/horovod backend
        distributed.init(args.distributed)
        rank = distributed.rank()
    else:
        rank = None

    task = aps_task(conf["task"], nnet, **conf["task_conf"])

    TrainerClass = aps_trainer(args.trainer, distributed=is_distributed)
    # construct trainer
    # torch.distributed.launch will provide
    # environment variables, and requires that you use init_method="env://".
    trainer = TrainerClass(task,
                           rank=rank,
                           device_ids=args.device_ids,
                           checkpoint=args.checkpoint,
                           resume=args.resume,
                           init=args.init,
                           save_interval=args.save_interval,
                           prog_interval=args.prog_interval,
                           tensorboard=args.tensorboard,
                           reduction_tag=reduction_tag,
                           **conf["trainer_conf"])
    # save cmd options
    if rank in [0, None]:
        conf["cmd_args"] = vars(args)
        with open(f"{args.checkpoint}/train.yaml", "w") as f:
            yaml.dump(conf, f)

    # check if #devices == world_size
    if is_distributed:
        num_process = len(args.device_ids.split(","))
        if num_process != distributed.world_size():
            raise RuntimeError(
                f"Number of process != world size: {num_process} " +
                f"vs {distributed.world_size()}")
    else:
        num_process = 1

    data_conf = conf["data_conf"]
    loader_conf = {
        "fmt": data_conf["fmt"],
        "num_workers": args.num_workers // num_process
    }
    loader_conf.update(data_conf["loader"])
    if other_loader_conf:
        loader_conf.update(other_loader_conf)

    trn_loader = aps_dataloader(train=True,
                                distributed=is_distributed,
                                max_batch_size=args.batch_size // num_process,
                                **loader_conf,
                                **data_conf["train"])
    dev_loader = aps_dataloader(train=False,
                                distributed=False,
                                max_batch_size=args.batch_size //
                                args.dev_batch_factor,
                                **loader_conf,
                                **data_conf["valid"])
    trainer.run(trn_loader,
                dev_loader,
                num_epochs=args.epochs,
                eval_interval=args.eval_interval)

    return trainer