def test_slice_again(self): original = Embedding(10, 10) already_sliced = SlicedEmbedding.slice(original, 5, True, False) slice_it_again = SlicedEmbedding.slice(already_sliced, 5, True, False) self.assertTrue(type(slice_it_again) is SlicedEmbedding) self.assertEqual(already_sliced, slice_it_again)
def test_lookup_embeddings(self): original_embedding = torch.nn.Embedding(10, 2) slieced_embedding = SlicedEmbedding.slice(original_embedding, 5, True, False) batch_for_original = torch.tensor([[0, 1], [8, 9]]) batch_for_slieced = torch.tensor([[0, 1], [8, 9]]) lookup_original = original_embedding(batch_for_original) lookup_sliced = slieced_embedding(batch_for_slieced) self.assertTrue(torch.all(lookup_original == lookup_sliced).item())
def prepare_model_for_testing(checkpoint_path: str, lexical_checkpoint: str = None, always_use_finetuned_lexical: bool = False): try: model = NLIFinetuneModel.load_from_checkpoint(checkpoint_path) except: model = NLIFinetuneModel(pretrained_model=checkpoint_path, num_classes=3, train_lexical_strategy='none', train_dataset='assin2', eval_dataset='assin2', data_dir='./data', batch_size=32, max_seq_length=128, tokenizer_name=None) model_training = model.hparams.train_lexical_strategy if model_training == 'none' or always_use_finetuned_lexical is True: return model lexical_model = LexicalTrainingModel.load_from_checkpoint( lexical_checkpoint) # Setup Target Lexical Tokenizer model.tokenizer = lexical_model.tokenizer if model_training == 'freeze-all': model.bert.set_input_embeddings( lexical_model.bert.get_input_embeddings()) elif model_training == 'freeze-nonspecial': # In thise case, embedding should be a sliced embedding # So, we just take the parts we want to form a new Embedding # Special tokens from the finetuned model model_weights = model.bert.get_input_embeddings().get_first_weigths() # Other tokens (language specific) from the lexical aligned model lexical_embeddings = lexical_model.bert.get_input_embeddings() target_weights = lexical_embeddings.get_second_weigths() tobe = SlicedEmbedding(model_weights, target_weights, True, True) # For testing, both are freezed model.bert.set_input_embeddings(tobe) return model
def test_slice(self): original = Embedding(10, 10) cut = 5 sliced = SlicedEmbedding.slice(original, 5, True, False) self.assertTrue(type(sliced) is SlicedEmbedding) first = sliced.first_embedding second = sliced.second_embedding self.assertFalse(first.weight.requires_grad) self.assertTrue(second.weight.requires_grad) self.assertTrue( torch.all(original.weight[:cut] == first.weight).item()) self.assertTrue( torch.all(original.weight[cut:] == second.weight).item())
def test_freeze_final_positions(self): """ We emulate a simple training loop to check whether embeddings are being updated correctly. """ original_embedding = torch.nn.Embedding(10, 2) slieced_embedding = SlicedEmbedding.slice(original_embedding, 5, False, True) # We clone the original weigths, since they are updated original_values = original_embedding.weight.clone() data = torch.tensor([[[0, 1, 5], [2, 3, 6]], [[0, 1, 9], [4, 5, 7]]]) labels = torch.tensor([[0], [1]]) # Always fixed model = SimpleNetwork(slieced_embedding) optimizer = torch.optim.SGD(model.parameters(), lr=0.001) for i in range(2): model.train() optimizer.zero_grad() for i, batch in enumerate(data): outputs = model(batch) loss = torch.nn.functional.cross_entropy(outputs, labels) loss.backward() optimizer.step() trainable_original = original_values[:5] freezed_original = original_values[5:] trainable_actual = model.embeddings.first_embedding.weight freezed_actual = model.embeddings.second_embedding.weight self.assertTrue(torch.all(freezed_original == freezed_actual).item()) self.assertFalse( torch.all(trainable_original == trainable_actual).item())
def __setup_lexical_for_training(self): if self.hparams.train_strategy == 'train-all-lexical': # We freeze all parameters in this model. # Then, we unlock the ones we want. self._freeze_parameters() # Train Word Embeddings Only input_embeddings = self.bert.get_input_embeddings() for parameter in input_embeddings.parameters(): parameter.requires_grad = True # We also train the HEAD (Output Embeddings, since they are tied) output_embeddings = self.bert.get_output_embeddings() for parameter in output_embeddings.parameters(): parameter.requires_grad = True elif self.hparams.train_strategy == 'train-non-special': self._freeze_parameters() # Train Word Embeddings Only, skipping special tokens input_embeddings = self.bert.get_input_embeddings() last_special_token = max(self.tokenizer.all_special_ids) new_input_embeddings = SlicedEmbedding.slice( input_embeddings, last_special_token + 1, True, False) self.bert.set_input_embeddings(new_input_embeddings) # Handling output embeddings output_embeddings = self.bert.get_output_embeddings() new_output_embeddings = SlicedOutputEmbedding( output_embeddings, last_special_token + 1, True, False) self.bert.cls.predictions.decoder = new_output_embeddings
def test_slice_again_different_cut(self): original = Embedding(10, 10) already_sliced = SlicedEmbedding.slice(original, 5, True, False) with self.assertRaises(NotImplementedError) as context: SlicedEmbedding.slice(already_sliced, 6, True, False)