Beispiel #1
0
    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        """
        Override the default on_save_checkpoint to save the best model if needed.

        Parameters
        ----------
        trainer: The trainer object.
        pl_module: The PyTorch-Lightning module.
        checkpoint: The checkpoint object.
        """
        output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
        if not self.always_save_mridc:
            return output
        # Load the best model and then re-save it
        app_state = AppState()

        if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
            raise ValueError(
                "always_save_mridc is not implemented for model parallel models."
            )

        # since we are creating tarfile artifacts we need to update .mridc path
        app_state.model_restore_path = os.path.abspath(
            os.path.expanduser(
                os.path.join(self.dirpath, self.prefix + self.postfix)))

        if self.save_best_model:
            if not os.path.exists(self.best_model_path):
                return output

            if self.best_model_path == self.previous_best_path:
                return output

            self.previous_model_path = self.best_model_path
            old_state_dict = deepcopy(pl_module.state_dict())
            checkpoint = torch.load(self.best_model_path, map_location="cpu")
            if "state_dict" in checkpoint:
                checkpoint = checkpoint["state_dict"]

            # get a new instance of the model
            pl_module.load_state_dict(checkpoint, strict=True)
            pl_module.save_to(save_path=app_state.model_restore_path)
            pl_module.load_state_dict(old_state_dict, strict=True)
        else:
            pl_module.save_to(save_path=app_state.model_restore_path)
        return output
Beispiel #2
0
def inject_model_parallel_rank(filepath):
    """Injects tensor/pipeline model parallel ranks into the filepath. Does nothing if not using model parallelism."""
    filepath = uninject_model_parallel_rank(filepath)
    app_state = AppState()
    if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
        # filepath needs to be updated to include mp_rank
        dirname = os.path.dirname(filepath)
        basename = os.path.basename(filepath)
        if app_state.pipeline_model_parallel_size is None or app_state.pipeline_model_parallel_size == 1:
            filepath = f"{dirname}/mp_rank_{app_state.tensor_model_parallel_rank:02d}/{basename}"
        else:
            filepath = (
                f"{dirname}/tp_rank_{app_state.tensor_model_parallel_rank:02d}_pp_rank_"
                f"{app_state.pipeline_model_parallel_rank:03d}/{basename} ")
        return filepath
    return filepath
Beispiel #3
0
    def _del_model_without_trainer(self, filepath: str) -> None:
        """
        Delete a model without a trainer.

        Parameters
        ----------
        filepath: The path to the model to delete.
        """
        app_state = AppState()
        if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
            # filepath needs to be updated to include mp_rank
            filepath = mridc.utils.model_utils.inject_model_parallel_rank(
                filepath)  # type: ignore

        # each model parallel rank needs to remove its model
        if is_global_rank_zero() or (app_state.model_parallel_size is not None
                                     and app_state.data_parallel_rank == 0):
            try:
                self._fs.rm(filepath)
                logging.info(f"Removed checkpoint: {filepath}")
            except FileNotFoundError:
                logging.info(
                    f"Tried to remove checkpoint: {filepath} but failed.")
Beispiel #4
0
    def test_restore_from_save_restore_connector_extracted_dir(self):
        class MySaveRestoreConnector(save_restore_connector.SaveRestoreConnector):
            def save_to(self, model, save_path: str):
                save_path = save_path.replace(".mridc", "_XYZ.mridc")
                super().save_to(model, save_path)

        class MockModelV2(MockModel):
            pass

        with tempfile.TemporaryDirectory() as extracted_tempdir:
            with tempfile.TemporaryDirectory() as tmpdir:
                # Update config
                cfg = _mock_model_config()

                # Create model
                save_path = os.path.join(tmpdir, "save_custom.mridc")
                model_with_custom_connector = MockModel(cfg=cfg.model, trainer=None)
                model_with_custom_connector._save_restore_connector = MySaveRestoreConnector()
                model_with_custom_connector.save_to(save_path)

                mridc_filepath = os.path.join(tmpdir, "save_custom_XYZ.mridc")
                assert os.path.exists(mridc_filepath)

                # extract the contents to this dir apriori
                # simulate by extracting now before calling restore_from
                connector = MySaveRestoreConnector()
                MySaveRestoreConnector._unpack_mridc_file(mridc_filepath, extracted_tempdir)
                assert get_size(extracted_tempdir) > 0

            # delete the old directory and preserve only the new extracted directory (escape scope of old dir)

            # next, set the model's extracted directory path
            connector.model_extracted_dir = extracted_tempdir

            # note, we pass in the "old" mridc_filepath, stored somewhere other than the extracted directory
            # this mridc_filepath is no longer valid, and has been deleted.
            restored_model = MockModelV2.restore_from(mridc_filepath, save_restore_connector=connector)
        assert type(restored_model) == MockModelV2
        assert type(restored_model._save_restore_connector) == MySaveRestoreConnector

        # assert models have correct restoration information and paths
        appstate = AppState()
        original_metadata = appstate.get_model_metadata_from_guid(model_with_custom_connector.model_guid)
        assert original_metadata.restoration_path is None

        restored_metadata = appstate.get_model_metadata_from_guid(restored_model.model_guid)
        assert restored_metadata.restoration_path is not None

        # assert that the restore path was the path of the pre-extracted directory
        # irrespective of whether an old `mridc_filepath` (which doesnt exist anymore) was passed to restore_from.
        assert extracted_tempdir in restored_metadata.restoration_path
        assert extracted_tempdir not in mridc_filepath
        assert not os.path.exists(mridc_filepath)

        # test for parameter equality
        model_with_custom_connector = model_with_custom_connector.to("cpu")
        restored_model = restored_model.to("cpu")

        original_state_dict = model_with_custom_connector.state_dict()
        restored_state_dict = restored_model.state_dict()
        for orig, restored in zip(original_state_dict.keys(), restored_state_dict.keys()):
            assert (original_state_dict[orig] - restored_state_dict[restored]).abs().mean() < 1e-6
Beispiel #5
0
    def load_config_and_state_dict(
        self,
        calling_cls,
        restore_path: str,
        override_config_path: Optional[Union[OmegaConf, str]] = None,
        map_location: Optional[torch.device] = None,
        strict: bool = True,
        return_config: bool = False,
        trainer: Trainer = None,
    ):
        """
        Restores model instance (weights and configuration) into .mridc file

        Parameters
        ----------
        calling_cls: Class of the model to be restored.
        restore_path: path to .mridc file from which model should be instantiated
        override_config_path: path to a yaml config that will override the internal config file or an
        OmegaConf/DictConfig object representing the model config.
        map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will
        select a GPU if available, falling back to CPU otherwise.
        strict: Passed to load_state_dict. By default, True.
        return_config: If set to true, will return just the underlying config of the restored model as an OmegaConf
        DictConfig object without instantiating the model.
        trainer: Optional trainer object to be used for model parallelism.

        Example
        -------
            ```
            model = mridc.collections.asr.models.EncDecCTCModel.restore_from('asr.mridc')
            assert isinstance(model, mridc.collections.asr.models.EncDecCTCModel)
            ```

        Returns
        -------
        An instance of type cls or its underlying config (if return_config is set).
        """
        # Get path where the command is executed - the artifacts will be "retrieved" there
        # (original .mridc behavior)
        cwd = os.getcwd()

        if map_location is None:
            if torch.cuda.is_available():
                map_location = torch.device("cuda")
            else:
                map_location = torch.device("cpu")

        app_state = AppState()
        with tempfile.TemporaryDirectory() as tmpdir:
            try:
                # Check if self.model_extracted_dir is set, and is a valid path
                if self.model_extracted_dir is not None and os.path.isdir(
                        self.model_extracted_dir):
                    # Log that MRIDC will use the provided `model_extracted_dir`
                    logging.info(
                        "Restoration will occur within pre-extracted directory : "
                        f"`{self.model_extracted_dir}`.")
                    # Override `tmpdir` above with the pre-extracted `model_extracted_dir`
                    tmpdir = self.model_extracted_dir
                else:
                    # Extract the nemo file into the temporary directory
                    self._unpack_mridc_file(path2file=restore_path,
                                            out_folder=tmpdir)

                # Change current working directory to the temporary directory
                os.chdir(tmpdir)
                if override_config_path is None:
                    config_yaml = os.path.join(tmpdir, self.model_config_yaml)
                else:
                    # can be str path or OmegaConf / DictConfig object
                    config_yaml = override_config_path
                if not isinstance(config_yaml, (OmegaConf, DictConfig)):
                    conf = OmegaConf.load(config_yaml)
                else:
                    conf = config_yaml
                    if override_config_path is not None:
                        # Resolve the override config
                        conf = OmegaConf.to_container(conf, resolve=True)
                        conf = OmegaConf.create(conf)
                # If override is top level config, extract just `model` from it
                if "model" in conf:
                    conf = conf.model

                if return_config:
                    instance = conf
                    return instance
                if app_state.model_parallel_rank is not None and app_state.model_parallel_size > 1:
                    model_weights = self._inject_model_parallel_rank_for_ckpt(
                        tmpdir, self.model_weights_ckpt)
                else:
                    model_weights = os.path.join(tmpdir,
                                                 self.model_weights_ckpt)
                OmegaConf.set_struct(conf, True)
                os.chdir(cwd)
                # get the class
                calling_cls._set_model_restore_state(
                    is_being_restored=True, folder=tmpdir)  # type: ignore
                instance = calling_cls.from_config_dict(config=conf,
                                                        trainer=trainer)
                instance = instance.to(map_location)
                # add load_state_dict override
                if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
                    model_weights = self._inject_model_parallel_rank_for_ckpt(
                        tmpdir, self.model_weights_ckpt)
                instance.load_state_dict(self._load_state_dict_from_disk(
                    model_weights, map_location=map_location),
                                         strict=strict)
                logging.info(
                    f"Model {instance.__class__.__name__} was successfully restored from {restore_path}."
                )
                instance._set_model_restore_state(
                    is_being_restored=False)  # type: ignore
            finally:
                os.chdir(cwd)

        return instance
Beispiel #6
0
    def _handle_artifacts(self, model, mridc_file_folder):
        """
        This method is called by ModelPT.save_to() and ModelPT.load_from(). It will handle all artifacts and save them
        to the mridc file.

        Parameters
        ----------
        model: ModelPT object to register artifact for.
        mridc_file_folder: Path to the mridc file.
        """
        tarfile_artifacts = []
        app_state = AppState()
        for conf_path, artiitem in model.artifacts.items():
            if artiitem.path_type == mridc.utils.model_utils.ArtifactPathType.LOCAL_PATH:
                if not os.path.exists(artiitem.path):
                    raise FileNotFoundError(
                        f"Artifact {conf_path} not found at location: {artiitem.path}"
                    )

                # Generate new uniq artifact name and copy it to mridc_file_folder
                # Note uuid.uuid4().hex is guaranteed to be 32 character long
                artifact_base_name = os.path.basename(artiitem.path)
                artifact_uniq_name = f"{uuid.uuid4().hex}_{artifact_base_name}"
                shutil.copy2(
                    artiitem.path,
                    os.path.join(mridc_file_folder, artifact_uniq_name))

                # Update artifacts registry
                artiitem.hashed_path = f"mridc:{artifact_uniq_name}"
                model.artifacts[conf_path] = artiitem

            elif artiitem.path_type == mridc.utils.model_utils.ArtifactPathType.TAR_PATH:
                # process all tarfile artifacts in one go, so preserve key-value pair
                tarfile_artifacts.append((conf_path, artiitem))

            else:
                raise ValueError(
                    "Directly referencing artifacts from other mridc files isn't supported yet"
                )

        # Process current tarfile artifacts by unpacking the previous tarfile and extract the artifacts
        # that are currently required.
        model_metadata = app_state.get_model_metadata_from_guid(
            model.model_guid)
        if tarfile_artifacts and model_metadata.restoration_path is not None:
            # Need to step into mridc archive to extract file
            # Get path where the command is executed - the artifacts will be "retrieved" there
            # (original .mridc behavior)
            cwd = os.getcwd()
            try:
                # Step into the mridc archive to try and find the file
                with tempfile.TemporaryDirectory() as archive_dir:
                    self._unpack_mridc_file(
                        path2file=model_metadata.restoration_path,
                        out_folder=archive_dir)
                    os.chdir(archive_dir)
                    for conf_path, artiitem in tarfile_artifacts:
                        # Get basename and copy it to mridc_file_folder
                        if "mridc:" in artiitem.path:
                            artifact_base_name = artiitem.path.split(
                                "mridc:")[1]
                        else:
                            artifact_base_name = os.path.basename(
                                artiitem.path)
                        # no need to hash here as we are in tarfile_artifacts which are already hashed
                        artifact_uniq_name = artifact_base_name
                        shutil.copy2(
                            artifact_base_name,
                            os.path.join(mridc_file_folder,
                                         artifact_uniq_name))

                        # Update artifacts registry
                        new_artiitem = mridc.utils.model_utils.ArtifactItem()
                        new_artiitem.path = f"mridc:{artifact_uniq_name}"
                        new_artiitem.path_type = mridc.utils.model_utils.ArtifactPathType.TAR_PATH
                        model.artifacts[conf_path] = new_artiitem
            finally:
                # change back working directory
                os.chdir(cwd)
Beispiel #7
0
    def register_artifact(model,
                          config_path: str,
                          src: str,
                          verify_src_exists: bool = True):
        """
        Register model artifacts with this function. These artifacts (files) will be included inside .mridc file
        when model.save_to("mymodel.mridc") is called.

        How it works:
        1. It always returns existing absolute path which can be used during Model constructor call. EXCEPTION: src is
        None or "" in which case nothing will be done and src will be returned
        2. It will add (config_path, model_utils.ArtifactItem()) pair to self.artifacts. If "src" is local existing
        path, then it will be returned in absolute path form. elif "src" starts with "mridc_file:unique_artifact_name":
        .mridc will be untarred to a temporary folder location and an actual existing path will be returned else an
        error will be raised.

        WARNING: use .register_artifact calls in your models' constructors.
        The returned path is not guaranteed to exist after you have exited your model's constructor.

        Parameters
        ----------
        model: ModelPT object to register artifact for.
        config_path: Artifact key. Usually corresponds to the model config.
        src: Path to artifact.
        verify_src_exists: If set to False, then the artifact is optional and register_artifact will return None
         even if src is not found. Defaults to True.

        Returns
        --------
        If src is not None or empty it always returns absolute path which is guaranteed to exist during model instance
         life.
        """
        app_state = AppState()

        artifact_item = mridc.utils.model_utils.ArtifactItem()  # type: ignore

        # This is for backward compatibility, if the src objects exists simply inside the tarfile
        # without its key having been overridden, this pathway will be used.
        src_obj_name = os.path.basename(src)
        if app_state.mridc_file_folder is not None:
            src_obj_path = os.path.abspath(
                os.path.join(app_state.mridc_file_folder, src_obj_name))
        else:
            src_obj_path = src_obj_name

        # src is a local existing path - register artifact and return exact same path for usage by the model
        if os.path.exists(os.path.abspath(src)):
            return_path = os.path.abspath(src)
            artifact_item.path_type = mridc.utils.model_utils.ArtifactPathType.LOCAL_PATH  # type: ignore

        elif src.startswith("mridc:"):
            return_path = os.path.abspath(
                os.path.join(app_state.mridc_file_folder, src[5:]))
            artifact_item.path_type = mridc.utils.model_utils.ArtifactPathType.TAR_PATH  # type: ignore

        elif os.path.exists(src_obj_path):
            return_path = src_obj_path
            artifact_item.path_type = mridc.utils.model_utils.ArtifactPathType.TAR_PATH  # type: ignore
        elif verify_src_exists:
            raise FileNotFoundError(
                f"src path does not exist or it is not a path in mridc file. src value I got was: {src}. "
                f"Absolute: {os.path.abspath(src)}")
        else:
            # artifact is optional and we simply return None
            return None

        if not os.path.exists(return_path):
            raise AssertionError

        artifact_item.path = os.path.abspath(src)
        model.artifacts[config_path] = artifact_item
        # we were called by ModelPT
        if hasattr(model, "cfg"):
            with open_dict(model._cfg):
                OmegaConf.update(model.cfg, config_path, return_path)
        return return_path
Beispiel #8
0
def exp_manager(
        trainer: Trainer,
        cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]:
    """
    exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning \
    paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will \
    get exp_dir, name, and version from the logger. Otherwise, it will use the exp_dir and name arguments to create \
    the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir.

    The version can be a datetime string or an integer. Datetime version can be disabled if you use_datetime_version \
    is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch \
    lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file \
    for each process to log their output into.

    exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from \
    the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need \
    multiple consecutive jobs), you need to avoid creating the version folders. Therefore, from v1.0.0, when \
    resume_if_exists is set to True, creating the version folders is ignored.

    Parameters
    ----------
    trainer: The lightning trainer object.
    cfg: Can have the following keys:
        - explicit_log_dir: Can be used to override exp_dir/name/version folder creation. Defaults to None, which \
        will use exp_dir, name, and version to construct the logging directory.
        - exp_dir: The base directory to create the logging directory. Defaults to None, which logs to \
         ./mridc_experiments.
        - name: The name of the experiment. Defaults to None which turns into "default" via name = name or "default".
        - version: The version of the experiment. Defaults to None which uses either a datetime string or lightning's \
         TensorboardLogger system of using version_{int}.
        - use_datetime_version: Whether to use a datetime string for version. Defaults to True.
        - resume_if_exists: Whether this experiment is resuming from a previous run. If True, it sets \
        trainer._checkpoint_connector.resume_from_checkpoint_fit_path so that the trainer should auto-resume. \
        exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. From v1.0.0, when \
        resume_if_exists is True, we would not create version folders to make it easier to find the log folder for \
        next runs.
        - resume_past_end: exp_manager errors out if resume_if_exists is True and a checkpoint matching \*end.ckpt \
        indicating a previous training run fully completed. This behaviour can be disabled, in which case the \
        \*end.ckpt will be loaded by setting resume_past_end to True. Defaults to False.
        - resume_ignore_no_checkpoint: exp_manager errors out if resume_if_exists is True and no checkpoint could be \
         found. This behaviour can be disabled, in which case exp_manager will print a message and continue without \
         restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False.
        - create_tensorboard_logger: Whether to create a tensorboard logger and attach it to the pytorch lightning \
        trainer. Defaults to True.
        - summary_writer_kwargs: A dictionary of kwargs that can be passed to lightning's TensorboardLogger class. \
        Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None.
        - create_wandb_logger: Whether to create a Weights and Biases logger and attach it to the pytorch lightning \
        trainer. Defaults to False.
        - wandb_logger_kwargs: A dictionary of kwargs that can be passed to lightning's WandBLogger class. Note that \
         name and project are required parameters if create_wandb_logger is True. Defaults to None.
        - create_checkpoint_callback: Whether to create a ModelCheckpoint callback and attach it to the pytorch \
        lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most recent \
        checkpoint under \*last.ckpt, and the final checkpoint after training completes under \*end.ckpt. \
        Defaults to True.
        - files_to_copy: A list of files to copy to the experiment logging directory. Defaults to None which copies \
        no files.
        - log_local_rank_0_only: Whether to only create log files for local rank 0. Defaults to False. Set this to \
        True if you are using DDP with many GPUs and do not want many log files in your exp dir.
        - log_global_rank_0_only: Whether to only create log files for global rank 0. Defaults to False. Set this to \
        True if you are using DDP with many GPUs and do not want many log files in your exp dir.

    Returns
    -------
    The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and version.
    """
    # Add rank information to logger
    # Note: trainer.global_rank and trainer.is_global_zero are not set until trainer.fit, so have to hack around it
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    global_rank = trainer.node_rank * trainer.num_devices + local_rank
    logging.rank = global_rank

    if cfg is None:
        logging.error(
            "exp_manager did not receive a cfg argument. It will be disabled.")
        return None

    if trainer.fast_dev_run:
        logging.info(
            "Trainer was called with fast_dev_run. exp_manager will return without any functionality."
        )
        return None

    # Ensure passed cfg is compliant with ExpManagerConfig
    schema = OmegaConf.structured(ExpManagerConfig)
    if isinstance(cfg, dict):
        cfg = OmegaConf.create(cfg)
    elif not isinstance(cfg, DictConfig):
        raise ValueError(
            f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
        )
    cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True))
    cfg = OmegaConf.merge(schema, cfg)

    error_checks(
        trainer, cfg
    )  # Ensures that trainer options are compliant with MRIDC and exp_manager arguments

    log_dir, exp_dir, name, version = get_log_dir(
        trainer=trainer,
        exp_dir=cfg.exp_dir,
        name=cfg.name,
        version=cfg.version,
        explicit_log_dir=cfg.explicit_log_dir,
        use_datetime_version=cfg.use_datetime_version,
        resume_if_exists=cfg.resume_if_exists,
    )

    if cfg.resume_if_exists:
        check_resume(trainer, str(log_dir), cfg.resume_past_end,
                     cfg.resume_ignore_no_checkpoint)

    checkpoint_name = name
    # If name returned from get_log_dir is "", use cfg.name for checkpointing
    if checkpoint_name is None or checkpoint_name == "":
        checkpoint_name = cfg.name or "default"
    cfg.name = name  # Used for configure_loggers so that the log_dir is properly set even if name is ""
    cfg.version = version

    # update app_state with log_dir, exp_dir, etc
    app_state = AppState()
    app_state.log_dir = log_dir
    app_state.exp_dir = exp_dir
    app_state.name = name
    app_state.version = version
    app_state.checkpoint_name = checkpoint_name
    app_state.create_checkpoint_callback = cfg.create_checkpoint_callback
    app_state.checkpoint_callback_params = cfg.checkpoint_callback_params

    # Create the logging directory if it does not exist
    os.makedirs(
        log_dir, exist_ok=True
    )  # Cannot limit creation to global zero as all ranks write to own log file
    logging.info(f"Experiments will be logged at {log_dir}")
    trainer._default_root_dir = log_dir

    if cfg.log_local_rank_0_only is True and cfg.log_global_rank_0_only is True:
        raise ValueError(
            "Cannot set both log_local_rank_0_only and log_global_rank_0_only to True."
            "Please set either one or neither.")

    # This is set if the env var MRIDC_TESTING is set to True.
    mridc_testing = get_envbool(MRIDC_ENV_VARNAME_TESTING, False)

    log_file = log_dir / f"mridc_log_globalrank-{global_rank}_localrank-{local_rank}.txt"

    # Handle logging to file. Logs local rank 0 only
    if local_rank == 0 and cfg.log_local_rank_0_only and not mridc_testing:
        logging.add_file_handler(log_file)
    elif global_rank == 0 and cfg.log_global_rank_0_only and mridc_testing:
        logging.add_file_handler(log_file)
    else:
        logging.add_file_handler(log_file)

    # For some reason, LearningRateLogger requires trainer to have a logger. Safer to create logger on all ranks
    # not just global rank 0.
    if cfg.create_tensorboard_logger or cfg.create_wandb_logger:
        configure_loggers(
            trainer,
            [Path(exp_dir)],
            cfg.name,
            cfg.version,
            cfg.create_tensorboard_logger,
            cfg.summary_writer_kwargs,
            cfg.create_wandb_logger,
            cfg.wandb_logger_kwargs,
        )

    # add loggers timing callbacks
    if cfg.log_step_timing:
        timing_callback = TimingCallback(
            timer_kwargs=cfg.step_timing_kwargs or {})
        trainer.callbacks.insert(0, timing_callback)

    if cfg.create_checkpoint_callback:
        configure_checkpointing(trainer, log_dir, checkpoint_name,
                                cfg.resume_if_exists,
                                cfg.checkpoint_callback_params)

    if is_global_rank_zero():
        # Move files_to_copy to folder and add git information if present
        if cfg.files_to_copy:
            for _file in cfg.files_to_copy:
                copy(Path(_file), log_dir)

        # Create files for cmd args and git info
        with open(log_dir / "cmd-args.log", "w", encoding="utf-8") as _file:
            _file.write(" ".join(sys.argv))

        # Try to get git hash
        git_repo, git_hash = get_git_hash()
        if git_repo:
            with open(log_dir / "git-info.log", "w",
                      encoding="utf-8") as _file:
                _file.write(f"commit hash: {git_hash}")
                _file.write(get_git_diff())

        # Add err_file logging to global_rank zero
        logging.add_err_file_handler(log_dir / "mridc_error_log.txt")

        # Add lightning file logging to global_rank zero
        add_filehandlers_to_pl_logger(log_dir / "lightning_logs.txt",
                                      log_dir / "mridc_error_log.txt")

    return log_dir