def setup_lexical_for_testing(strategy: str, model: PreTrainedModel,
                              tokenizer: PreTrainedTokenizer,
                              target_lexical: Union[str, Dict]):
    """
    Setup the model lexical part according to `strategy`.
    Strategy can be:

    - original: The model's lexical will be used as is.
    - target_lexical_full: The lexical located at `target_lexical`
                           will be used in model\'s, including special tokens.
    - target_lexical_keep_special: The lexical located at `target_lexical`
                                   will be used, but the model\'s
                                   embeddings for special tokens
                                   will be preserved.
    """
    assert strategy in testing_lexical_strategies
    assert model is not None
    assert tokenizer is not None

    if strategy == 'original':
        return  # Nothing to do here

    if type(target_lexical) is str:
        target_lexical = torch.load(target_lexical)

    if strategy == 'target-lexical':
        model.set_input_embeddings(
            new_like(model.get_input_embeddings(), target_lexical))
    elif strategy == 'target-lexical-original-special':
        assert model.get_input_embeddings().embedding_dim == \
            target_lexical['weight'].shape[1]

        # We cut on the last kwown special token
        # So we get the latest index + 1
        # (for instance, if the last one is 103, we get 104 ).
        # This is because SlicedEmbeddings cut on [:cut].
        bert_special_tokens_cut = sorted(tokenizer.all_special_ids)[-1] + 1

        model_weights = model.get_input_embeddings().weight
        target_weights = target_lexical['weight']

        tobe = SlicedEmbedding(model_weights[:bert_special_tokens_cut],
                               target_weights[bert_special_tokens_cut:], True,
                               True)  # For testing, both are freezed

        model.set_input_embeddings(tobe)
    else:
        raise NotImplementedError(f'strategy {strategy} is not implemented')
def slice_lexical_embedding(model: PreTrainedModel,
                            tokenizer: PreTrainedTokenizer, freeze_first: bool,
                            freeze_second: bool):
    # We cut on the last kwown special token
    # So we get the latest index + 1
    # (for instance, if the last one is 103, we get 104 ).
    # This is because SlicedEmbeddings cut on [:cut].
    bert_special_tokens_cut = max(tokenizer.all_special_ids) + 1
    original_embeddings = model.get_input_embeddings()

    tobe_embeddings = SlicedEmbedding.slice(original_embeddings,
                                            bert_special_tokens_cut,
                                            freeze_first_part=freeze_first,
                                            freeze_second_part=freeze_second)

    model.set_input_embeddings(tobe_embeddings)