def train_worker(task, conf, args): """ Initalize training workers """ # init torch/horovod backend distributed.init(args.distributed) rank = distributed.rank() Trainer = aps_trainer(args.trainer, distributed=True) trainer = Trainer(task, rank=distributed.rank(), device_ids=args.device_ids, checkpoint=args.checkpoint, resume=args.resume, init=args.init, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, **conf["trainer_conf"]) # dump configurations if rank == 0: conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) num_process = len(args.device_ids.split(",")) if num_process != distributed.world_size(): raise RuntimeError(f"Number of process != world size: {num_process} " + f"vs {distributed.world_size()}") data_conf = conf["data_conf"] load_conf = { "fmt": data_conf["fmt"], "num_workers": args.num_workers // num_process } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, distributed=True, batch_size=args.batch_size // num_process, **load_conf, **data_conf["train"]) dev_loader = aps_dataloader(train=False, distributed=False, batch_size=args.batch_size // args.dev_batch_factor, **load_conf, **data_conf["valid"]) if args.eval_interval <= 0: raise RuntimeError("For distributed training of SE/SS model, " "--eval-interval must be larger than 0") trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def __init__(self, dataset: dat.Dataset, num_workers: int = 4, chunk_size: int = 64000, batch_size: int = 16, distributed: bool = False, train: bool = True) -> None: self.dataset = dataset self.train = train self.batch_size = batch_size self.splitter = ChunkSplitter(chunk_size, train=train, hop=chunk_size // 2) if distributed: self.sampler = dat.DistributedSampler( dataset, shuffle=train, num_replicas=dist.world_size(), rank=dist.rank()) else: self.sampler = None # just return batch of egs, support multiple workers # NOTE: batch_size is not the batch_size of the audio chunk self.eg_loader = dat.DataLoader(self.dataset, batch_size=min(batch_size, 64), num_workers=num_workers, sampler=self.sampler, shuffle=(train and self.sampler is None), collate_fn=self._collate)
def train_worker(task, conf, vocab_dict, args): """ Initalize training workers """ # init torch/horovod backend distributed.init(args.distributed) rank = distributed.rank() Trainer = aps_trainer(args.trainer, distributed=True) # construct trainer # torch.distributed.launch will provide # environment variables, and requires that you use init_method="env://". trainer = Trainer(task, rank=rank, device_ids=args.device_ids, checkpoint=args.checkpoint, resume=args.resume, init=args.init, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, reduction_tag="#tok", **conf["trainer_conf"]) # dump configurations if rank == 0: conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) dump_dict(f"{args.checkpoint}/dict", vocab_dict, reverse=False) num_process = len(args.device_ids.split(",")) if num_process != distributed.world_size(): raise RuntimeError(f"Number of process != world size: {num_process} " + f"vs {distributed.world_size()}") data_conf = conf["data_conf"] load_conf = { "fmt": data_conf["fmt"], "vocab_dict": vocab_dict, "num_workers": args.num_workers // num_process } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, distributed=True, max_batch_size=args.batch_size // num_process, **load_conf, **data_conf["train"]) dev_loader = aps_dataloader(train=False, distributed=False, max_batch_size=args.batch_size // args.dev_batch_factor, **load_conf, **data_conf["valid"]) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def train_worker(task, conf, vocab_dict, args): """ Initalize training workers """ # init torch/horovod backend distributed.init(args.distributed) rank = distributed.rank() Trainer = aps_trainer(args.trainer, distributed=True) trainer = Trainer(task, device_ids=args.device_ids, checkpoint=args.checkpoint, resume=args.resume, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, reduction_tag="#tok", **conf["trainer_conf"]) # dump configurations if rank == 0: conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) num_process = len(args.device_ids.split(",")) if num_process != distributed.world_size(): raise RuntimeError(f"Number of process != world size: {num_process} " + f"vs {distributed.world_size()}") data_conf = conf["data_conf"] load_conf = { "vocab_dict": vocab_dict, "num_workers": args.num_workers // num_process, "sos": vocab_dict["<sos>"], "eos": vocab_dict["<eos>"], "fmt": data_conf["fmt"] } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, distributed=True, batch_size=args.batch_size // num_process, **data_conf["train"], **load_conf) dev_loader = aps_dataloader(train=False, distributed=False, batch_size=args.batch_size // args.dev_batch_factor, **data_conf["valid"], **load_conf) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def derive_indices(num_batches: int, seed: int = 0, shuffle: bool = True, distributed: bool = False) -> List[int]: """ Return indices for BatchSampler """ if distributed: rank = dist.rank() world_size = dist.world_size() num_batches = num_batches * world_size if shuffle: g = th.Generator() g.manual_seed(seed) indices = th.randperm(num_batches, generator=g).tolist() else: indices = th.arange(num_batches).tolist() if distributed: return indices[rank:num_batches:world_size] else: return indices
def start_trainer(trainer: str, conf: Dict, nnet: nn.Module, args: Namespace, reduction_tag: str = "none", other_loader_conf: Dict = None) -> None: """ Run the instance of the aps Trainer """ is_distributed = args.distributed != "none" if is_distributed: # init torch/horovod backend distributed.init(args.distributed) rank = distributed.rank() else: rank = None task = aps_task(conf["task"], nnet, **conf["task_conf"]) TrainerClass = aps_trainer(args.trainer, distributed=is_distributed) # construct trainer # torch.distributed.launch will provide # environment variables, and requires that you use init_method="env://". trainer = TrainerClass(task, rank=rank, device_ids=args.device_ids, checkpoint=args.checkpoint, resume=args.resume, init=args.init, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, reduction_tag=reduction_tag, **conf["trainer_conf"]) # save cmd options if rank in [0, None]: conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) # check if #devices == world_size if is_distributed: num_process = len(args.device_ids.split(",")) if num_process != distributed.world_size(): raise RuntimeError( f"Number of process != world size: {num_process} " + f"vs {distributed.world_size()}") else: num_process = 1 data_conf = conf["data_conf"] loader_conf = { "fmt": data_conf["fmt"], "num_workers": args.num_workers // num_process } loader_conf.update(data_conf["loader"]) if other_loader_conf: loader_conf.update(other_loader_conf) trn_loader = aps_dataloader(train=True, distributed=is_distributed, max_batch_size=args.batch_size // num_process, **loader_conf, **data_conf["train"]) dev_loader = aps_dataloader(train=False, distributed=False, max_batch_size=args.batch_size // args.dev_batch_factor, **loader_conf, **data_conf["valid"]) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval) return trainer