예제 #1
0
def main():
    ap = argparse.ArgumentParser("Joey NMT")

    ap.add_argument("mode",
                    choices=["train", "test", "translate"],
                    help="train a model or test or translate")

    ap.add_argument("config_path", type=str, help="path to YAML config file")

    ap.add_argument("--ckpt", type=str, help="checkpoint for prediction")

    ap.add_argument("--output_path",
                    type=str,
                    help="path for saving translation output")

    ap.add_argument("--save_attention",
                    action="store_true",
                    help="save attention visualizations")

    args = ap.parse_args()

    if args.mode == "train":
        train(cfg_file=args.config_path)
    elif args.mode == "test":
        test(cfg_file=args.config_path,
             ckpt=args.ckpt,
             output_path=args.output_path,
             save_attention=args.save_attention)
    elif args.mode == "translate":
        translate(cfg_file=args.config_path,
                  ckpt=args.ckpt,
                  output_path=args.output_path)
    else:
        raise ValueError("Unknown mode")
예제 #2
0
def train(cfg_file: str) -> None:
    """
    Main training function. After training, also test on test data if given.

    :param cfg_file: path to configuration yaml file
    """
    cfg = load_config(cfg_file)

    # make logger
    model_dir = make_model_dir(cfg["training"]["model_dir"],
                   overwrite=cfg["training"].get("overwrite", False))
    _ = make_logger(model_dir, mode="train")    # version string returned
    # TODO: save version number in model checkpoints

    # set the random seed
    set_seed(seed=cfg["training"].get("random_seed", 42))

    # load the data
    train_data, dev_data, test_data, src_vocab, trg_vocab = load_data(
        data_cfg=cfg["data"])

    # build an encoder-decoder model
    model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)

    # for training management, e.g. early stopping and model selection
    trainer = TrainManager(model=model, config=cfg)

    # store copy of original training config in model dir
    shutil.copy2(cfg_file, model_dir + "/config.yaml")

    # log all entries of config
    log_cfg(cfg)

    log_data_info(train_data=train_data, valid_data=dev_data,
                  test_data=test_data, src_vocab=src_vocab, trg_vocab=trg_vocab)

    logger.info(str(model))

    # store the vocabs
    src_vocab_file = "{}/src_vocab.txt".format(cfg["training"]["model_dir"])
    src_vocab.to_file(src_vocab_file)
    trg_vocab_file = "{}/trg_vocab.txt".format(cfg["training"]["model_dir"])
    trg_vocab.to_file(trg_vocab_file)

    # train the model
    trainer.train_and_validate(train_data=train_data, valid_data=dev_data)

    # predict with the best model on validation and test
    # (if test data is available)
    ckpt = "{}/{}.ckpt".format(model_dir, trainer.stats.best_ckpt_iter)
    output_name = "{:08d}.hyps".format(trainer.stats.best_ckpt_iter)
    output_path = os.path.join(model_dir, output_name)
    datasets_to_test = {"dev": dev_data, "test": test_data,
                        "src_vocab": src_vocab, "trg_vocab": trg_vocab}
    test(cfg_file, ckpt=ckpt, output_path=output_path,
         datasets=datasets_to_test)
예제 #3
0
def train(cfg_file: str) -> None:
    """
    Main training function. After training, also test on test data if given.

    :param cfg_file: path to configuration yaml file
    """
    cfg = load_config(cfg_file)
    train_cfg = cfg["training"]
    data_cfg = cfg["data"]

    # set the random seed
    set_seed(seed=train_cfg.get("random_seed", 42))

    # load the data
    data = load_data(data_cfg)
    train_data = data["train_data"]
    dev_data = data["dev_data"]
    test_data = data["test_data"]
    vocabs = data["vocabs"]

    # build an encoder-decoder model
    model = build_model(cfg["model"], vocabs=vocabs)

    # for training management, e.g. early stopping and model selection
    trainer = TrainManager(model=model, config=cfg)

    # store copy of original training config in model dir
    shutil.copy2(cfg_file, join(trainer.model_dir, "config.yaml"))

    # log all entries of config
    log_cfg(cfg, trainer.logger)

    log_data_info(
        train_data=train_data,
        valid_data=dev_data,
        test_data=test_data,
        vocabs=vocabs,
        logging_function=trainer.logger.info)

    trainer.logger.info(str(model))

    # store the vocabs
    model_dir = train_cfg["model_dir"]
    for field_name, vocab in vocabs.items():
        vocab_file = join(model_dir, field_name + "_vocab.txt")
        vocab.to_file(vocab_file)

    # train the model
    trainer.train_and_validate(train_data=train_data, valid_data=dev_data)

    # predict with the best model on validation (and test, if available)
    ckpt = join(trainer.model_dir, str(trainer.best_ckpt_iteration) + ".ckpt")
    output_name = "{:08d}.hyps".format(trainer.best_ckpt_iteration)
    output_path = join(trainer.model_dir, output_name)
    test(cfg_file, ckpt=ckpt, output_path=output_path, logger=trainer.logger)
예제 #4
0
def train(cfg_file: str) -> None:
    """
    Main training function. After training, also test on test data if given.

    :param cfg_file: path to configuration yaml file
    """
    cfg = load_config(cfg_file)

    # set the random seed
    set_seed(seed=cfg["training"].get("random_seed", 42))

    # load the data
    train_data, dev_data, test_data, src_vocab, trg_vocab = load_data(
        data_cfg=cfg["data"])

    # build an encoder-decoder model
    model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)

    # for training management, e.g. early stopping and model selection
    trainer = TrainManager(model=model, config=cfg)

    # store copy of original training config in model dir
    shutil.copy2(cfg_file, trainer.model_dir + "/config.yaml")

    # log all entries of config
    log_cfg(cfg, trainer.logger)

    log_data_info(train_data=train_data,
                  valid_data=dev_data,
                  test_data=test_data,
                  src_vocab=src_vocab,
                  trg_vocab=trg_vocab,
                  logging_function=trainer.logger.info)

    trainer.logger.info(str(model))

    # store the vocabs
    src_vocab_file = "{}/src_vocab.txt".format(cfg["training"]["model_dir"])
    src_vocab.to_file(src_vocab_file)
    trg_vocab_file = "{}/trg_vocab.txt".format(cfg["training"]["model_dir"])
    trg_vocab.to_file(trg_vocab_file)

    # train the model
    trainer.train_and_validate(train_data=train_data, valid_data=dev_data)

    # predict with the best model on validation and test
    # (if test data is available)
    ckpt = "{}/{}.ckpt".format(trainer.model_dir, trainer.best_ckpt_iteration)
    output_name = "{:08d}.hyps".format(trainer.best_ckpt_iteration)
    output_path = os.path.join(trainer.model_dir, output_name)
    test(cfg_file, ckpt=ckpt, output_path=output_path, logger=trainer.logger)
예제 #5
0
def train(cfg_file: str) -> None:
    """
    Main training function. After training, also test on test data if given.

    :param cfg_file: path to configuration yaml file
    """
    cfg = load_config(cfg_file)

    # set the random seed
    set_seed(seed=cfg["training"].get("random_seed", 42))
    shards_dir = os.path.dirname(cfg["data"]["shard_path"])
    if not os.path.exists(shards_dir):
        os.makedirs(shards_dir)

    if cfg["data"].get("shard_data", False):
        assert cfg["data"].get(
            "n_shards", 0) > 0, "n_shards needs to exist and be at least 1"
        shard_data(path=cfg["data"]["train"],
                   src_lang=cfg["data"]["src"],
                   tgt_lang=cfg["data"]["trg"],
                   n_shards=cfg["data"]["n_shards"],
                   shard_path=cfg["data"]["shard_path"])

    # load the data
    load_train_whole = True if cfg["data"].get("n_shards", 0) < 1 else False
    train_data, dev_data, test_data, src_vocab, trg_vocab, src_field, trg_field = load_data(
        data_cfg=cfg["data"], load_train=load_train_whole)

    if not load_train_whole:
        sharded_iterator = ShardedEpochDatasetIterator(
            n_shards=cfg["data"]["n_shards"],
            percent_to_sample=cfg["data"].get("percent_to_sample_from_shard",
                                              1.0),
            data_path=cfg["data"]["train"],
            shard_path=cfg["data"]["shard_path"],
            extensions=(cfg["data"]["src"], cfg["data"]["trg"]),
            fields=(src_field, trg_field),
            n_epochs=cfg["training"]["epochs"],
            filter_pred=lambda x: len(vars(x)[
                'src']) <= cfg["data"]["max_sent_length"] and len(
                    vars(x)['trg']) <= cfg["data"]["max_sent_length"])
    else:
        sharded_iterator = None

    # build an encoder-decoder model
    model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)

    # for training management, e.g. early stopping and model selection
    trainer = TrainManager(model=model, config=cfg)

    # store copy of original training config in model dir
    shutil.copy2(cfg_file, trainer.model_dir + "/config.yaml")

    # log all entries of config
    log_cfg(cfg, trainer.logger)
    if load_train_whole:
        log_data_info(train_data=train_data,
                      valid_data=dev_data,
                      test_data=test_data,
                      src_vocab=src_vocab,
                      trg_vocab=trg_vocab,
                      logging_function=trainer.logger.info)

    trainer.logger.info(str(model))

    # store the vocabs
    src_vocab_file = "{}/src_vocab.txt".format(cfg["training"]["model_dir"])
    src_vocab.to_file(src_vocab_file)
    trg_vocab_file = "{}/trg_vocab.txt".format(cfg["training"]["model_dir"])
    trg_vocab.to_file(trg_vocab_file)

    # train the model
    trainer.train_and_validate(train_data=train_data,
                               valid_data=dev_data,
                               sharded_iterator=sharded_iterator)

    # predict with the best model on validation and test
    # (if test data is available)
    ckpt = "{}/{}.ckpt".format(trainer.model_dir, trainer.best_ckpt_iteration)
    output_name = "{:08d}.hyps".format(trainer.best_ckpt_iteration)
    output_path = os.path.join(trainer.model_dir, output_name)
    test(cfg_file, ckpt=ckpt, output_path=output_path, logger=trainer.logger)