class TrainDataGenerator(object):
    def __init__(self, oce_corpus_path, ocn_corpus_path, tnews_corpus_path,
                 c2n_pickle_path):
        self.oce_data_tuple = []
        self.ocn_data_tuple = []
        self.tnews_data_tuple = []
        self.tokenizer = Tokenizer(CharsVocabPath)
        with open(c2n_pickle_path, 'rb') as f:
            self.classes2num = pickle.load(f)

        with open(oce_corpus_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line:
                    line = line.strip()
                    line = line.split('\t')
                    if line[0] and line[1]:
                        self.oce_data_tuple.append(
                            [self.classes2num[line[0]], line[1]])
        with open(ocn_corpus_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line:
                    line = line.strip()
                    line = line.split('\t')
                    if line[0] and line[1]:
                        self.ocn_data_tuple.append(
                            [self.classes2num[line[0]] - 7, line[1]])
        with open(tnews_corpus_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line:
                    line = line.strip()
                    line = line.split('\t')
                    if line[0] and line[1]:
                        self.tnews_data_tuple.append(
                            [self.classes2num[line[0]] - 10, line[1]])

        self.source_oce_data = self.oce_data_tuple
        self.source_ocn_data = self.ocn_data_tuple
        self.source_tnews_data = self.tnews_data_tuple

        random.shuffle(self.oce_data_tuple)
        random.shuffle(self.ocn_data_tuple)
        random.shuffle(self.tnews_data_tuple)

    def get_length(self):
        return len(self.oce_data_tuple), len(self.ocn_data_tuple), len(
            self.tnews_data_tuple)

    def ret_batch(self):
        self.oce_data_tuple = self.source_oce_data
        self.ocn_data_tuple = self.source_ocn_data
        self.tnews_data_tuple = self.source_tnews_data
        random.shuffle(self.oce_data_tuple)
        random.shuffle(self.ocn_data_tuple)
        random.shuffle(self.tnews_data_tuple)

    def gen_next_batch(self, oce_batch_size, ocn_batch_size, tnews_batch_size):
        output = {}
        batch_max_len = 0
        if len(self.oce_data_tuple) >= oce_batch_size and \
                len(self.ocn_data_tuple) >= ocn_batch_size and \
                len(self.tnews_data_tuple) >= tnews_batch_size:

            oce_current_tuple = self.oce_data_tuple[:oce_batch_size]
            ocn_current_tuple = self.ocn_data_tuple[:ocn_batch_size]
            tnews_current_tuple = self.tnews_data_tuple[:tnews_batch_size]

            self.oce_data_tuple = self.oce_data_tuple[oce_batch_size:]
            self.ocn_data_tuple = self.ocn_data_tuple[ocn_batch_size:]
            self.tnews_data_tuple = self.tnews_data_tuple[tnews_batch_size:]
        else:
            return None

        type_list = []
        label_list = []
        tokens_list = []
        segments_list = []

        for x in oce_current_tuple:
            type_list.append([0])
            label_list.append(x[0])
            token_ids = self.tokenizer.tokens_to_ids(['[CLS]'] +
                                                     x[1].split(' '))
            if len(token_ids) > batch_max_len:
                batch_max_len = len(token_ids)
            tokens_list.append(token_ids)
            segments_list.append([1] * len(token_ids))
        for x in ocn_current_tuple:
            type_list.append([1])
            label_list.append(x[0])
            token_ids = self.tokenizer.tokens_to_ids(['[CLS]'] +
                                                     x[1].split(' '))
            if len(token_ids) > batch_max_len:
                batch_max_len = len(token_ids)
            tokens_list.append(token_ids)
            segments_list.append([1] * len(token_ids))
        for x in tnews_current_tuple:
            type_list.append([2])
            label_list.append(x[0])
            token_ids = self.tokenizer.tokens_to_ids(['[CLS]'] +
                                                     x[1].split(' '))
            if len(token_ids) > batch_max_len:
                batch_max_len = len(token_ids)
            tokens_list.append(token_ids)
            segments_list.append([1] * len(token_ids))

        batch_max_len = min(batch_max_len, SentenceLength)

        for i, tokens in enumerate(tokens_list):
            if len(tokens) < batch_max_len:
                tokens_list[i] = tokens_list[i] + [0] * (batch_max_len -
                                                         len(tokens))
                segments_list[i] = segments_list[i] + [0] * (batch_max_len -
                                                             len(tokens))
            else:
                tokens_list[i] = tokens_list[i][:batch_max_len]
                segments_list[i] = segments_list[i][:batch_max_len]

        output['type_id'] = type_list
        output['input_token_ids'] = tokens_list
        output['position_ids'] = [[
            x for x in range(batch_max_len)
        ] for i in range(oce_batch_size + ocn_batch_size + tnews_batch_size)]
        output['segment_ids'] = segments_list
        output['token_ids_labels'] = label_list
        instance = {
            k: torch.tensor(v, dtype=torch.long)
            for k, v in output.items()
        }
        return instance
示例#2
0
class DataFactory(object):
    def __init__(self):
        self.tokenizer = Tokenizer(VocabPath)
        self.seg = pkuseg.pkuseg()
        self.vocab_size = self.tokenizer._vocab_size
        self.token_pad_id = self.tokenizer._token_pad_id
        self.token_cls_id = self.tokenizer._token_start_id
        self.token_sep_id = self.tokenizer._token_end_id
        self.token_mask_id = self.tokenizer._token_mask_id

    def __token_process(self, token_id):
        """
        以80%的几率替换为[MASK],以10%的几率保持不变,
        以10%的几率替换为一个随机token。
        """
        rand = np.random.random()
        if rand <= 0.8:
            return self.token_mask_id
        elif rand <= 0.9:
            return token_id
        else:
            return np.random.randint(0, self.vocab_size)

    def texts_to_ids(self, texts):
        texts_ids = []
        for text in texts:
            # 处理每个句子
            if ModelClass == 'RobertaMlm':
                # 注意roberta里并不是针对每个字进行mask,而是对字或者词进行mask
                words = self.seg.cut(text)
                for word in words:
                    # text_ids首位分别是cls和sep,这里暂时去除
                    word_tokes = self.tokenizer.tokenize(text=word)[1:-1]
                    words_ids = self.tokenizer.tokens_to_ids(word_tokes)
                    texts_ids.append(words_ids)
            else:
                for word in text:
                    # text_ids首位分别是cls和sep,这里暂时去除
                    word_tokes = self.tokenizer.tokenize(text=word)[1:-1]
                    words_ids = self.tokenizer.tokens_to_ids(word_tokes)
                    texts_ids.append(words_ids)
        return texts_ids

    def ids_to_mask(self, texts_ids):
        instances = []
        total_ids = []
        total_masks = []
        # 为每个字或者词生成一个概率,用于判断是否mask
        mask_rates = np.random.random(len(texts_ids))

        for i, word_id in enumerate(texts_ids):
            # 为每个字生成对应概率
            total_ids.extend(word_id)
            if mask_rates[i] < MaskRate:
                # 因为word_id可能是一个字,也可能是一个词
                for sub_id in word_id:
                    total_masks.append(self.__token_process(sub_id))
            else:
                total_masks.extend([0] * len(word_id))

        # 每个实例的最大长度为512,因此对一个段落进行裁剪
        # 510 = 512 - 2,给cls和sep留的位置
        for i in range(math.ceil(len(total_ids) / (SentenceLength - 2))):
            tmp_ids = [self.token_cls_id]
            tmp_masks = [self.token_pad_id]
            tmp_ids.extend(
                total_ids[i * (SentenceLength - 2):min((i + 1) *
                                                       (SentenceLength -
                                                        2), len(total_ids))])
            tmp_masks.extend(total_masks[i * (SentenceLength - 2):min(
                (i + 1) * (SentenceLength - 2), len(total_masks))])
            # 不足512的使用padding补全
            diff = SentenceLength - len(tmp_ids)
            if diff == 1:
                tmp_ids.append(self.token_sep_id)
                tmp_masks.append(self.token_pad_id)
            else:
                # 添加结束符
                tmp_ids.append(self.token_sep_id)
                tmp_masks.append(self.token_pad_id)
                # 将剩余部分padding补全
                tmp_ids.extend([self.token_pad_id] * (diff - 1))
                tmp_masks.extend([self.token_pad_id] * (diff - 1))
            instances.append([tmp_ids, tmp_masks])
        return instances

    def ids_all_mask(self, texts_ids, tokenid2count):
        instances = []
        tmp_ids = [101]

        # 格式化数据
        for token_ids in texts_ids:
            if isinstance(token_ids, list):
                for token_id in token_ids:
                    tmp_ids.append(token_id)
                    if len(tmp_ids) == SentenceLength - 1:
                        break
            else:
                tmp_ids.append(token_ids)
                if len(tmp_ids) == SentenceLength - 1:
                    break
            if len(tmp_ids) == SentenceLength - 1:
                break

        tmp_ids.append(102)
        input_length = len(tmp_ids) - 2
        if len(tmp_ids) < SentenceLength:
            for i in range(SentenceLength - len(tmp_ids)):
                tmp_ids.append(0)

        for i in range(1, input_length + 1):
            # 如果某字出现次数很少,则强行增加训练集
            if tokenid2count[tmp_ids[i]] < WordGenTimes:
                for j in range(WordGenTimes - tokenid2count[tmp_ids[i]]):
                    tmp_masks = [0] * SentenceLength
                    rand_num = np.random.randint(672, 7992)
                    tmp_masks[i] = rand_num
                    instances.append([tmp_ids, tmp_masks])
            tmp_masks = [0] * SentenceLength
            if random.random() < RanWrongDivisor:
                rand_num = np.random.randint(672, 7992)
                tmp_masks[i] = rand_num
            else:
                tmp_masks[i] = tmp_ids[i]
            instances.append([tmp_ids, tmp_masks])
        return instances
class EvalDataGenerator(object):
    def __init__(self, corpus_path, c2n_pickle_path):
        self.data_tuple = []
        self.corpus_path = corpus_path
        if self.corpus_path == OceEvalPath:
            self.type_id = 0
        if self.corpus_path == OcnEvalPath:
            self.type_id = 1
        if self.corpus_path == TnewsEvalPath:
            self.type_id = 2
        self.tokenizer = Tokenizer(CharsVocabPath)
        with open(c2n_pickle_path, 'rb') as f:
            self.classes2num = pickle.load(f)
        with open(self.corpus_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line:
                    line = line.strip()
                    line = line.split('\t')
                    if line[0] and line[1]:
                        self.data_tuple.append(
                            [self.classes2num[line[0]], line[1]])
        self.source_eval_data = self.data_tuple
        random.shuffle(self.data_tuple)

    def reset_batch(self):
        self.data_tuple = self.source_eval_data
        random.shuffle(self.data_tuple)

    def gen_next_batch(self, batch_size):
        output = {}
        batch_max_len = 0
        if len(self.data_tuple) >= batch_size:
            current_tuple = self.data_tuple[:batch_size]
            self.data_tuple = self.data_tuple[batch_size:]
        else:
            return None

        label_list = []
        tokens_list = []
        segments_list = []

        for x in current_tuple:
            if self.type_id == 0:
                label_list.append(x[0])
            if self.type_id == 1:
                label_list.append(x[0] - 7)
            if self.type_id == 2:
                label_list.append(x[0] - 10)
            token_ids = self.tokenizer.tokens_to_ids(['[CLS]'] +
                                                     x[1].split(' '))
            if len(token_ids) > batch_max_len:
                batch_max_len = len(token_ids)
            tokens_list.append(token_ids)
            segments_list.append([1] * len(token_ids))

        batch_max_len = min(batch_max_len, SentenceLength)

        for i, tokens in enumerate(tokens_list):
            if len(tokens) < batch_max_len:
                tokens_list[i] = tokens_list[i] + [0] * (batch_max_len -
                                                         len(tokens))
                segments_list[i] = segments_list[i] + [0] * (batch_max_len -
                                                             len(tokens))
            else:
                tokens_list[i] = tokens_list[i][:batch_max_len]
                segments_list[i] = segments_list[i][:batch_max_len]

        output['type_id'] = [self.type_id]
        output['input_token_ids'] = tokens_list
        output['position_ids'] = [[x for x in range(batch_max_len)]]
        output['segment_ids'] = segments_list
        output['token_ids_labels'] = label_list
        instance = {
            k: torch.tensor(v, dtype=torch.long)
            for k, v in output.items()
        }
        return instance