コード例 #1
0
    def test_sequ_classification_model_head_labels(self):
        model = BertForSequenceClassification(self.config)
        with TemporaryDirectory() as temp_dir:
            model.save_head(temp_dir)
            model.load_head(temp_dir)

        self.assertEqual(self.labels, model.get_labels())
        self.assertDictEqual(self.label_map, model.get_labels_dict())