Exemplo n.º 1
0
    def test_model_with_heads_tagging_head_labels(self):
        model = BertModelWithHeads(self.config)
        model.add_tagging_head("test_head",
                               num_labels=len(self.labels),
                               id2label=self.label_map)
        with TemporaryDirectory() as temp_dir:
            model.save_head(temp_dir, "test_head")
            model.load_head(temp_dir)
        # this is just loaded to test whether loading an adapter changes the label information
        model.add_adapter("sst-2", "text_task")

        self.assertEqual(self.labels, model.get_labels())
        self.assertDictEqual(self.label_map, model.get_labels_dict())
Exemplo n.º 2
0
    def test_model_with_heads_multiple_heads(self):
        model = BertModelWithHeads(self.config)
        model.add_tagging_head("test_head",
                               num_labels=len(self.labels),
                               id2label=self.label_map)
        model.add_classification_head("second_head", num_labels=5)
        with TemporaryDirectory() as temp_dir:
            model.save_head(temp_dir + "/test_head", "test_head")
            model.load_head(temp_dir + "/test_head")
            model.save_head(temp_dir + "/second_head", "second_head")
            model.load_head(temp_dir + "/second_head")
        model.add_adapter("sst-2", "text_task")

        self.assertEqual(model.get_labels("test_head"), self.labels)
        self.assertEqual(model.get_labels_dict("test_head"), self.label_map)
Exemplo n.º 3
0
    def test_multiple_heads_label(self):
        model = BertModelWithHeads(self.config)
        model.add_tagging_head("test_head",
                               num_labels=len(self.labels),
                               id2label=self.label_map)
        with TemporaryDirectory() as temp_dir:
            model.save_head(temp_dir, "test_head")
            model.load_head(temp_dir)
        # adapter loaded for testing whether it changes label information
        model.add_adapter("sst-2", "text_task")
        model.add_classification_head("classification_head")
        default_label, default_label_dict = get_default(2)

        self.assertEqual(model.get_labels("classification_head"),
                         default_label)
        self.assertEqual(model.get_labels_dict("classification_head"),
                         default_label_dict)
Exemplo n.º 4
0
    def test_loading_adapter_weights_with_prefix(self):
        model_base, model_with_head_base = create_twin_models(BertModel)

        model_with_head = BertModelWithHeads(model_with_head_base.config)
        model_with_head.bert = model_with_head_base

        model_with_head.add_adapter("dummy", AdapterType.text_task)

        with tempfile.TemporaryDirectory() as temp_dir:
            model_with_head.save_adapter(temp_dir, "dummy")

            loading_info = {}
            model_base.load_adapter(temp_dir, loading_info=loading_info)

        self.assertEqual(0, len(loading_info["missing_keys"]))
        self.assertEqual(0, len(loading_info["unexpected_keys"]))

        # check equal output
        in_data = ids_tensor((1, 128), 1000)
        output1 = model_with_head(in_data)
        output2 = model_base(in_data)
        self.assertEqual(len(output1), len(output2))
        self.assertTrue(torch.equal(output1[0], output2[0]))