def restore_from( self, calling_cls, restore_path: str, override_config_path: Optional[Union[OmegaConf, str]] = None, map_location: Optional[torch.device] = None, strict: bool = True, return_config: bool = False, trainer: Trainer = None, ): """ Restores model instance (weights and configuration) into .mridc file Parameters ---------- calling_cls: The class of the model to be restored. restore_path: path to .mridc file from which model should be instantiated override_config_path: path to a yaml config that will override the internal config file or an OmegaConf/DictConfig object representing the model config. map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will select a GPU if available, falling back to CPU otherwise. strict: Passed to load_state_dict. By default, True. return_config: If set to true, will return just the underlying config of the restored model as an OmegaConf/DictConfig object without instantiating the model. trainer: Optional trainer object to be used for restoring the model. Returns ------- An instance of type cls or its underlying config (if return_config is set). """ # Get path where the command is executed - the artifacts will be "retrieved" there (original .mridc behavior) loaded_params = self.load_config_and_state_dict( calling_cls, restore_path, override_config_path, map_location, strict, return_config, trainer, ) if not isinstance(loaded_params, tuple): return loaded_params _, instance, state_dict = loaded_params self.load_instance_with_state_dict(instance, state_dict, strict) logging.info( f"Model {instance.__class__.__name__} was successfully restored from {restore_path}." ) return instance
def load_state_dict(self, state_dict): """Load the state of the optimizer.""" # Optimizer. optimizer_key = "optimizer" if optimizer_key not in state_dict: optimizer_key = "optimizer_state_dict" logging.info("***WARNING*** loading optimizer from " "an old checkpoint ...") self.optimizer.load_state_dict(state_dict[optimizer_key]) # Copy data for the main params. fp32_from_float16_params_key = "fp32_from_fp16_params" if fp32_from_float16_params_key not in state_dict: fp32_from_float16_params_key = "fp32_from_fp16" for current_group, saved_group in zip( self.fp32_from_float16_groups, state_dict[fp32_from_float16_params_key]): for current_param, saved_param in zip(current_group, saved_group): current_param.data.copy_(saved_param.data)
def initialize_distributed(args, backend="nccl"): """ Initialize distributed training. Parameters ---------- args: The arguments object. backend: The backend to use. default: "nccl" Returns ------- local_rank: The local rank of the process. rank: The rank of the process. world_size: The number of processes. """ # Get local rank in case it is provided. local_rank = args.local_rank # Get rank and world size. rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) logging.info( f"Initializing torch.distributed with local_rank: {local_rank}, rank: {rank}, world_size: {world_size}" ) # Set the device id. device = rank % torch.cuda.device_count() if local_rank is not None: device = local_rank torch.cuda.set_device(device) # Call the init process. init_method = "tcp://" master_ip = os.getenv("MASTER_ADDR", "localhost") master_port = os.getenv("MASTER_PORT", "6000") init_method += f"{master_ip}:{master_port}" torch.distributed.init_process_group(backend=backend, world_size=world_size, rank=rank, init_method=init_method) return local_rank, rank, world_size
def _del_model_without_trainer(self, filepath: str) -> None: """ Delete a model without a trainer. Parameters ---------- filepath: The path to the model to delete. """ app_state = AppState() if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: # filepath needs to be updated to include mp_rank filepath = mridc.utils.model_utils.inject_model_parallel_rank( filepath) # type: ignore # each model parallel rank needs to remove its model if is_global_rank_zero() or (app_state.model_parallel_size is not None and app_state.data_parallel_rank == 0): try: self._fs.rm(filepath) logging.info(f"Removed checkpoint: {filepath}") except FileNotFoundError: logging.info( f"Tried to remove checkpoint: {filepath} but failed.")
def main(cfg: DictConfig) -> None: """ Main function for training and running a model Parameters ---------- cfg: Configuration (yaml) file. DictConfig """ logging.info(f"Config: {OmegaConf.to_yaml(cfg)}") trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) model_name = (cfg.model["model_name"]).upper() if model_name == "CASCADENET": model = CascadeNet(cfg.model, trainer=trainer) elif model_name == "CIRIM": model = CIRIM(cfg.model, trainer=trainer) elif model_name == "CRNNET": model = CRNNet(cfg.model, trainer=trainer) elif model_name == "DUNET": model = DUNet(cfg.model, trainer=trainer) elif model_name in ("E2EVN", "VN"): model = VarNet(cfg.model, trainer=trainer) elif model_name == "JOINTICNET": model = JointICNet(cfg.model, trainer=trainer) elif model_name == "KIKINET": model = KIKINet(cfg.model, trainer=trainer) elif model_name == "LPDNET": model = LPDNet(cfg.model, trainer=trainer) elif model_name == "MULTIDOMAINNET": model = MultiDomainNet(cfg.model, trainer=trainer) elif model_name == "PICS": model = PICS(cfg.model, trainer=trainer) elif model_name == "RVN": model = RecurrentVarNet(cfg.model, trainer=trainer) elif model_name == "UNET": model = UNet(cfg.model, trainer=trainer) elif model_name == "VSNET": model = VSNet(cfg.model, trainer=trainer) elif model_name == "XPDNET": model = XPDNet(cfg.model, trainer=trainer) elif model_name == "ZF": model = ZF(cfg.model, trainer=trainer) else: raise NotImplementedError( f"{model_name} is not implemented in MRIDC. You can use one of the following methods: " "CASCADENET, CIRIM, CRNNET, DUNET, E2EVN, JOINTICNET, KIKINET, LPDNET, MULTIDOMAINNET, PICS, RVN, UNET, " "VSNET, XPDNET, or Zero-Filled. /n" "If you implemented a new model, please consider adding it through a PR on GitHub." ) if cfg.get("pretrained", None): checkpoint = cfg.get("checkpoint", None) logging.info(f"Loading pretrained model from {checkpoint}") model.load_state_dict(torch.load(checkpoint)["state_dict"]) if cfg.get("mode", None) == "train": logging.info("Validating") trainer.validate(model) logging.info("Training") trainer.fit(model) else: logging.info("Evaluating") trainer.test(model)
def assert_dataclass_signature_match( cls: "class_type", # type: ignore datacls: "dataclass", # type: ignore ignore_args: Optional[List[str]] = None, remap_args: Optional[Dict[str, str]] = None, ): """ Analyses the signature of a provided class and its respective data class, asserting that the dataclass signature matches the class __init__ signature. Note: This is not a value based check. This function only checks if all argument names exist on both class and dataclass and logs mismatches. Parameters ---------- cls: Any class type - but not an instance of a class. Pass type(x) where x is an instance if class type is not easily available. datacls: A corresponding dataclass for the above class. ignore_args: (Optional) A list of string argument names which are forcibly ignored, even if mismatched in the signature. Useful when a dataclass is a superset of the arguments of a class. remap_args: (Optional) A dictionary, mapping an argument name that exists (in either the class or its dataclass), to another name. Useful when argument names are mismatched between a class and its dataclass due to indirect instantiation via a helper method. Returns ------- A tuple containing information about the analysis: 1) A bool value which is True if the signatures matched exactly / after ignoring values. False otherwise. 2) A set of arguments names that exist in the class, but *do not* exist in the dataclass. If exact signature match occurs, this will be None instead. 3) A set of argument names that exist in the data class, but *do not* exist in the class itself. If exact signature match occurs, this will be None instead. """ class_sig = inspect.signature(cls.__init__) class_params = dict(**class_sig.parameters) class_params.pop("self") dataclass_sig = inspect.signature(datacls) dataclass_params = dict(**dataclass_sig.parameters) dataclass_params.pop("_target_", None) class_params = set(class_params.keys()) # type: ignore dataclass_params = set(dataclass_params.keys()) # type: ignore if remap_args is not None: for original_arg, new_arg in remap_args.items(): if original_arg in class_params: class_params.remove(original_arg) # type: ignore class_params.add(new_arg) # type: ignore logging.info( f"Remapped {original_arg} -> {new_arg} in {cls.__name__}") if original_arg in dataclass_params: dataclass_params.remove(original_arg) # type: ignore dataclass_params.add(new_arg) # type: ignore logging.info( f"Remapped {original_arg} -> {new_arg} in {datacls.__name__}" ) if ignore_args is not None: ignore_args = set(ignore_args) # type: ignore class_params = class_params - ignore_args # type: ignore dataclass_params = dataclass_params - ignore_args # type: ignore logging.info(f"Removing ignored arguments - {ignore_args}") intersection: Set[type] = set.intersection( class_params, dataclass_params) # type: ignore subset_cls = class_params - intersection # type: ignore subset_datacls = dataclass_params - intersection # type: ignore if (len(class_params) != len(dataclass_params) ) or len(subset_cls) > 0 or len(subset_datacls) > 0: logging.error(f"Class {cls.__name__} arguments do not match " f"Dataclass {datacls.__name__}!") if len(subset_cls) > 0: logging.error(f"Class {cls.__name__} has additional arguments :\n" f"{subset_cls}") if len(subset_datacls): logging.error( f"Dataclass {datacls.__name__} has additional arguments :\n{subset_datacls}" ) return False, subset_cls, subset_datacls return True, None, None
def verify_runtime( output, input_list, input_dict, input_names, output_names, output_example, check_tolerance=0.01, ): """ Verify runtime output with onnxrt. Parameters ---------- output: The output of the module. input_list: The input list of the module. input_dict: The input dict of the module. input_names: The input names of the module. output_names: The output names of the module. output_example: The output example of the module. check_tolerance: The tolerance for the check. Returns ------- The runtime output. """ # Verify the model can be read, and is valid onnx_model = onnx.load(output) input_names = [node.name for node in onnx_model.graph.input] # skipcq: PYL-W0622 global ort_available if not ort_available: logging.warning( f"ONNX generated at {output}, not verified - please install onnxruntime_gpu package.\n" ) onnx.checker.check_model(onnx_model, full_check=True) return onnx_session_opt = onnxruntime.SessionOptions() onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL sess = onnxruntime.InferenceSession(onnx_model.SerializeToString(), sess_options=onnx_session_opt, providers=["CUDAExecutionProvider"]) ort_out = sess.run(output_names, to_onnxrt_input(input_names, input_dict, input_list)) all_good = True for i, out in enumerate(ort_out[0]): expected = output_example[i] if torch.is_tensor(expected): tout = torch.from_numpy(out) if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): all_good = False logging.info( f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}" ) status = "SUCCESS" if all_good else "FAIL" logging.info( f"ONNX generated at {output} verified with onnxruntime : {status}") return all_good
def maybe_download_from_cloud(url, filename, subfolder=None, cache_dir=None, refresh_cache=False) -> str: """ Download a file from a URL if it does not exist in the cache. Parameters ---------- url: URL to download the file from. str filename: What to download. The request will be issued to url/filename str subfolder: Subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can be empty. str cache_dir: A cache directory where to download. If not present, this function will attempt to create it. str, If None (default), then it will be $HOME/.cache/torch/mridc refresh_cache: If True and cached file is present, it will delete it and re-fetch bool Returns ------- If successful - absolute local path to the downloaded file else empty string. """ if cache_dir is None: cache_location = Path.joinpath(Path.home(), ".cache/torch/mridc") else: cache_location = cache_dir if subfolder is not None: destination = Path.joinpath(cache_location, subfolder) else: destination = cache_location if not os.path.exists(destination): os.makedirs(destination, exist_ok=True) destination_file = Path.joinpath(destination, filename) if os.path.exists(destination_file): logging.info(f"Found existing object {destination_file}.") if refresh_cache: logging.info("Asked to refresh the cache.") logging.info(f"Deleting file: {destination_file}") os.remove(destination_file) else: logging.info(f"Re-using file from: {destination_file}") return str(destination_file) # download file wget_uri = url + filename logging.info(f"Downloading from: {wget_uri} to {str(destination_file)}") # NGC links do not work everytime so we try and wait i = 0 max_attempts = 3 while i < max_attempts: i += 1 try: wget.download(wget_uri, str(destination_file)) if os.path.exists(destination_file): return str(destination_file) return "" except Exception as e: logging.info( f"Download from cloud failed. Attempt {i} of {max_attempts}") logging.info(f"Error: {e}") sleep(0.05) continue raise ValueError("Not able to download url right now, please try again.")
def from_pretrained( cls, model_name: str, refresh_cache: bool = False, override_config_path: Optional[str] = None, map_location: Optional[torch.device] = None, strict: bool = True, return_config: bool = False, trainer: Optional[Trainer] = None, save_restore_connector: SaveRestoreConnector = None, ): """ Instantiates an instance of mridc. Use restore_from() to instantiate from a local .mridc file. Parameters ---------- model_name: String key which will be used to find the module. refresh_cache: If set to True, then when fetching from cloud, this will re-fetch the file from cloud even if it is already found in a cache locally. override_config_path: Path to a yaml config that will override the internal config file. map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will select a GPU if available, falling back to CPU otherwise. strict: Passed to torch.load_state_dict. By default, True. return_config: If set to true, will return just the underlying config of the restored model as an OmegaConf/DictConfig object without instantiating the model. trainer: Optional Trainer objects to use for restoring the model. save_restore_connector: Optional SaveRestoreConnector object to use for restoring the model. Returns ------- A model instance of a particular model class or its underlying config (if return_config is set). """ if save_restore_connector is None: save_restore_connector = SaveRestoreConnector() location_in_the_cloud = None description = None models = cls.list_available_models() if models is not None: for pretrained_model_info in cls.list_available_models(): # type: ignore found = False if pretrained_model_info.pretrained_model_name == model_name: found = True elif pretrained_model_info.aliases is not None: for alias in pretrained_model_info.aliases: if alias == model_name: found = True break if found: location_in_the_cloud = pretrained_model_info.location description = pretrained_model_info.description class_ = pretrained_model_info.class_ break if location_in_the_cloud is None: raise FileNotFoundError( f"Model {model_name} was not found. " "Check cls.list_available_models() for the list of all available models." ) filename = location_in_the_cloud.split("/")[-1] url = location_in_the_cloud.replace(filename, "") cache_dir = Path.joinpath(mridc.utils.model_utils.resolve_cache_dir(), f"{filename[:-5]}") # type: ignore # If either description and location in the cloud changes, this will force re-download # of the model. cache_subfolder = hashlib.sha512((location_in_the_cloud + description).encode("utf-8")).hexdigest() # if file exists on cache_folder/subfolder, it will be re-used, unless refresh_cache is True mridc_model_file_in_cache = maybe_download_from_cloud( url=url, filename=filename, cache_dir=cache_dir, subfolder=cache_subfolder, refresh_cache=refresh_cache ) logging.info("Instantiating model from pre-trained checkpoint") if class_ is None: class_ = cls return class_.restore_from( restore_path=mridc_model_file_in_cache, override_config_path=override_config_path, map_location=map_location, strict=strict, return_config=return_config, trainer=trainer, save_restore_connector=save_restore_connector, )
def prepare_lr_scheduler( optimizer: optim.Optimizer, scheduler_config: Union[Dict[str, Any], DictConfig, None], train_dataloader: Optional[dataloader.DataLoader] = None, ) -> Optional[Dict[str, Any]]: """ Constructs an LR Scheduler (optionally) for a given optimizer, based on a config with the following schema. Parameters ---------- optimizer: The optimizer to use for the scheduler. name: <name of optimizer> lr: <maximal learning rate> # <additional optimizer arguments> args: name: auto # special keyword, resolves to correct optimizer config for given optimizer name # cls: mridc.core.config.optimizers.NovogradParams # explicit instantiation by class path params: # optional override parameters for the optimizer config betas: [0.8, 0.5] weight_decay: 0.001 scheduler_config: The scheduler config. name: <name of scheduler> iters_per_batch: null # computed at runtime; mandatory to have max_steps: null # computed at runtime or explicitly set here; mandatory to have # pytorch lightning args <mandatory> monitor: val_loss reduce_on_plateau: false # <scheduler config override> args: name: auto # special keyword, resolves to correct optimizer config for given optimizer name # cls: mridc.core.config.schedulers.CosineAnnealingParams # explicit instantiation by class path params: # optional override parameters for the optimizer config warmup_steps: null warmup_ratio: null min_lr: 0.0 last_epoch: -1 train_dataloader: Optional requirement, must be passed if "iters_per_batch" is defined instead of "max_steps". \ Used to compute effective "max_steps". Returns ------- A dictionary containing the LR Scheduler implementation if the config was successfully parsed along with other \ parameters required by Pytorch Lightning, otherwise None. """ if scheduler_config is not None: scheduler_config = maybe_update_config_version(scheduler_config) # Build nested dictionary for convenience out of structured objects if isinstance(scheduler_config, DictConfig): scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True) elif dataclasses.is_dataclass(scheduler_config): # Recursively transform data classes to basic dictionaries scheduler_config = OmegaConf.create(scheduler_config) scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True) # Test to see if config follows above schema add_max_args_flag = True interval = "step" if scheduler_config is not None: if "args" in scheduler_config: scheduler_args = scheduler_config.pop("args") else: scheduler_args = copy.deepcopy(scheduler_config) # Remove extra parameters from scheduler_args nest # Assume all other parameters are to be passed into scheduler constructor if "name" in scheduler_args and scheduler_args[ "name"] == "ReduceLROnPlateau": add_max_args_flag = False interval = "epoch" scheduler_args.pop("name", None) scheduler_args.pop("t_max_epochs", None) scheduler_args.pop("t_accumulate_grad_batches", None) scheduler_args.pop("t_limit_train_batches", None) scheduler_args.pop("t_num_workers", None) scheduler_args.pop("monitor", None) scheduler_args.pop("reduce_on_plateau", None) else: # Return gracefully in case `sched` was not supplied; inform user logging.info( "Scheduler not initialized as no `sched` config supplied to setup_optimizer()" ) return None # Try instantiation of scheduler params from config class path if "_target_" in scheduler_args: scheduler_args_cfg = OmegaConf.create(scheduler_args) scheduler_conf = hydra.utils.instantiate(scheduler_args_cfg) scheduler_args = vars(scheduler_conf) # Get name of the scheduler scheduler_name = scheduler_conf.__class__.__name__ if "Params" in scheduler_name: scheduler_name = scheduler_name.replace("Params", "") else: # Class path instantiation failed; try resolving "name" component # Get name of the scheduler if "name" in scheduler_config: scheduler_name = scheduler_config["name"] else: logging.warning( "Could not resolve classpath for Scheduler Config, and `name` " "was not provided either. \n" "Scheduler cannot be instantiated !") return None # If class path was not provided, perhaps `name` is provided for resolution if "name" in scheduler_args: # If `auto` is passed as name for resolution of optimizer name, # then lookup optimizer name and resolve its parameter config if scheduler_args["name"] == "auto": scheduler_params_name = f"{scheduler_name}Params" else: scheduler_params_name = scheduler_args["name"] # Get override arguments provided in the config yaml file / Dict Config scheduler_params_override = scheduler_args.get("params", {}) # If params is itself a dict config object provided explicitly in Dict Config # Resolve to dictionary for convenience if isinstance(scheduler_params_override, DictConfig): scheduler_params_override = OmegaConf.to_container( scheduler_params_override, resolve=True) # Get and instantiate the Config dataclass for this scheduler scheduler_params_cls = get_scheduler_config( scheduler_params_name, **scheduler_params_override) scheduler_params = scheduler_params_cls # instantiate the parameters object scheduler_args = vars( scheduler_params ) # extract just the dictionary from the Config object # Extract value to monitor in losses, if provided. if "monitor" in scheduler_config: monitor = scheduler_config.get("monitor") else: # Default to train loss monitor = "loss" # Store exact max_steps if it is provided if "max_steps" in scheduler_config and scheduler_config[ "max_steps"] is not None: max_steps = scheduler_config["max_steps"] elif "t_max_epochs" in scheduler_config: # Compute effective max_steps if t_max_epochs is provided if train_dataloader is None: logging.warning( "As `t_max_epochs` is provided/computed, it is required to pass the train dataloader in order\n" "to compute effective maximum number of steps.\n" "Scheduler will not be instantiated !") return None # Raise exception if neither `max_steps` nor `t_max_epochs` is provided if scheduler_config.get("t_max_epochs", None) is None: logging.warning( "`t_max_epochs` cannot be None when `max_steps` is not not provided.\n" "This can occur when `train dataloader` is not available to correctly " "prepare the scheduler.\n" "Scheduler will not be instantiated !") return None # Get iters_per_batch max_epochs = scheduler_config.get("t_max_epochs") accumulate_grad_batches = scheduler_config.get( "t_accumulate_grad_batches") limit_train_batches = scheduler_config.get("t_limit_train_batches") num_workers = scheduler_config.get("t_num_workers") # Compute effective num max_steps num_samples = len(train_dataloader.dataset) # type: ignore # we may need to override ModelPT setup_optimization if train_dataloader.batch_size is not None: batch_size = train_dataloader.batch_size elif hasattr(train_dataloader, "batch_sampler" ) and train_dataloader.batch_sampler is not None: if train_dataloader.batch_sampler.micro_batch_size is not None: batch_size = train_dataloader.batch_sampler.micro_batch_size else: raise ValueError( f"Could not find batch_size from batch_sampler: {train_dataloader.batch_sampler}" ) else: raise ValueError( f"Could not find batch_size from train_dataloader: {train_dataloader}" ) drop_last = train_dataloader.drop_last max_steps = compute_max_steps( max_epochs=max_epochs, accumulate_grad_batches=accumulate_grad_batches, limit_train_batches=limit_train_batches, num_workers=num_workers, num_samples=num_samples, batch_size=batch_size, drop_last=drop_last, ) else: logging.warning( "Neither `max_steps` nor `iters_per_batch` were provided to `optim.sched`, " "cannot compute effective `max_steps` !\n" "Scheduler will not be instantiated !") return None # Inject max_steps (effective or provided) into the scheduler config if add_max_args_flag and scheduler_config.get("name", "") != "ExponentialLR": scheduler_args["max_steps"] = max_steps # Get the scheduler class from the config scheduler_cls = get_scheduler(scheduler_name, **scheduler_args) # Instantiate the LR schedule schedule = scheduler_cls(optimizer, **scheduler_args) logging.info( 'Scheduler "%s" \nwill be used during training (effective maximum steps = %d) - \nParameters : \n(%s)', str(schedule), max_steps, OmegaConf.to_yaml(OmegaConf.create(scheduler_args)), ) # Wrap the schedule in PTL arguments to perform stepwise computation # Rather than epoch level computation reduce_lr_on_plateau = isinstance(schedule, optim.lr_scheduler.ReduceLROnPlateau) return { "scheduler": schedule, "interval": interval, "frequency": 1, "monitor": monitor, "reduce_on_plateau": reduce_lr_on_plateau, }
def load_config_and_state_dict( self, calling_cls, restore_path: str, override_config_path: Optional[Union[OmegaConf, str]] = None, map_location: Optional[torch.device] = None, strict: bool = True, return_config: bool = False, trainer: Trainer = None, ): """ Restores model instance (weights and configuration) into .mridc file Parameters ---------- calling_cls: Class of the model to be restored. restore_path: path to .mridc file from which model should be instantiated override_config_path: path to a yaml config that will override the internal config file or an OmegaConf/DictConfig object representing the model config. map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will select a GPU if available, falling back to CPU otherwise. strict: Passed to load_state_dict. By default, True. return_config: If set to true, will return just the underlying config of the restored model as an OmegaConf DictConfig object without instantiating the model. trainer: Optional trainer object to be used for model parallelism. Example ------- ``` model = mridc.collections.asr.models.EncDecCTCModel.restore_from('asr.mridc') assert isinstance(model, mridc.collections.asr.models.EncDecCTCModel) ``` Returns ------- An instance of type cls or its underlying config (if return_config is set). """ # Get path where the command is executed - the artifacts will be "retrieved" there # (original .mridc behavior) cwd = os.getcwd() if map_location is None: if torch.cuda.is_available(): map_location = torch.device("cuda") else: map_location = torch.device("cpu") app_state = AppState() with tempfile.TemporaryDirectory() as tmpdir: try: # Check if self.model_extracted_dir is set, and is a valid path if self.model_extracted_dir is not None and os.path.isdir( self.model_extracted_dir): # Log that MRIDC will use the provided `model_extracted_dir` logging.info( "Restoration will occur within pre-extracted directory : " f"`{self.model_extracted_dir}`.") # Override `tmpdir` above with the pre-extracted `model_extracted_dir` tmpdir = self.model_extracted_dir else: # Extract the nemo file into the temporary directory self._unpack_mridc_file(path2file=restore_path, out_folder=tmpdir) # Change current working directory to the temporary directory os.chdir(tmpdir) if override_config_path is None: config_yaml = os.path.join(tmpdir, self.model_config_yaml) else: # can be str path or OmegaConf / DictConfig object config_yaml = override_config_path if not isinstance(config_yaml, (OmegaConf, DictConfig)): conf = OmegaConf.load(config_yaml) else: conf = config_yaml if override_config_path is not None: # Resolve the override config conf = OmegaConf.to_container(conf, resolve=True) conf = OmegaConf.create(conf) # If override is top level config, extract just `model` from it if "model" in conf: conf = conf.model if return_config: instance = conf return instance if app_state.model_parallel_rank is not None and app_state.model_parallel_size > 1: model_weights = self._inject_model_parallel_rank_for_ckpt( tmpdir, self.model_weights_ckpt) else: model_weights = os.path.join(tmpdir, self.model_weights_ckpt) OmegaConf.set_struct(conf, True) os.chdir(cwd) # get the class calling_cls._set_model_restore_state( is_being_restored=True, folder=tmpdir) # type: ignore instance = calling_cls.from_config_dict(config=conf, trainer=trainer) instance = instance.to(map_location) # add load_state_dict override if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: model_weights = self._inject_model_parallel_rank_for_ckpt( tmpdir, self.model_weights_ckpt) instance.load_state_dict(self._load_state_dict_from_disk( model_weights, map_location=map_location), strict=strict) logging.info( f"Model {instance.__class__.__name__} was successfully restored from {restore_path}." ) instance._set_model_restore_state( is_being_restored=False) # type: ignore finally: os.chdir(cwd) return instance
def extract_state_dict_from(self, restore_path: str, save_dir: str, split_by_module: bool = False): """ Extract the state dict(s) from a provided .mridc tarfile and save it to a directory. Parameters ---------- restore_path: path to .mridc file from which state dict(s) should be extracted save_dir: directory in which the saved state dict(s) should be stored split_by_module: bool flag, which determines whether the output checkpoint should be for the entire Model, or the individual module's that comprise the Model. Example ------- To convert the .mridc tarfile into a single Model level PyTorch checkpoint :: state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc', './asr_ckpts') To restore a model from a Model level checkpoint :: model = mridc.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration model.load_state_dict(torch.load("./asr_ckpts/model_weights.ckpt")) To convert the .mridc tarfile into multiple Module level PyTorch checkpoints :: state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc', './asr_ckpts', split_by_module=True). To restore a module from a Module level checkpoint :: model = mridc.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration # load the individual components model.preprocessor.load_state_dict(torch.load("./asr_ckpts/preprocessor.ckpt")) model.encoder.load_state_dict(torch.load("./asr_ckpts/encoder.ckpt")) model.decoder.load_state_dict(torch.load("./asr_ckpts/decoder.ckpt")) Returns ------- The state dict that was loaded from the original .mridc checkpoint. """ cwd = os.getcwd() save_dir = os.path.abspath(save_dir) if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) with tempfile.TemporaryDirectory() as tmpdir: try: self._unpack_mridc_file(path2file=restore_path, out_folder=tmpdir) os.chdir(tmpdir) model_weights = os.path.join(tmpdir, self.model_weights_ckpt) state_dict = self._load_state_dict_from_disk(model_weights) if not split_by_module: filepath = os.path.join(save_dir, self.model_weights_ckpt) self._save_state_dict_to_disk(state_dict, filepath) else: key_set = {key.split(".")[0] for key in state_dict.keys()} for primary_key in key_set: inner_keys = [ key for key in state_dict.keys() if key.split(".")[0] == primary_key ] state_dict_subset = { ".".join(inner_key.split(".")[1:]): state_dict[inner_key] for inner_key in inner_keys } filepath = os.path.join(save_dir, f"{primary_key}.ckpt") self._save_state_dict_to_disk(state_dict_subset, filepath) logging.info( f"Checkpoints from {restore_path} were successfully extracted into {save_dir}." ) finally: os.chdir(cwd) return state_dict
def configure_loggers( trainer: Trainer, exp_dir: List[Union[Path, str]], name: str, version: str, create_tensorboard_logger: bool, summary_writer_kwargs: dict, create_wandb_logger: bool, wandb_kwargs: dict, ): """ Creates TensorboardLogger and/or WandBLogger and attach them to trainer. Raises ValueError if summary_writer_kwargs or wandb_kwargs are miss configured. Parameters ---------- trainer: The trainer to attach the loggers to. exp_dir: The experiment directory. name: The name of the experiment. version: The version of the experiment. create_tensorboard_logger: Whether to create a TensorboardLogger. summary_writer_kwargs: The kwargs to pass to the TensorboardLogger. create_wandb_logger: Whether to create a Weights & Biases logger. wandb_kwargs: The kwargs to pass to the Weights & Biases logger. Returns ------- LoggerList: A list of loggers. """ # Potentially create tensorboard logger and/or WandBLogger logger_list = [] if create_tensorboard_logger: if summary_writer_kwargs is None: summary_writer_kwargs = {} elif "log_dir" in summary_writer_kwargs: raise ValueError( "You cannot pass `log_dir` as part of `summary_writer_kwargs`. `log_dir` is handled by lightning's " "TensorBoardLogger logger.") tensorboard_logger = TensorBoardLogger(save_dir=exp_dir[0], name=name, version=version, **summary_writer_kwargs) logger_list.append(tensorboard_logger) logging.info("TensorboardLogger has been set up") if create_wandb_logger: if wandb_kwargs is None: wandb_kwargs = {} if "name" not in wandb_kwargs and "project" not in wandb_kwargs: raise ValueError("name and project are required for wandb_logger") wandb_logger = WandbLogger(save_dir=exp_dir[0], version=version, **wandb_kwargs) logger_list.append(wandb_logger) logging.info("WandBLogger has been set up") logger_list = (LoggerList( logger_list, mridc_name=name, mridc_version=version) if len(logger_list) > 1 else logger_list[0]) trainer._logger_connector.configure_logger(logger_list)
def check_resume( trainer: Trainer, log_dir: str, resume_past_end: bool = False, resume_ignore_no_checkpoint: bool = False, ): """ Checks that resume=True was used correctly with the arguments pass to exp_manager. Sets trainer._checkpoint_connector.resume_from_checkpoint_fit_path as necessary. Parameters ---------- trainer: The trainer that is being used. log_dir: The directory where the logs are being saved. resume_past_end: Whether to resume from the end of the experiment. resume_ignore_no_checkpoint: Whether to ignore if there is no checkpoint to resume from. Returns ------- NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found. ValueError: If resume is True, and there were more than 1 checkpoint could found. """ if not log_dir: raise ValueError( f"Resuming requires the log_dir {log_dir} to be passed to exp_manager" ) checkpoint_dir = Path(Path(log_dir) / "checkpoints") checkpoint = None end_checkpoints = list(checkpoint_dir.rglob("*end.ckpt")) last_checkpoints = list(checkpoint_dir.rglob("*last.ckpt")) if not checkpoint_dir.exists(): if not resume_ignore_no_checkpoint: raise NotFoundError( f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume." ) logging.warning( f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Training from scratch." ) return if end_checkpoints: if not resume_past_end: raise ValueError( f"Found {end_checkpoints[0]} indicating that the last training run has already completed." ) if len(end_checkpoints) > 1: if "mp_rank" in str(end_checkpoints[0]): checkpoint = end_checkpoints[0] else: raise ValueError( f"Multiple checkpoints {end_checkpoints} that matches *end.ckpt." ) logging.info(f"Resuming from {end_checkpoints[0]}") elif not last_checkpoints: if not resume_ignore_no_checkpoint: raise NotFoundError( f"There were no checkpoints found in {checkpoint_dir}. Cannot resume." ) logging.warning( f"There were no checkpoints found in {checkpoint_dir}. Training from scratch." ) return elif len(last_checkpoints) > 1: if "mp_rank" not in str(last_checkpoints[0]) and "tp_rank" not in str( last_checkpoints[0]): raise ValueError( f"Multiple checkpoints {last_checkpoints} that matches *last.ckpt." ) checkpoint = last_checkpoints[0] checkpoint = mridc.utils.model_utils.uninject_model_parallel_rank( checkpoint) # type: ignore else: logging.info(f"Resuming from {last_checkpoints[0]}") checkpoint = last_checkpoints[0] trainer._checkpoint_connector.resume_from_checkpoint_fit_path = str( checkpoint) if is_global_rank_zero(): if files_to_move := [ child for child in Path(log_dir).iterdir() if child.is_file() ]: # Move old files to a new folder other_run_dirs = Path(log_dir).glob("run_*") run_count = sum(bool(fold.is_dir()) for fold in other_run_dirs) new_run_dir = Path(Path(log_dir) / f"run_{run_count}") new_run_dir.mkdir() for _file in files_to_move: move(str(_file), str(new_run_dir))
def exp_manager( trainer: Trainer, cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]: """ exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning \ paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will \ get exp_dir, name, and version from the logger. Otherwise, it will use the exp_dir and name arguments to create \ the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir. The version can be a datetime string or an integer. Datetime version can be disabled if you use_datetime_version \ is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch \ lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file \ for each process to log their output into. exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from \ the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need \ multiple consecutive jobs), you need to avoid creating the version folders. Therefore, from v1.0.0, when \ resume_if_exists is set to True, creating the version folders is ignored. Parameters ---------- trainer: The lightning trainer object. cfg: Can have the following keys: - explicit_log_dir: Can be used to override exp_dir/name/version folder creation. Defaults to None, which \ will use exp_dir, name, and version to construct the logging directory. - exp_dir: The base directory to create the logging directory. Defaults to None, which logs to \ ./mridc_experiments. - name: The name of the experiment. Defaults to None which turns into "default" via name = name or "default". - version: The version of the experiment. Defaults to None which uses either a datetime string or lightning's \ TensorboardLogger system of using version_{int}. - use_datetime_version: Whether to use a datetime string for version. Defaults to True. - resume_if_exists: Whether this experiment is resuming from a previous run. If True, it sets \ trainer._checkpoint_connector.resume_from_checkpoint_fit_path so that the trainer should auto-resume. \ exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. From v1.0.0, when \ resume_if_exists is True, we would not create version folders to make it easier to find the log folder for \ next runs. - resume_past_end: exp_manager errors out if resume_if_exists is True and a checkpoint matching \*end.ckpt \ indicating a previous training run fully completed. This behaviour can be disabled, in which case the \ \*end.ckpt will be loaded by setting resume_past_end to True. Defaults to False. - resume_ignore_no_checkpoint: exp_manager errors out if resume_if_exists is True and no checkpoint could be \ found. This behaviour can be disabled, in which case exp_manager will print a message and continue without \ restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False. - create_tensorboard_logger: Whether to create a tensorboard logger and attach it to the pytorch lightning \ trainer. Defaults to True. - summary_writer_kwargs: A dictionary of kwargs that can be passed to lightning's TensorboardLogger class. \ Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None. - create_wandb_logger: Whether to create a Weights and Biases logger and attach it to the pytorch lightning \ trainer. Defaults to False. - wandb_logger_kwargs: A dictionary of kwargs that can be passed to lightning's WandBLogger class. Note that \ name and project are required parameters if create_wandb_logger is True. Defaults to None. - create_checkpoint_callback: Whether to create a ModelCheckpoint callback and attach it to the pytorch \ lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most recent \ checkpoint under \*last.ckpt, and the final checkpoint after training completes under \*end.ckpt. \ Defaults to True. - files_to_copy: A list of files to copy to the experiment logging directory. Defaults to None which copies \ no files. - log_local_rank_0_only: Whether to only create log files for local rank 0. Defaults to False. Set this to \ True if you are using DDP with many GPUs and do not want many log files in your exp dir. - log_global_rank_0_only: Whether to only create log files for global rank 0. Defaults to False. Set this to \ True if you are using DDP with many GPUs and do not want many log files in your exp dir. Returns ------- The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and version. """ # Add rank information to logger # Note: trainer.global_rank and trainer.is_global_zero are not set until trainer.fit, so have to hack around it local_rank = int(os.environ.get("LOCAL_RANK", 0)) global_rank = trainer.node_rank * trainer.num_devices + local_rank logging.rank = global_rank if cfg is None: logging.error( "exp_manager did not receive a cfg argument. It will be disabled.") return None if trainer.fast_dev_run: logging.info( "Trainer was called with fast_dev_run. exp_manager will return without any functionality." ) return None # Ensure passed cfg is compliant with ExpManagerConfig schema = OmegaConf.structured(ExpManagerConfig) if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) elif not isinstance(cfg, DictConfig): raise ValueError( f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig" ) cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True)) cfg = OmegaConf.merge(schema, cfg) error_checks( trainer, cfg ) # Ensures that trainer options are compliant with MRIDC and exp_manager arguments log_dir, exp_dir, name, version = get_log_dir( trainer=trainer, exp_dir=cfg.exp_dir, name=cfg.name, version=cfg.version, explicit_log_dir=cfg.explicit_log_dir, use_datetime_version=cfg.use_datetime_version, resume_if_exists=cfg.resume_if_exists, ) if cfg.resume_if_exists: check_resume(trainer, str(log_dir), cfg.resume_past_end, cfg.resume_ignore_no_checkpoint) checkpoint_name = name # If name returned from get_log_dir is "", use cfg.name for checkpointing if checkpoint_name is None or checkpoint_name == "": checkpoint_name = cfg.name or "default" cfg.name = name # Used for configure_loggers so that the log_dir is properly set even if name is "" cfg.version = version # update app_state with log_dir, exp_dir, etc app_state = AppState() app_state.log_dir = log_dir app_state.exp_dir = exp_dir app_state.name = name app_state.version = version app_state.checkpoint_name = checkpoint_name app_state.create_checkpoint_callback = cfg.create_checkpoint_callback app_state.checkpoint_callback_params = cfg.checkpoint_callback_params # Create the logging directory if it does not exist os.makedirs( log_dir, exist_ok=True ) # Cannot limit creation to global zero as all ranks write to own log file logging.info(f"Experiments will be logged at {log_dir}") trainer._default_root_dir = log_dir if cfg.log_local_rank_0_only is True and cfg.log_global_rank_0_only is True: raise ValueError( "Cannot set both log_local_rank_0_only and log_global_rank_0_only to True." "Please set either one or neither.") # This is set if the env var MRIDC_TESTING is set to True. mridc_testing = get_envbool(MRIDC_ENV_VARNAME_TESTING, False) log_file = log_dir / f"mridc_log_globalrank-{global_rank}_localrank-{local_rank}.txt" # Handle logging to file. Logs local rank 0 only if local_rank == 0 and cfg.log_local_rank_0_only and not mridc_testing: logging.add_file_handler(log_file) elif global_rank == 0 and cfg.log_global_rank_0_only and mridc_testing: logging.add_file_handler(log_file) else: logging.add_file_handler(log_file) # For some reason, LearningRateLogger requires trainer to have a logger. Safer to create logger on all ranks # not just global rank 0. if cfg.create_tensorboard_logger or cfg.create_wandb_logger: configure_loggers( trainer, [Path(exp_dir)], cfg.name, cfg.version, cfg.create_tensorboard_logger, cfg.summary_writer_kwargs, cfg.create_wandb_logger, cfg.wandb_logger_kwargs, ) # add loggers timing callbacks if cfg.log_step_timing: timing_callback = TimingCallback( timer_kwargs=cfg.step_timing_kwargs or {}) trainer.callbacks.insert(0, timing_callback) if cfg.create_checkpoint_callback: configure_checkpointing(trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params) if is_global_rank_zero(): # Move files_to_copy to folder and add git information if present if cfg.files_to_copy: for _file in cfg.files_to_copy: copy(Path(_file), log_dir) # Create files for cmd args and git info with open(log_dir / "cmd-args.log", "w", encoding="utf-8") as _file: _file.write(" ".join(sys.argv)) # Try to get git hash git_repo, git_hash = get_git_hash() if git_repo: with open(log_dir / "git-info.log", "w", encoding="utf-8") as _file: _file.write(f"commit hash: {git_hash}") _file.write(get_git_diff()) # Add err_file logging to global_rank zero logging.add_err_file_handler(log_dir / "mridc_error_log.txt") # Add lightning file logging to global_rank zero add_filehandlers_to_pl_logger(log_dir / "lightning_logs.txt", log_dir / "mridc_error_log.txt") return log_dir
def export( self, output: str, input_example=None, verbose=False, export_params=True, do_constant_folding=True, onnx_opset_version=None, try_script: bool = False, training=TrainingMode.EVAL, check_trace: bool = False, use_dynamic_axes: bool = True, dynamic_axes=None, check_tolerance=0.01, ): """ Export the module to a file. Parameters ---------- output: The output file path. input_example: A dictionary of input names and values. verbose: If True, print out the export process. export_params: If True, export the parameters of the module. do_constant_folding: If True, do constant folding. onnx_opset_version: The ONNX opset version to use. try_script: If True, try to export as TorchScript. training: Training mode for the export. check_trace: If True, check the trace of the exported model. use_dynamic_axes: If True, use dynamic axes for the export. dynamic_axes: A dictionary of input names and dynamic axes. check_tolerance: The tolerance for the check_trace. """ my_args = locals().copy() my_args.pop("self") exportables = [] for m in self.modules(): # type: ignore if isinstance(m, Exportable): exportables.append(m) qual_name = self.__module__ + "." + self.__class__.__qualname__ format = get_export_format(output) output_descr = f"{qual_name} exported to {format}" # Pytorch's default for None is too low, can't pass None through if onnx_opset_version is None: onnx_opset_version = 13 try: # Disable typechecks typecheck.set_typecheck_enabled(enabled=False) # Allow user to completely override forward method to export forward_method, old_forward_method = wrap_forward_method(self) # Set module mode with torch.onnx.select_model_mode_for_export( self, training), torch.inference_mode( ), torch.jit.optimized_execution(True): if input_example is None: input_example = self.input_module.input_example() # Remove i/o examples from args we propagate to enclosed Exportables my_args.pop("output") my_args.pop("input_example") # Run (possibly overridden) prepare methods before calling forward() for ex in exportables: ex._prepare_for_export(**my_args, noreplace=True) self._prepare_for_export(output=output, input_example=input_example, **my_args) input_list, input_dict = parse_input_example(input_example) input_names = self.input_names output_names = self.output_names output_example = tuple(self.forward( *input_list, **input_dict)) # type: ignore jitted_model = None if try_script: try: jitted_model = torch.jit.script(self) except Exception as e: logging.error(f"jit.script() failed!\n{e}") if format == ExportFormat.TORCHSCRIPT: if jitted_model is None: jitted_model = torch.jit.trace_module( self, { "forward": tuple(input_list) + tuple(input_dict.values()) }, strict=True, check_trace=check_trace, check_tolerance=check_tolerance, ) if not self.training: # type: ignore jitted_model = torch.jit.optimize_for_inference( jitted_model) if verbose: logging.info(f"JIT code:\n{jitted_model.code}") jitted_model.save(output) elif format == ExportFormat.ONNX: if jitted_model is None: jitted_model = self # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None and use_dynamic_axes: dynamic_axes = get_dynamic_axes( self.input_module.input_types, input_names) dynamic_axes.update( get_dynamic_axes(self.output_module.output_types, output_names)) torch.onnx.export( jitted_model, input_example, output, input_names=input_names, output_names=output_names, verbose=verbose, export_params=export_params, do_constant_folding=do_constant_folding, dynamic_axes=dynamic_axes, opset_version=onnx_opset_version, ) if check_trace: verify_runtime(output, input_list, input_dict, input_names, output_names, output_example) else: raise ValueError( f"Encountered unknown export format {format}.") finally: typecheck.set_typecheck_enabled(enabled=True) if forward_method: type(self).forward = old_forward_method # type: ignore self._export_teardown() return [output], [output_descr]