class Model(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = ScriptVocabulary( input_vocab, input_vocab.get_unk_index(), input_vocab.get_pad_index(), ) self.model = traced_model self.output_layer = output_layer self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int) self.max_seq_len = jit.Attribute(max_seq_len, int) @jit.script_method def forward( self, texts: Optional[List[str]] = None, multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, ): if tokens is None: raise RuntimeError("tokens is required") tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token) seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens)) return self.output_layer(logits)
class VocabTransform(ScriptTransform): def __init__( self, vocab_path: Optional[str] = None, vocab_list: Optional[List[str]] = None ): super().__init__() assert vocab_path or vocab_list, "vocab_path or vocab_list is required" assert not ( vocab_path and vocab_list ), "vocab_path and vocab_list are mutual exclusive" if vocab_list: self.vocab = ScriptVocabulary(vocab_list) else: with PathManager.open(vocab_path) as f: special_token_replacements = { "[UNK]": UNK, "[PAD]": PAD, "[CLS]": BOS, "[MASK]": MASK, "[SEP]": EOS, } vocab = build_fairseq_vocab( f, special_token_replacements=special_token_replacements ) self.vocab = ScriptVocabulary( list(vocab), pad_idx=vocab.get_pad_index(-1), bos_idx=vocab.get_bos_index(-1), eos_idx=vocab.get_eos_index(-1), unk_idx=vocab.get_unk_index(-1), ) def forward(self, tokens: List[List[str]]) -> List[List[int]]: return self.vocab.lookup_indices_2d(tokens)
class ModelWithDenseFeat(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = ScriptVocabulary(input_vocab, unk_idx=input_vocab.idx[UNK]) self.normalizer = tensorizers["dense"].normalizer self.max_byte_len = jit.Attribute(max_byte_len, int) self.byte_offset_for_non_padding = jit.Attribute( byte_offset_for_non_padding, int ) self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int) self.model = traced_model self.output_layer = output_layer @jit.script_method def forward(self, tokens: List[List[str]], dense_feat: List[List[float]]): seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) token_bytes, _ = make_byte_inputs( tokens, self.max_byte_len, self.byte_offset_for_non_padding ) dense_feat = self.normalizer.normalize(dense_feat) logits = self.model( torch.tensor(word_ids), token_bytes, torch.tensor(seq_lens), torch.tensor(dense_feat, dtype=torch.float), ) return self.output_layer(logits)
class ModelWithDenseFeat(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = ScriptVocabulary(input_vocab, unk_idx=input_vocab.idx[UNK]) self.normalizer = tensorizers["dense"].normalizer self.model = traced_model self.output_layer = output_layer self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int) @jit.script_method def forward( self, texts: Optional[List[str]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, dense_feat: Optional[List[List[float]]] = None, ): if tokens is None: raise RuntimeError("tokens is required") seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) if dense_feat is not None: dense_feat = self.normalizer.normalize(dense_feat) else: raise RuntimeError("dense is required") logits = self.model( torch.tensor(word_ids), torch.tensor(seq_lens), torch.tensor(dense_feat, dtype=torch.float), ) return self.output_layer(logits)
class Model(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = ScriptVocabulary(input_vocab, unk_idx=input_vocab.idx[UNK]) self.max_byte_len = jit.Attribute(max_byte_len, int) self.byte_offset_for_non_padding = jit.Attribute( byte_offset_for_non_padding, int) self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int) self.model = traced_model self.output_layer = output_layer @jit.script_method def forward( self, texts: Optional[List[str]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, ): if tokens is None: raise RuntimeError("tokens is required") seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) token_bytes, _ = make_byte_inputs( tokens, self.max_byte_len, self.byte_offset_for_non_padding) logits = self.model(torch.tensor(word_ids), token_bytes, torch.tensor(seq_lens)) return self.output_layer(logits)
class Model(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = ScriptVocabulary(input_vocab, unk_idx=input_vocab.idx[UNK]) self.model = traced_model self.output_layer = output_layer self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int) self.max_seq_len = jit.Attribute(max_seq_len, int) @jit.script_method def forward( self, texts: Optional[List[str]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, ): if tokens is None: raise RuntimeError("tokens is required") trimmed_tokens: List[List[str]] = [] if self.max_seq_len >= 0: for token in tokens: trimmed_tokens.append(token[0:self.max_seq_len]) else: trimmed_tokens = tokens seq_lens = make_sequence_lengths(trimmed_tokens) word_ids = self.vocab.lookup_indices_2d(trimmed_tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens)) return self.output_layer(logits)
class VocabTransform(nn.Module): def __init__( self, vocab_path: Optional[str] = None, vocab_list: Optional[List[str]] = None, special_token_replacements=SPECIAL_TOKEN_REPLACEMENT, ): super().__init__() assert vocab_path or vocab_list, "vocab_path or vocab_list is required" assert not ( vocab_path and vocab_list ), "vocab_path and vocab_list are mutual exclusive" if vocab_list: self.vocab = ScriptVocabulary(vocab_list) else: with PathManager.open(vocab_path) as f: vocab = build_fairseq_vocab( f, special_token_replacements=special_token_replacements ) self.vocab = ScriptVocabulary( list(vocab), pad_idx=vocab.get_pad_index(-1), bos_idx=vocab.get_bos_index(-1), eos_idx=vocab.get_eos_index(-1), unk_idx=vocab.get_unk_index(-1), unk_token=vocab.unk_token, ) def forward(self, tokens: List[List[str]]) -> List[List[int]]: return self.vocab.lookup_indices_2d(tokens)
class ModelWithDenseFeat(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = ScriptVocabulary( input_vocab, input_vocab.get_unk_index(), input_vocab.get_pad_index(), ) self.normalizer = tensorizers["dense"].normalizer self.model = traced_model self.output_layer = output_layer self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int) self.max_seq_len = jit.Attribute(max_seq_len, int) self.tokenizer = scripted_tokenizer @jit.script_method def forward( self, texts: Optional[List[str]] = None, multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, dense_feat: Optional[List[List[float]]] = None, ): # PyTorch breaks with 2 'not None' checks right now. if texts is not None: if tokens is not None: raise RuntimeError("Can't set both tokens and texts") if self.tokenizer is not None: tokens = [ [t[0] for t in self.tokenizer.tokenize(text)] for text in texts ] if tokens is None: raise RuntimeError("tokens is required") if dense_feat is None: raise RuntimeError("dense_feat is required") tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token) seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) dense_feat = self.normalizer.normalize(dense_feat) logits = self.model( torch.tensor(word_ids), torch.tensor(seq_lens), torch.tensor(dense_feat, dtype=torch.float), ) return self.output_layer(logits)
class Model(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = ScriptVocabulary(input_vocab, unk_idx=input_vocab.idx[UNK]) self.model = traced_model self.output_layer = output_layer self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int) @jit.script_method def forward(self, tokens: List[List[str]]): seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens)) return self.output_layer(logits)
class ModelWithDenseFeat(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = ScriptVocabulary( input_vocab, input_vocab.get_unk_index(), input_vocab.get_pad_index(), ) self.normalizer = tensorizers["dense"].normalizer self.max_seq_len = jit.Attribute(max_seq_len, int) self.max_byte_len = jit.Attribute(max_byte_len, int) self.byte_offset_for_non_padding = jit.Attribute( byte_offset_for_non_padding, int ) self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int) self.model = traced_model self.output_layer = output_layer @jit.script_method def forward( self, texts: Optional[List[str]] = None, multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, dense_feat: Optional[List[List[float]]] = None, ): if tokens is None: raise RuntimeError("tokens is required") if dense_feat is None: raise RuntimeError("dense_feat is required") tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token) seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) token_bytes, _ = make_byte_inputs( tokens, self.max_byte_len, self.byte_offset_for_non_padding ) dense_feat = self.normalizer.normalize(dense_feat) logits = self.model( torch.tensor(word_ids), token_bytes, torch.tensor(seq_lens), torch.tensor(dense_feat, dtype=torch.float), ) return self.output_layer(logits)
class VocabTransform(nn.Module): def __init__( self, vocab_path: Optional[str] = None, vocab_list: Optional[List[str]] = None, special_token_replacements=SPECIAL_TOKEN_REPLACEMENT, add_bos: bool = False, add_eos: bool = False, max_seq_len: int = 2 ** 30, ): super().__init__() assert vocab_path or vocab_list, "vocab_path or vocab_list is required" assert not ( vocab_path and vocab_list ), "vocab_path and vocab_list are mutual exclusive" if vocab_list: self.vocab = ScriptVocabulary(vocab_list) else: with PathManager.open(vocab_path) as f: vocab = build_fairseq_vocab( f, special_token_replacements=special_token_replacements ) self.vocab = ScriptVocabulary( list(vocab), pad_idx=vocab.get_pad_index(-1), bos_idx=vocab.get_bos_index(-1), eos_idx=vocab.get_eos_index(-1), unk_idx=vocab.get_unk_index(-1), unk_token=vocab.unk_token, ) # TODO T77728853 We need to combine truncate with BOS/EOS as they impact each other # Need to find a nicer way to do this, as this can't be chained. self.add_bos = add_bos self.add_eos = add_eos # Make room for bos and eos from max_seq_len if true self.truncate_transform = TruncateTransform(max_seq_len - add_bos - add_eos) def forward(self, tokens: List[List[str]]) -> List[List[int]]: tokens_idx = self.vocab.lookup_indices_2d(tokens) tokens_idx = self.truncate_transform(tokens_idx) if self.add_bos: tokens_idx = [[self.vocab.bos_idx] + row for row in tokens_idx] if self.add_eos: tokens_idx = [row + [self.vocab.eos_idx] for row in tokens_idx] return tokens_idx
class VocabTest(unittest.TestCase): def setUp(self): vocab_list = ["UNK", "a", "b", "c", "d"] self.vocab = ScriptVocabulary(vocab_list) def test_vocab_lookup(self): # There are bugs with just making this a script, eventually these can be simpler class LookupWord(jit.ScriptModule): def __init__(self, vocab): super().__init__() self.vocab = vocab @jit.script_method def forward(self, word: str): return self.vocab.idx[word] lookup_word = LookupWord(self.vocab) self.assertEqual(1, lookup_word("a")) self.assertEqual(3, lookup_word("c")) with self.assertRaises(Exception): lookup_word("notaword") def test_vocab_idx_lookup(self): # There are bugs with just making this a script, eventually these can be simpler class LookupIndex(jit.ScriptModule): def __init__(self, vocab): super().__init__() self.vocab = vocab @jit.script_method def forward(self, i: int): return self.vocab.vocab[i] lookup_idx = LookupIndex(self.vocab) self.assertEqual("UNK", lookup_idx(0)) self.assertEqual("b", lookup_idx(2)) with self.assertRaises(Exception): lookup_idx(20) def test_lookup_1d(self): self.assertEqual( [1, 0, 3, 4], self.vocab.lookup_indices_1d(["a", "e", "c", "d"]) ) self.assertEqual([], self.vocab.lookup_indices_1d([])) def test_lookup_2d(self): self.assertEqual( [[1, 0, 3, 4], [], [2]], self.vocab.lookup_indices_2d([["a", "e", "c", "d"], [], ["b"]]), ) self.assertEqual([], self.vocab.lookup_indices_2d([])) def test_custom_unk(self): vocab_list = ["a", "UNK", "b", "c", "d"] vocab = ScriptVocabulary(vocab_list, unk_idx=1) self.assertEqual([0, 1, 3, 4], vocab.lookup_indices_1d(["a", "e", "c", "d"])) def test_lookup_words_1d_cycle_heuristic(self): self.assertEqual( self.vocab.lookup_words_1d_cycle_heuristic( torch.tensor([1, 0, 0]), [], ["y", "z"] ), ["a", "y", "z"], )