def __init__(self, logger, args, examples, for_prediction=False):

        self.logger = logger
        self.args = args
        self.for_prediction = for_prediction

        reverse_qa = self.args["reverse_qa"]
        self.examples = examples
        if reverse_qa:
            self.reverse_qa()

        self.max_seq_length = self.args["max_seq_length"]

        # self.logger.info("Prepare to build tokenizer ……")
        self.c_token = CToken(
            self.args["vocab_name"], self.args["vocab_format"], self.args["vocab_type"], self.args["do_lowercase"]
        )
        self.engineer_token = EngnieerTokenizer(
            self.args["vocab_name"], self.args["vocab_format"], self.args["vocab_type"], self.args["do_lowercase"]
        )
        self.logger.info("Successfully build tokenizer")
        # 使用指定的字典,构建tokenizer

        self.func(half_width)
        self.func(lower)
        if self.args["pretrained_model_type"] == "ernie":
            self.func(punctuation_replace_for_ernie)
            self.func(translate_for_ernie)
        self.func(split_unk, self.c_token.vocab)

        self.features = []
 def __init__(self, args, logger, examples, feature_file_name, task_id=0, is_prediction=False):
     self.logger = logger
     self.args = args
     self.examples = examples
     self.task_id = task_id
     self.is_prediction = is_prediction
     self.feature_file_name = feature_file_name
     self.max_seq_length = self.args["max_seq_length"]
     # self.logger.info("Prepare to build tokenizer ……")
     self.tokenizer = CToken(
         self.args["vocab_name"], self.args["vocab_format"], self.args["vocab_type"], self.args["do_lowercase"]
     )
     self.logger.info("Successfully build tokenizer")
     # 使用指定的字典,构建tokenizer
     self.batch_size = args['batch_size']
     self.features = []
 def __init__(self, args, logger, questions, answers):
     self.logger = logger
     self.args = args
     assert len(questions) == len(answers)
     self.questions = questions
     self.answers = answers
     self.max_seq_length = self.args["max_seq_length"]
     self.qa_pair = []
     # self.logger.info("Prepare to build tokenizer ……")
     self.tokenizer = CToken(
         self.args["vocab_name"], self.args["vocab_format"], self.args["vocab_type"], self.args["do_lowercase"]
     )
     self.logger.info("Successfully build tokenizer")
     # 使用指定的字典,构建tokenizer
     self.batch_size = args['batch_size']
     self.features = []
 def __init__(
     self,
     args,
     logger,
     docs,
 ):
     self.logger = logger
     self.args = args
     self.docs = docs
     self.max_seq_length = self.args["max_seq_length"]
     self.sentence_split = []
     # self.logger.info("Prepare to build tokenizer ……")
     self.tokenizer = CToken(self.args["vocab_name"],
                             self.args["vocab_format"],
                             self.args["vocab_type"],
                             self.args["do_lowercase"])
     self.logger.info("Successfully build tokenizer")
     # 使用指定的字典,构建tokenizer
     self.batch_size = args['batch_size']
     self.features = []
class PreProcess:

    def __init__(self, logger, args, examples, for_prediction=False):

        self.logger = logger
        self.args = args
        self.for_prediction = for_prediction

        reverse_qa = self.args["reverse_qa"]
        self.examples = examples
        if reverse_qa:
            self.reverse_qa()

        self.max_seq_length = self.args["max_seq_length"]

        # self.logger.info("Prepare to build tokenizer ……")
        self.c_token = CToken(
            self.args["vocab_name"], self.args["vocab_format"], self.args["vocab_type"], self.args["do_lowercase"]
        )
        self.engineer_token = EngnieerTokenizer(
            self.args["vocab_name"], self.args["vocab_format"], self.args["vocab_type"], self.args["do_lowercase"]
        )
        self.logger.info("Successfully build tokenizer")
        # 使用指定的字典,构建tokenizer

        self.func(half_width)
        self.func(lower)
        if self.args["pretrained_model_type"] == "ernie":
            self.func(punctuation_replace_for_ernie)
            self.func(translate_for_ernie)
        self.func(split_unk, self.c_token.vocab)

        self.features = []

    def get_vocab_size(self):
        """
        获取使用的词表的大小
        """

        return len(self.c_token.vocab)

    def func(self, util_func, vocab=None):
        """
        对question和answer的文本信息进行整理
        """

        if vocab is None:
            for i in range(len(self.examples)):
                self.examples[i].question = util_func(self.examples[i].question)
                self.examples[i].answer = util_func(self.examples[i].answer)
        else:
            for i in range(len(self.examples)):
                self.examples[i].question = util_func(self.examples[i].question, vocab)
                self.examples[i].answer = util_func(self.examples[i].answer, vocab)

    def exams_tokenize(self, examples, token_id=2):
        """
        完成对list of Example中question和answer的tokenize,并返回结果列表
        :param examples: 需要处理的Example列表
        :param token_id: token_id=1时返回tokens,token_id=2时返回ids
        """

        ques_tokens = []
        ans_tokens = []
        ques_ids = []
        ans_ids = []
        for example in examples:
            q_tokens = self.c_token.tokenize(example.question)
            ques_tokens.append(q_tokens)
            a_tokens = self.c_token.tokenize(example.answer)
            ans_tokens.append(a_tokens)
            if token_id == 2:
                q_ids = self.c_token.convert_tokens_to_ids(q_tokens)
                ques_ids.append(q_ids)
                a_ids = self.c_token.convert_tokens_to_ids(a_tokens)
                ans_ids.append(a_ids)
        if token_id == 1:
            return ques_tokens, ans_tokens
        elif token_id == 2:
            return ques_ids, ans_ids

    def save_tokens(self, tokens, file_name, file_format="pickle", file_type="datap"):
        """
        将tokens储存在指定位置
        """

        save_file(tokens, file_type, file_name, file_format)

    def splice_ques_ans(self, ques_ids, ans_ids, special_char=None):
        """
        对问题张量和答案张量进行拼接,并返回句子最大长度与单个token总数的信息
        """

        vocab = self.c_token.vocab
        if special_char is None:
            special_char = {"CLS": vocab["[CLS]"], "SEP": vocab["[SEP]"],
                            "MASK": vocab["[MASK]"], "PAD": vocab["[PAD]"]}

        l1 = len(ques_ids)
        l2 = len(ans_ids)
        if l1 != l2:
            raise Exception("Different number of Questions and Answers")
            # 发现问题答案数量不匹配,返回错误信息
        batch_tokens = []
        max_len = 0
        total_token_num = 0
        for i in range(l1):
            if len(ques_ids[i]) + len(ans_ids[i]) > self.max_seq_length - 3:
                ques_ids[i], ans_ids[i] = self._truncate_seq_pair(ques_ids[i], ans_ids[i])

            sent = [special_char["CLS"]] + ques_ids[i] + [special_char["SEP"]] + ans_ids[i] + [special_char["SEP"]]

            batch_tokens.append(sent)
            max_len = max(max_len, len(sent))
            total_token_num += len(sent)

        return batch_tokens, max_len, total_token_num

    def mask(self, batch_tokens, max_len, total_token_num, special_char=None):
        """
        进行mask覆盖,返回覆盖后的结果和覆盖信息
        """

        vocab = self.c_token.vocab
        if special_char is None:
            special_char = {"CLS": vocab["[CLS]"], "SEP": vocab["[SEP]"],
                            "MASK": vocab["[MASK]"], "PAD": vocab["[PAD]"]}

        vocab_size = len(self.c_token.vocab)
        return mask(batch_tokens, max_len, total_token_num, vocab_size, special_char)

    def pad_batch_data(self,
                       batch_tokens,
                       max_len,
                       pad_idx=0,
                       return_pos=False,
                       return_sent=False, sep_id=None,
                       return_input_mask=False,
                       # return_max_len=False,
                       # return_num_token=False
                       ):
        """
        将句子统一填充到最大句子长度,并生成相应的位置数据和输入覆盖
        """

        if sep_id is None:
            sep_id = self.c_token.vocab["[SEP]"]
        return pad_batch_data(batch_tokens, max_len, pad_idx, return_pos, return_sent, sep_id, return_input_mask)

    '''
    def prepare_batch_data(self,
                           insts,
                           max_len,
                           total_token_num,
                           voc_size=0,
                           pad_id=None,
                           cls_id=None,
                           sep_id=None,
                           mask_id=None,
                           # return_input_mask=True,
                           # return_max_len=True,
                           # return_num_token=False
                           ):
        """
        创建数据张量、位置张量、自注意力覆盖(shape: batch_size*max_len*max_len)
        """

        return prepare_batch_data(insts, max_len, total_token_num, voc_size, pad_id, cls_id, sep_id, mask_id)
    '''

    def get_tokens(self, file_name, file_format="pickle", file_type="datap"):
        """
        获取tokenize结果并将之储存进指定文件,若文件已存在则直接读取
        file_name=""表示不进行文件缓存
        """

        if file_name != "" and os.path.exists(get_fullurl(file_type, file_name, file_format)):
            self.logger.info("Get tokens from file")
            self.logger.info("File location: " + get_fullurl(file_type, file_name, file_format))
            batch_tokens = read_file(file_type, file_name, file_format)
            total_token_num = 0
            for sent in batch_tokens:
                total_token_num += len(sent)

        else:
            self.logger.info("Start caching output of tokenizing")
            ques_ids, ans_ids = self.exams_tokenize(self.examples)
            self.logger.info("  - Complete tokenizing")
            batch_tokens, _, total_token_num = self.splice_ques_ans(ques_ids, ans_ids)
            self.logger.info("  - Complete splicing question and answer")
            if file_name != "":
                self.save_tokens(batch_tokens, file_name)
                self.logger.info("  - Complete cache of tokenize results")
                self.logger.info("    File location: " + "dataset_processed/" + file_name)
            self.logger.info("Finish caching")

        return batch_tokens, total_token_num

    def get_engineer_ids(self, cache_name=""):
        """
        获取特征工程特征
        """
        engineer_ids = None
        if cache_name != "":
            cache_name = cache_name + "_engineer"
        if cache_name != "" and os.path.exists(
                get_fullurl(file_type="engineer", file_name=cache_name, file_format="pickle")):
            engineer_ids = read_file(file_type="engineer", file_name=cache_name, file_format="pickle")
        else:
            sent_ids = []
            count = 0
            for example in self.examples:
                question = example.question
                answer = example.answer
                entity_same = 0
                # if self.engineer_token.entity_sim(question,answer):
                #     entity_same = 1
                ques_id_list = self.engineer_token.convert_tokens_to_ids(self.engineer_token.tokenize(question))
                ans_id_list = self.engineer_token.convert_tokens_to_ids(self.engineer_token.tokenize(answer))
                special_char = {"CLS": 1, "SEP": 2, "MASK": 3, "PAD": 0}
                if len(ques_id_list) + len(ans_id_list) > self.max_seq_length - 3:
                    ques_id_list, ans_id_list = self._truncate_seq_pair(ques_id_list, ans_id_list)
                sent = [special_char["CLS"]] + ques_id_list + [special_char["SEP"]] + ans_id_list + [special_char["SEP"]]
                # pad_sent = sent + [special_char["PAD"]] * (self.max_seq_length - len(sent)) + [entity_same]
                pad_sent = sent + [special_char["PAD"]] * (self.max_seq_length - len(sent))
                sent_ids.append(pad_sent)
                count += 1
                if count % 1000 == 0:
                    print("engineer_ids has get {}".format(count))
            engineer_ids = np.array(sent_ids)
            # engineer_ids = engineer_ids.astype("int64").reshape([-1, self.max_seq_length+1, 1])
            engineer_ids = engineer_ids.astype("int64").reshape([-1, self.max_seq_length, 1])
            save_file(content=engineer_ids, file_type="engineer", file_name=cache_name, file_format="pickle")
        return engineer_ids

    def prepare_batch_data(self, cache_filename="", file_format="pickle", file_type="datap"):
        """
        先从指定文件获取batch_tokens与total_token_num数据
        对给出的batch_tokens进行mask覆盖及填充处理,并返回其他id数据
        """

        batch_tokens, total_token_num = self.get_tokens(cache_filename, file_format=file_format, file_type=file_type)

        self.logger.info("Start data-preprocessing before batching")

        if self.args["is_mask"]:
            batch_tokens, mask_label, mask_pos = self.mask(batch_tokens, self.max_seq_length, total_token_num)
            self.logger.info("  - Complete masking tokens")

        out = self.pad_batch_data(batch_tokens, self.max_seq_length,
                                  return_pos=True, return_sent=True, return_input_mask=True)
        src_ids, pos_ids, sent_ids, input_masks = out[0], out[1], out[2], out[3]
        engineer_ids = np.random.randint(0,20,(len(self.examples), self.max_seq_length, 1))
        if self.args['use_engineer']:
            engineer_ids = self.get_engineer_ids(cache_filename)
        qas_ids = []
        labels = []
        temp = {"Yes": 0, "No": 1, "Depends": 2}
        for example in self.examples:
            qas_ids.append(example.qas_id)
        if not self.for_prediction:
            for example in self.examples:
                try:
                    labels.append(temp[example.yes_or_no])
                except Exception:
                    raise KeyError("Error in labels of train-dataset") from Exception
                    # 训练集标签中出现Yes,No,Depends以外的值,返回错误信息
        else:
            labels = [3] * len(self.examples)
        self.logger.info("  - Complete filling the tokens to max_seq_length, and getting other ids")

        self.features = []
        for i in range(len(self.examples)):
            self.features.append(Feature(
                qas_ids[i], src_ids[i], pos_ids[i], sent_ids[i], input_masks[i], labels[i], engineer_ids[i]
            ))
        self.logger.info("  - Complete constructing features object")
        self.logger.info("Finish data-preprocessing")

    def sample_generator(self):

        self.logger.info("Preprocessing a new round of data of {}".format(len(self.features)))
        if self.args["shuffle"]:
            random.shuffle(self.features)
        if not self.for_prediction:
            for feature in self.features:
                if self.args['use_engineer']:
                    yield feature.qas_id, feature.src_id, feature.pos_id, feature.sent_id, feature.input_mask, feature.label, feature.engineer_id
                else:
                    yield feature.qas_id, feature.src_id, feature.pos_id, feature.sent_id, feature.input_mask, feature.label
        else:
            for feature in self.features:
                if self.args['use_engineer']:
                    yield feature.qas_id, feature.src_id, feature.pos_id, feature.sent_id, feature.input_mask, feature.engineer_id
                else:
                    yield feature.qas_id, feature.src_id, feature.pos_id, feature.sent_id, feature.input_mask

    def batch_generator(self):

        reader = fluid.io.batch(self.sample_generator, batch_size=self.args["batch_size"])
        return reader

    def reverse_qa(self):
        """
        将question和answer的位置互换
        """

        for example in self.examples:
            a = example.answer
            example.answer = example.question
            example.question = a

    def _truncate_seq_pair(self, tokens_a, tokens_b):
        """
        截短过长的问答对
        """

        max_length = self.max_seq_length - 3

        while True:
            total_length = len(tokens_a) + len(tokens_b)
            if total_length <= max_length:
                break
            if len(tokens_a) > len(tokens_b):
                tokens_a.pop()
            else:
                tokens_b.pop()

        return tokens_a, tokens_b
class ProcessorForPretraining():
    def __init__(
        self,
        args,
        logger,
        docs,
    ):
        self.logger = logger
        self.args = args
        self.docs = docs
        self.max_seq_length = self.args["max_seq_length"]
        self.sentence_split = []
        # self.logger.info("Prepare to build tokenizer ……")
        self.tokenizer = CToken(self.args["vocab_name"],
                                self.args["vocab_format"],
                                self.args["vocab_type"],
                                self.args["do_lowercase"])
        self.logger.info("Successfully build tokenizer")
        # 使用指定的字典,构建tokenizer
        self.batch_size = args['batch_size']
        self.features = []

    def get_vocab_size(self):
        """
        获取使用的词表的大小
        """

        return len(self.tokenizer.vocab)

    def merge_sentences(self, sentence_1, sentence_2, special_char=None):
        vocab = self.tokenizer.vocab
        if special_char is None:
            special_char = {
                "CLS": vocab["[CLS]"],
                "SEP": vocab["[SEP]"],
                "MASK": vocab["[MASK]"],
                "PAD": vocab["[PAD]"]
            }

        sent = [special_char["CLS"]] + sentence_1 + [
            special_char["SEP"]
        ] + sentence_2 + [special_char["SEP"]]
        if len(sent) > self.max_seq_length:
            sent = sent[:self.max_seq_length - 1] + [special_char["SEP"]]
        return sent

    def mask(self, batch_tokens, max_len, total_token_num, special_char=None):
        """
        进行mask覆盖,返回覆盖后的结果和覆盖信息
        """
        vocab = self.tokenizer.vocab
        if special_char is None:
            special_char = {
                "CLS": vocab["[CLS]"],
                "SEP": vocab["[SEP]"],
                "MASK": vocab["[MASK]"],
                "PAD": vocab["[PAD]"]
            }

        vocab_size = len(self.tokenizer.vocab)
        return mask(batch_tokens, max_len, total_token_num, vocab_size,
                    special_char)

    def split_docs_to_sentence(self):
        num = 0
        for doc in self.docs:
            sentences = doc.split('。')
            sentence_tokenized = []
            for sentence in sentences:
                if len(sentence) == 0:
                    continue
                sentence = self.tokenizer.tokenize(sentence)
                sentence = self.tokenizer.convert_tokens_to_ids(sentence)
                if len(sentence) > (self.max_seq_length / 2 - 3):
                    split_length = int(self.max_seq_length / 2) - 3
                    sentence_split = [
                        sentence[i:i + split_length]
                        for i in range(0, len(sentence), split_length)
                    ]
                    sentence_tokenized += sentence_split
                else:
                    sentence_tokenized.append(sentence)

            if len(sentence_tokenized) == 1:
                continue
            for i in range(len(sentence_tokenized) - 1):
                sentence_1 = sentence_tokenized[i]
                sentence_2 = sentence_tokenized[i + 1]
                self.sentence_split.append([sentence_1, sentence_2])
            num += 1
            if num % 3000 == 0:
                self.logger.info('{}docs splited'.format(num))
        self.logger.info('total split sentence {}'.format(
            len(self.sentence_split)))

    def convert_docs_to_features(self):
        if os.path.exists(
                get_fullurl(file_type='datap',
                            file_name='pretrain_corpus_feature',
                            file_format='pickle')):
            self.logger.info('load features from file')
            features = read_file('datap', 'pretrain_corpus_feature', 'pickle')
            self.features = features
            self.logger.info('{} features loaded'.format(len(self.features)))
            return

        self.split_docs_to_sentence()
        features = []
        num = 0
        for sentence_pair in self.sentence_split:
            sentence_1, sentence_2 = sentence_pair
            prob_reverse = np.random.rand()
            if prob_reverse < 0.5:
                reverse_label = 0
                sent_merged = self.merge_sentences(sentence_1, sentence_2)
            else:
                reverse_label = 1
                sent_merged = self.merge_sentences(sentence_2, sentence_1)
            features.append([sent_merged, reverse_label])
            num += 1
            if num % 5000 == 0:
                self.logger.info('{}features created'.format(num))
        save_file(features, 'datap', 'pretrain_corpus_feature', 'pickle')
        self.features = features

    def pad_batch_data(
        self,
        batch_tokens,
        max_len,
        pad_idx=0,
        return_pos=False,
        return_sent=False,
        sep_id=2,
        return_input_mask=False,
        # return_max_len=False,
        # return_num_token=False
    ):
        """
        将句子统一填充到最大句子长度,并生成相应的位置数据和输入覆盖
        """

        return pad_batch_data(batch_tokens, max_len, pad_idx, return_pos,
                              return_sent, sep_id, return_input_mask)

    def data_generator(self):
        src_ids = []
        reverse_labels = []
        total_token_num = 0
        for sentence, reverse_label in self.features:
            reverse_labels.append(reverse_label)
            src_ids.append(sentence)
            total_token_num += len(sentence)
            if len(src_ids) == self.batch_size:
                src_ids, mask_labels, mask_pos = self.mask(
                    src_ids, self.max_seq_length, total_token_num)
                out = self.pad_batch_data(
                    src_ids,
                    self.max_seq_length,
                    return_pos=True,
                    return_sent=True,
                    return_input_mask=True,
                    sep_id=self.tokenizer.convert_tokens_to_ids(['[SEP]'])[0])
                src_ids, pos_ids, sent_ids, input_masks = out[0], out[1], out[
                    2], out[3]
                reverse_labels = np.array(reverse_labels).reshape([-1, 1])
                yield src_ids, pos_ids, sent_ids, input_masks, mask_labels, mask_pos, reverse_labels

                src_ids = []
                reverse_labels = []
                total_token_num = 0
class ProcessorForMultiTask(object):
    def __init__(self, args, logger, examples, feature_file_name, task_id=0, is_prediction=False):
        self.logger = logger
        self.args = args
        self.examples = examples
        self.task_id = task_id
        self.is_prediction = is_prediction
        self.feature_file_name = feature_file_name
        self.max_seq_length = self.args["max_seq_length"]
        # self.logger.info("Prepare to build tokenizer ……")
        self.tokenizer = CToken(
            self.args["vocab_name"], self.args["vocab_format"], self.args["vocab_type"], self.args["do_lowercase"]
        )
        self.logger.info("Successfully build tokenizer")
        # 使用指定的字典,构建tokenizer
        self.batch_size = args['batch_size']
        self.features = []

    def get_vocab_size(self):
        """
        获取使用的词表的大小
        """

        return len(self.tokenizer.vocab)

    def merge_sentences(self, sentence_1, sentence_2, special_char=None):
        vocab = self.tokenizer.vocab
        if special_char is None:
            special_char = {"CLS": vocab["[CLS]"], "SEP": vocab["[SEP]"],
                            "MASK": vocab["[MASK]"], "PAD": vocab["[PAD]"]}

        if len(sentence_1) + len(sentence_2) > self.max_seq_length - 3:
            sentence_1, sentence_2 = self._truncate_seq_pair(sentence_1, sentence_2)

        sent = [special_char["CLS"]] + sentence_1 + [special_char["SEP"]] + sentence_2 + [special_char["SEP"]]
        return sent

    def mask(self, batch_tokens, max_len, total_token_num, special_char=None):
        """
        进行mask覆盖,返回覆盖后的结果和覆盖信息
        """
        vocab = self.tokenizer.vocab
        if special_char is None:
            special_char = {"CLS": vocab["[CLS]"], "SEP": vocab["[SEP]"],
                            "MASK": vocab["[MASK]"], "PAD": vocab["[PAD]"]}

        vocab_size = len(self.tokenizer.vocab)
        return mask(batch_tokens, max_len, total_token_num, vocab_size, special_char)

    def convert_examples_to_features(self):
        if os.path.exists(get_fullurl(file_type='datap', file_name=self.feature_file_name, file_format='pickle')):
            self.logger.info('load features from file')
            features = read_file('datap', self.feature_file_name, 'pickle')
            self.features = features
            self.logger.info('{} features loaded'.format(len(self.features)))
            return

        features = []
        labels = []
        labels_for_reverse = []
        src_ids = []
        qas_ids = []
        label_map = {"Yes": 0, "No": 1, "Depends": 2}
        for example in self.examples:
            if self.is_prediction:
                labels.append(0)
            else:
                labels.append(label_map[example.yes_or_no])
            question_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(example.question))
            answer_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(example.answer))
            if self.task_id == 0:
                prob_reverse = np.random.rand()
                if prob_reverse > 0.5:
                    src_id = self.merge_sentences(question_id, answer_id)
                    labels_for_reverse.append(1)
                else:
                    src_id = self.merge_sentences(answer_id, question_id)
                    labels_for_reverse.append(0)
            if self.task_id == 1:
                src_id = self.merge_sentences(question_id, answer_id)
                labels_for_reverse.append(0)
            if self.task_id == 2:
                src_id = self.merge_sentences(answer_id, question_id)
                labels_for_reverse.append(1)
            src_ids.append(src_id)
            qas_ids.append(example.qas_id)
        src_ids, pos_ids, sent_ids, input_masks = self.pad_batch_data(
            src_ids, self.max_seq_length,
            return_pos=True, return_sent=True,
            return_input_mask=True,
            sep_id=self.tokenizer.vocab['[SEP]'])
        labels = np.array(labels).reshape([-1, 1])
        labels_for_reverse = np.array(labels_for_reverse).reshape([-1, 1])
        for i in range(len(src_ids)):
            features.append(FeatureForMultiTask(qas_ids[i],
                                                src_ids[i],
                                                pos_ids[i],
                                                sent_ids[i],
                                                input_masks[i],
                                                labels[i],
                                                labels_for_reverse[i]))

        save_file(features, 'datap', self.feature_file_name, 'pickle')
        self.features = features

    def pad_batch_data(self,
                       batch_tokens,
                       max_len,
                       pad_idx=0,
                       return_pos=False,
                       return_sent=False, sep_id=2,
                       return_input_mask=False,
                       # return_max_len=False,
                       # return_num_token=False
                       ):
        """
        将句子统一填充到最大句子长度,并生成相应的位置数据和输入覆盖
        """

        return pad_batch_data(batch_tokens, max_len, pad_idx, return_pos, return_sent, sep_id, return_input_mask)

    def sample_generator(self):
        self.logger.info("Preprocessing a new round of data of {}".format(len(self.features)))
        if self.args["shuffle"]:
            random.shuffle(self.features)
        for feature in self.features:
            if self.is_prediction:
                yield feature.qas_id, feature.src_id, feature.pos_id, \
                    feature.sent_id, feature.input_mask
            else:
                yield feature.qas_id, feature.src_id, feature.pos_id, \
                      feature.sent_id, feature.input_mask, feature.label, feature.label_1

    def batch_generator(self):
        reader = fluid.io.batch(self.sample_generator, batch_size=self.args["batch_size"])
        return reader

    def _truncate_seq_pair(self, tokens_a, tokens_b):
        """截短过长的问答对."""
        max_length = self.max_seq_length - 3

        while True:
            total_length = len(tokens_a) + len(tokens_b)
            if total_length <= max_length:
                break
            if len(tokens_a) > len(tokens_b):
                tokens_a.pop()
            else:
                tokens_b.pop()

        return tokens_a, tokens_b
class ProcessorForPretrainingQa():
    def __init__(self, args, logger, questions, answers):
        self.logger = logger
        self.args = args
        assert len(questions) == len(answers)
        self.questions = questions
        self.answers = answers
        self.max_seq_length = self.args["max_seq_length"]
        self.qa_pair = []
        # self.logger.info("Prepare to build tokenizer ……")
        self.tokenizer = CToken(
            self.args["vocab_name"], self.args["vocab_format"], self.args["vocab_type"], self.args["do_lowercase"]
        )
        self.logger.info("Successfully build tokenizer")
        # 使用指定的字典,构建tokenizer
        self.batch_size = args['batch_size']
        self.features = []

    def get_vocab_size(self):
        """
        获取使用的词表的大小
        """

        return len(self.tokenizer.vocab)

    def merge_qa(self, question, answer, special_char=None):
        vocab = self.tokenizer.vocab
        if special_char is None:
            special_char = {"CLS": vocab["[CLS]"], "SEP": vocab["[SEP]"],
                            "MASK": vocab["[MASK]"], "PAD": vocab["[PAD]"]}

        sent = [special_char["CLS"]] + question + [special_char["SEP"]] + answer + [special_char["SEP"]]
        if len(sent) > self.max_seq_length:
            sent = sent[:self.max_seq_length - 1] + [special_char["SEP"]]
        return sent

    def mask(self, batch_tokens, max_len, total_token_num, special_char=None):
        """
        进行mask覆盖,返回覆盖后的结果和覆盖信息
        """
        vocab = self.tokenizer.vocab
        if special_char is None:
            special_char = {"CLS": vocab["[CLS]"], "SEP": vocab["[SEP]"],
                            "MASK": vocab["[MASK]"], "PAD": vocab["[PAD]"]}

        vocab_size = len(self.tokenizer.vocab)
        return mask(batch_tokens, max_len, total_token_num, vocab_size, special_char)

    def split_qa_to_qair(self):
        num = 0
        length = len(self.questions)
        for i in range(length):
            question = self.questions[i]
            answer = self.answers[i]
            question = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(question))
            answer = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(answer))
            question, answer = self._truncate_seq_pair(question, answer)
            self.qa_pair.append([question, answer])
            num += 1
            if num % 3000 == 0:
                self.logger.info('{} pairs generated'.format(num))
        self.logger.info('total split qa_pair {}'.format(len(self.qa_pair)))

    def convert_docs_to_features(self):
        if os.path.exists(get_fullurl(file_type='datap', file_name='pretrain_corpus_feature', file_format='pickle')):
            self.logger.info('load features from file')
            features = read_file('datap', 'pretrain_corpus_feature', 'pickle')
            self.features = features
            self.logger.info('{} features loaded'.format(len(self.features)))
            return

        self.split_qa_to_qair()
        features = []
        num = 0
        for index, qa_pair in enumerate(self.qa_pair):
            question, answer = qa_pair
            random_index = randrange(0, len(self.qa_pair))
            if random_index == index:
                random_index = randrange(0, len(self.qa_pair))
            _, otheranswer = self.qa_pair[random_index]
            prob_otheranswer = np.random.rand()
            if prob_otheranswer < 0.5:
                otheranswer_label = 0
                qa_merged = self.merge_qa(question, answer)
            else:
                otheranswer_label = 1
                qa_merged = self.merge_qa(question, otheranswer)
            features.append([qa_merged, otheranswer_label])
            num += 1
            if num % 5000 == 0:
                self.logger.info('{}features created'.format(num))
        save_file(features, 'datap', 'pretrain_corpus_feature', 'pickle')
        self.features = features

    def pad_batch_data(self,
                       batch_tokens,
                       max_len,
                       pad_idx=0,
                       return_pos=False,
                       return_sent=False, sep_id=2,
                       return_input_mask=False,
                       # return_max_len=False,
                       # return_num_token=False
                       ):
        """
        将句子统一填充到最大句子长度,并生成相应的位置数据和输入覆盖
        """

        return pad_batch_data(batch_tokens, max_len, pad_idx, return_pos, return_sent, sep_id, return_input_mask)

    def data_generator(self):
        src_ids = []
        otheranswer_labels = []
        total_token_num = 0
        index = 0
        for sentence, otheranswer_label in self.features:
            index += 1
            otheranswer_labels.append(otheranswer_label)
            src_ids.append(sentence)
            total_token_num += len(sentence)
            if len(src_ids) == self.batch_size or index == len(self.features):
                src_ids, mask_labels, mask_pos = self.mask(src_ids, self.max_seq_length, total_token_num)
                out = self.pad_batch_data(src_ids, self.max_seq_length,
                                          return_pos=True, return_sent=True, return_input_mask=True,
                                          sep_id=self.tokenizer.convert_tokens_to_ids(['[SEP]'])[0])
                src_ids, pos_ids, sent_ids, input_masks = out[0], out[1], out[2], out[3]
                otheranswer_labels = np.array(otheranswer_labels).reshape([-1, 1])
                yield src_ids, pos_ids, sent_ids, input_masks, mask_labels, mask_pos, otheranswer_labels

                src_ids = []
                otheranswer_labels = []
                total_token_num = 0

    def _truncate_seq_pair(self, tokens_a, tokens_b):
        """
        截短过长的问答对
        """
        max_length = self.max_seq_length - 3
        while True:
            total_length = len(tokens_a) + len(tokens_b)
            if total_length <= max_length:
                break
            if len(tokens_a) > len(tokens_b):
                tokens_a.pop()
            else:
                tokens_b.pop()

        return tokens_a, tokens_b