Exemplo n.º 1
0
 def load_same_pinyin(path, sep='\t'):
     """
     加载同音字
     :param path:
     :param sep:
     :return:
     """
     result = dict()
     if not os.path.exists(path):
         logger.warn("file not exists:" + path)
         return result
     with codecs.open(path, 'r', encoding='utf-8') as f:
         for line in f:
             line = line.strip()
             if line.startswith('#'):
                 continue
             parts = line.split(sep)
             if parts and len(parts) > 2:
                 key_char = parts[0]
                 same_pron_same_tone = set(list(parts[1]))
                 same_pron_diff_tone = set(list(parts[2]))
                 value = same_pron_same_tone.union(same_pron_diff_tone)
                 if key_char and value:
                     result[key_char] = value
     return result
    def __init__(self, config):
        super().__init__(config)

        if not config.is_decoder:
            logger.warn("If you want to use `BertGenerationDecoder` as a standalone, add `is_decoder=True.`")

        self.bert = BertGenerationEncoder(config)
        self.lm_head = BertGenerationOnlyLMHead(config)

        self.init_weights()
Exemplo n.º 3
0
 def load_same_stroke(path, sep='\t'):
     """
     加载形似字
     :param path:
     :param sep:
     :return:
     """
     result = dict()
     if not os.path.exists(path):
         logger.warn("file not exists:" + path)
         return result
     with codecs.open(path, 'r', encoding='utf-8') as f:
         for line in f:
             line = line.strip()
             if line.startswith('#'):
                 continue
             parts = line.split(sep)
             if parts and len(parts) > 1:
                 for i, c in enumerate(parts):
                     result[c] = set(list(parts[:i] + parts[i + 1:]))
     return result
Exemplo n.º 4
0
    def detect_short(self, sentence, start_idx=0):
        """
        检测句子中的疑似错误信息,包括[词、位置、错误类型]
        :param sentence:
        :param start_idx:
        :return: list[list], [error_word, begin_pos, end_pos, error_type]
        """
        maybe_errors = []
        # 初始化
        self.check_detector_initialized()
        # 自定义混淆集加入疑似错误词典
        for confuse in self.custom_confusion:
            idx = sentence.find(confuse)
            if idx > -1:
                maybe_err = [
                    confuse, idx + start_idx, idx + len(confuse) + start_idx,
                    ErrorType.confusion
                ]
                self._add_maybe_error_item(maybe_err, maybe_errors)

        if self.is_word_error_detect:
            # 切词
            tokens = self.tokenizer.tokenize(sentence)
            # 未登录词加入疑似错误词典
            for token, begin_idx, end_idx in tokens:
                # pass filter word
                if self.is_filter_token(token):
                    continue
                # pass in dict
                if token in self.word_freq:
                    continue
                maybe_err = [
                    token, begin_idx + start_idx, end_idx + start_idx,
                    ErrorType.word
                ]
                self._add_maybe_error_item(maybe_err, maybe_errors)

        if self.is_char_error_detect:
            # 语言模型检测疑似错误字
            try:
                ngram_avg_scores = []
                for n in [2, 3]:
                    scores = []
                    for i in range(len(sentence) - n + 1):
                        word = sentence[i:i + n]
                        score = self.ngram_score(list(word))
                        scores.append(score)
                    if not scores:
                        continue
                    # 移动窗口补全得分
                    for _ in range(n - 1):
                        scores.insert(0, scores[0])
                        scores.append(scores[-1])
                    avg_scores = [
                        sum(scores[i:i + n]) / len(scores[i:i + n])
                        for i in range(len(sentence))
                    ]
                    ngram_avg_scores.append(avg_scores)

                if ngram_avg_scores:
                    # 取拼接后的n-gram平均得分
                    sent_scores = list(
                        np.average(np.array(ngram_avg_scores), axis=0))
                    # 取疑似错字信息
                    for i in self._get_maybe_error_index(sent_scores):
                        token = sentence[i]
                        # pass filter word
                        if self.is_filter_token(token):
                            continue
                        # pass in stop word dict
                        if token in self.stopwords:
                            continue
                        # token, begin_idx, end_idx, error_type
                        maybe_err = [
                            token, i + start_idx, i + start_idx + 1,
                            ErrorType.char
                        ]
                        self._add_maybe_error_item(maybe_err, maybe_errors)
            except IndexError as ie:
                logger.warn("index error, sentence:" + sentence + str(ie))
            except Exception as e:
                logger.warn("detect error, sentence:" + sentence + str(e))
        return sorted(maybe_errors, key=lambda k: k[1], reverse=False)
Exemplo n.º 5
0
    def detect_short(self, sentence, start_idx=0):
        """
        检测句子中的疑似错误信息,包括[词、位置、错误类型]
        :param sentence:
        :param start_idx:
        :return: list[list], [error_word, begin_pos, end_pos, error_type]
        """
        maybe_errors = []
        # 初始化
        self.check_detector_initialized()
        # 自定义混淆集加入疑似错误词典
        """
        直接在句子中遍历是否在混淆集中有出现,出现则直接添加到错误列表中。严格的匹配逻辑,可以通过修改混淆集文件,进行词的添加删除。
        """
        for confuse in self.custom_confusion:
            idx = sentence.find(confuse)
            if idx > -1:
                maybe_err = [
                    confuse, idx + start_idx, idx + len(confuse) + start_idx,
                    ErrorType.confusion
                ]
                self._add_maybe_error_item(maybe_err, maybe_errors)
        """
        错误检测部分先通过结巴中文分词器切词,由于句子中含有错别字,所以切词结果往往会有切分错误的情况,
        这样从字粒度和词粒度两方面检测错误, 整合这两种粒度的疑似错误结果,形成疑似错误位置候选集;

        词级别搜索:
        依次进行切词,然后遍历每个词,若词不在词典中,也认为是错误。
        这类词包括一些实体,一些错词,一些没有在词典中出现过,但是是正确的词等。
        这条规则比较严格,错词不放过,但是也错杀了一些其他正确词。
        但是优点同第一,可以灵活修改词典。因此,这步需要一个好的预先构造的词典。
        """
        if self.is_word_error_detect:
            # 切词
            tokens = self.tokenizer.tokenize(sentence)
            # 未登录词加入疑似错误词典
            for token, begin_idx, end_idx in tokens:
                # pass filter word
                if self.is_filter_token(token):
                    continue
                # pass in dict
                if token in self.word_freq:
                    continue
                maybe_err = [
                    token, begin_idx + start_idx, end_idx + start_idx,
                    ErrorType.word
                ]
                self._add_maybe_error_item(maybe_err, maybe_errors)
        """
        与词级别搜索不同,字级别不需要进行切词,依次进行打分。分数由一个基于人民日报语料预训练好的语言模型得出。 
        具体计算步骤如下: 
        1. 计算基于字的2-gram和3-gram的得分列表,二者取平均得到sent的每个字的分数。 
        2. 根据每个字的平均得分列表,找到可能的错误字的位置。

        根据每个字的平均得分列表,找到可能的错误字的位置(self._get_maybe_error_index);
        因此,这里要考虑找错的具体逻辑。代码中的实现是基于类似平均绝对离差(MAD)的统计概念,
        这里也是一个策略上的改进的方向,甚至多种策略的共同组合判断。
        """
        if self.is_char_error_detect:
            # 语言模型检测疑似错误字
            try:
                ngram_avg_scores = []
                for n in [2, 3]:
                    scores = []
                    for i in range(len(sentence) - n + 1):
                        word = sentence[i:i + n]
                        # 这里的ngram_score底层实现都是使用language model
                        score = self.ngram_score(list(word))
                        scores.append(score)
                    if not scores:
                        continue
                    # 移动窗口补全得分
                    for _ in range(n - 1):
                        scores.insert(0, scores[0])
                        scores.append(scores[-1])
                    avg_scores = [
                        sum(scores[i:i + n]) / len(scores[i:i + n])
                        for i in range(len(sentence))
                    ]
                    ngram_avg_scores.append(avg_scores)

                if ngram_avg_scores:
                    # 取拼接后的n-gram平均得分
                    sent_scores = list(
                        np.average(np.array(ngram_avg_scores), axis=0))
                    # 取疑似错字信息
                    for i in self._get_maybe_error_index(sent_scores):
                        token = sentence[i]
                        # pass filter word
                        if self.is_filter_token(token):
                            continue
                        # pass in stop word dict
                        if token in self.stopwords:
                            continue
                        # token, begin_idx, end_idx, error_type
                        maybe_err = [
                            token, i + start_idx, i + start_idx + 1,
                            ErrorType.char
                        ]
                        self._add_maybe_error_item(maybe_err, maybe_errors)
            except IndexError as ie:
                logger.warn("index error, sentence:" + sentence + str(ie))
            except Exception as e:
                logger.warn("detect error, sentence:" + sentence + str(e))
        return sorted(maybe_errors, key=lambda k: k[1], reverse=False)
Exemplo n.º 6
0
    def detect(self, sentence):
        """
        检测句子中的疑似错误信息,包括[词、位置、错误类型]
        :param sentence:
        :return: [error_word, begin_pos, end_pos, error_type]
        """
        maybe_errors = []
        if not sentence.strip():
            return maybe_errors
        self.check_detector_initialized()
        # 文本归一化
        sentence = uniform(sentence)
        # 切词
        tokens = self.tokenizer.tokenize(sentence)
        # print(tokens)
        # 自定义混淆集加入疑似错误词典
        for confuse in self.custom_confusion:
            idx = sentence.find(confuse)
            if idx > -1:
                maybe_err = [
                    confuse, idx, idx + len(confuse), error_type["confusion"]
                ]
                self._add_maybe_error_item(maybe_err, maybe_errors)

        if self.is_word_error_detect:
            # 未登录词加入疑似错误词典
            for word, begin_idx, end_idx in tokens:
                # pass blank
                if not word.strip():
                    continue
                # pass punctuation
                if word in PUNCTUATION_LIST:
                    continue
                # pass num
                if word.isdigit():
                    continue
                # pass alpha
                if is_alphabet_string(word.lower()):
                    continue
                # pass in dict
                if word in self.word_freq:
                    continue
                maybe_err = [word, begin_idx, end_idx, error_type["word"]]
                self._add_maybe_error_item(maybe_err, maybe_errors)

        if self.is_char_error_detect:
            # 语言模型检测疑似错误字
            ngram_avg_scores = []
            try:
                for n in [2, 3]:
                    scores = []
                    for i in range(len(sentence) - n + 1):
                        word = sentence[i:i + n]
                        score = self.ngram_score(list(word))
                        scores.append(score)
                    if not scores:
                        continue
                    # 移动窗口补全得分
                    for _ in range(n - 1):
                        scores.insert(0, scores[0])
                        scores.append(scores[-1])
                    avg_scores = [
                        sum(scores[i:i + n]) / len(scores[i:i + n])
                        for i in range(len(sentence))
                    ]
                    ngram_avg_scores.append(avg_scores)

                # 取拼接后的ngram平均得分
                sent_scores = list(
                    np.average(np.array(ngram_avg_scores), axis=0))
                # 取疑似错字信息
                for i in self._get_maybe_error_index(sent_scores):
                    maybe_err = [sentence[i], i, i + 1, error_type["char"]]
                    self._add_maybe_error_item(maybe_err, maybe_errors)
            except IndexError as ie:
                logger.warn("index error, sentence:" + sentence + str(ie))
            except Exception as e:
                logger.warn("detect error, sentence:" + sentence + str(e))
        return sorted(maybe_errors, key=lambda k: k[1], reverse=False)
Exemplo n.º 7
0
    os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
default_cache_path = os.path.join(hf_cache_home, "transformers")

# Onetime move from the old location to the new one if no ENV variable has been set.
if (
        os.path.isdir(old_default_cache_path)
        and not os.path.isdir(default_cache_path)
        and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ
        and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ
        and "TRANSFORMERS_CACHE" not in os.environ
):
    logger.warn(
        "In Transformers v4.0.0, the default path to cache downloaded models changed from "
        "'~/.cache/torch/transformers' to '~/.cache/huggingface/transformers'. Since you don't seem to have overridden "
        "and '~/.cache/torch/transformers' is a directory that exists, we're moving it to "
        "'~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should "
        "only see this message once."
    )
    shutil.move(old_default_cache_path, default_cache_path)

PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)

WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = "tf_model.h5"
TF_WEIGHTS_NAME = "model.ckpt"
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
CONFIG_NAME = "config.json"
MODEL_CARD_NAME = "modelcard.json"
Exemplo n.º 8
0
    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
        r"""
        Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.

        The tokenizer class to instantiate is selected based on the :obj:`model_type` property of the config object
        (either passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's
        missing, by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:

        List options

        Params:
            pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
                Can be either:

                    - A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
                      Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
                      a user or organization name, like ``dbmdz/bert-base-german-cased``.
                    - A path to a `directory` containing vocabulary files required by the tokenizer, for instance saved
                      using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.,
                      ``./my_model_directory/``.
                    - A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
                      single vocabulary file (like Bert or XLNet), e.g.: ``./my_model_directory/vocab.txt``. (Not
                      applicable to all derived classes)
            inputs (additional positional arguments, `optional`):
                Will be passed along to the Tokenizer ``__init__()`` method.
            config (:class:`~transformers.PreTrainedConfig`, `optional`)
                The configuration object used to dertermine the tokenizer class to instantiate.
            cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download the model weights and configuration files and override the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies (:obj:`Dict[str, str]`, `optional`):
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing model_files and other artifacts on huggingface.co, so ``revision`` can be any
                identifier allowed by git.
            subfolder (:obj:`str`, `optional`):
                In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
                facebook/rag-token-base), specify it here.
            use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to try to load the fast version of the tokenizer.
            kwargs (additional keyword arguments, `optional`):
                Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
                ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
                ``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__()`` for more details.

        Examples::

            >>> from transformers import AutoTokenizer

            >>> # Download vocabulary from huggingface.co and cache.
            >>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

            >>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
            >>> tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-german-cased')

            >>> # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
            >>> tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')

        """
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path,
                                                **kwargs)

        use_fast = kwargs.pop("use_fast", True)

        if config.tokenizer_class is not None:
            tokenizer_class = None
            if use_fast and not config.tokenizer_class.endswith("Fast"):
                tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
                tokenizer_class = tokenizer_class_from_name(
                    tokenizer_class_candidate)
            if tokenizer_class is None:
                tokenizer_class_candidate = config.tokenizer_class
                tokenizer_class = tokenizer_class_from_name(
                    tokenizer_class_candidate)

            if tokenizer_class is None:
                raise ValueError(
                    "Tokenizer class {} does not exist or is not currently imported."
                    .format(tokenizer_class_candidate))
            return tokenizer_class.from_pretrained(
                pretrained_model_name_or_path, *inputs, **kwargs)

        # if model is an encoder decoder, the encoder tokenizer class is used by default
        if isinstance(config, EncoderDecoderConfig):
            if type(config.decoder) is not type(config.encoder):  # noqa: E721
                logger.warn(
                    f"The encoder model config class: {config.encoder.__class__} is different from the decoder model "
                    f"config class: {config.decoder.__class}. It is not recommended to use the "
                    "`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder "
                    "specific tokenizer classes.")
            config = config.encoder

        if type(config) in TOKENIZER_MAPPING.keys():
            tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(
                config)]
            if tokenizer_class_fast and (use_fast
                                         or tokenizer_class_py is None):
                return tokenizer_class_fast.from_pretrained(
                    pretrained_model_name_or_path, *inputs, **kwargs)
            else:
                if tokenizer_class_py is not None:
                    return tokenizer_class_py.from_pretrained(
                        pretrained_model_name_or_path, *inputs, **kwargs)
                else:
                    raise ValueError(
                        "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
                        "in order to use this tokenizer.")

        raise ValueError(
            "Unrecognized configuration class {} to build an AutoTokenizer.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                ", ".join(c.__name__ for c in TOKENIZER_MAPPING.keys())))