예제 #1
0
    def load_state_dict(self, state_dict, strict=True, args=None):
        """Copies parameters and buffers from *state_dict* into this module and
        its descendants.

        Overrides the method in :class:`nn.Module`. Compared with that method
        this additionally "upgrades" *state_dicts* from old checkpoints.
        """
        self.upgrade_state_dict(state_dict)
        new_state_dict = prune_state_dict(state_dict, args)
        return super().load_state_dict(new_state_dict, strict)
예제 #2
0
    def load_state_dict(
        self,
        state_dict,
        strict=True,
        model_cfg: Optional[DictConfig] = None,
        args: Optional[Namespace] = None,
    ):

        if model_cfg is None and args is not None:
            logger.warn(
                "using 'args' is deprecated, please update your code to use dataclass config"
            )
            model_cfg = convert_namespace_to_omegaconf(args).model

        self.upgrade_state_dict(state_dict)

        from fairseq.checkpoint_utils import prune_state_dict

        new_state_dict = prune_state_dict(state_dict, model_cfg)
        if not model_cfg.ape:
            model_seq_len = self.state_dict(
            )['encoder.deit.pos_embed'].shape[1]
            ckpt_seq_len = new_state_dict['encoder.deit.pos_embed'].shape[1]
            logger.warning(
                'Load from encoder.deit {:d} seq len to {:d}'.format(
                    ckpt_seq_len, model_seq_len))
            if model_seq_len <= ckpt_seq_len:
                new_state_dict['encoder.deit.pos_embed'] = new_state_dict[
                    'encoder.deit.pos_embed'][:, :model_seq_len, :]
            else:
                t = self.state_dict()['encoder.deit.pos_embed']
                t[:, :ckpt_seq_len, :] = new_state_dict[
                    'encoder.deit.pos_embed']
                new_state_dict['encoder.deit.pos_embed'] = t

        if hasattr(model_cfg,
                   "reset_dictionary") and model_cfg.reset_dictionary:
            logger.info(
                'Reset token embed weights and output projection during loading pretrained models'
            )
            del new_state_dict['decoder.embed_tokens.weight']
            del new_state_dict['decoder.output_projection.weight']

        return super().load_state_dict(new_state_dict, strict=False)
예제 #3
0
    def load_state_dict(self, state_dict, strict=True, args=None):
        """Copies parameters and buffers from *state_dict* into this module and
        its descendants.

        Overrides the method in :class:`nn.Module`. Compared with that method
        this additionally "upgrades" *state_dicts* from old checkpoints.
        """
        self.upgrade_state_dict(state_dict)
        new_state_dict = prune_state_dict(state_dict, args)
        changed_state_dict = new_state_dict.copy()

        # for key, value in new_state_dict.items():
        #     if key in ['structure_att.tp_linear.weight', 'structure_att.tp_linear.bias', 'structure_att.tc_linear.weight', 'structure_att.tc_linear.bias', 'structure_att.fi_linear.weight', 'structure_att.bilinear._weight_matrix', 'structure_att.bilinear._bias', 'structure_att.fzlinear.weight', 'structure_att.fzlinear.bias', 'str_to_enc_linear.weight', 'str_to_enc_linear.bias', 'structure_att.exparam']:
        #         changed_state_dict['encoder.'+key] = new_state_dict[key]
        #         del changed_state_dict[key]

        #print(new_state_dict)
        #print(new_state_dict.keys())
        #exit()
        return super().load_state_dict(changed_state_dict, strict)
예제 #4
0
    def load_state_dict(
        self,
        state_dict,
        strict=True,
        model_cfg: Optional[DictConfig] = None,
        args: Optional[Namespace] = None,
    ):
        """Copies parameters and buffers from *state_dict* into this module and
        its descendants.

        Overrides the method in :class:`nn.Module`. Compared with that method
        this additionally "upgrades" *state_dicts* from old checkpoints.
        """

        if model_cfg is None and args is not None:
            logger.warn("using 'args' is deprecated, please update your code to use dataclass config")
            model_cfg = convert_namespace_to_omegaconf(args).model

        self.upgrade_state_dict(state_dict)
        new_state_dict = prune_state_dict(state_dict, model_cfg)
        return super().load_state_dict(new_state_dict, strict)