def __init__(
        self, 
        model: SentenceTransformer, 
        decoder_name_or_path='bert-base-uncased', 
        tie_encoder_decoder=True
    ):
        """
        :param model: SentenceTransformer model
        :param decoder_name_or_path: Model name or path for initializing a decoder (compatible with Huggingface's Transformers)
        :param tie_encoder_decoder: whether to tie the trainable parameters of encoder and decoder
        """
        super(DenoisingAutoEncoderLoss, self).__init__()
        self.encoder = model  # This will be the final model used during the inference time.
        self.tokenizer_encoder = model.tokenizer
        self.tokenizer_decoder = AutoTokenizer.from_pretrained(decoder_name_or_path)
        self.need_retokenization = not(type(self.tokenizer_encoder) == type(self.tokenizer_decoder))

        decoder_config = AutoConfig.from_pretrained(decoder_name_or_path)
        decoder_config.is_decoder = True
        decoder_config.add_cross_attention = True
        kwargs_decoder = {'config': decoder_config}
        self.decoder = AutoModelForCausalLM.from_pretrained(decoder_name_or_path, **kwargs_decoder)
        assert model[0].auto_model.config.hidden_size == decoder_config.hidden_size, 'Hidden sizes do not match!'
        if self.tokenizer_decoder.pad_token is None:
            # Needed by GPT-2, etc.
            self.tokenizer_decoder.pad_token = self.tokenizer_decoder.eos_token
            self.decoder.config.pad_token_id = self.decoder.config.eos_token_id

        if tie_encoder_decoder and not self.need_retokenization:
            decoder_base_model_prefix = self.decoder.base_model_prefix
            PreTrainedModel._tie_encoder_decoder_weights(
                model[0].auto_model, 
                self.decoder._modules[decoder_base_model_prefix], 
                self.decoder.base_model_prefix
            )
    def __init__(self,
                 model: SentenceTransformer,
                 decoder_name_or_path: str = None,
                 tie_encoder_decoder: bool = True):
        """
        :param model: SentenceTransformer model
        :param decoder_name_or_path: Model name or path for initializing a decoder (compatible with Huggingface's Transformers)
        :param tie_encoder_decoder: whether to tie the trainable parameters of encoder and decoder
        """
        super(DenoisingAutoEncoderLoss, self).__init__()
        self.encoder = model  # This will be the final model used during the inference time.
        self.tokenizer_encoder = model.tokenizer

        encoder_name_or_path = model[0].auto_model.config._name_or_path
        if decoder_name_or_path is None:
            assert tie_encoder_decoder, "Must indicate the decoder_name_or_path argument when tie_encoder_decoder=False!"
        if tie_encoder_decoder:
            if decoder_name_or_path:
                logger.warning(
                    'When tie_encoder_decoder=True, the decoder_name_or_path will be invalid.'
                )
            decoder_name_or_path = encoder_name_or_path

        self.tokenizer_decoder = AutoTokenizer.from_pretrained(
            decoder_name_or_path)
        self.need_retokenization = not (type(self.tokenizer_encoder) == type(
            self.tokenizer_decoder))

        decoder_config = AutoConfig.from_pretrained(decoder_name_or_path)
        decoder_config.is_decoder = True
        decoder_config.add_cross_attention = True
        kwargs_decoder = {'config': decoder_config}
        try:
            self.decoder = AutoModelForCausalLM.from_pretrained(
                decoder_name_or_path, **kwargs_decoder)
        except ValueError as e:
            logger.error(
                f'Model name or path "{decoder_name_or_path}" does not support being as a decoder. Please make sure the decoder model has an "XXXLMHead" class.'
            )
            raise e
        assert model[
            0].auto_model.config.hidden_size == decoder_config.hidden_size, 'Hidden sizes do not match!'
        if self.tokenizer_decoder.pad_token is None:
            # Needed by GPT-2, etc.
            self.tokenizer_decoder.pad_token = self.tokenizer_decoder.eos_token
            self.decoder.config.pad_token_id = self.decoder.config.eos_token_id

        if len(AutoTokenizer.from_pretrained(encoder_name_or_path)) != len(
                self.tokenizer_encoder):
            logger.warning(
                'WARNING: The vocabulary of the encoder has been changed. One might need to change the decoder vocabulary, too.'
            )

        if tie_encoder_decoder:
            assert not self.need_retokenization, "The tokenizers should be the same when tie_encoder_decoder=True."
            if len(self.tokenizer_encoder) != len(
                    self.tokenizer_decoder
            ):  # The vocabulary has been changed.
                self.tokenizer_decoder = self.tokenizer_encoder
                self.decoder.resize_token_embeddings(
                    len(self.tokenizer_decoder))
                logger.warning(
                    'Since the encoder vocabulary has been changed and --tie_encoder_decoder=True, now the new vocabulary has also been used for the decoder.'
                )
            decoder_base_model_prefix = self.decoder.base_model_prefix
            PreTrainedModel._tie_encoder_decoder_weights(
                model[0].auto_model,
                self.decoder._modules[decoder_base_model_prefix],
                self.decoder.base_model_prefix)