Exemple #1
0
    def test_model_with_heads_tagging_head_labels(self):
        model = BertModelWithHeads(self.config)
        model.add_tagging_head("test_head",
                               num_labels=len(self.labels),
                               id2label=self.label_map)
        with TemporaryDirectory() as temp_dir:
            model.save_head(temp_dir, "test_head")
            model.load_head(temp_dir)
        # this is just loaded to test whether loading an adapter changes the label information
        model.add_adapter("sst-2", "text_task")

        self.assertEqual(self.labels, model.get_labels())
        self.assertDictEqual(self.label_map, model.get_labels_dict())
Exemple #2
0
    def test_model_with_heads_multiple_heads(self):
        model = BertModelWithHeads(self.config)
        model.add_tagging_head("test_head",
                               num_labels=len(self.labels),
                               id2label=self.label_map)
        model.add_classification_head("second_head", num_labels=5)
        with TemporaryDirectory() as temp_dir:
            model.save_head(temp_dir + "/test_head", "test_head")
            model.load_head(temp_dir + "/test_head")
            model.save_head(temp_dir + "/second_head", "second_head")
            model.load_head(temp_dir + "/second_head")
        model.add_adapter("sst-2", "text_task")

        self.assertEqual(model.get_labels("test_head"), self.labels)
        self.assertEqual(model.get_labels_dict("test_head"), self.label_map)
Exemple #3
0
    def test_multiple_heads_label(self):
        model = BertModelWithHeads(self.config)
        model.add_tagging_head("test_head",
                               num_labels=len(self.labels),
                               id2label=self.label_map)
        with TemporaryDirectory() as temp_dir:
            model.save_head(temp_dir, "test_head")
            model.load_head(temp_dir)
        # adapter loaded for testing whether it changes label information
        model.add_adapter("sst-2", "text_task")
        model.add_classification_head("classification_head")
        default_label, default_label_dict = get_default(2)

        self.assertEqual(model.get_labels("classification_head"),
                         default_label)
        self.assertEqual(model.get_labels_dict("classification_head"),
                         default_label_dict)