def __init__( self, add_bos_token: bool, add_eos_token: bool, use_eos_token_for_bos: bool, max_seq_len: int, vocab: Vocabulary, tokenizer: Optional[Tokenizer], ): super().__init__() if tokenizer is not None and hasattr(tokenizer, "torchscriptify"): try: self.tokenizer = tokenizer.torchscriptify() except NotImplementedError: # This is fine as long as the exported tokenizer is only used # in pre-tokenized mode self.tokenizer = None else: self.tokenizer = None self.do_nothing_tokenizer = ScriptDoNothingTokenizer() self.vocab = ScriptVocabulary( list(vocab), pad_idx=vocab.get_pad_index(), bos_idx=vocab.get_bos_index() if add_bos_token else -1, eos_idx=vocab.get_eos_index() if add_eos_token else -1, ) self.vocab_lookup_1d = VocabLookup(self.vocab) self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token self.use_eos_token_for_bos = use_eos_token_for_bos self.max_seq_len = max_seq_len
def test_lookup_tokens(self): _, rand_tokens = self._mock_tokenizer() vocab = self._mock_vocab() vocab_lookup = VocabLookup(vocab) token_ids, start_idxs, end_idxs = vocab_lookup(rand_tokens) for token_id, token in zip(token_ids, rand_tokens): self.assertEqual(token_id, int(token[0]) - 100)
def test_lookup_tokens_with_bos_eos(self): _, rand_tokens = self._mock_tokenizer() vocab = self._mock_vocab() vocab_lookup = VocabLookup(vocab) token_ids, start_idxs, end_idxs = vocab_lookup(rand_tokens, bos_idx=201, eos_idx=202) self.assertEqual(token_ids[0], 201) self.assertEqual(token_ids[-1], 202) for token_id, token in zip(token_ids[1:-1], rand_tokens): self.assertEqual(token_id, int(token[0]) - 100)
def __init__( self, tokenizer: torch.jit.ScriptModule, vocab: ScriptVocabulary, max_seq_len: int = 100, ): super().__init__() self.tokenizer = tokenizer self.vocab = vocab self.vocab_lookup = VocabLookup(vocab) self.max_seq_len = torch.jit.Attribute(max_seq_len, int)
def __init__(self, tokenizer: Tokenizer, vocab: Vocabulary, max_seq_len: int): super().__init__() self.tokenizer = tokenizer self.vocab = ScriptVocabulary( list(vocab), pad_idx=vocab.get_pad_index(), bos_idx=vocab.get_bos_index(-1), eos_idx=vocab.get_eos_index(-1), unk_idx=vocab.get_unk_index(), ) self.vocab_lookup = VocabLookup(self.vocab) self.max_seq_len = max_seq_len