示例#1
0
def plot_cross_validation(config: PlotCrossValidationConfig) -> Path:
    """
    Collects results from an AzureML cross validation run, and writes aggregate metrics files.
    :param config: The settings for plotting cross validation results.
    :return:
    """
    logging_to_stdout(logging.INFO)
    with logging_section("Downloading cross-validation results"):
        result_files, root_folder = download_crossval_result_files(config)
    config_and_files = OfflineCrossvalConfigAndFiles(config=config, files=result_files)
    with logging_section("Plotting cross-validation results"):
        plot_cross_validation_from_files(config_and_files, root_folder)
    return root_folder
示例#2
0
    def register_model(self,
                       checkpoint_paths: List[Path],
                       model_description: str,
                       model_proc: ModelProcessing) -> None:
        """
        Registers the model in AzureML, with the given set of checkpoints. The AzureML run's tags are updated
        to describe with information about ensemble creation and the parent run ID.
        :param checkpoint_paths: The set of Pytorch checkpoints that should be included.
        :param model_description: A string description of the model, usually containing accuracy numbers.
        :param model_proc: The type of model that is registered (single or ensemble)
        """
        if not checkpoint_paths:
            # No point continuing, since no checkpoints were found
            logging.warning("Abandoning model registration - no valid checkpoint paths found")
            return

        if not self.model_config.is_offline_run:
            split_index = RUN_CONTEXT.get_tags().get(CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, None)
            if split_index == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX:
                RUN_CONTEXT.tag(IS_ENSEMBLE_KEY_NAME, str(model_proc == ModelProcessing.ENSEMBLE_CREATION))
            elif PARENT_RUN_CONTEXT is not None:
                RUN_CONTEXT.tag(PARENT_RUN_ID_KEY_NAME, str(PARENT_RUN_CONTEXT.id))
        if isinstance(self.model_config, SegmentationModelBase):
            with logging_section(f"Registering {model_proc.value} model"):
                self.register_segmentation_model(
                    checkpoint_paths=checkpoint_paths,
                    model_description=model_description,
                    model_proc=model_proc)
        else:
            logging.info(f"No deployment done for this type of model: {type(self.model_config)}")
def plot_cross_validation(config: PlotCrossValidationConfig) -> Path:
    """
    Collects results from an AzureML cross validation run, and writes aggregate metrics files.
    and assert that there are N+1 data files available. If false, this analysis only concerns the cross
    validation runs, and check that the number of files is N.
    :param config: The settings for plotting cross validation results.
    :return: The path with all cross validation result files.
    """
    logging_to_stdout(logging.INFO)
    with logging_section("Downloading cross-validation results"):
        result_files, root_folder = download_crossval_result_files(config)
    config_and_files = OfflineCrossvalConfigAndFiles(config=config,
                                                     files=result_files)
    with logging_section("Plotting cross-validation results"):
        plot_cross_validation_from_files(config_and_files, root_folder)
    return root_folder
示例#4
0
def download_dataset(azure_dataset_id: str,
                     target_folder: Path,
                     dataset_csv: str,
                     azure_config: AzureConfig) -> Path:
    """
    Downloads or checks for an existing dataset on the executing machine. If a local_dataset is supplied and the
    directory is present, return that. Otherwise, download the dataset specified by the azure_dataset_id from the
    AzureML dataset attached to the given AzureML workspace. The dataset is downloaded into the `target_folder`,
    in a subfolder that has the same name as the dataset. If there already appears to be such a folder, and the folder
    contains a dataset csv file, no download is started.
    :param azure_dataset_id: The name of a dataset that is registered in the AzureML workspace.
    :param target_folder: The folder in which to download the dataset from Azure.
    :param dataset_csv: Name of the csv file describing the dataset.
    :param azure_config: All Azure-related configuration options.
    :return: A path on the local machine that contains the dataset.
    """
    logging.info("Trying to download dataset via AzureML datastore now.")
    azure_dataset = get_or_create_dataset(azure_config, azure_dataset_id)
    if not isinstance(azure_dataset, FileDataset):
        raise ValueError(f"Expected to get a FileDataset, but got {type(azure_dataset)}")
    # The downloaded dataset may already exist from a previous run.
    expected_dataset_path = target_folder / azure_dataset_id
    expected_dataset_file = expected_dataset_path / dataset_csv
    logging.info(f"Model training will use dataset '{azure_dataset_id}' in Azure.")
    if expected_dataset_path.is_dir() and expected_dataset_file.is_file():
        logging.info(f"The dataset appears to be downloaded already in {expected_dataset_path}. Skipping.")
        return expected_dataset_path
    logging.info("Starting to download the dataset - WARNING, this could take very long!")
    with logging_section("Downloading dataset"):
        t0 = time.perf_counter()
        azure_dataset.download(target_path=str(expected_dataset_path), overwrite=False)
        t1 = time.perf_counter() - t0
        logging.info(f"Azure dataset '{azure_dataset_id}' downloaded in {t1} seconds")
    logging.info(f"Azure dataset '{azure_dataset_id}' is now available in {expected_dataset_path}")
    return expected_dataset_path
示例#5
0
    def register_model_for_epoch(self, run_context: Run,
                                 checkpoint_handler: CheckpointHandler,
                                 best_epoch: int, best_epoch_dice: float,
                                 model_proc: ModelProcessing) -> None:

        checkpoint_path_and_epoch = checkpoint_handler.get_checkpoint_from_epoch(
            epoch=best_epoch)
        if not checkpoint_path_and_epoch or not checkpoint_path_and_epoch.checkpoint_paths:
            # No point continuing, since no checkpoints were found
            logging.warning(
                "Abandoning model registration - no valid checkpoint paths found"
            )
            return

        if not self.model_config.is_offline_run:
            split_index = run_context.get_tags().get(
                CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, None)
            if split_index == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX:
                update_run_tags(
                    run_context, {
                        IS_ENSEMBLE_KEY_NAME:
                        model_proc == ModelProcessing.ENSEMBLE_CREATION
                    })
            elif PARENT_RUN_CONTEXT is not None:
                update_run_tags(
                    run_context,
                    {PARENT_RUN_ID_KEY_NAME: PARENT_RUN_CONTEXT.id})
        with logging_section(f"Registering {model_proc.value} model"):
            self.register_segmentation_model(
                run=run_context,
                best_epoch=best_epoch,
                best_epoch_dice=best_epoch_dice,
                checkpoint_paths=checkpoint_path_and_epoch.checkpoint_paths,
                model_proc=model_proc)
示例#6
0
def model_test(
    config: ModelConfigBase,
    data_split: ModelExecutionMode,
    run_recovery: Optional[RunRecovery] = None,
    model_proc: ModelProcessing = ModelProcessing.DEFAULT
) -> Optional[InferenceMetrics]:
    """
    Runs model inference on segmentation or classification models, using a given dataset (that could be training,
    test or validation set). The inference results and metrics will be stored and logged in a way that may
    differ for model categories (classification, segmentation).
    :param config: The configuration of the model
    :param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
    :param run_recovery: Run recovery data if applicable.
    :param model_proc: whether we are testing an ensemble or single model; this affects where results are written.
    :return: The metrics that the model achieved on the given data set, or None if the data set is empty.
    """
    if len(config.get_dataset_splits()[data_split]) == 0:
        logging.info(f"Skipping inference on empty data split {data_split}")
        return None
    if config.avoid_process_spawn_in_data_loaders and is_linux():
        logging.warning(
            "Not performing any inference because avoid_process_spawn_in_data_loaders is set "
            "and additional data loaders are likely to block.")
        return None
    with logging_section(
            f"Running {model_proc.value} model on {data_split.name.lower()} set"
    ):
        if isinstance(config, SegmentationModelBase):
            return segmentation_model_test(config, data_split, run_recovery,
                                           model_proc)
        if isinstance(config, ScalarModelBase):
            return classification_model_test(config, data_split, run_recovery)
    raise ValueError(
        f"There is no testing code for models of type {type(config)}")
示例#7
0
 def wait_for_runs_to_finish(self, delay: int = 60) -> None:
     """
     Wait for cross val runs (apart from the current one) to finish and then aggregate results of all.
     :param delay: How long to wait between polls to AML to get status of child runs
     """
     with logging_section("Waiting for sibling runs"):
         while not self.are_sibling_runs_finished():
             time.sleep(delay)
示例#8
0
 def try_compare_scores_against_baselines(self, model_proc: ModelProcessing) -> None:
     """
     Attempt comparison of scores against baseline scores and scatterplot creation if possible.
     """
     if not isinstance(self.model_config, SegmentationModelBase):  # keep type checker happy
         return
     from InnerEye.ML.baselines_util import compare_scores_against_baselines
     with logging_section("Comparing scores against baselines"):
         compare_scores_against_baselines(self.model_config, self.azure_config, model_proc)
 def wait_for_cross_val_runs_to_finish_and_aggregate(self, delay: int = 60) -> None:
     """
     Wait for cross val runs (apart from the current one) to finish and then aggregate results of all.
     :param delay: How long to wait between polls to AML to get status of child runs
     """
     with logging_section("Waiting for sibling runs"):
         while self.wait_until_cross_val_splits_are_ready_for_aggregation():
             time.sleep(delay)
     assert PARENT_RUN_CONTEXT, "This function should only be called in a Hyperdrive run"
     self.create_ensemble_model()
示例#10
0
    def before_training_on_global_rank_zero(self) -> None:
        # Save the dataset files for later use in cross validation analysis
        self.config.write_dataset_files()
        if isinstance(self.config, SegmentationModelBase):
            with logging_section(
                    "Visualizing the effect of sampling random crops for training"
            ):
                visualize_random_crops_for_dataset(self.config)

        # Print out a detailed breakdown of layers, memory consumption and time.
        assert isinstance(self.model, InnerEyeLightning)
        generate_and_print_model_summary(self.config, self.model.model)
示例#11
0
    def create_ensemble_model(self) -> None:
        """
        Call MLRunner again after training cross-validation models, to create an ensemble model from them.
        """
        # Import only here in case of dependency issues in reduced environment
        from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
        # Adjust parameters
        self.azure_config.hyperdrive = False
        self.model_config.number_of_cross_validation_splits = 0
        self.model_config.is_train = False

        with logging_section("Downloading checkpoints from sibling runs"):
            checkpoint_handler = CheckpointHandler(
                model_config=self.model_config,
                azure_config=self.azure_config,
                project_root=self.project_root,
                run_context=PARENT_RUN_CONTEXT)
            checkpoint_handler.discover_and_download_checkpoint_from_sibling_runs(
                output_subdir_name=OTHER_RUNS_SUBDIR_NAME)

        best_epoch = self.create_ml_runner().run_inference_and_register_model(
            checkpoint_handler=checkpoint_handler,
            model_proc=ModelProcessing.ENSEMBLE_CREATION)

        crossval_dir = self.plot_cross_validation_and_upload_results()
        Runner.generate_report(self.model_config, best_epoch,
                               ModelProcessing.ENSEMBLE_CREATION)
        # CrossValResults should have been uploaded to the parent run, so we don't need it here.
        remove_file_or_directory(crossval_dir)
        # We can also remove OTHER_RUNS under the root, as it is no longer useful and only contains copies of files
        # available elsewhere. However, first we need to upload relevant parts of OTHER_RUNS/ENSEMBLE.
        other_runs_dir = self.model_config.outputs_folder / OTHER_RUNS_SUBDIR_NAME
        other_runs_ensemble_dir = other_runs_dir / ENSEMBLE_SPLIT_NAME
        if PARENT_RUN_CONTEXT is not None:
            if other_runs_ensemble_dir.exists():
                # Only keep baseline Wilcoxon results and scatterplots and reports
                for subdir in other_runs_ensemble_dir.glob("*"):
                    if subdir.name not in [
                            BASELINE_WILCOXON_RESULTS_FILE,
                            SCATTERPLOTS_SUBDIR_NAME, REPORT_HTML, REPORT_IPYNB
                    ]:
                        remove_file_or_directory(subdir)
                PARENT_RUN_CONTEXT.upload_folder(
                    name=BASELINE_COMPARISONS_FOLDER,
                    path=str(other_runs_ensemble_dir))
            else:
                logging.warning(
                    f"Directory not found for upload: {other_runs_ensemble_dir}"
                )
        remove_file_or_directory(other_runs_dir)
示例#12
0
def download_dataset(azure_dataset_id: str, target_folder: Path,
                     azure_config: AzureConfig) -> Path:
    """
    Downloads or checks for an existing dataset on the executing machine. If a local_dataset is supplied and the
    directory is present, return that. Otherwise, download the dataset specified by the azure_dataset_id from the
    AzureML dataset attached to the given AzureML workspace. The dataset is downloaded into the `target_folder`,
    in a subfolder that has the same name as the dataset. If there already appears to be such a folder, and the folder
    contains a dataset.csv file, no download is started.
    :param local_dataset: The path to an existing local dataset.
    :param azure_dataset_id: The name of a dataset that is registered in the AzureML workspace.
    :param target_folder: The folder in which to download the dataset from Azure.
    :param azure_config: All Azure-related configuration options.
    :return: A path on the local machine that contains the dataset.
    """
    workspace = azure_config.get_workspace()
    try:
        downloaded_via_blobxfer = download_dataset_via_blobxfer(
            dataset_id=azure_dataset_id,
            azure_config=azure_config,
            target_folder=target_folder)
        if downloaded_via_blobxfer:
            return downloaded_via_blobxfer
    except Exception as ex:
        print_exception(ex, message="Unable to download dataset via blobxfer.")
    logging.info("Trying to download dataset via AzureML datastore now.")
    azure_dataset = get_or_create_dataset(workspace, azure_dataset_id)
    if not isinstance(azure_dataset, FileDataset):
        raise ValueError(
            f"Expected to get a FileDataset, but got {type(azure_dataset)}")
    # The downloaded dataset may already exist from a previous run.
    expected_dataset_path = target_folder / azure_dataset_id
    expected_dataset_file = expected_dataset_path / DATASET_CSV_FILE_NAME
    logging.info(
        f"Model training will use dataset '{azure_dataset_id}' in Azure.")
    if expected_dataset_path.is_dir() and expected_dataset_file.is_file():
        logging.info(
            f"The dataset appears to be downloaded already in {expected_dataset_path}. Skipping."
        )
        return expected_dataset_path
    logging.info(
        "Starting to download the dataset - WARNING, this could take very long!"
    )
    with logging_section("Downloading dataset"):
        azure_dataset.download(target_path=str(expected_dataset_path),
                               overwrite=False)
    logging.info(
        f"Azure dataset '{azure_dataset_id}' is now available in {expected_dataset_path}"
    )
    return expected_dataset_path
示例#13
0
def test_download_azureml_dataset(test_output_dirs: OutputFolderForTests) -> None:
    dataset_name = "test-dataset"
    config = DummyModel()
    config.local_dataset = None
    config.azure_dataset_id = ""
    azure_config = get_default_azure_config()
    runner = MLRunner(config, azure_config=azure_config)
    # If the model has neither local_dataset or azure_dataset_id, mount_or_download_dataset should fail.
    # This mounting call must happen before any other operations on the container, because already the model
    # creation may need access to the dataset.
    with pytest.raises(ValueError) as ex:
        runner.setup()
    assert ex.value.args[0] == "The model must contain either local_dataset or azure_dataset_id."
    runner.project_root = test_output_dirs.root_dir

    # Pointing the model to a dataset folder that does not exist should raise an Exception
    fake_folder = runner.project_root / "foo"
    runner.container.local_dataset = fake_folder
    with pytest.raises(FileNotFoundError):
        runner.mount_or_download_dataset(runner.container.azure_dataset_id, runner.container.local_dataset)

    # If the local dataset folder exists, mount_or_download_dataset should not do anything.
    fake_folder.mkdir()
    local_dataset = runner.mount_or_download_dataset(runner.container.azure_dataset_id, runner.container.local_dataset)
    assert local_dataset == fake_folder

    # Pointing the model to a dataset in Azure should trigger a download
    runner.container.local_dataset = None
    runner.container.azure_dataset_id = dataset_name
    with logging_section("Starting download"):
        result_path = runner.mount_or_download_dataset(runner.container.azure_dataset_id,
                                                       runner.container.local_dataset)
    # Download goes into <project_root> / "datasets" / "test_dataset"
    expected_path = runner.project_root / fixed_paths.DATASETS_DIR_NAME / dataset_name
    assert result_path == expected_path
    assert result_path.is_dir()
    dataset_csv = Path(result_path) / DATASET_CSV_FILE_NAME
    assert dataset_csv.is_file()
    # Check that each individual file in the dataset is present
    for folder in [1, *range(10, 20)]:
        sub_folder = result_path / str(folder)
        sub_folder.is_dir()
        for file in ["ct", "esophagus", "heart", "lung_l", "lung_r", "spinalcord"]:
            f = (sub_folder / file).with_suffix(".nii.gz")
            assert f.is_file()
示例#14
0
    def create_ensemble_model(self) -> None:
        """
        Create an ensemble model from the results of the sibling runs of the present run. The present run here will
        be cross validation child run 0.
        """
        assert PARENT_RUN_CONTEXT, "This function should only be called in a Hyperdrive run"
        with logging_section("Downloading checkpoints from sibling runs"):
            checkpoint_handler = CheckpointHandler(
                model_config=self.model_config,
                azure_config=self.azure_config,
                project_root=self.project_root,
                run_context=PARENT_RUN_CONTEXT)
            checkpoint_handler.download_checkpoints_from_hyperdrive_child_runs(
                PARENT_RUN_CONTEXT)

        self.run_inference_and_register_model(
            checkpoint_handler=checkpoint_handler,
            model_proc=ModelProcessing.ENSEMBLE_CREATION)

        crossval_dir = self.plot_cross_validation_and_upload_results()
        self.generate_report(ModelProcessing.ENSEMBLE_CREATION)
        # CrossValResults should have been uploaded to the parent run, so we don't need it here.
        remove_file_or_directory(crossval_dir)
        # We can also remove OTHER_RUNS under the root, as it is no longer useful and only contains copies of files
        # available elsewhere. However, first we need to upload relevant parts of OTHER_RUNS/ENSEMBLE.
        other_runs_dir = self.model_config.outputs_folder / OTHER_RUNS_SUBDIR_NAME
        other_runs_ensemble_dir = other_runs_dir / ENSEMBLE_SPLIT_NAME
        if PARENT_RUN_CONTEXT is not None:
            if other_runs_ensemble_dir.exists():
                # Only keep baseline Wilcoxon results and scatterplots and reports
                for subdir in other_runs_ensemble_dir.glob("*"):
                    if subdir.name not in [
                            BASELINE_WILCOXON_RESULTS_FILE,
                            SCATTERPLOTS_SUBDIR_NAME, REPORT_HTML, REPORT_IPYNB
                    ]:
                        remove_file_or_directory(subdir)
                PARENT_RUN_CONTEXT.upload_folder(
                    name=BASELINE_COMPARISONS_FOLDER,
                    path=str(other_runs_ensemble_dir))
            else:
                logging.warning(
                    f"Directory not found for upload: {other_runs_ensemble_dir}"
                )
        remove_file_or_directory(other_runs_dir)
示例#15
0
def test_download_azureml_dataset(
        test_output_dirs: OutputFolderForTests) -> None:
    dataset_name = "test-dataset"
    config = ModelConfigBase(should_validate=False)
    azure_config = get_default_azure_config()
    runner = MLRunner(config, azure_config)
    runner.project_root = test_output_dirs.root_dir

    # If the model has neither local_dataset or azure_dataset_id, mount_or_download_dataset should fail.
    with pytest.raises(ValueError):
        runner.mount_or_download_dataset()

    # Pointing the model to a dataset folder that does not exist should raise an Exception
    fake_folder = runner.project_root / "foo"
    runner.model_config.local_dataset = fake_folder
    with pytest.raises(FileNotFoundError):
        runner.mount_or_download_dataset()

    # If the local dataset folder exists, mount_or_download_dataset should not do anything.
    fake_folder.mkdir()
    local_dataset = runner.mount_or_download_dataset()
    assert local_dataset == fake_folder

    # Pointing the model to a dataset in Azure should trigger a download
    runner.model_config.local_dataset = None
    runner.model_config.azure_dataset_id = dataset_name
    with logging_section("Starting download"):
        result_path = runner.mount_or_download_dataset()
    # Download goes into <project_root> / "datasets" / "test_dataset"
    expected_path = runner.project_root / fixed_paths.DATASETS_DIR_NAME / dataset_name
    assert result_path == expected_path
    assert result_path.is_dir()
    dataset_csv = Path(result_path) / DATASET_CSV_FILE_NAME
    assert dataset_csv.is_file()
    # Check that each individual file in the dataset is present
    for folder in [1, *range(10, 20)]:
        sub_folder = result_path / str(folder)
        sub_folder.is_dir()
        for file in [
                "ct", "esophagus", "heart", "lung_l", "lung_r", "spinalcord"
        ]:
            f = (sub_folder / file).with_suffix(".nii.gz")
            assert f.is_file()
示例#16
0
 def register_model_for_epoch(self, run_context: Run,
                              run_recovery: Optional[RunRecovery],
                              best_epoch: int, best_epoch_dice: float,
                              model_proc: ModelProcessing) -> None:
     checkpoint_paths = [self.model_config.get_path_to_checkpoint(best_epoch)] if not run_recovery \
         else run_recovery.get_checkpoint_paths(best_epoch)
     if not self.model_config.is_offline_run:
         split_index = run_context.get_tags().get(
             CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, None)
         if split_index == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX:
             update_run_tags(
                 run_context, {
                     IS_ENSEMBLE_KEY_NAME:
                     model_proc == ModelProcessing.ENSEMBLE_CREATION
                 })
         elif PARENT_RUN_CONTEXT is not None:
             update_run_tags(
                 run_context,
                 {PARENT_RUN_ID_KEY_NAME: PARENT_RUN_CONTEXT.id})
     # Discard any checkpoint paths that do not exist - they will make registration fail. This can happen
     # when some child runs fail; it may still be worth registering the model.
     valid_checkpoint_paths = []
     for path in checkpoint_paths:
         if path.exists():
             valid_checkpoint_paths.append(path)
         else:
             logging.warning(
                 f"Discarding non-existent checkpoint path {path}")
     if not valid_checkpoint_paths:
         # No point continuing
         logging.warning(
             "Abandoning model registration - no valid checkpoint paths found"
         )
         return
     with logging_section(f"Registering {model_proc.value} model"):
         self.register_segmentation_model(
             run=run_context,
             best_epoch=best_epoch,
             best_epoch_dice=best_epoch_dice,
             checkpoint_paths=valid_checkpoint_paths,
             model_proc=model_proc)
示例#17
0
def download_dataset_via_blobxfer(dataset_id: str, azure_config: AzureConfig,
                                  target_folder: Path) -> Optional[Path]:
    """
    Attempts to downloads a dataset from the Azure storage account for datasets, with download happening via
    blobxfer. This is only possible if the datasets storage account and keyword are present in the `azure_config`.
    The function returns None if the required settings were not present.
    :param dataset_id: The folder of the dataset, expected in the container given by azure_config.datasets_container.
    :param azure_config: The object with all Azure-related settings.
    :param target_folder: The local folder into which the dataset should be downloaded.
    :return: The folder that contains the downloaded dataset. Returns None if the datasets account name or password
    were not present.
    """
    datasets_account_key = azure_config.get_dataset_storage_account_key()
    if not datasets_account_key:
        logging.info(
            "No account key for the dataset storage account was found.")
        logging.info(
            f"We checked in environment variables and in the file {PROJECT_SECRETS_FILE}"
        )
        return None
    if (not azure_config.datasets_container) or (
            not azure_config.datasets_storage_account):
        logging.info("Datasets storage account or container missing.")
        return None
    target_folder.mkdir(exist_ok=True)
    result_folder = target_folder / dataset_id
    # only download if hasn't already been downloaded
    if result_folder.is_dir():
        logging.info(
            f"Folder already exists, skipping download: {result_folder}")
        return result_folder
    with logging_section(f"Downloading dataset {dataset_id}"):
        download_blobs(
            account=azure_config.datasets_storage_account,
            account_key=datasets_account_key,
            # When specifying the blobs root path, ensure that there is a slash at the end, otherwise
            # all datasets with that dataset_id as a prefix get downloaded.
            blobs_root_path=f"{azure_config.datasets_container}/{dataset_id}/",
            destination=result_folder)
    return result_folder
示例#18
0
    def run(self) -> None:
        """
        Driver function to run a ML experiment. If an offline cross validation run is requested, then
        this function is recursively called for each cross validation split.
        """
        if self.is_offline_cross_val_parent_run():
            if self.model_config.is_segmentation_model:
                raise NotImplementedError("Offline cross validation is only supported for classification models.")
            self.spawn_offline_cross_val_classification_child_runs()
            return

        # Get the AzureML context in which the script is running
        if not self.model_config.is_offline_run and PARENT_RUN_CONTEXT is not None:
            logging.info("Setting tags from parent run.")
            self.set_run_tags_from_parent()

        self.save_build_info_for_dotnet_consumers()

        # Set data loader start method
        self.set_multiprocessing_start_method()

        # configure recovery container if provided
        checkpoint_handler = CheckpointHandler(model_config=self.model_config,
                                               azure_config=self.azure_config,
                                               project_root=self.project_root,
                                               run_context=RUN_CONTEXT)
        checkpoint_handler.download_recovery_checkpoints_or_weights()
        # do training and inference, unless the "only register" switch is set (which requires a run_recovery
        # to be valid).
        if not self.azure_config.only_register_model:
            # Set local_dataset to the mounted path specified in azure_runner.py, if any, or download it if that fails
            # and config.local_dataset was not already set.
            self.model_config.local_dataset = self.mount_or_download_dataset()
            # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been
            # loaded (typically only during tests)
            if self.model_config.dataset_data_frame is None:
                assert self.model_config.local_dataset is not None
                ml_util.validate_dataset_paths(
                    self.model_config.local_dataset,
                    self.model_config.dataset_csv)

            # train a new model if required
            if self.azure_config.train:
                with logging_section("Model training"):
                    model_train(self.model_config, checkpoint_handler, num_nodes=self.azure_config.num_nodes)
            else:
                self.model_config.write_dataset_files()
                self.create_activation_maps()

            # log the number of epochs used for model training
            RUN_CONTEXT.log(name="Train epochs", value=self.model_config.num_epochs)

        # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because
        # the current run is a single one. See the documentation of ModelProcessing for more details.
        self.run_inference_and_register_model(checkpoint_handler, ModelProcessing.DEFAULT)

        if self.model_config.generate_report:
            self.generate_report(ModelProcessing.DEFAULT)

        # If this is an cross validation run, and the present run is child run 0, then wait for the sibling runs,
        # build the ensemble model, and write a report for that.
        if self.model_config.number_of_cross_validation_splits > 0:
            if self.model_config.should_wait_for_other_cross_val_child_runs():
                self.wait_for_runs_to_finish()
                self.create_ensemble_model()
示例#19
0
    def run(self) -> None:
        """
        Driver function to run a ML experiment. If an offline cross validation run is requested, then
        this function is recursively called for each cross validation split.
        """
        if self.is_offline_cross_val_parent_run():
            if self.model_config.is_segmentation_model:
                raise NotImplementedError(
                    "Offline cross validation is only supported for classification models."
                )
            self.spawn_offline_cross_val_classification_child_runs()
            return

        # Get the AzureML context in which the script is running
        if not self.model_config.is_offline_run and PARENT_RUN_CONTEXT is not None:
            logging.info("Setting tags from parent run.")
            self.set_run_tags_from_parent()

        self.save_build_info_for_dotnet_consumers()

        # Set data loader start method
        self.set_multiprocessing_start_method()

        # configure recovery container if provided
        checkpoint_handler = CheckpointHandler(model_config=self.model_config,
                                               azure_config=self.azure_config,
                                               project_root=self.project_root,
                                               run_context=RUN_CONTEXT)
        checkpoint_handler.discover_and_download_checkpoints_from_previous_runs(
        )
        # do training and inference, unless the "only register" switch is set (which requires a run_recovery
        # to be valid).
        if not self.azure_config.register_model_only_for_epoch:
            # Set local_dataset to the mounted path specified in azure_runner.py, if any, or download it if that fails
            # and config.local_dataset was not already set.
            self.model_config.local_dataset = self.mount_or_download_dataset()
            self.model_config.write_args_file()
            logging.info(str(self.model_config))
            # Ensure that training runs are fully reproducible - setting random seeds alone is not enough!
            make_pytorch_reproducible()

            # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been
            # loaded (typically only during tests)
            if self.model_config.dataset_data_frame is None:
                assert self.model_config.local_dataset is not None
                ml_util.validate_dataset_paths(self.model_config.local_dataset)

            # train a new model if required
            if self.azure_config.train:
                with logging_section("Model training"):
                    model_train(self.model_config, checkpoint_handler)
            else:
                self.model_config.write_dataset_files()
                self.create_activation_maps()

            # log the number of epochs used for model training
            RUN_CONTEXT.log(name="Train epochs",
                            value=self.model_config.num_epochs)

        # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because
        # the current run is a single one. See the documentation of ModelProcessing for more details.
        best_epoch = self.run_inference_and_register_model(
            checkpoint_handler, ModelProcessing.DEFAULT)

        # Generate report
        if best_epoch:
            Runner.generate_report(self.model_config, best_epoch,
                                   ModelProcessing.DEFAULT)
        elif self.model_config.is_scalar_model and len(
                self.model_config.get_test_epochs()) == 1:
            # We don't register scalar models but still want to create a report if we have run inference.
            Runner.generate_report(self.model_config,
                                   self.model_config.get_test_epochs()[0],
                                   ModelProcessing.DEFAULT)
def model_train(config: ModelConfigBase,
                checkpoint_handler: CheckpointHandler) -> ModelTrainingResults:
    """
    The main training loop. It creates the model, dataset, optimizer_type, and criterion, then proceeds
    to train the model. If a checkpoint was specified, then it loads the checkpoint before resuming training.

    :param config: The arguments which specify all required information.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :raises TypeError: If the arguments are of the wrong type.
    :raises ValueError: When there are issues loading a previous checkpoint.
    """
    # Save the dataset files for later use in cross validation analysis
    config.write_dataset_files()

    # set the random seed for all libraries
    ml_util.set_random_seed(config.get_effective_random_seed(),
                            "Patch visualization")
    # Visualize how patches are sampled for segmentation models. This changes the random generator, but we don't
    # want training to depend on how many patients we visualized, and hence set the random seed again right after.
    with logging_section(
            "Visualizing the effect of sampling random crops for training"):
        visualize_random_crops_for_dataset(config)
    ml_util.set_random_seed(config.get_effective_random_seed(),
                            "Model training")

    logging.debug("Creating the PyTorch model.")

    # Create the train loader and validation loader to load images from the dataset
    data_loaders = config.create_data_loaders()

    # Get the path to the checkpoint to recover from
    checkpoint_path = checkpoint_handler.get_recovery_path_train()

    models_and_optimizer = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TRAIN,
        checkpoint_path=checkpoint_path)

    # Create the main model
    # If continuing from a previous run at a specific epoch, then load the previous model.
    model_loaded = models_and_optimizer.try_create_model_and_load_from_checkpoint(
    )
    if not model_loaded:
        raise ValueError(
            "There was no checkpoint file available for the model for given start_epoch {}"
            .format(config.start_epoch))

    # Print out a detailed breakdown of layers, memory consumption and time.
    generate_and_print_model_summary(config, models_and_optimizer.model)

    # Move model to GPU and adjust for multiple GPUs
    models_and_optimizer.adjust_model_for_gpus()

    # Create the mean teacher model and move to GPU
    if config.compute_mean_teacher_model:
        mean_teacher_model_loaded = models_and_optimizer.try_create_mean_teacher_model_load_from_checkpoint_and_adjust(
        )
        if not mean_teacher_model_loaded:
            raise ValueError(
                "There was no checkpoint file available for the mean teacher model "
                f"for given start_epoch {config.start_epoch}")

    # Create optimizer
    models_and_optimizer.create_optimizer()
    if checkpoint_handler.should_load_optimizer_checkpoint():
        optimizer_loaded = models_and_optimizer.try_load_checkpoint_for_optimizer(
        )
        if not optimizer_loaded:
            raise ValueError(
                f"There was no checkpoint file available for the optimizer for given start_epoch "
                f"{config.start_epoch}")

    # Create checkpoint directory for this run if it doesn't already exist
    logging.info(f"Models are saved at {config.checkpoint_folder}")
    if not config.checkpoint_folder.is_dir():
        config.checkpoint_folder.mkdir()

    # Create the SummaryWriters for Tensorboard
    writers = create_summary_writers(config)
    config.create_dataframe_loggers()

    # Create LR scheduler
    l_rate_scheduler = SchedulerWithWarmUp(config,
                                           models_and_optimizer.optimizer)

    # Training loop
    logging.info("Starting training")
    train_results_per_epoch, val_results_per_epoch, learning_rates_per_epoch = [], [], []

    resource_monitor = None
    if config.monitoring_interval_seconds > 0:
        # initialize and start GPU monitoring
        diagnostics_events = config.logs_folder / "diagnostics"
        logging.info(
            f"Starting resource monitor, outputting to {diagnostics_events}")
        resource_monitor = ResourceMonitor(
            interval_seconds=config.monitoring_interval_seconds,
            tensorboard_folder=diagnostics_events)
        resource_monitor.start()

    gradient_scaler = GradScaler(
    ) if config.use_gpu and config.use_mixed_precision else None
    optimal_temperature_scale_values = []
    for epoch in config.get_train_epochs():
        logging.info("Starting epoch {}".format(epoch))
        save_epoch = config.should_save_epoch(
            epoch) and models_and_optimizer.optimizer is not None

        # store the learning rates used for each epoch
        epoch_lrs = l_rate_scheduler.get_last_lr()
        learning_rates_per_epoch.append(epoch_lrs)

        train_val_params: TrainValidateParameters = \
            TrainValidateParameters(data_loader=data_loaders[ModelExecutionMode.TRAIN],
                                    model=models_and_optimizer.model,
                                    mean_teacher_model=models_and_optimizer.mean_teacher_model,
                                    epoch=epoch,
                                    optimizer=models_and_optimizer.optimizer,
                                    gradient_scaler=gradient_scaler,
                                    epoch_learning_rate=epoch_lrs,
                                    summary_writers=writers,
                                    dataframe_loggers=config.metrics_data_frame_loggers,
                                    in_training_mode=True)
        training_steps = create_model_training_steps(config, train_val_params)
        train_epoch_results = train_or_validate_epoch(training_steps)
        train_results_per_epoch.append(train_epoch_results.metrics)

        metrics.validate_and_store_model_parameters(writers.train, epoch,
                                                    models_and_optimizer.model)
        # Run without adjusting weights on the validation set
        train_val_params.in_training_mode = False
        train_val_params.data_loader = data_loaders[ModelExecutionMode.VAL]
        # if temperature scaling is enabled then do not save validation metrics for the checkpoint epochs
        # as these will be re-computed after performing temperature scaling on the validation set.
        if isinstance(config, SequenceModelBase):
            train_val_params.save_metrics = not (
                save_epoch and config.temperature_scaling_config)

        training_steps = create_model_training_steps(config, train_val_params)
        val_epoch_results = train_or_validate_epoch(training_steps)
        val_results_per_epoch.append(val_epoch_results.metrics)

        if config.is_segmentation_model:
            metrics.store_epoch_stats_for_segmentation(
                config.outputs_folder, epoch, epoch_lrs,
                train_epoch_results.metrics, val_epoch_results.metrics)

        if save_epoch:
            # perform temperature scaling if required
            if isinstance(
                    config,
                    SequenceModelBase) and config.temperature_scaling_config:
                optimal_temperature, scaled_val_results = \
                    temperature_scaling_steps(config, train_val_params, val_epoch_results)
                optimal_temperature_scale_values.append(optimal_temperature)
                # overwrite the metrics for the epoch with the metrics from the temperature scaled model
                val_results_per_epoch[-1] = scaled_val_results.metrics

            models_and_optimizer.save_checkpoint(epoch)

        # Updating the learning rate should happen at the end of the training loop, so that the
        # initial learning rate will be used for the very first epoch.
        l_rate_scheduler.step()

    model_training_results = ModelTrainingResults(
        train_results_per_epoch=train_results_per_epoch,
        val_results_per_epoch=val_results_per_epoch,
        learning_rates_per_epoch=learning_rates_per_epoch,
        optimal_temperature_scale_values_per_checkpoint_epoch=
        optimal_temperature_scale_values)

    logging.info("Finished training")

    # Since we have trained the model further, let the checkpoint_handler object know so it can handle
    # checkpoints correctly.
    checkpoint_handler.additional_training_done()

    # Upload visualization directory to AML run context to be able to see it
    # in the Azure UI.
    if config.max_batch_grad_cam > 0 and config.visualization_folder.exists():
        RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER,
                                  path=str(config.visualization_folder))

    writers.close_all()
    config.metrics_data_frame_loggers.close_all()
    if resource_monitor:
        # stop the resource monitoring process
        logging.info(
            "Shutting down the resource monitor process. Aggregate resource utilization:"
        )
        for name, value in resource_monitor.read_aggregate_metrics():
            logging.info(f"{name}: {value}")
            if not is_offline_run_context(RUN_CONTEXT):
                RUN_CONTEXT.log(name, value)
        resource_monitor.kill()

    return model_training_results
示例#21
0
def model_train(config: ModelConfigBase,
                checkpoint_handler: CheckpointHandler,
                num_nodes: int = 1) -> ModelTrainingResults:
    """
    The main training loop. It creates the Pytorch model based on the configuration options passed in,
    creates a Pytorch Lightning trainer, and trains the model.
    If a checkpoint was specified, then it loads the checkpoint before resuming training.
    :param config: The arguments which specify all required information.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :param num_nodes: The number of nodes to use in distributed training.
    """
    # Get the path to the checkpoint to recover from
    checkpoint_path = checkpoint_handler.get_recovery_path_train()
    # This reads the dataset file, and possibly sets required pre-processing objects, like one-hot encoder
    # for categorical features, that need to be available before creating the model.
    config.read_dataset_if_needed()

    # Create the trainer object. Backup the environment variables before doing that, in case we need to run a second
    # training in the unit tests.d
    old_environ = dict(os.environ)
    seed_everything(config.get_effective_random_seed())
    trainer, storing_logger = create_lightning_trainer(config, checkpoint_path, num_nodes=num_nodes)

    logging.info(f"GLOBAL_RANK: {os.getenv('GLOBAL_RANK')}, LOCAL_RANK {os.getenv('LOCAL_RANK')}. "
                 f"trainer.global_rank: {trainer.global_rank}")
    logging.debug("Creating the PyTorch model.")
    lightning_model = create_lightning_model(config)
    lightning_model.storing_logger = storing_logger

    resource_monitor = None
    # Execute some bookkeeping tasks only once if running distributed:
    if is_rank_zero():
        config.write_args_file()
        logging.info(str(config))
        # Save the dataset files for later use in cross validation analysis
        config.write_dataset_files()
        logging.info(f"Model checkpoints are saved at {config.checkpoint_folder}")

        # set the random seed for all libraries
        ml_util.set_random_seed(config.get_effective_random_seed(), "Patch visualization")
        # Visualize how patches are sampled for segmentation models. This changes the random generator, but we don't
        # want training to depend on how many patients we visualized, and hence set the random seed again right after.
        with logging_section("Visualizing the effect of sampling random crops for training"):
            visualize_random_crops_for_dataset(config)

        # Print out a detailed breakdown of layers, memory consumption and time.
        generate_and_print_model_summary(config, lightning_model.model)

        if config.monitoring_interval_seconds > 0:
            # initialize and start GPU monitoring
            diagnostics_events = config.logs_folder / "diagnostics"
            logging.info(f"Starting resource monitor, outputting to {diagnostics_events}")
            resource_monitor = ResourceMonitor(interval_seconds=config.monitoring_interval_seconds,
                                               tensorboard_folder=diagnostics_events)
            resource_monitor.start()

    # Training loop
    logging.info("Starting training")

    lightning_data = TrainingAndValidationDataLightning(config)  # type: ignore
    # When trying to store the config object in the constructor, it does not appear to get stored at all, later
    # reference of the object simply fail. Hence, have to set explicitly here.
    lightning_data.config = config
    trainer.fit(lightning_model, datamodule=lightning_data)
    trainer.logger.close()  # type: ignore
    lightning_model.close_all_loggers()
    world_size = getattr(trainer, "world_size", 0)
    is_azureml_run = not config.is_offline_run
    # Per-subject model outputs for regression models are written per rank, and need to be aggregated here.
    # Each thread per rank will come here, and upload its files to the run outputs. Rank 0 will later download them.
    if is_azureml_run and world_size > 1 and isinstance(lightning_model, ScalarLightning):
        upload_output_file_as_temp(lightning_model.train_subject_outputs_logger.csv_path, config.outputs_folder)
        upload_output_file_as_temp(lightning_model.val_subject_outputs_logger.csv_path, config.outputs_folder)
    # DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training.
    # We can now use the global_rank of the Lightining model, rather than environment variables, because DDP has set
    # all necessary properties.
    if lightning_model.global_rank != 0:
        logging.info(f"Terminating training thread with rank {lightning_model.global_rank}.")
        sys.exit()

    logging.info("Choosing the best checkpoint and removing redundant files.")
    cleanup_checkpoint_folder(config.checkpoint_folder)
    # Lightning modifies a ton of environment variables. If we first run training and then the test suite,
    # those environment variables will mislead the training runs in the test suite, and make them crash.
    # Hence, restore the original environment after training.
    os.environ.clear()
    os.environ.update(old_environ)

    if world_size and isinstance(lightning_model, ScalarLightning):
        if is_azureml_run and world_size > 1:
            # In a DDP run on the local box, all ranks will write to local disk, hence no download needed.
            # In a multi-node DDP, each rank would upload to AzureML, and rank 0 will now download all results and
            # concatenate
            for rank in range(world_size):
                for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]:
                    file = mode.value + "/" + get_subject_output_file_per_rank(rank)
                    RUN_CONTEXT.download_file(name=TEMP_PREFIX + file, output_file_path=config.outputs_folder / file)
        # Concatenate all temporary file per execution mode
        for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]:
            temp_files = (config.outputs_folder / mode.value).rglob(SUBJECT_OUTPUT_PER_RANK_PREFIX + "*")
            result_file = config.outputs_folder / mode.value / SUBJECT_METRICS_FILE_NAME
            for i, file in enumerate(temp_files):
                temp_file_contents = file.read_text()
                if i == 0:
                    # Copy the first file as-is, including the first line with the column headers
                    result_file.write_text(temp_file_contents)
                else:
                    # For all files but the first one, cut off the header line.
                    result_file.write_text(os.linesep.join(temp_file_contents.splitlines()[1:]))

    model_training_results = ModelTrainingResults(
        train_results_per_epoch=list(storing_logger.to_metrics_dicts(prefix_filter=TRAIN_PREFIX).values()),
        val_results_per_epoch=list(storing_logger.to_metrics_dicts(prefix_filter=VALIDATION_PREFIX).values()),
        train_diagnostics=lightning_model.train_diagnostics,
        val_diagnostics=lightning_model.val_diagnostics,
        optimal_temperature_scale_values_per_checkpoint_epoch=[]
    )

    logging.info("Finished training")

    # Since we have trained the model further, let the checkpoint_handler object know so it can handle
    # checkpoints correctly.
    checkpoint_handler.additional_training_done()

    # Upload visualization directory to AML run context to be able to see it
    # in the Azure UI.
    if config.max_batch_grad_cam > 0 and config.visualization_folder.exists():
        RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER, path=str(config.visualization_folder))

    if resource_monitor:
        # stop the resource monitoring process
        logging.info("Shutting down the resource monitor process. Aggregate resource utilization:")
        for name, value in resource_monitor.read_aggregate_metrics():
            logging.info(f"{name}: {value}")
            if not config.is_offline_run:
                RUN_CONTEXT.log(name, value)
        resource_monitor.kill()

    return model_training_results