Пример #1
0
def train(args):
    config = load_config(args.model_dir)

    train_dataset = LMDataset(config["train_file"],
                              vocab_file=config["vocab_file"])

    vocab_dump_path = os.path.join(args.model_dir, "vocab.pkl")
    with open(vocab_dump_path, 'wb') as fp:
        pickle.dump(train_dataset.vocab, fp)

    valid_dataset = LMDataset(config["valid_file"], vocab_dump=vocab_dump_path)

    config["vocab_size"] = len(train_dataset.vocab)
    model = LM(config, args.model_dir)

    if args.epoch is not None:
        print_time_info("Loading checkpoint {} from model_dir".format(
            args.epoch))
        model.load_model(args.model_dir, args.epoch)

    model.train(epochs=config["train_epochs"],
                batch_size=config["batch_size"],
                data_engine=train_dataset,
                valid_data_engine=valid_dataset,
                train_decoder_epochs=config.get("train_decoder_epochs", 0),
                max_iter_per_epoch=config.get("max_iter_per_epoch", 100000))
Пример #2
0
def test(args):
    config = load_config(args.model_dir)
    dataset_cls = DATASETS[config.get("dataset_cls", "text")]

    vocab_dump_path = os.path.join(args.model_dir, "vocab.pkl")

    test_file = config["test_file"] if len(
        args.test_file) == 0 else args.test_file
    test_dataset = dataset_cls(test_file,
                               vocab_dump=vocab_dump_path,
                               **(config.get("dataset_args", {})))

    config["vocab_size"] = len(test_dataset.vocab)
    model = LM(config, args.model_dir)

    if args.epoch is not None:
        print_time_info("Loading checkpoint {} from model_dir".format(
            args.epoch))
        epoch = model.load_model(args.model_dir, args.epoch)
    else:
        print_time_info("Loading last checkpoint from model_dir")
        epoch = model.load_model(args.model_dir)

    loss = model.test(batch_size=config["batch_size"],
                      data_engine=test_dataset)