def test_auto_set_save_adapters(self): model = BertForSequenceClassification( BertConfig( hidden_size=32, num_hidden_layers=4, num_attention_heads=4, intermediate_size=37, ) ) model.add_adapter("adapter1") model.add_adapter("adapter2") model.add_adapter_fusion(Fuse("adapter1", "adapter2")) model.train_adapter_fusion(Fuse("adapter1", "adapter2")) training_args = TrainingArguments( output_dir="./examples", ) trainer = AdapterTrainer( model=model, args=training_args, ) self.assertTrue(trainer.train_adapter_fusion)
class AdapterCompositionTest(unittest.TestCase): def setUp(self): self.model = BertForSequenceClassification(BertConfig()) self.model.add_adapter("a") self.model.add_adapter("b") self.model.add_adapter("c") self.model.add_adapter("d") self.model.to(torch_device) self.model.train() def training_pass(self): inputs = {} inputs["input_ids"] = ids_tensor((1, 128), 1000) inputs["labels"] = torch.ones(1, dtype=torch.long) loss = self.model(**inputs).loss loss.backward() def batched_training_pass(self): inputs = {"input_ids": ids_tensor((4, 128), 1000), "labels": torch.ones(4, dtype=torch.long)} loss = self.model(**inputs).loss loss.backward() def test_simple_split(self): # pass over split setup self.model.set_active_adapters(Split("a", "b", 64)) self.training_pass() def test_stacked_split(self): # split into two stacks self.model.set_active_adapters(Split(Stack("a", "b"), Stack("c", "d"), split_index=64)) self.training_pass() def test_stacked_fusion(self): self.model.add_adapter_fusion(Fuse("b", "d")) # fuse two stacks self.model.set_active_adapters(Fuse(Stack("a", "b"), Stack("c", "d"))) self.training_pass() def test_mixed_stack(self): self.model.add_adapter_fusion(Fuse("a", "b")) self.model.set_active_adapters(Stack("a", Split("c", "d", split_index=64), Fuse("a", "b"))) self.training_pass() def test_nested_split(self): # split into two stacks self.model.set_active_adapters(Split(Split("a", "b", split_index=32), "c", split_index=64)) self.training_pass() def test_parallel(self): self.model.set_active_adapters(Parallel("a", "b", "c", "d")) inputs = {} inputs["input_ids"] = ids_tensor((1, 128), 1000) logits = self.model(**inputs).logits self.assertEqual(logits.shape, (4, 2)) def test_nested_parallel(self): self.model.set_active_adapters(Stack("a", Parallel(Stack("b", "c"), "d"))) inputs = {} inputs["input_ids"] = ids_tensor((1, 128), 1000) logits = self.model(**inputs).logits self.assertEqual(logits.shape, (2, 2)) def test_batch_split(self): self.model.set_active_adapters(BatchSplit("a", "b", "c", batch_sizes=[1, 1, 2])) self.batched_training_pass() def test_batch_split_int(self): self.model.set_active_adapters(BatchSplit("a", "b", batch_sizes=2)) self.batched_training_pass() def test_nested_batch_split(self): self.model.set_active_adapters(Stack("a", BatchSplit("b", "c", batch_sizes=[2, 2]))) self.batched_training_pass() def test_batch_split_invalid(self): self.model.set_active_adapters(BatchSplit("a", "b", batch_sizes=[3, 4])) with self.assertRaises(IndexError): self.batched_training_pass() def test_batch_split_equivalent(self): self.model.set_active_adapters("a") self.model.eval() input_ids = ids_tensor((2, 128), 1000) output_a = self.model(input_ids[:1]) self.model.set_active_adapters("b") output_b = self.model(input_ids[1:2]) self.model.set_active_adapters(BatchSplit("a", "b", batch_sizes=[1, 1])) output = self.model(input_ids) self.assertTrue(torch.allclose(output_a[0], output[0][0], atol=1e-6)) self.assertTrue(torch.allclose(output_b[0], output[0][1], atol=1e-6))