def test_auto_set_save_adapters(self):
        model = BertForSequenceClassification(
            BertConfig(
                hidden_size=32,
                num_hidden_layers=4,
                num_attention_heads=4,
                intermediate_size=37,
            )
        )
        model.add_adapter("adapter1")
        model.add_adapter("adapter2")
        model.add_adapter_fusion(Fuse("adapter1", "adapter2"))
        model.train_adapter_fusion(Fuse("adapter1", "adapter2"))

        training_args = TrainingArguments(
            output_dir="./examples",
        )
        trainer = AdapterTrainer(
            model=model,
            args=training_args,
        )
        self.assertTrue(trainer.train_adapter_fusion)
示例#2
0
class AdapterCompositionTest(unittest.TestCase):
    def setUp(self):
        self.model = BertForSequenceClassification(BertConfig())
        self.model.add_adapter("a")
        self.model.add_adapter("b")
        self.model.add_adapter("c")
        self.model.add_adapter("d")
        self.model.to(torch_device)
        self.model.train()

    def training_pass(self):
        inputs = {}
        inputs["input_ids"] = ids_tensor((1, 128), 1000)
        inputs["labels"] = torch.ones(1, dtype=torch.long)
        loss = self.model(**inputs).loss
        loss.backward()

    def batched_training_pass(self):
        inputs = {"input_ids": ids_tensor((4, 128), 1000), "labels": torch.ones(4, dtype=torch.long)}
        loss = self.model(**inputs).loss
        loss.backward()

    def test_simple_split(self):
        # pass over split setup
        self.model.set_active_adapters(Split("a", "b", 64))

        self.training_pass()

    def test_stacked_split(self):
        # split into two stacks
        self.model.set_active_adapters(Split(Stack("a", "b"), Stack("c", "d"), split_index=64))

        self.training_pass()

    def test_stacked_fusion(self):
        self.model.add_adapter_fusion(Fuse("b", "d"))

        # fuse two stacks
        self.model.set_active_adapters(Fuse(Stack("a", "b"), Stack("c", "d")))

        self.training_pass()

    def test_mixed_stack(self):
        self.model.add_adapter_fusion(Fuse("a", "b"))

        self.model.set_active_adapters(Stack("a", Split("c", "d", split_index=64), Fuse("a", "b")))

        self.training_pass()

    def test_nested_split(self):
        # split into two stacks
        self.model.set_active_adapters(Split(Split("a", "b", split_index=32), "c", split_index=64))

        self.training_pass()

    def test_parallel(self):
        self.model.set_active_adapters(Parallel("a", "b", "c", "d"))

        inputs = {}
        inputs["input_ids"] = ids_tensor((1, 128), 1000)
        logits = self.model(**inputs).logits
        self.assertEqual(logits.shape, (4, 2))

    def test_nested_parallel(self):
        self.model.set_active_adapters(Stack("a", Parallel(Stack("b", "c"), "d")))

        inputs = {}
        inputs["input_ids"] = ids_tensor((1, 128), 1000)
        logits = self.model(**inputs).logits
        self.assertEqual(logits.shape, (2, 2))

    def test_batch_split(self):
        self.model.set_active_adapters(BatchSplit("a", "b", "c", batch_sizes=[1, 1, 2]))
        self.batched_training_pass()

    def test_batch_split_int(self):
        self.model.set_active_adapters(BatchSplit("a", "b", batch_sizes=2))
        self.batched_training_pass()

    def test_nested_batch_split(self):
        self.model.set_active_adapters(Stack("a", BatchSplit("b", "c", batch_sizes=[2, 2])))
        self.batched_training_pass()

    def test_batch_split_invalid(self):
        self.model.set_active_adapters(BatchSplit("a", "b", batch_sizes=[3, 4]))
        with self.assertRaises(IndexError):
            self.batched_training_pass()

    def test_batch_split_equivalent(self):
        self.model.set_active_adapters("a")
        self.model.eval()
        input_ids = ids_tensor((2, 128), 1000)
        output_a = self.model(input_ids[:1])

        self.model.set_active_adapters("b")
        output_b = self.model(input_ids[1:2])

        self.model.set_active_adapters(BatchSplit("a", "b", batch_sizes=[1, 1]))
        output = self.model(input_ids)

        self.assertTrue(torch.allclose(output_a[0], output[0][0], atol=1e-6))
        self.assertTrue(torch.allclose(output_b[0], output[0][1], atol=1e-6))