def test_save_load_language_model(self): """ Ensure saving + loading does not cause errors Ensure saving + loading does not change predictions """ save_file = "tests/saved-models/test-save-load" model = Classifier() lm_out = model.generate_text("The quick brown fox", 6) start_id = model.input_pipeline.text_encoder.start_token start_token = model.input_pipeline.text_encoder.decoder[start_id] self.assertNotIn(start_token, lm_out) # Non finetuned models do not use extra tokens train_sample = self.dataset.sample(n=self.n_sample) model.fit(train_sample.Text, train_sample.Target) lm_out = model.generate_text("", 5) self.assertIn(start_token, lm_out.lower()) self.assertEqual(type(lm_out), str) model.save(save_file) model = Classifier.load(save_file) lm_out_2 = model.generate_text("Indico RULE") self.assertEqual(type(lm_out_2), str) self.assertIn("{}Indico RULE".format(start_token).lower(), lm_out_2.lower()) # Both of these models use extra toks
def test_language_model(self): """ Ensure saving + loading does not cause errors Ensure saving + loading does not change predictions """ model = Classifier(verbose=False) lm_out = model.generate_text("", max_length=5) self.assertEqual(type(lm_out), str) lm_out_2 = model.generate_text("Indico RULE") self.assertEqual(type(lm_out_2), str) self.assertIn('_start_Indico RULE'.lower(), lm_out_2)
def test_language_model(self): """ Ensure saving + loading does not cause errors Ensure saving + loading does not change predictions """ model = Classifier() lm_out = model.generate_text("", max_length=5) self.assertEqual(type(lm_out), str) lm_out_2 = model.generate_text("Indico RULE").lower() self.assertEqual(type(lm_out_2), str) start_id = model.input_pipeline.text_encoder.start start_token = model.input_pipeline.text_encoder.decoder[start_id] self.assertIn('{}Indico RULE'.format(start_token).lower(), lm_out_2.lower())
def test_save_load_language_model(self): """ Ensure saving + loading does not cause errors Ensure saving + loading does not change predictions """ save_file = 'tests/saved-models/test-save-load' model = Classifier(verbose=False) train_sample = self.dataset.sample(n=self.n_sample) model.fit(train_sample.Text, train_sample.Target) lm_out = model.generate_text("", 5) self.assertEqual(type(lm_out), str) model.save(save_file) model = Classifier.load(save_file) lm_out_2 = model.generate_text("Indico RULE") self.assertEqual(type(lm_out_2), str) self.assertIn('_start_Indico RULE'.lower(), lm_out_2)
def test_save_load_language_model(self): """ Ensure saving + loading does not cause errors Ensure saving + loading does not change predictions """ save_file = 'tests/saved-models/test-save-load' model = Classifier() train_sample = self.dataset.sample(n=self.n_sample) model.fit(train_sample.Text, train_sample.Target) lm_out = model.generate_text("", 5) self.assertEqual(type(lm_out), str) model.save(save_file) model = Classifier.load(save_file) lm_out_2 = model.generate_text("Indico RULE") self.assertEqual(type(lm_out_2), str) start_id = model.input_pipeline.text_encoder.start start_token = model.input_pipeline.text_encoder.decoder[start_id] self.assertIn('{}Indico RULE'.format(start_token).lower(), lm_out_2.lower())
def test_early_termination_lm(self): model = Classifier(verbose=False) # A dirty mock to make all model inferences output a hundred _classify_ tokens def load_mock(*args, **kwargs): model.sess = MagicMock() model.sess.run = MagicMock(return_value=100 * [model.encoder['_classify_']]) model.saver.initialize = load_mock lm_out = model.generate_text() self.assertEqual(lm_out, '_start__classify_')
def test_early_termination_lm(self): model = Classifier(verbose=False) # A dirty mock to make all model inferences output a hundred _classify_ tokens fake_estimator = MagicMock() model.get_estimator = lambda *args, **kwargs: fake_estimator fake_estimator.predict = MagicMock( return_value=iter([{ "GEN_TEXT": 100 * [ENCODER['_classify_']] }])) lm_out = model.generate_text() self.assertEqual(lm_out, '_start__classify_')
def test_generate_text_stop_early(self): model = Classifier() # A dirty mock to make all model inferences output a hundred _classify_ tokens fake_estimator = MagicMock() model.get_estimator = lambda *args, **kwargs: (fake_estimator, []) model.input_pipeline.text_encoder._lazy_init() fake_estimator.predict = MagicMock(return_value=iter([{ PredictMode.GENERATE_TEXT: 100 * [model.input_pipeline.text_encoder["_classify_"]] }])) start_id = model.input_pipeline.text_encoder.start start_token = model.input_pipeline.text_encoder.decoder[start_id] lm_out = model.generate_text(use_extra_toks=True) self.assertEqual(lm_out, "{}_classify_".format(start_token))