class AdapterCompositionTest(unittest.TestCase): def setUp(self): self.model = BertForSequenceClassification(BertConfig()) self.model.add_adapter("a") self.model.add_adapter("b") self.model.add_adapter("c") self.model.add_adapter("d") self.model.to(torch_device) self.model.train() def training_pass(self): inputs = {} inputs["input_ids"] = ids_tensor((1, 128), 1000) inputs["labels"] = torch.ones(1, dtype=torch.long) loss = self.model(**inputs).loss loss.backward() def test_simple_split(self): # pass over split setup self.model.set_active_adapters(Split("a", "b", 64)) self.training_pass() def test_stacked_split(self): # split into two stacks self.model.set_active_adapters( Split(Stack("a", "b"), Stack("c", "d"), split_index=64)) self.training_pass() def test_stacked_fusion(self): self.model.add_fusion(Fuse("b", "d")) # fuse two stacks self.model.set_active_adapters(Fuse(Stack("a", "b"), Stack("c", "d"))) self.training_pass() def test_mixed_stack(self): self.model.add_fusion(Fuse("a", "b")) self.model.set_active_adapters( Stack("a", Split("c", "d", split_index=64), Fuse("a", "b"))) self.training_pass() def test_nested_split(self): # split into two stacks self.model.set_active_adapters( Split(Split("a", "b", split_index=32), "c", split_index=64)) self.training_pass() def test_parallel(self): self.model.set_active_adapters(Parallel("a", "b", "c", "d")) inputs = {} inputs["input_ids"] = ids_tensor((1, 128), 1000) logits = self.model(**inputs).logits self.assertEqual(logits.shape, (4, 2)) def test_nested_parallel(self): self.model.set_active_adapters( Stack("a", Parallel(Stack("b", "c"), "d"))) inputs = {} inputs["input_ids"] = ids_tensor((1, 128), 1000) logits = self.model(**inputs).logits self.assertEqual(logits.shape, (2, 2))