Exemplo n.º 1
0
    def test_load_adapter_from_hub(self):
        for config in ["pfeiffer", "houlsby"]:
            with self.subTest(config=config):
                model = BertForSequenceClassification.from_pretrained(
                    "bert-base-uncased")

                loading_info = {}
                adapter_name = model.load_adapter("sts/mrpc@ukp",
                                                  config=config,
                                                  version="1",
                                                  loading_info=loading_info)

                self.assertEqual(0, len(loading_info["missing_keys"]))

                # hotfix for unnecessary weights in old adapters
                unexpected_keys = [
                    k for k in loading_info["unexpected_keys"]
                    if "adapter_attention" not in k
                ]
                self.assertEqual(0, len(unexpected_keys))

                self.assertIn(adapter_name, model.config.adapters.adapters)
                # check if config is valid
                expected_hash = get_adapter_config_hash(
                    AdapterConfig.load(config))
                real_hash = get_adapter_config_hash(
                    model.config.adapters.get(adapter_name))
                self.assertEqual(expected_hash, real_hash)

                # check size of output
                in_data = ids_tensor((1, 128), 1000)
                output = model(in_data)
                self.assertEqual([1, 2], list(output[0].size()))
    def test_load_task_adapter_from_hub(self):
        """This test checks if an adapter is loaded from the Hub correctly by evaluating it on some MRPC samples
        and comparing with the expected result.
        """
        for config in ["pfeiffer", "houlsby"]:
            with self.subTest(config=config):
                tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
                model = BertForSequenceClassification.from_pretrained(
                    "bert-base-uncased")

                loading_info = {}
                adapter_name = model.load_adapter("sts/mrpc@ukp",
                                                  config=config,
                                                  version="1",
                                                  loading_info=loading_info)
                model.train_adapter(adapter_name)

                self.assertEqual(0, len(loading_info["missing_keys"]))
                self.assertEqual(0, len(loading_info["unexpected_keys"]))

                self.assertIn(adapter_name, model.config.adapters.adapters)
                self.assertNotIn(adapter_name,
                                 model.base_model.invertible_adapters)

                # check if config is valid
                expected_hash = get_adapter_config_hash(
                    AdapterConfig.load(config))
                real_hash = get_adapter_config_hash(
                    model.config.adapters.get(adapter_name))
                self.assertEqual(expected_hash, real_hash)

                # setup dataset
                data_args = GlueDataTrainingArguments(
                    task_name="mrpc",
                    data_dir="./tests/fixtures/tests_samples/MRPC",
                    overwrite_cache=True)
                eval_dataset = GlueDataset(data_args,
                                           tokenizer=tokenizer,
                                           mode="dev")
                training_args = TrainingArguments(output_dir="./examples",
                                                  no_cuda=True)

                # evaluate
                trainer = Trainer(
                    model=model,
                    args=training_args,
                    eval_dataset=eval_dataset,
                    compute_metrics=self._compute_glue_metrics("mrpc"),
                    adapter_names=["mrpc"],
                )
                result = trainer.evaluate()
                self.assertGreater(result["eval_acc"], 0.9)
Exemplo n.º 3
0
    def test_load_adapter_with_head_from_hub(self):
        model = BertModelWithHeads.from_pretrained("bert-base-uncased")

        loading_info = {}
        adapter_name = model.load_adapter("qa/squad1@ukp", config="houlsby", version="1", loading_info=loading_info)

        self.assertEqual(0, len(loading_info["missing_keys"]))
        self.assertEqual(0, len(loading_info["unexpected_keys"]))

        self.assertIn(adapter_name, model.config.adapters.adapters)
        # check if config is valid
        expected_hash = get_adapter_config_hash(AdapterConfig.load("houlsby"))
        real_hash = get_adapter_config_hash(model.config.adapters.get(adapter_name))
        self.assertEqual(expected_hash, real_hash)

        # check size of output
        in_data = ids_tensor((1, 128), 1000)
        output = model(in_data)
        self.assertEqual([1, 128], list(output[0].size()))