def __init__(self, sents_src, sents_tgt, vocab_path):
     # 一般init函数是加载所有数据
     super(BertDataset, self).__init__()
     self.sents_src = sents_src
     self.sents_tgt = sents_tgt
     self.word2idx = load_chinese_base_vocab(vocab_path)
     self.idx2word = {k: v for v, k in self.word2idx.items()}
     self.tokenizer = Tokenizer(self.word2idx)
def read_corpus(dir_path, vocab_path):
    sents_src = []
    sents_tgt = []
    word2idx = load_chinese_base_vocab(vocab_path)
    tokenizer = Tokenizer(word2idx)
    files = os.listdir(dir_path)
    for file1 in files:
        if not os.path.isdir(file1):
            file_path = dir_path + "/" + file1
            print(file_path)
            if file_path[-3:] != "csv":
                continue
            df = pd.read_csv(file_path)
            for index, row in df.iterrows():
                if type(row[0]) is not str or type(row[3]) is not str:
                    continue
                if len(row[0]) > 8 or len(row[0]) < 2:
                    # 过滤掉题目长度过长和过短的诗句
                    continue
                if len(row[0].split(" ")) > 1:
                    # 说明题目里面存在空格,只要空格前面的数据
                    row[0] = row[0].split(" ")[0]
                encode_text = tokenizer.encode(row[3])[0]
                if word2idx["[UNK]"] in encode_text:
                    continue
                if len(row[3]) == 24 and (row[3][5] == ","
                                          or row[3][5] == "。"):
                    # 五言绝句
                    sents_src.append(row[0] + "##" + "五言绝句")
                    sents_tgt.append(row[3])
                elif len(row[3]) == 32 and (row[3][7] == ","
                                            or row[3][7] == "。"):
                    # 七言绝句
                    sents_src.append(row[0] + "##" + "七言绝句")
                    sents_tgt.append(row[3])
                elif len(row[3]) == 48 and (row[3][5] == ","
                                            or row[3][5] == "。"):
                    # 五言律诗
                    sents_src.append(row[0] + "##" + "五言律诗")
                    sents_tgt.append(row[3])
                elif len(row[3]) == 64 and (row[3][7] == ","
                                            or row[3][7] == "。"):
                    # 七言律诗
                    sents_src.append(row[0] + "##" + "七言律诗")
                    sents_tgt.append(row[3])

    print("诗句共: " + str(len(sents_src)) + "篇")
    return sents_src, sents_tgt
 def __init__(self, vocab_path, target_size, model_name="roberta"):
     super(BertClsClassifier, self).__init__()
     self.word2ix = load_chinese_base_vocab(vocab_path)
     self.tokenizer = Tokenizer(self.word2ix)
     self.target_size = target_size
     config = ""
     if model_name == "roberta":
         from bert.seq2seq.model.roberta_model import BertModel, BertConfig
         config = BertConfig(len(self.word2ix))
         self.bert = BertModel(config)
     elif model_name == "bert":
         from bert.seq2seq.model.bert_model import BertConfig, BertModel
         config = BertConfig(len(self.word2ix))
         self.bert = BertModel(config)
     else :
         raise Exception("model_name_err")
         
     self.final_dense = nn.Linear(config.hidden_size, self.target_size)
    def __init__(self, vocab_path, model_name="roberta"):
        super(Seq2SeqModel, self).__init__()
        self.word2ix = load_chinese_base_vocab(vocab_path)
        self.tokenizer = Tokenizer(self.word2ix)
        config = ""
        if model_name == "roberta":
            from bert.seq2seq.model.roberta_model import BertModel, BertConfig, BertLMPredictionHead
            config = BertConfig(len(self.word2ix))
            self.bert = BertModel(config)
            self.decoder = BertLMPredictionHead(
                config, self.bert.embeddings.word_embeddings.weight)
        elif model_name == "bert":
            from bert.seq2seq.model.bert_model import BertConfig, BertModel, BertLMPredictionHead
            config = BertConfig(len(self.word2ix))
            self.bert = BertModel(config)
            self.decoder = BertLMPredictionHead(
                config, self.bert.embeddings.word_embeddings.weight)
        else:
            raise Exception("model_name_err")

        self.hidden_dim = config.hidden_size
        self.vocab_size = config.vocab_size