Пример #1
0
    def test_loading_adapter_weights_without_prefix(self):
        model_base, model_with_head_base = create_twin_models(BertModel)

        model_with_head = BertModelWithHeads(model_with_head_base.config)
        model_with_head.bert = model_with_head_base

        model_base.add_adapter("dummy", AdapterType.text_task)

        with tempfile.TemporaryDirectory() as temp_dir:
            model_base.save_adapter(temp_dir, "dummy")

            loading_info = {}
            model_with_head.load_adapter(temp_dir, loading_info=loading_info)

        self.assertEqual(0, len(loading_info["missing_keys"]))
        self.assertEqual(0, len(loading_info["unexpected_keys"]))

        # check equal output
        in_data = ids_tensor((1, 128), 1000)
        output1 = model_with_head(in_data)
        output2 = model_base(in_data)
        self.assertEqual(len(output1), len(output2))
        self.assertTrue(torch.equal(output1[0], output2[0]))
Пример #2
0
    def test_load_adapter_with_head_from_hub(self):
        model = BertModelWithHeads.from_pretrained("bert-base-uncased")

        loading_info = {}
        adapter_name = model.load_adapter("qa/squad1@ukp", config="houlsby", version="1", loading_info=loading_info)

        self.assertEqual(0, len(loading_info["missing_keys"]))
        self.assertEqual(0, len(loading_info["unexpected_keys"]))

        self.assertIn(adapter_name, model.config.adapters.adapters)
        # check if config is valid
        expected_hash = get_adapter_config_hash(AdapterConfig.load("houlsby"))
        real_hash = get_adapter_config_hash(model.config.adapters.get(adapter_name))
        self.assertEqual(expected_hash, real_hash)

        # check size of output
        in_data = ids_tensor((1, 128), 1000)
        output = model(in_data)
        self.assertEqual([1, 128], list(output[0].size()))
Пример #3
0
    def test_model_with_heads_multiple_heads(self):
        model = BertModelWithHeads(self.config)
        model.add_tagging_head("test_head",
                               num_labels=len(self.labels),
                               id2label=self.label_map)
        model.add_classification_head("second_head", num_labels=5)
        with TemporaryDirectory() as temp_dir:
            model.save_head(temp_dir + "/test_head", "test_head")
            model.load_head(temp_dir + "/test_head")
            model.save_head(temp_dir + "/second_head", "second_head")
            model.load_head(temp_dir + "/second_head")
        model.add_adapter("sst-2", "text_task")

        self.assertEqual(model.get_labels("test_head"), self.labels)
        self.assertEqual(model.get_labels_dict("test_head"), self.label_map)
Пример #4
0
    def test_multiple_heads_label(self):
        model = BertModelWithHeads(self.config)
        model.add_tagging_head("test_head",
                               num_labels=len(self.labels),
                               id2label=self.label_map)
        with TemporaryDirectory() as temp_dir:
            model.save_head(temp_dir, "test_head")
            model.load_head(temp_dir)
        # adapter loaded for testing whether it changes label information
        model.add_adapter("sst-2", "text_task")
        model.add_classification_head("classification_head")
        default_label, default_label_dict = get_default(2)

        self.assertEqual(model.get_labels("classification_head"),
                         default_label)
        self.assertEqual(model.get_labels_dict("classification_head"),
                         default_label_dict)
Пример #5
0
    def test_model_with_heads_tagging_head_labels(self):
        model = BertModelWithHeads(self.config)
        model.add_tagging_head("test_head",
                               num_labels=len(self.labels),
                               id2label=self.label_map)
        with TemporaryDirectory() as temp_dir:
            model.save_head(temp_dir, "test_head")
            model.load_head(temp_dir)
        # this is just loaded to test whether loading an adapter changes the label information
        model.add_adapter("sst-2", "text_task")

        self.assertEqual(self.labels, model.get_labels())
        self.assertDictEqual(self.label_map, model.get_labels_dict())