Beispiel #1
0
    def __init__(self, vocab_path, size="base"):
        super().__init__()
        if size == "base":
            config = T5Config()
        elif size == "small":
            config = T5SmallConfig()
        else:
            raise Exception("not support this model type")
        self.model = T5ForConditionalGeneration(config)

        self.word2idx = load_chinese_base_vocab(vocab_path)
        self.tokenizer = T5PegasusTokenizer(self.word2idx)
        self.bos_id = self.word2idx["[CLS]"]
        self.eos_id = self.word2idx["[SEP]"]
        self.unk_id = self.word2idx["[UNK]"]
import json
import time
import glob
import bert_seq2seq
from torch.utils.data import Dataset, DataLoader
from bert_seq2seq.t5_ch import T5Model
from bert_seq2seq.tokenizer import T5PegasusTokenizer, load_chinese_base_vocab
from bert_seq2seq.t5_ch import T5Model

vocab_path = "./state_dict/t5-chinese/vocab.txt"
model_path = "./state_dict/t5-chinese/pytorch_model.bin"
model_save_path = "./state_dict/t5_ancient_trans_model.bin"
batch_size = 8
lr = 1e-5
word2idx = load_chinese_base_vocab(vocab_path)
tokenizer = T5PegasusTokenizer(word2idx)


def read_corpus():
    """
    读原始数据
    """
    src = []
    tgt = []
    data_path = glob.glob("./corpus/文言文翻译/*")
    for p in data_path:
        dir = p.split("/")[:-1]
        dir = "/".join(dir)
        # print(dir)
        name = p.split("/")[-1]
        if "翻译" in name: