Example #1
0
 def _run_benchmark_pytext_script_vocab(toks, v: PytextScriptVocabulary):
     # list lookup
     if isinstance(toks, list) and isinstance(toks[0], list):
         for tokens_list in toks:
             v.lookup_indices_1d(tokens_list)
     # single token lookup
     elif isinstance(toks, list):
         for token in toks:
             v.lookup_indices_1d([token])
     else:
         raise RuntimeError("Received tokens of incorrect type {}.".format(
             type(toks)))
Example #2
0
class LabelTransform(nn.Module):
    def __init__(self, label_names: List[str]):
        super().__init__()

        self.label_vocab = ScriptVocabulary(sorted(label_names))

    def forward(self, labels: List[str]) -> List[int]:
        return self.label_vocab.lookup_indices_1d(labels)
Example #3
0
class VocabTransform(Transform):
    def __init__(self, vocab: Vocabulary):
        super().__init__()
        self.vocab = ScriptVocabulary(
            list(vocab),
            pad_idx=vocab.get_pad_index(-1),
            bos_idx=vocab.get_bos_index(-1),
            eos_idx=vocab.get_eos_index(-1),
            unk_idx=vocab.get_unk_index(-1),
        )

    def forward(self, tokens: Tokens) -> Dict[str, torch.Tensor]:
        token_ids: List[int] = self.vocab.lookup_indices_1d(tokens.token_texts)
        return {
            "token_ids": torch.tensor(token_ids, dtype=torch.long),
            "start_ids": torch.tensor(tokens.start_ids, dtype=torch.long),
            "end_ids": torch.tensor(tokens.end_ids, dtype=torch.long),
        }

    @property
    def is_jitable(self) -> bool:
        return True
Example #4
0
class Seq2SeqJIT(torch.nn.Module):
    def __init__(
        self,
        src_dict,
        tgt_dict,
        sequence_generator,
        filter_eos_bos,
        copy_unk_token=False,
        dictfeat_dict=None,
    ):
        super().__init__()
        self.source_vocab = ScriptVocabulary(
            src_dict._vocab,
            src_dict.get_unk_index(),
            bos_idx=src_dict.get_bos_index(-1),
            eos_idx=src_dict.get_eos_index(-1),
        )
        self.target_vocab = ScriptVocabulary(
            tgt_dict._vocab,
            tgt_dict.get_unk_index(),
            bos_idx=tgt_dict.get_bos_index(),
            eos_idx=tgt_dict.get_eos_index(),
        )
        if dictfeat_dict:
            self.dictfeat_vocab = ScriptVocabulary(
                dictfeat_dict._vocab,
                # We want to use the index for the source pad token
                pad_idx=dictfeat_dict.idx[src_dict[src_dict.get_pad_index()]],
            )
        else:
            # Optional types in Torchscript are a bit of a pain, so it's
            # more convenient to have an empty model than use None in
            # this case.
            self.dictfeat_vocab = ScriptVocabulary([])
        self.sequence_generator = sequence_generator

        self.copy_unk_token: bool = copy_unk_token
        self.unk_idx: int = self.source_vocab.unk_idx
        self.filter_eos_bos: bool = filter_eos_bos

    def prepare_generator_inputs(
        self,
        word_ids: List[int],
        dict_feat: Optional[Tuple[List[str], List[float], List[int]]] = None,
        contextual_token_embedding: Optional[List[float]] = None,
    ) -> Tuple[torch.Tensor, Optional[Tuple[
            torch.Tensor, torch.Tensor, torch.Tensor]], Optional[torch.Tensor],
               torch.Tensor, ]:
        src_len = len(word_ids)
        dict_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
                                     torch.Tensor]] = None
        if dict_feat is not None:
            dict_tokens, dict_weights, dict_lengths = dict_feat
            dict_ids = self.dictfeat_vocab.lookup_indices_1d(dict_tokens)
            dict_tensors = (
                torch.tensor([dict_ids]),
                torch.tensor([dict_weights], dtype=torch.float),
                torch.tensor([dict_lengths]),
            )

        contextual_embedding_tensor: Optional[torch.Tensor] = None
        if contextual_token_embedding is not None:
            assert (len(contextual_token_embedding) % src_len == 0
                    and len(contextual_token_embedding) > 0), (
                        f"Incorrect size for contextual embeddings: "
                        f"{len(contextual_token_embedding)}, Expected a "
                        f"non-zero multiple of input token count {src_len} ")
            contextual_embedding_tensor = torch.tensor(
                [contextual_token_embedding], dtype=torch.float)
        return (
            torch.tensor(word_ids).reshape(-1, 1),
            dict_tensors,
            contextual_embedding_tensor,
            torch.tensor([src_len]),
        )

    def forward(
        self,
        src_tokens: List[str],
        dict_feat: Optional[Tuple[List[str], List[float], List[int]]] = None,
        contextual_token_embedding: Optional[List[float]] = None,
    ) -> List[Tuple[List[str], float, List[float]]]:

        word_ids = self.source_vocab.lookup_indices_1d(src_tokens)

        # find, if there exists, the unk token in the source utterance.
        # If multiple we select the first unk token.
        single_unk_token: Optional[str] = get_single_unk_token(
            src_tokens, word_ids, self.copy_unk_token, self.unk_idx)

        (
            words,
            dict_tensors,
            contextual_embedding_tensor,
            src_lengths,
        ) = self.prepare_generator_inputs(word_ids, dict_feat,
                                          contextual_token_embedding)
        hypos_etc = self.sequence_generator(words, dict_tensors,
                                            contextual_embedding_tensor,
                                            src_lengths)
        hypos_list: List[Tuple[List[str], float, List[float]]] = []

        filter_token_list: List[int] = []
        if self.filter_eos_bos:
            filter_token_list = [
                self.target_vocab.bos_idx, self.target_vocab.eos_idx
            ]

        for seq in hypos_etc:
            hyopthesis = seq[0]
            stringified = self.target_vocab.lookup_words_1d(
                hyopthesis,
                filter_token_list=filter_token_list,
                possible_unk_token=single_unk_token,
            )
            hypos_list.append((stringified, seq[1], seq[2]))
        return hypos_list
Example #5
0
 def _run_benchmark_lists_pytext_script_vocab(tok_lists: List[List[str]],
                                              v: PytextScriptVocabulary):
     for tokens_list in tok_lists:
         v.lookup_indices_1d(tokens_list)
Example #6
0
 def _run_benchmark_pytext_script_vocab(toks: List[str],
                                        v: PytextScriptVocabulary):
     for token in toks:
         v.lookup_indices_1d([token])
Example #7
0
 def test_custom_unk(self):
     vocab_list = ["a", "UNK", "b", "c", "d"]
     vocab = ScriptVocabulary(vocab_list, unk_idx=1)
     self.assertEqual([0, 1, 3, 4], vocab.lookup_indices_1d(["a", "e", "c", "d"]))
Example #8
0
class VocabTest(unittest.TestCase):
    def setUp(self):
        vocab_list = ["UNK", "a", "b", "c", "d"]
        self.vocab = ScriptVocabulary(vocab_list)

    def test_vocab_lookup(self):
        # There are bugs with just making this a script, eventually these can be simpler
        class LookupWord(jit.ScriptModule):
            def __init__(self, vocab):
                super().__init__()
                self.vocab = vocab

            @jit.script_method
            def forward(self, word: str):
                return self.vocab.idx[word]

        lookup_word = LookupWord(self.vocab)

        self.assertEqual(1, lookup_word("a"))
        self.assertEqual(3, lookup_word("c"))
        with self.assertRaises(Exception):
            lookup_word("notaword")

    def test_vocab_idx_lookup(self):
        # There are bugs with just making this a script, eventually these can be simpler
        class LookupIndex(jit.ScriptModule):
            def __init__(self, vocab):
                super().__init__()
                self.vocab = vocab

            @jit.script_method
            def forward(self, i: int):
                return self.vocab.vocab[i]

        lookup_idx = LookupIndex(self.vocab)

        self.assertEqual("UNK", lookup_idx(0))
        self.assertEqual("b", lookup_idx(2))
        with self.assertRaises(Exception):
            lookup_idx(20)

    def test_lookup_1d(self):
        self.assertEqual(
            [1, 0, 3, 4], self.vocab.lookup_indices_1d(["a", "e", "c", "d"])
        )
        self.assertEqual([], self.vocab.lookup_indices_1d([]))

    def test_lookup_2d(self):
        self.assertEqual(
            [[1, 0, 3, 4], [], [2]],
            self.vocab.lookup_indices_2d([["a", "e", "c", "d"], [], ["b"]]),
        )
        self.assertEqual([], self.vocab.lookup_indices_2d([]))

    def test_custom_unk(self):
        vocab_list = ["a", "UNK", "b", "c", "d"]
        vocab = ScriptVocabulary(vocab_list, unk_idx=1)
        self.assertEqual([0, 1, 3, 4], vocab.lookup_indices_1d(["a", "e", "c", "d"]))

    def test_lookup_words_1d_cycle_heuristic(self):
        self.assertEqual(
            self.vocab.lookup_words_1d_cycle_heuristic(
                torch.tensor([1, 0, 0]), [], ["y", "z"]
            ),
            ["a", "y", "z"],
        )