def main(): ap = argparse.ArgumentParser("Joey NMT") ap.add_argument("mode", choices=["train", "test"], help="train a model or test") ap.add_argument("config_path", type=str, help="path to YAML config file") ap.add_argument("--ckpt", type=str, help="checkpoint for prediction") ap.add_argument( "--output_path", type=str, help="path for saving translation output" ) ap.add_argument("--gpu_id", type=str, default="0", help="gpu to run your job on") args = ap.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id if args.mode == "train": train(cfg_file=args.config_path) elif args.mode == "test": test(cfg_file=args.config_path, ckpt=args.ckpt, output_path=args.output_path) else: raise ValueError("Unknown mode")
def train(cfg_file: str) -> None: """ Main training function. After training, also test on test data if given. :param cfg_file: path to configuration yaml file """ cfg = load_config(cfg_file) # set the random seed set_seed(seed=cfg["training"].get("random_seed", 42)) train_data, dev_data, test_data, gls_vocab, txt_vocab = load_data( data_cfg=cfg["data"]) # build model and load parameters into it do_recognition = cfg["training"].get("recognition_loss_weight", 1.0) > 0.0 do_translation = cfg["training"].get("translation_loss_weight", 1.0) > 0.0 model = build_model( cfg=cfg["model"], gls_vocab=gls_vocab, txt_vocab=txt_vocab, sgn_dim=sum(cfg["data"]["feature_size"]) if isinstance( cfg["data"]["feature_size"], list) else cfg["data"]["feature_size"], do_recognition=do_recognition, do_translation=do_translation, ) # for training management, e.g. early stopping and model selection trainer = TrainManager(model=model, config=cfg) # store copy of original training config in model dir shutil.copy2(cfg_file, trainer.model_dir + "/config.yaml") # log all entries of config log_cfg(cfg, trainer.logger) log_data_info( train_data=train_data, valid_data=dev_data, test_data=test_data, gls_vocab=gls_vocab, txt_vocab=txt_vocab, logging_function=trainer.logger.info, ) trainer.logger.info(str(model)) # store the vocabs gls_vocab_file = "{}/gls.vocab".format(cfg["training"]["model_dir"]) gls_vocab.to_file(gls_vocab_file) txt_vocab_file = "{}/txt.vocab".format(cfg["training"]["model_dir"]) txt_vocab.to_file(txt_vocab_file) # train the model trainer.train_and_validate(train_data=train_data, valid_data=dev_data) # Delete to speed things up as we don't need training data anymore del train_data, dev_data, test_data # predict with the best model on validation and test # (if test data is available) ckpt = "{}/{}.ckpt".format(trainer.model_dir, trainer.best_ckpt_iteration) output_name = "best.IT_{:08d}".format(trainer.best_ckpt_iteration) output_path = os.path.join(trainer.model_dir, output_name) logger = trainer.logger del trainer test(cfg_file, ckpt=ckpt, output_path=output_path, logger=logger)