Ejemplo n.º 1
0
 def __init__(self, vocab_file):
     self.SYMBOL_GO = '[CLS]'
     self.SYMBOL_EOS = '[SEP]'
     self.SYMBOL_PAD = '[PAD]'
     self.more_tokens = []
     self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                                 do_lower_case=True)
     self.pad_id = self.tokenizer.vocab[self.SYMBOL_PAD]
     self.eos_id = self.tokenizer.vocab[self.SYMBOL_EOS]
     self.go_id = self.tokenizer.vocab[self.SYMBOL_GO]
Ejemplo n.º 2
0
    def build_ids(self, text_a, text_b=None, **kwargs):
        tokenizer = tokenization.FullTokenizer(vocab_file=self.vocab_file,
                                               do_lower_case=True)

        tokens_a = tokenizer.tokenize(text_a)
        tokens_b = None
        if text_b:
            tokens_b = tokenizer.tokenize(text_b)
        if tokens_b:
            # Modifies `tokens_a` and `tokens_b` in place so that the total
            # length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            self._truncate_seq_pair(tokens_a, tokens_b, self.maxlen - 3)
        else:
            # Account for [CLS] and [SEP] with "- 2"
            if len(tokens_a) > self.maxlen - 2:
                tokens_a = tokens_a[0:(self.maxlen - 2)]

        tokens = []
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append("[SEP]")
        segment_ids.append(0)

        if tokens_b:
            for token in tokens_b:
                tokens.append(token)
                segment_ids.append(1)
            tokens.append("[SEP]")
            segment_ids.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        while len(input_ids) < self.maxlen:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        #pdb.set_trace()

        assert len(input_ids) == self.maxlen
        assert len(input_mask) == self.maxlen
        assert len(segment_ids) == self.maxlen
        return input_ids, input_mask, segment_ids