Example #1
def _convert_config(cfg: "OmegaConf"):
    """Recursive function converting the configuration from old hydra format to the new one."""
    if not _HAS_HYDRA:
            "This function requires Hydra/OmegaConf and it was not installed.")

    # Get rid of cls -> _target_.
    if "cls" in cfg and "_target_" not in cfg:
        cfg._target_ = cfg.pop("cls")  # type: ignore

    # Get rid of params.
    if "params" in cfg:
        params = cfg.pop("params")  # type: ignore
        for param_key, param_val in params.items():
            cfg[param_key] = param_val

    # Recursion.
        for _, sub_cfg in cfg.items():  # type: ignore
            if isinstance(sub_cfg, DictConfig):
                _convert_config(sub_cfg)  # type: ignore
    except OmegaConfBaseException as e:
            f"Skipped conversion for config/subconfig:\n{cfg}\n Reason: {e}.")
Example #2
def error_checks(trainer: Trainer,
                 cfg: Optional[Union[DictConfig, Dict]] = None):
    Checks that the passed trainer is compliant with MRIDC and exp_manager's passed configuration. Checks that:
        - Throws error when hydra has changed the working directory. This causes issues with lightning's DDP
        - Throws error when trainer has loggers defined but create_tensorboard_logger or create_WandB_logger is True
        - Prints error messages when 1) run on multi-node and not Slurm, and 2) run on multi-gpu without DDP
    if HydraConfig.initialized() and get_original_cwd() != os.getcwd():
        raise ValueError(
            "Hydra changed the working directory. This interferes with ExpManger's functionality. Please pass "
            "hydra.run.dir=. to your python script.")

    if trainer.logger is not None and (
            or cfg.create_wandb_logger):  # type: ignore
        raise LoggerMisconfigurationError(
            "The pytorch lightning trainer that was passed to exp_manager contained a logger, and either "
            "create_tensorboard_logger or create_wandb_logger was set to True. These can only be used if trainer does "
            "not already have a logger.")

    if trainer.num_nodes > 1 and not check_slurm(trainer):  # type: ignore
            "You are running multi-node training without SLURM handling the processes."
            " Please note that this is not tested in MRIDC and could result in errors."

    if trainer.num_devices > 1 and not isinstance(trainer.strategy,
                                                  DDPStrategy):  # type: ignore
            "You are running multi-gpu without ddp.Please note that this is not tested in MRIDC and could result in "
Example #3
def convert_model_config_to_dict_config(
        cfg: Union[DictConfig, MRIDCConfig]) -> DictConfig:
    Converts its input into a standard DictConfig.

    Possible input values are:
        - DictConfig
        - A dataclass which is a subclass of MRIDCConfig

    cfg: A dict-like object.

    The equivalent DictConfig.
    if not _HAS_HYDRA:
            "This function requires Hydra/OmegaConf and it was not installed.")
    if not isinstance(cfg, (OmegaConf, DictConfig)) and is_dataclass(cfg):
        cfg = OmegaConf.structured(cfg)

    if not isinstance(cfg, DictConfig):
        raise ValueError(
            f"cfg constructor argument must be of type DictConfig/dict but got {type(cfg)} instead."

    config = OmegaConf.to_container(cfg, resolve=True)
    config = OmegaConf.create(config)
    return config
Example #4
def _add_subconfig_keys(model_cfg: "DictConfig", update_cfg: "DictConfig",
                        subconfig_key: str):
    For certain sub-configs, the default values specified by the MRIDCConfig class is insufficient.
    In order to support every potential value in the merge between the `update_cfg`, it would require explicit
    definition of all possible cases.
    An example of such a case is Optimizers, and their equivalent Schedulers. All optimizers share a few basic details
    - such as name and lr, but almost all require additional parameters - such as weight decay.
    It is impractical to create a config for every single optimizer + every single scheduler combination.
    In such a case, we perform a dual merge. The Optim and Sched Dataclass contain the bare minimum essential
    components. The extra values are provided via update_cfg.
    In order to enable the merge, we first need to update the update sub-config to incorporate the keys, with dummy
    temporary values (merge update config with model config). This is done on a copy of the update sub-config, as the
    actual override values might be overridden by the MRIDCConfig defaults.
    Then we perform a merge of this temporary sub-config with the actual override config in a later step (merge
    model_cfg with original update_cfg, done outside this function).

    model_cfg: A DictConfig instantiated from the MRIDCConfig subclass.
    update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values.
    subconfig_key: A str key used to check and update the sub-config.

    A ModelPT DictConfig with additional keys added to the sub-config.
    if not _HAS_HYDRA:
            "This function requires Hydra/Omegaconf and it was not installed.")
    with open_dict(model_cfg.model):
        # Create copy of original model sub config
        if subconfig_key in update_cfg.model:
            if subconfig_key not in model_cfg.model:
                # create the key as a placeholder
                model_cfg.model[subconfig_key] = None

            subconfig = copy.deepcopy(model_cfg.model[subconfig_key])
            update_subconfig = copy.deepcopy(update_cfg.model[subconfig_key])

            # Add the keys and update temporary values, will be updated during full merge
            subconfig = OmegaConf.merge(update_subconfig, subconfig)
            # Update sub config
            model_cfg.model[subconfig_key] = subconfig

    return model_cfg
Example #5
def check_explicit_log_dir(trainer: Trainer,
                           explicit_log_dir: List[Union[Path, str]],
                           exp_dir: str, name: str,
                           version: str) -> Tuple[Path, str, str, str]:
    Checks that the passed arguments are compatible with explicit_log_dir.

    trainer: The trainer to check.
    explicit_log_dir: The explicit log dir to check.
    exp_dir: The experiment directory to check.
    name: The experiment name to check.
    version: The experiment version to check.

    The log_dir, exp_dir, name, and version that should be used.

    if trainer.logger is not None:
        raise LoggerMisconfigurationError(
            "The pytorch lightning trainer that was passed to exp_manager contained a logger and explicit_log_dir: "
            f"{explicit_log_dir} was pass to exp_manager. Please remove the logger from the lightning trainer."
    # Checking only (explicit_log_dir) vs (exp_dir and version).
    # The `name` will be used as the actual name of checkpoint/archive.
    if exp_dir or version:
            f"exp_manager received explicit_log_dir: {explicit_log_dir} and at least one of exp_dir: {exp_dir}, "
            f"or version: {version}. Please note that exp_dir, name, and version will be ignored."
    if is_global_rank_zero() and Path(str(explicit_log_dir)).exists():
            f"Exp_manager is logging to {explicit_log_dir}, but it already exists."
    return Path(str(explicit_log_dir)), str(explicit_log_dir), "", ""
Example #6
def maybe_update_config_version(cfg: "DictConfig"):
    Recursively convert Hydra 0.x configs to Hydra 1.x configs.
    Changes include:
    -   `cls` -> `_target_`.
    -   `params` -> drop params and shift all arguments to parent.
    -   `target` -> `_target_` cannot be performed due to ModelPT injecting `target` inside class.

    cfg: Any Hydra compatible DictConfig

    An updated DictConfig that conforms to Hydra 1.x format.
    if not _HAS_HYDRA:
            "This function requires Hydra/OmegaConf and it was not installed.")
    if cfg is not None and not isinstance(cfg, DictConfig):
            temp_cfg = OmegaConf.create(cfg)
            cfg = temp_cfg
        except OmegaConfBaseException:
            # Cannot be cast to DictConfig, skip updating.
            return cfg

    # Make a copy of model config.
    cfg = copy.deepcopy(cfg)
    OmegaConf.set_struct(cfg, False)

    # Convert config.
    _convert_config(cfg)  # type: ignore

    # Update model config.
    OmegaConf.set_struct(cfg, True)

    return cfg
Example #7
def _update_subconfig(model_cfg: "DictConfig", update_cfg: "DictConfig",
                      subconfig_key: str, drop_missing_subconfigs: bool):
    Updates the MRIDCConfig DictConfig such that:
        1)  If the sub-config key exists in the `update_cfg`, but does not exist in ModelPT config:
            - Add the sub-config from update_cfg to ModelPT config
        2) If the sub-config key does not exist in `update_cfg`, but exists in ModelPT config:
            - Remove the sub-config from the ModelPT config; iff the `drop_missing_subconfigs` flag is set.

    model_cfg: A DictConfig instantiated from the MRIDCConfig subclass.
    update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values.
    subconfig_key: A str key used to check and update the sub-config.
    drop_missing_subconfigs: A bool flag, whether to allow deletion of the MRIDCConfig sub-config, if its mirror
    sub-config does not exist in the `update_cfg`.

    The updated DictConfig for the MRIDCConfig
    if not _HAS_HYDRA:
            "This function requires Hydra/Omegaconf and it was not installed.")
    with open_dict(model_cfg.model):
        # If update config has the key, but model cfg doesnt have the key
        # Add the update cfg subconfig to the model cfg
        if subconfig_key in update_cfg.model and subconfig_key not in model_cfg.model:
            model_cfg.model[subconfig_key] = update_cfg.model[subconfig_key]

        # If update config does not the key, but model cfg has the key
        # Remove the model cfg subconfig in order to match layout of update cfg
        if subconfig_key not in update_cfg.model and subconfig_key in model_cfg.model and drop_missing_subconfigs:

    return model_cfg
Example #8
def update_model_config(model_cls: MRIDCConfig,
                        update_cfg: "DictConfig",
                        drop_missing_subconfigs: bool = True):
    Helper class that updates the default values of a ModelPT config class with the values in a DictConfig that \
    mirrors the structure of the config class. Assumes the `update_cfg` is a DictConfig (either generated manually, \
    via hydra or instantiated via yaml/model.cfg). This update_cfg is then used to override the default values \
    preset inside the ModelPT config class. If `drop_missing_subconfigs` is set, the certain sub-configs of the \
    ModelPT config class will be removed, if they are not found in the mirrored `update_cfg`. The following \
    sub-configs are subject to potential removal:
        -   `train_ds`
        -   `validation_ds`
        -   `test_ds`
        -   `optim` + nested sched

    model_cls: A subclass of MRIDC, that details in entirety all the parameters that constitute the MRIDC Model.
    update_cfg: A DictConfig that mirrors the structure of the MRIDCConfig data class. Used to update the default \
    values of the config class.
    drop_missing_subconfigs: Bool which determines whether to drop certain sub-configs from the MRIDCConfig class, \
    if the corresponding sub-config is missing from `update_cfg`.

    A DictConfig with updated values that can be used to instantiate the MRIDC Model along with supporting \
    if not _HAS_HYDRA:
            "This function requires Hydra/Omegaconf and it was not installed.")
    if not (is_dataclass(model_cls) or isinstance(model_cls, DictConfig)):
        raise ValueError(
            "`model_cfg` must be a dataclass or a structured OmegaConf object")

    if not isinstance(update_cfg, DictConfig):
        update_cfg = OmegaConf.create(update_cfg)

    if is_dataclass(model_cls):
        model_cls = OmegaConf.structured(model_cls)

    # Update optional configs
    model_cls = _update_subconfig(
    model_cls = _update_subconfig(
    model_cls = _update_subconfig(
    model_cls = _update_subconfig(

    # Add optim and sched additional keys to model cls
    model_cls = _add_subconfig_keys(model_cls,

    # Perform full merge of model config class and update config
    # Remove ModelPT artifact `target`
    if "target" in update_cfg.model and "target" not in model_cls.model:  # type: ignore
        with open_dict(update_cfg.model):

    # Remove ModelPT artifact `mridc_version`
    if "mridc_version" in update_cfg.model and "mridc_version" not in model_cls.model:  # type: ignore
        with open_dict(update_cfg.model):

    return OmegaConf.merge(model_cls, update_cfg)
Example #9
def assert_dataclass_signature_match(
    cls: "class_type",  # type: ignore
    datacls: "dataclass",  # type: ignore
    ignore_args: Optional[List[str]] = None,
    remap_args: Optional[Dict[str, str]] = None,
    Analyses the signature of a provided class and its respective data class,
    asserting that the dataclass signature matches the class __init__ signature.
        This is not a value based check. This function only checks if all argument
        names exist on both class and dataclass and logs mismatches.

    cls: Any class type - but not an instance of a class. Pass type(x) where x is an instance
        if class type is not easily available.
    datacls: A corresponding dataclass for the above class.
    ignore_args: (Optional) A list of string argument names which are forcibly ignored,
        even if mismatched in the signature. Useful when a dataclass is a superset of the
        arguments of a class.
    remap_args: (Optional) A dictionary, mapping an argument name that exists (in either the
        class or its dataclass), to another name. Useful when argument names are mismatched between
        a class and its dataclass due to indirect instantiation via a helper method.

    A tuple containing information about the analysis:
        1) A bool value which is True if the signatures matched exactly / after ignoring values.
            False otherwise.
        2) A set of arguments names that exist in the class, but *do not* exist in the dataclass.
            If exact signature match occurs, this will be None instead.
        3) A set of argument names that exist in the data class, but *do not* exist in the class itself.
            If exact signature match occurs, this will be None instead.
    class_sig = inspect.signature(cls.__init__)

    class_params = dict(**class_sig.parameters)

    dataclass_sig = inspect.signature(datacls)

    dataclass_params = dict(**dataclass_sig.parameters)
    dataclass_params.pop("_target_", None)

    class_params = set(class_params.keys())  # type: ignore
    dataclass_params = set(dataclass_params.keys())  # type: ignore

    if remap_args is not None:
        for original_arg, new_arg in remap_args.items():
            if original_arg in class_params:
                class_params.remove(original_arg)  # type: ignore
                class_params.add(new_arg)  # type: ignore
                    f"Remapped {original_arg} -> {new_arg} in {cls.__name__}")

            if original_arg in dataclass_params:
                dataclass_params.remove(original_arg)  # type: ignore
                dataclass_params.add(new_arg)  # type: ignore
                    f"Remapped {original_arg} -> {new_arg} in {datacls.__name__}"

    if ignore_args is not None:
        ignore_args = set(ignore_args)  # type: ignore

        class_params = class_params - ignore_args  # type: ignore
        dataclass_params = dataclass_params - ignore_args  # type: ignore
        logging.info(f"Removing ignored arguments - {ignore_args}")

    intersection: Set[type] = set.intersection(
        class_params, dataclass_params)  # type: ignore
    subset_cls = class_params - intersection  # type: ignore
    subset_datacls = dataclass_params - intersection  # type: ignore

    if (len(class_params) != len(dataclass_params)
        ) or len(subset_cls) > 0 or len(subset_datacls) > 0:
        logging.error(f"Class {cls.__name__} arguments do not match "
                      f"Dataclass {datacls.__name__}!")

        if len(subset_cls) > 0:
            logging.error(f"Class {cls.__name__} has additional arguments :\n"

        if len(subset_datacls):
                f"Dataclass {datacls.__name__} has additional arguments :\n{subset_datacls}"

        return False, subset_cls, subset_datacls
    return True, None, None
Example #10
    def from_config_dict(cls, config: "DictConfig", trainer: Optional[Trainer] = None):
        """Instantiates object using DictConfig-based configuration"""
        # Resolve the config dict
        if _HAS_HYDRA:
            if isinstance(config, DictConfig):
                config = OmegaConf.to_container(config, resolve=True)
                config = OmegaConf.create(config)
                OmegaConf.set_struct(config, True)

            config = mridc.utils.model_utils.maybe_update_config_version(config)  # type: ignore

        # Hydra 0.x API
        if ("cls" in config or "target" in config) and "params" in config and _HAS_HYDRA:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
        elif "_target_" in config and _HAS_HYDRA:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
            instance = None
            prev_error = ""

            # Attempt class path resolution from config `target` class (if it exists)
            if "target" in config:
                target_cls = config["target"]  # No guarantee that this is a omegaconf class
                imported_cls = None
                    # try to import the target class
                    imported_cls = mridc.utils.model_utils.import_class_by_path(target_cls)  # type: ignore

                    # use subclass instead
                    if issubclass(cls, imported_cls):
                        imported_cls = cls
                        if accepts_trainer := Serialization._inspect_signature_for_trainer(imported_cls):
                            if trainer is None:
                                # Create a dummy PL trainer object
                                cfg_trainer = TrainerConfig(
                                    gpus=1, accelerator="ddp", num_nodes=1, logger=False, checkpoint_callback=False
                                trainer = Trainer(cfg_trainer)
                            instance = imported_cls(cfg=config, trainer=trainer)  # type: ignore
                            instance = imported_cls(cfg=config)  # type: ignore

                except Exception as e:
                    tb = traceback.format_exc()
                    prev_error = f"Model instantiation failed.\nTarget class: {target_cls}\nError: {e}\n{tb}"

                    logging.debug(prev_error + "\n falling back to 'cls'.")
            # target class resolution was unsuccessful, fall back to current `cls`
            if instance is None:
                    if accepts_trainer := Serialization._inspect_signature_for_trainer(cls):
                        instance = cls(cfg=config)  # type: ignore
                except Exception as e:
                    # report saved errors, if any, and raise the current error
                    if prev_error:
                    raise e from e

        if not hasattr(instance, "_cfg"):
            instance._cfg = config
        return instance
Example #11
def resolve_test_dataloaders(model: "ModelPT"):
    Helper method that operates on the ModelPT class to automatically support
    multiple dataloaders for the test set.
    It does so by first resolving the path to one/more data files via `resolve_dataset_name_from_cfg()`.
    If this resolution fails, it assumes the data loader is prepared to manually support / not support
    multiple data loaders and simply calls the appropriate setup method.
    If resolution succeeds:
        Checks if provided path is to a single file or a list of files.
        If a single file is provided, simply tags that file as such and loads it via the setup method.
        If multiple files are provided:
            Inject a new manifest path at index "i" into the resolved key.
            Calls the appropriate setup method to set the data loader.
            Collects the initialized data loader in a list and preserves it.
            Once all data loaders are processed, assigns the list of loaded loaders to the ModelPT.
            Finally, assigns a list of unique names resolved from the file paths to the ModelPT.

    model: ModelPT subclass, which requires >=1 Test Dataloaders to be setup.
    if not _HAS_HYDRA:
            "This function requires Hydra/OmegaConf and it was not installed.")
    cfg = copy.deepcopy(model._cfg)
    dataloaders: List[Any] = []

    # process test_loss_idx
    if "test_dl_idx" in cfg.test_ds:
        cfg = OmegaConf.to_container(cfg)
        test_dl_idx = cfg["test_ds"].pop("test_dl_idx")
        cfg = OmegaConf.create(cfg)
        test_dl_idx = 0

    # Set val_loss_idx
    model._test_dl_idx = test_dl_idx

    ds_key = resolve_dataset_name_from_cfg(cfg.test_ds)

    if ds_key is None:
            f"Could not resolve file path from provided config - {cfg.test_ds}. "
            "Disabling support for multi-dataloaders.")


    ds_values = cfg.test_ds[ds_key]

    if isinstance(ds_values, (list, tuple, ListConfig)):

        for ds_value in ds_values:
            cfg.test_ds[ds_key] = ds_value

        model.test_dl = dataloaders  # type: ignore
        model.test_names = [parse_dataset_as_name(ds)
                            for ds in ds_values]  # type: ignore

    model.test_names = [parse_dataset_as_name(ds_values)]

Example #12
def configure_checkpointing(trainer: Trainer, log_dir: Path, name: str,
                            resume: bool, params: "DictConfig"):
    """Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint
    callback or if trainer.weights_save_path was passed to Trainer.
    for callback in trainer.callbacks:
        if isinstance(callback, ModelCheckpoint):
            raise CheckpointMisconfigurationError(
                "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint "
                "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback "
                "to False, or remove ModelCheckpoint from the lightning trainer"
    if Path(trainer.weights_save_path) != Path.cwd():
        raise CheckpointMisconfigurationError(
            "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager"

    # Create the callback and attach it to trainer
    if "filepath" in params:
        if params.filepath is not None:
                "filepath is deprecated. Please switch to dirpath and filename instead"
            if params.dirpath is None:
                params.dirpath = Path(params.filepath).parent
            if params.filename is None:
                params.filename = Path(params.filepath).name
        with open_dict(params):
            del params["filepath"]
    if params.dirpath is None:
        params.dirpath = Path(log_dir / "checkpoints")
    if params.filename is None:
        params.filename = f"{name}--{{{params.monitor}:.4f}}-{{epoch}}"
    if params.prefix is None:
        params.prefix = name
    MRIDCModelCheckpoint.CHECKPOINT_NAME_LAST = f"{params.filename}-last"


    if "val" in params.monitor:
        if (trainer.max_epochs is not None and trainer.max_epochs != -1
                and trainer.max_epochs < trainer.check_val_every_n_epoch):
                "The checkpoint callback was told to monitor a validation value but trainer.max_epochs("
                f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch("
                f"{trainer.check_val_every_n_epoch}). It is very likely this run will fail with "
                f"ModelCheckpoint(monitor='{params.monitor}') not found in the returned metrics. Please ensure that "
                "validation is run within trainer.max_epochs.")
        elif trainer.max_steps is not None:
                "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to "
                f"{trainer.max_steps}. Please ensure that max_steps will run for at least "
                f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out."

    checkpoint_callback = MRIDCModelCheckpoint(n_resume=resume, **params)
    checkpoint_callback.last_model_path = trainer._checkpoint_connector.resume_from_checkpoint_fit_path or ""
    if "mp_rank" in checkpoint_callback.last_model_path or "tp_rank" in checkpoint_callback.last_model_path:
        checkpoint_callback.last_model_path = mridc.utils.model_utils.uninject_model_parallel_rank(  # type: ignore
Example #13
def exp_manager(
        trainer: Trainer,
        cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]:
    exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning \
    paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will \
    get exp_dir, name, and version from the logger. Otherwise, it will use the exp_dir and name arguments to create \
    the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir.

    The version can be a datetime string or an integer. Datetime version can be disabled if you use_datetime_version \
    is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch \
    lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file \
    for each process to log their output into.

    exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from \
    the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need \
    multiple consecutive jobs), you need to avoid creating the version folders. Therefore, from v1.0.0, when \
    resume_if_exists is set to True, creating the version folders is ignored.

    trainer: The lightning trainer object.
    cfg: Can have the following keys:
        - explicit_log_dir: Can be used to override exp_dir/name/version folder creation. Defaults to None, which \
        will use exp_dir, name, and version to construct the logging directory.
        - exp_dir: The base directory to create the logging directory. Defaults to None, which logs to \
        - name: The name of the experiment. Defaults to None which turns into "default" via name = name or "default".
        - version: The version of the experiment. Defaults to None which uses either a datetime string or lightning's \
         TensorboardLogger system of using version_{int}.
        - use_datetime_version: Whether to use a datetime string for version. Defaults to True.
        - resume_if_exists: Whether this experiment is resuming from a previous run. If True, it sets \
        trainer._checkpoint_connector.resume_from_checkpoint_fit_path so that the trainer should auto-resume. \
        exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. From v1.0.0, when \
        resume_if_exists is True, we would not create version folders to make it easier to find the log folder for \
        next runs.
        - resume_past_end: exp_manager errors out if resume_if_exists is True and a checkpoint matching \*end.ckpt \
        indicating a previous training run fully completed. This behaviour can be disabled, in which case the \
        \*end.ckpt will be loaded by setting resume_past_end to True. Defaults to False.
        - resume_ignore_no_checkpoint: exp_manager errors out if resume_if_exists is True and no checkpoint could be \
         found. This behaviour can be disabled, in which case exp_manager will print a message and continue without \
         restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False.
        - create_tensorboard_logger: Whether to create a tensorboard logger and attach it to the pytorch lightning \
        trainer. Defaults to True.
        - summary_writer_kwargs: A dictionary of kwargs that can be passed to lightning's TensorboardLogger class. \
        Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None.
        - create_wandb_logger: Whether to create a Weights and Biases logger and attach it to the pytorch lightning \
        trainer. Defaults to False.
        - wandb_logger_kwargs: A dictionary of kwargs that can be passed to lightning's WandBLogger class. Note that \
         name and project are required parameters if create_wandb_logger is True. Defaults to None.
        - create_checkpoint_callback: Whether to create a ModelCheckpoint callback and attach it to the pytorch \
        lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most recent \
        checkpoint under \*last.ckpt, and the final checkpoint after training completes under \*end.ckpt. \
        Defaults to True.
        - files_to_copy: A list of files to copy to the experiment logging directory. Defaults to None which copies \
        no files.
        - log_local_rank_0_only: Whether to only create log files for local rank 0. Defaults to False. Set this to \
        True if you are using DDP with many GPUs and do not want many log files in your exp dir.
        - log_global_rank_0_only: Whether to only create log files for global rank 0. Defaults to False. Set this to \
        True if you are using DDP with many GPUs and do not want many log files in your exp dir.

    The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and version.
    # Add rank information to logger
    # Note: trainer.global_rank and trainer.is_global_zero are not set until trainer.fit, so have to hack around it
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    global_rank = trainer.node_rank * trainer.num_devices + local_rank
    logging.rank = global_rank

    if cfg is None:
            "exp_manager did not receive a cfg argument. It will be disabled.")
        return None

    if trainer.fast_dev_run:
            "Trainer was called with fast_dev_run. exp_manager will return without any functionality."
        return None

    # Ensure passed cfg is compliant with ExpManagerConfig
    schema = OmegaConf.structured(ExpManagerConfig)
    if isinstance(cfg, dict):
        cfg = OmegaConf.create(cfg)
    elif not isinstance(cfg, DictConfig):
        raise ValueError(
            f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
    cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True))
    cfg = OmegaConf.merge(schema, cfg)

        trainer, cfg
    )  # Ensures that trainer options are compliant with MRIDC and exp_manager arguments

    log_dir, exp_dir, name, version = get_log_dir(

    if cfg.resume_if_exists:
        check_resume(trainer, str(log_dir), cfg.resume_past_end,

    checkpoint_name = name
    # If name returned from get_log_dir is "", use cfg.name for checkpointing
    if checkpoint_name is None or checkpoint_name == "":
        checkpoint_name = cfg.name or "default"
    cfg.name = name  # Used for configure_loggers so that the log_dir is properly set even if name is ""
    cfg.version = version

    # update app_state with log_dir, exp_dir, etc
    app_state = AppState()
    app_state.log_dir = log_dir
    app_state.exp_dir = exp_dir
    app_state.name = name
    app_state.version = version
    app_state.checkpoint_name = checkpoint_name
    app_state.create_checkpoint_callback = cfg.create_checkpoint_callback
    app_state.checkpoint_callback_params = cfg.checkpoint_callback_params

    # Create the logging directory if it does not exist
        log_dir, exist_ok=True
    )  # Cannot limit creation to global zero as all ranks write to own log file
    logging.info(f"Experiments will be logged at {log_dir}")
    trainer._default_root_dir = log_dir

    if cfg.log_local_rank_0_only is True and cfg.log_global_rank_0_only is True:
        raise ValueError(
            "Cannot set both log_local_rank_0_only and log_global_rank_0_only to True."
            "Please set either one or neither.")

    # This is set if the env var MRIDC_TESTING is set to True.
    mridc_testing = get_envbool(MRIDC_ENV_VARNAME_TESTING, False)

    log_file = log_dir / f"mridc_log_globalrank-{global_rank}_localrank-{local_rank}.txt"

    # Handle logging to file. Logs local rank 0 only
    if local_rank == 0 and cfg.log_local_rank_0_only and not mridc_testing:
    elif global_rank == 0 and cfg.log_global_rank_0_only and mridc_testing:

    # For some reason, LearningRateLogger requires trainer to have a logger. Safer to create logger on all ranks
    # not just global rank 0.
    if cfg.create_tensorboard_logger or cfg.create_wandb_logger:

    # add loggers timing callbacks
    if cfg.log_step_timing:
        timing_callback = TimingCallback(
            timer_kwargs=cfg.step_timing_kwargs or {})
        trainer.callbacks.insert(0, timing_callback)

    if cfg.create_checkpoint_callback:
        configure_checkpointing(trainer, log_dir, checkpoint_name,

    if is_global_rank_zero():
        # Move files_to_copy to folder and add git information if present
        if cfg.files_to_copy:
            for _file in cfg.files_to_copy:
                copy(Path(_file), log_dir)

        # Create files for cmd args and git info
        with open(log_dir / "cmd-args.log", "w", encoding="utf-8") as _file:
            _file.write(" ".join(sys.argv))

        # Try to get git hash
        git_repo, git_hash = get_git_hash()
        if git_repo:
            with open(log_dir / "git-info.log", "w",
                      encoding="utf-8") as _file:
                _file.write(f"commit hash: {git_hash}")

        # Add err_file logging to global_rank zero
        logging.add_err_file_handler(log_dir / "mridc_error_log.txt")

        # Add lightning file logging to global_rank zero
        add_filehandlers_to_pl_logger(log_dir / "lightning_logs.txt",
                                      log_dir / "mridc_error_log.txt")

    return log_dir
Example #14
    def export(
        output: str,
        try_script: bool = False,
        check_trace: bool = False,
        use_dynamic_axes: bool = True,
        Export the module to a file.

        output: The output file path.
        input_example: A dictionary of input names and values.
        verbose: If True, print out the export process.
        export_params: If True, export the parameters of the module.
        do_constant_folding: If True, do constant folding.
        onnx_opset_version: The ONNX opset version to use.
        try_script: If True, try to export as TorchScript.
        training: Training mode for the export.
        check_trace: If True, check the trace of the exported model.
        use_dynamic_axes: If True, use dynamic axes for the export.
        dynamic_axes: A dictionary of input names and dynamic axes.
        check_tolerance: The tolerance for the check_trace.
        my_args = locals().copy()

        exportables = []
        for m in self.modules():  # type: ignore
            if isinstance(m, Exportable):

        qual_name = self.__module__ + "." + self.__class__.__qualname__
        format = get_export_format(output)
        output_descr = f"{qual_name} exported to {format}"

        # Pytorch's default for None is too low, can't pass None through
        if onnx_opset_version is None:
            onnx_opset_version = 13

            # Disable typechecks

            # Allow user to completely override forward method to export
            forward_method, old_forward_method = wrap_forward_method(self)

            # Set module mode
            with torch.onnx.select_model_mode_for_export(
                    self, training), torch.inference_mode(
                    ), torch.jit.optimized_execution(True):

                if input_example is None:
                    input_example = self.input_module.input_example()

                # Remove i/o examples from args we propagate to enclosed Exportables

                # Run (possibly overridden) prepare methods before calling forward()
                for ex in exportables:
                    ex._prepare_for_export(**my_args, noreplace=True)

                input_list, input_dict = parse_input_example(input_example)
                input_names = self.input_names
                output_names = self.output_names
                output_example = tuple(self.forward(
                    *input_list, **input_dict))  # type: ignore

                jitted_model = None
                if try_script:
                        jitted_model = torch.jit.script(self)
                    except Exception as e:
                        logging.error(f"jit.script() failed!\n{e}")

                if format == ExportFormat.TORCHSCRIPT:
                    if jitted_model is None:
                        jitted_model = torch.jit.trace_module(
                                tuple(input_list) + tuple(input_dict.values())
                    if not self.training:  # type: ignore
                        jitted_model = torch.jit.optimize_for_inference(
                    if verbose:
                        logging.info(f"JIT code:\n{jitted_model.code}")
                elif format == ExportFormat.ONNX:
                    if jitted_model is None:
                        jitted_model = self

                    # dynamic axis is a mapping from input/output_name => list of "dynamic" indices
                    if dynamic_axes is None and use_dynamic_axes:
                        dynamic_axes = get_dynamic_axes(
                            self.input_module.input_types, input_names)


                    if check_trace:
                        verify_runtime(output, input_list, input_dict,
                                       input_names, output_names,

                    raise ValueError(
                        f"Encountered unknown export format {format}.")
            if forward_method:
                type(self).forward = old_forward_method  # type: ignore
        return [output], [output_descr]