Ejemplo n.º 1
0
        class Model(jit.ScriptModule):
            def __init__(self):
                super().__init__()
                self.vocab = ScriptVocabulary(
                    input_vocab,
                    input_vocab.get_unk_index(),
                    input_vocab.get_pad_index(),
                )
                self.model = traced_model
                self.output_layer = output_layer
                self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
                self.max_seq_len = jit.Attribute(max_seq_len, int)

            @jit.script_method
            def forward(
                self,
                texts: Optional[List[str]] = None,
                multi_texts: Optional[List[List[str]]] = None,
                tokens: Optional[List[List[str]]] = None,
                languages: Optional[List[str]] = None,
            ):
                if tokens is None:
                    raise RuntimeError("tokens is required")

                tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token)
                seq_lens = make_sequence_lengths(tokens)
                word_ids = self.vocab.lookup_indices_2d(tokens)
                word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
                logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens))
                return self.output_layer(logits)
Ejemplo n.º 2
0
class VocabTransform(ScriptTransform):
    def __init__(
        self, vocab_path: Optional[str] = None, vocab_list: Optional[List[str]] = None
    ):
        super().__init__()
        assert vocab_path or vocab_list, "vocab_path or vocab_list is required"
        assert not (
            vocab_path and vocab_list
        ), "vocab_path and vocab_list are mutual exclusive"

        if vocab_list:
            self.vocab = ScriptVocabulary(vocab_list)
        else:
            with PathManager.open(vocab_path) as f:
                special_token_replacements = {
                    "[UNK]": UNK,
                    "[PAD]": PAD,
                    "[CLS]": BOS,
                    "[MASK]": MASK,
                    "[SEP]": EOS,
                }
                vocab = build_fairseq_vocab(
                    f, special_token_replacements=special_token_replacements
                )
                self.vocab = ScriptVocabulary(
                    list(vocab),
                    pad_idx=vocab.get_pad_index(-1),
                    bos_idx=vocab.get_bos_index(-1),
                    eos_idx=vocab.get_eos_index(-1),
                    unk_idx=vocab.get_unk_index(-1),
                )

    def forward(self, tokens: List[List[str]]) -> List[List[int]]:
        return self.vocab.lookup_indices_2d(tokens)
Ejemplo n.º 3
0
        class ModelWithDenseFeat(jit.ScriptModule):
            def __init__(self):
                super().__init__()
                self.vocab = ScriptVocabulary(input_vocab, unk_idx=input_vocab.idx[UNK])
                self.normalizer = tensorizers["dense"].normalizer
                self.max_byte_len = jit.Attribute(max_byte_len, int)
                self.byte_offset_for_non_padding = jit.Attribute(
                    byte_offset_for_non_padding, int
                )
                self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int)
                self.model = traced_model
                self.output_layer = output_layer

            @jit.script_method
            def forward(self, tokens: List[List[str]], dense_feat: List[List[float]]):
                seq_lens = make_sequence_lengths(tokens)
                word_ids = self.vocab.lookup_indices_2d(tokens)
                word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
                token_bytes, _ = make_byte_inputs(
                    tokens, self.max_byte_len, self.byte_offset_for_non_padding
                )
                dense_feat = self.normalizer.normalize(dense_feat)
                logits = self.model(
                    torch.tensor(word_ids),
                    token_bytes,
                    torch.tensor(seq_lens),
                    torch.tensor(dense_feat, dtype=torch.float),
                )
                return self.output_layer(logits)
Ejemplo n.º 4
0
        class ModelWithDenseFeat(jit.ScriptModule):
            def __init__(self):
                super().__init__()
                self.vocab = ScriptVocabulary(input_vocab,
                                              unk_idx=input_vocab.idx[UNK])
                self.normalizer = tensorizers["dense"].normalizer
                self.model = traced_model
                self.output_layer = output_layer
                self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int)

            @jit.script_method
            def forward(
                self,
                texts: Optional[List[str]] = None,
                tokens: Optional[List[List[str]]] = None,
                languages: Optional[List[str]] = None,
                dense_feat: Optional[List[List[float]]] = None,
            ):
                if tokens is None:
                    raise RuntimeError("tokens is required")

                seq_lens = make_sequence_lengths(tokens)
                word_ids = self.vocab.lookup_indices_2d(tokens)
                word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
                if dense_feat is not None:
                    dense_feat = self.normalizer.normalize(dense_feat)
                else:
                    raise RuntimeError("dense is required")
                logits = self.model(
                    torch.tensor(word_ids),
                    torch.tensor(seq_lens),
                    torch.tensor(dense_feat, dtype=torch.float),
                )
                return self.output_layer(logits)
Ejemplo n.º 5
0
        class Model(jit.ScriptModule):
            def __init__(self):
                super().__init__()
                self.vocab = ScriptVocabulary(input_vocab,
                                              unk_idx=input_vocab.idx[UNK])
                self.max_byte_len = jit.Attribute(max_byte_len, int)
                self.byte_offset_for_non_padding = jit.Attribute(
                    byte_offset_for_non_padding, int)
                self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int)
                self.model = traced_model
                self.output_layer = output_layer

            @jit.script_method
            def forward(
                self,
                texts: Optional[List[str]] = None,
                tokens: Optional[List[List[str]]] = None,
                languages: Optional[List[str]] = None,
            ):
                if tokens is None:
                    raise RuntimeError("tokens is required")
                seq_lens = make_sequence_lengths(tokens)
                word_ids = self.vocab.lookup_indices_2d(tokens)
                word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
                token_bytes, _ = make_byte_inputs(
                    tokens, self.max_byte_len,
                    self.byte_offset_for_non_padding)
                logits = self.model(torch.tensor(word_ids), token_bytes,
                                    torch.tensor(seq_lens))
                return self.output_layer(logits)
Ejemplo n.º 6
0
        class Model(jit.ScriptModule):
            def __init__(self):
                super().__init__()
                self.vocab = ScriptVocabulary(input_vocab,
                                              unk_idx=input_vocab.idx[UNK])
                self.model = traced_model
                self.output_layer = output_layer
                self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int)
                self.max_seq_len = jit.Attribute(max_seq_len, int)

            @jit.script_method
            def forward(
                self,
                texts: Optional[List[str]] = None,
                tokens: Optional[List[List[str]]] = None,
                languages: Optional[List[str]] = None,
            ):
                if tokens is None:
                    raise RuntimeError("tokens is required")

                trimmed_tokens: List[List[str]] = []
                if self.max_seq_len >= 0:
                    for token in tokens:
                        trimmed_tokens.append(token[0:self.max_seq_len])
                else:
                    trimmed_tokens = tokens

                seq_lens = make_sequence_lengths(trimmed_tokens)
                word_ids = self.vocab.lookup_indices_2d(trimmed_tokens)
                word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
                logits = self.model(torch.tensor(word_ids),
                                    torch.tensor(seq_lens))
                return self.output_layer(logits)
Ejemplo n.º 7
0
class VocabTransform(nn.Module):
    def __init__(
        self,
        vocab_path: Optional[str] = None,
        vocab_list: Optional[List[str]] = None,
        special_token_replacements=SPECIAL_TOKEN_REPLACEMENT,
    ):
        super().__init__()
        assert vocab_path or vocab_list, "vocab_path or vocab_list is required"
        assert not (
            vocab_path and vocab_list
        ), "vocab_path and vocab_list are mutual exclusive"

        if vocab_list:
            self.vocab = ScriptVocabulary(vocab_list)
        else:
            with PathManager.open(vocab_path) as f:
                vocab = build_fairseq_vocab(
                    f, special_token_replacements=special_token_replacements
                )
                self.vocab = ScriptVocabulary(
                    list(vocab),
                    pad_idx=vocab.get_pad_index(-1),
                    bos_idx=vocab.get_bos_index(-1),
                    eos_idx=vocab.get_eos_index(-1),
                    unk_idx=vocab.get_unk_index(-1),
                    unk_token=vocab.unk_token,
                )

    def forward(self, tokens: List[List[str]]) -> List[List[int]]:
        return self.vocab.lookup_indices_2d(tokens)
Ejemplo n.º 8
0
        class ModelWithDenseFeat(jit.ScriptModule):
            def __init__(self):
                super().__init__()
                self.vocab = ScriptVocabulary(
                    input_vocab,
                    input_vocab.get_unk_index(),
                    input_vocab.get_pad_index(),
                )
                self.normalizer = tensorizers["dense"].normalizer
                self.model = traced_model
                self.output_layer = output_layer
                self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
                self.max_seq_len = jit.Attribute(max_seq_len, int)
                self.tokenizer = scripted_tokenizer

            @jit.script_method
            def forward(
                self,
                texts: Optional[List[str]] = None,
                multi_texts: Optional[List[List[str]]] = None,
                tokens: Optional[List[List[str]]] = None,
                languages: Optional[List[str]] = None,
                dense_feat: Optional[List[List[float]]] = None,
            ):
                # PyTorch breaks with 2 'not None' checks right now.
                if texts is not None:
                    if tokens is not None:
                        raise RuntimeError("Can't set both tokens and texts")
                    if self.tokenizer is not None:
                        tokens = [
                            [t[0] for t in self.tokenizer.tokenize(text)]
                            for text in texts
                        ]

                if tokens is None:
                    raise RuntimeError("tokens is required")
                if dense_feat is None:
                    raise RuntimeError("dense_feat is required")

                tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token)
                seq_lens = make_sequence_lengths(tokens)
                word_ids = self.vocab.lookup_indices_2d(tokens)
                word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
                dense_feat = self.normalizer.normalize(dense_feat)
                logits = self.model(
                    torch.tensor(word_ids),
                    torch.tensor(seq_lens),
                    torch.tensor(dense_feat, dtype=torch.float),
                )
                return self.output_layer(logits)
Ejemplo n.º 9
0
        class Model(jit.ScriptModule):
            def __init__(self):
                super().__init__()
                self.vocab = ScriptVocabulary(input_vocab, unk_idx=input_vocab.idx[UNK])
                self.model = traced_model
                self.output_layer = output_layer
                self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int)

            @jit.script_method
            def forward(self, tokens: List[List[str]]):
                seq_lens = make_sequence_lengths(tokens)
                word_ids = self.vocab.lookup_indices_2d(tokens)
                word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
                logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens))
                return self.output_layer(logits)
Ejemplo n.º 10
0
        class ModelWithDenseFeat(jit.ScriptModule):
            def __init__(self):
                super().__init__()
                self.vocab = ScriptVocabulary(
                    input_vocab,
                    input_vocab.get_unk_index(),
                    input_vocab.get_pad_index(),
                )
                self.normalizer = tensorizers["dense"].normalizer
                self.max_seq_len = jit.Attribute(max_seq_len, int)
                self.max_byte_len = jit.Attribute(max_byte_len, int)
                self.byte_offset_for_non_padding = jit.Attribute(
                    byte_offset_for_non_padding, int
                )
                self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
                self.model = traced_model
                self.output_layer = output_layer

            @jit.script_method
            def forward(
                self,
                texts: Optional[List[str]] = None,
                multi_texts: Optional[List[List[str]]] = None,
                tokens: Optional[List[List[str]]] = None,
                languages: Optional[List[str]] = None,
                dense_feat: Optional[List[List[float]]] = None,
            ):
                if tokens is None:
                    raise RuntimeError("tokens is required")
                if dense_feat is None:
                    raise RuntimeError("dense_feat is required")

                tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token)
                seq_lens = make_sequence_lengths(tokens)
                word_ids = self.vocab.lookup_indices_2d(tokens)
                word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
                token_bytes, _ = make_byte_inputs(
                    tokens, self.max_byte_len, self.byte_offset_for_non_padding
                )
                dense_feat = self.normalizer.normalize(dense_feat)
                logits = self.model(
                    torch.tensor(word_ids),
                    token_bytes,
                    torch.tensor(seq_lens),
                    torch.tensor(dense_feat, dtype=torch.float),
                )
                return self.output_layer(logits)
Ejemplo n.º 11
0
class VocabTransform(nn.Module):
    def __init__(
        self,
        vocab_path: Optional[str] = None,
        vocab_list: Optional[List[str]] = None,
        special_token_replacements=SPECIAL_TOKEN_REPLACEMENT,
        add_bos: bool = False,
        add_eos: bool = False,
        max_seq_len: int = 2 ** 30,
    ):
        super().__init__()
        assert vocab_path or vocab_list, "vocab_path or vocab_list is required"
        assert not (
            vocab_path and vocab_list
        ), "vocab_path and vocab_list are mutual exclusive"

        if vocab_list:
            self.vocab = ScriptVocabulary(vocab_list)
        else:
            with PathManager.open(vocab_path) as f:
                vocab = build_fairseq_vocab(
                    f, special_token_replacements=special_token_replacements
                )
                self.vocab = ScriptVocabulary(
                    list(vocab),
                    pad_idx=vocab.get_pad_index(-1),
                    bos_idx=vocab.get_bos_index(-1),
                    eos_idx=vocab.get_eos_index(-1),
                    unk_idx=vocab.get_unk_index(-1),
                    unk_token=vocab.unk_token,
                )
        # TODO T77728853 We need to combine truncate with BOS/EOS as they impact each other
        # Need to find a nicer way to do this, as this can't be chained.
        self.add_bos = add_bos
        self.add_eos = add_eos
        # Make room for bos and eos from max_seq_len if true
        self.truncate_transform = TruncateTransform(max_seq_len - add_bos - add_eos)

    def forward(self, tokens: List[List[str]]) -> List[List[int]]:
        tokens_idx = self.vocab.lookup_indices_2d(tokens)
        tokens_idx = self.truncate_transform(tokens_idx)
        if self.add_bos:
            tokens_idx = [[self.vocab.bos_idx] + row for row in tokens_idx]
        if self.add_eos:
            tokens_idx = [row + [self.vocab.eos_idx] for row in tokens_idx]
        return tokens_idx
Ejemplo n.º 12
0
class VocabTest(unittest.TestCase):
    def setUp(self):
        vocab_list = ["UNK", "a", "b", "c", "d"]
        self.vocab = ScriptVocabulary(vocab_list)

    def test_vocab_lookup(self):
        # There are bugs with just making this a script, eventually these can be simpler
        class LookupWord(jit.ScriptModule):
            def __init__(self, vocab):
                super().__init__()
                self.vocab = vocab

            @jit.script_method
            def forward(self, word: str):
                return self.vocab.idx[word]

        lookup_word = LookupWord(self.vocab)

        self.assertEqual(1, lookup_word("a"))
        self.assertEqual(3, lookup_word("c"))
        with self.assertRaises(Exception):
            lookup_word("notaword")

    def test_vocab_idx_lookup(self):
        # There are bugs with just making this a script, eventually these can be simpler
        class LookupIndex(jit.ScriptModule):
            def __init__(self, vocab):
                super().__init__()
                self.vocab = vocab

            @jit.script_method
            def forward(self, i: int):
                return self.vocab.vocab[i]

        lookup_idx = LookupIndex(self.vocab)

        self.assertEqual("UNK", lookup_idx(0))
        self.assertEqual("b", lookup_idx(2))
        with self.assertRaises(Exception):
            lookup_idx(20)

    def test_lookup_1d(self):
        self.assertEqual(
            [1, 0, 3, 4], self.vocab.lookup_indices_1d(["a", "e", "c", "d"])
        )
        self.assertEqual([], self.vocab.lookup_indices_1d([]))

    def test_lookup_2d(self):
        self.assertEqual(
            [[1, 0, 3, 4], [], [2]],
            self.vocab.lookup_indices_2d([["a", "e", "c", "d"], [], ["b"]]),
        )
        self.assertEqual([], self.vocab.lookup_indices_2d([]))

    def test_custom_unk(self):
        vocab_list = ["a", "UNK", "b", "c", "d"]
        vocab = ScriptVocabulary(vocab_list, unk_idx=1)
        self.assertEqual([0, 1, 3, 4], vocab.lookup_indices_1d(["a", "e", "c", "d"]))

    def test_lookup_words_1d_cycle_heuristic(self):
        self.assertEqual(
            self.vocab.lookup_words_1d_cycle_heuristic(
                torch.tensor([1, 0, 0]), [], ["y", "z"]
            ),
            ["a", "y", "z"],
        )