예제 #1
0
  def test_convert_tokens_to_ids(self):
    vocab_tokens = [
        "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
        "##ing"
    ]

    vocab = {}
    for (i, token) in enumerate(vocab_tokens):
      vocab[token] = i

    self.assertAllEqual(
        tokenization.convert_tokens_to_ids(
            vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
예제 #2
0
    def tokenize(self, text, tags_list=None, test_id=None):
        """
        使用wordpiece对norm token切分后得到的token
        :return:
        """
        norm_data = self.__norm_tokenizer.tokenize(text, tags_list, test_id)
        self.return_tag = norm_data['norm_tags'] != None
        new_tokens, new_sents, new_tags = [], [], None
        tok_to_orig_index, orig_to_tok_index = [], []

        input_ids, tag_ids = [], None
        input_mask, segment_ids = [], []
        type_vocab = None
        new_chars, char_ids = [], []
        if self.return_tag:
            new_tags, tag_ids = [], []
            type_vocab = build_tag_vocab()

        ch_tokenizer = CharTokenizer(self.max_char_len)

        norm_sents = norm_data['tokens_by_sent']
        norm_tags = norm_data['norm_tags']

        for sent_id, sent in enumerate(norm_sents):
            tok_to_orig, orig_to_tok = [], []
            new_sent, new_char, new_char_ids = [], [], []
            new_sent.append("[CLS]")
            new_char.append(ch_tokenizer.tokenize()[0])
            new_char_ids.append(ch_tokenizer.tokenize()[1])
            new_tag, eval_m = [], [] if self.return_tag else None
            if self.return_tag:
                new_tag.append("BOS")
            for word_id, word in enumerate(sent):
                sub_tokens = self.__wordpiece_tokenizer.tokenize(word)
                if len(new_sent) + len(sub_tokens) >= self.max_seq_len:
                    norm_sents.insert(sent_id + 1, sent[word_id:])
                    if self.return_tag:
                        norm_tags.insert(sent_id + 1,
                                         norm_tags[sent_id][word_id:])
                    break
                orig_to_tok.append(len(new_sent))
                (c, c_ids) = ch_tokenizer.tokenize(word)
                for idx, sub_token in enumerate(sub_tokens):
                    tok_to_orig.append(word_id)
                    new_sent.append(sub_token)
                    new_tokens.append(sub_token)
                    new_char.append(c)
                    new_char_ids.append(c_ids)
                    if self.return_tag:
                        if idx == 0:
                            new_tag.append(norm_tags[sent_id][word_id])
                        else:
                            new_tag.append('X')
            new_sent.append("[SEP]")
            new_char.append(ch_tokenizer.tokenize()[0]), new_char_ids.append(
                ch_tokenizer.tokenize()[1])
            if self.return_tag:
                new_tag.append("EOS")
            orig_to_tok_index.append(orig_to_tok)
            tok_to_orig_index.append(tok_to_orig)
            assert len(new_sent) == len(tok_to_orig) + 2
            mask = [1] * len(new_sent)

            if len(new_sent) > self.max_seq_len:
                raise ValueError("The sentence is longger than: %d" %
                                 (self.max_seq_len))
            while len(new_sent) < self.max_seq_len:
                new_sent.append("[PAD]")
                mask.append(0)
                new_char.append(
                    ch_tokenizer.tokenize()[0]), new_char_ids.append(
                        ch_tokenizer.tokenize()[1])
                if self.return_tag:
                    new_tag.append("O")

            assert len(new_sent) == self.max_seq_len
            assert len(new_char) == self.max_seq_len
            assert len(new_char_ids) == self.max_seq_len

            new_sents.append(new_sent), new_chars.append(new_char)
            char_ids.append(new_char_ids)
            input_ids.append(
                tokenization.convert_tokens_to_ids(self.vocab, new_sent))
            input_mask.append(mask)
            segment_ids.append([0] * self.max_seq_len)
            if self.return_tag:
                assert len(new_sent) == len(new_tag)
                new_tags.append(new_tag)
                tag_ids.append(
                    tokenization.convert_tokens_to_ids(type_vocab, new_tag))

        assert len(new_sents) == len(tok_to_orig_index)
        assert len(new_sents) == len(orig_to_tok_index)
        if self.return_tag:
            assert len(new_sents) == len(new_tags)

        wp_data = {}
        wp_data['tokens'] = new_tokens
        wp_data['tokens_by_sent'] = new_sents
        wp_data['tok_to_orig_index'] = tok_to_orig_index
        wp_data['orig_to_tok_index'] = orig_to_tok_index
        wp_data[
            'norm_data'] = norm_data  # 包括了切分前的tokens、tags和char_to_word_offset
        wp_data['wp_tags'] = new_tags
        wp_data["input_ids"] = input_ids
        wp_data["input_mask"] = input_mask
        wp_data["segment_ids"] = segment_ids
        wp_data["tag_ids"] = tag_ids
        wp_data["chars"] = new_chars
        wp_data["char_ids"] = char_ids

        return wp_data
 def convert_tokens_to_ids(self, tokens):
     return convert_tokens_to_ids(self.vocab,tokens)