def __init__( self, model: SentenceTransformer, decoder_name_or_path='bert-base-uncased', tie_encoder_decoder=True ): """ :param model: SentenceTransformer model :param decoder_name_or_path: Model name or path for initializing a decoder (compatible with Huggingface's Transformers) :param tie_encoder_decoder: whether to tie the trainable parameters of encoder and decoder """ super(DenoisingAutoEncoderLoss, self).__init__() self.encoder = model # This will be the final model used during the inference time. self.tokenizer_encoder = model.tokenizer self.tokenizer_decoder = AutoTokenizer.from_pretrained(decoder_name_or_path) self.need_retokenization = not(type(self.tokenizer_encoder) == type(self.tokenizer_decoder)) decoder_config = AutoConfig.from_pretrained(decoder_name_or_path) decoder_config.is_decoder = True decoder_config.add_cross_attention = True kwargs_decoder = {'config': decoder_config} self.decoder = AutoModelForCausalLM.from_pretrained(decoder_name_or_path, **kwargs_decoder) assert model[0].auto_model.config.hidden_size == decoder_config.hidden_size, 'Hidden sizes do not match!' if self.tokenizer_decoder.pad_token is None: # Needed by GPT-2, etc. self.tokenizer_decoder.pad_token = self.tokenizer_decoder.eos_token self.decoder.config.pad_token_id = self.decoder.config.eos_token_id if tie_encoder_decoder and not self.need_retokenization: decoder_base_model_prefix = self.decoder.base_model_prefix PreTrainedModel._tie_encoder_decoder_weights( model[0].auto_model, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix )
def __init__(self, model: SentenceTransformer, decoder_name_or_path: str = None, tie_encoder_decoder: bool = True): """ :param model: SentenceTransformer model :param decoder_name_or_path: Model name or path for initializing a decoder (compatible with Huggingface's Transformers) :param tie_encoder_decoder: whether to tie the trainable parameters of encoder and decoder """ super(DenoisingAutoEncoderLoss, self).__init__() self.encoder = model # This will be the final model used during the inference time. self.tokenizer_encoder = model.tokenizer encoder_name_or_path = model[0].auto_model.config._name_or_path if decoder_name_or_path is None: assert tie_encoder_decoder, "Must indicate the decoder_name_or_path argument when tie_encoder_decoder=False!" if tie_encoder_decoder: if decoder_name_or_path: logger.warning( 'When tie_encoder_decoder=True, the decoder_name_or_path will be invalid.' ) decoder_name_or_path = encoder_name_or_path self.tokenizer_decoder = AutoTokenizer.from_pretrained( decoder_name_or_path) self.need_retokenization = not (type(self.tokenizer_encoder) == type( self.tokenizer_decoder)) decoder_config = AutoConfig.from_pretrained(decoder_name_or_path) decoder_config.is_decoder = True decoder_config.add_cross_attention = True kwargs_decoder = {'config': decoder_config} try: self.decoder = AutoModelForCausalLM.from_pretrained( decoder_name_or_path, **kwargs_decoder) except ValueError as e: logger.error( f'Model name or path "{decoder_name_or_path}" does not support being as a decoder. Please make sure the decoder model has an "XXXLMHead" class.' ) raise e assert model[ 0].auto_model.config.hidden_size == decoder_config.hidden_size, 'Hidden sizes do not match!' if self.tokenizer_decoder.pad_token is None: # Needed by GPT-2, etc. self.tokenizer_decoder.pad_token = self.tokenizer_decoder.eos_token self.decoder.config.pad_token_id = self.decoder.config.eos_token_id if len(AutoTokenizer.from_pretrained(encoder_name_or_path)) != len( self.tokenizer_encoder): logger.warning( 'WARNING: The vocabulary of the encoder has been changed. One might need to change the decoder vocabulary, too.' ) if tie_encoder_decoder: assert not self.need_retokenization, "The tokenizers should be the same when tie_encoder_decoder=True." if len(self.tokenizer_encoder) != len( self.tokenizer_decoder ): # The vocabulary has been changed. self.tokenizer_decoder = self.tokenizer_encoder self.decoder.resize_token_embeddings( len(self.tokenizer_decoder)) logger.warning( 'Since the encoder vocabulary has been changed and --tie_encoder_decoder=True, now the new vocabulary has also been used for the decoder.' ) decoder_base_model_prefix = self.decoder.base_model_prefix PreTrainedModel._tie_encoder_decoder_weights( model[0].auto_model, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix)