template_folder="templates", static_folder="./", static_url_path="", ) if "MODEL_DIR" not in os.environ: print("MODEL_DIR must be speicified before launching server") exit(1) model_dir = os.environ["MODEL_DIR"] src_tokenizer = CharTokenizer() src_tokenizer.load_vocab(os.path.join(model_dir, "src_vocab.json")) trg_vocab = Vocab() trg_vocab.load(os.path.join(model_dir, "trg_vocab.json")) model = ModelInterface.load_from_checkpoint( os.path.join(model_dir, "checkpoint.pt"), src_vocab=src_tokenizer.vocab, trg_vocab=trg_vocab, model_name="transformer", ).to("cuda" if torch.cuda.is_available() else "cpu") model = model.eval() @app.route("/", methods=["GET"]) def index(): return render_template("index.html")
parser.add_argument("--src_vocab_path", type=str, required=True, help="白话文词表路径") parser.add_argument("--trg_vocab_path", type=str, required=True, help="文言文词表路径") parser = ModelInterface.add_trainer_args(parser) args = parser.parse_args() if args.token_type == "char": src_tokenizer = CharTokenizer() elif args.token_type == "token": src_tokenizer = VernacularTokenTokenizer() src_tokenizer.load_vocab(args.src_vocab_path) trg_vocab = Vocab() trg_vocab.load(args.trg_vocab_path) model = ModelInterface.load_from_checkpoint( args.checkpoint_path, src_vocab=src_tokenizer.vocab, trg_vocab=trg_vocab, ) model = model.eval() while True: sent = input("原始白话文:") input_token_list = src_tokenizer.tokenize(sent, map_to_id=True) res_sent = model.inference( torch.LongTensor([input_token_list]), torch.LongTensor([len(input_token_list)]),
do_inference: vocab.filter_chars_by_cnt(min_cnt=2) filtered_num = unfiltered_vocab_size - vocab.size() logger.info( 'After filter {} tokens, the final vocab size is {}'.format( filtered_num, vocab.size())) filtered_num = unfiltered_char_size - vocab.get_char_vocab_size() logger.info( 'After filter {} tokens, the final vocab size is {}'.format( filtered_num, vocab.get_char_vocab_size())) # # sys.exit(1) import os vocab_file = 'first_third_baihuo_vocab.txt' # vocab.load_from_file('vocab_bool.txt') if os.path.exists(vocab_file): vocab.load_from_file(vocab_file) if os.path.exists(vocab_file): vocab.load() if not os.path.exists(vocab_file): vocab.load_pretrained_embeddings('/home/wujindou/sgns.merge.word') #print(vocab.get_char_vocab_size())#if not os.path.exists(vocab_file):vocab.load_pretrained_embeddings('/home/wujindou/sgns.merge.word') #sys.exit(1)#print(voab.get_word_vocab())#if not os.path.exists(vocab_file):vocab.load_pretrained_embeddings('/home/wujindou/sgns.merge.word') if not os.path.exists(vocab_file): vocab.save() writer = open(vocab_file, 'a+', encoding='utf-8') for word, id in vocab.token2id.items(): writer.write(word + '\t' + str(id) + '\n') writer.close() logger.info('after load embedding vocab size is {}'.format(vocab.size())) #print(vocab.embeddings.shape) import sys
class AncientPairDataModule(pl.LightningDataModule): def __init__(self, batch_size: int, data_dir: str, workers: int): super().__init__() self.data_dir = Path(data_dir) self.batch_size = batch_size self.workers = workers if not self.data_dir.exists(): raise ValueError("Directory or file doesn't exist") if not self.data_dir.is_dir(): raise ValueError("`data_dir` must be a path to directory") @classmethod def add_data_args(cls, parent_parser: argparse.ArgumentParser): parser = parent_parser.add_argument_group("data") parser.add_argument("--data_dir", type=str, default="./data", help="数据存储路径") parser.add_argument("--batch_size", type=int, default=128, help="一个batch的大小") parser.add_argument("--workers", type=int, default=0, help="读取dataset的worker数") cls.parser = parser return parent_parser def prepare_data(self): """数据已提前准备完成""" def setup(self, stage: Optional[str] = None): self.src_vocab = Vocab() self.src_vocab.load(str(self.data_dir / "src_vocab.json")) self.src_vocab_size = len(self.src_vocab) self.trg_vocab = Vocab() self.trg_vocab.load(str(self.data_dir / "trg_vocab.json")) self.trg_vocab_size = len(self.trg_vocab) self.train_dataset = AncientPairDataset( str(self.data_dir / "train.tsv"), 128, self.src_vocab, self.trg_vocab, ) self.valid_dataset = AncientPairDataset( str(self.data_dir / "valid.tsv"), 128, self.src_vocab, self.trg_vocab, ) self.test_dataset = AncientPairDataset( str(self.data_dir / "test.tsv"), 128, self.src_vocab, self.trg_vocab, ) logger.info( f"数据集信息:\n\t" f"训练集: {len(self.train_dataset)}, " f"验证集: {len(self.valid_dataset)}, " f"测试集: {len(self.test_dataset)}", ) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, num_workers=self.workers, ) def val_dataloader(self): return DataLoader( self.valid_dataset, batch_size=self.batch_size, num_workers=self.workers, ) def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.workers, )