예제 #1
0
    def correct(self,
                text,
                include_symbol=True,
                num_fragment=1,
                threshold=57,
                **kwargs):
        """
        句子改错
        :param text: str, query 文本
        :param include_symbol: bool, 是否包含标点符号
        :param num_fragment: 纠错候选集分段数, 1 / (num_fragment + 1)
        :param threshold: 语言模型纠错ppl阈值
        :param kwargs: ...
        :return: text (str)改正后的句子, list(wrong, right, begin_idx, end_idx)
        """
        text_new = ''
        details = []
        self.check_corrector_initialized()
        # 编码统一,utf-8 to unicode
        text = convert_to_unicode(text)
        # 长句切分为短句
        blocks = self.split_2_short_text(text, include_symbol=include_symbol)
        for blk, idx in blocks:
            maybe_errors = self.detect_short(blk, idx)
            """
            错误纠正部分,是遍历所有的疑似错误位置,并使用音似、形似词典替换错误位置的词,然后通过语言模型计算句子困惑度,对所有候选集结果比较并排序,得到最优纠正词。
            """
            for cur_item, begin_idx, end_idx, err_type in maybe_errors:
                # 纠错,逐个处理
                before_sent = blk[:(begin_idx - idx)]
                after_sent = blk[(end_idx - idx):]

                # 困惑集中指定的词,直接取结果
                if err_type == ErrorType.confusion:
                    corrected_item = self.custom_confusion[cur_item]
                else:
                    # 取得所有可能正确的词
                    candidates = self.generate_items(cur_item,
                                                     fragment=num_fragment)
                    if not candidates:
                        continue

                    # 语言模型选择最合适的纠正词
                    corrected_item = self.get_lm_correct_item(
                        cur_item,
                        candidates,
                        before_sent,
                        after_sent,
                        threshold=threshold)
                # output
                if corrected_item != cur_item:
                    blk = before_sent + corrected_item + after_sent
                    detail_word = [
                        cur_item, corrected_item, begin_idx, end_idx
                    ]
                    details.append(detail_word)
            text_new += blk
        details = sorted(details, key=operator.itemgetter(2))
        return text_new, details
예제 #2
0
    def electra_correct(self, text):
        """
        句子纠错
        :param text: 句子文本
        :return: corrected_text, list[list], [error_word, correct_word, begin_pos, end_pos]
        """
        text_new = ''
        details = []
        # 编码统一,utf-8 to unicode
        text = convert_to_unicode(text)
        # 长句切分为短句
        blocks = self.split_2_short_text(text, include_symbol=True)
        for blk, start_idx in blocks:
            error_ids = self.electra_detect(blk)
            sentence_lst = list(blk)
            for idx in error_ids:
                s = sentence_lst[idx]
                if is_chinese_string(s):
                    # 处理中文错误
                    sentence_lst[idx] = self.mask
                    sentence_new = ''.join(sentence_lst)
                    # 生成器fill-mask预测[mask],默认取top5
                    predicts = self.g_model(sentence_new)
                    top_tokens = []
                    for p in predicts:
                        token_id = p.get('token', 0)
                        token_str = self.g_model.tokenizer.convert_ids_to_tokens(
                            token_id)
                        top_tokens.append(token_str)

                    if top_tokens and (s not in top_tokens):
                        # 取得所有可能正确的词
                        candidates = self.generate_items(s)
                        if candidates:
                            for token_str in top_tokens:
                                if token_str in candidates:
                                    details.append([
                                        s, token_str, start_idx + idx,
                                        start_idx + idx + 1
                                    ])
                                    sentence_lst[idx] = token_str
                                    break
                    # 还原
                    if sentence_lst[idx] == self.mask:
                        sentence_lst[idx] = s

            blk_new = ''.join(sentence_lst)
            text_new += blk_new
        details = sorted(details, key=operator.itemgetter(2))
        return text_new, details
예제 #3
0
    def macbert_correct(self, text):
        """
        句子纠错
        :param text: 句子文本
        :return: corrected_text, list[list], [error_word, correct_word, begin_pos, end_pos]
        """
        text_new = ''
        details = []
        # 编码统一,utf-8 to unicode
        text = convert_to_unicode(text)
        # 长句切分为短句
        blocks = split_text_by_maxlen(text, maxlen=128)
        blocks = [block[0] for block in blocks]
        inputs = self.tokenizer(blocks, padding=True,
                                return_tensors='pt').to(device)
        with torch.no_grad():
            outputs = self.model(**inputs)

        def get_errors(corrected_text, origin_text):
            sub_details = []
            for i, ori_char in enumerate(origin_text):
                if ori_char in self.unk_tokens:
                    # add unk word
                    corrected_text = corrected_text[:
                                                    i] + ori_char + corrected_text[
                                                        i:]
                    continue
                if i >= len(corrected_text):
                    continue
                if ori_char != corrected_text[i]:
                    if ori_char.lower() == corrected_text[i]:
                        # pass english upper char
                        corrected_text = corrected_text[:
                                                        i] + ori_char + corrected_text[
                                                            i + 1:]
                        continue
                    sub_details.append((ori_char, corrected_text[i], i, i + 1))
            sub_details = sorted(sub_details, key=operator.itemgetter(2))
            return corrected_text, sub_details

        for ids, text in zip(outputs.logits, blocks):
            decode_tokens = self.tokenizer.decode(
                torch.argmax(ids, dim=-1),
                skip_special_tokens=True).replace(' ', '')
            corrected_text = decode_tokens[:len(text)]
            corrected_text, sub_details = get_errors(corrected_text, text)
            text_new += corrected_text
            details.extend(sub_details)
        return text_new, details
예제 #4
0
 def detect(self, text):
     maybe_errors = []
     if not text.strip():
         return maybe_errors
     # 初始化
     self.check_detector_initialized()
     # 编码统一,utf-8 to unicode
     text = convert_to_unicode(text)
     # 文本归一化
     text = uniform(text)
     # 长句切分为短句
     blocks = self.split_2_short_text(text)
     for blk, idx in blocks:
         maybe_errors += self.detect_short(blk, idx)
     return maybe_errors
예제 #5
0
    def bert_correct(self, text):
        """
        句子纠错
        :param text: 句子文本
        :return: corrected_text, list[list], [error_word, correct_word, begin_pos, end_pos]
        """
        text_new = ''
        details = []
        self.check_corrector_initialized()
        # 编码统一,utf-8 to unicode
        text = convert_to_unicode(text)
        # 长句切分为短句
        blocks = self.split_text_by_maxlen(text, maxlen=128)
        for blk, start_idx in blocks:
            blk_new = ''
            for idx, s in enumerate(blk):
                # 处理中文错误
                if is_chinese_string(s):
                    sentence_lst = list(blk_new + blk[idx:])
                    sentence_lst[idx] = self.mask
                    sentence_new = ''.join(sentence_lst)
                    # 预测,默认取top5
                    predicts = self.model(sentence_new)
                    top_tokens = []
                    for p in predicts:
                        token_id = p.get('token', 0)
                        token_str = self.model.tokenizer.convert_ids_to_tokens(
                            token_id)
                        top_tokens.append(token_str)

                    if top_tokens and (s not in top_tokens):
                        # 取得所有可能正确的词
                        candidates = self.generate_items(s)
                        if candidates:
                            for token_str in top_tokens:
                                if token_str in candidates:
                                    details.append([
                                        s, token_str, start_idx + idx,
                                        start_idx + idx + 1
                                    ])
                                    s = token_str
                                    break
                blk_new += s
            text_new += blk_new
        details = sorted(details, key=operator.itemgetter(2))
        return text_new, details
예제 #6
0
    def ernie_correct(self, text, ernie_cut_type='char'):
        """
        句子纠错
        :param text: 句子文本
        :param ernie_cut_type: 切词类型(char/word)
        :return: corrected_text, list[list], [error_word, correct_word, begin_pos, end_pos]
        """
        text_new = ''
        details = []
        self.check_corrector_initialized()
        # 编码统一,utf-8 to unicode
        text = convert_to_unicode(text)
        # 长句切分为短句
        blocks = self.split_text_by_maxlen(text, maxlen=512)
        for blk, start_idx in blocks:
            blk_new = ''
            blk = segment(blk, cut_type=ernie_cut_type, pos=False)
            for idx, s in enumerate(blk):
                # 处理中文错误
                if is_chinese_string(s):
                    sentence_lst = blk[:idx] + blk[idx:]
                    sentence_lst[idx] = self.mask_token * len(s)
                    sentence_new = ' '.join(sentence_lst)
                    # 预测,默认取top5
                    predicts = self.predict_mask(sentence_new)
                    top_tokens = []
                    for p in predicts:
                        top_tokens.append(p.get('token', ''))

                    if top_tokens and (s not in top_tokens):
                        # 取得所有可能正确的词
                        candidates = self.generate_items(s)
                        if candidates:
                            for token_str in top_tokens:
                                if token_str in candidates:
                                    details.append([
                                        s, token_str, start_idx + idx,
                                        start_idx + idx + 1
                                    ])
                                    s = token_str
                                    break
                blk_new += s
            text_new += blk_new
        details = sorted(details, key=operator.itemgetter(2))
        return text_new, details
예제 #7
0
    def correct(self, text):
        """
        句子改错
        :param text: 文本
        :return: 改正后的句子, list(wrong, right, begin_idx, end_idx)
        """
        text_new = ''
        details = []
        self.check_corrector_initialized()
        # 编码统一,utf-8 to unicode
        text = convert_to_unicode(text)
        # 长句切分为短句
        blocks = self.split_2_short_text(text, include_symbol=True)
        for blk, idx in blocks:
            maybe_errors = self.detect_short(blk, idx)
            for cur_item, begin_idx, end_idx, err_type in maybe_errors:
                # 纠错,逐个处理
                before_sent = blk[:(begin_idx - idx)]
                after_sent = blk[(end_idx - idx):]

                # 困惑集中指定的词,直接取结果
                if err_type == ErrorType.confusion:
                    corrected_item = self.custom_confusion[cur_item]
                else:
                    # 取得所有可能正确的词
                    candidates = self.generate_items(cur_item)
                    if not candidates:
                        continue
                    corrected_item = self.get_lm_correct_item(
                        cur_item, candidates, before_sent, after_sent)
                # output
                if corrected_item != cur_item:
                    blk = before_sent + corrected_item + after_sent
                    detail_word = [
                        cur_item, corrected_item, begin_idx, end_idx
                    ]
                    details.append(detail_word)
            text_new += blk
        details = sorted(details, key=operator.itemgetter(2))
        return text_new, details
예제 #8
0
 def detect(self, sentence):
     maybe_errors = []
     if not sentence.strip():
         return maybe_errors
     # 初始化
     self.check_detector_initialized()
     # 编码统一,utf-8 to unicode
     sentence = convert_to_unicode(sentence)
     # 文本归一化
     sentence = uniform(sentence)
     # 长句切分为短句
     blocks = re_han.split(sentence)
     start_idx = 0
     for blk in blocks:
         if not blk:
             continue
         if re_han.match(blk):
             maybe_errors += self._detect_short(blk, start_idx)
             start_idx += len(blk)
         else:
             start_idx += len(blk)
     return maybe_errors
예제 #9
0
    def predict(self, text, **kwargs):
        details = []
        text_new = ''
        self.check_corrector_initialized()
        # 编码统一,utf-8 to unicode
        text = convert_to_unicode(text)
        # 长句切分为短句
        blocks = split_text_by_maxlen(text, maxlen=128)
        for blk, start_idx in blocks:
            blk_new = ''
            for idx, s in enumerate(blk):
                # 处理中文错误
                if is_chinese_string(s):
                    sentence_lst = list(blk_new + blk[idx:])
                    sentence_lst[idx] = self.mask
                    # 预测,默认取top10
                    predict_words = self.predict_mask_token(sentence_lst,
                                                            idx,
                                                            k=10)
                    top_tokens = []
                    for w, _ in predict_words:
                        top_tokens.append(w)

                    if top_tokens and (s not in top_tokens):
                        # 取得所有可能正确的词
                        candidates = self.generate_items(s)
                        if candidates:
                            for token_str in top_tokens:
                                if token_str in candidates:
                                    details.append(
                                        (s, token_str, start_idx + idx,
                                         start_idx + idx + 1))
                                    s = token_str
                                    break
                blk_new += s
            text_new += blk_new
        details = sorted(details, key=operator.itemgetter(2))
        return text_new, details