def _resize_token_embeddings(self, state_dict, msg=None): """ Resize the token embeddings when are adding extra special tokens. Modify TGA._resize_token_embeddings to access correct modules within RAG. """ # map extra special tokens carefully new_size = self.model.embeddings.weight.size()[0] orig_size = state_dict['embeddings.weight'].size()[0] logging.info( f'Resizing token embeddings from {orig_size} to {new_size}') if new_size <= orig_size: # new size should be greater than original size, # as we are adding special tokens raise RuntimeError(msg) for emb_weights in [ 'embeddings.weight', 'seq2seq_encoder.embeddings.weight', 'seq2seq_decoder.embeddings.weight', ]: # get new_embs old_embs = state_dict[emb_weights] new_embs = recursive_getattr(self.model, emb_weights).to(old_embs.device) # copy over old weights new_embs.data[:orig_size, :] = old_embs.data[:orig_size, :] # reset in state dict state_dict[emb_weights] = new_embs return state_dict
def _resize_token_embeddings(self, state_dict, msg=None): """ Resize the token embeddings when adding extra special tokens. H/t TransformerGenerator._resize_token_embeddings for inspiration. """ # map extra special tokens carefully new_size = self.model.encoder_ctxt.embeddings.weight.size()[0] orig_size = state_dict['encoder_ctxt.embeddings.weight'].size()[0] logging.info(f'Resizing token embeddings from {orig_size} to {new_size}') if new_size <= orig_size: # new size should be greater than original size, # as we are adding special tokens raise RuntimeError(msg) for emb_weights in [ 'encoder_ctxt.embeddings.weight', 'encoder_cand.embeddings.weight', ]: # get new_embs old_embs = state_dict[emb_weights] new_embs = recursive_getattr(self.model, emb_weights).to(old_embs.device) # copy over old weights new_embs.data[:orig_size, :] = old_embs.data[:orig_size, :] # reset in state dict state_dict[emb_weights] = new_embs return state_dict
def _resize_token_embeddings(self, state_dict, msg=None): """ Resize the token embeddings when are adding extra special tokens. Switch to `base_model`. """ # map extra special tokens carefully key = self.base_model_key base_model = getattr(self.model, key) new_size = base_model.embeddings.weight.size()[0] orig_size = state_dict[f'{key}.embeddings.weight'].size()[0] logging.info(f'Resizing token embeddings from {orig_size} to {new_size}') if new_size <= orig_size: # new size should be greater than original size, # as we are adding special tokens raise RuntimeError(msg) for emb_weights in [ 'embeddings.weight', 'encoder.embeddings.weight', 'decoder.embeddings.weight', ]: # get new_embs old_embs = state_dict[f"{key}.{emb_weights}"] new_embs = recursive_getattr(base_model, emb_weights).to(old_embs.device) # copy over old weights new_embs.data[:orig_size, :] = old_embs.data[:orig_size, :] # reset in state dict state_dict[f"{key}.{emb_weights}"] = new_embs if self.opt['copy_cand_embeddings']: state_dict = self.update_cand_embeddings(state_dict) return state_dict
def load_dpr_model(bert_model: BertModel, pretrained_dpr_path: str, encoder_type: str): """ Load saved state from pretrained DPR model directly into given bert_model. :param bert_model: bert model to load :param pretrained_dpr_path: path to pretrained DPR BERT Model :param encoder_type: whether we're loading a document or query encoder. """ saved_state = torch.load(pretrained_dpr_path, map_location='cpu') model_to_load = (bert_model.module if hasattr(bert_model, 'module') else bert_model) prefix = 'question_model.' if encoder_type == 'query' else 'ctx_model.' prefix_len = len(prefix) encoder_state = { key[prefix_len:]: value for (key, value) in saved_state['model_dict'].items() if key.startswith(prefix) } encoder_state.update({ k: v for k, v in saved_state['model_dict'].items() if 'encode_proj' in k }) try: model_to_load.load_state_dict(encoder_state) except RuntimeError: for key in BERT_COMPATIBILITY_KEYS: encoder_state[key] = recursive_getattr(model_to_load, key) model_to_load.load_state_dict(encoder_state)