def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1): from fairseq import tasks assert not ( strict and num_shards > 1 ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] for filename in filenames: orig_filename = filename for shard_idx in range(num_shards): if num_shards == 1: filename = filename.replace(".pt", suffix + ".pt") else: filename = orig_filename[:-3] + f"_part{shard_idx}.pt" if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = load_checkpoint_to_cpu(filename, arg_overrides) if shard_idx == 0: args = state["args"] if task is None: task = tasks.setup_task(args) # build model for ensemble model = task.build_model(args) model.load_state_dict(state["model"], strict=strict, args=args) ensemble.append(model) return ensemble, args, task
def load_pretrained_component_from_model(component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str): """ Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the provided `component` object. If state_dict fails to load, there may be a mismatch in the architecture of the corresponding `component` found in the `checkpoint` file. """ if not PathManager.exists(checkpoint): raise IOError("Model file not found: {}".format(checkpoint)) state = load_checkpoint_to_cpu(checkpoint) if isinstance(component, FairseqEncoder): component_type = "encoder" elif isinstance(component, FairseqDecoder): component_type = "decoder" else: raise ValueError( "component to load must be either a FairseqEncoder or " "FairseqDecoder. Loading other component types are not supported.") component_state_dict = OrderedDict() for key in state["model"].keys(): if key.startswith(component_type): # encoder.input_layers.0.0.weight --> input_layers.0.0.weight component_subkey = key[len(component_type) + 1:] component_state_dict[component_subkey] = state["model"][key] component.load_state_dict(component_state_dict, strict=True) return component
def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix=''): from fairseq import tasks ensemble = [] for filename in filenames: filename = filename.replace(".pt", suffix + ".pt") if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = load_checkpoint_to_cpu(filename, arg_overrides) args = state["args"] logger.info('[load_model_ensemble_and_task[data]:] {}'.format( args.data)) if task is None: task = tasks.setup_task(args) # build model for ensemble model = task.build_model(args) model.load_state_dict(state["model"], strict=strict, args=args) ensemble.append(model) return ensemble, args, task
def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix=''): from fairseq import tasks ensemble = [] for filename in filenames: filename = filename.replace(".pt", suffix + ".pt") if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = load_checkpoint_to_cpu(filename, arg_overrides) args = state["args"] if task is None: task = tasks.setup_task(args) # build model for ensemble model = task.build_model(args) states = state["model"] if hasattr(args, 'mixout') and args.mixout > 0: for k, v in list(states.items()): if '._params_learned' in k: del states[k] states[k.replace('._params_learned', '')] = v model.load_state_dict(states, strict=strict, args=args) ensemble.append(model) return ensemble, args, task
def load_bert_state(model, checkpoint): print('Load pretrained data augmentation checkpoint (BERT)') if not PathManager.exists(checkpoint): raise IOError("Model file not found: {}".format(checkpoint)) from torch.serialization import default_restore_location state = torch.load( checkpoint, map_location=lambda s, l: default_restore_location(s, 'cpu')) def upgrade(obj): if isinstance(obj, OrderedDict): oldkeys = list(obj.keys()) for k in oldkeys: if k.startswith('encoder') and k != 'encoder': newkey = k.split('.', 1)[1] else: newkey = k obj[newkey] = upgrade(obj[k]) if k.startswith('encoder'): del obj[k] else: return obj upgrade(state['model']) try: model.load_state_dict(state['model'], strict=True) except Exception: raise Exception( 'Cannot load model parameters from pretrained augmentation model checkpoint, ' 'please ensure that the architectures match') return True
def load_model_ensemble_and_task( filenames, arg_overrides: Optional[Dict[str, Any]] = None, task=None, strict=True, suffix="", num_shards=1, state=None, ): assert state is None or len(filenames) == 1 from fairseq import tasks assert not ( strict and num_shards > 1 ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] cfg = None for filename in filenames: orig_filename = filename assert num_shards > 0 for shard_idx in range(num_shards): if num_shards == 1: filename = filename.replace(".pt", suffix + ".pt") else: filename = orig_filename[:-3] + f"_part{shard_idx}.pt" if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) if state is None: state = load_checkpoint_to_cpu(filename, arg_overrides) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: cfg = state["cfg"] else: raise RuntimeError( f"Neither args nor cfg exist in state keys = {state.keys()}" ) if task is None: task = tasks.setup_task(cfg.task) if "task_state" in state: task.load_state_dict(state["task_state"]) # build model for ensemble model = task.build_model(cfg.model) model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) # reset state so it gets loaded for the next model in ensemble state = None ensemble.append(model) return ensemble, cfg, task
def load_xlmt_model_ensemble(filenames, arg_overrides=None, strict=True, suffix="", num_shards=1, state=None, src_dict=None, tgt_dict=None): assert state is None or len(filenames) == 1 from fairseq.models.xlmt_decoder_variant import XLMTDecoderVariantModel assert not ( strict and num_shards > 1 ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] cfg = None for filename in filenames: orig_filename = filename assert num_shards > 0 for shard_idx in range(num_shards): if num_shards == 1: filename = filename.replace(".pt", suffix + ".pt") else: filename = orig_filename[:-3] + f"_part{shard_idx}.pt" if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) if state is None: state = load_checkpoint_to_cpu(filename, arg_overrides) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: cfg = state["cfg"] else: raise RuntimeError( f"Neither args nor cfg exist in state keys = {state.keys()}" ) # build model for ensemble model = XLMTDecoderVariantModel.build_model_without_task( cfg.model, src_dict, tgt_dict) state = expand_embedding_matrix(state, model) model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) # reset state so it gets loaded for the next model in ensemble state = None ensemble.append(model) return ensemble, cfg
def get_maybe_sharded_checkpoint_filename( filename: str, suffix: str, shard_idx: int, num_shards: int ) -> str: orig_filename = filename filename = filename.replace(".pt", suffix + ".pt") fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt" model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt" if PathManager.exists(fsdp_filename): return fsdp_filename elif num_shards > 1: return model_parallel_filename else: return filename
def load_feature_extractor(component, checkpoint): if not PathManager.exists(checkpoint): raise IOError( "Model file not found: {}".format(checkpoint)) state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint) component_state_dict = OrderedDict() component_prefix = "feature_extractor" for key in state["model"].keys(): if key.startswith(component_prefix): component_subkey = key[len(component_prefix) + 1:] component_state_dict[component_subkey] = state[ "model"][key] component.load_state_dict(component_state_dict, strict=True) return component
def load_pretrained_speech_text_components(cls, checkpoint, component_pairs): if not PathManager.exists(checkpoint): raise IOError("Model file not found: {}".format(checkpoint)) state = load_checkpoint_to_cpu(checkpoint) for component_type, component in component_pairs: if isinstance(component, nn.parameter.Parameter): component.data.copy_(state["model"][component_type]) else: component_state_dict = OrderedDict() for key in state["model"].keys(): if key.startswith(component_type): component_subkey = key[len(component_type) + 1:] component_state_dict[component_subkey] = state[ "model"][key] component.load_state_dict(component_state_dict, strict=True) return state
def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None): from fairseq import tasks ensemble = [] for filename in filenames: if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = load_checkpoint_to_cpu(filename, arg_overrides) args = state["args"] if task is None: task = tasks.setup_task(args) # build model for ensemble model = task.build_model(args) model.load_state_dict(state["model"], strict=True, args=args) ensemble.append(model) return ensemble, args, task
def exists(prefix_path): return ( PathManager.exists(indexed_dataset.index_file_path(prefix_path)) and PathManager.exists(indexed_dataset.data_file_path(prefix_path)) and PathManager.exists(vocab_file_path(prefix_path)))
def exists(path): return PathManager.exists( index_file_path(path)) and PathManager.exists(data_file_path(path))
def exists(path): return PathManager.exists(path)
def load_model_ensemble_and_task( filenames, arg_overrides: Optional[Dict[str, Any]] = None, task=None, strict=True, suffix="", num_shards=1, state=None, ): assert state is None or len(filenames) == 1 from fairseq import tasks assert not ( strict and num_shards > 1 ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] cfg = None for filename in filenames: orig_filename = filename model_shard_state = {"shard_weights": [], "shard_metadata": []} assert num_shards > 0 st = time.time() for shard_idx in range(num_shards): filename = get_maybe_sharded_checkpoint_filename( orig_filename, suffix, shard_idx, num_shards ) if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) if state is None: state = load_checkpoint_to_cpu(filename, arg_overrides) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: cfg = state["cfg"] else: raise RuntimeError( f"Neither args nor cfg exist in state keys = {state.keys()}" ) if task is None: task = tasks.setup_task(cfg.task) if "task_state" in state: task.load_state_dict(state["task_state"]) if "fsdp_metadata" in state and num_shards > 1: model_shard_state["shard_weights"].append(state["model"]) model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) # check FSDP import before the code goes too far if not has_FSDP: raise ImportError( "Cannot find FullyShardedDataParallel. " "Please install fairscale with: pip install fairscale" ) if shard_idx == num_shards - 1: consolidated_model_state = FSDP.consolidate_shard_weights( shard_weights=model_shard_state["shard_weights"], shard_metadata=model_shard_state["shard_metadata"], ) model = task.build_model(cfg.model) if ( "optimizer_history" in state and len(state["optimizer_history"]) > 0 and "num_updates" in state["optimizer_history"][-1] ): model.set_num_updates( state["optimizer_history"][-1]["num_updates"] ) model.load_state_dict( consolidated_model_state, strict=strict, model_cfg=cfg.model ) else: # model parallel checkpoint or unsharded checkpoint # support old external tasks argspec = inspect.getfullargspec(task.build_model) if "from_checkpoint" in argspec.args: model = task.build_model(cfg.model, from_checkpoint=True) else: model = task.build_model(cfg.model) if ( "optimizer_history" in state and len(state["optimizer_history"]) > 0 and "num_updates" in state["optimizer_history"][-1] ): model.set_num_updates(state["optimizer_history"][-1]["num_updates"]) model.load_state_dict( state["model"], strict=strict, model_cfg=cfg.model ) # reset state so it gets loaded for the next model in ensemble state = None if shard_idx % 10 == 0 and shard_idx > 0: elapsed = time.time() - st logger.info( f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard" ) # build model for ensemble ensemble.append(model) return ensemble, cfg, task
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): from fairseq import meters # only one worker should attempt to create the required dir if trainer.data_parallel_rank == 0: os.makedirs(cfg.save_dir, exist_ok=True) prev_best = getattr(save_checkpoint, "best", val_loss) if val_loss is not None: best_function = max if cfg.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) if cfg.no_save: return trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state if not trainer.should_save_checkpoint_on_current_rank: if trainer.always_call_state_dict_during_save_checkpoint: trainer.state_dict() return write_timer = meters.StopwatchMeter() write_timer.start() epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") def is_better(a, b): return a >= b if cfg.maximize_best_checkpoint_metric else a <= b suffix = trainer.checkpoint_suffix checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 ) checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( not end_of_epoch and cfg.save_interval_updates > 0 and updates % cfg.save_interval_updates == 0 ) checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best) ) if val_loss is not None and cfg.keep_best_checkpoints > 0: worst_best = getattr(save_checkpoint, "best", None) chkpts = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( cfg.best_checkpoint_metric, suffix ), ) if len(chkpts) > 0: p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), "")) # add random digits to resolve ties with data_utils.numpy_seed(epoch, updates, val_loss): rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints) checkpoint_conds[ "checkpoint.best_{}_{:.3f}{}{}.pt".format( cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix ) ] = worst_best is None or is_better(val_loss, worst_best) checkpoint_conds[ "checkpoint_last{}.pt".format(suffix) ] = not cfg.no_last_checkpoints extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) checkpoints = [ os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: if cfg.write_checkpoints_asynchronously: # TODO[ioPath]: Need to implement a delayed asynchronous # file copying/moving feature. logger.warning( f"ioPath is not copying {checkpoints[0]} to {cp} " "since async write mode is on." ) else: assert PathManager.copy( checkpoints[0], cp, overwrite=True ), f"Failed to copy {checkpoints[0]} to {cp}" write_timer.stop() logger.info( "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( checkpoints[0], epoch, updates, val_loss, write_timer.sum ) ) if not end_of_epoch and cfg.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order if cfg.keep_interval_updates_pattern == -1: checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) ) else: checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), keep_match=True, ) checkpoints = [ x[0] for x in checkpoints if x[1] % cfg.keep_interval_updates_pattern != 0 ] for old_chk in checkpoints[cfg.keep_interval_updates :]: if os.path.lexists(old_chk): os.remove(old_chk) elif PathManager.exists(old_chk): PathManager.rm(old_chk) if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) ) for old_chk in checkpoints[cfg.keep_last_epochs :]: if os.path.lexists(old_chk): os.remove(old_chk) elif PathManager.exists(old_chk): PathManager.rm(old_chk) if cfg.keep_best_checkpoints > 0: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( cfg.best_checkpoint_metric, suffix ), ) if not cfg.maximize_best_checkpoint_metric: checkpoints = checkpoints[::-1] for old_chk in checkpoints[cfg.keep_best_checkpoints :]: if os.path.lexists(old_chk): os.remove(old_chk) elif PathManager.exists(old_chk): PathManager.rm(old_chk)
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): """ Load a checkpoint and restore the training iterator. *passthrough_args* will be passed through to ``trainer.get_train_iterator``. """ reset_optimizer = cfg.reset_optimizer reset_lr_scheduler = cfg.reset_lr_scheduler optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) reset_meters = cfg.reset_meters reset_dataloader = cfg.reset_dataloader if cfg.finetune_from_model is not None and ( reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader ): raise ValueError( "--finetune-from-model can not be set together with either --reset-optimizer" " or reset_lr_scheduler or reset_meters or reset_dataloader" ) suffix = trainer.checkpoint_suffix if ( cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' checkpoint_path = os.path.join( cfg.save_dir, "checkpoint_last{}.pt".format(suffix) ) first_launch = not PathManager.exists(checkpoint_path) if first_launch and getattr(cfg, "continue_once", None) is not None: checkpoint_path = cfg.continue_once elif cfg.finetune_from_model is not None and first_launch: # if there is no last checkpoint to restore, start the finetune from pretrained model # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. if PathManager.exists(cfg.finetune_from_model): checkpoint_path = cfg.finetune_from_model reset_optimizer = True reset_lr_scheduler = True reset_meters = True reset_dataloader = True logger.info( f"loading pretrained model from {checkpoint_path}: " "optimizer, lr scheduler, meters, dataloader will be reset" ) else: raise ValueError( f"--finetune-from-model {cfg.finetune_from_model} does not exist" ) elif suffix is not None: checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") else: checkpoint_path = cfg.restore_file if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model: raise ValueError( "--finetune-from-model and --restore-file (non-default value) " "can not be specified together: " + str(cfg) ) extra_state = trainer.load_checkpoint( checkpoint_path, reset_optimizer, reset_lr_scheduler, optimizer_overrides, reset_meters=reset_meters, ) if ( extra_state is not None and "best" in extra_state and not reset_optimizer and not reset_meters ): save_checkpoint.best = extra_state["best"] if extra_state is not None and not reset_dataloader: # restore iterator from checkpoint itr_state = extra_state["train_iterator"] epoch_itr = trainer.get_train_iterator( epoch=itr_state["epoch"], load_dataset=True, **passthrough_args ) epoch_itr.load_state_dict(itr_state) else: epoch_itr = trainer.get_train_iterator( epoch=1, load_dataset=True, **passthrough_args ) trainer.lr_step(epoch_itr.epoch) return extra_state, epoch_itr
def load_checkpoint(args, trainer, **passthrough_args): """ Load a checkpoint and restore the training iterator. *passthrough_args* will be passed through to ``trainer.get_train_iterator``. """ reset_optimizer = args.reset_optimizer reset_lr_scheduler = args.reset_lr_scheduler optimizer_overrides = eval(args.optimizer_overrides) reset_meters = args.reset_meters reset_dataloader = args.reset_dataloader if getattr(args, 'finetune_from_model', None) is not None \ and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader): raise ValueError( "--finetune-from-model can not be set together with either --reset-optimizer" " or reset_lr_scheduler or reset_meters or reset_dataloader") suffix = getattr(args, "checkpoint_suffix", "") if args.restore_file == "checkpoint_last.pt": # default value of restore_file is 'checkpoint_last.pt' checkpoint_path = os.path.join(args.save_dir, "checkpoint_last{}.pt".format(suffix)) first_launch = not PathManager.exists(checkpoint_path) if getattr(args, 'finetune_from_model', None) is not None and first_launch: # if there is no last checkpoint to restore, start the finetune from pretrained model # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. if PathManager.exists(args.finetune_from_model): checkpoint_path = args.finetune_from_model reset_optimizer = True reset_lr_scheduler = True reset_meters = True reset_dataloader = True logger.info( f'loading pretrained model from {checkpoint_path}: ' 'optimizer, lr scheduler, meters, dataloader will be reset' ) else: raise ValueError( f'--funetune-from-model {args.finetune_from_model} does not exist' ) elif getattr(args, "model_parallel_size", 1) > 1: checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt") else: checkpoint_path = args.restore_file if args.restore_file != "checkpoint_last.pt" and getattr( args, 'finetune_from_model', None): raise ValueError( '--finetune-from-model and --restore-file (non-default value) ' 'can not be specified together: ' + str(args)) extra_state = trainer.load_checkpoint( checkpoint_path, reset_optimizer, reset_lr_scheduler, optimizer_overrides, reset_meters=reset_meters, ) if (extra_state is not None and "best" in extra_state and not reset_optimizer and not reset_meters): save_checkpoint.best = extra_state["best"] if extra_state is not None and not reset_dataloader: # restore iterator from checkpoint itr_state = extra_state["train_iterator"] epoch_itrs = trainer.get_train_iterator(epoch=itr_state["epoch"], load_dataset=True, **passthrough_args) epoch_itrs.load_state_dict(itr_state) else: epoch_itrs = trainer.get_train_iterator(epoch=1, load_dataset=True, **passthrough_args) if isinstance(epoch_itrs, list): trainer.lr_step(epoch_itrs[0].epoch) else: trainer.lr_step(epoch_itrs.epoch) return extra_state, epoch_itrs
def load_model_ensemble_and_task( filenames, arg_overrides: Optional[Dict[str, Any]] = None, task=None, strict=True, suffix="", num_shards=1, state=None, ): assert state is None or len(filenames) == 1 from fairseq import tasks assert not ( strict and num_shards > 1 ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] cfg = None for filename in filenames: orig_filename = filename model_shard_state = {"shard_weights": [], "shard_metadata": []} assert num_shards > 0 st = time.time() for shard_idx in range(num_shards): filename = get_maybe_sharded_checkpoint_filename( orig_filename, suffix, shard_idx, num_shards ) if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) if state is None: state = load_checkpoint_to_cpu(filename, arg_overrides) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: cfg = state["cfg"] else: raise RuntimeError( f"Neither args nor cfg exist in state keys = {state.keys()}" ) if task is None: task = tasks.setup_task(cfg.task) if "task_state" in state: task.load_state_dict(state["task_state"]) if "fsdp_metadata" in state and num_shards > 1: model_shard_state["shard_weights"].append(state["model"]) model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) # check FSDP import before the code goes too far if not has_FSDP: raise ImportError( "Cannot find FullyShardedDataParallel. " "Please install fairscale with: pip install fairscale" ) if shard_idx == num_shards - 1: consolidated_model_state = FSDP.consolidate_shard_weights( shard_weights=model_shard_state["shard_weights"], shard_metadata=model_shard_state["shard_metadata"], ) model = task.build_model(cfg.model) model.load_state_dict( consolidated_model_state, strict=strict, model_cfg=cfg.model ) else: # model parallel checkpoint or unsharded checkpoint model = task.build_model(cfg.model) new_state_model = state["model"] '''=====The following if-else statement is a work-around ===== # the current metadata loading/saving of pytorch. # In Pytorch, if state["model"]["_metadata"] exists as dictionary, then model.load_state_dict(strict=True) # will throw an error for unexpected "_metadata" key. To avoid this error, we need the state_dict to be # in orderedDict format, which has new_state_model._metadata attribute but not as key. # TODO yuansg@ This issue should be fixed in pytorch ideally. ''' if new_state_model.get("_metadata", None) is not None: new_metadata = new_state_model.get("_metadata", None) del state["model"]["_metadata"] else: new_metadata = None # Construct state dict content. contents = OrderedDict(new_state_model) # We explicitly set _metadata for the state models. The _metadata is implicitly stored for pytorch models. # calling state["model"] in fairseq will not invoke metadata storage. if new_metadata is None: logger.warning("===Jit: state[\"model\"] does not contain key \"_metadata\"=====") logger.warning("===Jit: we will be filling in with current model's meta-data instead.") # For models trained before this diff, we do the following to be backward compatible. contents.__setattr__("_metadata", model.state_dict()._metadata) else: contents.__setattr__("_metadata", new_metadata) '''====End of work-around logic=====''' model.load_state_dict( contents, strict=strict, model_cfg=cfg.model ) # reset state so it gets loaded for the next model in ensemble state = None if shard_idx % 10 == 0 and shard_idx > 0: elapsed = time.time() - st logger.info( f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard" ) # build model for ensemble ensemble.append(model) return ensemble, cfg, task