Пример #1
0
 def test_make_prediction_2(self):
     """Testing make_prediction() with texts argument."""
     model = ScriptPyTextEmbeddingModule(self._mock_model(),
                                         self._mock_tensoriser())
     res = model.make_prediction([(["123", "12"], None, None, None, None)])
     self.assertEqual(res[0][0], 123)
     self.assertEqual(res[0][1], 12)
Пример #2
0
 def test_make_prediction_7(self):
     """Testing make_prediction() with texts argument passing text that is invalid for MockTensoriser.
     Should raise RuntimeError.
     """
     model = ScriptPyTextEmbeddingModule(self._mock_model(),
                                         self._mock_tensoriser())
     with self.assertRaises(RuntimeError):
         model.make_prediction([(["foo", "bar"], None, None, None, None)])
Пример #3
0
 def test_make_prediction_4(self):
     """Testing make_prediction() with texts argument set to [None].
     Should raise RuntimeError.
     """
     model = ScriptPyTextEmbeddingModule(self._mock_model(),
                                         self._mock_tensoriser())
     with self.assertRaises(RuntimeError):
         model.make_prediction([([None], None, None, None, None)])
Пример #4
0
 def test_make_prediction_6(self):
     """Testing make_prediction() with texts argument for batch input."""
     model = ScriptPyTextEmbeddingModule(self._mock_model(),
                                         self._mock_tensoriser())
     res = model.make_prediction([
         (["12345", "129"], None, None, None, None),
         (["95", "12"], None, None, None, None),
     ])
     self.assertEqual(res[0][0], 12345)
     self.assertEqual(res[0][1], 129)
     self.assertEqual(res[1][0], 95)
     self.assertEqual(res[1][1], 12)
Пример #5
0
 def setUp(self) -> None:
     self.batch_size = 4
     self.module = ScriptPyTextEmbeddingModule(self._mock_model(),
                                               self._mock_tensoriser())
Пример #6
0
class ScriptPyTextEmbeddingModuleTest(unittest.TestCase):
    def _mock_model(self):
        class MockModel(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                return inp

        model = MockModel()
        return torch.jit.script(model)

    def _mock_tensoriser(self):
        class MockTensoriser(TensorizerScriptImpl):
            def __init__(self):
                super().__init__()

            def tokenize(
                self,
                row: Optional[List[str]],
                row_pre_tokenized: Optional[List[List[str]]],
            ) -> List[List[Tuple[str, int, int]]]:
                tokens: List[List[Tuple[str, int, int]]] = []

                # handle row_pre_tokenized
                # this mock implementation will return early and use `row_pre_tokenized` if it's available;
                # otherwise, it defaults back to `row`.
                if row_pre_tokenized is not None:
                    for text in row_pre_tokenized:
                        res: List[Tuple[str, int, int]] = []
                        prev: int = 0
                        for i, w in enumerate(text):
                            res.append((w, i + prev, i + prev + len(w) - 1))
                            prev += len(w)
                        tokens.append(res)
                    return tokens

                # handle row if row_pre_tokenized is not defined
                if row is not None:
                    for text in row:
                        res: List[Tuple[str, int, int]] = []
                        prev: int = 0
                        for i, w in enumerate(text.split()):
                            res.append((w, i + prev, i + prev + len(w) - 1))
                            prev += len(w)
                        tokens.append(res)

                return tokens

            def numberize(self, inp: List[List[Tuple[str, int,
                                                     int]]]) -> List[int]:
                res: List[int] = []
                for row in inp:
                    res.extend([int(x) for x, _, _ in row])
                return res

            def tensorize(self, n: List[int]) -> torch.Tensor:
                return torch.tensor(n, dtype=torch.int)

            def forward(self, inputs: ScriptBatchInput) -> torch.Tensor:
                res: List[int] = []
                for idx in range(self.batch_size(inputs)):
                    res.extend(
                        self.numberize(
                            self.tokenize(
                                self.get_texts_by_index(inputs.texts, idx),
                                None)))
                return self.tensorize(res)

        return MockTensoriser()

    def get_random_string_with_n_tokens(self, n) -> str:
        return " ".join(
            random.choice(string.ascii_uppercase) for _ in range(n))

    def setUp(self) -> None:
        self.batch_size = 4
        self.module = ScriptPyTextEmbeddingModule(self._mock_model(),
                                                  self._mock_tensoriser())

    def test_make_prediction_1(self) -> None:
        """Testing make_prediction() with texts argument that is invalid.
        Should raise RuntimeError.
        """
        with self.assertRaises(RuntimeError):
            self.module.make_prediction([(["1",
                                           "foo"], None, None, None, None)])

    def test_make_prediction_2(self) -> None:
        """Testing make_prediction() with texts argument."""
        res = self.module.make_prediction([(["123",
                                             "12"], None, None, None, None)])
        self.assertEqual(res[0][0], 123)
        self.assertEqual(res[0][1], 12)

    def test_make_prediction_3(self) -> None:
        """Testing make_prediction() with texts argument(one string)."""
        res = self.module.make_prediction([(["1"], None, None, None, None)])
        self.assertEqual(res[0][0], 1)

    def test_make_prediction_4(self) -> None:
        """Testing make_prediction() with texts argument set to [None].
        Should raise RuntimeError.
        """
        with self.assertRaises(RuntimeError):
            self.module.make_prediction([([None], None, None, None, None)])

    def test_make_prediction_5(self) -> None:
        """Testing make_prediction() with texts and tokens arguments set to a list of None.
        Should raise RuntimeError.
        """
        with self.assertRaises(RuntimeError):
            self.module.make_prediction([([None,
                                           None], None, [None,
                                                         None], None, None)])

    def test_make_prediction_6(self) -> None:
        """Testing make_prediction() with texts argument for batch input."""
        res = self.module.make_prediction([
            (["12345", "129"], None, None, None, None),
            (["95", "12"], None, None, None, None),
        ])
        self.assertEqual(res[0][0], 12345)
        self.assertEqual(res[0][1], 129)
        self.assertEqual(res[1][0], 95)
        self.assertEqual(res[1][1], 12)

    def test_make_prediction_7(self) -> None:
        """Testing make_prediction() with texts argument passing text that is invalid for MockTensoriser.
        Should raise RuntimeError.
        """
        with self.assertRaises(RuntimeError):
            self.module.make_prediction([(["foo",
                                           "bar"], None, None, None, None)])

    def test_make_batch_runtime_invalid(self) -> None:
        invalid_inputs = [
            {
                "mega_batch": None,  # missing params
                "goals": {},
            },
            {
                "mega_batch": [()],  # missing params
                "goals": {},
            },
            {
                "mega_batch": [(["1 2"])],  # missing params
                "goals": {},
            },
            {
                "mega_batch": [(["1 2"], ["3 4"])],  # missing params
                "goals": {},
            },
            {
                "mega_batch": [(
                    ["1 2"],
                    ["3 4"],
                    ["5 6"],
                    # missing params
                )],
                "goals": {},
            },
            {
                "mega_batch": [(
                    ["1 2"],
                    ["3 4"],
                    ["5 6"],
                    ["7 8"],
                    [1.0],
                    # missing params
                )],
                "goals": {},
            },
            {
                "mega_batch": [(
                    ["1 2"],
                    ["3 4"],
                    ["5 6"],
                    ["7 8"],
                    [[1.0]],
                    # missing params
                )],
                "goals": {},
            },
            {
                "mega_batch": [(
                    [["1 2"]],  # invalid
                    [["3 4"]],
                    [["5 6"]],
                    ["7 8"],
                    [[1.0]],
                    1,
                )],
                "goals": {},
            },
            {
                "mega_batch": [(
                    ["1 2"],
                    [["3 4"]],
                    [["5 6"]],
                    ["7 8"],
                    [1.0],  # invalid
                    1,
                )],
                "goals": {},
            },
            {
                "mega_batch": [(
                    ["1 2"],
                    [["3 4"]],
                    [["5 6"]],
                    ["7 8"],
                    [[1.0]],
                    [1],  # invalid
                )],
                "goals": {},
            },
        ]
        for invalid_input in invalid_inputs:
            with self.assertRaises(
                    RuntimeError,
                    msg="mega_batch: {}, goals: {}".format(
                        invalid_input["mega_batch"], invalid_input["goals"]),
            ):
                self.module.make_batch(invalid_input["mega_batch"],
                                       invalid_input["goals"])

    def test_make_batch_trivial_input(self) -> None:
        trivial_inputs = [
            {
                "mega_batch": [],
                "goals": {},
                "output": [],
            },
            {
                "mega_batch": [(None, None, None, None, None, 1)],
                "goals": {},
                "output": [[(None, None, None, None, None, 1)]],
            },
            {
                "mega_batch": [(["1 2"], None, None, None, None, 1)],
                "goals": {},
                "output": [[(["1 2"], None, None, None, None, 1)]],
            },
            {
                "mega_batch": [(
                    None,
                    [["3 4"]],
                    [["5 6"]],
                    ["7 8"],
                    [[1.0]],
                    0,
                )],
                "goals": {},
                "output": [[
                    (
                        None,
                        [["3 4"]],
                        [["5 6"]],
                        ["7 8"],
                        [[1.0]],
                        0,
                    ),
                ]],
            },
        ]
        for trivial_input in trivial_inputs:
            output = self.module.make_batch(trivial_input["mega_batch"],
                                            trivial_input["goals"])
            self.assertEqual(
                output,
                trivial_input["output"],
            )

    def test_make_batch_with_pre_tokenized_input(self) -> None:
        # case: pass in both raw text and pre-tokenized and the return is
        #       levearging pre-tokenized based on the mock tokenizer
        mega_batch = [
            (None, None, [["1", "1", "1"]], None, None, 1),
            (None, None, [["1", "1"]], None, None, 2),
        ]
        batches = self.module.make_batch(mega_batch, {})
        self.assertEqual(
            batches,
            [
                [
                    (None, None, [["1", "1"]], None, None, 2),
                    (None, None, [["1", "1", "1"]], None, None, 1),
                ],
            ],
        )

        # case: if only pass in raw text the return is only considering the
        #       raw text based on the mock tokenizer
        mega_batch = [
            (["1 1 1"], None, None, None, None, 1),
            (["1 1"], None, None, None, None, 2),
        ]
        batches = self.module.make_batch(mega_batch, {})
        self.assertEqual(
            batches,
            [
                [
                    (["1 1"], None, None, None, None, 2),
                    (["1 1 1"], None, None, None, None, 1),
                ],
            ],
        )

    def test_make_batch_returns_multiple_batches_for_input_mega_batch(
            self) -> None:
        # case: multiple batches are returned for a large input mega-batch
        input = (["1 2"], [["3 4"]], None, ["7 8"], [[1.0]])
        mega_batch = []
        for _i in range(self.batch_size * 3 + self.batch_size - 1):
            mega_batch.append((*input, _i))
        batches = self.module.make_batch(mega_batch,
                                         {"batchsize": str(self.batch_size)})
        start = 0
        while start < min(start + self.batch_size, len(mega_batch)):
            end = min(start + self.batch_size, len(mega_batch))
            self.assertEqual(batches[start // self.batch_size],
                             mega_batch[start:end])
            start = end

    def test_make_batch_returns_one_batch_with_input_tuples_sorted(
            self) -> None:
        # case: input tuples are returned in a sorted order
        mega_batch = []
        for _i in range(self.batch_size):
            text = self.get_random_string_with_n_tokens(self.batch_size - _i)
            mega_batch.append(([text], None, None, None, None, _i))
        sorted_mega_batch = sorted(mega_batch,
                                   key=lambda x: x[-1],
                                   reverse=True)
        batches = self.module.make_batch(mega_batch,
                                         {"batchsize": str(self.batch_size)})
        self.assertEqual(batches, [sorted_mega_batch])

    def test_make_batch_returns_multiple_batches_with_input_tuples_sorted(
            self) -> None:
        # case: multiple batches are returned (w/ tuples sorted) for a large input mega-batch
        mega_batch = []
        mega_batch_len = (self.batch_size * 3) + self.batch_size - 1
        # make a large mega-batch so that multiuple batches are returned
        for _i in range(mega_batch_len):
            # each input has multiple text fields of varying lengths
            multi_text = []
            # first add a few short strings
            for short_text_len in range(5):
                multi_text.append(
                    self.get_random_string_with_n_tokens(short_text_len))
            # add a string with many tokens (to verify sorted order)
            multi_text.append(
                self.get_random_string_with_n_tokens(mega_batch_len * 2 - _i))
            mega_batch.append((multi_text, None, None, None, None, _i))
        sorted_mega_batch = sorted(mega_batch,
                                   key=lambda x: x[-1],
                                   reverse=True)
        batches = self.module.make_batch(mega_batch,
                                         {"batchsize": str(self.batch_size)})
        start = 0
        while start < min(start + self.batch_size, len(mega_batch)):
            end = min(start + self.batch_size, len(mega_batch))
            self.assertEqual(batches[start // self.batch_size],
                             sorted_mega_batch[start:end])
            start = end