def load_model_ensemble_and_task( filenames, arg_overrides: Optional[Dict[str, Any]] = None, task=None, strict=True, suffix="", num_shards=1, state=None, ): assert state is None or len(filenames) == 1 from fairseq import tasks assert not ( strict and num_shards > 1 ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] cfg = None for filename in filenames: orig_filename = filename model_shard_state = {"shard_weights": [], "shard_metadata": []} assert num_shards > 0 st = time.time() for shard_idx in range(num_shards): filename = get_maybe_sharded_checkpoint_filename( orig_filename, suffix, shard_idx, num_shards ) if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) if state is None: state = load_checkpoint_to_cpu(filename, arg_overrides) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: cfg = state["cfg"] else: raise RuntimeError( f"Neither args nor cfg exist in state keys = {state.keys()}" ) if task is None: task = tasks.setup_task(cfg.task) if "task_state" in state: task.load_state_dict(state["task_state"]) if "fsdp_metadata" in state and num_shards > 1: model_shard_state["shard_weights"].append(state["model"]) model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) # check FSDP import before the code goes too far if not has_FSDP: raise ImportError( "Cannot find FullyShardedDataParallel. " "Please install fairscale with: pip install fairscale" ) if shard_idx == num_shards - 1: consolidated_model_state = FSDP.consolidate_shard_weights( shard_weights=model_shard_state["shard_weights"], shard_metadata=model_shard_state["shard_metadata"], ) model = task.build_model(cfg.model) if ( "optimizer_history" in state and len(state["optimizer_history"]) > 0 and "num_updates" in state["optimizer_history"][-1] ): model.set_num_updates( state["optimizer_history"][-1]["num_updates"] ) model.load_state_dict( consolidated_model_state, strict=strict, model_cfg=cfg.model ) else: # model parallel checkpoint or unsharded checkpoint # support old external tasks argspec = inspect.getfullargspec(task.build_model) if "from_checkpoint" in argspec.args: model = task.build_model(cfg.model, from_checkpoint=True) else: model = task.build_model(cfg.model) if ( "optimizer_history" in state and len(state["optimizer_history"]) > 0 and "num_updates" in state["optimizer_history"][-1] ): model.set_num_updates(state["optimizer_history"][-1]["num_updates"]) model.load_state_dict( state["model"], strict=strict, model_cfg=cfg.model ) # reset state so it gets loaded for the next model in ensemble state = None if shard_idx % 10 == 0 and shard_idx > 0: elapsed = time.time() - st logger.info( f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard" ) # build model for ensemble ensemble.append(model) return ensemble, cfg, task
def load_model_ensemble_and_task( filenames, arg_overrides: Optional[Dict[str, Any]] = None, task=None, strict=True, suffix="", num_shards=1, state=None, ): assert state is None or len(filenames) == 1 from fairseq import tasks assert not ( strict and num_shards > 1 ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] cfg = None for filename in filenames: orig_filename = filename model_shard_state = {"shard_weights": [], "shard_metadata": []} assert num_shards > 0 st = time.time() for shard_idx in range(num_shards): filename = get_maybe_sharded_checkpoint_filename( orig_filename, suffix, shard_idx, num_shards ) if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) if state is None: state = load_checkpoint_to_cpu(filename, arg_overrides) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: cfg = state["cfg"] else: raise RuntimeError( f"Neither args nor cfg exist in state keys = {state.keys()}" ) if task is None: task = tasks.setup_task(cfg.task) if "task_state" in state: task.load_state_dict(state["task_state"]) if "fsdp_metadata" in state and num_shards > 1: model_shard_state["shard_weights"].append(state["model"]) model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) # check FSDP import before the code goes too far if not has_FSDP: raise ImportError( "Cannot find FullyShardedDataParallel. " "Please install fairscale with: pip install fairscale" ) if shard_idx == num_shards - 1: consolidated_model_state = FSDP.consolidate_shard_weights( shard_weights=model_shard_state["shard_weights"], shard_metadata=model_shard_state["shard_metadata"], ) model = task.build_model(cfg.model) model.load_state_dict( consolidated_model_state, strict=strict, model_cfg=cfg.model ) else: # model parallel checkpoint or unsharded checkpoint model = task.build_model(cfg.model) new_state_model = state["model"] '''=====The following if-else statement is a work-around ===== # the current metadata loading/saving of pytorch. # In Pytorch, if state["model"]["_metadata"] exists as dictionary, then model.load_state_dict(strict=True) # will throw an error for unexpected "_metadata" key. To avoid this error, we need the state_dict to be # in orderedDict format, which has new_state_model._metadata attribute but not as key. # TODO yuansg@ This issue should be fixed in pytorch ideally. ''' if new_state_model.get("_metadata", None) is not None: new_metadata = new_state_model.get("_metadata", None) del state["model"]["_metadata"] else: new_metadata = None # Construct state dict content. contents = OrderedDict(new_state_model) # We explicitly set _metadata for the state models. The _metadata is implicitly stored for pytorch models. # calling state["model"] in fairseq will not invoke metadata storage. if new_metadata is None: logger.warning("===Jit: state[\"model\"] does not contain key \"_metadata\"=====") logger.warning("===Jit: we will be filling in with current model's meta-data instead.") # For models trained before this diff, we do the following to be backward compatible. contents.__setattr__("_metadata", model.state_dict()._metadata) else: contents.__setattr__("_metadata", new_metadata) '''====End of work-around logic=====''' model.load_state_dict( contents, strict=strict, model_cfg=cfg.model ) # reset state so it gets loaded for the next model in ensemble state = None if shard_idx % 10 == 0 and shard_idx > 0: elapsed = time.time() - st logger.info( f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard" ) # build model for ensemble ensemble.append(model) return ensemble, cfg, task