예제 #1
    def on_train_end(self, trainer, pl_module):
        This is called at the end of training.

        trainer: The trainer object.
        pl_module: The PyTorch-Lightning module.
        if trainer.fast_dev_run:
            return None

        # Call parent on_train_end() to save the -last checkpoint
        super().on_train_end(trainer, pl_module)

        # Load the best model and then re-save it
        if self.save_best_model:
            # wait for all processes to finish
            if self.best_model_path == "":
                    f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints "
                    "were found. Saving latest model instead.")

        if self.save_mridc_on_train_end:
                save_path=os.path.join(self.dirpath, self.prefix +
예제 #2
def _convert_config(cfg: "OmegaConf"):
    """Recursive function converting the configuration from old hydra format to the new one."""
    if not _HAS_HYDRA:
            "This function requires Hydra/OmegaConf and it was not installed.")

    # Get rid of cls -> _target_.
    if "cls" in cfg and "_target_" not in cfg:
        cfg._target_ = cfg.pop("cls")  # type: ignore

    # Get rid of params.
    if "params" in cfg:
        params = cfg.pop("params")  # type: ignore
        for param_key, param_val in params.items():
            cfg[param_key] = param_val

    # Recursion.
        for _, sub_cfg in cfg.items():  # type: ignore
            if isinstance(sub_cfg, DictConfig):
                _convert_config(sub_cfg)  # type: ignore
    except OmegaConfBaseException as e:
            f"Skipped conversion for config/subconfig:\n{cfg}\n Reason: {e}.")
예제 #3
    def __init__(

        always_save_mridc: Whether to save the model even if it is not the best model. Default: False.
        save_mridc_on_train_end: Whether to save the model at the end of training. Default: True.
        save_best_model: Whether to save the model if it is the best model. Default: False.
        postfix: The postfix to add to the model name. Default: ".mridc".
        n_resume: Whether to resume training from a checkpoint. Default: False.
        model_parallel_size: The size of the model parallel group. Default: None.
        kwargs: The kwargs to pass to ModelCheckpoint.
        # Parse and store "extended" parameters: save_best model and postfix.
        self.always_save_mridc = always_save_mridc
        self.save_mridc_on_train_end = save_mridc_on_train_end
        self.save_best_model = save_best_model
        self.previous_model_path = None
        self.last_model_path: Union[Any, str] = None
        if self.save_best_model and not self.save_mridc_on_train_end:
                "Found save_best_model is True and save_mridc_on_train_end is False. "
                "Set save_mridc_on_train_end to True to automatically save the best model."
        self.postfix = postfix
        self.previous_best_path = ""
        self.model_parallel_size = model_parallel_size

        # `prefix` is deprecated
        self.prefix = kwargs.pop("prefix") if "prefix" in kwargs else ""
        # Call the parent class constructor with the remaining kwargs.

        if self.save_top_k != -1 and n_resume:
            logging.debug("Checking previous runs")
예제 #4
    def wrapped(x):
        Wrapper function.

        x: The class to be decorated.

        The decorated class with the experimental flag set.
            f"Module {x} is experimental, not ready for production and is not fully supported. Use at your own risk."

        return x
예제 #5
파일: lr_scheduler.py 프로젝트: wdika/mridc
def compute_max_steps(max_epochs, accumulate_grad_batches, limit_train_batches,
                      num_workers, num_samples, batch_size, drop_last):
    """Compute effective max_steps from the provided parameters."""
    _round = math.floor if drop_last else math.ceil

    sampler_num_samples = math.ceil(num_samples / max(1, num_workers))

    if drop_last and num_workers > 1:
            "Please note that drop_last is broken in pytorch 1.6.0. We will fix when pytorch 1.7.0 is released"

    steps_per_epoch = _round(sampler_num_samples / batch_size)
    if isinstance(limit_train_batches, int) or limit_train_batches == 0.0:
        steps_per_epoch = min(steps_per_epoch, int(limit_train_batches))
    elif steps_per_epoch != float("inf"):
        # limit_train_batches is a percentage of batches per epoch
        steps_per_epoch = int(steps_per_epoch * limit_train_batches)

    return math.ceil(steps_per_epoch / accumulate_grad_batches) * max_epochs
예제 #6
def check_explicit_log_dir(trainer: Trainer,
                           explicit_log_dir: List[Union[Path, str]],
                           exp_dir: str, name: str,
                           version: str) -> Tuple[Path, str, str, str]:
    Checks that the passed arguments are compatible with explicit_log_dir.

    trainer: The trainer to check.
    explicit_log_dir: The explicit log dir to check.
    exp_dir: The experiment directory to check.
    name: The experiment name to check.
    version: The experiment version to check.

    The log_dir, exp_dir, name, and version that should be used.

    if trainer.logger is not None:
        raise LoggerMisconfigurationError(
            "The pytorch lightning trainer that was passed to exp_manager contained a logger and explicit_log_dir: "
            f"{explicit_log_dir} was pass to exp_manager. Please remove the logger from the lightning trainer."
    # Checking only (explicit_log_dir) vs (exp_dir and version).
    # The `name` will be used as the actual name of checkpoint/archive.
    if exp_dir or version:
            f"exp_manager received explicit_log_dir: {explicit_log_dir} and at least one of exp_dir: {exp_dir}, "
            f"or version: {version}. Please note that exp_dir, name, and version will be ignored."
    if is_global_rank_zero() and Path(str(explicit_log_dir)).exists():
            f"Exp_manager is logging to {explicit_log_dir}, but it already exists."
    return Path(str(explicit_log_dir)), str(explicit_log_dir), "", ""
예제 #7
파일: deprecated.py 프로젝트: wdika/mridc
    def wrapper(_wrapped, args, kwargs):
        Prints the adequate warning (only once per function) when required and calls the function func, passing the
        original arguments, i.e. version and explanation.

        _wrapped: The function to be decorated.
        args: The arguments passed to the function to be decorated.
        kwargs: The keyword arguments passed to the function to be decorated.

        The decorated function.
        # Check if we already warned about that function.
        if _wrapped.__name__ not in _PRINTED_WARNING:
            # Add to list so we won't print it again.
            _PRINTED_WARNING[_wrapped.__name__] = True

            # Prepare the warning message.
            entity_name = "Class" if inspect.isclass(wrapped) else "Function"
            msg = f"{entity_name} '{_wrapped.__name__}' is deprecated."

            # Optionally, add version and explanation.
            if version is not None:
                msg = f"{msg} It is going to be removed in the {version} version."

            if explanation is not None:
                msg = f"{msg} {explanation}"

            # Display the deprecated warning.

        # Call the function.
        return _wrapped(*args, **kwargs)
예제 #8
def unique_names_check(name_list: Optional[List[str]]):
    Performs a uniqueness check on the name list resolved, so that it can warn users about non-unique keys.

    name_list: List of strings resolved for data loaders.
    if name_list is None:

    # Name uniqueness checks
    names = set()
    for name in name_list:
        if name in names:
                "Name resolution has found more than one data loader having the same name !\n"
                "In such cases, logs will nor be properly generated. "
                "Please rename the item to have unique names.\n"
                f"Resolved name : {name}")
            )  # we need just hash key check, value is just a placeholder
예제 #9
def verify_runtime(
    Verify runtime output with onnxrt.

    output: The output of the module.
    input_list: The input list of the module.
    input_dict: The input dict of the module.
    input_names: The input names of the module.
    output_names: The output names of the module.
    output_example: The output example of the module.
    check_tolerance: The tolerance for the check.

    The runtime output.
    # Verify the model can be read, and is valid
    onnx_model = onnx.load(output)
    input_names = [node.name for node in onnx_model.graph.input]
    # skipcq: PYL-W0622
    global ort_available
    if not ort_available:
            f"ONNX generated at {output}, not verified - please install onnxruntime_gpu package.\n"
        onnx.checker.check_model(onnx_model, full_check=True)

    onnx_session_opt = onnxruntime.SessionOptions()
    onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL

    sess = onnxruntime.InferenceSession(onnx_model.SerializeToString(),
    ort_out = sess.run(output_names,
                       to_onnxrt_input(input_names, input_dict, input_list))
    all_good = True

    for i, out in enumerate(ort_out[0]):
        expected = output_example[i]
        if torch.is_tensor(expected):
            tout = torch.from_numpy(out)
            if not torch.allclose(tout,
                                  atol=100 * check_tolerance):
                all_good = False
                    f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}"
    status = "SUCCESS" if all_good else "FAIL"
        f"ONNX generated at {output} verified with onnxruntime : {status}")
    return all_good
예제 #10
    model: Top-level model to replace modules in.
    expansions: A dictionary of module class names to functions to replace them with.

    The model with replaced modules.
    mapping: Dict[str, nn.Module] = {}
    for name, m in model.named_modules():
        m_type = type(m).__name__
        if m_type in expansions:  # type: ignore
            if swapped := expansions[m_type](m):  # type: ignore
                mapping[name] = swapped
    logging.warning(f"Swapped {len(mapping)} modules")
    swap_modules(model, mapping)
    return model

default_replacements = {
    "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat),
    "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat),
    "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat),

def replace_for_export(model: nn.Module) -> nn.Module:
    Top-level function to replace default set of modules in model
    NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
예제 #11
파일: lr_scheduler.py 프로젝트: wdika/mridc
def prepare_lr_scheduler(
    optimizer: optim.Optimizer,
    scheduler_config: Union[Dict[str, Any], DictConfig, None],
    train_dataloader: Optional[dataloader.DataLoader] = None,
) -> Optional[Dict[str, Any]]:
    Constructs an LR Scheduler (optionally) for a given optimizer, based on a config with the following schema.

    optimizer: The optimizer to use for the scheduler.
        name: <name of optimizer>

        lr: <maximal learning rate>

        # <additional optimizer arguments>


            name: auto  # special keyword, resolves to correct optimizer config for given optimizer name

            # cls: mridc.core.config.optimizers.NovogradParams  # explicit instantiation by class path

            params:  # optional override parameters for the optimizer config

                betas: [0.8, 0.5]

                weight_decay: 0.001

    scheduler_config: The scheduler config.

        name: <name of scheduler>

        iters_per_batch: null # computed at runtime; mandatory to have

        max_steps: null # computed at runtime or explicitly set here; mandatory to have

        # pytorch lightning args <mandatory>

        monitor: val_loss

        reduce_on_plateau: false

        # <scheduler config override>


            name: auto  # special keyword, resolves to correct optimizer config for given optimizer name

            # cls: mridc.core.config.schedulers.CosineAnnealingParams  # explicit instantiation by class path

            params:  # optional override parameters for the optimizer config

                warmup_steps: null

                warmup_ratio: null

                min_lr: 0.0

                last_epoch: -1

    train_dataloader: Optional requirement, must be passed if "iters_per_batch" is defined instead of "max_steps". \
    Used to compute effective "max_steps".

    A dictionary containing the LR Scheduler implementation if the config was successfully parsed along with other \
    parameters required by Pytorch Lightning, otherwise None.
    if scheduler_config is not None:
        scheduler_config = maybe_update_config_version(scheduler_config)

    # Build nested dictionary for convenience out of structured objects
    if isinstance(scheduler_config, DictConfig):
        scheduler_config = OmegaConf.to_container(scheduler_config,

    elif dataclasses.is_dataclass(scheduler_config):
        # Recursively transform data classes to basic dictionaries
        scheduler_config = OmegaConf.create(scheduler_config)
        scheduler_config = OmegaConf.to_container(scheduler_config,

    # Test to see if config follows above schema

    add_max_args_flag = True
    interval = "step"
    if scheduler_config is not None:
        if "args" in scheduler_config:
            scheduler_args = scheduler_config.pop("args")
            scheduler_args = copy.deepcopy(scheduler_config)

            # Remove extra parameters from scheduler_args nest
            # Assume all other parameters are to be passed into scheduler constructor

            if "name" in scheduler_args and scheduler_args[
                    "name"] == "ReduceLROnPlateau":
                add_max_args_flag = False
                interval = "epoch"

            scheduler_args.pop("name", None)
            scheduler_args.pop("t_max_epochs", None)
            scheduler_args.pop("t_accumulate_grad_batches", None)
            scheduler_args.pop("t_limit_train_batches", None)
            scheduler_args.pop("t_num_workers", None)
            scheduler_args.pop("monitor", None)
            scheduler_args.pop("reduce_on_plateau", None)

        # Return gracefully in case `sched` was not supplied; inform user
            "Scheduler not initialized as no `sched` config supplied to setup_optimizer()"
        return None

    # Try instantiation of scheduler params from config class path
    if "_target_" in scheduler_args:
        scheduler_args_cfg = OmegaConf.create(scheduler_args)
        scheduler_conf = hydra.utils.instantiate(scheduler_args_cfg)
        scheduler_args = vars(scheduler_conf)

        # Get name of the scheduler
        scheduler_name = scheduler_conf.__class__.__name__

        if "Params" in scheduler_name:
            scheduler_name = scheduler_name.replace("Params", "")

        # Class path instantiation failed; try resolving "name" component

        # Get name of the scheduler
        if "name" in scheduler_config:
            scheduler_name = scheduler_config["name"]
                "Could not resolve classpath for Scheduler Config, and `name` "
                "was not provided either. \n"
                "Scheduler cannot be instantiated !")
            return None

        # If class path was not provided, perhaps `name` is provided for resolution
        if "name" in scheduler_args:
            # If `auto` is passed as name for resolution of optimizer name,
            # then lookup optimizer name and resolve its parameter config
            if scheduler_args["name"] == "auto":
                scheduler_params_name = f"{scheduler_name}Params"
                scheduler_params_name = scheduler_args["name"]

            # Get override arguments provided in the config yaml file / Dict Config
            scheduler_params_override = scheduler_args.get("params", {})

            # If params is itself a dict config object provided explicitly in Dict Config
            # Resolve to dictionary for convenience
            if isinstance(scheduler_params_override, DictConfig):
                scheduler_params_override = OmegaConf.to_container(
                    scheduler_params_override, resolve=True)

            # Get and instantiate the Config dataclass for this scheduler
            scheduler_params_cls = get_scheduler_config(
                scheduler_params_name, **scheduler_params_override)
            scheduler_params = scheduler_params_cls  # instantiate the parameters object
            scheduler_args = vars(
            )  # extract just the dictionary from the Config object

    # Extract value to monitor in losses, if provided.
    if "monitor" in scheduler_config:
        monitor = scheduler_config.get("monitor")
        # Default to train loss
        monitor = "loss"

    # Store exact max_steps if it is provided
    if "max_steps" in scheduler_config and scheduler_config[
            "max_steps"] is not None:
        max_steps = scheduler_config["max_steps"]

    elif "t_max_epochs" in scheduler_config:
        # Compute effective max_steps if t_max_epochs is provided
        if train_dataloader is None:
                "As `t_max_epochs` is provided/computed, it is required to pass the train dataloader in order\n"
                "to compute effective maximum number of steps.\n"
                "Scheduler will not be instantiated !")
            return None

        # Raise exception if neither `max_steps` nor `t_max_epochs` is provided
        if scheduler_config.get("t_max_epochs", None) is None:
                "`t_max_epochs` cannot be None when `max_steps` is not not provided.\n"
                "This can occur when `train dataloader` is not available to correctly "
                "prepare the scheduler.\n"
                "Scheduler will not be instantiated !")
            return None

        # Get iters_per_batch
        max_epochs = scheduler_config.get("t_max_epochs")
        accumulate_grad_batches = scheduler_config.get(
        limit_train_batches = scheduler_config.get("t_limit_train_batches")
        num_workers = scheduler_config.get("t_num_workers")

        # Compute effective num max_steps
        num_samples = len(train_dataloader.dataset)  # type: ignore

        # we may need to override ModelPT setup_optimization
        if train_dataloader.batch_size is not None:
            batch_size = train_dataloader.batch_size
        elif hasattr(train_dataloader, "batch_sampler"
                     ) and train_dataloader.batch_sampler is not None:
            if train_dataloader.batch_sampler.micro_batch_size is not None:
                batch_size = train_dataloader.batch_sampler.micro_batch_size
                raise ValueError(
                    f"Could not find batch_size from batch_sampler: {train_dataloader.batch_sampler}"
            raise ValueError(
                f"Could not find batch_size from train_dataloader: {train_dataloader}"
        drop_last = train_dataloader.drop_last

        max_steps = compute_max_steps(

            "Neither `max_steps` nor `iters_per_batch` were provided to `optim.sched`, "
            "cannot compute effective `max_steps` !\n"
            "Scheduler will not be instantiated !")
        return None

    # Inject max_steps (effective or provided) into the scheduler config
    if add_max_args_flag and scheduler_config.get("name",
                                                  "") != "ExponentialLR":
        scheduler_args["max_steps"] = max_steps

    # Get the scheduler class from the config
    scheduler_cls = get_scheduler(scheduler_name, **scheduler_args)

    # Instantiate the LR schedule
    schedule = scheduler_cls(optimizer, **scheduler_args)

        'Scheduler "%s" \nwill be used during training (effective maximum steps = %d) - \nParameters : \n(%s)',

    # Wrap the schedule in PTL arguments to perform stepwise computation
    # Rather than epoch level computation
    reduce_lr_on_plateau = isinstance(schedule,

    return {
        "scheduler": schedule,
        "interval": interval,
        "frequency": 1,
        "monitor": monitor,
        "reduce_on_plateau": reduce_lr_on_plateau,
예제 #12
def configure_checkpointing(trainer: Trainer, log_dir: Path, name: str,
                            resume: bool, params: "DictConfig"):
    """Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint
    callback or if trainer.weights_save_path was passed to Trainer.
    for callback in trainer.callbacks:
        if isinstance(callback, ModelCheckpoint):
            raise CheckpointMisconfigurationError(
                "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint "
                "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback "
                "to False, or remove ModelCheckpoint from the lightning trainer"
    if Path(trainer.weights_save_path) != Path.cwd():
        raise CheckpointMisconfigurationError(
            "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager"

    # Create the callback and attach it to trainer
    if "filepath" in params:
        if params.filepath is not None:
                "filepath is deprecated. Please switch to dirpath and filename instead"
            if params.dirpath is None:
                params.dirpath = Path(params.filepath).parent
            if params.filename is None:
                params.filename = Path(params.filepath).name
        with open_dict(params):
            del params["filepath"]
    if params.dirpath is None:
        params.dirpath = Path(log_dir / "checkpoints")
    if params.filename is None:
        params.filename = f"{name}--{{{params.monitor}:.4f}}-{{epoch}}"
    if params.prefix is None:
        params.prefix = name
    MRIDCModelCheckpoint.CHECKPOINT_NAME_LAST = f"{params.filename}-last"


    if "val" in params.monitor:
        if (trainer.max_epochs is not None and trainer.max_epochs != -1
                and trainer.max_epochs < trainer.check_val_every_n_epoch):
                "The checkpoint callback was told to monitor a validation value but trainer.max_epochs("
                f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch("
                f"{trainer.check_val_every_n_epoch}). It is very likely this run will fail with "
                f"ModelCheckpoint(monitor='{params.monitor}') not found in the returned metrics. Please ensure that "
                "validation is run within trainer.max_epochs.")
        elif trainer.max_steps is not None:
                "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to "
                f"{trainer.max_steps}. Please ensure that max_steps will run for at least "
                f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out."

    checkpoint_callback = MRIDCModelCheckpoint(n_resume=resume, **params)
    checkpoint_callback.last_model_path = trainer._checkpoint_connector.resume_from_checkpoint_fit_path or ""
    if "mp_rank" in checkpoint_callback.last_model_path or "tp_rank" in checkpoint_callback.last_model_path:
        checkpoint_callback.last_model_path = mridc.utils.model_utils.uninject_model_parallel_rank(  # type: ignore
예제 #13
def get_log_dir(
    trainer: Trainer,
    exp_dir: str = None,
    name: str = None,
    version: str = None,
    explicit_log_dir: str = None,
    use_datetime_version: bool = True,
    resume_if_exists: bool = False,
) -> Tuple[Path, str, str, str]:
    Obtains the log_dir used for exp_manager.

    trainer: The trainer to check.
    exp_dir: The experiment directory to check.
    name: The experiment name to check.
    version: The experiment version to check.
    explicit_log_dir: The explicit log dir to check.
    use_datetime_version: Whether to use datetime versioning.
    resume_if_exists: Whether to resume if the log_dir already exists.

    LoggerMisconfigurationError: If trainer is incompatible with arguments
    NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found.
    ValueError: If resume is True, and there were more than 1 checkpoint could found.
    if explicit_log_dir:  # If explicit log_dir was passed, short circuit
        return check_explicit_log_dir(trainer, [Path(explicit_log_dir)],
                                      exp_dir, name, version)  # type: ignore

    # Default exp_dir to ./mridc_experiments if None was passed
    _exp_dir = exp_dir
    if exp_dir is None:
        _exp_dir = str(Path.cwd() / "mridc_experiments")

    # If the user has already defined a logger for the trainer, use the logger defaults for logging directory
    if trainer.logger is not None:
        if trainer.logger.save_dir:
            if exp_dir:
                raise LoggerMisconfigurationError(
                    "The pytorch lightning trainer that was passed to exp_manager contained a logger, the logger's "
                    f"save_dir was not None, and exp_dir ({exp_dir}) was not None. If trainer.logger.save_dir "
                    "exists, exp_manager will use trainer.logger.save_dir as the logging directory and exp_dir "
                    "must be None.")
            _exp_dir = trainer.logger.save_dir
        if name:
            raise LoggerMisconfigurationError(
                "The pytorch lightning trainer that was passed to exp_manager contained a logger, and name: "
                f"{name} was also passed to exp_manager. If the trainer contains a "
                "logger, exp_manager will use trainer.logger.name, and name passed to exp_manager must be None."
        name = trainer.logger.name
        version = f"version_{trainer.logger.version}"
    # Use user-defined exp_dir, project_name, exp_name, and versioning options
        name = name or "default"
        version = version or os.environ.get(MRIDC_ENV_VARNAME_VERSION)

        if not version:
            if resume_if_exists:
                    "No version folders would be created under the log folder as 'resume_if_exists' is enabled."
                version = None
            elif is_global_rank_zero():
                if use_datetime_version:
                    version = time.strftime("%Y-%m-%d_%H-%M-%S")
                    tensorboard_logger = TensorBoardLogger(save_dir=_exp_dir,
                    version = f"version_{tensorboard_logger.version}"
                    MRIDC_ENV_VARNAME_VERSION] = "" if version is None else version

    log_dir = Path(str(_exp_dir)) / Path(
        str(name)) / Path("" if version is None else str(version))
    return log_dir, str(_exp_dir), str(name), str(version)
예제 #14
def check_resume(
    trainer: Trainer,
    log_dir: str,
    resume_past_end: bool = False,
    resume_ignore_no_checkpoint: bool = False,
    Checks that resume=True was used correctly with the arguments pass to exp_manager. Sets
    trainer._checkpoint_connector.resume_from_checkpoint_fit_path as necessary.

    trainer: The trainer that is being used.
    log_dir: The directory where the logs are being saved.
    resume_past_end: Whether to resume from the end of the experiment.
    resume_ignore_no_checkpoint: Whether to ignore if there is no checkpoint to resume from.

    NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found.
    ValueError: If resume is True, and there were more than 1 checkpoint could found.
    if not log_dir:
        raise ValueError(
            f"Resuming requires the log_dir {log_dir} to be passed to exp_manager"

    checkpoint_dir = Path(Path(log_dir) / "checkpoints")

    checkpoint = None
    end_checkpoints = list(checkpoint_dir.rglob("*end.ckpt"))
    last_checkpoints = list(checkpoint_dir.rglob("*last.ckpt"))
    if not checkpoint_dir.exists():
        if not resume_ignore_no_checkpoint:
            raise NotFoundError(
                f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume."
            f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Training from scratch."
    if end_checkpoints:
        if not resume_past_end:
            raise ValueError(
                f"Found {end_checkpoints[0]} indicating that the last training run has already completed."
        if len(end_checkpoints) > 1:
            if "mp_rank" in str(end_checkpoints[0]):
                checkpoint = end_checkpoints[0]
                raise ValueError(
                    f"Multiple checkpoints {end_checkpoints} that matches *end.ckpt."
        logging.info(f"Resuming from {end_checkpoints[0]}")
    elif not last_checkpoints:
        if not resume_ignore_no_checkpoint:
            raise NotFoundError(
                f"There were no checkpoints found in {checkpoint_dir}. Cannot resume."
            f"There were no checkpoints found in {checkpoint_dir}. Training from scratch."
    elif len(last_checkpoints) > 1:
        if "mp_rank" not in str(last_checkpoints[0]) and "tp_rank" not in str(
            raise ValueError(
                f"Multiple checkpoints {last_checkpoints} that matches *last.ckpt."
        checkpoint = last_checkpoints[0]
        checkpoint = mridc.utils.model_utils.uninject_model_parallel_rank(
            checkpoint)  # type: ignore
        logging.info(f"Resuming from {last_checkpoints[0]}")
        checkpoint = last_checkpoints[0]

    trainer._checkpoint_connector.resume_from_checkpoint_fit_path = str(

    if is_global_rank_zero():
        if files_to_move := [
                child for child in Path(log_dir).iterdir() if child.is_file()
            # Move old files to a new folder
            other_run_dirs = Path(log_dir).glob("run_*")
            run_count = sum(bool(fold.is_dir()) for fold in other_run_dirs)
            new_run_dir = Path(Path(log_dir) / f"run_{run_count}")
            for _file in files_to_move:
                move(str(_file), str(new_run_dir))