예제 #1
0
    def test_save_load_language_model(self):
        """
        Ensure saving + loading does not cause errors
        Ensure saving + loading does not change predictions
        """
        save_file = "tests/saved-models/test-save-load"
        model = Classifier()

        lm_out = model.generate_text("The quick brown fox", 6)
        start_id = model.input_pipeline.text_encoder.start_token
        start_token = model.input_pipeline.text_encoder.decoder[start_id]
        self.assertNotIn(start_token, lm_out) # Non finetuned models do not use extra tokens
        
        train_sample = self.dataset.sample(n=self.n_sample)
        model.fit(train_sample.Text, train_sample.Target)
        lm_out = model.generate_text("", 5)
        self.assertIn(start_token, lm_out.lower())
        self.assertEqual(type(lm_out), str)
        model.save(save_file)

        model = Classifier.load(save_file)
        lm_out_2 = model.generate_text("Indico RULE")
        self.assertEqual(type(lm_out_2), str)
        
        self.assertIn("{}Indico RULE".format(start_token).lower(), lm_out_2.lower()) # Both of these models use extra toks
예제 #2
0
 def test_language_model(self):
     """
     Ensure saving + loading does not cause errors
     Ensure saving + loading does not change predictions
     """
     model = Classifier(verbose=False)
     lm_out = model.generate_text("", max_length=5)
     self.assertEqual(type(lm_out), str)
     lm_out_2 = model.generate_text("Indico RULE")
     self.assertEqual(type(lm_out_2), str)
     self.assertIn('_start_Indico RULE'.lower(), lm_out_2)
예제 #3
0
 def test_language_model(self):
     """
     Ensure saving + loading does not cause errors
     Ensure saving + loading does not change predictions
     """
     model = Classifier()
     lm_out = model.generate_text("", max_length=5)
     self.assertEqual(type(lm_out), str)
     lm_out_2 = model.generate_text("Indico RULE").lower()
     self.assertEqual(type(lm_out_2), str)
     start_id = model.input_pipeline.text_encoder.start
     start_token = model.input_pipeline.text_encoder.decoder[start_id]
     self.assertIn('{}Indico RULE'.format(start_token).lower(), lm_out_2.lower())
예제 #4
0
 def test_save_load_language_model(self):
     """
     Ensure saving + loading does not cause errors
     Ensure saving + loading does not change predictions
     """
     save_file = 'tests/saved-models/test-save-load'
     model = Classifier(verbose=False)
     train_sample = self.dataset.sample(n=self.n_sample)
     model.fit(train_sample.Text, train_sample.Target)
     lm_out = model.generate_text("", 5)
     self.assertEqual(type(lm_out), str)
     model.save(save_file)
     model = Classifier.load(save_file)
     lm_out_2 = model.generate_text("Indico RULE")
     self.assertEqual(type(lm_out_2), str)
     self.assertIn('_start_Indico RULE'.lower(), lm_out_2)
예제 #5
0
 def test_save_load_language_model(self):
     """
     Ensure saving + loading does not cause errors
     Ensure saving + loading does not change predictions
     """
     save_file = 'tests/saved-models/test-save-load'
     model = Classifier()
     train_sample = self.dataset.sample(n=self.n_sample)
     model.fit(train_sample.Text, train_sample.Target)
     lm_out = model.generate_text("", 5)
     self.assertEqual(type(lm_out), str)
     model.save(save_file)
     model = Classifier.load(save_file)
     lm_out_2 = model.generate_text("Indico RULE")
     self.assertEqual(type(lm_out_2), str)
     start_id = model.input_pipeline.text_encoder.start
     start_token = model.input_pipeline.text_encoder.decoder[start_id]
     self.assertIn('{}Indico RULE'.format(start_token).lower(), lm_out_2.lower())
예제 #6
0
    def test_early_termination_lm(self):
        model = Classifier(verbose=False)

        # A dirty mock to make all model inferences output a hundred _classify_ tokens
        def load_mock(*args, **kwargs):
            model.sess = MagicMock()
            model.sess.run = MagicMock(return_value=100 *
                                       [model.encoder['_classify_']])

        model.saver.initialize = load_mock
        lm_out = model.generate_text()
        self.assertEqual(lm_out, '_start__classify_')
예제 #7
0
    def test_early_termination_lm(self):
        model = Classifier(verbose=False)

        # A dirty mock to make all model inferences output a hundred _classify_ tokens
        fake_estimator = MagicMock()
        model.get_estimator = lambda *args, **kwargs: fake_estimator
        fake_estimator.predict = MagicMock(
            return_value=iter([{
                "GEN_TEXT": 100 * [ENCODER['_classify_']]
            }]))

        lm_out = model.generate_text()
        self.assertEqual(lm_out, '_start__classify_')
예제 #8
0
    def test_generate_text_stop_early(self):
        model = Classifier()

        # A dirty mock to make all model inferences output a hundred _classify_ tokens
        fake_estimator = MagicMock()
        model.get_estimator = lambda *args, **kwargs: (fake_estimator, [])
        model.input_pipeline.text_encoder._lazy_init()
        fake_estimator.predict = MagicMock(return_value=iter([{
            PredictMode.GENERATE_TEXT:
            100 * [model.input_pipeline.text_encoder["_classify_"]]
        }]))
        start_id = model.input_pipeline.text_encoder.start
        start_token = model.input_pipeline.text_encoder.decoder[start_id]
        lm_out = model.generate_text(use_extra_toks=True)
        self.assertEqual(lm_out, "{}_classify_".format(start_token))