コード例 #1
0
    def create_ensemble_model(self) -> None:
        """
        Call MLRunner again after training cross-validation models, to create an ensemble model from them.
        """
        # Import only here in case of dependency issues in reduced environment
        from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
        # Adjust parameters
        self.azure_config.hyperdrive = False
        self.model_config.number_of_cross_validation_splits = 0
        self.model_config.is_train = False

        with logging_section("Downloading checkpoints from sibling runs"):
            checkpoint_handler = CheckpointHandler(
                model_config=self.model_config,
                azure_config=self.azure_config,
                project_root=self.project_root,
                run_context=PARENT_RUN_CONTEXT)
            checkpoint_handler.discover_and_download_checkpoint_from_sibling_runs(
                output_subdir_name=OTHER_RUNS_SUBDIR_NAME)

        best_epoch = self.create_ml_runner().run_inference_and_register_model(
            checkpoint_handler=checkpoint_handler,
            model_proc=ModelProcessing.ENSEMBLE_CREATION)

        crossval_dir = self.plot_cross_validation_and_upload_results()
        Runner.generate_report(self.model_config, best_epoch,
                               ModelProcessing.ENSEMBLE_CREATION)
        # CrossValResults should have been uploaded to the parent run, so we don't need it here.
        remove_file_or_directory(crossval_dir)
        # We can also remove OTHER_RUNS under the root, as it is no longer useful and only contains copies of files
        # available elsewhere. However, first we need to upload relevant parts of OTHER_RUNS/ENSEMBLE.
        other_runs_dir = self.model_config.outputs_folder / OTHER_RUNS_SUBDIR_NAME
        other_runs_ensemble_dir = other_runs_dir / ENSEMBLE_SPLIT_NAME
        if PARENT_RUN_CONTEXT is not None:
            if other_runs_ensemble_dir.exists():
                # Only keep baseline Wilcoxon results and scatterplots and reports
                for subdir in other_runs_ensemble_dir.glob("*"):
                    if subdir.name not in [
                            BASELINE_WILCOXON_RESULTS_FILE,
                            SCATTERPLOTS_SUBDIR_NAME, REPORT_HTML, REPORT_IPYNB
                    ]:
                        remove_file_or_directory(subdir)
                PARENT_RUN_CONTEXT.upload_folder(
                    name=BASELINE_COMPARISONS_FOLDER,
                    path=str(other_runs_ensemble_dir))
            else:
                logging.warning(
                    f"Directory not found for upload: {other_runs_ensemble_dir}"
                )
        remove_file_or_directory(other_runs_dir)
コード例 #2
0
    def register_model_for_epoch(self, run_context: Run,
                                 checkpoint_handler: CheckpointHandler,
                                 best_epoch: int, best_epoch_dice: float,
                                 model_proc: ModelProcessing) -> None:

        checkpoint_path_and_epoch = checkpoint_handler.get_checkpoint_from_epoch(
            epoch=best_epoch)
        if not checkpoint_path_and_epoch or not checkpoint_path_and_epoch.checkpoint_paths:
            # No point continuing, since no checkpoints were found
            logging.warning(
                "Abandoning model registration - no valid checkpoint paths found"
            )
            return

        if not self.model_config.is_offline_run:
            split_index = run_context.get_tags().get(
                CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, None)
            if split_index == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX:
                update_run_tags(
                    run_context, {
                        IS_ENSEMBLE_KEY_NAME:
                        model_proc == ModelProcessing.ENSEMBLE_CREATION
                    })
            elif PARENT_RUN_CONTEXT is not None:
                update_run_tags(
                    run_context,
                    {PARENT_RUN_ID_KEY_NAME: PARENT_RUN_CONTEXT.id})
        with logging_section(f"Registering {model_proc.value} model"):
            self.register_segmentation_model(
                run=run_context,
                best_epoch=best_epoch,
                best_epoch_dice=best_epoch_dice,
                checkpoint_paths=checkpoint_path_and_epoch.checkpoint_paths,
                model_proc=model_proc)
コード例 #3
0
def model_train_unittest(config: Optional[DeepLearningConfig],
                         dirs: OutputFolderForTests,
                         checkpoint_handler: Optional[CheckpointHandler] = None,
                         lightning_container: Optional[LightningContainer] = None) -> \
        Tuple[StoringLogger, CheckpointHandler]:
    """
    A shortcut for running model training in the unit test suite. It runs training for the given config, with the
    default checkpoint handler initialized to point to the test output folder specified in dirs.
    :param config: The configuration of the model to train.
    :param dirs: The test fixture that provides an output folder for the test.
    :param lightning_container: An optional LightningContainer object that will be pass through to the training routine.
    :param checkpoint_handler: The checkpoint handler that should be used for training. If not provided, it will be
    created via get_default_checkpoint_handler.
    :return: Tuple[StoringLogger, CheckpointHandler]
    """
    runner = MLRunner(model_config=config, container=lightning_container)
    # Setup will set random seeds before model creation, and set the model in the container.
    # It will also set random seeds correctly. Later we use so initialized container.
    # For all tests running in AzureML, we need to skip the downloading of datasets that would otherwise happen,
    # because all unit test configs come with their own local dataset already.
    runner.setup(use_mount_or_download_dataset=False)
    if checkpoint_handler is None:
        azure_config = get_default_azure_config()
        checkpoint_handler = CheckpointHandler(azure_config=azure_config,
                                               container=runner.container,
                                               project_root=dirs.root_dir)
    _, storing_logger = model_train(checkpoint_handler=checkpoint_handler,
                                    container=runner.container)
    return storing_logger, checkpoint_handler  # type: ignore
コード例 #4
0
    def run_inference_and_register_model(self, checkpoint_handler: CheckpointHandler,
                                         model_proc: ModelProcessing) -> None:
        """
        Run inference as required, and register the model, but not necessarily in that order:
        if we can identify the epoch to register at without running inference, we register first.
        :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
        :param model_proc: whether we are running an ensemble model from within a child run with index 0. If we are,
        then outputs will be written to OTHER_RUNS/ENSEMBLE under the main outputs directory.
        """

        if self.should_register_model():
            checkpoint_paths = checkpoint_handler.get_checkpoints_to_test()
            if not checkpoint_paths:
                raise ValueError("Model registration failed: No checkpoints found")

            model_description = "Registering model."
            checkpoint_paths = checkpoint_paths
            self.register_model(checkpoint_paths, model_description, model_proc)

        if not self.azure_config.only_register_model:
            # run full image inference on existing or newly trained model on the training, and testing set
            test_metrics, val_metrics, _ = self.model_inference_train_and_test(checkpoint_handler=checkpoint_handler,
                                                                               model_proc=model_proc)

            self.try_compare_scores_against_baselines(model_proc)
コード例 #5
0
 def register_model_for_best_epoch(
         self, checkpoint_handler: CheckpointHandler,
         test_metrics: Optional[InferenceMetricsForSegmentation],
         val_metrics: Optional[InferenceMetricsForSegmentation],
         model_proc: ModelProcessing) -> int:
     if val_metrics is not None:
         best_epoch = val_metrics.get_best_epoch()
         num_epochs = len(val_metrics.epochs)
         model_description = f"Epoch {best_epoch} has best validation set metrics (out of {num_epochs} epochs " \
                             f"available). Validation set Dice: {val_metrics.epochs[best_epoch]}. "
         if test_metrics:
             model_description += f"Test set Dice: {test_metrics.epochs[best_epoch]}."
         else:
             model_description += "Test set metrics not available."
     elif test_metrics is not None:
         # We should normally not get here. We presently always run inference on both validation and test set
         # together.
         best_epoch = test_metrics.get_best_epoch()
         num_epochs = len(test_metrics.epochs)
         model_description = f"Epoch {best_epoch} has best test set metrics (out of {num_epochs} epochs " \
                             f"available). Test set Dice: {test_metrics.epochs[best_epoch]}"
     else:
         best_epoch = self.model_config.get_test_epochs()[-1]
         model_description = f"Model for epoch {best_epoch}. No validation or test set metrics were available."
     checkpoint_paths = checkpoint_handler.get_checkpoint_paths_from_epoch_or_fail(
         best_epoch)
     self.register_model_for_epoch(checkpoint_paths, model_description,
                                   model_proc)
     return best_epoch
コード例 #6
0
ファイル: util.py プロジェクト: mmachua/InnerEye-DeepLearning
def get_default_checkpoint_handler(model_config: DeepLearningConfig,
                                   project_root: Path) -> CheckpointHandler:
    """
    Gets a checkpoint handler, using the given model config and the default azure configuration.
    """
    azure_config = get_default_azure_config()
    return CheckpointHandler(azure_config=azure_config,
                             model_config=model_config,
                             project_root=project_root)
コード例 #7
0
    def create_ensemble_model(self) -> None:
        """
        Create an ensemble model from the results of the sibling runs of the present run. The present run here will
        be cross validation child run 0.
        """
        assert PARENT_RUN_CONTEXT, "This function should only be called in a Hyperdrive run"
        with logging_section("Downloading checkpoints from sibling runs"):
            checkpoint_handler = CheckpointHandler(
                model_config=self.model_config,
                azure_config=self.azure_config,
                project_root=self.project_root,
                run_context=PARENT_RUN_CONTEXT)
            checkpoint_handler.download_checkpoints_from_hyperdrive_child_runs(
                PARENT_RUN_CONTEXT)

        self.run_inference_and_register_model(
            checkpoint_handler=checkpoint_handler,
            model_proc=ModelProcessing.ENSEMBLE_CREATION)

        crossval_dir = self.plot_cross_validation_and_upload_results()
        self.generate_report(ModelProcessing.ENSEMBLE_CREATION)
        # CrossValResults should have been uploaded to the parent run, so we don't need it here.
        remove_file_or_directory(crossval_dir)
        # We can also remove OTHER_RUNS under the root, as it is no longer useful and only contains copies of files
        # available elsewhere. However, first we need to upload relevant parts of OTHER_RUNS/ENSEMBLE.
        other_runs_dir = self.model_config.outputs_folder / OTHER_RUNS_SUBDIR_NAME
        other_runs_ensemble_dir = other_runs_dir / ENSEMBLE_SPLIT_NAME
        if PARENT_RUN_CONTEXT is not None:
            if other_runs_ensemble_dir.exists():
                # Only keep baseline Wilcoxon results and scatterplots and reports
                for subdir in other_runs_ensemble_dir.glob("*"):
                    if subdir.name not in [
                            BASELINE_WILCOXON_RESULTS_FILE,
                            SCATTERPLOTS_SUBDIR_NAME, REPORT_HTML, REPORT_IPYNB
                    ]:
                        remove_file_or_directory(subdir)
                PARENT_RUN_CONTEXT.upload_folder(
                    name=BASELINE_COMPARISONS_FOLDER,
                    path=str(other_runs_ensemble_dir))
            else:
                logging.warning(
                    f"Directory not found for upload: {other_runs_ensemble_dir}"
                )
        remove_file_or_directory(other_runs_dir)
コード例 #8
0
def segmentation_model_test(
    config: SegmentationModelBase,
    data_split: ModelExecutionMode,
    checkpoint_handler: CheckpointHandler,
    model_proc: ModelProcessing = ModelProcessing.DEFAULT
) -> InferenceMetricsForSegmentation:
    """
    The main testing loop for segmentation models.
    It loads the model and datasets, then proceeds to test the model for all requested checkpoints.
    :param config: The arguments object which has a valid random seed attribute.
    :param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :param model_proc: whether we are testing an ensemble or single model
    :return: InferenceMetric object that contains metrics related for all of the checkpoint epochs.
    """
    results: Dict[int, float] = {}
    checkpoints_to_test = checkpoint_handler.get_checkpoints_to_test()

    if not checkpoints_to_test:
        raise ValueError(
            "There were no checkpoints available for model testing.")

    for checkpoint_paths_and_epoch in checkpoints_to_test:
        epoch = checkpoint_paths_and_epoch.epoch
        epoch_results_folder = config.outputs_folder / get_epoch_results_path(
            epoch, data_split, model_proc)
        # save the datasets.csv used
        config.write_dataset_files(root=epoch_results_folder)
        epoch_and_split = "epoch {} {} set".format(epoch, data_split.value)
        epoch_dice_per_image = segmentation_model_test_epoch(
            config=copy.deepcopy(config),
            data_split=data_split,
            checkpoint_paths=checkpoint_paths_and_epoch.checkpoint_paths,
            results_folder=epoch_results_folder,
            epoch_and_split=epoch_and_split)
        if epoch_dice_per_image is None:
            logging.warning(
                "There is no checkpoint file for epoch {}".format(epoch))
        else:
            epoch_average_dice: float = np.mean(
                epoch_dice_per_image) if len(epoch_dice_per_image) > 0 else 0
            results[epoch] = epoch_average_dice
            logging.info("Epoch: {:3} | Mean Dice: {:4f}".format(
                epoch, epoch_average_dice))
            if model_proc == ModelProcessing.ENSEMBLE_CREATION:
                # For the upload, we want the path without the "OTHER_RUNS/ENSEMBLE" prefix.
                name = str(
                    get_epoch_results_path(epoch, data_split,
                                           ModelProcessing.DEFAULT))
                PARENT_RUN_CONTEXT.upload_folder(
                    name=name, path=str(epoch_results_folder))
    if len(results) == 0:
        raise ValueError(
            "There was no single checkpoint file available for model testing.")
    return InferenceMetricsForSegmentation(data_split=data_split,
                                           epochs=results)
コード例 #9
0
def test_download_model_weights(test_output_dirs: OutputFolderForTests) -> None:
    # Download a sample ResNet model from a URL given in the Pytorch docs
    result_path = CheckpointHandler.download_weights(urls=[EXTERNAL_WEIGHTS_URL_EXAMPLE],
                                                     download_folder=test_output_dirs.root_dir)
    assert len(result_path) == 1
    assert result_path[0] == test_output_dirs.root_dir / os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
    assert result_path[0].is_file()

    modified_time = result_path[0].stat().st_mtime

    result_path = CheckpointHandler.download_weights(urls=[EXTERNAL_WEIGHTS_URL_EXAMPLE, EXTERNAL_WEIGHTS_URL_EXAMPLE],
                                                     download_folder=test_output_dirs.root_dir)
    assert len(result_path) == 2
    assert len(list(test_output_dirs.root_dir.glob("*"))) == 1
    assert result_path[0].samefile(result_path[1])
    assert result_path[0] == test_output_dirs.root_dir / os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
    assert result_path[0].is_file()
    # This call should not re-download the files, just return the existing ones
    assert result_path[0].stat().st_mtime == modified_time
コード例 #10
0
def test_get_checkpoints_from_model_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
    model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)

    downloaded_checkpoints = CheckpointHandler.get_checkpoints_from_model(model_id=model_id,
                                                                          workspace=get_default_workspace(),
                                                                          download_path=test_output_dirs.root_dir)
    # Check that all the ensemble checkpoints have been downloaded
    expected_model_root = test_output_dirs.root_dir / FINAL_ENSEMBLE_MODEL_FOLDER
    assert expected_model_root.is_dir()
    model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
    expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]

    assert len(expected_paths) == len(downloaded_checkpoints)
    assert set(expected_paths) == set(downloaded_checkpoints)
    for expected_path in expected_paths:
        assert expected_path.is_file()
コード例 #11
0
def test_get_checkpoints_from_model_single_run(test_output_dirs: OutputFolderForTests) -> None:
    model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)

    downloaded_checkpoints = CheckpointHandler.get_checkpoints_from_model(model_id=model_id,
                                                                          workspace=get_default_workspace(),
                                                                          download_path=test_output_dirs.root_dir)
    # Check a single checkpoint has been downloaded
    expected_model_root = test_output_dirs.root_dir / FINAL_MODEL_FOLDER
    assert expected_model_root.is_dir()
    model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
    expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]

    assert len(expected_paths) == 1  # A registered model for a non-ensemble run should contain only one checkpoint
    assert len(downloaded_checkpoints) == 1
    assert expected_paths[0] == downloaded_checkpoints[0]
    assert expected_paths[0].is_file()
コード例 #12
0
    def run_inference_and_register_model(
            self, checkpoint_handler: CheckpointHandler,
            model_proc: ModelProcessing) -> Optional[int]:
        """
        Run inference as required, and register the model, but not necessarily in that order:
        if we can identify the epoch to register at without running inference, we register first.
        :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
        :param model_proc: whether we are running an ensemble model from within a child run with index 0. If we are,
        then outputs will be written to OTHER_RUNS/ENSEMBLE under the main outputs directory.
        """
        registration_epoch = self.decide_registration_epoch_without_evaluating(
        )
        if registration_epoch is not None:
            model_description = f"Registering model for epoch {registration_epoch} without considering metrics."
            checkpoint_paths = checkpoint_handler.get_checkpoint_paths_from_epoch_or_fail(
                registration_epoch)
            self.register_model_for_epoch(checkpoint_paths, model_description,
                                          model_proc)
            if self.azure_config.register_model_only_for_epoch is not None:
                return self.azure_config.register_model_only_for_epoch

        # run full image inference on existing or newly trained model on the training, and testing set
        test_metrics, val_metrics, _ = self.model_inference_train_and_test(
            checkpoint_handler=checkpoint_handler, model_proc=model_proc)

        # register the generated model from the run if we haven't already done so
        if self.model_config.is_segmentation_model and (
                not self.model_config.is_offline_run):
            if registration_epoch is None:
                if self.should_register_model():
                    assert test_metrics is None or isinstance(
                        test_metrics, InferenceMetricsForSegmentation)
                    assert val_metrics is None or isinstance(
                        val_metrics, InferenceMetricsForSegmentation)
                    registration_epoch = self.register_model_for_best_epoch(
                        checkpoint_handler, test_metrics, val_metrics,
                        model_proc)
            self.try_compare_scores_against_baselines(model_proc)
        else:
            logging.warning("Couldn't register model in offline mode")

        return registration_epoch
コード例 #13
0
def test_runner_restart(test_output_dirs: OutputFolderForTests) -> None:
    """
    Test if starting training from a folder where the checkpoints folder already has recovery checkpoints picks up
    that it is a recovery run. Also checks that we update the start epoch in the config at loading time.
    """
    model_config = DummyClassification()
    model_config.set_output_to(test_output_dirs.root_dir)
    model_config.num_epochs = FIXED_EPOCH + 2
    # We save all checkpoints - if recovery works as expected we should have a new checkpoint for epoch 4, 5.
    model_config.recovery_checkpoint_save_interval = 1
    model_config.recovery_checkpoints_save_last_k = -1
    runner = MLRunner(model_config=model_config)
    runner.setup(use_mount_or_download_dataset=False)
    # Epochs are 0 based for saving
    create_model_and_store_checkpoint(model_config,
                                      runner.container.checkpoint_folder /
                                      f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
                                      f"{FIXED_EPOCH - 1}{CHECKPOINT_SUFFIX}",
                                      weights_only=False)
    azure_config = get_default_azure_config()
    checkpoint_handler = CheckpointHandler(
        azure_config=azure_config,
        container=runner.container,
        project_root=test_output_dirs.root_dir)
    _, storing_logger = model_train(checkpoint_handler=checkpoint_handler,
                                    container=runner.container)
    # We expect to have 4 checkpoints, FIXED_EPOCH (recovery), FIXED_EPOCH+1, FIXED_EPOCH and best.
    assert len(os.listdir(runner.container.checkpoint_folder)) == 4
    assert (runner.container.checkpoint_folder /
            f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
            f"{FIXED_EPOCH - 1}{CHECKPOINT_SUFFIX}").exists()
    assert (runner.container.checkpoint_folder /
            f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
            f"{FIXED_EPOCH}{CHECKPOINT_SUFFIX}").exists()
    assert (runner.container.checkpoint_folder /
            f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
            f"{FIXED_EPOCH + 1}{CHECKPOINT_SUFFIX}").exists()
    assert (runner.container.checkpoint_folder /
            BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).exists()
    # Check that we really restarted epoch from epoch FIXED_EPOCH.
    assert list(storing_logger.epochs) == [FIXED_EPOCH,
                                           FIXED_EPOCH + 1]  # type: ignore
コード例 #14
0
    def run(self) -> None:
        """
        Driver function to run a ML experiment. If an offline cross validation run is requested, then
        this function is recursively called for each cross validation split.
        """
        if self.is_offline_cross_val_parent_run():
            if self.model_config.is_segmentation_model:
                raise NotImplementedError(
                    "Offline cross validation is only supported for classification models."
                )
            self.spawn_offline_cross_val_classification_child_runs()
            return

        # Get the AzureML context in which the script is running
        if not self.model_config.is_offline_run and PARENT_RUN_CONTEXT is not None:
            logging.info("Setting tags from parent run.")
            self.set_run_tags_from_parent()

        self.save_build_info_for_dotnet_consumers()

        # Set data loader start method
        self.set_multiprocessing_start_method()

        # configure recovery container if provided
        checkpoint_handler = CheckpointHandler(model_config=self.model_config,
                                               azure_config=self.azure_config,
                                               project_root=self.project_root,
                                               run_context=RUN_CONTEXT)
        checkpoint_handler.discover_and_download_checkpoints_from_previous_runs(
        )
        # do training and inference, unless the "only register" switch is set (which requires a run_recovery
        # to be valid).
        if not self.azure_config.register_model_only_for_epoch:
            # Set local_dataset to the mounted path specified in azure_runner.py, if any, or download it if that fails
            # and config.local_dataset was not already set.
            self.model_config.local_dataset = self.mount_or_download_dataset()
            self.model_config.write_args_file()
            logging.info(str(self.model_config))
            # Ensure that training runs are fully reproducible - setting random seeds alone is not enough!
            make_pytorch_reproducible()

            # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been
            # loaded (typically only during tests)
            if self.model_config.dataset_data_frame is None:
                assert self.model_config.local_dataset is not None
                ml_util.validate_dataset_paths(self.model_config.local_dataset)

            # train a new model if required
            if self.azure_config.train:
                with logging_section("Model training"):
                    model_train(self.model_config, checkpoint_handler)
            else:
                self.model_config.write_dataset_files()
                self.create_activation_maps()

            # log the number of epochs used for model training
            RUN_CONTEXT.log(name="Train epochs",
                            value=self.model_config.num_epochs)

        # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because
        # the current run is a single one. See the documentation of ModelProcessing for more details.
        best_epoch = self.run_inference_and_register_model(
            checkpoint_handler, ModelProcessing.DEFAULT)

        # Generate report
        if best_epoch:
            Runner.generate_report(self.model_config, best_epoch,
                                   ModelProcessing.DEFAULT)
        elif self.model_config.is_scalar_model and len(
                self.model_config.get_test_epochs()) == 1:
            # We don't register scalar models but still want to create a report if we have run inference.
            Runner.generate_report(self.model_config,
                                   self.model_config.get_test_epochs()[0],
                                   ModelProcessing.DEFAULT)
コード例 #15
0
def model_train(config: ModelConfigBase,
                checkpoint_handler: CheckpointHandler) -> ModelTrainingResults:
    """
    The main training loop. It creates the model, dataset, optimizer_type, and criterion, then proceeds
    to train the model. If a checkpoint was specified, then it loads the checkpoint before resuming training.

    :param config: The arguments which specify all required information.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :raises TypeError: If the arguments are of the wrong type.
    :raises ValueError: When there are issues loading a previous checkpoint.
    """
    # Save the dataset files for later use in cross validation analysis
    config.write_dataset_files()

    # set the random seed for all libraries
    ml_util.set_random_seed(config.get_effective_random_seed(),
                            "Patch visualization")
    # Visualize how patches are sampled for segmentation models. This changes the random generator, but we don't
    # want training to depend on how many patients we visualized, and hence set the random seed again right after.
    with logging_section(
            "Visualizing the effect of sampling random crops for training"):
        visualize_random_crops_for_dataset(config)
    ml_util.set_random_seed(config.get_effective_random_seed(),
                            "Model training")

    logging.debug("Creating the PyTorch model.")

    # Create the train loader and validation loader to load images from the dataset
    data_loaders = config.create_data_loaders()

    # Get the path to the checkpoint to recover from
    checkpoint_path = checkpoint_handler.get_recovery_path_train()

    models_and_optimizer = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TRAIN,
        checkpoint_path=checkpoint_path)

    # Create the main model
    # If continuing from a previous run at a specific epoch, then load the previous model.
    model_loaded = models_and_optimizer.try_create_model_and_load_from_checkpoint(
    )
    if not model_loaded:
        raise ValueError(
            "There was no checkpoint file available for the model for given start_epoch {}"
            .format(config.start_epoch))

    # Print out a detailed breakdown of layers, memory consumption and time.
    generate_and_print_model_summary(config, models_and_optimizer.model)

    # Move model to GPU and adjust for multiple GPUs
    models_and_optimizer.adjust_model_for_gpus()

    # Create the mean teacher model and move to GPU
    if config.compute_mean_teacher_model:
        mean_teacher_model_loaded = models_and_optimizer.try_create_mean_teacher_model_load_from_checkpoint_and_adjust(
        )
        if not mean_teacher_model_loaded:
            raise ValueError(
                "There was no checkpoint file available for the mean teacher model "
                f"for given start_epoch {config.start_epoch}")

    # Create optimizer
    models_and_optimizer.create_optimizer()
    if checkpoint_handler.should_load_optimizer_checkpoint():
        optimizer_loaded = models_and_optimizer.try_load_checkpoint_for_optimizer(
        )
        if not optimizer_loaded:
            raise ValueError(
                f"There was no checkpoint file available for the optimizer for given start_epoch "
                f"{config.start_epoch}")

    # Create checkpoint directory for this run if it doesn't already exist
    logging.info(f"Models are saved at {config.checkpoint_folder}")
    if not config.checkpoint_folder.is_dir():
        config.checkpoint_folder.mkdir()

    # Create the SummaryWriters for Tensorboard
    writers = create_summary_writers(config)
    config.create_dataframe_loggers()

    # Create LR scheduler
    l_rate_scheduler = SchedulerWithWarmUp(config,
                                           models_and_optimizer.optimizer)

    # Training loop
    logging.info("Starting training")
    train_results_per_epoch, val_results_per_epoch, learning_rates_per_epoch = [], [], []

    resource_monitor = None
    if config.monitoring_interval_seconds > 0:
        # initialize and start GPU monitoring
        diagnostics_events = config.logs_folder / "diagnostics"
        logging.info(
            f"Starting resource monitor, outputting to {diagnostics_events}")
        resource_monitor = ResourceMonitor(
            interval_seconds=config.monitoring_interval_seconds,
            tensorboard_folder=diagnostics_events)
        resource_monitor.start()

    gradient_scaler = GradScaler(
    ) if config.use_gpu and config.use_mixed_precision else None
    optimal_temperature_scale_values = []
    for epoch in config.get_train_epochs():
        logging.info("Starting epoch {}".format(epoch))
        save_epoch = config.should_save_epoch(
            epoch) and models_and_optimizer.optimizer is not None

        # store the learning rates used for each epoch
        epoch_lrs = l_rate_scheduler.get_last_lr()
        learning_rates_per_epoch.append(epoch_lrs)

        train_val_params: TrainValidateParameters = \
            TrainValidateParameters(data_loader=data_loaders[ModelExecutionMode.TRAIN],
                                    model=models_and_optimizer.model,
                                    mean_teacher_model=models_and_optimizer.mean_teacher_model,
                                    epoch=epoch,
                                    optimizer=models_and_optimizer.optimizer,
                                    gradient_scaler=gradient_scaler,
                                    epoch_learning_rate=epoch_lrs,
                                    summary_writers=writers,
                                    dataframe_loggers=config.metrics_data_frame_loggers,
                                    in_training_mode=True)
        training_steps = create_model_training_steps(config, train_val_params)
        train_epoch_results = train_or_validate_epoch(training_steps)
        train_results_per_epoch.append(train_epoch_results.metrics)

        metrics.validate_and_store_model_parameters(writers.train, epoch,
                                                    models_and_optimizer.model)
        # Run without adjusting weights on the validation set
        train_val_params.in_training_mode = False
        train_val_params.data_loader = data_loaders[ModelExecutionMode.VAL]
        # if temperature scaling is enabled then do not save validation metrics for the checkpoint epochs
        # as these will be re-computed after performing temperature scaling on the validation set.
        if isinstance(config, SequenceModelBase):
            train_val_params.save_metrics = not (
                save_epoch and config.temperature_scaling_config)

        training_steps = create_model_training_steps(config, train_val_params)
        val_epoch_results = train_or_validate_epoch(training_steps)
        val_results_per_epoch.append(val_epoch_results.metrics)

        if config.is_segmentation_model:
            metrics.store_epoch_stats_for_segmentation(
                config.outputs_folder, epoch, epoch_lrs,
                train_epoch_results.metrics, val_epoch_results.metrics)

        if save_epoch:
            # perform temperature scaling if required
            if isinstance(
                    config,
                    SequenceModelBase) and config.temperature_scaling_config:
                optimal_temperature, scaled_val_results = \
                    temperature_scaling_steps(config, train_val_params, val_epoch_results)
                optimal_temperature_scale_values.append(optimal_temperature)
                # overwrite the metrics for the epoch with the metrics from the temperature scaled model
                val_results_per_epoch[-1] = scaled_val_results.metrics

            models_and_optimizer.save_checkpoint(epoch)

        # Updating the learning rate should happen at the end of the training loop, so that the
        # initial learning rate will be used for the very first epoch.
        l_rate_scheduler.step()

    model_training_results = ModelTrainingResults(
        train_results_per_epoch=train_results_per_epoch,
        val_results_per_epoch=val_results_per_epoch,
        learning_rates_per_epoch=learning_rates_per_epoch,
        optimal_temperature_scale_values_per_checkpoint_epoch=
        optimal_temperature_scale_values)

    logging.info("Finished training")

    # Since we have trained the model further, let the checkpoint_handler object know so it can handle
    # checkpoints correctly.
    checkpoint_handler.additional_training_done()

    # Upload visualization directory to AML run context to be able to see it
    # in the Azure UI.
    if config.max_batch_grad_cam > 0 and config.visualization_folder.exists():
        RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER,
                                  path=str(config.visualization_folder))

    writers.close_all()
    config.metrics_data_frame_loggers.close_all()
    if resource_monitor:
        # stop the resource monitoring process
        logging.info(
            "Shutting down the resource monitor process. Aggregate resource utilization:"
        )
        for name, value in resource_monitor.read_aggregate_metrics():
            logging.info(f"{name}: {value}")
            if not is_offline_run_context(RUN_CONTEXT):
                RUN_CONTEXT.log(name, value)
        resource_monitor.kill()

    return model_training_results
コード例 #16
0
def model_train(config: ModelConfigBase,
                checkpoint_handler: CheckpointHandler,
                num_nodes: int = 1) -> ModelTrainingResults:
    """
    The main training loop. It creates the Pytorch model based on the configuration options passed in,
    creates a Pytorch Lightning trainer, and trains the model.
    If a checkpoint was specified, then it loads the checkpoint before resuming training.
    :param config: The arguments which specify all required information.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :param num_nodes: The number of nodes to use in distributed training.
    """
    # Get the path to the checkpoint to recover from
    checkpoint_path = checkpoint_handler.get_recovery_path_train()
    # This reads the dataset file, and possibly sets required pre-processing objects, like one-hot encoder
    # for categorical features, that need to be available before creating the model.
    config.read_dataset_if_needed()

    # Create the trainer object. Backup the environment variables before doing that, in case we need to run a second
    # training in the unit tests.d
    old_environ = dict(os.environ)
    seed_everything(config.get_effective_random_seed())
    trainer, storing_logger = create_lightning_trainer(config, checkpoint_path, num_nodes=num_nodes)

    logging.info(f"GLOBAL_RANK: {os.getenv('GLOBAL_RANK')}, LOCAL_RANK {os.getenv('LOCAL_RANK')}. "
                 f"trainer.global_rank: {trainer.global_rank}")
    logging.debug("Creating the PyTorch model.")
    lightning_model = create_lightning_model(config)
    lightning_model.storing_logger = storing_logger

    resource_monitor = None
    # Execute some bookkeeping tasks only once if running distributed:
    if is_rank_zero():
        config.write_args_file()
        logging.info(str(config))
        # Save the dataset files for later use in cross validation analysis
        config.write_dataset_files()
        logging.info(f"Model checkpoints are saved at {config.checkpoint_folder}")

        # set the random seed for all libraries
        ml_util.set_random_seed(config.get_effective_random_seed(), "Patch visualization")
        # Visualize how patches are sampled for segmentation models. This changes the random generator, but we don't
        # want training to depend on how many patients we visualized, and hence set the random seed again right after.
        with logging_section("Visualizing the effect of sampling random crops for training"):
            visualize_random_crops_for_dataset(config)

        # Print out a detailed breakdown of layers, memory consumption and time.
        generate_and_print_model_summary(config, lightning_model.model)

        if config.monitoring_interval_seconds > 0:
            # initialize and start GPU monitoring
            diagnostics_events = config.logs_folder / "diagnostics"
            logging.info(f"Starting resource monitor, outputting to {diagnostics_events}")
            resource_monitor = ResourceMonitor(interval_seconds=config.monitoring_interval_seconds,
                                               tensorboard_folder=diagnostics_events)
            resource_monitor.start()

    # Training loop
    logging.info("Starting training")

    lightning_data = TrainingAndValidationDataLightning(config)  # type: ignore
    # When trying to store the config object in the constructor, it does not appear to get stored at all, later
    # reference of the object simply fail. Hence, have to set explicitly here.
    lightning_data.config = config
    trainer.fit(lightning_model, datamodule=lightning_data)
    trainer.logger.close()  # type: ignore
    lightning_model.close_all_loggers()
    world_size = getattr(trainer, "world_size", 0)
    is_azureml_run = not config.is_offline_run
    # Per-subject model outputs for regression models are written per rank, and need to be aggregated here.
    # Each thread per rank will come here, and upload its files to the run outputs. Rank 0 will later download them.
    if is_azureml_run and world_size > 1 and isinstance(lightning_model, ScalarLightning):
        upload_output_file_as_temp(lightning_model.train_subject_outputs_logger.csv_path, config.outputs_folder)
        upload_output_file_as_temp(lightning_model.val_subject_outputs_logger.csv_path, config.outputs_folder)
    # DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training.
    # We can now use the global_rank of the Lightining model, rather than environment variables, because DDP has set
    # all necessary properties.
    if lightning_model.global_rank != 0:
        logging.info(f"Terminating training thread with rank {lightning_model.global_rank}.")
        sys.exit()

    logging.info("Choosing the best checkpoint and removing redundant files.")
    cleanup_checkpoint_folder(config.checkpoint_folder)
    # Lightning modifies a ton of environment variables. If we first run training and then the test suite,
    # those environment variables will mislead the training runs in the test suite, and make them crash.
    # Hence, restore the original environment after training.
    os.environ.clear()
    os.environ.update(old_environ)

    if world_size and isinstance(lightning_model, ScalarLightning):
        if is_azureml_run and world_size > 1:
            # In a DDP run on the local box, all ranks will write to local disk, hence no download needed.
            # In a multi-node DDP, each rank would upload to AzureML, and rank 0 will now download all results and
            # concatenate
            for rank in range(world_size):
                for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]:
                    file = mode.value + "/" + get_subject_output_file_per_rank(rank)
                    RUN_CONTEXT.download_file(name=TEMP_PREFIX + file, output_file_path=config.outputs_folder / file)
        # Concatenate all temporary file per execution mode
        for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]:
            temp_files = (config.outputs_folder / mode.value).rglob(SUBJECT_OUTPUT_PER_RANK_PREFIX + "*")
            result_file = config.outputs_folder / mode.value / SUBJECT_METRICS_FILE_NAME
            for i, file in enumerate(temp_files):
                temp_file_contents = file.read_text()
                if i == 0:
                    # Copy the first file as-is, including the first line with the column headers
                    result_file.write_text(temp_file_contents)
                else:
                    # For all files but the first one, cut off the header line.
                    result_file.write_text(os.linesep.join(temp_file_contents.splitlines()[1:]))

    model_training_results = ModelTrainingResults(
        train_results_per_epoch=list(storing_logger.to_metrics_dicts(prefix_filter=TRAIN_PREFIX).values()),
        val_results_per_epoch=list(storing_logger.to_metrics_dicts(prefix_filter=VALIDATION_PREFIX).values()),
        train_diagnostics=lightning_model.train_diagnostics,
        val_diagnostics=lightning_model.val_diagnostics,
        optimal_temperature_scale_values_per_checkpoint_epoch=[]
    )

    logging.info("Finished training")

    # Since we have trained the model further, let the checkpoint_handler object know so it can handle
    # checkpoints correctly.
    checkpoint_handler.additional_training_done()

    # Upload visualization directory to AML run context to be able to see it
    # in the Azure UI.
    if config.max_batch_grad_cam > 0 and config.visualization_folder.exists():
        RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER, path=str(config.visualization_folder))

    if resource_monitor:
        # stop the resource monitoring process
        logging.info("Shutting down the resource monitor process. Aggregate resource utilization:")
        for name, value in resource_monitor.read_aggregate_metrics():
            logging.info(f"{name}: {value}")
            if not config.is_offline_run:
                RUN_CONTEXT.log(name, value)
        resource_monitor.kill()

    return model_training_results
コード例 #17
0
    def run(self) -> None:
        """
        Driver function to run a ML experiment. If an offline cross validation run is requested, then
        this function is recursively called for each cross validation split.
        """
        if self.is_offline_cross_val_parent_run():
            if self.model_config.is_segmentation_model:
                raise NotImplementedError("Offline cross validation is only supported for classification models.")
            self.spawn_offline_cross_val_classification_child_runs()
            return

        # Get the AzureML context in which the script is running
        if not self.model_config.is_offline_run and PARENT_RUN_CONTEXT is not None:
            logging.info("Setting tags from parent run.")
            self.set_run_tags_from_parent()

        self.save_build_info_for_dotnet_consumers()

        # Set data loader start method
        self.set_multiprocessing_start_method()

        # configure recovery container if provided
        checkpoint_handler = CheckpointHandler(model_config=self.model_config,
                                               azure_config=self.azure_config,
                                               project_root=self.project_root,
                                               run_context=RUN_CONTEXT)
        checkpoint_handler.download_recovery_checkpoints_or_weights()
        # do training and inference, unless the "only register" switch is set (which requires a run_recovery
        # to be valid).
        if not self.azure_config.only_register_model:
            # Set local_dataset to the mounted path specified in azure_runner.py, if any, or download it if that fails
            # and config.local_dataset was not already set.
            self.model_config.local_dataset = self.mount_or_download_dataset()
            # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been
            # loaded (typically only during tests)
            if self.model_config.dataset_data_frame is None:
                assert self.model_config.local_dataset is not None
                ml_util.validate_dataset_paths(
                    self.model_config.local_dataset,
                    self.model_config.dataset_csv)

            # train a new model if required
            if self.azure_config.train:
                with logging_section("Model training"):
                    model_train(self.model_config, checkpoint_handler, num_nodes=self.azure_config.num_nodes)
            else:
                self.model_config.write_dataset_files()
                self.create_activation_maps()

            # log the number of epochs used for model training
            RUN_CONTEXT.log(name="Train epochs", value=self.model_config.num_epochs)

        # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because
        # the current run is a single one. See the documentation of ModelProcessing for more details.
        self.run_inference_and_register_model(checkpoint_handler, ModelProcessing.DEFAULT)

        if self.model_config.generate_report:
            self.generate_report(ModelProcessing.DEFAULT)

        # If this is an cross validation run, and the present run is child run 0, then wait for the sibling runs,
        # build the ensemble model, and write a report for that.
        if self.model_config.number_of_cross_validation_splits > 0:
            if self.model_config.should_wait_for_other_cross_val_child_runs():
                self.wait_for_runs_to_finish()
                self.create_ensemble_model()
コード例 #18
0
def classification_model_test(config: ScalarModelBase,
                              data_split: ModelExecutionMode,
                              checkpoint_handler: CheckpointHandler,
                              model_proc: ModelProcessing) -> InferenceMetricsForClassification:
    """
    The main testing loop for classification models. It runs a loop over all epochs for which testing should be done.
    It loads the model and datasets, then proceeds to test the model for all requested checkpoints.
    :param config: The model configuration.
    :param data_split: The name of the folder to store the results inside each epoch folder in the outputs_dir,
                       used mainly in model evaluation using different dataset splits.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :param model_proc: whether we are testing an ensemble or single model
    :return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs.
    """

    def test_epoch(checkpoint_paths: List[Path]) -> Optional[MetricsDict]:
        pipeline = create_inference_pipeline(config=config,
                                             checkpoint_paths=checkpoint_paths)

        if pipeline is None:
            return None

        # for mypy
        assert isinstance(pipeline, ScalarInferencePipelineBase)

        ml_util.set_random_seed(config.get_effective_random_seed(), "Model Testing")
        ds = config.get_torch_dataset_for_inference(data_split).as_data_loader(
            shuffle=False,
            batch_size=1,
            num_dataload_workers=0
        )

        logging.info(f"Starting to evaluate model on {data_split.value} set.")
        metrics_dict = create_metrics_dict_for_scalar_models(config)
        for sample in ds:
            result = pipeline.predict(sample)
            model_output = result.posteriors
            label = result.labels.to(device=model_output.device)
            sample_id = result.subject_ids[0]
            compute_scalar_metrics(metrics_dict,
                                   subject_ids=[sample_id],
                                   model_output=model_output,
                                   labels=label,
                                   loss_type=config.loss_type)
            logging.debug(f"Example {sample_id}: {metrics_dict.to_string()}")

        average = metrics_dict.average(across_hues=False)
        logging.info(average.to_string())

        return metrics_dict

    checkpoints_to_test = checkpoint_handler.get_checkpoints_to_test()

    if not checkpoints_to_test:
        raise ValueError("There were no checkpoints available for model testing.")

    result = test_epoch(checkpoint_paths=checkpoints_to_test)
    if result is None:
        raise ValueError("There was no single checkpoint file available for model testing.")
    else:
        if isinstance(result, ScalarMetricsDict):
            results_folder = config.outputs_folder / get_epoch_results_path(data_split, model_proc)
            csv_file = results_folder / SUBJECT_METRICS_FILE_NAME

            logging.info(f"Writing {data_split.value} metrics to file {str(csv_file)}")

            # If we are running inference after a training run, the validation set metrics may have been written
            # during train time. If this is not the case, or we are running on the test set, create the metrics
            # file.
            if not csv_file.exists():
                os.makedirs(str(results_folder), exist_ok=False)
                df_logger = DataframeLogger(csv_file)

                # cross validation split index not relevant during test time
                result.store_metrics_per_subject(df_logger=df_logger,
                                                 mode=data_split)
                # write to disk
                df_logger.flush()

    return InferenceMetricsForClassification(metrics=result)
コード例 #19
0
def model_train(checkpoint_handler: CheckpointHandler,
                container: LightningContainer,
                num_nodes: int = 1) -> Tuple[Trainer, Optional[StoringLogger]]:
    """
    The main training loop. It creates the Pytorch model based on the configuration options passed in,
    creates a Pytorch Lightning trainer, and trains the model.
    If a checkpoint was specified, then it loads the checkpoint before resuming training.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :param num_nodes: The number of nodes to use in distributed training.
    :param container: A container object that holds the training data in PyTorch Lightning format
    and the model to train.
    :return: A tuple of [Trainer, StoringLogger]. Trainer is the Lightning Trainer object that was used for fitting
    the model. The StoringLogger object is returned when training an InnerEye built-in model, this is None when
    fitting other models.
    """
    # Get the path to the checkpoint to recover from
    checkpoint_path = checkpoint_handler.get_recovery_path_train()
    lightning_model = container.model

    resource_monitor: Optional[ResourceMonitor] = None
    # Execute some bookkeeping tasks only once if running distributed:
    if is_global_rank_zero():
        logging.info(
            f"Model checkpoints are saved at {container.checkpoint_folder}")
        write_args_file(container.config if isinstance(
            container, InnerEyeContainer) else container,
                        outputs_folder=container.outputs_folder)
        if container.monitoring_interval_seconds > 0:
            resource_monitor = start_resource_monitor(container)

    # Run all of the container-related operations consistently with changed outputs folder, even ones that
    # should not rely on the current working directory, like get_data_module.
    with change_working_directory(container.outputs_folder):
        data_module = container.get_data_module()
        if is_global_rank_zero():
            container.before_training_on_global_rank_zero()
        if is_local_rank_zero():
            container.before_training_on_local_rank_zero()
        container.before_training_on_all_ranks()

    # Create the trainer object. Backup the environment variables before doing that, in case we need to run a second
    # training in the unit tests.d
    old_environ = dict(os.environ)
    # Set random seeds just before training. For segmentation models, we have
    # something that changes the random seed in the before_training_on_rank_zero hook.
    seed_everything(container.get_effective_random_seed())
    trainer, storing_logger = create_lightning_trainer(
        container,
        checkpoint_path,
        num_nodes=num_nodes,
        **container.get_trainer_arguments())
    rank_info = ", ".join(
        f"{env}: {os.getenv(env)}"
        for env in [ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK])
    logging.info(
        f"Environment variables: {rank_info}. trainer.global_rank: {trainer.global_rank}"
    )
    # InnerEye models use this logger for diagnostics
    if isinstance(lightning_model, InnerEyeLightning):
        if storing_logger is None:
            raise ValueError(
                "InnerEye models require the storing_logger for diagnostics")
        lightning_model.storing_logger = storing_logger

    logging.info("Starting training")
    # When training models that are not built-in InnerEye models, we have no guarantee that they write
    # files to the right folder. Best guess is to change the current working directory to where files should go.
    with change_working_directory(container.outputs_folder):
        trainer.fit(lightning_model, datamodule=data_module)
        trainer.logger.close()  # type: ignore
    world_size = getattr(trainer, "world_size", 0)
    is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
    # Per-subject model outputs for regression models are written per rank, and need to be aggregated here.
    # Each thread per rank will come here, and upload its files to the run outputs. Rank 0 will later download them.
    if is_azureml_run and world_size > 1 and isinstance(
            lightning_model, ScalarLightning):
        upload_output_file_as_temp(
            lightning_model.train_subject_outputs_logger.csv_path,
            container.outputs_folder)
        upload_output_file_as_temp(
            lightning_model.val_subject_outputs_logger.csv_path,
            container.outputs_folder)
    # DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training.
    # We can now use the global_rank of the Lightining model, rather than environment variables, because DDP has set
    # all necessary properties.
    if lightning_model.global_rank != 0:
        logging.info(
            f"Terminating training thread with rank {lightning_model.global_rank}."
        )
        sys.exit()

    logging.info("Choosing the best checkpoint and removing redundant files.")
    create_best_checkpoint(container.checkpoint_folder)
    # Lightning modifies a ton of environment variables. If we first run training and then the test suite,
    # those environment variables will mislead the training runs in the test suite, and make them crash.
    # Hence, restore the original environment after training.
    os.environ.clear()
    os.environ.update(old_environ)

    if world_size and isinstance(lightning_model, ScalarLightning):
        if is_azureml_run and world_size > 1:
            # In a DDP run on the local box, all ranks will write to local disk, hence no download needed.
            # In a multi-node DDP, each rank would upload to AzureML, and rank 0 will now download all results and
            # concatenate
            for rank in range(world_size):
                for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]:
                    file = mode.value + "/" + get_subject_output_file_per_rank(
                        rank)
                    RUN_CONTEXT.download_file(
                        name=TEMP_PREFIX + file,
                        output_file_path=container.outputs_folder / file)
        # Concatenate all temporary file per execution mode
        aggregate_and_create_subject_metrics_file(container.outputs_folder)

    logging.info("Finished training")

    # Since we have trained the model further, let the checkpoint_handler object know so it can handle
    # checkpoints correctly.
    checkpoint_handler.additional_training_done()

    # Upload visualization directory to AML run context to be able to see it in the Azure UI.
    if isinstance(container, InnerEyeContainer):
        if container.config.max_batch_grad_cam > 0 and container.visualization_folder.exists(
        ):
            RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER,
                                      path=str(container.visualization_folder))

    if resource_monitor:
        logging.info("Shutting down the resource monitor process.")
        if is_azureml_run:
            for gpu_name, metrics_per_gpu in resource_monitor.read_aggregate_metrics(
            ).items():
                # Log as a table, with GPU being the first column
                RUN_CONTEXT.log_row("GPU utilization",
                                    GPU=gpu_name,
                                    **metrics_per_gpu)
        resource_monitor.kill()

    return trainer, storing_logger
コード例 #20
0
def classification_model_test(config: ScalarModelBase,
                              data_split: ModelExecutionMode,
                              checkpoint_handler: CheckpointHandler,
                              model_proc: ModelProcessing,
                              cross_val_split_index: int) -> InferenceMetricsForClassification:
    """
    The main testing loop for classification models. It runs a loop over all epochs for which testing should be done.
    It loads the model and datasets, then proceeds to test the model for all requested checkpoints.
    :param config: The model configuration.
    :param data_split: The name of the folder to store the results inside each epoch folder in the outputs_dir,
                       used mainly in model evaluation using different dataset splits.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :param model_proc: whether we are testing an ensemble or single model
    :return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs.
    """
    posthoc_label_transform = config.get_posthoc_label_transform()

    checkpoint_paths = checkpoint_handler.get_checkpoints_to_test()
    if not checkpoint_paths:
        raise ValueError("There were no checkpoints available for model testing.")

    pipeline = create_inference_pipeline(config=config,
                                         checkpoint_paths=checkpoint_paths)
    if pipeline is None:
        raise ValueError("Inference pipeline could not be created.")

    # for mypy
    assert isinstance(pipeline, ScalarInferencePipelineBase)

    ml_util.set_random_seed(config.get_effective_random_seed(), "Model Testing")
    ds = config.get_torch_dataset_for_inference(data_split).as_data_loader(
        shuffle=False,
        batch_size=1,
        num_dataload_workers=0
    )

    logging.info(f"Starting to evaluate model on {data_split.value} set.")
    results_folder = config.outputs_folder / get_best_epoch_results_path(data_split, model_proc)
    os.makedirs(str(results_folder), exist_ok=True)
    metrics_dict = create_metrics_dict_for_scalar_models(config)
    if not isinstance(config, SequenceModelBase):
        output_logger: Optional[DataframeLogger] = DataframeLogger(csv_path=results_folder / MODEL_OUTPUT_CSV)
    else:
        output_logger = None

    for sample in ds:
        result = pipeline.predict(sample)
        model_output = result.posteriors
        label = result.labels.to(device=model_output.device)
        label = posthoc_label_transform(label)
        sample_id = result.subject_ids[0]
        if output_logger:
            for i in range(len(config.target_names)):
                output_logger.add_record({LoggingColumns.Patient.value: sample_id,
                                          LoggingColumns.Hue.value: config.target_names[i],
                                          LoggingColumns.Label.value: label[0][i].item(),
                                          LoggingColumns.ModelOutput.value: model_output[0][i].item(),
                                          LoggingColumns.CrossValidationSplitIndex.value: cross_val_split_index})

        compute_scalar_metrics(metrics_dict,
                               subject_ids=[sample_id],
                               model_output=model_output,
                               labels=label,
                               loss_type=config.loss_type)
        logging.debug(f"Example {sample_id}: {metrics_dict.to_string()}")

    average = metrics_dict.average(across_hues=False)
    logging.info(average.to_string())

    if isinstance(metrics_dict, ScalarMetricsDict):
        csv_file = results_folder / SUBJECT_METRICS_FILE_NAME

        logging.info(f"Writing {data_split.value} metrics to file {str(csv_file)}")

        # If we are running inference after a training run, the validation set metrics may have been written
        # during train time. If this is not the case, or we are running on the test set, create the metrics
        # file.
        if not csv_file.exists():
            df_logger = DataframeLogger(csv_file)
            # For test if ensemble split should be default, else record which fold produced this prediction
            cv_index = DEFAULT_CROSS_VALIDATION_SPLIT_INDEX if model_proc == ModelProcessing.ENSEMBLE_CREATION \
                else cross_val_split_index
            metrics_dict.store_metrics_per_subject(df_logger=df_logger,
                                                   mode=data_split,
                                                   cross_validation_split_index=cv_index,
                                                   epoch=BEST_EPOCH_FOLDER_NAME)
            # write to disk
            df_logger.flush()

    if output_logger:
        output_logger.flush()

    return InferenceMetricsForClassification(metrics=metrics_dict)