Ejemplo n.º 1
0
    template_folder="templates",
    static_folder="./",
    static_url_path="",
)

if "MODEL_DIR" not in os.environ:
    print("MODEL_DIR must be speicified before launching server")
    exit(1)

model_dir = os.environ["MODEL_DIR"]

src_tokenizer = CharTokenizer()
src_tokenizer.load_vocab(os.path.join(model_dir, "src_vocab.json"))

trg_vocab = Vocab()
trg_vocab.load(os.path.join(model_dir, "trg_vocab.json"))

model = ModelInterface.load_from_checkpoint(
    os.path.join(model_dir, "checkpoint.pt"),
    src_vocab=src_tokenizer.vocab,
    trg_vocab=trg_vocab,
    model_name="transformer",
).to("cuda" if torch.cuda.is_available() else "cpu")

model = model.eval()


@app.route("/", methods=["GET"])
def index():
    return render_template("index.html")
Ejemplo n.º 2
0
    parser.add_argument("--src_vocab_path", type=str, required=True, help="白话文词表路径")
    parser.add_argument("--trg_vocab_path", type=str, required=True, help="文言文词表路径")

    parser = ModelInterface.add_trainer_args(parser)

    args = parser.parse_args()

    if args.token_type == "char":
        src_tokenizer = CharTokenizer()
    elif args.token_type == "token":
        src_tokenizer = VernacularTokenTokenizer()

    src_tokenizer.load_vocab(args.src_vocab_path)

    trg_vocab = Vocab()
    trg_vocab.load(args.trg_vocab_path)

    model = ModelInterface.load_from_checkpoint(
        args.checkpoint_path,
        src_vocab=src_tokenizer.vocab,
        trg_vocab=trg_vocab,
    )

    model = model.eval()
    while True:
        sent = input("原始白话文:")

        input_token_list = src_tokenizer.tokenize(sent, map_to_id=True)
        res_sent = model.inference(
            torch.LongTensor([input_token_list]),
            torch.LongTensor([len(input_token_list)]),
        do_inference: vocab.filter_chars_by_cnt(min_cnt=2)
        filtered_num = unfiltered_vocab_size - vocab.size()
        logger.info(
            'After filter {} tokens, the final vocab size is {}'.format(
                filtered_num, vocab.size()))

        filtered_num = unfiltered_char_size - vocab.get_char_vocab_size()
        logger.info(
            'After filter {} tokens, the final vocab size is {}'.format(
                filtered_num, vocab.get_char_vocab_size()))
        # # sys.exit(1)

    import os
    vocab_file = 'first_third_baihuo_vocab.txt'  # vocab.load_from_file('vocab_bool.txt')
    if os.path.exists(vocab_file): vocab.load_from_file(vocab_file)
    if os.path.exists(vocab_file): vocab.load()
    if not os.path.exists(vocab_file):
        vocab.load_pretrained_embeddings('/home/wujindou/sgns.merge.word')
    #print(vocab.get_char_vocab_size())#if  not  os.path.exists(vocab_file):vocab.load_pretrained_embeddings('/home/wujindou/sgns.merge.word')
    #sys.exit(1)#print(voab.get_word_vocab())#if  not  os.path.exists(vocab_file):vocab.load_pretrained_embeddings('/home/wujindou/sgns.merge.word')

    if not os.path.exists(vocab_file):
        vocab.save()
        writer = open(vocab_file, 'a+', encoding='utf-8')
        for word, id in vocab.token2id.items():
            writer.write(word + '\t' + str(id) + '\n')
        writer.close()

    logger.info('after load embedding vocab size is {}'.format(vocab.size()))
    #print(vocab.embeddings.shape)
    import sys
Ejemplo n.º 4
0
class AncientPairDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int, data_dir: str, workers: int):
        super().__init__()

        self.data_dir = Path(data_dir)
        self.batch_size = batch_size
        self.workers = workers

        if not self.data_dir.exists():
            raise ValueError("Directory or file doesn't exist")
        if not self.data_dir.is_dir():
            raise ValueError("`data_dir` must be a path to directory")

    @classmethod
    def add_data_args(cls, parent_parser: argparse.ArgumentParser):
        parser = parent_parser.add_argument_group("data")

        parser.add_argument("--data_dir",
                            type=str,
                            default="./data",
                            help="数据存储路径")
        parser.add_argument("--batch_size",
                            type=int,
                            default=128,
                            help="一个batch的大小")
        parser.add_argument("--workers",
                            type=int,
                            default=0,
                            help="读取dataset的worker数")

        cls.parser = parser

        return parent_parser

    def prepare_data(self):
        """数据已提前准备完成"""

    def setup(self, stage: Optional[str] = None):
        self.src_vocab = Vocab()
        self.src_vocab.load(str(self.data_dir / "src_vocab.json"))
        self.src_vocab_size = len(self.src_vocab)

        self.trg_vocab = Vocab()
        self.trg_vocab.load(str(self.data_dir / "trg_vocab.json"))
        self.trg_vocab_size = len(self.trg_vocab)

        self.train_dataset = AncientPairDataset(
            str(self.data_dir / "train.tsv"),
            128,
            self.src_vocab,
            self.trg_vocab,
        )
        self.valid_dataset = AncientPairDataset(
            str(self.data_dir / "valid.tsv"),
            128,
            self.src_vocab,
            self.trg_vocab,
        )
        self.test_dataset = AncientPairDataset(
            str(self.data_dir / "test.tsv"),
            128,
            self.src_vocab,
            self.trg_vocab,
        )

        logger.info(
            f"数据集信息:\n\t"
            f"训练集: {len(self.train_dataset)}, "
            f"验证集: {len(self.valid_dataset)}, "
            f"测试集: {len(self.test_dataset)}", )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            num_workers=self.workers,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.workers,
        )