def main(args): import_user_module(args) ckpt_path1 = args.model1_root + '/checkpoints/checkpoint_best.pt' ckpt_path2 = args.model2_root + '/checkpoints/checkpoint_best.pt' state1 = load_checkpoint_to_cpu(ckpt_path1) state2 = load_checkpoint_to_cpu(ckpt_path2) enc_emb1 = state1['model']['encoder.embed_tokens.weight'] enc_emb2 = state2['model']['encoder.embed_tokens.weight'] check = enc_emb1 == enc_emb2 print(check[:6]) print(enc_emb1[:6, :5]) print(enc_emb2[:6, :5])
def load_model_vocab(self, args): filename = args.model_path if not os.path.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = checkpoint_utils.load_checkpoint_to_cpu(filename) task_args = state["cfg"]["task"] task_args.data = args.data_bin task = self.set_up_task(task_args) # build model for ensemble self.model = task.build_model(state["cfg"]["model"]) self.model.load_state_dict(state["model"], strict=True) self.model.eval() self.model.share_memory() if self.gpu: self.model.cuda() # Set dictionary self.dict = {} self.dict["tgt"] = task.target_dictionary
def load_from_pretrained(self, filename, prefix, args): state_dict = load_checkpoint_to_cpu(filename)['model'] if prefix: state_dict = {key[len(prefix):]: value for key, value in state_dict.items() if key.startswith(prefix)} model_vocab_size = self.decoder.sentence_encoder.embed_tokens.weight.shape[0] ckpt_vocab_size = state_dict['decoder.sentence_encoder.embed_tokens.weight'].shape[0] diff = model_vocab_size - ckpt_vocab_size model_pos_size = self.decoder.sentence_encoder.embed_positions.weight.shape[0] ckpt_pos_size = state_dict['decoder.sentence_encoder.embed_positions.weight'].shape[0] diff_pos_size = model_pos_size - ckpt_pos_size new_state_dict = {} for n, c in state_dict.items(): if n in ['decoder.sentence_encoder.embed_tokens.weight', 'decoder.lm_head.weight'] and diff > 0: new_weight = torch.Tensor(c.shape[0]+diff, c.shape[1]) new_weight.data.normal_(mean=0.0, std=0.02) new_weight[:-diff] = c new_state_dict[n] = new_weight elif n == 'decoder.lm_head.bias' and diff > 0: new_weight = torch.zeros(c.shape[0]+diff) new_weight[:-diff] = c new_state_dict[n] = new_weight elif n == 'decoder.sentence_encoder.embed_positions.weight' and diff_pos_size < 0: new_weight = c[:c.shape[0] + diff_pos_size] new_state_dict[n] = new_weight else: new_state_dict[n] = c missing_keys, unexpected_keys = super().load_state_dict(new_state_dict, strict=False, args=args) handle_state_dict_keys(missing_keys, unexpected_keys)
def upgrade_state_dict_with_xlm_weights( state_dict: Dict[str, Any], pretrained_xlm_checkpoint: str ) -> Dict[str, Any]: """ Load XLM weights into a Transformer encoder or decoder model. Args: state_dict: state dict for either TransformerEncoder or TransformerDecoder pretrained_xlm_checkpoint: checkpoint to load XLM weights from Raises: AssertionError: If architecture (num layers, attention heads, etc.) does not match between the current Transformer encoder or decoder and the pretrained_xlm_checkpoint """ if not os.path.exists(pretrained_xlm_checkpoint): raise IOError(f"Model file not found: {pretrained_xlm_checkpoint}") state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint) xlm_state_dict = state["model"] for key in xlm_state_dict.keys(): for search_key in ["embed_tokens", "embed_positions", "layers"]: if search_key in key: subkey = key[key.find(search_key):] assert subkey in state_dict, ( f"{str(state_dict.keys())} Transformer encoder / decoder " f"state_dict does not contain {subkey}. Cannot " f"load {key} from pretrained XLM checkpoint " f"{pretrained_xlm_checkpoint} into Transformer." ) state_dict[subkey] = xlm_state_dict[key] return state_dict
def build_encoder(cls, args): _args = copy.deepcopy(args) if not args.adaptor_proj and not args.encoder_proj: # V0 arch if args.w2v_path: state = checkpoint_utils.load_checkpoint_to_cpu(args.w2v_path) if state.get("cfg") is not None: encoder_embed_dim = state["cfg"]._content["model"]["encoder_embed_dim"] elif state.get("args") is not None: encoder_embed_dim = state["args"].encoder_embed_dim else: raise ValueError(f"Invalid config in {args.w2v_path}") _args.decoder_embed_dim = encoder_embed_dim del state else: _args.decoder_embed_dim = args.encoder_embed_dim encoder = Wav2VecEncoderWithAdaptor(_args) encoder = cls.maybe_load_pretrained( encoder, getattr(args, "load_pretrained_encoder_from", None) ) if args.remove_weight_norm: # remove the wn for EMA usage logger.warning("Removing weight norm from wav2vec encoder") remove_weight_norm_from_model(encoder) return encoder
def __init__(self, cfg: Wav2Vec2AsrConfig, tgt_dict=None): self.mask = cfg.apply_mask state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path) w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) cfg.w2v_args = w2v_args self.mask_prob = cfg.mask_prob self.mask_selection = cfg.mask_selection self.mask_other = cfg.mask_other self.mask_length = cfg.mask_length self.no_mask_overlap = cfg.no_mask_overlap self.mask_min_space = cfg.mask_min_space self.mask_channel_prob = cfg.mask_channel_prob self.mask_channel_selection = cfg.mask_channel_selection self.mask_channel_other = cfg.mask_channel_other self.mask_channel_length = cfg.mask_channel_length self.no_mask_channel_overlap = cfg.no_mask_channel_overlap self.mask_channel_min_space = cfg.mask_channel_min_space ''' assert cfg.normalize == w2v_args.task.normalize, ( "Fine-tuning works best when data normalization is the same. " "Please check that --normalize is set or unset for both pre-training and here" ) ''' w2v_args.task.data = cfg.data task = tasks.setup_task(w2v_args.task) model = task.build_model(w2v_args.model) if state is not None and not cfg.no_pretrained_weights: model.load_state_dict(state["model"], strict=True) model.remove_pretraining_modules() super().__init__(task.source_dictionary) d = w2v_args.model.encoder_embed_dim self.w2v_model = model self.final_dropout = nn.Dropout(cfg.final_dropout) self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 self.lstm = nn.LSTM(input_size=d, hidden_size=1024, num_layers=2, batch_first=True, bidirectional=True) d = 2048 if tgt_dict is not None: self.proj = Linear(d, len(tgt_dict)) elif getattr(cfg, "decoder_embed_dim", d) != d: self.proj = Linear(d, cfg.decoder_embed_dim) else: self.proj = None
def build_model(cls, cfg: HubertSeq2SeqConfig, task: FairseqTask): """Build a new model instance.""" assert (cfg.autoregressive ), "Please set task.autoregressive=true for seq2seq asr models" src_dict, tgt_dict = task.source_dictionary, task.target_dictionary def build_embedding(dictionary, embed_dim): num_embeddings = len(dictionary) padding_idx = dictionary.pad() emb = Embedding(num_embeddings, embed_dim, padding_idx) return emb decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim) encoder = cls.build_encoder(cfg, task) decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) model = HubertSeq2SeqModel(encoder, decoder) if cfg["seq2seq_path"]: state = checkpoint_utils.load_checkpoint_to_cpu(cfg.seq2seq_path) state = state["model"] if cfg["reset_dict"]: del state["decoder.embed_out"] del state["decoder.embed_tokens.weight"] model.load_state_dict(state, strict=False) return model
def load_models_and_criterions(filenames, arg_overrides=None, task=None): models = [] criterions = [] for filename in filenames: if not os.path.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = checkpoint_utils.load_checkpoint_to_cpu( filename, arg_overrides) args = state["args"] if task is None: task = tasks.setup_task(args) # build model for ensemble model = task.build_model(args) # model.decoder.load_state_dict(state['model'],strict=False) # model_state = {k: v for k, v in state['model'].items() if 'encoder' not in k} # print(model_state.keys()) # model.load_state_dict(model_state, strict=False) model.load_state_dict(state["model"], strict=True) models.append(model) criterion = task.build_criterion(args) if "criterion" in state: criterion.load_state_dict(state["criterion"], strict=True) criterions.append(criterion) return models, criterions, args
def __init__(self, args, tgt_dict=None): self.apply_mask = args.apply_mask arg_overrides = { "dropout": args.dropout, "activation_dropout": args.activation_dropout, "dropout_input": args.dropout_input, "attention_dropout": args.attention_dropout, "mask_length": args.mask_length, "mask_prob": args.mask_prob, "mask_selection": args.mask_selection, "mask_other": args.mask_other, "no_mask_overlap": args.no_mask_overlap, "mask_channel_length": args.mask_channel_length, "mask_channel_prob": args.mask_channel_prob, "mask_channel_selection": args.mask_channel_selection, "mask_channel_other": args.mask_channel_other, "no_mask_channel_overlap": args.no_mask_channel_overlap, "encoder_layerdrop": args.layerdrop, "feature_grad_mult": args.feature_grad_mult, } if getattr(args, "w2v_args", None) is None: state = checkpoint_utils.load_checkpoint_to_cpu( args.w2v_path, arg_overrides ) args.w2v_args = w2v_args = state.get("args", None) or state["cfg"].model else: state = None w2v_args = args.w2v_args assert ( args.normalize == w2v_args.normalize ), "Fine-tuning works best when data normalization is the same" w2v_args.data = args.data task = tasks.setup_task(w2v_args) model = task.build_model(w2v_args) if state is not None and not args.no_pretrained_weights: model.load_state_dict(state["model"], strict=True) model.remove_pretraining_modules() super().__init__(task.source_dictionary) d = w2v_args.encoder_embed_dim self.w2v_model = model self.final_dropout = nn.Dropout(args.final_dropout) self.freeze_finetune_updates = args.freeze_finetune_updates self.num_updates = 0 if tgt_dict is not None: self.proj = Linear(d, len(tgt_dict)) elif getattr(args, "decoder_embed_dim", d) != d: self.proj = Linear(d, args.decoder_embed_dim) else: self.proj = None
def __init__(self, args, exit_after_mask=False): super().__init__() self.args = args self.iterations = args.decoding_iterations self.end_iteration = args.end_iteration self.exit_after_mask = exit_after_mask self.baseline_model = None self.masker = getattr(args, "masker", False) self.progressive = hasattr(args, "progressive") and args.progressive if getattr(args, "ensemble", False): from nsml import DATASET_PATH from fairseq import checkpoint_utils data_token = "en-de" pretrained_path = "{}/train/pretrained_models/maskPredict_{}/checkpoint_best.pt".format( DATASET_PATH, data_token.split(".")[-1].replace("-", "_")) print("| loading", pretrained_path) state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_path) baseline_model = args.taskobj.build_model(args) baseline_model.load_state_dict(state["model"], strict=True) if torch.cuda.is_available(): baseline_model.cuda() self.baseline_model = baseline_model if args.fp16: self.baseline_model.half()
def build_model(cls, args, task): model_fast = RobertaModel.build_model(args, task) model_slow = RobertaModel.build_model(args, task) if args.roberta_model_path != "": state = checkpoint_utils.load_checkpoint_to_cpu(args.roberta_model_path) model_fast.load_state_dict(state["model"], strict=True, args=args) model_slow.load_state_dict(state["model"], strict=True, args=args) else: model_slow.load_state_dict(model_fast.state_dict(), strict=True, args=args) proj = None if args.use_proj: # NOTE alway be share_proj langs = ["share_lang"] proj = build_projection_dict(langs, args.encoder_embed_dim, args.activation_fn, args.fp16) if "xlco_queue_size" in args: xlco_queue_size = args.xlco_queue_size else: xlco_queue_size = 1 print("xlco_queue_size is set as %d" % xlco_queue_size, flush=True) queue = torch.randn(xlco_queue_size, args.encoder_embed_dim) return cls(model_fast, model_slow, queue, proj=proj)
def load_pretrained_checkpoint( self, filename, ): """Load all training state from a checkpoint file.""" extra_state, self._optim_history = None, None state = checkpoint_utils.load_checkpoint_to_cpu(filename) # load model parameters try: self.get_model().load_state_dict( state["model"], strict=False, args=self.args ) except Exception: raise Exception( "Cannot load model parameters from checkpoint {}; " "please ensure that the architectures match.".format( filename) ) print("warm start from {}".format(filename)) return extra_state
def build_model(cls, args, task, dictionary=None): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if not hasattr(args, 'max_positions'): args.max_positions = args.tokens_per_sample logger.info(args) if task is None: assert dictionary encoder = MaskedLMEncoder(args, dictionary) else: encoder = MaskedLMEncoder(args, task.dictionary) if getattr(args, "lm_path", None): print('load masked_lm from {}'.format(args.lm_path)) state = checkpoint_utils.load_checkpoint_to_cpu(args.lm_path) lm_args = state["args"] lm_args.data = args.data assert getattr(lm_args, "lm_path", None) is None task = tasks.setup_task(lm_args) encoder = task.build_model(lm_args) print('restore masked_lm from {}'.format(args.lm_path)) encoder.load_state_dict(state["model"], strict=False) return cls(args, encoder)
def _load(self): import torch import fairseq from fairseq import checkpoint_utils with torch.no_grad(): checkpoint = checkpoint_utils.load_checkpoint_to_cpu( self._model_path) args = checkpoint["args"] or checkpoint["cfg"]["model"] args.data = self._data_dir if self._fixed_dictionary is not None: args.fixed_dictionary = self._fixed_dictionary if self._source_lang is not None: args.source_lang = self._source_lang if self._target_lang is not None: args.target_lang = self._target_lang model_spec = _get_model_spec(args) model_spec.with_source_eos = True model_spec.with_target_bos = False task = fairseq.tasks.setup_task(args) model = fairseq.models.build_model(args, task) model.eval() model.load_state_dict(checkpoint["model"]) set_transformer_spec(model_spec, model) model_spec.register_source_vocabulary( _get_vocab(task.source_dictionary)) model_spec.register_target_vocabulary( _get_vocab(task.target_dictionary)) return model_spec
def build_model(self, args): from fairseq import models model = models.build_model(args, self) if args.reload_checkpoint is not None: filename = args.reload_checkpoint if os.path.exists(filename): state = checkpoint_utils.load_checkpoint_to_cpu(filename) model.load_state_dict(state['model'], strict=False) return model
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None): """Load all training state from a checkpoint file.""" extra_state, self._optim_history, last_optim_state = None, [], None if os.path.exists(filename): state = checkpoint_utils.load_checkpoint_to_cpu(filename) # load model parameters try: # TODO this should be a command line flag self.get_model().load_state_dict(state['model'], strict=False) except Exception: raise Exception( 'Cannot load model parameters from checkpoint, ' 'please ensure that the architectures match.') extra_state = state['extra_state'] self._optim_history = state['optimizer_history'] last_optim_state = state['last_optimizer_state'] if last_optim_state is not None and not reset_optimizer: # rebuild optimizer after loading model, since params may have changed self._build_optimizer() # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \ 'Criterion does not match; please reset the optimizer (--reset-optimizer).' assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \ 'Optimizer does not match; please reset the optimizer (--reset-optimizer).' if not reset_lr_scheduler: self.lr_scheduler.load_state_dict( last_optim['lr_scheduler_state']) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self._num_updates = last_optim['num_updates'] if extra_state is not None and 'train_meters' in extra_state: self.meters.update(extra_state['train_meters']) del extra_state['train_meters'] # reset TimeMeters, since their start times don't make sense anymore for meter in self.meters.values(): if isinstance(meter, TimeMeter): meter.reset() return extra_state
def __init__(self, args, tgt_dict): super().__init__() self.args = args feature_enc_layers = eval(args.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=args.extractor_mode, conv_bias=args.conv_bias, ) self.post_extract_proj = ( nn.Linear(self.embed, args.encoder_embed_dim) if self.embed != args.encoder_embed_dim else None ) self.mask_prob = args.mask_prob self.mask_selection = args.mask_selection self.mask_other = args.mask_other self.mask_length = args.mask_length self.no_mask_overlap = args.no_mask_overlap self.mask_min_space = args.mask_min_space self.mask_channel_prob = args.mask_channel_prob self.mask_channel_selection = args.mask_channel_selection self.mask_channel_other = args.mask_channel_other self.mask_channel_length = args.mask_channel_length self.no_mask_channel_overlap = args.no_mask_channel_overlap self.mask_channel_min_space = args.mask_channel_min_space self.dropout_input = nn.Dropout(args.dropout_input) self.dropout_features = nn.Dropout(args.dropout_features) self.feature_grad_mult = args.feature_grad_mult self.mask_emb = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_() ) self.encoder = TransformerEncoder(args) self.layer_norm = LayerNorm(self.embed) self.phone_proj = nn.Linear(args.encoder_embed_dim, len(tgt_dict)) if getattr(args, "w2v_path", None): print('load Wav2VecEncoder from {}'.format(args.w2v_path)) state = checkpoint_utils.load_checkpoint_to_cpu(args.w2v_path) self.load_state_dict(state["model"], strict=False)
def _load_models(self, args): state = checkpoint_utils.load_checkpoint_to_cpu(args.path) state["args"].data = args.data task = tasks.setup_task(state["args"]) model = task.build_model(state["args"]) model.load_state_dict(state["model"], strict=True, args=state["args"]) model.make_generation_fast_() if args.fp16: model.half() if args.use_cuda: model.cuda() return [model]
def build_encoder(cls, args): _args = copy.deepcopy(args) state = checkpoint_utils.load_checkpoint_to_cpu(args.w2v_path) if state.get("cfg") is not None: encoder_embed_dim = state["cfg"]._content["model"][ "encoder_embed_dim"] elif state.get("args") is not None: encoder_embed_dim = state["args"].encoder_embed_dim else: raise ValueError(f"Invalid config in {args.w2v_path}") _args.decoder_embed_dim = encoder_embed_dim encoder = Wav2VecEncoderWithAdaptor(_args) return encoder
def build_model(cls, args, task): reload_roberta_base(args) if not hasattr(args, 'max_positions'): args.max_positions = args.tokens_per_sample encoder = RobertaEncoder(args, task.source_dictionary) model = cls(args, encoder) if args.roberta_model_path != "": state = checkpoint_utils.load_checkpoint_to_cpu(args.roberta_model_path) model.load_state_dict(state["model"], strict=False, args=args) print(model.__class__) return model
def init_tmodel(source_path, target_path, modified_path): """ Args: source_path: A fairseq.Language_model that whose params will be initialized with the params from the Transformer model. target_path: A fairseq.Transformer model that has been trained on the Translation task modified_path: A string object denoting the path to where you wish to store the model """ encoder_state = checkpoint_utils.load_checkpoint_to_cpu(source_path) translation_state = checkpoint_utils.load_checkpoint_to_cpu(target_path) filtered_state = [] for key in encoder_state['model'].keys(): filtered_state.append((key, encoder_state['model'][key])) #Remove the linear and layer norm layers to maintain compatibiility filtered_state.pop() filtered_state.pop() filtered_state.pop() filtered_state.pop() list_translation_state = [] for key in translation_state['model'].keys(): list_translation_state.append((key, translation_state['model'][key])) for index, key in enumerate(list_translation_state): if key[0].startswith('encoder'): list_translation_state[index] = filtered_state[index] list_translation_state_dict = OrderedDict(list_translation_state) translation_state['model'] = list_translation_state_dict checkpoint_utils.torch_persistent_save(translation_state, modified_path) return
def load_feature_extractor(component, checkpoint): if not PathManager.exists(checkpoint): raise IOError( "Model file not found: {}".format(checkpoint)) state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint) component_state_dict = OrderedDict() component_prefix = "feature_extractor" for key in state["model"].keys(): if key.startswith(component_prefix): component_subkey = key[len(component_prefix) + 1:] component_state_dict[component_subkey] = state[ "model"][key] component.load_state_dict(component_state_dict, strict=True) return component
def __init__(self, input_feat_per_channel, vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG, transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, encoder_output_dim=512, in_channels=1, transformer_context=None, transformer_sampling=None): super().__init__(input_feat_per_channel, vggblock_config, transformer_config, encoder_output_dim, in_channels, transformer_context, transformer_sampling) wav2vec_checkpoint = HOME + '/data/fairseq-data/wav2vec_models/checkpoint_last.pt' # wav2vec_checkpoint = '/tmp/checkpoint_last.pt' cp = checkpoint_utils.load_checkpoint_to_cpu(wav2vec_checkpoint) model = Wav2VecModel.build_model(cp['args'], task=None) model.load_state_dict(cp['model']) freeze_module_params(model) self.wav2vec_model = model
def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=None, model_state=None): models = [] criterions = [] if arg_overrides is None: arg_overrides = {} arg_overrides["wer_args"] = None arg_overrides["data"] = data_path if filenames is None: assert model_state is not None filenames = [0] else: filenames = filenames.split(":") for filename in filenames: if model_state is None: if not os.path.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = checkpoint_utils.load_checkpoint_to_cpu( filename, arg_overrides) else: state = model_state if "cfg" in state: cfg = state["cfg"] else: cfg = convert_namespace_to_omegaconf(state["args"]) if task is None: if hasattr(cfg.task, 'data'): cfg.task.data = data_path task = tasks.setup_task(cfg.task) model = task.build_model(cfg.model) model.load_state_dict(state["model"], strict=True) models.append(model) criterion = task.build_criterion(cfg.criterion) if "criterion" in state: criterion.load_state_dict(state["criterion"], strict=True) criterions.append(criterion) return models, criterions, task
def build_encoder(cls, args): _args = copy.deepcopy(args) if not args.adaptor_proj: # V0 arch state = checkpoint_utils.load_checkpoint_to_cpu(args.w2v_path) if state.get("cfg") is not None: encoder_embed_dim = state["cfg"]._content["model"][ "encoder_embed_dim"] elif state.get("args") is not None: encoder_embed_dim = state["args"].encoder_embed_dim else: raise ValueError(f"Invalid config in {args.w2v_path}") _args.decoder_embed_dim = encoder_embed_dim del state encoder = Wav2VecEncoderWithAdaptor(_args) return cls.maybe_load_pretrained( encoder, getattr(args, "load_pretrained_encoder_from", None))
def load_pretrained_speech_text_components(cls, checkpoint, component_pairs): if not PathManager.exists(checkpoint): raise IOError("Model file not found: {}".format(checkpoint)) state = load_checkpoint_to_cpu(checkpoint) for component_type, component in component_pairs: if isinstance(component, nn.parameter.Parameter): component.data.copy_(state["model"][component_type]) else: component_state_dict = OrderedDict() for key in state["model"].keys(): if key.startswith(component_type): component_subkey = key[len(component_type) + 1:] component_state_dict[component_subkey] = state[ "model"][key] component.load_state_dict(component_state_dict, strict=True) return state
def update_args(args): import os from fairseq.checkpoint_utils import load_checkpoint_to_cpu bart_large_cnn_path = os.path.join( os.path.dirname(os.path.dirname(args.pretrained_doc_model_path)), 'bart.large.cnn/model.pt') state = load_checkpoint_to_cpu(bart_large_cnn_path) new_args = state['args'] no_update_args = [ 'source_lang', 'target_lang', 'task', 'data', 'save_dir', 'update_freq', 'log_interval', 'dataset_impl' ] for k, v in new_args.__dict__.items(): if k not in no_update_args and not k.startswith('distributed'): if getattr(args, k, None) != v: print('| WARNING: update {} in args from {} to {}'.format( k, getattr(args, k, None), v)) setattr(args, k, v)
def main(args): import_user_module(args) ckpt_path = args.model_root + '/checkpoints/checkpoint_best.pt' # state = torch.load( # ckpt_path, map_location=lambda s, l: default_restore_location(s, 'cpu'), # ) state = load_checkpoint_to_cpu(ckpt_path) enc_emb = state['model']['encoder.embed_tokens.weight'] enc_emb_output_path = args.model_root + '/embeddings/encoder.indomain' output_trained_embeddings_to_file(enc_emb, args.srcdict, enc_emb_output_path) if args.tgtdict: dec_emb = state['model']['decoder.embed_tokens.weight'] dec_emb_output_path = args.model_root + '/embeddings/decoder.indomain' output_trained_embeddings_to_file(dec_emb, args.tgtdict, dec_emb_output_path)
def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=None, model_state=None): models = [] criterions = [] if arg_overrides is None: arg_overrides = {} arg_overrides['wer_args'] = None arg_overrides['data'] = data_path if filenames is None: assert model_state is not None filenames = [0] else: filenames = filenames.split(":") for filename in filenames: if model_state is None: if not os.path.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = checkpoint_utils.load_checkpoint_to_cpu( filename, arg_overrides) else: state = model_state args = state["args"] if task is None: task = tasks.setup_task(args) model = task.build_model(args) model.load_state_dict(state["model"], strict=True) models.append(model) criterion = task.build_criterion(args) if "criterion" in state: criterion.load_state_dict(state["criterion"], strict=True) criterions.append(criterion) return models, criterions, args
def run_maybe_distributed_reptile(meta_learning_args, downstream_args, load_meta_tasks_fn, fine_tune_args): seed = downstream_args.seed if torch.cuda.is_available() and not meta_learning_args.cpu: torch.cuda.set_device(meta_learning_args.device_id) torch.manual_seed(seed) meta_train_tasks, meta_dev_tasks = load_meta_tasks_fn() # build model and criterion print('| training on {} GPUs'.format( meta_learning_args.distributed_world_size)) print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( meta_learning_args.max_tokens, meta_learning_args.max_sentences, )) # Reptile training loop print('setup for meta-learning task') # Meta-learning task list meta_learning_task = tasks.setup_task(args=meta_learning_args, meta_train_tasks=meta_train_tasks, meta_dev_tasks=meta_dev_tasks, meta_test_tasks=None) print('building meta-learning model...') model = meta_learning_task.build_model( meta_learning_args) # Transformer RAW state = load_checkpoint_to_cpu(meta_learning_args.restore_file) model.load_state_dict(state['model'], strict=False) meta_learning_criterion = meta_learning_task.build_criterion( meta_learning_args) # MAML, FoMAML print(model) print('| model {}, criterion {}'.format( meta_learning_args.arch, meta_learning_criterion.__class__.__name__)) print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) reptile_function = build_reptile_function( is_baseline=meta_learning_args.baseline, is_curriculum=meta_learning_args.is_curriculum) reptile_function(model=model, meta_learning_task=meta_learning_task, meta_learning_args=meta_learning_args, meta_learning_criterion=meta_learning_criterion, fine_tune_args=fine_tune_args)