def train_worker(task, conf, vocab_dict, args): """ Initalize training workers """ # init torch/horovod backend distributed.init(args.distributed) rank = distributed.rank() Trainer = aps_trainer(args.trainer, distributed=True) # construct trainer # torch.distributed.launch will provide # environment variables, and requires that you use init_method="env://". trainer = Trainer(task, rank=rank, device_ids=args.device_ids, checkpoint=args.checkpoint, resume=args.resume, init=args.init, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, reduction_tag="#tok", **conf["trainer_conf"]) # dump configurations if rank == 0: conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) dump_dict(f"{args.checkpoint}/dict", vocab_dict, reverse=False) num_process = len(args.device_ids.split(",")) if num_process != distributed.world_size(): raise RuntimeError(f"Number of process != world size: {num_process} " + f"vs {distributed.world_size()}") data_conf = conf["data_conf"] load_conf = { "fmt": data_conf["fmt"], "vocab_dict": vocab_dict, "num_workers": args.num_workers // num_process } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, distributed=True, max_batch_size=args.batch_size // num_process, **load_conf, **data_conf["train"]) dev_loader = aps_dataloader(train=False, distributed=False, max_batch_size=args.batch_size // args.dev_batch_factor, **load_conf, **data_conf["valid"]) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def train_worker(task, conf, args): """ Initalize training workers """ # init torch/horovod backend distributed.init(args.distributed) rank = distributed.rank() Trainer = aps_trainer(args.trainer, distributed=True) trainer = Trainer(task, rank=distributed.rank(), device_ids=args.device_ids, checkpoint=args.checkpoint, resume=args.resume, init=args.init, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, **conf["trainer_conf"]) # dump configurations if rank == 0: conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) num_process = len(args.device_ids.split(",")) if num_process != distributed.world_size(): raise RuntimeError(f"Number of process != world size: {num_process} " + f"vs {distributed.world_size()}") data_conf = conf["data_conf"] load_conf = { "fmt": data_conf["fmt"], "num_workers": args.num_workers // num_process } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, distributed=True, batch_size=args.batch_size // num_process, **load_conf, **data_conf["train"]) dev_loader = aps_dataloader(train=False, distributed=False, batch_size=args.batch_size // args.dev_batch_factor, **load_conf, **data_conf["valid"]) if args.eval_interval <= 0: raise RuntimeError("For distributed training of SE/SS model, " "--eval-interval must be larger than 0") trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def train_worker(task, conf, vocab_dict, args): """ Initalize training workers """ # init torch/horovod backend distributed.init(args.distributed) rank = distributed.rank() Trainer = aps_trainer(args.trainer, distributed=True) trainer = Trainer(task, device_ids=args.device_ids, checkpoint=args.checkpoint, resume=args.resume, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, reduction_tag="#tok", **conf["trainer_conf"]) # dump configurations if rank == 0: conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) num_process = len(args.device_ids.split(",")) if num_process != distributed.world_size(): raise RuntimeError(f"Number of process != world size: {num_process} " + f"vs {distributed.world_size()}") data_conf = conf["data_conf"] load_conf = { "vocab_dict": vocab_dict, "num_workers": args.num_workers // num_process, "sos": vocab_dict["<sos>"], "eos": vocab_dict["<eos>"], "fmt": data_conf["fmt"] } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, distributed=True, batch_size=args.batch_size // num_process, **data_conf["train"], **load_conf) dev_loader = aps_dataloader(train=False, distributed=False, batch_size=args.batch_size // args.dev_batch_factor, **data_conf["valid"], **load_conf) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def start_trainer(trainer: str, conf: Dict, nnet: nn.Module, args: Namespace, reduction_tag: str = "none", other_loader_conf: Dict = None) -> None: """ Run the instance of the aps Trainer """ is_distributed = args.distributed != "none" if is_distributed: # init torch/horovod backend distributed.init(args.distributed) rank = distributed.rank() else: rank = None task = aps_task(conf["task"], nnet, **conf["task_conf"]) TrainerClass = aps_trainer(args.trainer, distributed=is_distributed) # construct trainer # torch.distributed.launch will provide # environment variables, and requires that you use init_method="env://". trainer = TrainerClass(task, rank=rank, device_ids=args.device_ids, checkpoint=args.checkpoint, resume=args.resume, init=args.init, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, reduction_tag=reduction_tag, **conf["trainer_conf"]) # save cmd options if rank in [0, None]: conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) # check if #devices == world_size if is_distributed: num_process = len(args.device_ids.split(",")) if num_process != distributed.world_size(): raise RuntimeError( f"Number of process != world size: {num_process} " + f"vs {distributed.world_size()}") else: num_process = 1 data_conf = conf["data_conf"] loader_conf = { "fmt": data_conf["fmt"], "num_workers": args.num_workers // num_process } loader_conf.update(data_conf["loader"]) if other_loader_conf: loader_conf.update(other_loader_conf) trn_loader = aps_dataloader(train=True, distributed=is_distributed, max_batch_size=args.batch_size // num_process, **loader_conf, **data_conf["train"]) dev_loader = aps_dataloader(train=False, distributed=False, max_batch_size=args.batch_size // args.dev_batch_factor, **loader_conf, **data_conf["valid"]) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval) return trainer