コード例 #1
0
def _convert_config(cfg: "OmegaConf"):
    """Recursive function converting the configuration from old hydra format to the new one."""
    if not _HAS_HYDRA:
        logging.error(
            "This function requires Hydra/OmegaConf and it was not installed.")
        sys.exit(1)

    # Get rid of cls -> _target_.
    if "cls" in cfg and "_target_" not in cfg:
        cfg._target_ = cfg.pop("cls")  # type: ignore

    # Get rid of params.
    if "params" in cfg:
        params = cfg.pop("params")  # type: ignore
        for param_key, param_val in params.items():
            cfg[param_key] = param_val

    # Recursion.
    try:
        for _, sub_cfg in cfg.items():  # type: ignore
            if isinstance(sub_cfg, DictConfig):
                _convert_config(sub_cfg)  # type: ignore
    except OmegaConfBaseException as e:
        logging.warning(
            f"Skipped conversion for config/subconfig:\n{cfg}\n Reason: {e}.")
コード例 #2
0
def error_checks(trainer: Trainer,
                 cfg: Optional[Union[DictConfig, Dict]] = None):
    """
    Checks that the passed trainer is compliant with MRIDC and exp_manager's passed configuration. Checks that:
        - Throws error when hydra has changed the working directory. This causes issues with lightning's DDP
        - Throws error when trainer has loggers defined but create_tensorboard_logger or create_WandB_logger is True
        - Prints error messages when 1) run on multi-node and not Slurm, and 2) run on multi-gpu without DDP
    """
    if HydraConfig.initialized() and get_original_cwd() != os.getcwd():
        raise ValueError(
            "Hydra changed the working directory. This interferes with ExpManger's functionality. Please pass "
            "hydra.run.dir=. to your python script.")

    if trainer.logger is not None and (
            cfg.create_tensorboard_logger
            or cfg.create_wandb_logger):  # type: ignore
        raise LoggerMisconfigurationError(
            "The pytorch lightning trainer that was passed to exp_manager contained a logger, and either "
            "create_tensorboard_logger or create_wandb_logger was set to True. These can only be used if trainer does "
            "not already have a logger.")

    if trainer.num_nodes > 1 and not check_slurm(trainer):  # type: ignore
        logging.error(
            "You are running multi-node training without SLURM handling the processes."
            " Please note that this is not tested in MRIDC and could result in errors."
        )

    if trainer.num_devices > 1 and not isinstance(trainer.strategy,
                                                  DDPStrategy):  # type: ignore
        logging.error(
            "You are running multi-gpu without ddp.Please note that this is not tested in MRIDC and could result in "
            "errors.")
コード例 #3
0
def convert_model_config_to_dict_config(
        cfg: Union[DictConfig, MRIDCConfig]) -> DictConfig:
    """
    Converts its input into a standard DictConfig.

    Possible input values are:
        - DictConfig
        - A dataclass which is a subclass of MRIDCConfig

    Parameters
    ----------
    cfg: A dict-like object.

    Returns
    -------
    The equivalent DictConfig.
    """
    if not _HAS_HYDRA:
        logging.error(
            "This function requires Hydra/OmegaConf and it was not installed.")
        sys.exit(1)
    if not isinstance(cfg, (OmegaConf, DictConfig)) and is_dataclass(cfg):
        cfg = OmegaConf.structured(cfg)

    if not isinstance(cfg, DictConfig):
        raise ValueError(
            f"cfg constructor argument must be of type DictConfig/dict but got {type(cfg)} instead."
        )

    config = OmegaConf.to_container(cfg, resolve=True)
    config = OmegaConf.create(config)
    return config
コード例 #4
0
def _add_subconfig_keys(model_cfg: "DictConfig", update_cfg: "DictConfig",
                        subconfig_key: str):
    """
    For certain sub-configs, the default values specified by the MRIDCConfig class is insufficient.
    In order to support every potential value in the merge between the `update_cfg`, it would require explicit
    definition of all possible cases.
    An example of such a case is Optimizers, and their equivalent Schedulers. All optimizers share a few basic details
    - such as name and lr, but almost all require additional parameters - such as weight decay.
    It is impractical to create a config for every single optimizer + every single scheduler combination.
    In such a case, we perform a dual merge. The Optim and Sched Dataclass contain the bare minimum essential
    components. The extra values are provided via update_cfg.
    In order to enable the merge, we first need to update the update sub-config to incorporate the keys, with dummy
    temporary values (merge update config with model config). This is done on a copy of the update sub-config, as the
    actual override values might be overridden by the MRIDCConfig defaults.
    Then we perform a merge of this temporary sub-config with the actual override config in a later step (merge
    model_cfg with original update_cfg, done outside this function).

    Parameters
    ----------
    model_cfg: A DictConfig instantiated from the MRIDCConfig subclass.
    update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values.
    subconfig_key: A str key used to check and update the sub-config.

    Returns
    -------
    A ModelPT DictConfig with additional keys added to the sub-config.
    """
    if not _HAS_HYDRA:
        logging.error(
            "This function requires Hydra/Omegaconf and it was not installed.")
        sys.exit(1)
    with open_dict(model_cfg.model):
        # Create copy of original model sub config
        if subconfig_key in update_cfg.model:
            if subconfig_key not in model_cfg.model:
                # create the key as a placeholder
                model_cfg.model[subconfig_key] = None

            subconfig = copy.deepcopy(model_cfg.model[subconfig_key])
            update_subconfig = copy.deepcopy(update_cfg.model[subconfig_key])

            # Add the keys and update temporary values, will be updated during full merge
            subconfig = OmegaConf.merge(update_subconfig, subconfig)
            # Update sub config
            model_cfg.model[subconfig_key] = subconfig

    return model_cfg
コード例 #5
0
def check_explicit_log_dir(trainer: Trainer,
                           explicit_log_dir: List[Union[Path, str]],
                           exp_dir: str, name: str,
                           version: str) -> Tuple[Path, str, str, str]:
    """
    Checks that the passed arguments are compatible with explicit_log_dir.

    Parameters
    ----------
    trainer: The trainer to check.
    explicit_log_dir: The explicit log dir to check.
    exp_dir: The experiment directory to check.
    name: The experiment name to check.
    version: The experiment version to check.

    Returns
    -------
    The log_dir, exp_dir, name, and version that should be used.

    Raises
    ------
    LoggerMisconfigurationError
    """
    if trainer.logger is not None:
        raise LoggerMisconfigurationError(
            "The pytorch lightning trainer that was passed to exp_manager contained a logger and explicit_log_dir: "
            f"{explicit_log_dir} was pass to exp_manager. Please remove the logger from the lightning trainer."
        )
    # Checking only (explicit_log_dir) vs (exp_dir and version).
    # The `name` will be used as the actual name of checkpoint/archive.
    if exp_dir or version:
        logging.error(
            f"exp_manager received explicit_log_dir: {explicit_log_dir} and at least one of exp_dir: {exp_dir}, "
            f"or version: {version}. Please note that exp_dir, name, and version will be ignored."
        )
    if is_global_rank_zero() and Path(str(explicit_log_dir)).exists():
        logging.warning(
            f"Exp_manager is logging to {explicit_log_dir}, but it already exists."
        )
    return Path(str(explicit_log_dir)), str(explicit_log_dir), "", ""
コード例 #6
0
def maybe_update_config_version(cfg: "DictConfig"):
    """
    Recursively convert Hydra 0.x configs to Hydra 1.x configs.
    Changes include:
    -   `cls` -> `_target_`.
    -   `params` -> drop params and shift all arguments to parent.
    -   `target` -> `_target_` cannot be performed due to ModelPT injecting `target` inside class.

    Parameters
    ----------
    cfg: Any Hydra compatible DictConfig

    Returns
    -------
    An updated DictConfig that conforms to Hydra 1.x format.
    """
    if not _HAS_HYDRA:
        logging.error(
            "This function requires Hydra/OmegaConf and it was not installed.")
        sys.exit(1)
    if cfg is not None and not isinstance(cfg, DictConfig):
        try:
            temp_cfg = OmegaConf.create(cfg)
            cfg = temp_cfg
        except OmegaConfBaseException:
            # Cannot be cast to DictConfig, skip updating.
            return cfg

    # Make a copy of model config.
    cfg = copy.deepcopy(cfg)
    OmegaConf.set_struct(cfg, False)

    # Convert config.
    _convert_config(cfg)  # type: ignore

    # Update model config.
    OmegaConf.set_struct(cfg, True)

    return cfg
コード例 #7
0
def _update_subconfig(model_cfg: "DictConfig", update_cfg: "DictConfig",
                      subconfig_key: str, drop_missing_subconfigs: bool):
    """
    Updates the MRIDCConfig DictConfig such that:
        1)  If the sub-config key exists in the `update_cfg`, but does not exist in ModelPT config:
            - Add the sub-config from update_cfg to ModelPT config
        2) If the sub-config key does not exist in `update_cfg`, but exists in ModelPT config:
            - Remove the sub-config from the ModelPT config; iff the `drop_missing_subconfigs` flag is set.

    Parameters
    ----------
    model_cfg: A DictConfig instantiated from the MRIDCConfig subclass.
    update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values.
    subconfig_key: A str key used to check and update the sub-config.
    drop_missing_subconfigs: A bool flag, whether to allow deletion of the MRIDCConfig sub-config, if its mirror
    sub-config does not exist in the `update_cfg`.

    Returns
    -------
    The updated DictConfig for the MRIDCConfig
    """
    if not _HAS_HYDRA:
        logging.error(
            "This function requires Hydra/Omegaconf and it was not installed.")
        sys.exit(1)
    with open_dict(model_cfg.model):
        # If update config has the key, but model cfg doesnt have the key
        # Add the update cfg subconfig to the model cfg
        if subconfig_key in update_cfg.model and subconfig_key not in model_cfg.model:
            model_cfg.model[subconfig_key] = update_cfg.model[subconfig_key]

        # If update config does not the key, but model cfg has the key
        # Remove the model cfg subconfig in order to match layout of update cfg
        if subconfig_key not in update_cfg.model and subconfig_key in model_cfg.model and drop_missing_subconfigs:
            model_cfg.model.pop(subconfig_key)

    return model_cfg
コード例 #8
0
def update_model_config(model_cls: MRIDCConfig,
                        update_cfg: "DictConfig",
                        drop_missing_subconfigs: bool = True):
    """
    Helper class that updates the default values of a ModelPT config class with the values in a DictConfig that \
    mirrors the structure of the config class. Assumes the `update_cfg` is a DictConfig (either generated manually, \
    via hydra or instantiated via yaml/model.cfg). This update_cfg is then used to override the default values \
    preset inside the ModelPT config class. If `drop_missing_subconfigs` is set, the certain sub-configs of the \
    ModelPT config class will be removed, if they are not found in the mirrored `update_cfg`. The following \
    sub-configs are subject to potential removal:
        -   `train_ds`
        -   `validation_ds`
        -   `test_ds`
        -   `optim` + nested sched

    Parameters
    ----------
    model_cls: A subclass of MRIDC, that details in entirety all the parameters that constitute the MRIDC Model.
    update_cfg: A DictConfig that mirrors the structure of the MRIDCConfig data class. Used to update the default \
    values of the config class.
    drop_missing_subconfigs: Bool which determines whether to drop certain sub-configs from the MRIDCConfig class, \
    if the corresponding sub-config is missing from `update_cfg`.

    Returns
    -------
    A DictConfig with updated values that can be used to instantiate the MRIDC Model along with supporting \
    infrastructure.
    """
    if not _HAS_HYDRA:
        logging.error(
            "This function requires Hydra/Omegaconf and it was not installed.")
        sys.exit(1)
    if not (is_dataclass(model_cls) or isinstance(model_cls, DictConfig)):
        raise ValueError(
            "`model_cfg` must be a dataclass or a structured OmegaConf object")

    if not isinstance(update_cfg, DictConfig):
        update_cfg = OmegaConf.create(update_cfg)

    if is_dataclass(model_cls):
        model_cls = OmegaConf.structured(model_cls)

    # Update optional configs
    model_cls = _update_subconfig(
        model_cls,
        update_cfg,
        subconfig_key="train_ds",
        drop_missing_subconfigs=drop_missing_subconfigs)
    model_cls = _update_subconfig(
        model_cls,
        update_cfg,
        subconfig_key="validation_ds",
        drop_missing_subconfigs=drop_missing_subconfigs)
    model_cls = _update_subconfig(
        model_cls,
        update_cfg,
        subconfig_key="test_ds",
        drop_missing_subconfigs=drop_missing_subconfigs)
    model_cls = _update_subconfig(
        model_cls,
        update_cfg,
        subconfig_key="optim",
        drop_missing_subconfigs=drop_missing_subconfigs)

    # Add optim and sched additional keys to model cls
    model_cls = _add_subconfig_keys(model_cls,
                                    update_cfg,
                                    subconfig_key="optim")

    # Perform full merge of model config class and update config
    # Remove ModelPT artifact `target`
    if "target" in update_cfg.model and "target" not in model_cls.model:  # type: ignore
        with open_dict(update_cfg.model):
            update_cfg.model.pop("target")

    # Remove ModelPT artifact `mridc_version`
    if "mridc_version" in update_cfg.model and "mridc_version" not in model_cls.model:  # type: ignore
        with open_dict(update_cfg.model):
            update_cfg.model.pop("mridc_version")

    return OmegaConf.merge(model_cls, update_cfg)
コード例 #9
0
def assert_dataclass_signature_match(
    cls: "class_type",  # type: ignore
    datacls: "dataclass",  # type: ignore
    ignore_args: Optional[List[str]] = None,
    remap_args: Optional[Dict[str, str]] = None,
):
    """
    Analyses the signature of a provided class and its respective data class,
    asserting that the dataclass signature matches the class __init__ signature.
    Note:
        This is not a value based check. This function only checks if all argument
        names exist on both class and dataclass and logs mismatches.

    Parameters
    ----------
    cls: Any class type - but not an instance of a class. Pass type(x) where x is an instance
        if class type is not easily available.
    datacls: A corresponding dataclass for the above class.
    ignore_args: (Optional) A list of string argument names which are forcibly ignored,
        even if mismatched in the signature. Useful when a dataclass is a superset of the
        arguments of a class.
    remap_args: (Optional) A dictionary, mapping an argument name that exists (in either the
        class or its dataclass), to another name. Useful when argument names are mismatched between
        a class and its dataclass due to indirect instantiation via a helper method.

    Returns
    -------
    A tuple containing information about the analysis:
        1) A bool value which is True if the signatures matched exactly / after ignoring values.
            False otherwise.
        2) A set of arguments names that exist in the class, but *do not* exist in the dataclass.
            If exact signature match occurs, this will be None instead.
        3) A set of argument names that exist in the data class, but *do not* exist in the class itself.
            If exact signature match occurs, this will be None instead.
    """
    class_sig = inspect.signature(cls.__init__)

    class_params = dict(**class_sig.parameters)
    class_params.pop("self")

    dataclass_sig = inspect.signature(datacls)

    dataclass_params = dict(**dataclass_sig.parameters)
    dataclass_params.pop("_target_", None)

    class_params = set(class_params.keys())  # type: ignore
    dataclass_params = set(dataclass_params.keys())  # type: ignore

    if remap_args is not None:
        for original_arg, new_arg in remap_args.items():
            if original_arg in class_params:
                class_params.remove(original_arg)  # type: ignore
                class_params.add(new_arg)  # type: ignore
                logging.info(
                    f"Remapped {original_arg} -> {new_arg} in {cls.__name__}")

            if original_arg in dataclass_params:
                dataclass_params.remove(original_arg)  # type: ignore
                dataclass_params.add(new_arg)  # type: ignore
                logging.info(
                    f"Remapped {original_arg} -> {new_arg} in {datacls.__name__}"
                )

    if ignore_args is not None:
        ignore_args = set(ignore_args)  # type: ignore

        class_params = class_params - ignore_args  # type: ignore
        dataclass_params = dataclass_params - ignore_args  # type: ignore
        logging.info(f"Removing ignored arguments - {ignore_args}")

    intersection: Set[type] = set.intersection(
        class_params, dataclass_params)  # type: ignore
    subset_cls = class_params - intersection  # type: ignore
    subset_datacls = dataclass_params - intersection  # type: ignore

    if (len(class_params) != len(dataclass_params)
        ) or len(subset_cls) > 0 or len(subset_datacls) > 0:
        logging.error(f"Class {cls.__name__} arguments do not match "
                      f"Dataclass {datacls.__name__}!")

        if len(subset_cls) > 0:
            logging.error(f"Class {cls.__name__} has additional arguments :\n"
                          f"{subset_cls}")

        if len(subset_datacls):
            logging.error(
                f"Dataclass {datacls.__name__} has additional arguments :\n{subset_datacls}"
            )

        return False, subset_cls, subset_datacls
    return True, None, None
コード例 #10
0
ファイル: common.py プロジェクト: wdika/mridc
    def from_config_dict(cls, config: "DictConfig", trainer: Optional[Trainer] = None):
        """Instantiates object using DictConfig-based configuration"""
        # Resolve the config dict
        if _HAS_HYDRA:
            if isinstance(config, DictConfig):
                config = OmegaConf.to_container(config, resolve=True)
                config = OmegaConf.create(config)
                OmegaConf.set_struct(config, True)

            config = mridc.utils.model_utils.maybe_update_config_version(config)  # type: ignore

        # Hydra 0.x API
        if ("cls" in config or "target" in config) and "params" in config and _HAS_HYDRA:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
        elif "_target_" in config and _HAS_HYDRA:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
        else:
            instance = None
            prev_error = ""

            # Attempt class path resolution from config `target` class (if it exists)
            if "target" in config:
                target_cls = config["target"]  # No guarantee that this is a omegaconf class
                imported_cls = None
                try:
                    # try to import the target class
                    imported_cls = mridc.utils.model_utils.import_class_by_path(target_cls)  # type: ignore

                    # use subclass instead
                    if issubclass(cls, imported_cls):
                        imported_cls = cls
                        if accepts_trainer := Serialization._inspect_signature_for_trainer(imported_cls):
                            if trainer is None:
                                # Create a dummy PL trainer object
                                cfg_trainer = TrainerConfig(
                                    gpus=1, accelerator="ddp", num_nodes=1, logger=False, checkpoint_callback=False
                                )
                                trainer = Trainer(cfg_trainer)
                            instance = imported_cls(cfg=config, trainer=trainer)  # type: ignore
                        else:
                            instance = imported_cls(cfg=config)  # type: ignore

                except Exception as e:
                    tb = traceback.format_exc()
                    prev_error = f"Model instantiation failed.\nTarget class: {target_cls}\nError: {e}\n{tb}"

                    logging.debug(prev_error + "\n falling back to 'cls'.")
            # target class resolution was unsuccessful, fall back to current `cls`
            if instance is None:
                try:
                    if accepts_trainer := Serialization._inspect_signature_for_trainer(cls):
                        instance = cls(cfg=config)  # type: ignore
                except Exception as e:
                    # report saved errors, if any, and raise the current error
                    if prev_error:
                        logging.error(f"{prev_error}")
                    raise e from e

        if not hasattr(instance, "_cfg"):
            instance._cfg = config
        return instance
コード例 #11
0
def resolve_test_dataloaders(model: "ModelPT"):
    """
    Helper method that operates on the ModelPT class to automatically support
    multiple dataloaders for the test set.
    It does so by first resolving the path to one/more data files via `resolve_dataset_name_from_cfg()`.
    If this resolution fails, it assumes the data loader is prepared to manually support / not support
    multiple data loaders and simply calls the appropriate setup method.
    If resolution succeeds:
        Checks if provided path is to a single file or a list of files.
        If a single file is provided, simply tags that file as such and loads it via the setup method.
        If multiple files are provided:
            Inject a new manifest path at index "i" into the resolved key.
            Calls the appropriate setup method to set the data loader.
            Collects the initialized data loader in a list and preserves it.
            Once all data loaders are processed, assigns the list of loaded loaders to the ModelPT.
            Finally, assigns a list of unique names resolved from the file paths to the ModelPT.

    Parameters
    ----------
    model: ModelPT subclass, which requires >=1 Test Dataloaders to be setup.
    """
    if not _HAS_HYDRA:
        logging.error(
            "This function requires Hydra/OmegaConf and it was not installed.")
        sys.exit(1)
    cfg = copy.deepcopy(model._cfg)
    dataloaders: List[Any] = []

    # process test_loss_idx
    if "test_dl_idx" in cfg.test_ds:
        cfg = OmegaConf.to_container(cfg)
        test_dl_idx = cfg["test_ds"].pop("test_dl_idx")
        cfg = OmegaConf.create(cfg)
    else:
        test_dl_idx = 0

    # Set val_loss_idx
    model._test_dl_idx = test_dl_idx

    ds_key = resolve_dataset_name_from_cfg(cfg.test_ds)

    if ds_key is None:
        logging.debug(
            f"Could not resolve file path from provided config - {cfg.test_ds}. "
            "Disabling support for multi-dataloaders.")

        model.setup_test_data(cfg.test_ds)
        return

    ds_values = cfg.test_ds[ds_key]

    if isinstance(ds_values, (list, tuple, ListConfig)):

        for ds_value in ds_values:
            cfg.test_ds[ds_key] = ds_value
            model.setup_test_data(cfg.test_ds)
            dataloaders.append(model.test_dl)

        model.test_dl = dataloaders  # type: ignore
        model.test_names = [parse_dataset_as_name(ds)
                            for ds in ds_values]  # type: ignore

        unique_names_check(name_list=model.test_names)
        return
    model.setup_test_data(cfg.test_ds)
    model.test_names = [parse_dataset_as_name(ds_values)]

    unique_names_check(name_list=model.test_names)
コード例 #12
0
def configure_checkpointing(trainer: Trainer, log_dir: Path, name: str,
                            resume: bool, params: "DictConfig"):
    """Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint
    callback or if trainer.weights_save_path was passed to Trainer.
    """
    for callback in trainer.callbacks:
        if isinstance(callback, ModelCheckpoint):
            raise CheckpointMisconfigurationError(
                "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint "
                "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback "
                "to False, or remove ModelCheckpoint from the lightning trainer"
            )
    if Path(trainer.weights_save_path) != Path.cwd():
        raise CheckpointMisconfigurationError(
            "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager"
        )

    # Create the callback and attach it to trainer
    if "filepath" in params:
        if params.filepath is not None:
            logging.warning(
                "filepath is deprecated. Please switch to dirpath and filename instead"
            )
            if params.dirpath is None:
                params.dirpath = Path(params.filepath).parent
            if params.filename is None:
                params.filename = Path(params.filepath).name
        with open_dict(params):
            del params["filepath"]
    if params.dirpath is None:
        params.dirpath = Path(log_dir / "checkpoints")
    if params.filename is None:
        params.filename = f"{name}--{{{params.monitor}:.4f}}-{{epoch}}"
    if params.prefix is None:
        params.prefix = name
    MRIDCModelCheckpoint.CHECKPOINT_NAME_LAST = f"{params.filename}-last"

    logging.debug(params.dirpath)
    logging.debug(params.filename)
    logging.debug(params.prefix)

    if "val" in params.monitor:
        if (trainer.max_epochs is not None and trainer.max_epochs != -1
                and trainer.max_epochs < trainer.check_val_every_n_epoch):
            logging.error(
                "The checkpoint callback was told to monitor a validation value but trainer.max_epochs("
                f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch("
                f"{trainer.check_val_every_n_epoch}). It is very likely this run will fail with "
                f"ModelCheckpoint(monitor='{params.monitor}') not found in the returned metrics. Please ensure that "
                "validation is run within trainer.max_epochs.")
        elif trainer.max_steps is not None:
            logging.warning(
                "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to "
                f"{trainer.max_steps}. Please ensure that max_steps will run for at least "
                f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out."
            )

    checkpoint_callback = MRIDCModelCheckpoint(n_resume=resume, **params)
    checkpoint_callback.last_model_path = trainer._checkpoint_connector.resume_from_checkpoint_fit_path or ""
    if "mp_rank" in checkpoint_callback.last_model_path or "tp_rank" in checkpoint_callback.last_model_path:
        checkpoint_callback.last_model_path = mridc.utils.model_utils.uninject_model_parallel_rank(  # type: ignore
            checkpoint_callback.last_model_path)
    trainer.callbacks.append(checkpoint_callback)
コード例 #13
0
def exp_manager(
        trainer: Trainer,
        cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]:
    """
    exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning \
    paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will \
    get exp_dir, name, and version from the logger. Otherwise, it will use the exp_dir and name arguments to create \
    the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir.

    The version can be a datetime string or an integer. Datetime version can be disabled if you use_datetime_version \
    is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch \
    lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file \
    for each process to log their output into.

    exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from \
    the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need \
    multiple consecutive jobs), you need to avoid creating the version folders. Therefore, from v1.0.0, when \
    resume_if_exists is set to True, creating the version folders is ignored.

    Parameters
    ----------
    trainer: The lightning trainer object.
    cfg: Can have the following keys:
        - explicit_log_dir: Can be used to override exp_dir/name/version folder creation. Defaults to None, which \
        will use exp_dir, name, and version to construct the logging directory.
        - exp_dir: The base directory to create the logging directory. Defaults to None, which logs to \
         ./mridc_experiments.
        - name: The name of the experiment. Defaults to None which turns into "default" via name = name or "default".
        - version: The version of the experiment. Defaults to None which uses either a datetime string or lightning's \
         TensorboardLogger system of using version_{int}.
        - use_datetime_version: Whether to use a datetime string for version. Defaults to True.
        - resume_if_exists: Whether this experiment is resuming from a previous run. If True, it sets \
        trainer._checkpoint_connector.resume_from_checkpoint_fit_path so that the trainer should auto-resume. \
        exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. From v1.0.0, when \
        resume_if_exists is True, we would not create version folders to make it easier to find the log folder for \
        next runs.
        - resume_past_end: exp_manager errors out if resume_if_exists is True and a checkpoint matching \*end.ckpt \
        indicating a previous training run fully completed. This behaviour can be disabled, in which case the \
        \*end.ckpt will be loaded by setting resume_past_end to True. Defaults to False.
        - resume_ignore_no_checkpoint: exp_manager errors out if resume_if_exists is True and no checkpoint could be \
         found. This behaviour can be disabled, in which case exp_manager will print a message and continue without \
         restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False.
        - create_tensorboard_logger: Whether to create a tensorboard logger and attach it to the pytorch lightning \
        trainer. Defaults to True.
        - summary_writer_kwargs: A dictionary of kwargs that can be passed to lightning's TensorboardLogger class. \
        Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None.
        - create_wandb_logger: Whether to create a Weights and Biases logger and attach it to the pytorch lightning \
        trainer. Defaults to False.
        - wandb_logger_kwargs: A dictionary of kwargs that can be passed to lightning's WandBLogger class. Note that \
         name and project are required parameters if create_wandb_logger is True. Defaults to None.
        - create_checkpoint_callback: Whether to create a ModelCheckpoint callback and attach it to the pytorch \
        lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most recent \
        checkpoint under \*last.ckpt, and the final checkpoint after training completes under \*end.ckpt. \
        Defaults to True.
        - files_to_copy: A list of files to copy to the experiment logging directory. Defaults to None which copies \
        no files.
        - log_local_rank_0_only: Whether to only create log files for local rank 0. Defaults to False. Set this to \
        True if you are using DDP with many GPUs and do not want many log files in your exp dir.
        - log_global_rank_0_only: Whether to only create log files for global rank 0. Defaults to False. Set this to \
        True if you are using DDP with many GPUs and do not want many log files in your exp dir.

    Returns
    -------
    The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and version.
    """
    # Add rank information to logger
    # Note: trainer.global_rank and trainer.is_global_zero are not set until trainer.fit, so have to hack around it
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    global_rank = trainer.node_rank * trainer.num_devices + local_rank
    logging.rank = global_rank

    if cfg is None:
        logging.error(
            "exp_manager did not receive a cfg argument. It will be disabled.")
        return None

    if trainer.fast_dev_run:
        logging.info(
            "Trainer was called with fast_dev_run. exp_manager will return without any functionality."
        )
        return None

    # Ensure passed cfg is compliant with ExpManagerConfig
    schema = OmegaConf.structured(ExpManagerConfig)
    if isinstance(cfg, dict):
        cfg = OmegaConf.create(cfg)
    elif not isinstance(cfg, DictConfig):
        raise ValueError(
            f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
        )
    cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True))
    cfg = OmegaConf.merge(schema, cfg)

    error_checks(
        trainer, cfg
    )  # Ensures that trainer options are compliant with MRIDC and exp_manager arguments

    log_dir, exp_dir, name, version = get_log_dir(
        trainer=trainer,
        exp_dir=cfg.exp_dir,
        name=cfg.name,
        version=cfg.version,
        explicit_log_dir=cfg.explicit_log_dir,
        use_datetime_version=cfg.use_datetime_version,
        resume_if_exists=cfg.resume_if_exists,
    )

    if cfg.resume_if_exists:
        check_resume(trainer, str(log_dir), cfg.resume_past_end,
                     cfg.resume_ignore_no_checkpoint)

    checkpoint_name = name
    # If name returned from get_log_dir is "", use cfg.name for checkpointing
    if checkpoint_name is None or checkpoint_name == "":
        checkpoint_name = cfg.name or "default"
    cfg.name = name  # Used for configure_loggers so that the log_dir is properly set even if name is ""
    cfg.version = version

    # update app_state with log_dir, exp_dir, etc
    app_state = AppState()
    app_state.log_dir = log_dir
    app_state.exp_dir = exp_dir
    app_state.name = name
    app_state.version = version
    app_state.checkpoint_name = checkpoint_name
    app_state.create_checkpoint_callback = cfg.create_checkpoint_callback
    app_state.checkpoint_callback_params = cfg.checkpoint_callback_params

    # Create the logging directory if it does not exist
    os.makedirs(
        log_dir, exist_ok=True
    )  # Cannot limit creation to global zero as all ranks write to own log file
    logging.info(f"Experiments will be logged at {log_dir}")
    trainer._default_root_dir = log_dir

    if cfg.log_local_rank_0_only is True and cfg.log_global_rank_0_only is True:
        raise ValueError(
            "Cannot set both log_local_rank_0_only and log_global_rank_0_only to True."
            "Please set either one or neither.")

    # This is set if the env var MRIDC_TESTING is set to True.
    mridc_testing = get_envbool(MRIDC_ENV_VARNAME_TESTING, False)

    log_file = log_dir / f"mridc_log_globalrank-{global_rank}_localrank-{local_rank}.txt"

    # Handle logging to file. Logs local rank 0 only
    if local_rank == 0 and cfg.log_local_rank_0_only and not mridc_testing:
        logging.add_file_handler(log_file)
    elif global_rank == 0 and cfg.log_global_rank_0_only and mridc_testing:
        logging.add_file_handler(log_file)
    else:
        logging.add_file_handler(log_file)

    # For some reason, LearningRateLogger requires trainer to have a logger. Safer to create logger on all ranks
    # not just global rank 0.
    if cfg.create_tensorboard_logger or cfg.create_wandb_logger:
        configure_loggers(
            trainer,
            [Path(exp_dir)],
            cfg.name,
            cfg.version,
            cfg.create_tensorboard_logger,
            cfg.summary_writer_kwargs,
            cfg.create_wandb_logger,
            cfg.wandb_logger_kwargs,
        )

    # add loggers timing callbacks
    if cfg.log_step_timing:
        timing_callback = TimingCallback(
            timer_kwargs=cfg.step_timing_kwargs or {})
        trainer.callbacks.insert(0, timing_callback)

    if cfg.create_checkpoint_callback:
        configure_checkpointing(trainer, log_dir, checkpoint_name,
                                cfg.resume_if_exists,
                                cfg.checkpoint_callback_params)

    if is_global_rank_zero():
        # Move files_to_copy to folder and add git information if present
        if cfg.files_to_copy:
            for _file in cfg.files_to_copy:
                copy(Path(_file), log_dir)

        # Create files for cmd args and git info
        with open(log_dir / "cmd-args.log", "w", encoding="utf-8") as _file:
            _file.write(" ".join(sys.argv))

        # Try to get git hash
        git_repo, git_hash = get_git_hash()
        if git_repo:
            with open(log_dir / "git-info.log", "w",
                      encoding="utf-8") as _file:
                _file.write(f"commit hash: {git_hash}")
                _file.write(get_git_diff())

        # Add err_file logging to global_rank zero
        logging.add_err_file_handler(log_dir / "mridc_error_log.txt")

        # Add lightning file logging to global_rank zero
        add_filehandlers_to_pl_logger(log_dir / "lightning_logs.txt",
                                      log_dir / "mridc_error_log.txt")

    return log_dir
コード例 #14
0
ファイル: export.py プロジェクト: wdika/mridc
    def export(
        self,
        output: str,
        input_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        onnx_opset_version=None,
        try_script: bool = False,
        training=TrainingMode.EVAL,
        check_trace: bool = False,
        use_dynamic_axes: bool = True,
        dynamic_axes=None,
        check_tolerance=0.01,
    ):
        """
        Export the module to a file.

        Parameters
        ----------
        output: The output file path.
        input_example: A dictionary of input names and values.
        verbose: If True, print out the export process.
        export_params: If True, export the parameters of the module.
        do_constant_folding: If True, do constant folding.
        onnx_opset_version: The ONNX opset version to use.
        try_script: If True, try to export as TorchScript.
        training: Training mode for the export.
        check_trace: If True, check the trace of the exported model.
        use_dynamic_axes: If True, use dynamic axes for the export.
        dynamic_axes: A dictionary of input names and dynamic axes.
        check_tolerance: The tolerance for the check_trace.
        """
        my_args = locals().copy()
        my_args.pop("self")

        exportables = []
        for m in self.modules():  # type: ignore
            if isinstance(m, Exportable):
                exportables.append(m)

        qual_name = self.__module__ + "." + self.__class__.__qualname__
        format = get_export_format(output)
        output_descr = f"{qual_name} exported to {format}"

        # Pytorch's default for None is too low, can't pass None through
        if onnx_opset_version is None:
            onnx_opset_version = 13

        try:
            # Disable typechecks
            typecheck.set_typecheck_enabled(enabled=False)

            # Allow user to completely override forward method to export
            forward_method, old_forward_method = wrap_forward_method(self)

            # Set module mode
            with torch.onnx.select_model_mode_for_export(
                    self, training), torch.inference_mode(
                    ), torch.jit.optimized_execution(True):

                if input_example is None:
                    input_example = self.input_module.input_example()

                # Remove i/o examples from args we propagate to enclosed Exportables
                my_args.pop("output")
                my_args.pop("input_example")

                # Run (possibly overridden) prepare methods before calling forward()
                for ex in exportables:
                    ex._prepare_for_export(**my_args, noreplace=True)
                self._prepare_for_export(output=output,
                                         input_example=input_example,
                                         **my_args)

                input_list, input_dict = parse_input_example(input_example)
                input_names = self.input_names
                output_names = self.output_names
                output_example = tuple(self.forward(
                    *input_list, **input_dict))  # type: ignore

                jitted_model = None
                if try_script:
                    try:
                        jitted_model = torch.jit.script(self)
                    except Exception as e:
                        logging.error(f"jit.script() failed!\n{e}")

                if format == ExportFormat.TORCHSCRIPT:
                    if jitted_model is None:
                        jitted_model = torch.jit.trace_module(
                            self,
                            {
                                "forward":
                                tuple(input_list) + tuple(input_dict.values())
                            },
                            strict=True,
                            check_trace=check_trace,
                            check_tolerance=check_tolerance,
                        )
                    if not self.training:  # type: ignore
                        jitted_model = torch.jit.optimize_for_inference(
                            jitted_model)
                    if verbose:
                        logging.info(f"JIT code:\n{jitted_model.code}")
                    jitted_model.save(output)
                elif format == ExportFormat.ONNX:
                    if jitted_model is None:
                        jitted_model = self

                    # dynamic axis is a mapping from input/output_name => list of "dynamic" indices
                    if dynamic_axes is None and use_dynamic_axes:
                        dynamic_axes = get_dynamic_axes(
                            self.input_module.input_types, input_names)
                        dynamic_axes.update(
                            get_dynamic_axes(self.output_module.output_types,
                                             output_names))

                    torch.onnx.export(
                        jitted_model,
                        input_example,
                        output,
                        input_names=input_names,
                        output_names=output_names,
                        verbose=verbose,
                        export_params=export_params,
                        do_constant_folding=do_constant_folding,
                        dynamic_axes=dynamic_axes,
                        opset_version=onnx_opset_version,
                    )

                    if check_trace:
                        verify_runtime(output, input_list, input_dict,
                                       input_names, output_names,
                                       output_example)

                else:
                    raise ValueError(
                        f"Encountered unknown export format {format}.")
        finally:
            typecheck.set_typecheck_enabled(enabled=True)
            if forward_method:
                type(self).forward = old_forward_method  # type: ignore
            self._export_teardown()
        return [output], [output_descr]