def setup_lexical_for_testing(strategy: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, target_lexical: Union[str, Dict]): """ Setup the model lexical part according to `strategy`. Strategy can be: - original: The model's lexical will be used as is. - target_lexical_full: The lexical located at `target_lexical` will be used in model\'s, including special tokens. - target_lexical_keep_special: The lexical located at `target_lexical` will be used, but the model\'s embeddings for special tokens will be preserved. """ assert strategy in testing_lexical_strategies assert model is not None assert tokenizer is not None if strategy == 'original': return # Nothing to do here if type(target_lexical) is str: target_lexical = torch.load(target_lexical) if strategy == 'target-lexical': model.set_input_embeddings( new_like(model.get_input_embeddings(), target_lexical)) elif strategy == 'target-lexical-original-special': assert model.get_input_embeddings().embedding_dim == \ target_lexical['weight'].shape[1] # We cut on the last kwown special token # So we get the latest index + 1 # (for instance, if the last one is 103, we get 104 ). # This is because SlicedEmbeddings cut on [:cut]. bert_special_tokens_cut = sorted(tokenizer.all_special_ids)[-1] + 1 model_weights = model.get_input_embeddings().weight target_weights = target_lexical['weight'] tobe = SlicedEmbedding(model_weights[:bert_special_tokens_cut], target_weights[bert_special_tokens_cut:], True, True) # For testing, both are freezed model.set_input_embeddings(tobe) else: raise NotImplementedError(f'strategy {strategy} is not implemented')
def slice_lexical_embedding(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, freeze_first: bool, freeze_second: bool): # We cut on the last kwown special token # So we get the latest index + 1 # (for instance, if the last one is 103, we get 104 ). # This is because SlicedEmbeddings cut on [:cut]. bert_special_tokens_cut = max(tokenizer.all_special_ids) + 1 original_embeddings = model.get_input_embeddings() tobe_embeddings = SlicedEmbedding.slice(original_embeddings, bert_special_tokens_cut, freeze_first_part=freeze_first, freeze_second_part=freeze_second) model.set_input_embeddings(tobe_embeddings)
def freeze_all_tokens(model: PreTrainedModel, tokenizer: PreTrainedTokenizer): model.get_input_embeddings().weight.requires_grad = False