def on_save_checkpoint(self, trainer, pl_module, checkpoint): """ Override the default on_save_checkpoint to save the best model if needed. Parameters ---------- trainer: The trainer object. pl_module: The PyTorch-Lightning module. checkpoint: The checkpoint object. """ output = super().on_save_checkpoint(trainer, pl_module, checkpoint) if not self.always_save_mridc: return output # Load the best model and then re-save it app_state = AppState() if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: raise ValueError( "always_save_mridc is not implemented for model parallel models." ) # since we are creating tarfile artifacts we need to update .mridc path app_state.model_restore_path = os.path.abspath( os.path.expanduser( os.path.join(self.dirpath, self.prefix + self.postfix))) if self.save_best_model: if not os.path.exists(self.best_model_path): return output if self.best_model_path == self.previous_best_path: return output self.previous_model_path = self.best_model_path old_state_dict = deepcopy(pl_module.state_dict()) checkpoint = torch.load(self.best_model_path, map_location="cpu") if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] # get a new instance of the model pl_module.load_state_dict(checkpoint, strict=True) pl_module.save_to(save_path=app_state.model_restore_path) pl_module.load_state_dict(old_state_dict, strict=True) else: pl_module.save_to(save_path=app_state.model_restore_path) return output
def inject_model_parallel_rank(filepath): """Injects tensor/pipeline model parallel ranks into the filepath. Does nothing if not using model parallelism.""" filepath = uninject_model_parallel_rank(filepath) app_state = AppState() if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: # filepath needs to be updated to include mp_rank dirname = os.path.dirname(filepath) basename = os.path.basename(filepath) if app_state.pipeline_model_parallel_size is None or app_state.pipeline_model_parallel_size == 1: filepath = f"{dirname}/mp_rank_{app_state.tensor_model_parallel_rank:02d}/{basename}" else: filepath = ( f"{dirname}/tp_rank_{app_state.tensor_model_parallel_rank:02d}_pp_rank_" f"{app_state.pipeline_model_parallel_rank:03d}/{basename} ") return filepath return filepath
def _del_model_without_trainer(self, filepath: str) -> None: """ Delete a model without a trainer. Parameters ---------- filepath: The path to the model to delete. """ app_state = AppState() if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: # filepath needs to be updated to include mp_rank filepath = mridc.utils.model_utils.inject_model_parallel_rank( filepath) # type: ignore # each model parallel rank needs to remove its model if is_global_rank_zero() or (app_state.model_parallel_size is not None and app_state.data_parallel_rank == 0): try: self._fs.rm(filepath) logging.info(f"Removed checkpoint: {filepath}") except FileNotFoundError: logging.info( f"Tried to remove checkpoint: {filepath} but failed.")
def test_restore_from_save_restore_connector_extracted_dir(self): class MySaveRestoreConnector(save_restore_connector.SaveRestoreConnector): def save_to(self, model, save_path: str): save_path = save_path.replace(".mridc", "_XYZ.mridc") super().save_to(model, save_path) class MockModelV2(MockModel): pass with tempfile.TemporaryDirectory() as extracted_tempdir: with tempfile.TemporaryDirectory() as tmpdir: # Update config cfg = _mock_model_config() # Create model save_path = os.path.join(tmpdir, "save_custom.mridc") model_with_custom_connector = MockModel(cfg=cfg.model, trainer=None) model_with_custom_connector._save_restore_connector = MySaveRestoreConnector() model_with_custom_connector.save_to(save_path) mridc_filepath = os.path.join(tmpdir, "save_custom_XYZ.mridc") assert os.path.exists(mridc_filepath) # extract the contents to this dir apriori # simulate by extracting now before calling restore_from connector = MySaveRestoreConnector() MySaveRestoreConnector._unpack_mridc_file(mridc_filepath, extracted_tempdir) assert get_size(extracted_tempdir) > 0 # delete the old directory and preserve only the new extracted directory (escape scope of old dir) # next, set the model's extracted directory path connector.model_extracted_dir = extracted_tempdir # note, we pass in the "old" mridc_filepath, stored somewhere other than the extracted directory # this mridc_filepath is no longer valid, and has been deleted. restored_model = MockModelV2.restore_from(mridc_filepath, save_restore_connector=connector) assert type(restored_model) == MockModelV2 assert type(restored_model._save_restore_connector) == MySaveRestoreConnector # assert models have correct restoration information and paths appstate = AppState() original_metadata = appstate.get_model_metadata_from_guid(model_with_custom_connector.model_guid) assert original_metadata.restoration_path is None restored_metadata = appstate.get_model_metadata_from_guid(restored_model.model_guid) assert restored_metadata.restoration_path is not None # assert that the restore path was the path of the pre-extracted directory # irrespective of whether an old `mridc_filepath` (which doesnt exist anymore) was passed to restore_from. assert extracted_tempdir in restored_metadata.restoration_path assert extracted_tempdir not in mridc_filepath assert not os.path.exists(mridc_filepath) # test for parameter equality model_with_custom_connector = model_with_custom_connector.to("cpu") restored_model = restored_model.to("cpu") original_state_dict = model_with_custom_connector.state_dict() restored_state_dict = restored_model.state_dict() for orig, restored in zip(original_state_dict.keys(), restored_state_dict.keys()): assert (original_state_dict[orig] - restored_state_dict[restored]).abs().mean() < 1e-6
def load_config_and_state_dict( self, calling_cls, restore_path: str, override_config_path: Optional[Union[OmegaConf, str]] = None, map_location: Optional[torch.device] = None, strict: bool = True, return_config: bool = False, trainer: Trainer = None, ): """ Restores model instance (weights and configuration) into .mridc file Parameters ---------- calling_cls: Class of the model to be restored. restore_path: path to .mridc file from which model should be instantiated override_config_path: path to a yaml config that will override the internal config file or an OmegaConf/DictConfig object representing the model config. map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will select a GPU if available, falling back to CPU otherwise. strict: Passed to load_state_dict. By default, True. return_config: If set to true, will return just the underlying config of the restored model as an OmegaConf DictConfig object without instantiating the model. trainer: Optional trainer object to be used for model parallelism. Example ------- ``` model = mridc.collections.asr.models.EncDecCTCModel.restore_from('asr.mridc') assert isinstance(model, mridc.collections.asr.models.EncDecCTCModel) ``` Returns ------- An instance of type cls or its underlying config (if return_config is set). """ # Get path where the command is executed - the artifacts will be "retrieved" there # (original .mridc behavior) cwd = os.getcwd() if map_location is None: if torch.cuda.is_available(): map_location = torch.device("cuda") else: map_location = torch.device("cpu") app_state = AppState() with tempfile.TemporaryDirectory() as tmpdir: try: # Check if self.model_extracted_dir is set, and is a valid path if self.model_extracted_dir is not None and os.path.isdir( self.model_extracted_dir): # Log that MRIDC will use the provided `model_extracted_dir` logging.info( "Restoration will occur within pre-extracted directory : " f"`{self.model_extracted_dir}`.") # Override `tmpdir` above with the pre-extracted `model_extracted_dir` tmpdir = self.model_extracted_dir else: # Extract the nemo file into the temporary directory self._unpack_mridc_file(path2file=restore_path, out_folder=tmpdir) # Change current working directory to the temporary directory os.chdir(tmpdir) if override_config_path is None: config_yaml = os.path.join(tmpdir, self.model_config_yaml) else: # can be str path or OmegaConf / DictConfig object config_yaml = override_config_path if not isinstance(config_yaml, (OmegaConf, DictConfig)): conf = OmegaConf.load(config_yaml) else: conf = config_yaml if override_config_path is not None: # Resolve the override config conf = OmegaConf.to_container(conf, resolve=True) conf = OmegaConf.create(conf) # If override is top level config, extract just `model` from it if "model" in conf: conf = conf.model if return_config: instance = conf return instance if app_state.model_parallel_rank is not None and app_state.model_parallel_size > 1: model_weights = self._inject_model_parallel_rank_for_ckpt( tmpdir, self.model_weights_ckpt) else: model_weights = os.path.join(tmpdir, self.model_weights_ckpt) OmegaConf.set_struct(conf, True) os.chdir(cwd) # get the class calling_cls._set_model_restore_state( is_being_restored=True, folder=tmpdir) # type: ignore instance = calling_cls.from_config_dict(config=conf, trainer=trainer) instance = instance.to(map_location) # add load_state_dict override if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: model_weights = self._inject_model_parallel_rank_for_ckpt( tmpdir, self.model_weights_ckpt) instance.load_state_dict(self._load_state_dict_from_disk( model_weights, map_location=map_location), strict=strict) logging.info( f"Model {instance.__class__.__name__} was successfully restored from {restore_path}." ) instance._set_model_restore_state( is_being_restored=False) # type: ignore finally: os.chdir(cwd) return instance
def _handle_artifacts(self, model, mridc_file_folder): """ This method is called by ModelPT.save_to() and ModelPT.load_from(). It will handle all artifacts and save them to the mridc file. Parameters ---------- model: ModelPT object to register artifact for. mridc_file_folder: Path to the mridc file. """ tarfile_artifacts = [] app_state = AppState() for conf_path, artiitem in model.artifacts.items(): if artiitem.path_type == mridc.utils.model_utils.ArtifactPathType.LOCAL_PATH: if not os.path.exists(artiitem.path): raise FileNotFoundError( f"Artifact {conf_path} not found at location: {artiitem.path}" ) # Generate new uniq artifact name and copy it to mridc_file_folder # Note uuid.uuid4().hex is guaranteed to be 32 character long artifact_base_name = os.path.basename(artiitem.path) artifact_uniq_name = f"{uuid.uuid4().hex}_{artifact_base_name}" shutil.copy2( artiitem.path, os.path.join(mridc_file_folder, artifact_uniq_name)) # Update artifacts registry artiitem.hashed_path = f"mridc:{artifact_uniq_name}" model.artifacts[conf_path] = artiitem elif artiitem.path_type == mridc.utils.model_utils.ArtifactPathType.TAR_PATH: # process all tarfile artifacts in one go, so preserve key-value pair tarfile_artifacts.append((conf_path, artiitem)) else: raise ValueError( "Directly referencing artifacts from other mridc files isn't supported yet" ) # Process current tarfile artifacts by unpacking the previous tarfile and extract the artifacts # that are currently required. model_metadata = app_state.get_model_metadata_from_guid( model.model_guid) if tarfile_artifacts and model_metadata.restoration_path is not None: # Need to step into mridc archive to extract file # Get path where the command is executed - the artifacts will be "retrieved" there # (original .mridc behavior) cwd = os.getcwd() try: # Step into the mridc archive to try and find the file with tempfile.TemporaryDirectory() as archive_dir: self._unpack_mridc_file( path2file=model_metadata.restoration_path, out_folder=archive_dir) os.chdir(archive_dir) for conf_path, artiitem in tarfile_artifacts: # Get basename and copy it to mridc_file_folder if "mridc:" in artiitem.path: artifact_base_name = artiitem.path.split( "mridc:")[1] else: artifact_base_name = os.path.basename( artiitem.path) # no need to hash here as we are in tarfile_artifacts which are already hashed artifact_uniq_name = artifact_base_name shutil.copy2( artifact_base_name, os.path.join(mridc_file_folder, artifact_uniq_name)) # Update artifacts registry new_artiitem = mridc.utils.model_utils.ArtifactItem() new_artiitem.path = f"mridc:{artifact_uniq_name}" new_artiitem.path_type = mridc.utils.model_utils.ArtifactPathType.TAR_PATH model.artifacts[conf_path] = new_artiitem finally: # change back working directory os.chdir(cwd)
def register_artifact(model, config_path: str, src: str, verify_src_exists: bool = True): """ Register model artifacts with this function. These artifacts (files) will be included inside .mridc file when model.save_to("mymodel.mridc") is called. How it works: 1. It always returns existing absolute path which can be used during Model constructor call. EXCEPTION: src is None or "" in which case nothing will be done and src will be returned 2. It will add (config_path, model_utils.ArtifactItem()) pair to self.artifacts. If "src" is local existing path, then it will be returned in absolute path form. elif "src" starts with "mridc_file:unique_artifact_name": .mridc will be untarred to a temporary folder location and an actual existing path will be returned else an error will be raised. WARNING: use .register_artifact calls in your models' constructors. The returned path is not guaranteed to exist after you have exited your model's constructor. Parameters ---------- model: ModelPT object to register artifact for. config_path: Artifact key. Usually corresponds to the model config. src: Path to artifact. verify_src_exists: If set to False, then the artifact is optional and register_artifact will return None even if src is not found. Defaults to True. Returns -------- If src is not None or empty it always returns absolute path which is guaranteed to exist during model instance life. """ app_state = AppState() artifact_item = mridc.utils.model_utils.ArtifactItem() # type: ignore # This is for backward compatibility, if the src objects exists simply inside the tarfile # without its key having been overridden, this pathway will be used. src_obj_name = os.path.basename(src) if app_state.mridc_file_folder is not None: src_obj_path = os.path.abspath( os.path.join(app_state.mridc_file_folder, src_obj_name)) else: src_obj_path = src_obj_name # src is a local existing path - register artifact and return exact same path for usage by the model if os.path.exists(os.path.abspath(src)): return_path = os.path.abspath(src) artifact_item.path_type = mridc.utils.model_utils.ArtifactPathType.LOCAL_PATH # type: ignore elif src.startswith("mridc:"): return_path = os.path.abspath( os.path.join(app_state.mridc_file_folder, src[5:])) artifact_item.path_type = mridc.utils.model_utils.ArtifactPathType.TAR_PATH # type: ignore elif os.path.exists(src_obj_path): return_path = src_obj_path artifact_item.path_type = mridc.utils.model_utils.ArtifactPathType.TAR_PATH # type: ignore elif verify_src_exists: raise FileNotFoundError( f"src path does not exist or it is not a path in mridc file. src value I got was: {src}. " f"Absolute: {os.path.abspath(src)}") else: # artifact is optional and we simply return None return None if not os.path.exists(return_path): raise AssertionError artifact_item.path = os.path.abspath(src) model.artifacts[config_path] = artifact_item # we were called by ModelPT if hasattr(model, "cfg"): with open_dict(model._cfg): OmegaConf.update(model.cfg, config_path, return_path) return return_path
def exp_manager( trainer: Trainer, cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]: """ exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning \ paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will \ get exp_dir, name, and version from the logger. Otherwise, it will use the exp_dir and name arguments to create \ the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir. The version can be a datetime string or an integer. Datetime version can be disabled if you use_datetime_version \ is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch \ lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file \ for each process to log their output into. exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from \ the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need \ multiple consecutive jobs), you need to avoid creating the version folders. Therefore, from v1.0.0, when \ resume_if_exists is set to True, creating the version folders is ignored. Parameters ---------- trainer: The lightning trainer object. cfg: Can have the following keys: - explicit_log_dir: Can be used to override exp_dir/name/version folder creation. Defaults to None, which \ will use exp_dir, name, and version to construct the logging directory. - exp_dir: The base directory to create the logging directory. Defaults to None, which logs to \ ./mridc_experiments. - name: The name of the experiment. Defaults to None which turns into "default" via name = name or "default". - version: The version of the experiment. Defaults to None which uses either a datetime string or lightning's \ TensorboardLogger system of using version_{int}. - use_datetime_version: Whether to use a datetime string for version. Defaults to True. - resume_if_exists: Whether this experiment is resuming from a previous run. If True, it sets \ trainer._checkpoint_connector.resume_from_checkpoint_fit_path so that the trainer should auto-resume. \ exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. From v1.0.0, when \ resume_if_exists is True, we would not create version folders to make it easier to find the log folder for \ next runs. - resume_past_end: exp_manager errors out if resume_if_exists is True and a checkpoint matching \*end.ckpt \ indicating a previous training run fully completed. This behaviour can be disabled, in which case the \ \*end.ckpt will be loaded by setting resume_past_end to True. Defaults to False. - resume_ignore_no_checkpoint: exp_manager errors out if resume_if_exists is True and no checkpoint could be \ found. This behaviour can be disabled, in which case exp_manager will print a message and continue without \ restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False. - create_tensorboard_logger: Whether to create a tensorboard logger and attach it to the pytorch lightning \ trainer. Defaults to True. - summary_writer_kwargs: A dictionary of kwargs that can be passed to lightning's TensorboardLogger class. \ Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None. - create_wandb_logger: Whether to create a Weights and Biases logger and attach it to the pytorch lightning \ trainer. Defaults to False. - wandb_logger_kwargs: A dictionary of kwargs that can be passed to lightning's WandBLogger class. Note that \ name and project are required parameters if create_wandb_logger is True. Defaults to None. - create_checkpoint_callback: Whether to create a ModelCheckpoint callback and attach it to the pytorch \ lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most recent \ checkpoint under \*last.ckpt, and the final checkpoint after training completes under \*end.ckpt. \ Defaults to True. - files_to_copy: A list of files to copy to the experiment logging directory. Defaults to None which copies \ no files. - log_local_rank_0_only: Whether to only create log files for local rank 0. Defaults to False. Set this to \ True if you are using DDP with many GPUs and do not want many log files in your exp dir. - log_global_rank_0_only: Whether to only create log files for global rank 0. Defaults to False. Set this to \ True if you are using DDP with many GPUs and do not want many log files in your exp dir. Returns ------- The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and version. """ # Add rank information to logger # Note: trainer.global_rank and trainer.is_global_zero are not set until trainer.fit, so have to hack around it local_rank = int(os.environ.get("LOCAL_RANK", 0)) global_rank = trainer.node_rank * trainer.num_devices + local_rank logging.rank = global_rank if cfg is None: logging.error( "exp_manager did not receive a cfg argument. It will be disabled.") return None if trainer.fast_dev_run: logging.info( "Trainer was called with fast_dev_run. exp_manager will return without any functionality." ) return None # Ensure passed cfg is compliant with ExpManagerConfig schema = OmegaConf.structured(ExpManagerConfig) if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) elif not isinstance(cfg, DictConfig): raise ValueError( f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig" ) cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True)) cfg = OmegaConf.merge(schema, cfg) error_checks( trainer, cfg ) # Ensures that trainer options are compliant with MRIDC and exp_manager arguments log_dir, exp_dir, name, version = get_log_dir( trainer=trainer, exp_dir=cfg.exp_dir, name=cfg.name, version=cfg.version, explicit_log_dir=cfg.explicit_log_dir, use_datetime_version=cfg.use_datetime_version, resume_if_exists=cfg.resume_if_exists, ) if cfg.resume_if_exists: check_resume(trainer, str(log_dir), cfg.resume_past_end, cfg.resume_ignore_no_checkpoint) checkpoint_name = name # If name returned from get_log_dir is "", use cfg.name for checkpointing if checkpoint_name is None or checkpoint_name == "": checkpoint_name = cfg.name or "default" cfg.name = name # Used for configure_loggers so that the log_dir is properly set even if name is "" cfg.version = version # update app_state with log_dir, exp_dir, etc app_state = AppState() app_state.log_dir = log_dir app_state.exp_dir = exp_dir app_state.name = name app_state.version = version app_state.checkpoint_name = checkpoint_name app_state.create_checkpoint_callback = cfg.create_checkpoint_callback app_state.checkpoint_callback_params = cfg.checkpoint_callback_params # Create the logging directory if it does not exist os.makedirs( log_dir, exist_ok=True ) # Cannot limit creation to global zero as all ranks write to own log file logging.info(f"Experiments will be logged at {log_dir}") trainer._default_root_dir = log_dir if cfg.log_local_rank_0_only is True and cfg.log_global_rank_0_only is True: raise ValueError( "Cannot set both log_local_rank_0_only and log_global_rank_0_only to True." "Please set either one or neither.") # This is set if the env var MRIDC_TESTING is set to True. mridc_testing = get_envbool(MRIDC_ENV_VARNAME_TESTING, False) log_file = log_dir / f"mridc_log_globalrank-{global_rank}_localrank-{local_rank}.txt" # Handle logging to file. Logs local rank 0 only if local_rank == 0 and cfg.log_local_rank_0_only and not mridc_testing: logging.add_file_handler(log_file) elif global_rank == 0 and cfg.log_global_rank_0_only and mridc_testing: logging.add_file_handler(log_file) else: logging.add_file_handler(log_file) # For some reason, LearningRateLogger requires trainer to have a logger. Safer to create logger on all ranks # not just global rank 0. if cfg.create_tensorboard_logger or cfg.create_wandb_logger: configure_loggers( trainer, [Path(exp_dir)], cfg.name, cfg.version, cfg.create_tensorboard_logger, cfg.summary_writer_kwargs, cfg.create_wandb_logger, cfg.wandb_logger_kwargs, ) # add loggers timing callbacks if cfg.log_step_timing: timing_callback = TimingCallback( timer_kwargs=cfg.step_timing_kwargs or {}) trainer.callbacks.insert(0, timing_callback) if cfg.create_checkpoint_callback: configure_checkpointing(trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params) if is_global_rank_zero(): # Move files_to_copy to folder and add git information if present if cfg.files_to_copy: for _file in cfg.files_to_copy: copy(Path(_file), log_dir) # Create files for cmd args and git info with open(log_dir / "cmd-args.log", "w", encoding="utf-8") as _file: _file.write(" ".join(sys.argv)) # Try to get git hash git_repo, git_hash = get_git_hash() if git_repo: with open(log_dir / "git-info.log", "w", encoding="utf-8") as _file: _file.write(f"commit hash: {git_hash}") _file.write(get_git_diff()) # Add err_file logging to global_rank zero logging.add_err_file_handler(log_dir / "mridc_error_log.txt") # Add lightning file logging to global_rank zero add_filehandlers_to_pl_logger(log_dir / "lightning_logs.txt", log_dir / "mridc_error_log.txt") return log_dir