Пример #1
0
    def download_recovery_checkpoints_or_weights(self) -> None:
        """
        Download checkpoints from a run recovery object or from a weights url. Set the checkpoints path based on the
        run_recovery_object, weights_url or local_weights_path.
        This is called at the start of training.
        """
        if self.azure_config.run_recovery_id:
            run_to_recover = self.azure_config.fetch_run(
                self.azure_config.run_recovery_id.strip())
            self.run_recovery = RunRecovery.download_all_checkpoints_from_run(
                self.output_params, run_to_recover)
        else:
            self.run_recovery = None

        if self.azure_config.pretraining_run_recovery_id is not None:
            run_to_recover = self.azure_config.fetch_run(
                self.azure_config.pretraining_run_recovery_id.strip())
            run_recovery_object = RunRecovery.download_all_checkpoints_from_run(
                self.output_params, run_to_recover, EXTRA_RUN_SUBFOLDER)
            self.container.extra_downloaded_run_id = run_recovery_object
        else:
            self.container.extra_downloaded_run_id = None

        if self.container.weights_url or self.container.local_weights_path:
            self.local_weights_path = self.get_and_save_modified_weights()
Пример #2
0
def download_best_checkpoints_from_child_runs(config: OutputParams, run: Run) -> RunRecovery:
    """
    Downloads the best checkpoints from all child runs of the provided Hyperdrive parent run.
    The checkpoints for the sibling runs will go into folder 'OTHER_RUNS/<cross_validation_split>'
    in the checkpoint folder. There is special treatment for the child run that is equal to the present AzureML
    run, its checkpoints will be read off the checkpoint folder as-is.
    :param config: Model related configs.
    :param run: The Hyperdrive parent run to download from.
    :return: run recovery information
    """
    child_runs: List[Run] = fetch_child_runs(run)
    if not child_runs:
        raise ValueError(f"AzureML run {run.id} does not have any child runs.")
    logging.info(f"Run {run.id} has {len(child_runs)} child runs: {', '.join(c.id for c in child_runs)}")
    tag_to_use = 'cross_validation_split_index'
    can_use_split_indices = tag_values_all_distinct(child_runs, tag_to_use)
    # download checkpoints for the child runs in the root of the parent
    child_runs_checkpoints_roots: List[Path] = []
    for child in child_runs:
        if child.id == RUN_CONTEXT.id:
            # We expect to find the file(s) we need in config.checkpoint_folder
            child_dst = config.checkpoint_folder
        else:
            subdir = str(child.tags[tag_to_use] if can_use_split_indices else child.number)
            child_dst = config.checkpoint_folder / OTHER_RUNS_SUBDIR_NAME / subdir
            download_run_output_file(
                blob_path=Path(CHECKPOINT_FOLDER) / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
                destination=child_dst,
                run=child
            )
        child_runs_checkpoints_roots.append(child_dst)
    return RunRecovery(checkpoints_roots=child_runs_checkpoints_roots)
Пример #3
0
def test_download_checkpoints(test_output_dirs: OutputFolderForTests, is_ensemble: bool,
                              runner_config: AzureConfig) -> None:
    output_dir = test_output_dirs.root_dir
    assert get_results_blob_path("some_run_id") == "azureml/ExperimentRun/dcid.some_run_id"
    # Any recent run ID from a PR build will do. Use a PR build because the checkpoint files are small there.
    config = ModelConfigBase(should_validate=False)
    config.set_output_to(output_dir)

    runner_config.run_recovery_id = DEFAULT_ENSEMBLE_RUN_RECOVERY_ID if is_ensemble else DEFAULT_RUN_RECOVERY_ID
    run_recovery = RunRecovery.download_checkpoints_from_recovery_run(runner_config, config)
    run_to_recover = fetch_run(workspace=runner_config.get_workspace(), run_recovery_id=runner_config.run_recovery_id)
    expected_checkpoint_file = "1" + CHECKPOINT_FILE_SUFFIX
    if is_ensemble:
        child_runs = fetch_child_runs(run_to_recover)
        expected_files = [config.checkpoint_folder
                          / OTHER_RUNS_SUBDIR_NAME
                          / str(x.get_tags()['cross_validation_split_index']) / expected_checkpoint_file
                          for x in child_runs]
    else:
        expected_files = [config.checkpoint_folder / run_to_recover.id / expected_checkpoint_file]

    checkpoint_paths = run_recovery.get_checkpoint_paths(1)
    if is_ensemble:
        assert len(run_recovery.checkpoints_roots) == len(expected_files)
        assert all([(x in [y.parent for y in expected_files]) for x in run_recovery.checkpoints_roots])
        assert len(checkpoint_paths) == len(expected_files)
        assert all([x in expected_files for x in checkpoint_paths])
    else:
        assert len(checkpoint_paths) == 1
        assert checkpoint_paths[0] == expected_files[0]

    assert all([expected_file.exists() for expected_file in expected_files])
Пример #4
0
def download_all_checkpoints_from_run(config: OutputParams, run: Run,
                                      subfolder: Optional[str] = None,
                                      only_return_path: bool = False) -> RunRecovery:
    """
    Downloads all checkpoints of the provided run inside the checkpoints folder.
    :param config: Model related configs.
    :param run: Run whose checkpoints should be recovered
    :param subfolder: optional subfolder name, if provided the checkpoints will be downloaded to
    CHECKPOINT_FOLDER / subfolder. If None, the checkpoint are downloaded to CHECKPOINT_FOLDER of the current run.
    :param: only_return_path: if True, return a RunRecovery object with the path to the checkpoint without actually
    downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple
    nodes. If False, return the RunRecovery object and download the checkpoint to disk.
    :return: run recovery information
    """
    if fetch_child_runs(run):
        raise ValueError(f"AzureML run {run.id} has child runs, this method does not support those.")

    destination_folder = config.checkpoint_folder / subfolder if subfolder else config.checkpoint_folder

    if not only_return_path:
        download_run_outputs_by_prefix(
            blobs_prefix=Path(CHECKPOINT_FOLDER),
            destination=destination_folder,
            run=run
        )
    time.sleep(60)  # Needed because AML is not fast enough to download
    return RunRecovery(checkpoints_roots=[destination_folder])
def test_recover_training_mean_teacher_model(
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Tests that training can be recovered from a previous checkpoint.
    """
    config = DummyClassification()
    config.mean_teacher_alpha = 0.999
    config.recovery_checkpoint_save_interval = 1
    config.set_output_to(test_output_dirs.root_dir / "original")
    os.makedirs(str(config.outputs_folder))

    original_checkpoint_folder = config.checkpoint_folder

    # First round of training
    config.num_epochs = 2
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)
    model_train(config, checkpoint_handler=checkpoint_handler)
    assert len(list(config.checkpoint_folder.glob("*.*"))) == 2

    # Restart training from previous run
    config.start_epoch = 2
    config.num_epochs = 3
    config.set_output_to(test_output_dirs.root_dir / "recovered")
    os.makedirs(str(config.outputs_folder))
    # make if seem like run recovery objects have been downloaded
    checkpoint_root = config.checkpoint_folder / "old_run"
    shutil.copytree(str(original_checkpoint_folder), str(checkpoint_root))
    checkpoint_handler.run_recovery = RunRecovery([checkpoint_root])

    model_train(config, checkpoint_handler=checkpoint_handler)
    # remove recovery checkpoints
    shutil.rmtree(checkpoint_root)
    assert len(list(config.checkpoint_folder.glob("*.*"))) == 2
 def download_checkpoints_from_hyperdrive_child_runs(self, hyperdrive_parent_run: Run) -> None:
     """
     Downloads the best checkpoints from all child runs of a Hyperdrive parent runs. This is used to gather results
     for ensemble creation.
     """
     self.run_recovery = RunRecovery.download_best_checkpoints_from_child_runs(self.model_config,
                                                                               hyperdrive_parent_run)
     # Check paths are good, just in case
     for path in self.run_recovery.checkpoints_roots:
         if not path.is_dir():
             raise NotADirectoryError(f"Does not exist or is not a directory: {path}")
def test_download_recovery_single_run(test_output_dirs: OutputFolderForTests,
                                      runner_config: AzureConfig) -> None:
    output_dir = test_output_dirs.root_dir
    config = ModelConfigBase(should_validate=False)
    config.set_output_to(output_dir)
    run = get_most_recent_run(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
    run_recovery = RunRecovery.download_all_checkpoints_from_run(config, run)

    # This fails if there is no recovery checkpoint
    check_single_checkpoint(run_recovery.get_recovery_checkpoint_paths())
    check_single_checkpoint(run_recovery.get_best_checkpoint_paths())
Пример #8
0
    def discover_and_download_checkpoint_from_sibling_runs(self) -> None:
        """
        Downloads checkpoints from sibling runs in a hyperdrive run. This is used to gather results from all
        splits in a hyperdrive run.
        """

        self.run_recovery = RunRecovery.download_checkpoints_from_run(
            self.model_config, self.run_context)
        # Check paths are good, just in case
        for path in self.run_recovery.checkpoints_roots:
            if not path.is_dir():
                raise NotADirectoryError(
                    f"Does not exist or is not a directory: {path}")
    def discover_and_download_checkpoints_from_previous_runs(self) -> None:
        """
        Download checkpoints from a run recovery object or from a weights url. Set the checkpoints path based on the
        run_recovery_object, weights_url or local_weights_path
        """
        if self.azure_config.run_recovery_id:
            self.run_recovery = RunRecovery.download_checkpoints_from_recovery_run(
                self.azure_config, self.model_config, self.run_context)
        else:
            self.run_recovery = None

        if self.model_config.weights_url or self.model_config.local_weights_path:
            self.local_weights_path = self.get_and_save_modified_weights()
def test_create_inference_pipeline(
        with_run_recovery: bool, config: ModelConfigBase,
        checkpoint_folder: str, inference_type: type, ensemble_type: type,
        test_output_dirs: TestOutputDirectories) -> None:
    config.set_output_to(test_output_dirs.root_dir)
    # Mimic the behaviour that checkpoints are downloaded from blob storage into the checkpoints folder.
    stored_checkpoints = full_ml_test_data_path(checkpoint_folder)
    shutil.copytree(str(stored_checkpoints), str(config.checkpoint_folder))

    if with_run_recovery:
        run_recovery: Optional[RunRecovery] = RunRecovery(
            checkpoints_roots=[stored_checkpoints])
    else:
        run_recovery = None
    assert isinstance(create_inference_pipeline(config, 1, run_recovery),
                      inference_type)

    # test for ensemble pipeline if run_recovery is enabled
    if with_run_recovery:
        run_recovery = RunRecovery(checkpoints_roots=[stored_checkpoints] * 2)
        assert isinstance(create_inference_pipeline(config, 1, run_recovery),
                          ensemble_type)
Пример #11
0
def test_download_checkpoints_hyperdrive_run(test_output_dirs: OutputFolderForTests,
                                             runner_config: AzureConfig) -> None:
    output_dir = test_output_dirs.root_dir
    config = ModelConfigBase(should_validate=False)
    config.set_output_to(output_dir)
    runner_config.run_recovery_id = DEFAULT_ENSEMBLE_RUN_RECOVERY_ID
    child_runs = fetch_child_runs(run=fetch_run(runner_config.get_workspace(), DEFAULT_ENSEMBLE_RUN_RECOVERY_ID))
    # recover child runs separately also to test hyperdrive child run recovery functionality
    expected_checkpoint_file = "1" + CHECKPOINT_FILE_SUFFIX
    for child in child_runs:
        expected_files = [config.checkpoint_folder / child.id / expected_checkpoint_file]
        run_recovery = RunRecovery.download_checkpoints_from_recovery_run(runner_config, config, child)
        assert all([x in expected_files for x in run_recovery.get_checkpoint_paths(epoch=1)])
        assert all([expected_file.exists() for expected_file in expected_files])
    def download_recovery_checkpoints_or_weights(self) -> None:
        """
        Download checkpoints from a run recovery object or from a weights url. Set the checkpoints path based on the
        run_recovery_object, weights_url or local_weights_path.
        This is called at the start of training.
        """
        if self.azure_config.run_recovery_id:
            run_to_recover = self.azure_config.fetch_run(self.azure_config.run_recovery_id.strip())
            self.run_recovery = RunRecovery.download_all_checkpoints_from_run(self.model_config, run_to_recover)
        else:
            self.run_recovery = None

        if self.model_config.weights_url or self.model_config.local_weights_path:
            self.local_weights_path = self.get_and_save_modified_weights()
    def create_ensemble_model(self) -> None:
        """
        Call MLRunner again after training cross-validation models, to create an ensemble model from them.
        """
        # Import only here in case of dependency issues in reduced environment
        from InnerEye.ML.utils.run_recovery import RunRecovery
        with logging_section("Downloading checkpoints from sibling runs"):
            run_recovery = RunRecovery.download_checkpoints_from_run(
                self.azure_config,
                self.model_config,
                PARENT_RUN_CONTEXT,
                output_subdir_name=OTHER_RUNS_SUBDIR_NAME)
            # Check paths are good, just in case
            for path in run_recovery.checkpoints_roots:
                if not path.is_dir():
                    raise NotADirectoryError(
                        f"Does not exist or is not a directory: {path}")
        # Adjust parameters
        self.azure_config.hyperdrive = False
        self.model_config.number_of_cross_validation_splits = 0
        self.model_config.is_train = False
        best_epoch = self.create_ml_runner().run_inference_and_register_model(
            run_recovery, model_proc=ModelProcessing.ENSEMBLE_CREATION)

        crossval_dir = self.plot_cross_validation_and_upload_results()
        Runner.generate_report(self.model_config, best_epoch,
                               ModelProcessing.ENSEMBLE_CREATION)
        # CrossValResults should have been uploaded to the parent run, so we don't need it here.
        remove_file_or_directory(crossval_dir)
        # We can also remove OTHER_RUNS under the root, as it is no longer useful and only contains copies of files
        # available elsewhere. However, first we need to upload relevant parts of OTHER_RUNS/ENSEMBLE.
        other_runs_dir = self.model_config.outputs_folder / OTHER_RUNS_SUBDIR_NAME
        other_runs_ensemble_dir = other_runs_dir / ENSEMBLE_SPLIT_NAME
        if PARENT_RUN_CONTEXT is not None:
            if other_runs_ensemble_dir.exists():
                # Only keep baseline Wilcoxon results and scatterplots and reports
                for subdir in other_runs_ensemble_dir.glob("*"):
                    if subdir.name not in [
                            BASELINE_WILCOXON_RESULTS_FILE,
                            SCATTERPLOTS_SUBDIR_NAME, REPORT_HTML, REPORT_IPYNB
                    ]:
                        remove_file_or_directory(subdir)
                PARENT_RUN_CONTEXT.upload_folder(
                    name=BASELINE_COMPARISONS_FOLDER,
                    path=str(other_runs_ensemble_dir))
            else:
                logging.warning(
                    f"Directory not found for upload: {other_runs_ensemble_dir}"
                )
        remove_file_or_directory(other_runs_dir)
def test_download_best_checkpoints_ensemble_run(test_output_dirs: OutputFolderForTests,
                                                runner_config: AzureConfig) -> None:
    output_dir = test_output_dirs.root_dir
    config = ModelConfigBase(should_validate=False)
    config.set_output_to(output_dir)

    run = get_most_recent_run(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
    run_recovery = RunRecovery.download_best_checkpoints_from_child_runs(config, run)
    other_runs_folder = config.checkpoint_folder / OTHER_RUNS_SUBDIR_NAME
    assert other_runs_folder.is_dir()
    for child in ["0", "1"]:
        assert (other_runs_folder / child).is_dir(), "Child run folder does not exist"
    for checkpoint in run_recovery.get_best_checkpoint_paths():
        assert checkpoint.is_file(), f"File {checkpoint} does not exist"
Пример #15
0
def test_recover_training_mean_teacher_model(
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Tests that training can be recovered from a previous checkpoint.
    """
    config = DummyClassification()
    config.mean_teacher_alpha = 0.999
    config.autosave_every_n_val_epochs = 1
    config.set_output_to(test_output_dirs.root_dir / "original")
    os.makedirs(str(config.outputs_folder))

    original_checkpoint_folder = config.checkpoint_folder

    # First round of training
    config.num_epochs = 4
    model_train_unittest(config, output_folder=test_output_dirs)
    assert len(list(config.checkpoint_folder.glob("*.*"))) == 1
    assert (config.checkpoint_folder /
            LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()

    # Restart training from previous run
    config.num_epochs = 3
    config.set_output_to(test_output_dirs.root_dir / "recovered")
    os.makedirs(str(config.outputs_folder))
    # make if seem like run recovery objects have been downloaded
    checkpoint_root = config.checkpoint_folder / "old_run"
    shutil.copytree(str(original_checkpoint_folder), str(checkpoint_root))

    # Create a new checkpoint handler and set run_recovery to the copied checkpoints
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)
    checkpoint_handler.run_recovery = RunRecovery([checkpoint_root])

    model_train_unittest(config,
                         output_folder=test_output_dirs,
                         checkpoint_handler=checkpoint_handler)
    # remove recovery checkpoints
    shutil.rmtree(checkpoint_root)
    assert len(list(config.checkpoint_folder.glob("*.ckpt"))) == 1
Пример #16
0
    def run(self) -> None:
        """
        Driver function to run a ML experiment. If an offline cross validation run is requested, then
        this function is recursively called for each cross validation split.
        """
        if self.is_offline_cross_val_parent_run():
            if self.model_config.is_segmentation_model:
                raise NotImplementedError(
                    "Offline cross validation is only supported for classification models."
                )
            self.spawn_offline_cross_val_classification_child_runs()
            return

        # Get the AzureML context in which the script is running
        if not self.model_config.is_offline_run and PARENT_RUN_CONTEXT is not None:
            logging.info("Setting tags from parent run.")
            self.set_run_tags_from_parent()

        self.save_build_info_for_dotnet_consumers()

        # Set data loader start method
        self.set_multiprocessing_start_method()

        # configure recovery container if provided
        run_recovery: Optional[RunRecovery] = None
        if self.azure_config.run_recovery_id:
            run_recovery = RunRecovery.download_checkpoints_from_recovery_run(
                self.azure_config, self.model_config, RUN_CONTEXT)
        # do training and inference, unless the "only register" switch is set (which requires a run_recovery
        # to be valid).
        if self.azure_config.register_model_only_for_epoch is None or run_recovery is None:
            # Set local_dataset to the mounted path specified in azure_runner.py, if any, or download it if that fails
            # and config.local_dataset was not already set.
            self.model_config.local_dataset = self.mount_or_download_dataset()
            self.model_config.write_args_file()
            logging.info(str(self.model_config))
            # Ensure that training runs are fully reproducible - setting random seeds alone is not enough!
            make_pytorch_reproducible()

            # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been
            # loaded (typically only during tests)
            if self.model_config.dataset_data_frame is None:
                assert self.model_config.local_dataset is not None
                ml_util.validate_dataset_paths(self.model_config.local_dataset)

            # train a new model if required
            if self.azure_config.train:
                with logging_section("Model training"):
                    model_train(self.model_config, run_recovery)
            else:
                self.model_config.write_dataset_files()
                self.create_activation_maps()

            # log the number of epochs used for model training
            RUN_CONTEXT.log(name="Train epochs",
                            value=self.model_config.num_epochs)

        # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because
        # the current run is a single one. See the documentation of ModelProcessing for more details.
        best_epoch = self.run_inference_and_register_model(
            run_recovery, ModelProcessing.DEFAULT)

        # Generate report
        if best_epoch:
            Runner.generate_report(self.model_config, best_epoch,
                                   ModelProcessing.DEFAULT)
        elif self.model_config.is_scalar_model and len(
                self.model_config.get_test_epochs()) == 1:
            # We don't register scalar models but still want to create a report if we have run inference.
            Runner.generate_report(self.model_config,
                                   self.model_config.get_test_epochs()[0],
                                   ModelProcessing.DEFAULT)
Пример #17
0
def test_recover_testing_from_run_recovery(
        mean_teacher_model: bool,
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Checks that inference results are the same whether from a checkpoint in the same run, from a run recovery or from a
    local_weights_path param.
    """
    # Train for 4 epochs
    config = DummyClassification()
    if mean_teacher_model:
        config.mean_teacher_alpha = 0.999
    config.set_output_to(test_output_dirs.root_dir / "original")
    os.makedirs(str(config.outputs_folder))
    config.save_start_epoch = 2
    config.save_step_epochs = 2

    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)
    train_results = model_train(config, checkpoint_handler=checkpoint_handler)
    assert len(train_results.learning_rates_per_epoch) == config.num_epochs

    # Run inference on this
    test_results = model_test(config=config,
                              data_split=ModelExecutionMode.TEST,
                              checkpoint_handler=checkpoint_handler)
    assert isinstance(test_results, InferenceMetricsForClassification)
    assert list(test_results.epochs.keys()) == [config.num_epochs]

    # Mimic using a run recovery and see if it is the same
    config_run_recovery = DummyClassification()
    if mean_teacher_model:
        config_run_recovery.mean_teacher_alpha = 0.999
    config_run_recovery.set_output_to(test_output_dirs.root_dir /
                                      "run_recovery")
    os.makedirs(str(config_run_recovery.outputs_folder))

    checkpoint_handler_run_recovery = get_default_checkpoint_handler(
        model_config=config_run_recovery,
        project_root=test_output_dirs.root_dir)
    # make it seem like run recovery objects have been downloaded
    checkpoint_root = config_run_recovery.checkpoint_folder / "recovered"
    shutil.copytree(str(config.checkpoint_folder), str(checkpoint_root))
    checkpoint_handler_run_recovery.run_recovery = RunRecovery(
        [checkpoint_root])
    test_results_run_recovery = model_test(
        config_run_recovery,
        data_split=ModelExecutionMode.TEST,
        checkpoint_handler=checkpoint_handler_run_recovery)
    assert isinstance(test_results_run_recovery,
                      InferenceMetricsForClassification)
    assert list(test_results_run_recovery.epochs.keys()) == [config.num_epochs]
    assert test_results.epochs[config.num_epochs].values()[MetricType.CROSS_ENTROPY.value] == \
           test_results_run_recovery.epochs[config.num_epochs].values()[MetricType.CROSS_ENTROPY.value]

    # Run inference with the local checkpoints
    config_local_weights = DummyClassification()
    if mean_teacher_model:
        config_local_weights.mean_teacher_alpha = 0.999
    config_local_weights.set_output_to(test_output_dirs.root_dir /
                                       "local_weights_path")
    os.makedirs(str(config_local_weights.outputs_folder))

    local_weights_path = test_output_dirs.root_dir / "local_weights_file.pth"
    shutil.copyfile(
        str(
            create_checkpoint_path(config.checkpoint_folder,
                                   epoch=config.num_epochs)),
        local_weights_path)
    config_local_weights.local_weights_path = local_weights_path

    checkpoint_handler_local_weights = get_default_checkpoint_handler(
        model_config=config_local_weights,
        project_root=test_output_dirs.root_dir)
    checkpoint_handler_local_weights.discover_and_download_checkpoints_from_previous_runs(
    )
    test_results_local_weights = model_test(
        config_local_weights,
        data_split=ModelExecutionMode.TEST,
        checkpoint_handler=checkpoint_handler_local_weights)
    assert isinstance(test_results_local_weights,
                      InferenceMetricsForClassification)
    assert list(test_results_local_weights.epochs.keys()) == [0]
    assert test_results.epochs[config.num_epochs].values()[MetricType.CROSS_ENTROPY.value] == \
           test_results_local_weights.epochs[0].values()[MetricType.CROSS_ENTROPY.value]
Пример #18
0
def test_recover_testing_from_run_recovery(
        mean_teacher_model: bool,
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Checks that inference results are the same whether from a checkpoint in the same run, from a run recovery or from a
    local_weights_path param.
    """
    # Train for 4 epochs
    config = DummyClassification()
    if mean_teacher_model:
        config.mean_teacher_alpha = 0.999
    config.set_output_to(test_output_dirs.root_dir / "original")
    os.makedirs(str(config.outputs_folder))

    train_results, checkpoint_handler = model_train_unittest(
        config, output_folder=test_output_dirs)
    assert len(train_results.train_results_per_epoch()) == config.num_epochs

    # Run inference on this
    test_results = model_test(
        config=config,
        data_split=ModelExecutionMode.TEST,
        checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
    assert isinstance(test_results, InferenceMetricsForClassification)

    # Mimic using a run recovery and see if it is the same
    config_run_recovery = DummyClassification()
    if mean_teacher_model:
        config_run_recovery.mean_teacher_alpha = 0.999
    config_run_recovery.set_output_to(test_output_dirs.root_dir /
                                      "run_recovery")
    os.makedirs(str(config_run_recovery.outputs_folder))

    checkpoint_handler_run_recovery = get_default_checkpoint_handler(
        model_config=config_run_recovery,
        project_root=test_output_dirs.root_dir)
    # make it seem like run recovery objects have been downloaded
    checkpoint_root = config_run_recovery.checkpoint_folder / "recovered"
    shutil.copytree(str(config.checkpoint_folder), str(checkpoint_root))
    checkpoint_handler_run_recovery.run_recovery = RunRecovery(
        [checkpoint_root])
    test_results_run_recovery = model_test(
        config_run_recovery,
        data_split=ModelExecutionMode.TEST,
        checkpoint_paths=checkpoint_handler_run_recovery.
        get_checkpoints_to_test())
    assert isinstance(test_results_run_recovery,
                      InferenceMetricsForClassification)
    assert test_results.metrics.values()[MetricType.CROSS_ENTROPY.value] == \
           test_results_run_recovery.metrics.values()[MetricType.CROSS_ENTROPY.value]

    # Run inference with the local checkpoints
    config_local_weights = DummyClassification()
    if mean_teacher_model:
        config_local_weights.mean_teacher_alpha = 0.999
    config_local_weights.set_output_to(test_output_dirs.root_dir /
                                       "local_weights_path")
    os.makedirs(str(config_local_weights.outputs_folder))

    local_weights_path = test_output_dirs.root_dir / "local_weights_file.pth"
    shutil.copyfile(
        str(config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX),
        local_weights_path)
    config_local_weights.local_weights_path = [local_weights_path]

    checkpoint_handler_local_weights = get_default_checkpoint_handler(
        model_config=config_local_weights,
        project_root=test_output_dirs.root_dir)
    checkpoint_handler_local_weights.download_recovery_checkpoints_or_weights()
    test_results_local_weights = model_test(
        config_local_weights,
        data_split=ModelExecutionMode.TEST,
        checkpoint_paths=checkpoint_handler_local_weights.
        get_checkpoints_to_test())
    assert isinstance(test_results_local_weights,
                      InferenceMetricsForClassification)
    assert test_results.metrics.values()[MetricType.CROSS_ENTROPY.value] == \
           test_results_local_weights.metrics.values()[MetricType.CROSS_ENTROPY.value]