コード例 #1
0
ファイル: test_sequence.py プロジェクト: seeker1943/finetune
 def test_fit_lm_only(self):
     """
     Ensure model training does not error out
     Ensure model returns predictions
     """
     raw_docs = ["".join(text) for text in self.texts]
     texts, annotations = finetune_to_indico_sequence(raw_docs, self.texts, self.labels,
                                                      none_value=self.model.config.pad_token)
     train_texts, test_texts, train_annotations, test_annotations = train_test_split(texts, annotations, test_size=0.1)
     self.model.fit(train_texts)
     self.model.fit(train_texts, train_annotations)
     predictions = self.model.predict(test_texts)
     probas = self.model.predict_proba(test_texts)
     self.assertIsInstance(probas, list)
     self.assertIsInstance(probas[0], list)
     self.assertIsInstance(probas[0][0], dict)
     self.assertIsInstance(probas[0][0]['confidence'], dict)
     token_precision = sequence_labeling_token_precision(test_annotations, predictions)
     token_recall = sequence_labeling_token_recall(test_annotations, predictions)
     overlap_precision = sequence_labeling_overlap_precision(test_annotations, predictions)
     overlap_recall = sequence_labeling_overlap_recall(test_annotations, predictions)
     self.assertIn('Named Entity', token_precision)
     self.assertIn('Named Entity', token_recall)
     self.assertIn('Named Entity', overlap_precision)
     self.assertIn('Named Entity', overlap_recall)
     self.model.save(self.save_file)
     model = SequenceLabeler.load(self.save_file)
     predictions = model.predict(test_texts)
コード例 #2
0
    def test_fit_predict(self):
        """
        Ensure model training does not error out
        Ensure model returns predictions
        Ensure class reweighting behaves as intended
        """
        raw_docs = ["".join(text) for text in self.texts]
        texts, annotations = finetune_to_indico_sequence(
            raw_docs, self.texts, self.labels, none_value=self.model.config.pad_token
        )
        train_texts, test_texts, train_annotations, test_annotations = train_test_split(
            texts, annotations, test_size=0.1
        )

        reweighted_model = SequenceLabeler(
            **self.default_config(class_weights={"Named Entity": 100.0})
        )
        reweighted_model.fit(train_texts, train_annotations)
        reweighted_predictions = reweighted_model.predict(test_texts)
        reweighted_token_recall = sequence_labeling_token_recall(
            test_annotations, reweighted_predictions
        )

        self.model.fit(train_texts, train_annotations)
        predictions = self.model.predict(test_texts)
        probas = self.model.predict_proba(test_texts)

        self.assertIsInstance(probas, list)
        self.assertIsInstance(probas[0], list)
        self.assertIsInstance(probas[0][0], dict)
        self.assertIsInstance(probas[0][0]["confidence"], dict)

        token_precision = sequence_labeling_token_precision(
            test_annotations, predictions
        )
        token_recall = sequence_labeling_token_recall(test_annotations, predictions)
        overlap_precision = sequence_labeling_overlap_precision(
            test_annotations, predictions
        )
        overlap_recall = sequence_labeling_overlap_recall(test_annotations, predictions)

        self.assertIn("Named Entity", token_precision)
        self.assertIn("Named Entity", token_recall)
        self.assertIn("Named Entity", overlap_precision)
        self.assertIn("Named Entity", overlap_recall)

        self.model.save(self.save_file)

        self.assertGreater(
            reweighted_token_recall["Named Entity"], token_recall["Named Entity"]
        )