class SerializableBertTokenizer(pytt.BertTokenizer, SerializationMixin):
    serialization_fields = list(BASE_CLASS_FIELDS) + [
        "vocab",
        "do_basic_tokenize",
        "do_lower_case",
        "never_split",
        "tokenize_chinese_chars",
    ]

    @classmethod
    def blank(cls):
        self = cls.__new__(cls)
        for field in self.serialization_fields:
            setattr(self, field, None)
        self.ids_to_tokens = None
        self.basic_tokenizer = None
        self.wordpiece_tokenizer = None
        return self

    def prepare_for_serialization(self):
        if self.basic_tokenizer is not None:
            self.do_lower_case = self.basic_tokenizer.do_lower_case
            self.never_split = self.basic_tokenizer.never_split
            self.tokenize_chinese_chars = self.basic_tokenizer.tokenize_chinese_chars

    def finish_deserializing(self):
        self.ids_to_tokens = OrderedDict([(ids, tok)
                                          for tok, ids in self.vocab.items()])
        if self.do_basic_tokenize:
            self.basic_tokenizer = BasicTokenizer(
                do_lower_case=self.do_lower_case,
                never_split=self.never_split,
                tokenize_chinese_chars=self.tokenize_chinese_chars,
            )
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab,
                                                      unk_token=self.unk_token)

    def clean_token(self, text):
        if self.do_basic_tokenize:
            text = self.basic_tokenizer._clean_text(text)
        text = text.strip()
        return clean_accents(text)

    def clean_wp_token(self, token):
        return token.replace("##", "", 1).strip()

    def add_special_tokens(self, tokens):
        return [self.cls_token] + tokens + [self.sep_token]
Ejemplo n.º 2
0
class SerializableBertTokenizer(pytt.BertTokenizer, SerializationMixin):
    serialization_fields = list(BASE_CLASS_FIELDS) + [
        "vocab",
        "do_basic_tokenize",
        "do_lower_case",
        "never_split",
        "tokenize_chinese_chars",
    ]

    @classmethod
    def blank(cls):
        self = cls.__new__(cls)
        for field in self.serialization_fields:
            setattr(self, field, None)
        self.ids_to_tokens = None
        self.basic_tokenizer = None
        self.wordpiece_tokenizer = None
        return self

    def prepare_for_serialization(self):
        if self.basic_tokenizer is not None:
            self.do_lower_case = self.basic_tokenizer.do_lower_case
            self.never_split = self.basic_tokenizer.never_split
            self.tokenize_chinese_chars = self.basic_tokenizer.tokenize_chinese_chars

    def finish_deserializing(self):
        self.ids_to_tokens = OrderedDict([(ids, tok)
                                          for tok, ids in self.vocab.items()])
        if self.do_basic_tokenize:
            self.basic_tokenizer = BasicTokenizer(
                do_lower_case=self.do_lower_case,
                never_split=self.never_split,
                tokenize_chinese_chars=self.tokenize_chinese_chars,
            )
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab,
                                                      unk_token=self.unk_token)

    def clean_token(self, text):
        if self.do_basic_tokenize:
            text = self.basic_tokenizer._clean_text(text)
        text = text.strip()
        return clean_accents(text)

    def clean_wp_token(self, token):
        return token.replace("##", "", 1).strip()

    def add_special_tokens(self, segments):
        output = []
        for segment in segments:
            output.extend(segment)
            if segment:
                output.append(self.sep_token)
        if output:
            # If we otherwise would have an empty output, don't add cls
            output.insert(0, self.cls_token)
        return output

    def fix_alignment(self, segments):
        """Turn a nested segment alignment into an alignment for the whole input,
        by offsetting and accounting for special tokens."""
        offset = 0
        output = []
        for segment in segments:
            if segment:
                offset += 1
            seen = set()
            for idx_group in segment:
                output.append([idx + offset for idx in idx_group])
                seen.update({idx for idx in idx_group})
            offset += len(seen)
        return output