コード例 #1
0
def main(args):
    import_user_module(args)
    ckpt_path1 = args.model1_root + '/checkpoints/checkpoint_best.pt'
    ckpt_path2 = args.model2_root + '/checkpoints/checkpoint_best.pt'
    state1 = load_checkpoint_to_cpu(ckpt_path1)
    state2 = load_checkpoint_to_cpu(ckpt_path2)

    enc_emb1 = state1['model']['encoder.embed_tokens.weight']
    enc_emb2 = state2['model']['encoder.embed_tokens.weight']
    check = enc_emb1 == enc_emb2

    print(check[:6])
    print(enc_emb1[:6, :5])
    print(enc_emb2[:6, :5])
コード例 #2
0
    def load_model_vocab(self, args):

        filename = args.model_path
        if not os.path.exists(filename):
            raise IOError("Model file not found: {}".format(filename))

        state = checkpoint_utils.load_checkpoint_to_cpu(filename)

        task_args = state["cfg"]["task"]
        task_args.data = args.data_bin

        task = self.set_up_task(task_args)

        # build model for ensemble
        self.model = task.build_model(state["cfg"]["model"])
        self.model.load_state_dict(state["model"], strict=True)
        self.model.eval()
        self.model.share_memory()

        if self.gpu:
            self.model.cuda()

        # Set dictionary
        self.dict = {}
        self.dict["tgt"] = task.target_dictionary
コード例 #3
0
    def load_from_pretrained(self, filename, prefix, args):

        state_dict = load_checkpoint_to_cpu(filename)['model']

        if prefix:
            state_dict = {key[len(prefix):]: value for key, value in state_dict.items() if key.startswith(prefix)}

        model_vocab_size = self.decoder.sentence_encoder.embed_tokens.weight.shape[0]
        ckpt_vocab_size = state_dict['decoder.sentence_encoder.embed_tokens.weight'].shape[0]
        diff = model_vocab_size - ckpt_vocab_size

        model_pos_size = self.decoder.sentence_encoder.embed_positions.weight.shape[0]
        ckpt_pos_size = state_dict['decoder.sentence_encoder.embed_positions.weight'].shape[0]
        diff_pos_size = model_pos_size - ckpt_pos_size

        new_state_dict = {}
        for n, c in state_dict.items():
            if n in ['decoder.sentence_encoder.embed_tokens.weight', 'decoder.lm_head.weight'] and diff > 0:
                new_weight = torch.Tensor(c.shape[0]+diff, c.shape[1])
                new_weight.data.normal_(mean=0.0, std=0.02)
                new_weight[:-diff] = c
                new_state_dict[n] = new_weight
            elif n == 'decoder.lm_head.bias' and diff > 0:
                new_weight = torch.zeros(c.shape[0]+diff)
                new_weight[:-diff] = c
                new_state_dict[n] = new_weight
            elif n == 'decoder.sentence_encoder.embed_positions.weight' and diff_pos_size < 0:
                new_weight = c[:c.shape[0] + diff_pos_size]
                new_state_dict[n] = new_weight
            else:
                new_state_dict[n] = c

        missing_keys, unexpected_keys = super().load_state_dict(new_state_dict, strict=False, args=args)
        handle_state_dict_keys(missing_keys, unexpected_keys)
コード例 #4
0
def upgrade_state_dict_with_xlm_weights(
    state_dict: Dict[str, Any], pretrained_xlm_checkpoint: str
) -> Dict[str, Any]:
    """
    Load XLM weights into a Transformer encoder or decoder model.

    Args:
        state_dict: state dict for either TransformerEncoder or
            TransformerDecoder
        pretrained_xlm_checkpoint: checkpoint to load XLM weights from

    Raises:
        AssertionError: If architecture (num layers, attention heads, etc.)
            does not match between the current Transformer encoder or
            decoder and the pretrained_xlm_checkpoint
    """
    if not os.path.exists(pretrained_xlm_checkpoint):
        raise IOError(f"Model file not found: {pretrained_xlm_checkpoint}")

    state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint)
    xlm_state_dict = state["model"]
    for key in xlm_state_dict.keys():

        for search_key in ["embed_tokens", "embed_positions", "layers"]:
            if search_key in key:
                subkey = key[key.find(search_key):]
                assert subkey in state_dict, (
                    f"{str(state_dict.keys())} Transformer encoder / decoder "
                    f"state_dict does not contain {subkey}. Cannot "
                    f"load {key} from pretrained XLM checkpoint "
                    f"{pretrained_xlm_checkpoint} into Transformer."
                )

                state_dict[subkey] = xlm_state_dict[key]
    return state_dict
コード例 #5
0
    def build_encoder(cls, args):
        _args = copy.deepcopy(args)
        if not args.adaptor_proj and not args.encoder_proj:  # V0 arch
            if args.w2v_path:
                state = checkpoint_utils.load_checkpoint_to_cpu(args.w2v_path)
                if state.get("cfg") is not None:
                    encoder_embed_dim = state["cfg"]._content["model"]["encoder_embed_dim"]
                elif state.get("args") is not None:
                    encoder_embed_dim = state["args"].encoder_embed_dim
                else:
                    raise ValueError(f"Invalid config in {args.w2v_path}")
                _args.decoder_embed_dim = encoder_embed_dim
                del state
            else:
                _args.decoder_embed_dim = args.encoder_embed_dim

        encoder = Wav2VecEncoderWithAdaptor(_args)
        encoder = cls.maybe_load_pretrained(
            encoder, getattr(args, "load_pretrained_encoder_from", None)
        )
        if args.remove_weight_norm:
            # remove the wn for EMA usage
            logger.warning("Removing weight norm from wav2vec encoder")
            remove_weight_norm_from_model(encoder)

        return encoder
コード例 #6
0
    def __init__(self, cfg: Wav2Vec2AsrConfig, tgt_dict=None):
        self.mask = cfg.apply_mask

        state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path)
        w2v_args = state.get("cfg", None)
        if w2v_args is None:
            w2v_args = convert_namespace_to_omegaconf(state["args"])
        cfg.w2v_args = w2v_args

        self.mask_prob = cfg.mask_prob
        self.mask_selection = cfg.mask_selection
        self.mask_other = cfg.mask_other
        self.mask_length = cfg.mask_length
        self.no_mask_overlap = cfg.no_mask_overlap
        self.mask_min_space = cfg.mask_min_space

        self.mask_channel_prob = cfg.mask_channel_prob
        self.mask_channel_selection = cfg.mask_channel_selection
        self.mask_channel_other = cfg.mask_channel_other
        self.mask_channel_length = cfg.mask_channel_length
        self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
        self.mask_channel_min_space = cfg.mask_channel_min_space
        '''
        assert cfg.normalize == w2v_args.task.normalize, (
            "Fine-tuning works best when data normalization is the same. "
            "Please check that --normalize is set or unset for both pre-training and here"
        )
        '''
        w2v_args.task.data = cfg.data
        task = tasks.setup_task(w2v_args.task)
        model = task.build_model(w2v_args.model)

        if state is not None and not cfg.no_pretrained_weights:
            model.load_state_dict(state["model"], strict=True)

        model.remove_pretraining_modules()

        super().__init__(task.source_dictionary)

        d = w2v_args.model.encoder_embed_dim

        self.w2v_model = model

        self.final_dropout = nn.Dropout(cfg.final_dropout)
        self.freeze_finetune_updates = cfg.freeze_finetune_updates
        self.num_updates = 0

        self.lstm = nn.LSTM(input_size=d,
                            hidden_size=1024,
                            num_layers=2,
                            batch_first=True,
                            bidirectional=True)
        d = 2048

        if tgt_dict is not None:
            self.proj = Linear(d, len(tgt_dict))
        elif getattr(cfg, "decoder_embed_dim", d) != d:
            self.proj = Linear(d, cfg.decoder_embed_dim)
        else:
            self.proj = None
コード例 #7
0
    def build_model(cls, cfg: HubertSeq2SeqConfig, task: FairseqTask):
        """Build a new model instance."""

        assert (cfg.autoregressive
                ), "Please set task.autoregressive=true for seq2seq asr models"

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        def build_embedding(dictionary, embed_dim):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            return emb

        decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim)

        encoder = cls.build_encoder(cfg, task)
        decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens)

        model = HubertSeq2SeqModel(encoder, decoder)

        if cfg["seq2seq_path"]:
            state = checkpoint_utils.load_checkpoint_to_cpu(cfg.seq2seq_path)
            state = state["model"]
            if cfg["reset_dict"]:
                del state["decoder.embed_out"]
                del state["decoder.embed_tokens.weight"]
            model.load_state_dict(state, strict=False)
        return model
コード例 #8
0
def load_models_and_criterions(filenames, arg_overrides=None, task=None):
    models = []
    criterions = []
    for filename in filenames:
        if not os.path.exists(filename):
            raise IOError("Model file not found: {}".format(filename))
        state = checkpoint_utils.load_checkpoint_to_cpu(
            filename, arg_overrides)

        args = state["args"]
        if task is None:
            task = tasks.setup_task(args)

        # build model for ensemble
        model = task.build_model(args)
        # model.decoder.load_state_dict(state['model'],strict=False)
        # model_state = {k: v for k, v in state['model'].items() if 'encoder' not in k}
        # print(model_state.keys())
        # model.load_state_dict(model_state, strict=False)
        model.load_state_dict(state["model"], strict=True)
        models.append(model)

        criterion = task.build_criterion(args)
        if "criterion" in state:
            criterion.load_state_dict(state["criterion"], strict=True)
        criterions.append(criterion)
    return models, criterions, args
コード例 #9
0
ファイル: wav2vec2_asr.py プロジェクト: yuchenlin/fairseq
    def __init__(self, args, tgt_dict=None):
        self.apply_mask = args.apply_mask

        arg_overrides = {
            "dropout": args.dropout,
            "activation_dropout": args.activation_dropout,
            "dropout_input": args.dropout_input,
            "attention_dropout": args.attention_dropout,
            "mask_length": args.mask_length,
            "mask_prob": args.mask_prob,
            "mask_selection": args.mask_selection,
            "mask_other": args.mask_other,
            "no_mask_overlap": args.no_mask_overlap,
            "mask_channel_length": args.mask_channel_length,
            "mask_channel_prob": args.mask_channel_prob,
            "mask_channel_selection": args.mask_channel_selection,
            "mask_channel_other": args.mask_channel_other,
            "no_mask_channel_overlap": args.no_mask_channel_overlap,
            "encoder_layerdrop": args.layerdrop,
            "feature_grad_mult": args.feature_grad_mult,
        }

        if getattr(args, "w2v_args", None) is None:
            state = checkpoint_utils.load_checkpoint_to_cpu(
                args.w2v_path, arg_overrides
            )
            args.w2v_args = w2v_args = state.get("args", None) or state["cfg"].model
        else:
            state = None
            w2v_args = args.w2v_args

        assert (
            args.normalize == w2v_args.normalize
        ), "Fine-tuning works best when data normalization is the same"

        w2v_args.data = args.data
        task = tasks.setup_task(w2v_args)
        model = task.build_model(w2v_args)

        if state is not None and not args.no_pretrained_weights:
            model.load_state_dict(state["model"], strict=True)

        model.remove_pretraining_modules()

        super().__init__(task.source_dictionary)

        d = w2v_args.encoder_embed_dim

        self.w2v_model = model

        self.final_dropout = nn.Dropout(args.final_dropout)
        self.freeze_finetune_updates = args.freeze_finetune_updates
        self.num_updates = 0

        if tgt_dict is not None:
            self.proj = Linear(d, len(tgt_dict))
        elif getattr(args, "decoder_embed_dim", d) != d:
            self.proj = Linear(d, args.decoder_embed_dim)
        else:
            self.proj = None
コード例 #10
0
 def __init__(self, args, exit_after_mask=False):
     super().__init__()
     self.args = args
     self.iterations = args.decoding_iterations
     self.end_iteration = args.end_iteration
     self.exit_after_mask = exit_after_mask
     self.baseline_model = None
     self.masker = getattr(args, "masker", False)
     self.progressive = hasattr(args, "progressive") and args.progressive
     if getattr(args, "ensemble", False):
         from nsml import DATASET_PATH
         from fairseq import checkpoint_utils
         data_token = "en-de"
         pretrained_path = "{}/train/pretrained_models/maskPredict_{}/checkpoint_best.pt".format(
             DATASET_PATH,
             data_token.split(".")[-1].replace("-", "_"))
         print("| loading", pretrained_path)
         state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_path)
         baseline_model = args.taskobj.build_model(args)
         baseline_model.load_state_dict(state["model"], strict=True)
         if torch.cuda.is_available():
             baseline_model.cuda()
         self.baseline_model = baseline_model
         if args.fp16:
             self.baseline_model.half()
コード例 #11
0
ファイル: infoxlm.py プロジェクト: microsoft/unilm
  def build_model(cls, args, task):

    model_fast = RobertaModel.build_model(args, task)
    model_slow = RobertaModel.build_model(args, task)

    if args.roberta_model_path != "":
      state = checkpoint_utils.load_checkpoint_to_cpu(args.roberta_model_path)
      model_fast.load_state_dict(state["model"], strict=True, args=args)
      model_slow.load_state_dict(state["model"], strict=True, args=args)
    else:
      model_slow.load_state_dict(model_fast.state_dict(), strict=True, args=args)

    proj = None
    if args.use_proj:
      # NOTE alway be share_proj
      langs = ["share_lang"]
      proj = build_projection_dict(langs, args.encoder_embed_dim, args.activation_fn, args.fp16)

    if "xlco_queue_size" in args:
      xlco_queue_size = args.xlco_queue_size
    else: xlco_queue_size = 1
    print("xlco_queue_size is set as %d" % xlco_queue_size, flush=True)
    queue = torch.randn(xlco_queue_size, args.encoder_embed_dim)

    return cls(model_fast, model_slow, queue, proj=proj)
コード例 #12
0
    def load_pretrained_checkpoint(
        self,
        filename,
    ):
        """Load all training state from a checkpoint file."""
        extra_state, self._optim_history = None, None

        state = checkpoint_utils.load_checkpoint_to_cpu(filename)

        # load model parameters
        try:
            self.get_model().load_state_dict(
                state["model"], strict=False, args=self.args
            )

        except Exception:
            raise Exception(
                "Cannot load model parameters from checkpoint {}; "
                "please ensure that the architectures match.".format(
                    filename)
            )

        print("warm start from {}".format(filename))

        return extra_state
コード例 #13
0
ファイル: masked_lm.py プロジェクト: eastonYi/fairseq
    def build_model(cls, args, task, dictionary=None):
        """Build a new model instance."""
        # make sure all arguments are present in older models
        base_architecture(args)

        if not hasattr(args, 'max_positions'):
            args.max_positions = args.tokens_per_sample

        logger.info(args)

        if task is None:
            assert dictionary
            encoder = MaskedLMEncoder(args, dictionary)
        else:
            encoder = MaskedLMEncoder(args, task.dictionary)

        if getattr(args, "lm_path", None):
            print('load masked_lm from {}'.format(args.lm_path))
            state = checkpoint_utils.load_checkpoint_to_cpu(args.lm_path)
            lm_args = state["args"]
            lm_args.data = args.data
            assert getattr(lm_args, "lm_path", None) is None
            task = tasks.setup_task(lm_args)
            encoder = task.build_model(lm_args)
            print('restore masked_lm from {}'.format(args.lm_path))
            encoder.load_state_dict(state["model"], strict=False)

        return cls(args, encoder)
コード例 #14
0
ファイル: fairseq.py プロジェクト: OpenNMT/CTranslate2
    def _load(self):
        import torch
        import fairseq
        from fairseq import checkpoint_utils

        with torch.no_grad():
            checkpoint = checkpoint_utils.load_checkpoint_to_cpu(
                self._model_path)
            args = checkpoint["args"] or checkpoint["cfg"]["model"]

            args.data = self._data_dir
            if self._fixed_dictionary is not None:
                args.fixed_dictionary = self._fixed_dictionary

            if self._source_lang is not None:
                args.source_lang = self._source_lang

            if self._target_lang is not None:
                args.target_lang = self._target_lang

            model_spec = _get_model_spec(args)
            model_spec.with_source_eos = True
            model_spec.with_target_bos = False

            task = fairseq.tasks.setup_task(args)
            model = fairseq.models.build_model(args, task)
            model.eval()
            model.load_state_dict(checkpoint["model"])

            set_transformer_spec(model_spec, model)
            model_spec.register_source_vocabulary(
                _get_vocab(task.source_dictionary))
            model_spec.register_target_vocabulary(
                _get_vocab(task.target_dictionary))
            return model_spec
コード例 #15
0
ファイル: xmasked_seq2seq.py プロジェクト: ljw9609/NMT-MASS
 def build_model(self, args):
     from fairseq import models
     model = models.build_model(args, self)
     if args.reload_checkpoint is not None:
         filename = args.reload_checkpoint
         if os.path.exists(filename):
             state = checkpoint_utils.load_checkpoint_to_cpu(filename)
             model.load_state_dict(state['model'], strict=False)
     return model
コード例 #16
0
    def load_checkpoint(self,
                        filename,
                        reset_optimizer=False,
                        reset_lr_scheduler=False,
                        optimizer_overrides=None):
        """Load all training state from a checkpoint file."""
        extra_state, self._optim_history, last_optim_state = None, [], None

        if os.path.exists(filename):
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
                # TODO this should be a command line flag
                self.get_model().load_state_dict(state['model'], strict=False)
            except Exception:
                raise Exception(
                    'Cannot load model parameters from checkpoint, '
                    'please ensure that the architectures match.')

            extra_state = state['extra_state']
            self._optim_history = state['optimizer_history']
            last_optim_state = state['last_optimizer_state']

        if last_optim_state is not None and not reset_optimizer:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
                'Criterion does not match; please reset the optimizer (--reset-optimizer).'
            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
                'Optimizer does not match; please reset the optimizer (--reset-optimizer).'

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(
                    last_optim['lr_scheduler_state'])
            self.optimizer.load_state_dict(last_optim_state,
                                           optimizer_overrides)

            self._num_updates = last_optim['num_updates']

        if extra_state is not None and 'train_meters' in extra_state:
            self.meters.update(extra_state['train_meters'])
            del extra_state['train_meters']

            # reset TimeMeters, since their start times don't make sense anymore
            for meter in self.meters.values():
                if isinstance(meter, TimeMeter):
                    meter.reset()

        return extra_state
コード例 #17
0
    def __init__(self, args, tgt_dict):
        super().__init__()
        self.args = args

        feature_enc_layers = eval(args.conv_feature_layers)
        self.embed = feature_enc_layers[-1][0]

        self.feature_extractor = ConvFeatureExtractionModel(
            conv_layers=feature_enc_layers,
            dropout=0.0,
            mode=args.extractor_mode,
            conv_bias=args.conv_bias,
        )

        self.post_extract_proj = (
            nn.Linear(self.embed, args.encoder_embed_dim)
            if self.embed != args.encoder_embed_dim
            else None
        )

        self.mask_prob = args.mask_prob
        self.mask_selection = args.mask_selection
        self.mask_other = args.mask_other
        self.mask_length = args.mask_length
        self.no_mask_overlap = args.no_mask_overlap
        self.mask_min_space = args.mask_min_space

        self.mask_channel_prob = args.mask_channel_prob
        self.mask_channel_selection = args.mask_channel_selection
        self.mask_channel_other = args.mask_channel_other
        self.mask_channel_length = args.mask_channel_length
        self.no_mask_channel_overlap = args.no_mask_channel_overlap
        self.mask_channel_min_space = args.mask_channel_min_space

        self.dropout_input = nn.Dropout(args.dropout_input)
        self.dropout_features = nn.Dropout(args.dropout_features)

        self.feature_grad_mult = args.feature_grad_mult

        self.mask_emb = nn.Parameter(
            torch.FloatTensor(args.encoder_embed_dim).uniform_()
        )

        self.encoder = TransformerEncoder(args)
        self.layer_norm = LayerNorm(self.embed)

        self.phone_proj = nn.Linear(args.encoder_embed_dim, len(tgt_dict))

        if getattr(args, "w2v_path", None):
            print('load Wav2VecEncoder from {}'.format(args.w2v_path))
            state = checkpoint_utils.load_checkpoint_to_cpu(args.w2v_path)
            self.load_state_dict(state["model"], strict=False)
コード例 #18
0
    def _load_models(self, args):
        state = checkpoint_utils.load_checkpoint_to_cpu(args.path)
        state["args"].data = args.data
        task = tasks.setup_task(state["args"])
        model = task.build_model(state["args"])
        model.load_state_dict(state["model"], strict=True, args=state["args"])
        model.make_generation_fast_()
        if args.fp16:
            model.half()
        if args.use_cuda:
            model.cuda()

        return [model]
コード例 #19
0
 def build_encoder(cls, args):
     _args = copy.deepcopy(args)
     state = checkpoint_utils.load_checkpoint_to_cpu(args.w2v_path)
     if state.get("cfg") is not None:
         encoder_embed_dim = state["cfg"]._content["model"][
             "encoder_embed_dim"]
     elif state.get("args") is not None:
         encoder_embed_dim = state["args"].encoder_embed_dim
     else:
         raise ValueError(f"Invalid config in {args.w2v_path}")
     _args.decoder_embed_dim = encoder_embed_dim
     encoder = Wav2VecEncoderWithAdaptor(_args)
     return encoder
コード例 #20
0
  def build_model(cls, args, task):
    reload_roberta_base(args)
    if not hasattr(args, 'max_positions'):
      args.max_positions = args.tokens_per_sample

    encoder = RobertaEncoder(args, task.source_dictionary)
    model = cls(args, encoder)

    if args.roberta_model_path != "":
      state = checkpoint_utils.load_checkpoint_to_cpu(args.roberta_model_path)
      model.load_state_dict(state["model"], strict=False, args=args)
    
    print(model.__class__)
    return model
コード例 #21
0
def init_tmodel(source_path, target_path, modified_path):
    """
    Args:
        source_path: A fairseq.Language_model that whose params will be initialized with the params
                    from the Transformer model.
        target_path: A fairseq.Transformer model that has been trained on the Translation task
        modified_path: A string object denoting the path to where you wish to store the model
    """
    encoder_state = checkpoint_utils.load_checkpoint_to_cpu(source_path)
    translation_state = checkpoint_utils.load_checkpoint_to_cpu(target_path)

    filtered_state = []

    for key in encoder_state['model'].keys():
        filtered_state.append((key, encoder_state['model'][key]))

    #Remove the linear and layer norm layers to maintain compatibiility
    filtered_state.pop()
    filtered_state.pop()
    filtered_state.pop()
    filtered_state.pop()

    list_translation_state = []

    for key in translation_state['model'].keys():
        list_translation_state.append((key, translation_state['model'][key]))

    for index, key in enumerate(list_translation_state):
      if key[0].startswith('encoder'):
        list_translation_state[index] = filtered_state[index]

    list_translation_state_dict = OrderedDict(list_translation_state)
    translation_state['model'] = list_translation_state_dict
    
    checkpoint_utils.torch_persistent_save(translation_state, modified_path)

    return 
コード例 #22
0
            def load_feature_extractor(component, checkpoint):
                if not PathManager.exists(checkpoint):
                    raise IOError(
                        "Model file not found: {}".format(checkpoint))
                state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint)
                component_state_dict = OrderedDict()

                component_prefix = "feature_extractor"
                for key in state["model"].keys():
                    if key.startswith(component_prefix):
                        component_subkey = key[len(component_prefix) + 1:]
                        component_state_dict[component_subkey] = state[
                            "model"][key]
                component.load_state_dict(component_state_dict, strict=True)
                return component
コード例 #23
0
 def __init__(self, input_feat_per_channel,
              vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
              transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
              encoder_output_dim=512, in_channels=1, transformer_context=None,
              transformer_sampling=None):
     super().__init__(input_feat_per_channel, vggblock_config, transformer_config,
                      encoder_output_dim, in_channels, transformer_context,
                      transformer_sampling)
     wav2vec_checkpoint = HOME + '/data/fairseq-data/wav2vec_models/checkpoint_last.pt'
     # wav2vec_checkpoint = '/tmp/checkpoint_last.pt'
     cp = checkpoint_utils.load_checkpoint_to_cpu(wav2vec_checkpoint)
     model = Wav2VecModel.build_model(cp['args'], task=None)
     model.load_state_dict(cp['model'])
     freeze_module_params(model)
     self.wav2vec_model = model
コード例 #24
0
ファイル: infer.py プロジェクト: trentontemple/fairseq
def load_models_and_criterions(filenames,
                               data_path,
                               arg_overrides=None,
                               task=None,
                               model_state=None):
    models = []
    criterions = []

    if arg_overrides is None:
        arg_overrides = {}

    arg_overrides["wer_args"] = None
    arg_overrides["data"] = data_path

    if filenames is None:
        assert model_state is not None
        filenames = [0]
    else:
        filenames = filenames.split(":")

    for filename in filenames:
        if model_state is None:
            if not os.path.exists(filename):
                raise IOError("Model file not found: {}".format(filename))
            state = checkpoint_utils.load_checkpoint_to_cpu(
                filename, arg_overrides)
        else:
            state = model_state

        if "cfg" in state:
            cfg = state["cfg"]
        else:
            cfg = convert_namespace_to_omegaconf(state["args"])

        if task is None:
            if hasattr(cfg.task, 'data'):
                cfg.task.data = data_path
            task = tasks.setup_task(cfg.task)

        model = task.build_model(cfg.model)
        model.load_state_dict(state["model"], strict=True)
        models.append(model)

        criterion = task.build_criterion(cfg.criterion)
        if "criterion" in state:
            criterion.load_state_dict(state["criterion"], strict=True)
        criterions.append(criterion)
    return models, criterions, task
コード例 #25
0
ファイル: xm_transformer.py プロジェクト: xuqiantong/fairseq
    def build_encoder(cls, args):
        _args = copy.deepcopy(args)
        if not args.adaptor_proj:  # V0 arch
            state = checkpoint_utils.load_checkpoint_to_cpu(args.w2v_path)
            if state.get("cfg") is not None:
                encoder_embed_dim = state["cfg"]._content["model"][
                    "encoder_embed_dim"]
            elif state.get("args") is not None:
                encoder_embed_dim = state["args"].encoder_embed_dim
            else:
                raise ValueError(f"Invalid config in {args.w2v_path}")
            _args.decoder_embed_dim = encoder_embed_dim
            del state

        encoder = Wav2VecEncoderWithAdaptor(_args)
        return cls.maybe_load_pretrained(
            encoder, getattr(args, "load_pretrained_encoder_from", None))
コード例 #26
0
 def load_pretrained_speech_text_components(cls, checkpoint,
                                            component_pairs):
     if not PathManager.exists(checkpoint):
         raise IOError("Model file not found: {}".format(checkpoint))
     state = load_checkpoint_to_cpu(checkpoint)
     for component_type, component in component_pairs:
         if isinstance(component, nn.parameter.Parameter):
             component.data.copy_(state["model"][component_type])
         else:
             component_state_dict = OrderedDict()
             for key in state["model"].keys():
                 if key.startswith(component_type):
                     component_subkey = key[len(component_type) + 1:]
                     component_state_dict[component_subkey] = state[
                         "model"][key]
             component.load_state_dict(component_state_dict, strict=True)
     return state
コード例 #27
0
ファイル: train.py プロジェクト: xssstory/STAS
def update_args(args):
    import os
    from fairseq.checkpoint_utils import load_checkpoint_to_cpu
    bart_large_cnn_path = os.path.join(
        os.path.dirname(os.path.dirname(args.pretrained_doc_model_path)),
        'bart.large.cnn/model.pt')
    state = load_checkpoint_to_cpu(bart_large_cnn_path)
    new_args = state['args']
    no_update_args = [
        'source_lang', 'target_lang', 'task', 'data', 'save_dir',
        'update_freq', 'log_interval', 'dataset_impl'
    ]
    for k, v in new_args.__dict__.items():
        if k not in no_update_args and not k.startswith('distributed'):
            if getattr(args, k, None) != v:
                print('| WARNING: update {} in args from {} to {}'.format(
                    k, getattr(args, k, None), v))
                setattr(args, k, v)
コード例 #28
0
def main(args):
    import_user_module(args)
    ckpt_path = args.model_root + '/checkpoints/checkpoint_best.pt'
    # state = torch.load(
    #     ckpt_path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
    # )
    state = load_checkpoint_to_cpu(ckpt_path)

    enc_emb = state['model']['encoder.embed_tokens.weight']
    enc_emb_output_path = args.model_root + '/embeddings/encoder.indomain'
    output_trained_embeddings_to_file(enc_emb, args.srcdict, 
                                      enc_emb_output_path)
    
    if args.tgtdict:
        dec_emb = state['model']['decoder.embed_tokens.weight']
        dec_emb_output_path = args.model_root + '/embeddings/decoder.indomain'
        output_trained_embeddings_to_file(dec_emb, args.tgtdict, 
                                          dec_emb_output_path)
コード例 #29
0
ファイル: infer.py プロジェクト: mtanana/LyssnFairSeq
def load_models_and_criterions(filenames,
                               data_path,
                               arg_overrides=None,
                               task=None,
                               model_state=None):
    models = []
    criterions = []

    if arg_overrides is None:
        arg_overrides = {}

    arg_overrides['wer_args'] = None
    arg_overrides['data'] = data_path

    if filenames is None:
        assert model_state is not None
        filenames = [0]
    else:
        filenames = filenames.split(":")

    for filename in filenames:

        if model_state is None:
            if not os.path.exists(filename):
                raise IOError("Model file not found: {}".format(filename))
            state = checkpoint_utils.load_checkpoint_to_cpu(
                filename, arg_overrides)
        else:
            state = model_state

        args = state["args"]
        if task is None:
            task = tasks.setup_task(args)

        model = task.build_model(args)
        model.load_state_dict(state["model"], strict=True)

        models.append(model)

        criterion = task.build_criterion(args)
        if "criterion" in state:
            criterion.load_state_dict(state["criterion"], strict=True)
        criterions.append(criterion)
    return models, criterions, args
コード例 #30
0
def run_maybe_distributed_reptile(meta_learning_args, downstream_args,
                                  load_meta_tasks_fn, fine_tune_args):
    seed = downstream_args.seed
    if torch.cuda.is_available() and not meta_learning_args.cpu:
        torch.cuda.set_device(meta_learning_args.device_id)
    torch.manual_seed(seed)
    meta_train_tasks, meta_dev_tasks = load_meta_tasks_fn()
    # build model and criterion
    print('| training on {} GPUs'.format(
        meta_learning_args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        meta_learning_args.max_tokens,
        meta_learning_args.max_sentences,
    ))
    # Reptile training loop
    print('setup for meta-learning task')
    # Meta-learning task list
    meta_learning_task = tasks.setup_task(args=meta_learning_args,
                                          meta_train_tasks=meta_train_tasks,
                                          meta_dev_tasks=meta_dev_tasks,
                                          meta_test_tasks=None)
    print('building meta-learning model...')
    model = meta_learning_task.build_model(
        meta_learning_args)  # Transformer RAW
    state = load_checkpoint_to_cpu(meta_learning_args.restore_file)
    model.load_state_dict(state['model'], strict=False)
    meta_learning_criterion = meta_learning_task.build_criterion(
        meta_learning_args)  # MAML, FoMAML
    print(model)
    print('| model {}, criterion {}'.format(
        meta_learning_args.arch, meta_learning_criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))
    reptile_function = build_reptile_function(
        is_baseline=meta_learning_args.baseline,
        is_curriculum=meta_learning_args.is_curriculum)
    reptile_function(model=model,
                     meta_learning_task=meta_learning_task,
                     meta_learning_args=meta_learning_args,
                     meta_learning_criterion=meta_learning_criterion,
                     fine_tune_args=fine_tune_args)