class InferConfig(FairseqDataclass): task: Any = None decoding: DecodingConfig = DecodingConfig() common: CommonConfig = CommonConfig() common_eval: CommonEvalConfig = CommonEvalConfig() checkpoint: CheckpointConfig = CheckpointConfig() generation: GenerationConfig = GenerationConfig() distributed_training: DistributedTrainingConfig = DistributedTrainingConfig( ) dataset: DatasetConfig = DatasetConfig()
class InferConfig(FairseqDataclass): task: Any = None decoding: DecodingConfig = DecodingConfig() common: CommonConfig = CommonConfig() common_eval: CommonEvalConfig = CommonEvalConfig() checkpoint: CheckpointConfig = CheckpointConfig() distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() dataset: DatasetConfig = DatasetConfig() is_ax: bool = field( default=False, metadata={ "help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume" }, )
def add_checkpoint_args(parser): group = parser.add_argument_group("checkpoint") # fmt: off gen_parser_from_dataclass(group, CheckpointConfig()) # fmt: on return group
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): """ Load a checkpoint and restore the training iterator. *passthrough_args* will be passed through to ``trainer.get_train_iterator``. """ reset_optimizer = cfg.reset_optimizer reset_lr_scheduler = cfg.reset_lr_scheduler optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) reset_meters = cfg.reset_meters reset_dataloader = cfg.reset_dataloader if cfg.finetune_from_model is not None and ( reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader ): raise ValueError( "--finetune-from-model can not be set together with either --reset-optimizer" " or reset_lr_scheduler or reset_meters or reset_dataloader" ) suffix = trainer.checkpoint_suffix if ( cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' checkpoint_path = os.path.join( cfg.save_dir, "checkpoint_last{}.pt".format(suffix) ) first_launch = not PathManager.exists(checkpoint_path) if first_launch and cfg.get("continue_once", None) is not None: checkpoint_path = cfg.continue_once elif cfg.finetune_from_model is not None and first_launch: # if there is no last checkpoint to restore, start the finetune from pretrained model # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. if PathManager.exists(cfg.finetune_from_model): checkpoint_path = cfg.finetune_from_model reset_optimizer = True reset_lr_scheduler = True reset_meters = True reset_dataloader = True logger.info( f"loading pretrained model from {checkpoint_path}: " "optimizer, lr scheduler, meters, dataloader will be reset" ) else: raise ValueError( f"--funetune-from-model {cfg.finetune_from_model} does not exist" ) elif suffix is not None: checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") else: checkpoint_path = cfg.restore_file if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model: raise ValueError( "--finetune-from-model and --restore-file (non-default value) " "can not be specified together: " + str(cfg) ) extra_state = trainer.load_checkpoint( checkpoint_path, reset_optimizer, reset_lr_scheduler, optimizer_overrides, reset_meters=reset_meters, ) if ( extra_state is not None and "best" in extra_state and not reset_optimizer and not reset_meters ): save_checkpoint.best = extra_state["best"] if extra_state is not None and not reset_dataloader: # restore iterator from checkpoint itr_state = extra_state["train_iterator"] epoch_itr = trainer.get_train_iterator( epoch=itr_state["epoch"], load_dataset=True, **passthrough_args ) epoch_itr.load_state_dict(itr_state) else: epoch_itr = trainer.get_train_iterator( epoch=1, load_dataset=True, **passthrough_args ) trainer.lr_step(epoch_itr.epoch) return extra_state, epoch_itr