def _convert_config(cfg: "OmegaConf"): """Recursive function converting the configuration from old hydra format to the new one.""" if not _HAS_HYDRA: logging.error( "This function requires Hydra/OmegaConf and it was not installed.") sys.exit(1) # Get rid of cls -> _target_. if "cls" in cfg and "_target_" not in cfg: cfg._target_ = cfg.pop("cls") # type: ignore # Get rid of params. if "params" in cfg: params = cfg.pop("params") # type: ignore for param_key, param_val in params.items(): cfg[param_key] = param_val # Recursion. try: for _, sub_cfg in cfg.items(): # type: ignore if isinstance(sub_cfg, DictConfig): _convert_config(sub_cfg) # type: ignore except OmegaConfBaseException as e: logging.warning( f"Skipped conversion for config/subconfig:\n{cfg}\n Reason: {e}.")
def error_checks(trainer: Trainer, cfg: Optional[Union[DictConfig, Dict]] = None): """ Checks that the passed trainer is compliant with MRIDC and exp_manager's passed configuration. Checks that: - Throws error when hydra has changed the working directory. This causes issues with lightning's DDP - Throws error when trainer has loggers defined but create_tensorboard_logger or create_WandB_logger is True - Prints error messages when 1) run on multi-node and not Slurm, and 2) run on multi-gpu without DDP """ if HydraConfig.initialized() and get_original_cwd() != os.getcwd(): raise ValueError( "Hydra changed the working directory. This interferes with ExpManger's functionality. Please pass " "hydra.run.dir=. to your python script.") if trainer.logger is not None and ( cfg.create_tensorboard_logger or cfg.create_wandb_logger): # type: ignore raise LoggerMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a logger, and either " "create_tensorboard_logger or create_wandb_logger was set to True. These can only be used if trainer does " "not already have a logger.") if trainer.num_nodes > 1 and not check_slurm(trainer): # type: ignore logging.error( "You are running multi-node training without SLURM handling the processes." " Please note that this is not tested in MRIDC and could result in errors." ) if trainer.num_devices > 1 and not isinstance(trainer.strategy, DDPStrategy): # type: ignore logging.error( "You are running multi-gpu without ddp.Please note that this is not tested in MRIDC and could result in " "errors.")
def convert_model_config_to_dict_config( cfg: Union[DictConfig, MRIDCConfig]) -> DictConfig: """ Converts its input into a standard DictConfig. Possible input values are: - DictConfig - A dataclass which is a subclass of MRIDCConfig Parameters ---------- cfg: A dict-like object. Returns ------- The equivalent DictConfig. """ if not _HAS_HYDRA: logging.error( "This function requires Hydra/OmegaConf and it was not installed.") sys.exit(1) if not isinstance(cfg, (OmegaConf, DictConfig)) and is_dataclass(cfg): cfg = OmegaConf.structured(cfg) if not isinstance(cfg, DictConfig): raise ValueError( f"cfg constructor argument must be of type DictConfig/dict but got {type(cfg)} instead." ) config = OmegaConf.to_container(cfg, resolve=True) config = OmegaConf.create(config) return config
def _add_subconfig_keys(model_cfg: "DictConfig", update_cfg: "DictConfig", subconfig_key: str): """ For certain sub-configs, the default values specified by the MRIDCConfig class is insufficient. In order to support every potential value in the merge between the `update_cfg`, it would require explicit definition of all possible cases. An example of such a case is Optimizers, and their equivalent Schedulers. All optimizers share a few basic details - such as name and lr, but almost all require additional parameters - such as weight decay. It is impractical to create a config for every single optimizer + every single scheduler combination. In such a case, we perform a dual merge. The Optim and Sched Dataclass contain the bare minimum essential components. The extra values are provided via update_cfg. In order to enable the merge, we first need to update the update sub-config to incorporate the keys, with dummy temporary values (merge update config with model config). This is done on a copy of the update sub-config, as the actual override values might be overridden by the MRIDCConfig defaults. Then we perform a merge of this temporary sub-config with the actual override config in a later step (merge model_cfg with original update_cfg, done outside this function). Parameters ---------- model_cfg: A DictConfig instantiated from the MRIDCConfig subclass. update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values. subconfig_key: A str key used to check and update the sub-config. Returns ------- A ModelPT DictConfig with additional keys added to the sub-config. """ if not _HAS_HYDRA: logging.error( "This function requires Hydra/Omegaconf and it was not installed.") sys.exit(1) with open_dict(model_cfg.model): # Create copy of original model sub config if subconfig_key in update_cfg.model: if subconfig_key not in model_cfg.model: # create the key as a placeholder model_cfg.model[subconfig_key] = None subconfig = copy.deepcopy(model_cfg.model[subconfig_key]) update_subconfig = copy.deepcopy(update_cfg.model[subconfig_key]) # Add the keys and update temporary values, will be updated during full merge subconfig = OmegaConf.merge(update_subconfig, subconfig) # Update sub config model_cfg.model[subconfig_key] = subconfig return model_cfg
def check_explicit_log_dir(trainer: Trainer, explicit_log_dir: List[Union[Path, str]], exp_dir: str, name: str, version: str) -> Tuple[Path, str, str, str]: """ Checks that the passed arguments are compatible with explicit_log_dir. Parameters ---------- trainer: The trainer to check. explicit_log_dir: The explicit log dir to check. exp_dir: The experiment directory to check. name: The experiment name to check. version: The experiment version to check. Returns ------- The log_dir, exp_dir, name, and version that should be used. Raises ------ LoggerMisconfigurationError """ if trainer.logger is not None: raise LoggerMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a logger and explicit_log_dir: " f"{explicit_log_dir} was pass to exp_manager. Please remove the logger from the lightning trainer." ) # Checking only (explicit_log_dir) vs (exp_dir and version). # The `name` will be used as the actual name of checkpoint/archive. if exp_dir or version: logging.error( f"exp_manager received explicit_log_dir: {explicit_log_dir} and at least one of exp_dir: {exp_dir}, " f"or version: {version}. Please note that exp_dir, name, and version will be ignored." ) if is_global_rank_zero() and Path(str(explicit_log_dir)).exists(): logging.warning( f"Exp_manager is logging to {explicit_log_dir}, but it already exists." ) return Path(str(explicit_log_dir)), str(explicit_log_dir), "", ""
def maybe_update_config_version(cfg: "DictConfig"): """ Recursively convert Hydra 0.x configs to Hydra 1.x configs. Changes include: - `cls` -> `_target_`. - `params` -> drop params and shift all arguments to parent. - `target` -> `_target_` cannot be performed due to ModelPT injecting `target` inside class. Parameters ---------- cfg: Any Hydra compatible DictConfig Returns ------- An updated DictConfig that conforms to Hydra 1.x format. """ if not _HAS_HYDRA: logging.error( "This function requires Hydra/OmegaConf and it was not installed.") sys.exit(1) if cfg is not None and not isinstance(cfg, DictConfig): try: temp_cfg = OmegaConf.create(cfg) cfg = temp_cfg except OmegaConfBaseException: # Cannot be cast to DictConfig, skip updating. return cfg # Make a copy of model config. cfg = copy.deepcopy(cfg) OmegaConf.set_struct(cfg, False) # Convert config. _convert_config(cfg) # type: ignore # Update model config. OmegaConf.set_struct(cfg, True) return cfg
def _update_subconfig(model_cfg: "DictConfig", update_cfg: "DictConfig", subconfig_key: str, drop_missing_subconfigs: bool): """ Updates the MRIDCConfig DictConfig such that: 1) If the sub-config key exists in the `update_cfg`, but does not exist in ModelPT config: - Add the sub-config from update_cfg to ModelPT config 2) If the sub-config key does not exist in `update_cfg`, but exists in ModelPT config: - Remove the sub-config from the ModelPT config; iff the `drop_missing_subconfigs` flag is set. Parameters ---------- model_cfg: A DictConfig instantiated from the MRIDCConfig subclass. update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values. subconfig_key: A str key used to check and update the sub-config. drop_missing_subconfigs: A bool flag, whether to allow deletion of the MRIDCConfig sub-config, if its mirror sub-config does not exist in the `update_cfg`. Returns ------- The updated DictConfig for the MRIDCConfig """ if not _HAS_HYDRA: logging.error( "This function requires Hydra/Omegaconf and it was not installed.") sys.exit(1) with open_dict(model_cfg.model): # If update config has the key, but model cfg doesnt have the key # Add the update cfg subconfig to the model cfg if subconfig_key in update_cfg.model and subconfig_key not in model_cfg.model: model_cfg.model[subconfig_key] = update_cfg.model[subconfig_key] # If update config does not the key, but model cfg has the key # Remove the model cfg subconfig in order to match layout of update cfg if subconfig_key not in update_cfg.model and subconfig_key in model_cfg.model and drop_missing_subconfigs: model_cfg.model.pop(subconfig_key) return model_cfg
def update_model_config(model_cls: MRIDCConfig, update_cfg: "DictConfig", drop_missing_subconfigs: bool = True): """ Helper class that updates the default values of a ModelPT config class with the values in a DictConfig that \ mirrors the structure of the config class. Assumes the `update_cfg` is a DictConfig (either generated manually, \ via hydra or instantiated via yaml/model.cfg). This update_cfg is then used to override the default values \ preset inside the ModelPT config class. If `drop_missing_subconfigs` is set, the certain sub-configs of the \ ModelPT config class will be removed, if they are not found in the mirrored `update_cfg`. The following \ sub-configs are subject to potential removal: - `train_ds` - `validation_ds` - `test_ds` - `optim` + nested sched Parameters ---------- model_cls: A subclass of MRIDC, that details in entirety all the parameters that constitute the MRIDC Model. update_cfg: A DictConfig that mirrors the structure of the MRIDCConfig data class. Used to update the default \ values of the config class. drop_missing_subconfigs: Bool which determines whether to drop certain sub-configs from the MRIDCConfig class, \ if the corresponding sub-config is missing from `update_cfg`. Returns ------- A DictConfig with updated values that can be used to instantiate the MRIDC Model along with supporting \ infrastructure. """ if not _HAS_HYDRA: logging.error( "This function requires Hydra/Omegaconf and it was not installed.") sys.exit(1) if not (is_dataclass(model_cls) or isinstance(model_cls, DictConfig)): raise ValueError( "`model_cfg` must be a dataclass or a structured OmegaConf object") if not isinstance(update_cfg, DictConfig): update_cfg = OmegaConf.create(update_cfg) if is_dataclass(model_cls): model_cls = OmegaConf.structured(model_cls) # Update optional configs model_cls = _update_subconfig( model_cls, update_cfg, subconfig_key="train_ds", drop_missing_subconfigs=drop_missing_subconfigs) model_cls = _update_subconfig( model_cls, update_cfg, subconfig_key="validation_ds", drop_missing_subconfigs=drop_missing_subconfigs) model_cls = _update_subconfig( model_cls, update_cfg, subconfig_key="test_ds", drop_missing_subconfigs=drop_missing_subconfigs) model_cls = _update_subconfig( model_cls, update_cfg, subconfig_key="optim", drop_missing_subconfigs=drop_missing_subconfigs) # Add optim and sched additional keys to model cls model_cls = _add_subconfig_keys(model_cls, update_cfg, subconfig_key="optim") # Perform full merge of model config class and update config # Remove ModelPT artifact `target` if "target" in update_cfg.model and "target" not in model_cls.model: # type: ignore with open_dict(update_cfg.model): update_cfg.model.pop("target") # Remove ModelPT artifact `mridc_version` if "mridc_version" in update_cfg.model and "mridc_version" not in model_cls.model: # type: ignore with open_dict(update_cfg.model): update_cfg.model.pop("mridc_version") return OmegaConf.merge(model_cls, update_cfg)
def assert_dataclass_signature_match( cls: "class_type", # type: ignore datacls: "dataclass", # type: ignore ignore_args: Optional[List[str]] = None, remap_args: Optional[Dict[str, str]] = None, ): """ Analyses the signature of a provided class and its respective data class, asserting that the dataclass signature matches the class __init__ signature. Note: This is not a value based check. This function only checks if all argument names exist on both class and dataclass and logs mismatches. Parameters ---------- cls: Any class type - but not an instance of a class. Pass type(x) where x is an instance if class type is not easily available. datacls: A corresponding dataclass for the above class. ignore_args: (Optional) A list of string argument names which are forcibly ignored, even if mismatched in the signature. Useful when a dataclass is a superset of the arguments of a class. remap_args: (Optional) A dictionary, mapping an argument name that exists (in either the class or its dataclass), to another name. Useful when argument names are mismatched between a class and its dataclass due to indirect instantiation via a helper method. Returns ------- A tuple containing information about the analysis: 1) A bool value which is True if the signatures matched exactly / after ignoring values. False otherwise. 2) A set of arguments names that exist in the class, but *do not* exist in the dataclass. If exact signature match occurs, this will be None instead. 3) A set of argument names that exist in the data class, but *do not* exist in the class itself. If exact signature match occurs, this will be None instead. """ class_sig = inspect.signature(cls.__init__) class_params = dict(**class_sig.parameters) class_params.pop("self") dataclass_sig = inspect.signature(datacls) dataclass_params = dict(**dataclass_sig.parameters) dataclass_params.pop("_target_", None) class_params = set(class_params.keys()) # type: ignore dataclass_params = set(dataclass_params.keys()) # type: ignore if remap_args is not None: for original_arg, new_arg in remap_args.items(): if original_arg in class_params: class_params.remove(original_arg) # type: ignore class_params.add(new_arg) # type: ignore logging.info( f"Remapped {original_arg} -> {new_arg} in {cls.__name__}") if original_arg in dataclass_params: dataclass_params.remove(original_arg) # type: ignore dataclass_params.add(new_arg) # type: ignore logging.info( f"Remapped {original_arg} -> {new_arg} in {datacls.__name__}" ) if ignore_args is not None: ignore_args = set(ignore_args) # type: ignore class_params = class_params - ignore_args # type: ignore dataclass_params = dataclass_params - ignore_args # type: ignore logging.info(f"Removing ignored arguments - {ignore_args}") intersection: Set[type] = set.intersection( class_params, dataclass_params) # type: ignore subset_cls = class_params - intersection # type: ignore subset_datacls = dataclass_params - intersection # type: ignore if (len(class_params) != len(dataclass_params) ) or len(subset_cls) > 0 or len(subset_datacls) > 0: logging.error(f"Class {cls.__name__} arguments do not match " f"Dataclass {datacls.__name__}!") if len(subset_cls) > 0: logging.error(f"Class {cls.__name__} has additional arguments :\n" f"{subset_cls}") if len(subset_datacls): logging.error( f"Dataclass {datacls.__name__} has additional arguments :\n{subset_datacls}" ) return False, subset_cls, subset_datacls return True, None, None
def from_config_dict(cls, config: "DictConfig", trainer: Optional[Trainer] = None): """Instantiates object using DictConfig-based configuration""" # Resolve the config dict if _HAS_HYDRA: if isinstance(config, DictConfig): config = OmegaConf.to_container(config, resolve=True) config = OmegaConf.create(config) OmegaConf.set_struct(config, True) config = mridc.utils.model_utils.maybe_update_config_version(config) # type: ignore # Hydra 0.x API if ("cls" in config or "target" in config) and "params" in config and _HAS_HYDRA: # regular hydra-based instantiation instance = hydra.utils.instantiate(config=config) elif "_target_" in config and _HAS_HYDRA: # regular hydra-based instantiation instance = hydra.utils.instantiate(config=config) else: instance = None prev_error = "" # Attempt class path resolution from config `target` class (if it exists) if "target" in config: target_cls = config["target"] # No guarantee that this is a omegaconf class imported_cls = None try: # try to import the target class imported_cls = mridc.utils.model_utils.import_class_by_path(target_cls) # type: ignore # use subclass instead if issubclass(cls, imported_cls): imported_cls = cls if accepts_trainer := Serialization._inspect_signature_for_trainer(imported_cls): if trainer is None: # Create a dummy PL trainer object cfg_trainer = TrainerConfig( gpus=1, accelerator="ddp", num_nodes=1, logger=False, checkpoint_callback=False ) trainer = Trainer(cfg_trainer) instance = imported_cls(cfg=config, trainer=trainer) # type: ignore else: instance = imported_cls(cfg=config) # type: ignore except Exception as e: tb = traceback.format_exc() prev_error = f"Model instantiation failed.\nTarget class: {target_cls}\nError: {e}\n{tb}" logging.debug(prev_error + "\n falling back to 'cls'.") # target class resolution was unsuccessful, fall back to current `cls` if instance is None: try: if accepts_trainer := Serialization._inspect_signature_for_trainer(cls): instance = cls(cfg=config) # type: ignore except Exception as e: # report saved errors, if any, and raise the current error if prev_error: logging.error(f"{prev_error}") raise e from e if not hasattr(instance, "_cfg"): instance._cfg = config return instance
def resolve_test_dataloaders(model: "ModelPT"): """ Helper method that operates on the ModelPT class to automatically support multiple dataloaders for the test set. It does so by first resolving the path to one/more data files via `resolve_dataset_name_from_cfg()`. If this resolution fails, it assumes the data loader is prepared to manually support / not support multiple data loaders and simply calls the appropriate setup method. If resolution succeeds: Checks if provided path is to a single file or a list of files. If a single file is provided, simply tags that file as such and loads it via the setup method. If multiple files are provided: Inject a new manifest path at index "i" into the resolved key. Calls the appropriate setup method to set the data loader. Collects the initialized data loader in a list and preserves it. Once all data loaders are processed, assigns the list of loaded loaders to the ModelPT. Finally, assigns a list of unique names resolved from the file paths to the ModelPT. Parameters ---------- model: ModelPT subclass, which requires >=1 Test Dataloaders to be setup. """ if not _HAS_HYDRA: logging.error( "This function requires Hydra/OmegaConf and it was not installed.") sys.exit(1) cfg = copy.deepcopy(model._cfg) dataloaders: List[Any] = [] # process test_loss_idx if "test_dl_idx" in cfg.test_ds: cfg = OmegaConf.to_container(cfg) test_dl_idx = cfg["test_ds"].pop("test_dl_idx") cfg = OmegaConf.create(cfg) else: test_dl_idx = 0 # Set val_loss_idx model._test_dl_idx = test_dl_idx ds_key = resolve_dataset_name_from_cfg(cfg.test_ds) if ds_key is None: logging.debug( f"Could not resolve file path from provided config - {cfg.test_ds}. " "Disabling support for multi-dataloaders.") model.setup_test_data(cfg.test_ds) return ds_values = cfg.test_ds[ds_key] if isinstance(ds_values, (list, tuple, ListConfig)): for ds_value in ds_values: cfg.test_ds[ds_key] = ds_value model.setup_test_data(cfg.test_ds) dataloaders.append(model.test_dl) model.test_dl = dataloaders # type: ignore model.test_names = [parse_dataset_as_name(ds) for ds in ds_values] # type: ignore unique_names_check(name_list=model.test_names) return model.setup_test_data(cfg.test_ds) model.test_names = [parse_dataset_as_name(ds_values)] unique_names_check(name_list=model.test_names)
def configure_checkpointing(trainer: Trainer, log_dir: Path, name: str, resume: bool, params: "DictConfig"): """Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint callback or if trainer.weights_save_path was passed to Trainer. """ for callback in trainer.callbacks: if isinstance(callback, ModelCheckpoint): raise CheckpointMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint " "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback " "to False, or remove ModelCheckpoint from the lightning trainer" ) if Path(trainer.weights_save_path) != Path.cwd(): raise CheckpointMisconfigurationError( "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager" ) # Create the callback and attach it to trainer if "filepath" in params: if params.filepath is not None: logging.warning( "filepath is deprecated. Please switch to dirpath and filename instead" ) if params.dirpath is None: params.dirpath = Path(params.filepath).parent if params.filename is None: params.filename = Path(params.filepath).name with open_dict(params): del params["filepath"] if params.dirpath is None: params.dirpath = Path(log_dir / "checkpoints") if params.filename is None: params.filename = f"{name}--{{{params.monitor}:.4f}}-{{epoch}}" if params.prefix is None: params.prefix = name MRIDCModelCheckpoint.CHECKPOINT_NAME_LAST = f"{params.filename}-last" logging.debug(params.dirpath) logging.debug(params.filename) logging.debug(params.prefix) if "val" in params.monitor: if (trainer.max_epochs is not None and trainer.max_epochs != -1 and trainer.max_epochs < trainer.check_val_every_n_epoch): logging.error( "The checkpoint callback was told to monitor a validation value but trainer.max_epochs(" f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch(" f"{trainer.check_val_every_n_epoch}). It is very likely this run will fail with " f"ModelCheckpoint(monitor='{params.monitor}') not found in the returned metrics. Please ensure that " "validation is run within trainer.max_epochs.") elif trainer.max_steps is not None: logging.warning( "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to " f"{trainer.max_steps}. Please ensure that max_steps will run for at least " f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out." ) checkpoint_callback = MRIDCModelCheckpoint(n_resume=resume, **params) checkpoint_callback.last_model_path = trainer._checkpoint_connector.resume_from_checkpoint_fit_path or "" if "mp_rank" in checkpoint_callback.last_model_path or "tp_rank" in checkpoint_callback.last_model_path: checkpoint_callback.last_model_path = mridc.utils.model_utils.uninject_model_parallel_rank( # type: ignore checkpoint_callback.last_model_path) trainer.callbacks.append(checkpoint_callback)
def exp_manager( trainer: Trainer, cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]: """ exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning \ paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will \ get exp_dir, name, and version from the logger. Otherwise, it will use the exp_dir and name arguments to create \ the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir. The version can be a datetime string or an integer. Datetime version can be disabled if you use_datetime_version \ is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch \ lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file \ for each process to log their output into. exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from \ the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need \ multiple consecutive jobs), you need to avoid creating the version folders. Therefore, from v1.0.0, when \ resume_if_exists is set to True, creating the version folders is ignored. Parameters ---------- trainer: The lightning trainer object. cfg: Can have the following keys: - explicit_log_dir: Can be used to override exp_dir/name/version folder creation. Defaults to None, which \ will use exp_dir, name, and version to construct the logging directory. - exp_dir: The base directory to create the logging directory. Defaults to None, which logs to \ ./mridc_experiments. - name: The name of the experiment. Defaults to None which turns into "default" via name = name or "default". - version: The version of the experiment. Defaults to None which uses either a datetime string or lightning's \ TensorboardLogger system of using version_{int}. - use_datetime_version: Whether to use a datetime string for version. Defaults to True. - resume_if_exists: Whether this experiment is resuming from a previous run. If True, it sets \ trainer._checkpoint_connector.resume_from_checkpoint_fit_path so that the trainer should auto-resume. \ exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. From v1.0.0, when \ resume_if_exists is True, we would not create version folders to make it easier to find the log folder for \ next runs. - resume_past_end: exp_manager errors out if resume_if_exists is True and a checkpoint matching \*end.ckpt \ indicating a previous training run fully completed. This behaviour can be disabled, in which case the \ \*end.ckpt will be loaded by setting resume_past_end to True. Defaults to False. - resume_ignore_no_checkpoint: exp_manager errors out if resume_if_exists is True and no checkpoint could be \ found. This behaviour can be disabled, in which case exp_manager will print a message and continue without \ restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False. - create_tensorboard_logger: Whether to create a tensorboard logger and attach it to the pytorch lightning \ trainer. Defaults to True. - summary_writer_kwargs: A dictionary of kwargs that can be passed to lightning's TensorboardLogger class. \ Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None. - create_wandb_logger: Whether to create a Weights and Biases logger and attach it to the pytorch lightning \ trainer. Defaults to False. - wandb_logger_kwargs: A dictionary of kwargs that can be passed to lightning's WandBLogger class. Note that \ name and project are required parameters if create_wandb_logger is True. Defaults to None. - create_checkpoint_callback: Whether to create a ModelCheckpoint callback and attach it to the pytorch \ lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most recent \ checkpoint under \*last.ckpt, and the final checkpoint after training completes under \*end.ckpt. \ Defaults to True. - files_to_copy: A list of files to copy to the experiment logging directory. Defaults to None which copies \ no files. - log_local_rank_0_only: Whether to only create log files for local rank 0. Defaults to False. Set this to \ True if you are using DDP with many GPUs and do not want many log files in your exp dir. - log_global_rank_0_only: Whether to only create log files for global rank 0. Defaults to False. Set this to \ True if you are using DDP with many GPUs and do not want many log files in your exp dir. Returns ------- The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and version. """ # Add rank information to logger # Note: trainer.global_rank and trainer.is_global_zero are not set until trainer.fit, so have to hack around it local_rank = int(os.environ.get("LOCAL_RANK", 0)) global_rank = trainer.node_rank * trainer.num_devices + local_rank logging.rank = global_rank if cfg is None: logging.error( "exp_manager did not receive a cfg argument. It will be disabled.") return None if trainer.fast_dev_run: logging.info( "Trainer was called with fast_dev_run. exp_manager will return without any functionality." ) return None # Ensure passed cfg is compliant with ExpManagerConfig schema = OmegaConf.structured(ExpManagerConfig) if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) elif not isinstance(cfg, DictConfig): raise ValueError( f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig" ) cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True)) cfg = OmegaConf.merge(schema, cfg) error_checks( trainer, cfg ) # Ensures that trainer options are compliant with MRIDC and exp_manager arguments log_dir, exp_dir, name, version = get_log_dir( trainer=trainer, exp_dir=cfg.exp_dir, name=cfg.name, version=cfg.version, explicit_log_dir=cfg.explicit_log_dir, use_datetime_version=cfg.use_datetime_version, resume_if_exists=cfg.resume_if_exists, ) if cfg.resume_if_exists: check_resume(trainer, str(log_dir), cfg.resume_past_end, cfg.resume_ignore_no_checkpoint) checkpoint_name = name # If name returned from get_log_dir is "", use cfg.name for checkpointing if checkpoint_name is None or checkpoint_name == "": checkpoint_name = cfg.name or "default" cfg.name = name # Used for configure_loggers so that the log_dir is properly set even if name is "" cfg.version = version # update app_state with log_dir, exp_dir, etc app_state = AppState() app_state.log_dir = log_dir app_state.exp_dir = exp_dir app_state.name = name app_state.version = version app_state.checkpoint_name = checkpoint_name app_state.create_checkpoint_callback = cfg.create_checkpoint_callback app_state.checkpoint_callback_params = cfg.checkpoint_callback_params # Create the logging directory if it does not exist os.makedirs( log_dir, exist_ok=True ) # Cannot limit creation to global zero as all ranks write to own log file logging.info(f"Experiments will be logged at {log_dir}") trainer._default_root_dir = log_dir if cfg.log_local_rank_0_only is True and cfg.log_global_rank_0_only is True: raise ValueError( "Cannot set both log_local_rank_0_only and log_global_rank_0_only to True." "Please set either one or neither.") # This is set if the env var MRIDC_TESTING is set to True. mridc_testing = get_envbool(MRIDC_ENV_VARNAME_TESTING, False) log_file = log_dir / f"mridc_log_globalrank-{global_rank}_localrank-{local_rank}.txt" # Handle logging to file. Logs local rank 0 only if local_rank == 0 and cfg.log_local_rank_0_only and not mridc_testing: logging.add_file_handler(log_file) elif global_rank == 0 and cfg.log_global_rank_0_only and mridc_testing: logging.add_file_handler(log_file) else: logging.add_file_handler(log_file) # For some reason, LearningRateLogger requires trainer to have a logger. Safer to create logger on all ranks # not just global rank 0. if cfg.create_tensorboard_logger or cfg.create_wandb_logger: configure_loggers( trainer, [Path(exp_dir)], cfg.name, cfg.version, cfg.create_tensorboard_logger, cfg.summary_writer_kwargs, cfg.create_wandb_logger, cfg.wandb_logger_kwargs, ) # add loggers timing callbacks if cfg.log_step_timing: timing_callback = TimingCallback( timer_kwargs=cfg.step_timing_kwargs or {}) trainer.callbacks.insert(0, timing_callback) if cfg.create_checkpoint_callback: configure_checkpointing(trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params) if is_global_rank_zero(): # Move files_to_copy to folder and add git information if present if cfg.files_to_copy: for _file in cfg.files_to_copy: copy(Path(_file), log_dir) # Create files for cmd args and git info with open(log_dir / "cmd-args.log", "w", encoding="utf-8") as _file: _file.write(" ".join(sys.argv)) # Try to get git hash git_repo, git_hash = get_git_hash() if git_repo: with open(log_dir / "git-info.log", "w", encoding="utf-8") as _file: _file.write(f"commit hash: {git_hash}") _file.write(get_git_diff()) # Add err_file logging to global_rank zero logging.add_err_file_handler(log_dir / "mridc_error_log.txt") # Add lightning file logging to global_rank zero add_filehandlers_to_pl_logger(log_dir / "lightning_logs.txt", log_dir / "mridc_error_log.txt") return log_dir
def export( self, output: str, input_example=None, verbose=False, export_params=True, do_constant_folding=True, onnx_opset_version=None, try_script: bool = False, training=TrainingMode.EVAL, check_trace: bool = False, use_dynamic_axes: bool = True, dynamic_axes=None, check_tolerance=0.01, ): """ Export the module to a file. Parameters ---------- output: The output file path. input_example: A dictionary of input names and values. verbose: If True, print out the export process. export_params: If True, export the parameters of the module. do_constant_folding: If True, do constant folding. onnx_opset_version: The ONNX opset version to use. try_script: If True, try to export as TorchScript. training: Training mode for the export. check_trace: If True, check the trace of the exported model. use_dynamic_axes: If True, use dynamic axes for the export. dynamic_axes: A dictionary of input names and dynamic axes. check_tolerance: The tolerance for the check_trace. """ my_args = locals().copy() my_args.pop("self") exportables = [] for m in self.modules(): # type: ignore if isinstance(m, Exportable): exportables.append(m) qual_name = self.__module__ + "." + self.__class__.__qualname__ format = get_export_format(output) output_descr = f"{qual_name} exported to {format}" # Pytorch's default for None is too low, can't pass None through if onnx_opset_version is None: onnx_opset_version = 13 try: # Disable typechecks typecheck.set_typecheck_enabled(enabled=False) # Allow user to completely override forward method to export forward_method, old_forward_method = wrap_forward_method(self) # Set module mode with torch.onnx.select_model_mode_for_export( self, training), torch.inference_mode( ), torch.jit.optimized_execution(True): if input_example is None: input_example = self.input_module.input_example() # Remove i/o examples from args we propagate to enclosed Exportables my_args.pop("output") my_args.pop("input_example") # Run (possibly overridden) prepare methods before calling forward() for ex in exportables: ex._prepare_for_export(**my_args, noreplace=True) self._prepare_for_export(output=output, input_example=input_example, **my_args) input_list, input_dict = parse_input_example(input_example) input_names = self.input_names output_names = self.output_names output_example = tuple(self.forward( *input_list, **input_dict)) # type: ignore jitted_model = None if try_script: try: jitted_model = torch.jit.script(self) except Exception as e: logging.error(f"jit.script() failed!\n{e}") if format == ExportFormat.TORCHSCRIPT: if jitted_model is None: jitted_model = torch.jit.trace_module( self, { "forward": tuple(input_list) + tuple(input_dict.values()) }, strict=True, check_trace=check_trace, check_tolerance=check_tolerance, ) if not self.training: # type: ignore jitted_model = torch.jit.optimize_for_inference( jitted_model) if verbose: logging.info(f"JIT code:\n{jitted_model.code}") jitted_model.save(output) elif format == ExportFormat.ONNX: if jitted_model is None: jitted_model = self # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None and use_dynamic_axes: dynamic_axes = get_dynamic_axes( self.input_module.input_types, input_names) dynamic_axes.update( get_dynamic_axes(self.output_module.output_types, output_names)) torch.onnx.export( jitted_model, input_example, output, input_names=input_names, output_names=output_names, verbose=verbose, export_params=export_params, do_constant_folding=do_constant_folding, dynamic_axes=dynamic_axes, opset_version=onnx_opset_version, ) if check_trace: verify_runtime(output, input_list, input_dict, input_names, output_names, output_example) else: raise ValueError( f"Encountered unknown export format {format}.") finally: typecheck.set_typecheck_enabled(enabled=True) if forward_method: type(self).forward = old_forward_method # type: ignore self._export_teardown() return [output], [output_descr]