示例#1
0
 def test_fit_lm_only(self):
     """
     Ensure model training does not error out
     Ensure model returns predictions
     """
     raw_docs = ["".join(text) for text in self.texts]
     texts, annotations = finetune_to_indico_sequence(raw_docs, self.texts, self.labels,
                                                      none_value=self.model.config.pad_token)
     train_texts, test_texts, train_annotations, test_annotations = train_test_split(texts, annotations, test_size=0.1)
     self.model.fit(train_texts)
     self.model.fit(train_texts, train_annotations)
     predictions = self.model.predict(test_texts)
     probas = self.model.predict_proba(test_texts)
     self.assertIsInstance(probas, list)
     self.assertIsInstance(probas[0], list)
     self.assertIsInstance(probas[0][0], dict)
     self.assertIsInstance(probas[0][0]['confidence'], dict)
     token_precision = sequence_labeling_token_precision(test_annotations, predictions)
     token_recall = sequence_labeling_token_recall(test_annotations, predictions)
     overlap_precision = sequence_labeling_overlap_precision(test_annotations, predictions)
     overlap_recall = sequence_labeling_overlap_recall(test_annotations, predictions)
     self.assertIn('Named Entity', token_precision)
     self.assertIn('Named Entity', token_recall)
     self.assertIn('Named Entity', overlap_precision)
     self.assertIn('Named Entity', overlap_recall)
     self.model.save(self.save_file)
     model = SequenceLabeler.load(self.save_file)
     predictions = model.predict(test_texts)
示例#2
0
    def test_fit_predict(self):
        """
        Ensure model training does not error out
        Ensure model returns predictions
        Ensure class reweighting behaves as intended
        """
        raw_docs = ["".join(text) for text in self.texts]
        texts, annotations = finetune_to_indico_sequence(
            raw_docs, self.texts, self.labels, none_value=self.model.config.pad_token
        )
        train_texts, test_texts, train_annotations, test_annotations = train_test_split(
            texts, annotations, test_size=0.1
        )

        reweighted_model = SequenceLabeler(
            **self.default_config(class_weights={"Named Entity": 100.0})
        )
        reweighted_model.fit(train_texts, train_annotations)
        reweighted_predictions = reweighted_model.predict(test_texts)
        reweighted_token_recall = sequence_labeling_token_recall(
            test_annotations, reweighted_predictions
        )

        self.model.fit(train_texts, train_annotations)
        predictions = self.model.predict(test_texts)
        probas = self.model.predict_proba(test_texts)

        self.assertIsInstance(probas, list)
        self.assertIsInstance(probas[0], list)
        self.assertIsInstance(probas[0][0], dict)
        self.assertIsInstance(probas[0][0]["confidence"], dict)

        token_precision = sequence_labeling_token_precision(
            test_annotations, predictions
        )
        token_recall = sequence_labeling_token_recall(test_annotations, predictions)
        overlap_precision = sequence_labeling_overlap_precision(
            test_annotations, predictions
        )
        overlap_recall = sequence_labeling_overlap_recall(test_annotations, predictions)

        self.assertIn("Named Entity", token_precision)
        self.assertIn("Named Entity", token_recall)
        self.assertIn("Named Entity", overlap_precision)
        self.assertIn("Named Entity", overlap_recall)

        self.model.save(self.save_file)

        self.assertGreater(
            reweighted_token_recall["Named Entity"], token_recall["Named Entity"]
        )
示例#3
0
 def _evaluate_sequence_preds(self, preds, includes_context):
     token_precision = sequence_labeling_token_precision(
         self.trainY_seq, preds)
     token_recall = sequence_labeling_token_recall(self.trainY_seq, preds)
     self.assertIn("IMPORTANT", token_precision)
     self.assertIn("IMPORTANT", token_recall)
     token_precision = np.mean(list(token_precision.values()))
     token_recall = np.mean(list(token_recall.values()))
     if includes_context:
         self.assertEqual(token_precision, 1.0)
         self.assertEqual(token_recall, 1.0)
     else:
         self.assertLessEqual(token_precision, 1.0)
         self.assertLessEqual(token_recall, 1.0)
示例#4
0
 def test_auxiliary_sequence_labeler(self):
     """
     Ensure model training does not error out
     Ensure model returns reasonable predictions
     """
     (trainX, testX, trainY, testY) = self.dataset
     model = SequenceLabeler(**self.default_config())
     model.fit(trainX, trainY)
     preds = model.predict(testX)
     token_precision = sequence_labeling_token_precision(preds, testY)
     token_recall = sequence_labeling_token_recall(preds, testY)
     self.assertIn("Named Entity", token_precision)
     self.assertIn("Named Entity", token_recall)
     token_precision = np.mean(list(token_precision.values()))
     token_recall = np.mean(list(token_recall.values()))
     self.assertGreater(token_precision, 0.6)
     self.assertGreater(token_recall, 0.6)