def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): """Loads a checkpoint to CPU (with upgrading for backward compatibility). If doing single-GPU training or if the checkpoint is only being loaded by at most one process on each node (current default behavior is for only rank 0 to read the checkpoint from disk), load_on_all_ranks should be False to avoid errors from torch.distributed not having been initialized or torch.distributed.barrier() hanging. If all processes on each node may be loading the checkpoint simultaneously, load_on_all_ranks should be set to True to avoid I/O conflicts. There's currently no support for > 1 but < all processes loading the checkpoint on each node. """ local_path = PathManager.get_local_path(path) # The locally cached file returned by get_local_path() may be stale for # remote files that are periodically updated/overwritten (ex: # checkpoint_last.pt) - so we remove the local copy, sync across processes # (if needed), and then download a fresh copy. if local_path != path and PathManager.path_requires_pathmanager(path): try: os.remove(local_path) except FileNotFoundError: # With potentially multiple processes removing the same file, the # file being missing is benign (missing_ok isn't available until # Python 3.8). pass if load_on_all_ranks: torch.distributed.barrier() local_path = PathManager.get_local_path(path) with open(local_path, "rb") as f: state = torch.load(f, map_location=torch.device("cpu")) if "args" in state and state["args"] is not None and arg_overrides is not None: args = state["args"] for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) if "cfg" in state and state["cfg"] is not None: # hack to be able to set Namespace in dict config. this should be removed when we update to newer # omegaconf version that supports object flags, or when we migrate all existing models from omegaconf import _utils old_primitive = _utils.is_primitive_type _utils.is_primitive_type = lambda _: True state["cfg"] = OmegaConf.create(state["cfg"]) _utils.is_primitive_type = old_primitive OmegaConf.set_struct(state["cfg"], True) if arg_overrides is not None: overwrite_args_by_name(state["cfg"], arg_overrides) state = _upgrade_state_dict(state) return state
def load_checkpoint_to_cpu(path, arg_overrides=None): """Loads a checkpoint to CPU (with upgrading for backward compatibility).""" with open(PathManager.get_local_path(path), "rb") as f: state = torch.load(f, map_location=torch.device("cpu")) if "args" in state and state["args"] is not None and arg_overrides is not None: args = state["args"] for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) if "cfg" in state and state["cfg"] is not None and arg_overrides is not None: overwrite_args_by_name(state["cfg"], arg_overrides) state = _upgrade_state_dict(state) return state
def load_checkpoint_to_cpu(path, arg_overrides=None): """Loads a checkpoint to CPU (with upgrading for backward compatibility).""" print("RAV MSG CATCH: ", path) if (path == '/home/ubuntu/project/model/wav2vec_small.pt'): path = '/home/ras306/Classwork/Project/Wav2Vec/model/wav2vec_small.pt' with open(PathManager.get_local_path(path), "rb") as f: state = torch.load(f, map_location=torch.device("cpu")) if "args" in state and state[ "args"] is not None and arg_overrides is not None: args = state["args"] for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) if "cfg" in state and state[ "cfg"] is not None and arg_overrides is not None: overwrite_args_by_name(state["cfg"], arg_overrides) state = _upgrade_state_dict(state) return state