def test_get_model_and_tokenizer_simple(self, auto_model_mock,
                                            auto_tokenizer_mock):
        """Randomized model and tokenizer loading test"""

        transformers_extractor.get_model_and_tokenizer("non-existent model",
                                                       random_weights=True)

        auto_model_mock.return_value.to.return_value.init_weights.assert_called_once(
        )
Example #2
0
    def test_get_model_and_tokenizer_custom(self, auto_model_mock, auto_tokenizer_mock):
        """Tokenizer with name different than model loading test"""

        # Using torch matrix here to avoid mocking .to(device) in torch
        expected_model = torch.rand((5,2))
        expected_tokenizer = torch.rand((5,))
        auto_model_mock.return_value = expected_model
        auto_tokenizer_mock.side_effect = lambda *args, **kwargs: expected_tokenizer if args[0] == "custom-tokenizer" else torch.rand((5,))

        model, tokenizer = transformers_extractor.get_model_and_tokenizer("non-existent model,custom-tokenizer")

        self.assertTrue(torch.equal(model, expected_model))
        self.assertTrue(torch.equal(tokenizer, expected_tokenizer))
Example #3
0
    def test_get_model_and_tokenizer_normal(self, auto_model_mock, auto_tokenizer_mock):
        """Normal model and tokenizer loading test"""

        # Using torch matrix here to avoid mocking .to(device) in torch
        expected_model = torch.rand((5,2))
        expected_tokenizer = torch.rand((5,))
        auto_model_mock.return_value = expected_model
        auto_tokenizer_mock.return_value = expected_tokenizer

        model, tokenizer = transformers_extractor.get_model_and_tokenizer("non-existent model")

        self.assertTrue(torch.equal(model, expected_model))
        self.assertTrue(torch.equal(tokenizer, expected_tokenizer))
Example #4
0
 def __init__(self, model_name):
     "Load the model and tokenizer"
     self.model_name = model_name
     self.model, self.tokenizer = get_model_and_tokenizer(model_name)