def train(args): config = load_config(args.model_dir) train_dataset = LMDataset(config["train_file"], vocab_file=config["vocab_file"]) vocab_dump_path = os.path.join(args.model_dir, "vocab.pkl") with open(vocab_dump_path, 'wb') as fp: pickle.dump(train_dataset.vocab, fp) valid_dataset = LMDataset(config["valid_file"], vocab_dump=vocab_dump_path) config["vocab_size"] = len(train_dataset.vocab) model = LM(config, args.model_dir) if args.epoch is not None: print_time_info("Loading checkpoint {} from model_dir".format( args.epoch)) model.load_model(args.model_dir, args.epoch) model.train(epochs=config["train_epochs"], batch_size=config["batch_size"], data_engine=train_dataset, valid_data_engine=valid_dataset, train_decoder_epochs=config.get("train_decoder_epochs", 0), max_iter_per_epoch=config.get("max_iter_per_epoch", 100000))
def test(args): config = load_config(args.model_dir) dataset_cls = DATASETS[config.get("dataset_cls", "text")] vocab_dump_path = os.path.join(args.model_dir, "vocab.pkl") test_file = config["test_file"] if len( args.test_file) == 0 else args.test_file test_dataset = dataset_cls(test_file, vocab_dump=vocab_dump_path, **(config.get("dataset_args", {}))) config["vocab_size"] = len(test_dataset.vocab) model = LM(config, args.model_dir) if args.epoch is not None: print_time_info("Loading checkpoint {} from model_dir".format( args.epoch)) epoch = model.load_model(args.model_dir, args.epoch) else: print_time_info("Loading last checkpoint from model_dir") epoch = model.load_model(args.model_dir) loss = model.test(batch_size=config["batch_size"], data_engine=test_dataset)