def _run_benchmark_pytext_script_vocab(toks, v: PytextScriptVocabulary): # list lookup if isinstance(toks, list) and isinstance(toks[0], list): for tokens_list in toks: v.lookup_indices_1d(tokens_list) # single token lookup elif isinstance(toks, list): for token in toks: v.lookup_indices_1d([token]) else: raise RuntimeError("Received tokens of incorrect type {}.".format( type(toks)))
class LabelTransform(nn.Module): def __init__(self, label_names: List[str]): super().__init__() self.label_vocab = ScriptVocabulary(sorted(label_names)) def forward(self, labels: List[str]) -> List[int]: return self.label_vocab.lookup_indices_1d(labels)
class VocabTransform(Transform): def __init__(self, vocab: Vocabulary): super().__init__() self.vocab = ScriptVocabulary( list(vocab), pad_idx=vocab.get_pad_index(-1), bos_idx=vocab.get_bos_index(-1), eos_idx=vocab.get_eos_index(-1), unk_idx=vocab.get_unk_index(-1), ) def forward(self, tokens: Tokens) -> Dict[str, torch.Tensor]: token_ids: List[int] = self.vocab.lookup_indices_1d(tokens.token_texts) return { "token_ids": torch.tensor(token_ids, dtype=torch.long), "start_ids": torch.tensor(tokens.start_ids, dtype=torch.long), "end_ids": torch.tensor(tokens.end_ids, dtype=torch.long), } @property def is_jitable(self) -> bool: return True
class Seq2SeqJIT(torch.nn.Module): def __init__( self, src_dict, tgt_dict, sequence_generator, filter_eos_bos, copy_unk_token=False, dictfeat_dict=None, ): super().__init__() self.source_vocab = ScriptVocabulary( src_dict._vocab, src_dict.get_unk_index(), bos_idx=src_dict.get_bos_index(-1), eos_idx=src_dict.get_eos_index(-1), ) self.target_vocab = ScriptVocabulary( tgt_dict._vocab, tgt_dict.get_unk_index(), bos_idx=tgt_dict.get_bos_index(), eos_idx=tgt_dict.get_eos_index(), ) if dictfeat_dict: self.dictfeat_vocab = ScriptVocabulary( dictfeat_dict._vocab, # We want to use the index for the source pad token pad_idx=dictfeat_dict.idx[src_dict[src_dict.get_pad_index()]], ) else: # Optional types in Torchscript are a bit of a pain, so it's # more convenient to have an empty model than use None in # this case. self.dictfeat_vocab = ScriptVocabulary([]) self.sequence_generator = sequence_generator self.copy_unk_token: bool = copy_unk_token self.unk_idx: int = self.source_vocab.unk_idx self.filter_eos_bos: bool = filter_eos_bos def prepare_generator_inputs( self, word_ids: List[int], dict_feat: Optional[Tuple[List[str], List[float], List[int]]] = None, contextual_token_embedding: Optional[List[float]] = None, ) -> Tuple[torch.Tensor, Optional[Tuple[ torch.Tensor, torch.Tensor, torch.Tensor]], Optional[torch.Tensor], torch.Tensor, ]: src_len = len(word_ids) dict_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None if dict_feat is not None: dict_tokens, dict_weights, dict_lengths = dict_feat dict_ids = self.dictfeat_vocab.lookup_indices_1d(dict_tokens) dict_tensors = ( torch.tensor([dict_ids]), torch.tensor([dict_weights], dtype=torch.float), torch.tensor([dict_lengths]), ) contextual_embedding_tensor: Optional[torch.Tensor] = None if contextual_token_embedding is not None: assert (len(contextual_token_embedding) % src_len == 0 and len(contextual_token_embedding) > 0), ( f"Incorrect size for contextual embeddings: " f"{len(contextual_token_embedding)}, Expected a " f"non-zero multiple of input token count {src_len} ") contextual_embedding_tensor = torch.tensor( [contextual_token_embedding], dtype=torch.float) return ( torch.tensor(word_ids).reshape(-1, 1), dict_tensors, contextual_embedding_tensor, torch.tensor([src_len]), ) def forward( self, src_tokens: List[str], dict_feat: Optional[Tuple[List[str], List[float], List[int]]] = None, contextual_token_embedding: Optional[List[float]] = None, ) -> List[Tuple[List[str], float, List[float]]]: word_ids = self.source_vocab.lookup_indices_1d(src_tokens) # find, if there exists, the unk token in the source utterance. # If multiple we select the first unk token. single_unk_token: Optional[str] = get_single_unk_token( src_tokens, word_ids, self.copy_unk_token, self.unk_idx) ( words, dict_tensors, contextual_embedding_tensor, src_lengths, ) = self.prepare_generator_inputs(word_ids, dict_feat, contextual_token_embedding) hypos_etc = self.sequence_generator(words, dict_tensors, contextual_embedding_tensor, src_lengths) hypos_list: List[Tuple[List[str], float, List[float]]] = [] filter_token_list: List[int] = [] if self.filter_eos_bos: filter_token_list = [ self.target_vocab.bos_idx, self.target_vocab.eos_idx ] for seq in hypos_etc: hyopthesis = seq[0] stringified = self.target_vocab.lookup_words_1d( hyopthesis, filter_token_list=filter_token_list, possible_unk_token=single_unk_token, ) hypos_list.append((stringified, seq[1], seq[2])) return hypos_list
def _run_benchmark_lists_pytext_script_vocab(tok_lists: List[List[str]], v: PytextScriptVocabulary): for tokens_list in tok_lists: v.lookup_indices_1d(tokens_list)
def _run_benchmark_pytext_script_vocab(toks: List[str], v: PytextScriptVocabulary): for token in toks: v.lookup_indices_1d([token])
def test_custom_unk(self): vocab_list = ["a", "UNK", "b", "c", "d"] vocab = ScriptVocabulary(vocab_list, unk_idx=1) self.assertEqual([0, 1, 3, 4], vocab.lookup_indices_1d(["a", "e", "c", "d"]))
class VocabTest(unittest.TestCase): def setUp(self): vocab_list = ["UNK", "a", "b", "c", "d"] self.vocab = ScriptVocabulary(vocab_list) def test_vocab_lookup(self): # There are bugs with just making this a script, eventually these can be simpler class LookupWord(jit.ScriptModule): def __init__(self, vocab): super().__init__() self.vocab = vocab @jit.script_method def forward(self, word: str): return self.vocab.idx[word] lookup_word = LookupWord(self.vocab) self.assertEqual(1, lookup_word("a")) self.assertEqual(3, lookup_word("c")) with self.assertRaises(Exception): lookup_word("notaword") def test_vocab_idx_lookup(self): # There are bugs with just making this a script, eventually these can be simpler class LookupIndex(jit.ScriptModule): def __init__(self, vocab): super().__init__() self.vocab = vocab @jit.script_method def forward(self, i: int): return self.vocab.vocab[i] lookup_idx = LookupIndex(self.vocab) self.assertEqual("UNK", lookup_idx(0)) self.assertEqual("b", lookup_idx(2)) with self.assertRaises(Exception): lookup_idx(20) def test_lookup_1d(self): self.assertEqual( [1, 0, 3, 4], self.vocab.lookup_indices_1d(["a", "e", "c", "d"]) ) self.assertEqual([], self.vocab.lookup_indices_1d([])) def test_lookup_2d(self): self.assertEqual( [[1, 0, 3, 4], [], [2]], self.vocab.lookup_indices_2d([["a", "e", "c", "d"], [], ["b"]]), ) self.assertEqual([], self.vocab.lookup_indices_2d([])) def test_custom_unk(self): vocab_list = ["a", "UNK", "b", "c", "d"] vocab = ScriptVocabulary(vocab_list, unk_idx=1) self.assertEqual([0, 1, 3, 4], vocab.lookup_indices_1d(["a", "e", "c", "d"])) def test_lookup_words_1d_cycle_heuristic(self): self.assertEqual( self.vocab.lookup_words_1d_cycle_heuristic( torch.tensor([1, 0, 0]), [], ["y", "z"] ), ["a", "y", "z"], )