def test_fit_lm_only(self): """ Ensure model training does not error out Ensure model returns predictions """ raw_docs = ["".join(text) for text in self.texts] texts, annotations = finetune_to_indico_sequence(raw_docs, self.texts, self.labels, none_value=self.model.config.pad_token) train_texts, test_texts, train_annotations, test_annotations = train_test_split(texts, annotations, test_size=0.1) self.model.fit(train_texts) self.model.fit(train_texts, train_annotations) predictions = self.model.predict(test_texts) probas = self.model.predict_proba(test_texts) self.assertIsInstance(probas, list) self.assertIsInstance(probas[0], list) self.assertIsInstance(probas[0][0], dict) self.assertIsInstance(probas[0][0]['confidence'], dict) token_precision = sequence_labeling_token_precision(test_annotations, predictions) token_recall = sequence_labeling_token_recall(test_annotations, predictions) overlap_precision = sequence_labeling_overlap_precision(test_annotations, predictions) overlap_recall = sequence_labeling_overlap_recall(test_annotations, predictions) self.assertIn('Named Entity', token_precision) self.assertIn('Named Entity', token_recall) self.assertIn('Named Entity', overlap_precision) self.assertIn('Named Entity', overlap_recall) self.model.save(self.save_file) model = SequenceLabeler.load(self.save_file) predictions = model.predict(test_texts)
def test_fit_predict(self): """ Ensure model training does not error out Ensure model returns predictions Ensure class reweighting behaves as intended """ raw_docs = ["".join(text) for text in self.texts] texts, annotations = finetune_to_indico_sequence( raw_docs, self.texts, self.labels, none_value=self.model.config.pad_token ) train_texts, test_texts, train_annotations, test_annotations = train_test_split( texts, annotations, test_size=0.1 ) reweighted_model = SequenceLabeler( **self.default_config(class_weights={"Named Entity": 100.0}) ) reweighted_model.fit(train_texts, train_annotations) reweighted_predictions = reweighted_model.predict(test_texts) reweighted_token_recall = sequence_labeling_token_recall( test_annotations, reweighted_predictions ) self.model.fit(train_texts, train_annotations) predictions = self.model.predict(test_texts) probas = self.model.predict_proba(test_texts) self.assertIsInstance(probas, list) self.assertIsInstance(probas[0], list) self.assertIsInstance(probas[0][0], dict) self.assertIsInstance(probas[0][0]["confidence"], dict) token_precision = sequence_labeling_token_precision( test_annotations, predictions ) token_recall = sequence_labeling_token_recall(test_annotations, predictions) overlap_precision = sequence_labeling_overlap_precision( test_annotations, predictions ) overlap_recall = sequence_labeling_overlap_recall(test_annotations, predictions) self.assertIn("Named Entity", token_precision) self.assertIn("Named Entity", token_recall) self.assertIn("Named Entity", overlap_precision) self.assertIn("Named Entity", overlap_recall) self.model.save(self.save_file) self.assertGreater( reweighted_token_recall["Named Entity"], token_recall["Named Entity"] )