def test_model_with_heads_multiple_heads(self): model = BertModelWithHeads(self.config) model.add_tagging_head("test_head", num_labels=len(self.labels), id2label=self.label_map) model.add_classification_head("second_head", num_labels=5) with TemporaryDirectory() as temp_dir: model.save_head(temp_dir + "/test_head", "test_head") model.load_head(temp_dir + "/test_head") model.save_head(temp_dir + "/second_head", "second_head") model.load_head(temp_dir + "/second_head") model.add_adapter("sst-2", "text_task") self.assertEqual(model.get_labels("test_head"), self.labels) self.assertEqual(model.get_labels_dict("test_head"), self.label_map)
def test_multiple_heads_label(self): model = BertModelWithHeads(self.config) model.add_tagging_head("test_head", num_labels=len(self.labels), id2label=self.label_map) with TemporaryDirectory() as temp_dir: model.save_head(temp_dir, "test_head") model.load_head(temp_dir) # adapter loaded for testing whether it changes label information model.add_adapter("sst-2", "text_task") model.add_classification_head("classification_head") default_label, default_label_dict = get_default(2) self.assertEqual(model.get_labels("classification_head"), default_label) self.assertEqual(model.get_labels_dict("classification_head"), default_label_dict)