def __init__(self, args, tgt_dict=None): self.apply_mask = args.apply_mask arg_overrides = { "dropout": args.dropout, "activation_dropout": args.activation_dropout, "dropout_input": args.dropout_input, "attention_dropout": args.attention_dropout, "mask_length": args.mask_length, "mask_prob": args.mask_prob, "mask_selection": args.mask_selection, "mask_other": args.mask_other, "no_mask_overlap": args.no_mask_overlap, "mask_channel_length": args.mask_channel_length, "mask_channel_prob": args.mask_channel_prob, "mask_channel_selection": args.mask_channel_selection, "mask_channel_other": args.mask_channel_other, "no_mask_channel_overlap": args.no_mask_channel_overlap, "encoder_layerdrop": args.layerdrop, "feature_grad_mult": args.feature_grad_mult, } if getattr(args, "w2v_args", None) is None: args.w2v_path = '../libri/wav2vec2_small.pt' print('load Wav2VecEncoder from {}'.format(args.w2v_path)) state = checkpoint_utils.load_checkpoint_to_cpu( args.w2v_path, arg_overrides) w2v_args = state["args"] assert getattr( w2v_args, "w2v_path", None ) is None # w2v_path is the pretrain model which should not have w2v_path else: state = None w2v_args = args.w2v_args assert args.normalize == w2v_args.normalize, 'Fine-tuning works best when data normalization is the same' w2v_args.data = args.data task = tasks.setup_task(w2v_args) model = task.build_model(w2v_args) if state is not None and not args.no_pretrained_weights: print('restore Wav2VecEncoder from {}'.format(args.w2v_path)) model.load_state_dict(state["model"], strict=True) model.remove_pretraining_modules() super().__init__(task.source_dictionary) d = w2v_args.encoder_embed_dim self.d = d self.w2v_model = model self.final_dropout = nn.Dropout(args.final_dropout) self.freeze_finetune_updates = args.freeze_finetune_updates self.num_updates = 0 if tgt_dict is not None: self.proj = Linear(d, len(tgt_dict)) else: self.proj = None
def load_generator(cls, filename, args, arg_overrides=None, task=None): if arg_overrides is None: arg_overrides = {} if not os.path.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = checkpoint_utils.load_checkpoint_to_cpu( filename, arg_overrides) args = state["args"] model = task.build_model(args) model.load_state_dict(state["model"], strict=True) return model, args
def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=None, model_state=None): models = [] criterions = [] if arg_overrides is None: arg_overrides = {} arg_overrides['wer_args'] = None arg_overrides['data'] = data_path if filenames is None: assert model_state is not None filenames = [0] else: filenames = filenames.split(":") for filename in filenames: if model_state is None: if not os.path.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = checkpoint_utils.load_checkpoint_to_cpu( filename, arg_overrides) else: state = model_state args = state["args"] if task is None: task = tasks.setup_task(args) model = task.build_model(args) print('model restore state from {}'.format(filename)) model.load_state_dict(state["model"], strict=True) models.append(model) criterion = task.build_criterion(args) if "criterion" in state: criterion.load_state_dict(state["criterion"], strict=True) criterions.append(criterion) return models, criterions, args
def load_checkpoint( self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None, reset_meters=False, ): """Load all training state from a checkpoint file.""" extra_state, self._optim_history, last_optim_state = None, [], None bexists = PathManager.isfile(filename) if bexists: state = checkpoint_utils.load_checkpoint_to_cpu(filename) # load model parameters try: self.get_model().load_state_dict(state["model"], strict=True, args=self.args) if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict(state["criterion"], strict=True) except Exception: raise Exception( "Cannot load model parameters from checkpoint {}; " "please ensure that the architectures match.".format( filename)) extra_state = state["extra_state"] self._optim_history = state["optimizer_history"] last_optim_state = state.get("last_optimizer_state", None) if last_optim_state is not None and not reset_optimizer: # rebuild optimizer after loading model, since params may have changed self._build_optimizer() # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] assert ( last_optim["criterion_name"] == self.get_criterion().__class__.__name__ ), "Criterion does not match; please reset the optimizer (--reset-optimizer)." assert ( last_optim["optimizer_name"] == self.optimizer.__class__.__name__ ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)." if not reset_lr_scheduler: self.lr_scheduler.load_state_dict( last_optim["lr_scheduler_state"]) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.set_num_updates(last_optim["num_updates"]) if extra_state is not None: epoch = extra_state["train_iterator"]["epoch"] logger.info("loaded checkpoint {} (epoch {} @ {} updates)".format( filename, epoch, self.get_num_updates())) if "previous_training_time" in extra_state: self._previous_training_time = extra_state[ "previous_training_time"] self._start_time = time.time() self.lr_step(epoch) if "metrics" in extra_state and not reset_meters: metrics.load_state_dict(extra_state["metrics"]) # reset TimeMeters, since their start times don't make sense anymore for meter in metrics.get_meters("default"): if isinstance(meter, meters.TimeMeter): meter.reset() else: logger.info("no existing checkpoint found {}".format(filename)) return extra_state