def test_slice_again(self):
        original = Embedding(10, 10)
        already_sliced = SlicedEmbedding.slice(original, 5, True, False)
        slice_it_again = SlicedEmbedding.slice(already_sliced, 5, True, False)

        self.assertTrue(type(slice_it_again) is SlicedEmbedding)
        self.assertEqual(already_sliced, slice_it_again)
    def test_lookup_embeddings(self):
        original_embedding = torch.nn.Embedding(10, 2)
        slieced_embedding = SlicedEmbedding.slice(original_embedding, 5, True,
                                                  False)

        batch_for_original = torch.tensor([[0, 1], [8, 9]])
        batch_for_slieced = torch.tensor([[0, 1], [8, 9]])

        lookup_original = original_embedding(batch_for_original)
        lookup_sliced = slieced_embedding(batch_for_slieced)

        self.assertTrue(torch.all(lookup_original == lookup_sliced).item())
def prepare_model_for_testing(checkpoint_path: str,
                              lexical_checkpoint: str = None,
                              always_use_finetuned_lexical: bool = False):
    try:
        model = NLIFinetuneModel.load_from_checkpoint(checkpoint_path)
    except:
        model = NLIFinetuneModel(pretrained_model=checkpoint_path,
                                 num_classes=3,
                                 train_lexical_strategy='none',
                                 train_dataset='assin2',
                                 eval_dataset='assin2',
                                 data_dir='./data',
                                 batch_size=32,
                                 max_seq_length=128,
                                 tokenizer_name=None)

    model_training = model.hparams.train_lexical_strategy

    if model_training == 'none' or always_use_finetuned_lexical is True:
        return model

    lexical_model = LexicalTrainingModel.load_from_checkpoint(
        lexical_checkpoint)

    # Setup Target Lexical Tokenizer
    model.tokenizer = lexical_model.tokenizer

    if model_training == 'freeze-all':
        model.bert.set_input_embeddings(
            lexical_model.bert.get_input_embeddings())

    elif model_training == 'freeze-nonspecial':
        # In thise case, embedding should be a sliced embedding
        # So, we just take the parts we want to form a new Embedding

        # Special tokens from the finetuned model
        model_weights = model.bert.get_input_embeddings().get_first_weigths()

        # Other tokens (language specific) from the lexical aligned model
        lexical_embeddings = lexical_model.bert.get_input_embeddings()
        target_weights = lexical_embeddings.get_second_weigths()

        tobe = SlicedEmbedding(model_weights, target_weights, True,
                               True)  # For testing, both are freezed

        model.bert.set_input_embeddings(tobe)

    return model
    def test_slice(self):
        original = Embedding(10, 10)
        cut = 5

        sliced = SlicedEmbedding.slice(original, 5, True, False)

        self.assertTrue(type(sliced) is SlicedEmbedding)

        first = sliced.first_embedding
        second = sliced.second_embedding

        self.assertFalse(first.weight.requires_grad)
        self.assertTrue(second.weight.requires_grad)

        self.assertTrue(
            torch.all(original.weight[:cut] == first.weight).item())

        self.assertTrue(
            torch.all(original.weight[cut:] == second.weight).item())
    def test_freeze_final_positions(self):
        """
        We emulate a simple training loop to check whether embeddings
        are being updated correctly.
        """
        original_embedding = torch.nn.Embedding(10, 2)
        slieced_embedding = SlicedEmbedding.slice(original_embedding, 5, False,
                                                  True)

        # We clone the original weigths, since they are updated
        original_values = original_embedding.weight.clone()

        data = torch.tensor([[[0, 1, 5], [2, 3, 6]], [[0, 1, 9], [4, 5, 7]]])
        labels = torch.tensor([[0], [1]])  # Always fixed

        model = SimpleNetwork(slieced_embedding)
        optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

        for i in range(2):
            model.train()
            optimizer.zero_grad()

            for i, batch in enumerate(data):
                outputs = model(batch)
                loss = torch.nn.functional.cross_entropy(outputs, labels)

                loss.backward()
                optimizer.step()

        trainable_original = original_values[:5]
        freezed_original = original_values[5:]

        trainable_actual = model.embeddings.first_embedding.weight
        freezed_actual = model.embeddings.second_embedding.weight

        self.assertTrue(torch.all(freezed_original == freezed_actual).item())
        self.assertFalse(
            torch.all(trainable_original == trainable_actual).item())
示例#6
0
    def __setup_lexical_for_training(self):
        if self.hparams.train_strategy == 'train-all-lexical':
            # We freeze all parameters in this model.
            # Then, we unlock the ones we want.
            self._freeze_parameters()

            # Train Word Embeddings Only
            input_embeddings = self.bert.get_input_embeddings()

            for parameter in input_embeddings.parameters():
                parameter.requires_grad = True

            # We also train the HEAD (Output Embeddings, since they are tied)
            output_embeddings = self.bert.get_output_embeddings()

            for parameter in output_embeddings.parameters():
                parameter.requires_grad = True

        elif self.hparams.train_strategy == 'train-non-special':
            self._freeze_parameters()

            # Train Word Embeddings Only, skipping special tokens
            input_embeddings = self.bert.get_input_embeddings()
            last_special_token = max(self.tokenizer.all_special_ids)

            new_input_embeddings = SlicedEmbedding.slice(
                input_embeddings, last_special_token + 1, True, False)

            self.bert.set_input_embeddings(new_input_embeddings)

            # Handling output embeddings
            output_embeddings = self.bert.get_output_embeddings()

            new_output_embeddings = SlicedOutputEmbedding(
                output_embeddings, last_special_token + 1, True, False)

            self.bert.cls.predictions.decoder = new_output_embeddings
    def test_slice_again_different_cut(self):
        original = Embedding(10, 10)
        already_sliced = SlicedEmbedding.slice(original, 5, True, False)

        with self.assertRaises(NotImplementedError) as context:
            SlicedEmbedding.slice(already_sliced, 6, True, False)