예제 #1
0
    def _resize_token_embeddings(self, state_dict, msg=None):
        """
        Resize the token embeddings when are adding extra special tokens.

        Modify TGA._resize_token_embeddings to access correct modules within RAG.
        """
        # map extra special tokens carefully
        new_size = self.model.embeddings.weight.size()[0]
        orig_size = state_dict['embeddings.weight'].size()[0]
        logging.info(
            f'Resizing token embeddings from {orig_size} to {new_size}')
        if new_size <= orig_size:
            # new size should be greater than original size,
            # as we are adding special tokens
            raise RuntimeError(msg)

        for emb_weights in [
                'embeddings.weight',
                'seq2seq_encoder.embeddings.weight',
                'seq2seq_decoder.embeddings.weight',
        ]:
            # get new_embs
            old_embs = state_dict[emb_weights]
            new_embs = recursive_getattr(self.model,
                                         emb_weights).to(old_embs.device)
            # copy over old weights
            new_embs.data[:orig_size, :] = old_embs.data[:orig_size, :]
            # reset in state dict
            state_dict[emb_weights] = new_embs

        return state_dict
예제 #2
0
    def _resize_token_embeddings(self, state_dict, msg=None):
        """
        Resize the token embeddings when adding extra special tokens.

        H/t TransformerGenerator._resize_token_embeddings for inspiration.
        """
        # map extra special tokens carefully
        new_size = self.model.encoder_ctxt.embeddings.weight.size()[0]
        orig_size = state_dict['encoder_ctxt.embeddings.weight'].size()[0]
        logging.info(f'Resizing token embeddings from {orig_size} to {new_size}')
        if new_size <= orig_size:
            # new size should be greater than original size,
            # as we are adding special tokens
            raise RuntimeError(msg)

        for emb_weights in [
            'encoder_ctxt.embeddings.weight',
            'encoder_cand.embeddings.weight',
        ]:
            # get new_embs
            old_embs = state_dict[emb_weights]
            new_embs = recursive_getattr(self.model, emb_weights).to(old_embs.device)
            # copy over old weights
            new_embs.data[:orig_size, :] = old_embs.data[:orig_size, :]
            # reset in state dict
            state_dict[emb_weights] = new_embs

        return state_dict
예제 #3
0
    def _resize_token_embeddings(self, state_dict, msg=None):
        """
        Resize the token embeddings when are adding extra special tokens.

        Switch to `base_model`.
        """
        # map extra special tokens carefully
        key = self.base_model_key
        base_model = getattr(self.model, key)
        new_size = base_model.embeddings.weight.size()[0]
        orig_size = state_dict[f'{key}.embeddings.weight'].size()[0]
        logging.info(f'Resizing token embeddings from {orig_size} to {new_size}')
        if new_size <= orig_size:
            # new size should be greater than original size,
            # as we are adding special tokens
            raise RuntimeError(msg)

        for emb_weights in [
            'embeddings.weight',
            'encoder.embeddings.weight',
            'decoder.embeddings.weight',
        ]:
            # get new_embs
            old_embs = state_dict[f"{key}.{emb_weights}"]
            new_embs = recursive_getattr(base_model, emb_weights).to(old_embs.device)
            # copy over old weights
            new_embs.data[:orig_size, :] = old_embs.data[:orig_size, :]
            # reset in state dict
            state_dict[f"{key}.{emb_weights}"] = new_embs

        if self.opt['copy_cand_embeddings']:
            state_dict = self.update_cand_embeddings(state_dict)

        return state_dict
예제 #4
0
    def load_dpr_model(bert_model: BertModel, pretrained_dpr_path: str,
                       encoder_type: str):
        """
        Load saved state from pretrained DPR model directly into given bert_model.

        :param bert_model:
            bert model to load
        :param pretrained_dpr_path:
            path to pretrained DPR BERT Model
        :param encoder_type:
            whether we're loading a document or query encoder.
        """
        saved_state = torch.load(pretrained_dpr_path, map_location='cpu')
        model_to_load = (bert_model.module
                         if hasattr(bert_model, 'module') else bert_model)

        prefix = 'question_model.' if encoder_type == 'query' else 'ctx_model.'
        prefix_len = len(prefix)
        encoder_state = {
            key[prefix_len:]: value
            for (key, value) in saved_state['model_dict'].items()
            if key.startswith(prefix)
        }
        encoder_state.update({
            k: v
            for k, v in saved_state['model_dict'].items() if 'encode_proj' in k
        })
        try:
            model_to_load.load_state_dict(encoder_state)
        except RuntimeError:
            for key in BERT_COMPATIBILITY_KEYS:
                encoder_state[key] = recursive_getattr(model_to_load, key)
            model_to_load.load_state_dict(encoder_state)