def _upgrade_state_dict(state): """Helper for upgrading old model checkpoints.""" from fairseq import models, registry, tasks # add optimizer_history if 'optimizer_history' not in state: state['optimizer_history'] = [ { 'criterion_name': 'CrossEntropyCriterion', 'best_loss': state['best_loss'], }, ] state['last_optimizer_state'] = state['optimizer'] del state['optimizer'] del state['best_loss'] # move extra_state into sub-dictionary if 'epoch' in state and 'extra_state' not in state: state['extra_state'] = { 'epoch': state['epoch'], 'batch_offset': state['batch_offset'], 'val_loss': state['val_loss'], } del state['epoch'] del state['batch_offset'] del state['val_loss'] # reduce optimizer history's memory usage (only keep the last state) if 'optimizer' in state['optimizer_history'][-1]: state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer'] for optim_hist in state['optimizer_history']: del optim_hist['optimizer'] # record the optimizer class name if 'optimizer_name' not in state['optimizer_history'][-1]: state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG' # move best_loss into lr_scheduler_state if 'lr_scheduler_state' not in state['optimizer_history'][-1]: state['optimizer_history'][-1]['lr_scheduler_state'] = { 'best': state['optimizer_history'][-1]['best_loss'], } del state['optimizer_history'][-1]['best_loss'] # keep track of number of updates if 'num_updates' not in state['optimizer_history'][-1]: state['optimizer_history'][-1]['num_updates'] = 0 # old model checkpoints may not have separate source/target positions if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'): state['args'].max_source_positions = state['args'].max_positions state['args'].max_target_positions = state['args'].max_positions # use stateful training data iterator if 'train_iterator' not in state['extra_state']: state['extra_state']['train_iterator'] = { 'epoch': state['extra_state']['epoch'], 'iterations_in_epoch': state['extra_state'].get('batch_offset', 0), } # default to translation task if not hasattr(state['args'], 'task'): state['args'].task = 'translation' # set any missing default values in the task, model or other registries registry.set_defaults(state['args'], tasks.TASK_REGISTRY[state['args'].task]) registry.set_defaults(state['args'], models.ARCH_MODEL_REGISTRY[state['args'].arch]) for registry_name, REGISTRY in registry.REGISTRIES.items(): choice = getattr(state['args'], registry_name, None) if choice is not None: cls = REGISTRY['registry'][choice] registry.set_defaults(state['args'], cls) return state
def _upgrade_state_dict(state): """Helper for upgrading old model checkpoints.""" from fairseq import models, registry, tasks # add optimizer_history if "optimizer_history" not in state: state["optimizer_history"] = [{ "criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"] }] state["last_optimizer_state"] = state["optimizer"] del state["optimizer"] del state["best_loss"] # move extra_state into sub-dictionary if "epoch" in state and "extra_state" not in state: state["extra_state"] = { "epoch": state["epoch"], "batch_offset": state["batch_offset"], "val_loss": state["val_loss"], } del state["epoch"] del state["batch_offset"] del state["val_loss"] # reduce optimizer history's memory usage (only keep the last state) if "optimizer" in state["optimizer_history"][-1]: state["last_optimizer_state"] = state["optimizer_history"][-1][ "optimizer"] for optim_hist in state["optimizer_history"]: del optim_hist["optimizer"] # record the optimizer class name if "optimizer_name" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG" # move best_loss into lr_scheduler_state if "lr_scheduler_state" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["lr_scheduler_state"] = { "best": state["optimizer_history"][-1]["best_loss"] } del state["optimizer_history"][-1]["best_loss"] # keep track of number of updates if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 # old model checkpoints may not have separate source/target positions if hasattr(state["args"], "max_positions") and not hasattr( state["args"], "max_source_positions"): state["args"].max_source_positions = state["args"].max_positions state["args"].max_target_positions = state["args"].max_positions # use stateful training data iterator if "train_iterator" not in state["extra_state"]: state["extra_state"]["train_iterator"] = { "epoch": state["extra_state"]["epoch"], "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), } # default to translation task if not hasattr(state["args"], "task"): state["args"].task = "translation" # --raw-text and --lazy-load are deprecated if getattr(state["args"], "raw_text", False): state["args"].dataset_impl = "raw" elif getattr(state["args"], "lazy_load", False): state["args"].dataset_impl = "lazy" # epochs start at 1 if state["extra_state"]["train_iterator"] is not None: state["extra_state"]["train_iterator"]["epoch"] = max( state["extra_state"]["train_iterator"].get("epoch", 1), 1, ) # set any missing default values in the task, model or other registries registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task]) registry.set_defaults(state["args"], models.ARCH_MODEL_REGISTRY[state["args"].arch]) for registry_name, REGISTRY in registry.REGISTRIES.items(): choice = getattr(state["args"], registry_name, None) if choice is not None: cls = REGISTRY["registry"][choice] registry.set_defaults(state["args"], cls) return state