Exemplo n.º 1
0
    def __init__(
        self,
        add_bos_token: bool,
        add_eos_token: bool,
        use_eos_token_for_bos: bool,
        max_seq_len: int,
        vocab: Vocabulary,
        tokenizer: Optional[Tokenizer],
    ):
        super().__init__()

        if tokenizer is not None and hasattr(tokenizer, "torchscriptify"):
            try:
                self.tokenizer = tokenizer.torchscriptify()
            except NotImplementedError:
                # This is fine as long as the exported tokenizer is only used
                # in pre-tokenized mode
                self.tokenizer = None
        else:
            self.tokenizer = None

        self.do_nothing_tokenizer = ScriptDoNothingTokenizer()
        self.vocab = ScriptVocabulary(
            list(vocab),
            pad_idx=vocab.get_pad_index(),
            bos_idx=vocab.get_bos_index() if add_bos_token else -1,
            eos_idx=vocab.get_eos_index() if add_eos_token else -1,
        )
        self.vocab_lookup_1d = VocabLookup(self.vocab)

        self.add_bos_token = add_bos_token
        self.add_eos_token = add_eos_token
        self.use_eos_token_for_bos = use_eos_token_for_bos
        self.max_seq_len = max_seq_len
Exemplo n.º 2
0
    def test_lookup_tokens(self):
        _, rand_tokens = self._mock_tokenizer()
        vocab = self._mock_vocab()
        vocab_lookup = VocabLookup(vocab)
        token_ids, start_idxs, end_idxs = vocab_lookup(rand_tokens)

        for token_id, token in zip(token_ids, rand_tokens):
            self.assertEqual(token_id, int(token[0]) - 100)
Exemplo n.º 3
0
 def test_lookup_tokens_with_bos_eos(self):
     _, rand_tokens = self._mock_tokenizer()
     vocab = self._mock_vocab()
     vocab_lookup = VocabLookup(vocab)
     token_ids, start_idxs, end_idxs = vocab_lookup(rand_tokens,
                                                    bos_idx=201,
                                                    eos_idx=202)
     self.assertEqual(token_ids[0], 201)
     self.assertEqual(token_ids[-1], 202)
     for token_id, token in zip(token_ids[1:-1], rand_tokens):
         self.assertEqual(token_id, int(token[0]) - 100)
Exemplo n.º 4
0
 def __init__(
     self,
     tokenizer: torch.jit.ScriptModule,
     vocab: ScriptVocabulary,
     max_seq_len: int = 100,
 ):
     super().__init__()
     self.tokenizer = tokenizer
     self.vocab = vocab
     self.vocab_lookup = VocabLookup(vocab)
     self.max_seq_len = torch.jit.Attribute(max_seq_len, int)
Exemplo n.º 5
0
 def __init__(self, tokenizer: Tokenizer, vocab: Vocabulary, max_seq_len: int):
     super().__init__()
     self.tokenizer = tokenizer
     self.vocab = ScriptVocabulary(
         list(vocab),
         pad_idx=vocab.get_pad_index(),
         bos_idx=vocab.get_bos_index(-1),
         eos_idx=vocab.get_eos_index(-1),
         unk_idx=vocab.get_unk_index(),
     )
     self.vocab_lookup = VocabLookup(self.vocab)
     self.max_seq_len = max_seq_len