def upgrade_state_dict_with_xlm_weights(
        state_dict: Dict[str, Any],
        pretrained_xlm_checkpoint: str) -> Dict[str, Any]:
    """
    Load XLM weights into a Transformer encoder or decoder model.

    Args:
        state_dict: state dict for either TransformerEncoder or
            TransformerDecoder
        pretrained_xlm_checkpoint: checkpoint to load XLM weights from

    Raises:
        AssertionError: If architecture (num layers, attention heads, etc.)
            does not match between the current Transformer encoder or
            decoder and the pretrained_xlm_checkpoint
    """
    if not os.path.exists(pretrained_xlm_checkpoint):
        raise IOError(f"Model file not found: {pretrained_xlm_checkpoint}")

    state = utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint)
    xlm_state_dict = state["model"]
    for key in xlm_state_dict.keys():

        for search_key in ["embed_tokens", "embed_positions", "layers"]:
            if search_key in key:
                subkey = key[key.find(search_key):]
                assert subkey in state_dict, (
                    f"{str(state_dict.keys())} Transformer encoder / decoder "
                    f"state_dict does not contain {subkey}. Cannot "
                    f"load {key} from pretrained XLM checkpoint "
                    f"{pretrained_xlm_checkpoint} into Transformer.")

                state_dict[subkey] = xlm_state_dict[key]
    return state_dict
Пример #2
0
    def load_pretrained_model(path,
                              src_dict_path,
                              tgt_dict_path,
                              ctx_dict_path,
                              arg_overrides=None):
        model = utils.load_checkpoint_to_cpu(path)
        args = model['args']
        state_dict = model['model']
        args = utils.override_model_args(args, arg_overrides)
        src_dict = Dictionary.load(src_dict_path)
        tgt_dict = Dictionary.load(tgt_dict_path)
        ctx_dict = Dictionary.load(ctx_dict_path)  # [CONTEXT]/
        # [CONTEXT]/
        # assert src_dict.pad() == tgt_dict.pad()
        # assert src_dict.eos() == tgt_dict.eos()
        # assert src_dict.unk() == tgt_dict.unk()
        assert src_dict.pad() == tgt_dict.pad() == ctx_dict.pad()
        assert src_dict.eos() == tgt_dict.eos() == ctx_dict.eos()
        assert src_dict.unk() == tgt_dict.unk() == ctx_dict.unk()

        # [CONTEXT]/
        # task = TranslationTask(args, src_dict, tgt_dict)
        task = TranslationContextTask(args, src_dict, tgt_dict, ctx_dict)
        model = task.build_model(args)
        model.upgrade_state_dict(state_dict)
        model.load_state_dict(state_dict, strict=True)
        return model
    def load_pretrained_model(path, arg_overrides=None):
        model = utils.load_checkpoint_to_cpu(path)
        args = model['args']
        state_dict = model['model']
        args = utils.override_model_args(args, arg_overrides)
        if args.smile_dic_type == 'short':
            dictionary = SmileDictionary.load()
        else:
            dictionary = GeneralSmileDictionary.load()

        task = SmilePropertyPredictionTask(args, dictionary)
        model = task.build_model(args)
        model.upgrade_state_dict(state_dict)
        model.load_state_dict(state_dict, strict=True)
        return model
    def load_pretrained_generator(self, path, arg_overrides=None):
        model = utils.load_checkpoint_to_cpu(path)
        args = model['args']
        state_dict = model['model']
        if not(arg_overrides is None):
            args = utils.override_model_args(args, arg_overrides)
        src_dict = self.source_dictionary
        tgt_dict = self.target_dictionary
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()

        task = MaskMLETask(args, src_dict, tgt_dict)
        model = task.build_model(args)
        model.upgrade_state_dict(state_dict)
        model.load_state_dict(state_dict, strict=True)
        return model
Пример #5
0
    def load_pretrained_model(path,
                              src_dict_path,
                              tgt_dict_path,
                              arg_overrides=None):
        model = utils.load_checkpoint_to_cpu(path)
        args = model['args']
        state_dict = model['model']
        args = utils.override_model_args(args, arg_overrides)
        src_dict = BertBasedDictionary(args.bert_name)
        tgt_dict = Dictionary.load(tgt_dict_path)
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()

        task = BertTranslationTask(args, src_dict, tgt_dict)
        model = task.build_model(args)
        model.upgrade_state_dict(state_dict)
        model.load_state_dict(state_dict, strict=True)
        return model