def on_train_end(self, trainer, pl_module): """ This is called at the end of training. Parameters ---------- trainer: The trainer object. pl_module: The PyTorch-Lightning module. """ if trainer.fast_dev_run: return None # Call parent on_train_end() to save the -last checkpoint super().on_train_end(trainer, pl_module) # Load the best model and then re-save it if self.save_best_model: # wait for all processes to finish trainer.training_type_plugin.barrier( "SaveBestCheckpointConnector.resume_end") if self.best_model_path == "": logging.warning( f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints " "were found. Saving latest model instead.") else: trainer._checkpoint_connector.restore(self.best_model_path) if self.save_mridc_on_train_end: pl_module.save_to( save_path=os.path.join(self.dirpath, self.prefix + self.postfix))
def _convert_config(cfg: "OmegaConf"): """Recursive function converting the configuration from old hydra format to the new one.""" if not _HAS_HYDRA: logging.error( "This function requires Hydra/OmegaConf and it was not installed.") sys.exit(1) # Get rid of cls -> _target_. if "cls" in cfg and "_target_" not in cfg: cfg._target_ = cfg.pop("cls") # type: ignore # Get rid of params. if "params" in cfg: params = cfg.pop("params") # type: ignore for param_key, param_val in params.items(): cfg[param_key] = param_val # Recursion. try: for _, sub_cfg in cfg.items(): # type: ignore if isinstance(sub_cfg, DictConfig): _convert_config(sub_cfg) # type: ignore except OmegaConfBaseException as e: logging.warning( f"Skipped conversion for config/subconfig:\n{cfg}\n Reason: {e}.")
def __init__( self, always_save_mridc=False, save_mridc_on_train_end=True, save_best_model=False, postfix=".mridc", n_resume=False, model_parallel_size=None, **kwargs, ): """ Parameters ---------- always_save_mridc: Whether to save the model even if it is not the best model. Default: False. save_mridc_on_train_end: Whether to save the model at the end of training. Default: True. save_best_model: Whether to save the model if it is the best model. Default: False. postfix: The postfix to add to the model name. Default: ".mridc". n_resume: Whether to resume training from a checkpoint. Default: False. model_parallel_size: The size of the model parallel group. Default: None. kwargs: The kwargs to pass to ModelCheckpoint. """ # Parse and store "extended" parameters: save_best model and postfix. self.always_save_mridc = always_save_mridc self.save_mridc_on_train_end = save_mridc_on_train_end self.save_best_model = save_best_model self.previous_model_path = None self.last_model_path: Union[Any, str] = None if self.save_best_model and not self.save_mridc_on_train_end: logging.warning(( "Found save_best_model is True and save_mridc_on_train_end is False. " "Set save_mridc_on_train_end to True to automatically save the best model." )) self.postfix = postfix self.previous_best_path = "" self.model_parallel_size = model_parallel_size # `prefix` is deprecated self.prefix = kwargs.pop("prefix") if "prefix" in kwargs else "" # Call the parent class constructor with the remaining kwargs. super().__init__(**kwargs) if self.save_top_k != -1 and n_resume: logging.debug("Checking previous runs") self.mridc_topk_check_previous_run()
def wrapped(x): """ Wrapper function. Parameters ---------- x: The class to be decorated. class Returns ------- The decorated class with the experimental flag set. """ logging.warning( f"Module {x} is experimental, not ready for production and is not fully supported. Use at your own risk." ) return x
def compute_max_steps(max_epochs, accumulate_grad_batches, limit_train_batches, num_workers, num_samples, batch_size, drop_last): """Compute effective max_steps from the provided parameters.""" _round = math.floor if drop_last else math.ceil sampler_num_samples = math.ceil(num_samples / max(1, num_workers)) if drop_last and num_workers > 1: logging.warning( "Please note that drop_last is broken in pytorch 1.6.0. We will fix when pytorch 1.7.0 is released" ) steps_per_epoch = _round(sampler_num_samples / batch_size) if isinstance(limit_train_batches, int) or limit_train_batches == 0.0: steps_per_epoch = min(steps_per_epoch, int(limit_train_batches)) elif steps_per_epoch != float("inf"): # limit_train_batches is a percentage of batches per epoch steps_per_epoch = int(steps_per_epoch * limit_train_batches) return math.ceil(steps_per_epoch / accumulate_grad_batches) * max_epochs
def check_explicit_log_dir(trainer: Trainer, explicit_log_dir: List[Union[Path, str]], exp_dir: str, name: str, version: str) -> Tuple[Path, str, str, str]: """ Checks that the passed arguments are compatible with explicit_log_dir. Parameters ---------- trainer: The trainer to check. explicit_log_dir: The explicit log dir to check. exp_dir: The experiment directory to check. name: The experiment name to check. version: The experiment version to check. Returns ------- The log_dir, exp_dir, name, and version that should be used. Raises ------ LoggerMisconfigurationError """ if trainer.logger is not None: raise LoggerMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a logger and explicit_log_dir: " f"{explicit_log_dir} was pass to exp_manager. Please remove the logger from the lightning trainer." ) # Checking only (explicit_log_dir) vs (exp_dir and version). # The `name` will be used as the actual name of checkpoint/archive. if exp_dir or version: logging.error( f"exp_manager received explicit_log_dir: {explicit_log_dir} and at least one of exp_dir: {exp_dir}, " f"or version: {version}. Please note that exp_dir, name, and version will be ignored." ) if is_global_rank_zero() and Path(str(explicit_log_dir)).exists(): logging.warning( f"Exp_manager is logging to {explicit_log_dir}, but it already exists." ) return Path(str(explicit_log_dir)), str(explicit_log_dir), "", ""
def wrapper(_wrapped, args, kwargs): """ Prints the adequate warning (only once per function) when required and calls the function func, passing the original arguments, i.e. version and explanation. Parameters ---------- _wrapped: The function to be decorated. args: The arguments passed to the function to be decorated. kwargs: The keyword arguments passed to the function to be decorated. Returns ------- The decorated function. """ # Check if we already warned about that function. if _wrapped.__name__ not in _PRINTED_WARNING: # Add to list so we won't print it again. _PRINTED_WARNING[_wrapped.__name__] = True # Prepare the warning message. entity_name = "Class" if inspect.isclass(wrapped) else "Function" msg = f"{entity_name} '{_wrapped.__name__}' is deprecated." # Optionally, add version and explanation. if version is not None: msg = f"{msg} It is going to be removed in the {version} version." if explanation is not None: msg = f"{msg} {explanation}" # Display the deprecated warning. logging.warning(msg) # Call the function. return _wrapped(*args, **kwargs)
def unique_names_check(name_list: Optional[List[str]]): """ Performs a uniqueness check on the name list resolved, so that it can warn users about non-unique keys. Parameters ---------- name_list: List of strings resolved for data loaders. """ if name_list is None: return # Name uniqueness checks names = set() for name in name_list: if name in names: logging.warning( "Name resolution has found more than one data loader having the same name !\n" "In such cases, logs will nor be properly generated. " "Please rename the item to have unique names.\n" f"Resolved name : {name}") else: names.add( name ) # we need just hash key check, value is just a placeholder
def verify_runtime( output, input_list, input_dict, input_names, output_names, output_example, check_tolerance=0.01, ): """ Verify runtime output with onnxrt. Parameters ---------- output: The output of the module. input_list: The input list of the module. input_dict: The input dict of the module. input_names: The input names of the module. output_names: The output names of the module. output_example: The output example of the module. check_tolerance: The tolerance for the check. Returns ------- The runtime output. """ # Verify the model can be read, and is valid onnx_model = onnx.load(output) input_names = [node.name for node in onnx_model.graph.input] # skipcq: PYL-W0622 global ort_available if not ort_available: logging.warning( f"ONNX generated at {output}, not verified - please install onnxruntime_gpu package.\n" ) onnx.checker.check_model(onnx_model, full_check=True) return onnx_session_opt = onnxruntime.SessionOptions() onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL sess = onnxruntime.InferenceSession(onnx_model.SerializeToString(), sess_options=onnx_session_opt, providers=["CUDAExecutionProvider"]) ort_out = sess.run(output_names, to_onnxrt_input(input_names, input_dict, input_list)) all_good = True for i, out in enumerate(ort_out[0]): expected = output_example[i] if torch.is_tensor(expected): tout = torch.from_numpy(out) if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): all_good = False logging.info( f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}" ) status = "SUCCESS" if all_good else "FAIL" logging.info( f"ONNX generated at {output} verified with onnxruntime : {status}") return all_good
Parameters ---------- model: Top-level model to replace modules in. expansions: A dictionary of module class names to functions to replace them with. Returns ------- The model with replaced modules. """ mapping: Dict[str, nn.Module] = {} for name, m in model.named_modules(): m_type = type(m).__name__ if m_type in expansions: # type: ignore if swapped := expansions[m_type](m): # type: ignore mapping[name] = swapped logging.warning(f"Swapped {len(mapping)} modules") swap_modules(model, mapping) return model default_replacements = { "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), } def replace_for_export(model: nn.Module) -> nn.Module: """ Top-level function to replace default set of modules in model NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
def prepare_lr_scheduler( optimizer: optim.Optimizer, scheduler_config: Union[Dict[str, Any], DictConfig, None], train_dataloader: Optional[dataloader.DataLoader] = None, ) -> Optional[Dict[str, Any]]: """ Constructs an LR Scheduler (optionally) for a given optimizer, based on a config with the following schema. Parameters ---------- optimizer: The optimizer to use for the scheduler. name: <name of optimizer> lr: <maximal learning rate> # <additional optimizer arguments> args: name: auto # special keyword, resolves to correct optimizer config for given optimizer name # cls: mridc.core.config.optimizers.NovogradParams # explicit instantiation by class path params: # optional override parameters for the optimizer config betas: [0.8, 0.5] weight_decay: 0.001 scheduler_config: The scheduler config. name: <name of scheduler> iters_per_batch: null # computed at runtime; mandatory to have max_steps: null # computed at runtime or explicitly set here; mandatory to have # pytorch lightning args <mandatory> monitor: val_loss reduce_on_plateau: false # <scheduler config override> args: name: auto # special keyword, resolves to correct optimizer config for given optimizer name # cls: mridc.core.config.schedulers.CosineAnnealingParams # explicit instantiation by class path params: # optional override parameters for the optimizer config warmup_steps: null warmup_ratio: null min_lr: 0.0 last_epoch: -1 train_dataloader: Optional requirement, must be passed if "iters_per_batch" is defined instead of "max_steps". \ Used to compute effective "max_steps". Returns ------- A dictionary containing the LR Scheduler implementation if the config was successfully parsed along with other \ parameters required by Pytorch Lightning, otherwise None. """ if scheduler_config is not None: scheduler_config = maybe_update_config_version(scheduler_config) # Build nested dictionary for convenience out of structured objects if isinstance(scheduler_config, DictConfig): scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True) elif dataclasses.is_dataclass(scheduler_config): # Recursively transform data classes to basic dictionaries scheduler_config = OmegaConf.create(scheduler_config) scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True) # Test to see if config follows above schema add_max_args_flag = True interval = "step" if scheduler_config is not None: if "args" in scheduler_config: scheduler_args = scheduler_config.pop("args") else: scheduler_args = copy.deepcopy(scheduler_config) # Remove extra parameters from scheduler_args nest # Assume all other parameters are to be passed into scheduler constructor if "name" in scheduler_args and scheduler_args[ "name"] == "ReduceLROnPlateau": add_max_args_flag = False interval = "epoch" scheduler_args.pop("name", None) scheduler_args.pop("t_max_epochs", None) scheduler_args.pop("t_accumulate_grad_batches", None) scheduler_args.pop("t_limit_train_batches", None) scheduler_args.pop("t_num_workers", None) scheduler_args.pop("monitor", None) scheduler_args.pop("reduce_on_plateau", None) else: # Return gracefully in case `sched` was not supplied; inform user logging.info( "Scheduler not initialized as no `sched` config supplied to setup_optimizer()" ) return None # Try instantiation of scheduler params from config class path if "_target_" in scheduler_args: scheduler_args_cfg = OmegaConf.create(scheduler_args) scheduler_conf = hydra.utils.instantiate(scheduler_args_cfg) scheduler_args = vars(scheduler_conf) # Get name of the scheduler scheduler_name = scheduler_conf.__class__.__name__ if "Params" in scheduler_name: scheduler_name = scheduler_name.replace("Params", "") else: # Class path instantiation failed; try resolving "name" component # Get name of the scheduler if "name" in scheduler_config: scheduler_name = scheduler_config["name"] else: logging.warning( "Could not resolve classpath for Scheduler Config, and `name` " "was not provided either. \n" "Scheduler cannot be instantiated !") return None # If class path was not provided, perhaps `name` is provided for resolution if "name" in scheduler_args: # If `auto` is passed as name for resolution of optimizer name, # then lookup optimizer name and resolve its parameter config if scheduler_args["name"] == "auto": scheduler_params_name = f"{scheduler_name}Params" else: scheduler_params_name = scheduler_args["name"] # Get override arguments provided in the config yaml file / Dict Config scheduler_params_override = scheduler_args.get("params", {}) # If params is itself a dict config object provided explicitly in Dict Config # Resolve to dictionary for convenience if isinstance(scheduler_params_override, DictConfig): scheduler_params_override = OmegaConf.to_container( scheduler_params_override, resolve=True) # Get and instantiate the Config dataclass for this scheduler scheduler_params_cls = get_scheduler_config( scheduler_params_name, **scheduler_params_override) scheduler_params = scheduler_params_cls # instantiate the parameters object scheduler_args = vars( scheduler_params ) # extract just the dictionary from the Config object # Extract value to monitor in losses, if provided. if "monitor" in scheduler_config: monitor = scheduler_config.get("monitor") else: # Default to train loss monitor = "loss" # Store exact max_steps if it is provided if "max_steps" in scheduler_config and scheduler_config[ "max_steps"] is not None: max_steps = scheduler_config["max_steps"] elif "t_max_epochs" in scheduler_config: # Compute effective max_steps if t_max_epochs is provided if train_dataloader is None: logging.warning( "As `t_max_epochs` is provided/computed, it is required to pass the train dataloader in order\n" "to compute effective maximum number of steps.\n" "Scheduler will not be instantiated !") return None # Raise exception if neither `max_steps` nor `t_max_epochs` is provided if scheduler_config.get("t_max_epochs", None) is None: logging.warning( "`t_max_epochs` cannot be None when `max_steps` is not not provided.\n" "This can occur when `train dataloader` is not available to correctly " "prepare the scheduler.\n" "Scheduler will not be instantiated !") return None # Get iters_per_batch max_epochs = scheduler_config.get("t_max_epochs") accumulate_grad_batches = scheduler_config.get( "t_accumulate_grad_batches") limit_train_batches = scheduler_config.get("t_limit_train_batches") num_workers = scheduler_config.get("t_num_workers") # Compute effective num max_steps num_samples = len(train_dataloader.dataset) # type: ignore # we may need to override ModelPT setup_optimization if train_dataloader.batch_size is not None: batch_size = train_dataloader.batch_size elif hasattr(train_dataloader, "batch_sampler" ) and train_dataloader.batch_sampler is not None: if train_dataloader.batch_sampler.micro_batch_size is not None: batch_size = train_dataloader.batch_sampler.micro_batch_size else: raise ValueError( f"Could not find batch_size from batch_sampler: {train_dataloader.batch_sampler}" ) else: raise ValueError( f"Could not find batch_size from train_dataloader: {train_dataloader}" ) drop_last = train_dataloader.drop_last max_steps = compute_max_steps( max_epochs=max_epochs, accumulate_grad_batches=accumulate_grad_batches, limit_train_batches=limit_train_batches, num_workers=num_workers, num_samples=num_samples, batch_size=batch_size, drop_last=drop_last, ) else: logging.warning( "Neither `max_steps` nor `iters_per_batch` were provided to `optim.sched`, " "cannot compute effective `max_steps` !\n" "Scheduler will not be instantiated !") return None # Inject max_steps (effective or provided) into the scheduler config if add_max_args_flag and scheduler_config.get("name", "") != "ExponentialLR": scheduler_args["max_steps"] = max_steps # Get the scheduler class from the config scheduler_cls = get_scheduler(scheduler_name, **scheduler_args) # Instantiate the LR schedule schedule = scheduler_cls(optimizer, **scheduler_args) logging.info( 'Scheduler "%s" \nwill be used during training (effective maximum steps = %d) - \nParameters : \n(%s)', str(schedule), max_steps, OmegaConf.to_yaml(OmegaConf.create(scheduler_args)), ) # Wrap the schedule in PTL arguments to perform stepwise computation # Rather than epoch level computation reduce_lr_on_plateau = isinstance(schedule, optim.lr_scheduler.ReduceLROnPlateau) return { "scheduler": schedule, "interval": interval, "frequency": 1, "monitor": monitor, "reduce_on_plateau": reduce_lr_on_plateau, }
def configure_checkpointing(trainer: Trainer, log_dir: Path, name: str, resume: bool, params: "DictConfig"): """Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint callback or if trainer.weights_save_path was passed to Trainer. """ for callback in trainer.callbacks: if isinstance(callback, ModelCheckpoint): raise CheckpointMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint " "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback " "to False, or remove ModelCheckpoint from the lightning trainer" ) if Path(trainer.weights_save_path) != Path.cwd(): raise CheckpointMisconfigurationError( "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager" ) # Create the callback and attach it to trainer if "filepath" in params: if params.filepath is not None: logging.warning( "filepath is deprecated. Please switch to dirpath and filename instead" ) if params.dirpath is None: params.dirpath = Path(params.filepath).parent if params.filename is None: params.filename = Path(params.filepath).name with open_dict(params): del params["filepath"] if params.dirpath is None: params.dirpath = Path(log_dir / "checkpoints") if params.filename is None: params.filename = f"{name}--{{{params.monitor}:.4f}}-{{epoch}}" if params.prefix is None: params.prefix = name MRIDCModelCheckpoint.CHECKPOINT_NAME_LAST = f"{params.filename}-last" logging.debug(params.dirpath) logging.debug(params.filename) logging.debug(params.prefix) if "val" in params.monitor: if (trainer.max_epochs is not None and trainer.max_epochs != -1 and trainer.max_epochs < trainer.check_val_every_n_epoch): logging.error( "The checkpoint callback was told to monitor a validation value but trainer.max_epochs(" f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch(" f"{trainer.check_val_every_n_epoch}). It is very likely this run will fail with " f"ModelCheckpoint(monitor='{params.monitor}') not found in the returned metrics. Please ensure that " "validation is run within trainer.max_epochs.") elif trainer.max_steps is not None: logging.warning( "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to " f"{trainer.max_steps}. Please ensure that max_steps will run for at least " f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out." ) checkpoint_callback = MRIDCModelCheckpoint(n_resume=resume, **params) checkpoint_callback.last_model_path = trainer._checkpoint_connector.resume_from_checkpoint_fit_path or "" if "mp_rank" in checkpoint_callback.last_model_path or "tp_rank" in checkpoint_callback.last_model_path: checkpoint_callback.last_model_path = mridc.utils.model_utils.uninject_model_parallel_rank( # type: ignore checkpoint_callback.last_model_path) trainer.callbacks.append(checkpoint_callback)
def get_log_dir( trainer: Trainer, exp_dir: str = None, name: str = None, version: str = None, explicit_log_dir: str = None, use_datetime_version: bool = True, resume_if_exists: bool = False, ) -> Tuple[Path, str, str, str]: """ Obtains the log_dir used for exp_manager. Parameters ---------- trainer: The trainer to check. exp_dir: The experiment directory to check. name: The experiment name to check. version: The experiment version to check. explicit_log_dir: The explicit log dir to check. use_datetime_version: Whether to use datetime versioning. resume_if_exists: Whether to resume if the log_dir already exists. Raises ------- LoggerMisconfigurationError: If trainer is incompatible with arguments NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found. ValueError: If resume is True, and there were more than 1 checkpoint could found. """ if explicit_log_dir: # If explicit log_dir was passed, short circuit return check_explicit_log_dir(trainer, [Path(explicit_log_dir)], exp_dir, name, version) # type: ignore # Default exp_dir to ./mridc_experiments if None was passed _exp_dir = exp_dir if exp_dir is None: _exp_dir = str(Path.cwd() / "mridc_experiments") # If the user has already defined a logger for the trainer, use the logger defaults for logging directory if trainer.logger is not None: if trainer.logger.save_dir: if exp_dir: raise LoggerMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a logger, the logger's " f"save_dir was not None, and exp_dir ({exp_dir}) was not None. If trainer.logger.save_dir " "exists, exp_manager will use trainer.logger.save_dir as the logging directory and exp_dir " "must be None.") _exp_dir = trainer.logger.save_dir if name: raise LoggerMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a logger, and name: " f"{name} was also passed to exp_manager. If the trainer contains a " "logger, exp_manager will use trainer.logger.name, and name passed to exp_manager must be None." ) name = trainer.logger.name version = f"version_{trainer.logger.version}" # Use user-defined exp_dir, project_name, exp_name, and versioning options else: name = name or "default" version = version or os.environ.get(MRIDC_ENV_VARNAME_VERSION) if not version: if resume_if_exists: logging.warning( "No version folders would be created under the log folder as 'resume_if_exists' is enabled." ) version = None elif is_global_rank_zero(): if use_datetime_version: version = time.strftime("%Y-%m-%d_%H-%M-%S") else: tensorboard_logger = TensorBoardLogger(save_dir=_exp_dir, name=name, version=version) version = f"version_{tensorboard_logger.version}" os.environ[ MRIDC_ENV_VARNAME_VERSION] = "" if version is None else version log_dir = Path(str(_exp_dir)) / Path( str(name)) / Path("" if version is None else str(version)) return log_dir, str(_exp_dir), str(name), str(version)
def check_resume( trainer: Trainer, log_dir: str, resume_past_end: bool = False, resume_ignore_no_checkpoint: bool = False, ): """ Checks that resume=True was used correctly with the arguments pass to exp_manager. Sets trainer._checkpoint_connector.resume_from_checkpoint_fit_path as necessary. Parameters ---------- trainer: The trainer that is being used. log_dir: The directory where the logs are being saved. resume_past_end: Whether to resume from the end of the experiment. resume_ignore_no_checkpoint: Whether to ignore if there is no checkpoint to resume from. Returns ------- NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found. ValueError: If resume is True, and there were more than 1 checkpoint could found. """ if not log_dir: raise ValueError( f"Resuming requires the log_dir {log_dir} to be passed to exp_manager" ) checkpoint_dir = Path(Path(log_dir) / "checkpoints") checkpoint = None end_checkpoints = list(checkpoint_dir.rglob("*end.ckpt")) last_checkpoints = list(checkpoint_dir.rglob("*last.ckpt")) if not checkpoint_dir.exists(): if not resume_ignore_no_checkpoint: raise NotFoundError( f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume." ) logging.warning( f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Training from scratch." ) return if end_checkpoints: if not resume_past_end: raise ValueError( f"Found {end_checkpoints[0]} indicating that the last training run has already completed." ) if len(end_checkpoints) > 1: if "mp_rank" in str(end_checkpoints[0]): checkpoint = end_checkpoints[0] else: raise ValueError( f"Multiple checkpoints {end_checkpoints} that matches *end.ckpt." ) logging.info(f"Resuming from {end_checkpoints[0]}") elif not last_checkpoints: if not resume_ignore_no_checkpoint: raise NotFoundError( f"There were no checkpoints found in {checkpoint_dir}. Cannot resume." ) logging.warning( f"There were no checkpoints found in {checkpoint_dir}. Training from scratch." ) return elif len(last_checkpoints) > 1: if "mp_rank" not in str(last_checkpoints[0]) and "tp_rank" not in str( last_checkpoints[0]): raise ValueError( f"Multiple checkpoints {last_checkpoints} that matches *last.ckpt." ) checkpoint = last_checkpoints[0] checkpoint = mridc.utils.model_utils.uninject_model_parallel_rank( checkpoint) # type: ignore else: logging.info(f"Resuming from {last_checkpoints[0]}") checkpoint = last_checkpoints[0] trainer._checkpoint_connector.resume_from_checkpoint_fit_path = str( checkpoint) if is_global_rank_zero(): if files_to_move := [ child for child in Path(log_dir).iterdir() if child.is_file() ]: # Move old files to a new folder other_run_dirs = Path(log_dir).glob("run_*") run_count = sum(bool(fold.is_dir()) for fold in other_run_dirs) new_run_dir = Path(Path(log_dir) / f"run_{run_count}") new_run_dir.mkdir() for _file in files_to_move: move(str(_file), str(new_run_dir))