class SerializableBertTokenizer(transformers.BertTokenizer,
                                SerializationMixin):
    serialization_fields = list(BASE_CLASS_FIELDS) + [
        "vocab",
        "do_basic_tokenize",
        "do_lower_case",
        "never_split",
        "tokenize_chinese_chars",
    ]

    @classmethod
    def blank(cls):
        self = cls.__new__(cls)
        for field in self.serialization_fields:
            setattr(self, field, None)
        self.ids_to_tokens = None
        self.basic_tokenizer = None
        self.wordpiece_tokenizer = None
        return self

    def prepare_for_serialization(self):
        if self.basic_tokenizer is not None:
            self.do_lower_case = self.basic_tokenizer.do_lower_case
            self.never_split = self.basic_tokenizer.never_split
            self.tokenize_chinese_chars = self.basic_tokenizer.tokenize_chinese_chars
        super().prepare_for_serialization()

    def finish_deserializing(self):
        self.ids_to_tokens = OrderedDict([(ids, tok)
                                          for tok, ids in self.vocab.items()])
        if self.do_basic_tokenize:
            self.basic_tokenizer = BasicTokenizer(
                do_lower_case=self.do_lower_case,
                never_split=self.never_split,
                tokenize_chinese_chars=self.tokenize_chinese_chars,
            )
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab,
                                                      unk_token=self.unk_token)
        super().finish_deserializing()

    def clean_token(self, text):
        if self.do_basic_tokenize:
            text = self.basic_tokenizer._clean_text(text)
        text = text.strip()
        return clean_accents(text)

    def clean_wp_token(self, token):
        return token.replace("##", "", 1).strip()

    def add_special_tokens(self, segments):
        output = []
        for segment in segments:
            output.extend(segment)
            if segment:
                output.append(self.sep_token)
        if output:
            # If we otherwise would have an empty output, don't add cls
            output.insert(0, self.cls_token)
        return output

    def fix_alignment(self, segments):
        """Turn a nested segment alignment into an alignment for the whole input,
        by offsetting and accounting for special tokens."""
        offset = 0
        output = []
        for segment in segments:
            if segment:
                offset += 1
            seen = set()
            for idx_group in segment:
                output.append([idx + offset for idx in idx_group])
                seen.update({idx for idx in idx_group})
            offset += len(seen)
        return output
Exemplo n.º 2
0
    def read_sequence(self, dataset, path, is_train, max_sents):
        """
        Reads conllu-like files. It relies heavily on reader_utils.seqs2data.
        Can also read sentence classification tasks for which the labels should
        be specified in the comments.
        Note that this read corresponds to a variety of task_types, but the
        differences between them during data reading are kept minimal
        """
        data = []
        word_idx = self.datasets[dataset]['word_idx']
        sent_counter = 0
        tknzr = BasicTokenizer()

        for sent, full_data in seqs2data(path, self.do_lowercase):
            task2type = {}
            sent_counter += 1
            if max_sents != 0 and sent_counter > max_sents:
                break
            sent_tasks = {}
            tokens = [token[word_idx] for token in sent]
            for tokenIdx in range(len(tokens)):
                if len(tknzr._clean_text(tokens[tokenIdx])) == 0:
                    tokens[tokenIdx] = self.tokenizer.tokenizer.unk_token
            sent_tasks['tokens'] = [Token(token) for token in tokens]

            col_idxs = {'word_idx': word_idx}
            for task in self.datasets[dataset]['tasks']:
                sent_tasks[task] = []
                task_type = self.datasets[dataset]['tasks'][task]['task_type']
                task_idx = self.datasets[dataset]['tasks'][task]['column_idx']
                task2type[task] = task_type
                col_idxs[task] = task_idx
                if task_type == 'classification' and task_idx == -1:
                    start = '# ' + task + ': '
                    for line in full_data:
                        if line[0].startswith(start):
                            sent_tasks[task] = line[0][len(start):]
                elif task_type in ['seq', 'multiseq', 'seq_bio']:
                    for word_data in sent:
                        sent_tasks[task].append(word_data[task_idx])
                elif task_type == 'string2string':
                    for word_data in sent:
                        task_label = gen_lemma_rule(word_data[word_idx],
                                                    word_data[task_idx])
                        sent_tasks[task].append(task_label)
                elif task_type == 'dependency':
                    heads = []
                    rels = []
                    for word_data in sent:
                        if not word_data[task_idx].isdigit():
                            logger.error(
                                "Your dependency file " + path +
                                " seems to contain invalid structures sentence "
                                + str(sent_counter) +
                                " contains a non-integer head: " +
                                word_data[task_idx] +
                                "\nIf you directly used UD data, this could be due to special EUD constructions which we do not support, you can clean your conllu file by using scripts/misc/cleanconl.py"
                            )
                            exit(1)
                        heads.append(int(word_data[task_idx]))
                        rels.append(word_data[task_idx + 1])
                    sent_tasks[task] = list(zip(rels, heads))
                else:
                    logger.error('Task type ' + task_type + ' for task ' +
                                 task + ' in dataset ' + dataset +
                                 ' is unknown')
            data.append(
                self.text_to_instance(sent_tasks, full_data, col_idxs,
                                      is_train, task2type, dataset))
        return data