def plot_cross_validation(config: PlotCrossValidationConfig) -> Path: """ Collects results from an AzureML cross validation run, and writes aggregate metrics files. :param config: The settings for plotting cross validation results. :return: """ logging_to_stdout(logging.INFO) with logging_section("Downloading cross-validation results"): result_files, root_folder = download_crossval_result_files(config) config_and_files = OfflineCrossvalConfigAndFiles(config=config, files=result_files) with logging_section("Plotting cross-validation results"): plot_cross_validation_from_files(config_and_files, root_folder) return root_folder
def register_model(self, checkpoint_paths: List[Path], model_description: str, model_proc: ModelProcessing) -> None: """ Registers the model in AzureML, with the given set of checkpoints. The AzureML run's tags are updated to describe with information about ensemble creation and the parent run ID. :param checkpoint_paths: The set of Pytorch checkpoints that should be included. :param model_description: A string description of the model, usually containing accuracy numbers. :param model_proc: The type of model that is registered (single or ensemble) """ if not checkpoint_paths: # No point continuing, since no checkpoints were found logging.warning("Abandoning model registration - no valid checkpoint paths found") return if not self.model_config.is_offline_run: split_index = RUN_CONTEXT.get_tags().get(CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, None) if split_index == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX: RUN_CONTEXT.tag(IS_ENSEMBLE_KEY_NAME, str(model_proc == ModelProcessing.ENSEMBLE_CREATION)) elif PARENT_RUN_CONTEXT is not None: RUN_CONTEXT.tag(PARENT_RUN_ID_KEY_NAME, str(PARENT_RUN_CONTEXT.id)) if isinstance(self.model_config, SegmentationModelBase): with logging_section(f"Registering {model_proc.value} model"): self.register_segmentation_model( checkpoint_paths=checkpoint_paths, model_description=model_description, model_proc=model_proc) else: logging.info(f"No deployment done for this type of model: {type(self.model_config)}")
def plot_cross_validation(config: PlotCrossValidationConfig) -> Path: """ Collects results from an AzureML cross validation run, and writes aggregate metrics files. and assert that there are N+1 data files available. If false, this analysis only concerns the cross validation runs, and check that the number of files is N. :param config: The settings for plotting cross validation results. :return: The path with all cross validation result files. """ logging_to_stdout(logging.INFO) with logging_section("Downloading cross-validation results"): result_files, root_folder = download_crossval_result_files(config) config_and_files = OfflineCrossvalConfigAndFiles(config=config, files=result_files) with logging_section("Plotting cross-validation results"): plot_cross_validation_from_files(config_and_files, root_folder) return root_folder
def download_dataset(azure_dataset_id: str, target_folder: Path, dataset_csv: str, azure_config: AzureConfig) -> Path: """ Downloads or checks for an existing dataset on the executing machine. If a local_dataset is supplied and the directory is present, return that. Otherwise, download the dataset specified by the azure_dataset_id from the AzureML dataset attached to the given AzureML workspace. The dataset is downloaded into the `target_folder`, in a subfolder that has the same name as the dataset. If there already appears to be such a folder, and the folder contains a dataset csv file, no download is started. :param azure_dataset_id: The name of a dataset that is registered in the AzureML workspace. :param target_folder: The folder in which to download the dataset from Azure. :param dataset_csv: Name of the csv file describing the dataset. :param azure_config: All Azure-related configuration options. :return: A path on the local machine that contains the dataset. """ logging.info("Trying to download dataset via AzureML datastore now.") azure_dataset = get_or_create_dataset(azure_config, azure_dataset_id) if not isinstance(azure_dataset, FileDataset): raise ValueError(f"Expected to get a FileDataset, but got {type(azure_dataset)}") # The downloaded dataset may already exist from a previous run. expected_dataset_path = target_folder / azure_dataset_id expected_dataset_file = expected_dataset_path / dataset_csv logging.info(f"Model training will use dataset '{azure_dataset_id}' in Azure.") if expected_dataset_path.is_dir() and expected_dataset_file.is_file(): logging.info(f"The dataset appears to be downloaded already in {expected_dataset_path}. Skipping.") return expected_dataset_path logging.info("Starting to download the dataset - WARNING, this could take very long!") with logging_section("Downloading dataset"): t0 = time.perf_counter() azure_dataset.download(target_path=str(expected_dataset_path), overwrite=False) t1 = time.perf_counter() - t0 logging.info(f"Azure dataset '{azure_dataset_id}' downloaded in {t1} seconds") logging.info(f"Azure dataset '{azure_dataset_id}' is now available in {expected_dataset_path}") return expected_dataset_path
def register_model_for_epoch(self, run_context: Run, checkpoint_handler: CheckpointHandler, best_epoch: int, best_epoch_dice: float, model_proc: ModelProcessing) -> None: checkpoint_path_and_epoch = checkpoint_handler.get_checkpoint_from_epoch( epoch=best_epoch) if not checkpoint_path_and_epoch or not checkpoint_path_and_epoch.checkpoint_paths: # No point continuing, since no checkpoints were found logging.warning( "Abandoning model registration - no valid checkpoint paths found" ) return if not self.model_config.is_offline_run: split_index = run_context.get_tags().get( CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, None) if split_index == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX: update_run_tags( run_context, { IS_ENSEMBLE_KEY_NAME: model_proc == ModelProcessing.ENSEMBLE_CREATION }) elif PARENT_RUN_CONTEXT is not None: update_run_tags( run_context, {PARENT_RUN_ID_KEY_NAME: PARENT_RUN_CONTEXT.id}) with logging_section(f"Registering {model_proc.value} model"): self.register_segmentation_model( run=run_context, best_epoch=best_epoch, best_epoch_dice=best_epoch_dice, checkpoint_paths=checkpoint_path_and_epoch.checkpoint_paths, model_proc=model_proc)
def model_test( config: ModelConfigBase, data_split: ModelExecutionMode, run_recovery: Optional[RunRecovery] = None, model_proc: ModelProcessing = ModelProcessing.DEFAULT ) -> Optional[InferenceMetrics]: """ Runs model inference on segmentation or classification models, using a given dataset (that could be training, test or validation set). The inference results and metrics will be stored and logged in a way that may differ for model categories (classification, segmentation). :param config: The configuration of the model :param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed. :param run_recovery: Run recovery data if applicable. :param model_proc: whether we are testing an ensemble or single model; this affects where results are written. :return: The metrics that the model achieved on the given data set, or None if the data set is empty. """ if len(config.get_dataset_splits()[data_split]) == 0: logging.info(f"Skipping inference on empty data split {data_split}") return None if config.avoid_process_spawn_in_data_loaders and is_linux(): logging.warning( "Not performing any inference because avoid_process_spawn_in_data_loaders is set " "and additional data loaders are likely to block.") return None with logging_section( f"Running {model_proc.value} model on {data_split.name.lower()} set" ): if isinstance(config, SegmentationModelBase): return segmentation_model_test(config, data_split, run_recovery, model_proc) if isinstance(config, ScalarModelBase): return classification_model_test(config, data_split, run_recovery) raise ValueError( f"There is no testing code for models of type {type(config)}")
def wait_for_runs_to_finish(self, delay: int = 60) -> None: """ Wait for cross val runs (apart from the current one) to finish and then aggregate results of all. :param delay: How long to wait between polls to AML to get status of child runs """ with logging_section("Waiting for sibling runs"): while not self.are_sibling_runs_finished(): time.sleep(delay)
def try_compare_scores_against_baselines(self, model_proc: ModelProcessing) -> None: """ Attempt comparison of scores against baseline scores and scatterplot creation if possible. """ if not isinstance(self.model_config, SegmentationModelBase): # keep type checker happy return from InnerEye.ML.baselines_util import compare_scores_against_baselines with logging_section("Comparing scores against baselines"): compare_scores_against_baselines(self.model_config, self.azure_config, model_proc)
def wait_for_cross_val_runs_to_finish_and_aggregate(self, delay: int = 60) -> None: """ Wait for cross val runs (apart from the current one) to finish and then aggregate results of all. :param delay: How long to wait between polls to AML to get status of child runs """ with logging_section("Waiting for sibling runs"): while self.wait_until_cross_val_splits_are_ready_for_aggregation(): time.sleep(delay) assert PARENT_RUN_CONTEXT, "This function should only be called in a Hyperdrive run" self.create_ensemble_model()
def before_training_on_global_rank_zero(self) -> None: # Save the dataset files for later use in cross validation analysis self.config.write_dataset_files() if isinstance(self.config, SegmentationModelBase): with logging_section( "Visualizing the effect of sampling random crops for training" ): visualize_random_crops_for_dataset(self.config) # Print out a detailed breakdown of layers, memory consumption and time. assert isinstance(self.model, InnerEyeLightning) generate_and_print_model_summary(self.config, self.model.model)
def create_ensemble_model(self) -> None: """ Call MLRunner again after training cross-validation models, to create an ensemble model from them. """ # Import only here in case of dependency issues in reduced environment from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler # Adjust parameters self.azure_config.hyperdrive = False self.model_config.number_of_cross_validation_splits = 0 self.model_config.is_train = False with logging_section("Downloading checkpoints from sibling runs"): checkpoint_handler = CheckpointHandler( model_config=self.model_config, azure_config=self.azure_config, project_root=self.project_root, run_context=PARENT_RUN_CONTEXT) checkpoint_handler.discover_and_download_checkpoint_from_sibling_runs( output_subdir_name=OTHER_RUNS_SUBDIR_NAME) best_epoch = self.create_ml_runner().run_inference_and_register_model( checkpoint_handler=checkpoint_handler, model_proc=ModelProcessing.ENSEMBLE_CREATION) crossval_dir = self.plot_cross_validation_and_upload_results() Runner.generate_report(self.model_config, best_epoch, ModelProcessing.ENSEMBLE_CREATION) # CrossValResults should have been uploaded to the parent run, so we don't need it here. remove_file_or_directory(crossval_dir) # We can also remove OTHER_RUNS under the root, as it is no longer useful and only contains copies of files # available elsewhere. However, first we need to upload relevant parts of OTHER_RUNS/ENSEMBLE. other_runs_dir = self.model_config.outputs_folder / OTHER_RUNS_SUBDIR_NAME other_runs_ensemble_dir = other_runs_dir / ENSEMBLE_SPLIT_NAME if PARENT_RUN_CONTEXT is not None: if other_runs_ensemble_dir.exists(): # Only keep baseline Wilcoxon results and scatterplots and reports for subdir in other_runs_ensemble_dir.glob("*"): if subdir.name not in [ BASELINE_WILCOXON_RESULTS_FILE, SCATTERPLOTS_SUBDIR_NAME, REPORT_HTML, REPORT_IPYNB ]: remove_file_or_directory(subdir) PARENT_RUN_CONTEXT.upload_folder( name=BASELINE_COMPARISONS_FOLDER, path=str(other_runs_ensemble_dir)) else: logging.warning( f"Directory not found for upload: {other_runs_ensemble_dir}" ) remove_file_or_directory(other_runs_dir)
def download_dataset(azure_dataset_id: str, target_folder: Path, azure_config: AzureConfig) -> Path: """ Downloads or checks for an existing dataset on the executing machine. If a local_dataset is supplied and the directory is present, return that. Otherwise, download the dataset specified by the azure_dataset_id from the AzureML dataset attached to the given AzureML workspace. The dataset is downloaded into the `target_folder`, in a subfolder that has the same name as the dataset. If there already appears to be such a folder, and the folder contains a dataset.csv file, no download is started. :param local_dataset: The path to an existing local dataset. :param azure_dataset_id: The name of a dataset that is registered in the AzureML workspace. :param target_folder: The folder in which to download the dataset from Azure. :param azure_config: All Azure-related configuration options. :return: A path on the local machine that contains the dataset. """ workspace = azure_config.get_workspace() try: downloaded_via_blobxfer = download_dataset_via_blobxfer( dataset_id=azure_dataset_id, azure_config=azure_config, target_folder=target_folder) if downloaded_via_blobxfer: return downloaded_via_blobxfer except Exception as ex: print_exception(ex, message="Unable to download dataset via blobxfer.") logging.info("Trying to download dataset via AzureML datastore now.") azure_dataset = get_or_create_dataset(workspace, azure_dataset_id) if not isinstance(azure_dataset, FileDataset): raise ValueError( f"Expected to get a FileDataset, but got {type(azure_dataset)}") # The downloaded dataset may already exist from a previous run. expected_dataset_path = target_folder / azure_dataset_id expected_dataset_file = expected_dataset_path / DATASET_CSV_FILE_NAME logging.info( f"Model training will use dataset '{azure_dataset_id}' in Azure.") if expected_dataset_path.is_dir() and expected_dataset_file.is_file(): logging.info( f"The dataset appears to be downloaded already in {expected_dataset_path}. Skipping." ) return expected_dataset_path logging.info( "Starting to download the dataset - WARNING, this could take very long!" ) with logging_section("Downloading dataset"): azure_dataset.download(target_path=str(expected_dataset_path), overwrite=False) logging.info( f"Azure dataset '{azure_dataset_id}' is now available in {expected_dataset_path}" ) return expected_dataset_path
def test_download_azureml_dataset(test_output_dirs: OutputFolderForTests) -> None: dataset_name = "test-dataset" config = DummyModel() config.local_dataset = None config.azure_dataset_id = "" azure_config = get_default_azure_config() runner = MLRunner(config, azure_config=azure_config) # If the model has neither local_dataset or azure_dataset_id, mount_or_download_dataset should fail. # This mounting call must happen before any other operations on the container, because already the model # creation may need access to the dataset. with pytest.raises(ValueError) as ex: runner.setup() assert ex.value.args[0] == "The model must contain either local_dataset or azure_dataset_id." runner.project_root = test_output_dirs.root_dir # Pointing the model to a dataset folder that does not exist should raise an Exception fake_folder = runner.project_root / "foo" runner.container.local_dataset = fake_folder with pytest.raises(FileNotFoundError): runner.mount_or_download_dataset(runner.container.azure_dataset_id, runner.container.local_dataset) # If the local dataset folder exists, mount_or_download_dataset should not do anything. fake_folder.mkdir() local_dataset = runner.mount_or_download_dataset(runner.container.azure_dataset_id, runner.container.local_dataset) assert local_dataset == fake_folder # Pointing the model to a dataset in Azure should trigger a download runner.container.local_dataset = None runner.container.azure_dataset_id = dataset_name with logging_section("Starting download"): result_path = runner.mount_or_download_dataset(runner.container.azure_dataset_id, runner.container.local_dataset) # Download goes into <project_root> / "datasets" / "test_dataset" expected_path = runner.project_root / fixed_paths.DATASETS_DIR_NAME / dataset_name assert result_path == expected_path assert result_path.is_dir() dataset_csv = Path(result_path) / DATASET_CSV_FILE_NAME assert dataset_csv.is_file() # Check that each individual file in the dataset is present for folder in [1, *range(10, 20)]: sub_folder = result_path / str(folder) sub_folder.is_dir() for file in ["ct", "esophagus", "heart", "lung_l", "lung_r", "spinalcord"]: f = (sub_folder / file).with_suffix(".nii.gz") assert f.is_file()
def create_ensemble_model(self) -> None: """ Create an ensemble model from the results of the sibling runs of the present run. The present run here will be cross validation child run 0. """ assert PARENT_RUN_CONTEXT, "This function should only be called in a Hyperdrive run" with logging_section("Downloading checkpoints from sibling runs"): checkpoint_handler = CheckpointHandler( model_config=self.model_config, azure_config=self.azure_config, project_root=self.project_root, run_context=PARENT_RUN_CONTEXT) checkpoint_handler.download_checkpoints_from_hyperdrive_child_runs( PARENT_RUN_CONTEXT) self.run_inference_and_register_model( checkpoint_handler=checkpoint_handler, model_proc=ModelProcessing.ENSEMBLE_CREATION) crossval_dir = self.plot_cross_validation_and_upload_results() self.generate_report(ModelProcessing.ENSEMBLE_CREATION) # CrossValResults should have been uploaded to the parent run, so we don't need it here. remove_file_or_directory(crossval_dir) # We can also remove OTHER_RUNS under the root, as it is no longer useful and only contains copies of files # available elsewhere. However, first we need to upload relevant parts of OTHER_RUNS/ENSEMBLE. other_runs_dir = self.model_config.outputs_folder / OTHER_RUNS_SUBDIR_NAME other_runs_ensemble_dir = other_runs_dir / ENSEMBLE_SPLIT_NAME if PARENT_RUN_CONTEXT is not None: if other_runs_ensemble_dir.exists(): # Only keep baseline Wilcoxon results and scatterplots and reports for subdir in other_runs_ensemble_dir.glob("*"): if subdir.name not in [ BASELINE_WILCOXON_RESULTS_FILE, SCATTERPLOTS_SUBDIR_NAME, REPORT_HTML, REPORT_IPYNB ]: remove_file_or_directory(subdir) PARENT_RUN_CONTEXT.upload_folder( name=BASELINE_COMPARISONS_FOLDER, path=str(other_runs_ensemble_dir)) else: logging.warning( f"Directory not found for upload: {other_runs_ensemble_dir}" ) remove_file_or_directory(other_runs_dir)
def test_download_azureml_dataset( test_output_dirs: OutputFolderForTests) -> None: dataset_name = "test-dataset" config = ModelConfigBase(should_validate=False) azure_config = get_default_azure_config() runner = MLRunner(config, azure_config) runner.project_root = test_output_dirs.root_dir # If the model has neither local_dataset or azure_dataset_id, mount_or_download_dataset should fail. with pytest.raises(ValueError): runner.mount_or_download_dataset() # Pointing the model to a dataset folder that does not exist should raise an Exception fake_folder = runner.project_root / "foo" runner.model_config.local_dataset = fake_folder with pytest.raises(FileNotFoundError): runner.mount_or_download_dataset() # If the local dataset folder exists, mount_or_download_dataset should not do anything. fake_folder.mkdir() local_dataset = runner.mount_or_download_dataset() assert local_dataset == fake_folder # Pointing the model to a dataset in Azure should trigger a download runner.model_config.local_dataset = None runner.model_config.azure_dataset_id = dataset_name with logging_section("Starting download"): result_path = runner.mount_or_download_dataset() # Download goes into <project_root> / "datasets" / "test_dataset" expected_path = runner.project_root / fixed_paths.DATASETS_DIR_NAME / dataset_name assert result_path == expected_path assert result_path.is_dir() dataset_csv = Path(result_path) / DATASET_CSV_FILE_NAME assert dataset_csv.is_file() # Check that each individual file in the dataset is present for folder in [1, *range(10, 20)]: sub_folder = result_path / str(folder) sub_folder.is_dir() for file in [ "ct", "esophagus", "heart", "lung_l", "lung_r", "spinalcord" ]: f = (sub_folder / file).with_suffix(".nii.gz") assert f.is_file()
def register_model_for_epoch(self, run_context: Run, run_recovery: Optional[RunRecovery], best_epoch: int, best_epoch_dice: float, model_proc: ModelProcessing) -> None: checkpoint_paths = [self.model_config.get_path_to_checkpoint(best_epoch)] if not run_recovery \ else run_recovery.get_checkpoint_paths(best_epoch) if not self.model_config.is_offline_run: split_index = run_context.get_tags().get( CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, None) if split_index == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX: update_run_tags( run_context, { IS_ENSEMBLE_KEY_NAME: model_proc == ModelProcessing.ENSEMBLE_CREATION }) elif PARENT_RUN_CONTEXT is not None: update_run_tags( run_context, {PARENT_RUN_ID_KEY_NAME: PARENT_RUN_CONTEXT.id}) # Discard any checkpoint paths that do not exist - they will make registration fail. This can happen # when some child runs fail; it may still be worth registering the model. valid_checkpoint_paths = [] for path in checkpoint_paths: if path.exists(): valid_checkpoint_paths.append(path) else: logging.warning( f"Discarding non-existent checkpoint path {path}") if not valid_checkpoint_paths: # No point continuing logging.warning( "Abandoning model registration - no valid checkpoint paths found" ) return with logging_section(f"Registering {model_proc.value} model"): self.register_segmentation_model( run=run_context, best_epoch=best_epoch, best_epoch_dice=best_epoch_dice, checkpoint_paths=valid_checkpoint_paths, model_proc=model_proc)
def download_dataset_via_blobxfer(dataset_id: str, azure_config: AzureConfig, target_folder: Path) -> Optional[Path]: """ Attempts to downloads a dataset from the Azure storage account for datasets, with download happening via blobxfer. This is only possible if the datasets storage account and keyword are present in the `azure_config`. The function returns None if the required settings were not present. :param dataset_id: The folder of the dataset, expected in the container given by azure_config.datasets_container. :param azure_config: The object with all Azure-related settings. :param target_folder: The local folder into which the dataset should be downloaded. :return: The folder that contains the downloaded dataset. Returns None if the datasets account name or password were not present. """ datasets_account_key = azure_config.get_dataset_storage_account_key() if not datasets_account_key: logging.info( "No account key for the dataset storage account was found.") logging.info( f"We checked in environment variables and in the file {PROJECT_SECRETS_FILE}" ) return None if (not azure_config.datasets_container) or ( not azure_config.datasets_storage_account): logging.info("Datasets storage account or container missing.") return None target_folder.mkdir(exist_ok=True) result_folder = target_folder / dataset_id # only download if hasn't already been downloaded if result_folder.is_dir(): logging.info( f"Folder already exists, skipping download: {result_folder}") return result_folder with logging_section(f"Downloading dataset {dataset_id}"): download_blobs( account=azure_config.datasets_storage_account, account_key=datasets_account_key, # When specifying the blobs root path, ensure that there is a slash at the end, otherwise # all datasets with that dataset_id as a prefix get downloaded. blobs_root_path=f"{azure_config.datasets_container}/{dataset_id}/", destination=result_folder) return result_folder
def run(self) -> None: """ Driver function to run a ML experiment. If an offline cross validation run is requested, then this function is recursively called for each cross validation split. """ if self.is_offline_cross_val_parent_run(): if self.model_config.is_segmentation_model: raise NotImplementedError("Offline cross validation is only supported for classification models.") self.spawn_offline_cross_val_classification_child_runs() return # Get the AzureML context in which the script is running if not self.model_config.is_offline_run and PARENT_RUN_CONTEXT is not None: logging.info("Setting tags from parent run.") self.set_run_tags_from_parent() self.save_build_info_for_dotnet_consumers() # Set data loader start method self.set_multiprocessing_start_method() # configure recovery container if provided checkpoint_handler = CheckpointHandler(model_config=self.model_config, azure_config=self.azure_config, project_root=self.project_root, run_context=RUN_CONTEXT) checkpoint_handler.download_recovery_checkpoints_or_weights() # do training and inference, unless the "only register" switch is set (which requires a run_recovery # to be valid). if not self.azure_config.only_register_model: # Set local_dataset to the mounted path specified in azure_runner.py, if any, or download it if that fails # and config.local_dataset was not already set. self.model_config.local_dataset = self.mount_or_download_dataset() # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been # loaded (typically only during tests) if self.model_config.dataset_data_frame is None: assert self.model_config.local_dataset is not None ml_util.validate_dataset_paths( self.model_config.local_dataset, self.model_config.dataset_csv) # train a new model if required if self.azure_config.train: with logging_section("Model training"): model_train(self.model_config, checkpoint_handler, num_nodes=self.azure_config.num_nodes) else: self.model_config.write_dataset_files() self.create_activation_maps() # log the number of epochs used for model training RUN_CONTEXT.log(name="Train epochs", value=self.model_config.num_epochs) # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because # the current run is a single one. See the documentation of ModelProcessing for more details. self.run_inference_and_register_model(checkpoint_handler, ModelProcessing.DEFAULT) if self.model_config.generate_report: self.generate_report(ModelProcessing.DEFAULT) # If this is an cross validation run, and the present run is child run 0, then wait for the sibling runs, # build the ensemble model, and write a report for that. if self.model_config.number_of_cross_validation_splits > 0: if self.model_config.should_wait_for_other_cross_val_child_runs(): self.wait_for_runs_to_finish() self.create_ensemble_model()
def run(self) -> None: """ Driver function to run a ML experiment. If an offline cross validation run is requested, then this function is recursively called for each cross validation split. """ if self.is_offline_cross_val_parent_run(): if self.model_config.is_segmentation_model: raise NotImplementedError( "Offline cross validation is only supported for classification models." ) self.spawn_offline_cross_val_classification_child_runs() return # Get the AzureML context in which the script is running if not self.model_config.is_offline_run and PARENT_RUN_CONTEXT is not None: logging.info("Setting tags from parent run.") self.set_run_tags_from_parent() self.save_build_info_for_dotnet_consumers() # Set data loader start method self.set_multiprocessing_start_method() # configure recovery container if provided checkpoint_handler = CheckpointHandler(model_config=self.model_config, azure_config=self.azure_config, project_root=self.project_root, run_context=RUN_CONTEXT) checkpoint_handler.discover_and_download_checkpoints_from_previous_runs( ) # do training and inference, unless the "only register" switch is set (which requires a run_recovery # to be valid). if not self.azure_config.register_model_only_for_epoch: # Set local_dataset to the mounted path specified in azure_runner.py, if any, or download it if that fails # and config.local_dataset was not already set. self.model_config.local_dataset = self.mount_or_download_dataset() self.model_config.write_args_file() logging.info(str(self.model_config)) # Ensure that training runs are fully reproducible - setting random seeds alone is not enough! make_pytorch_reproducible() # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been # loaded (typically only during tests) if self.model_config.dataset_data_frame is None: assert self.model_config.local_dataset is not None ml_util.validate_dataset_paths(self.model_config.local_dataset) # train a new model if required if self.azure_config.train: with logging_section("Model training"): model_train(self.model_config, checkpoint_handler) else: self.model_config.write_dataset_files() self.create_activation_maps() # log the number of epochs used for model training RUN_CONTEXT.log(name="Train epochs", value=self.model_config.num_epochs) # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because # the current run is a single one. See the documentation of ModelProcessing for more details. best_epoch = self.run_inference_and_register_model( checkpoint_handler, ModelProcessing.DEFAULT) # Generate report if best_epoch: Runner.generate_report(self.model_config, best_epoch, ModelProcessing.DEFAULT) elif self.model_config.is_scalar_model and len( self.model_config.get_test_epochs()) == 1: # We don't register scalar models but still want to create a report if we have run inference. Runner.generate_report(self.model_config, self.model_config.get_test_epochs()[0], ModelProcessing.DEFAULT)
def model_train(config: ModelConfigBase, checkpoint_handler: CheckpointHandler) -> ModelTrainingResults: """ The main training loop. It creates the model, dataset, optimizer_type, and criterion, then proceeds to train the model. If a checkpoint was specified, then it loads the checkpoint before resuming training. :param config: The arguments which specify all required information. :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization :raises TypeError: If the arguments are of the wrong type. :raises ValueError: When there are issues loading a previous checkpoint. """ # Save the dataset files for later use in cross validation analysis config.write_dataset_files() # set the random seed for all libraries ml_util.set_random_seed(config.get_effective_random_seed(), "Patch visualization") # Visualize how patches are sampled for segmentation models. This changes the random generator, but we don't # want training to depend on how many patients we visualized, and hence set the random seed again right after. with logging_section( "Visualizing the effect of sampling random crops for training"): visualize_random_crops_for_dataset(config) ml_util.set_random_seed(config.get_effective_random_seed(), "Model training") logging.debug("Creating the PyTorch model.") # Create the train loader and validation loader to load images from the dataset data_loaders = config.create_data_loaders() # Get the path to the checkpoint to recover from checkpoint_path = checkpoint_handler.get_recovery_path_train() models_and_optimizer = ModelAndInfo( config=config, model_execution_mode=ModelExecutionMode.TRAIN, checkpoint_path=checkpoint_path) # Create the main model # If continuing from a previous run at a specific epoch, then load the previous model. model_loaded = models_and_optimizer.try_create_model_and_load_from_checkpoint( ) if not model_loaded: raise ValueError( "There was no checkpoint file available for the model for given start_epoch {}" .format(config.start_epoch)) # Print out a detailed breakdown of layers, memory consumption and time. generate_and_print_model_summary(config, models_and_optimizer.model) # Move model to GPU and adjust for multiple GPUs models_and_optimizer.adjust_model_for_gpus() # Create the mean teacher model and move to GPU if config.compute_mean_teacher_model: mean_teacher_model_loaded = models_and_optimizer.try_create_mean_teacher_model_load_from_checkpoint_and_adjust( ) if not mean_teacher_model_loaded: raise ValueError( "There was no checkpoint file available for the mean teacher model " f"for given start_epoch {config.start_epoch}") # Create optimizer models_and_optimizer.create_optimizer() if checkpoint_handler.should_load_optimizer_checkpoint(): optimizer_loaded = models_and_optimizer.try_load_checkpoint_for_optimizer( ) if not optimizer_loaded: raise ValueError( f"There was no checkpoint file available for the optimizer for given start_epoch " f"{config.start_epoch}") # Create checkpoint directory for this run if it doesn't already exist logging.info(f"Models are saved at {config.checkpoint_folder}") if not config.checkpoint_folder.is_dir(): config.checkpoint_folder.mkdir() # Create the SummaryWriters for Tensorboard writers = create_summary_writers(config) config.create_dataframe_loggers() # Create LR scheduler l_rate_scheduler = SchedulerWithWarmUp(config, models_and_optimizer.optimizer) # Training loop logging.info("Starting training") train_results_per_epoch, val_results_per_epoch, learning_rates_per_epoch = [], [], [] resource_monitor = None if config.monitoring_interval_seconds > 0: # initialize and start GPU monitoring diagnostics_events = config.logs_folder / "diagnostics" logging.info( f"Starting resource monitor, outputting to {diagnostics_events}") resource_monitor = ResourceMonitor( interval_seconds=config.monitoring_interval_seconds, tensorboard_folder=diagnostics_events) resource_monitor.start() gradient_scaler = GradScaler( ) if config.use_gpu and config.use_mixed_precision else None optimal_temperature_scale_values = [] for epoch in config.get_train_epochs(): logging.info("Starting epoch {}".format(epoch)) save_epoch = config.should_save_epoch( epoch) and models_and_optimizer.optimizer is not None # store the learning rates used for each epoch epoch_lrs = l_rate_scheduler.get_last_lr() learning_rates_per_epoch.append(epoch_lrs) train_val_params: TrainValidateParameters = \ TrainValidateParameters(data_loader=data_loaders[ModelExecutionMode.TRAIN], model=models_and_optimizer.model, mean_teacher_model=models_and_optimizer.mean_teacher_model, epoch=epoch, optimizer=models_and_optimizer.optimizer, gradient_scaler=gradient_scaler, epoch_learning_rate=epoch_lrs, summary_writers=writers, dataframe_loggers=config.metrics_data_frame_loggers, in_training_mode=True) training_steps = create_model_training_steps(config, train_val_params) train_epoch_results = train_or_validate_epoch(training_steps) train_results_per_epoch.append(train_epoch_results.metrics) metrics.validate_and_store_model_parameters(writers.train, epoch, models_and_optimizer.model) # Run without adjusting weights on the validation set train_val_params.in_training_mode = False train_val_params.data_loader = data_loaders[ModelExecutionMode.VAL] # if temperature scaling is enabled then do not save validation metrics for the checkpoint epochs # as these will be re-computed after performing temperature scaling on the validation set. if isinstance(config, SequenceModelBase): train_val_params.save_metrics = not ( save_epoch and config.temperature_scaling_config) training_steps = create_model_training_steps(config, train_val_params) val_epoch_results = train_or_validate_epoch(training_steps) val_results_per_epoch.append(val_epoch_results.metrics) if config.is_segmentation_model: metrics.store_epoch_stats_for_segmentation( config.outputs_folder, epoch, epoch_lrs, train_epoch_results.metrics, val_epoch_results.metrics) if save_epoch: # perform temperature scaling if required if isinstance( config, SequenceModelBase) and config.temperature_scaling_config: optimal_temperature, scaled_val_results = \ temperature_scaling_steps(config, train_val_params, val_epoch_results) optimal_temperature_scale_values.append(optimal_temperature) # overwrite the metrics for the epoch with the metrics from the temperature scaled model val_results_per_epoch[-1] = scaled_val_results.metrics models_and_optimizer.save_checkpoint(epoch) # Updating the learning rate should happen at the end of the training loop, so that the # initial learning rate will be used for the very first epoch. l_rate_scheduler.step() model_training_results = ModelTrainingResults( train_results_per_epoch=train_results_per_epoch, val_results_per_epoch=val_results_per_epoch, learning_rates_per_epoch=learning_rates_per_epoch, optimal_temperature_scale_values_per_checkpoint_epoch= optimal_temperature_scale_values) logging.info("Finished training") # Since we have trained the model further, let the checkpoint_handler object know so it can handle # checkpoints correctly. checkpoint_handler.additional_training_done() # Upload visualization directory to AML run context to be able to see it # in the Azure UI. if config.max_batch_grad_cam > 0 and config.visualization_folder.exists(): RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER, path=str(config.visualization_folder)) writers.close_all() config.metrics_data_frame_loggers.close_all() if resource_monitor: # stop the resource monitoring process logging.info( "Shutting down the resource monitor process. Aggregate resource utilization:" ) for name, value in resource_monitor.read_aggregate_metrics(): logging.info(f"{name}: {value}") if not is_offline_run_context(RUN_CONTEXT): RUN_CONTEXT.log(name, value) resource_monitor.kill() return model_training_results
def model_train(config: ModelConfigBase, checkpoint_handler: CheckpointHandler, num_nodes: int = 1) -> ModelTrainingResults: """ The main training loop. It creates the Pytorch model based on the configuration options passed in, creates a Pytorch Lightning trainer, and trains the model. If a checkpoint was specified, then it loads the checkpoint before resuming training. :param config: The arguments which specify all required information. :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization :param num_nodes: The number of nodes to use in distributed training. """ # Get the path to the checkpoint to recover from checkpoint_path = checkpoint_handler.get_recovery_path_train() # This reads the dataset file, and possibly sets required pre-processing objects, like one-hot encoder # for categorical features, that need to be available before creating the model. config.read_dataset_if_needed() # Create the trainer object. Backup the environment variables before doing that, in case we need to run a second # training in the unit tests.d old_environ = dict(os.environ) seed_everything(config.get_effective_random_seed()) trainer, storing_logger = create_lightning_trainer(config, checkpoint_path, num_nodes=num_nodes) logging.info(f"GLOBAL_RANK: {os.getenv('GLOBAL_RANK')}, LOCAL_RANK {os.getenv('LOCAL_RANK')}. " f"trainer.global_rank: {trainer.global_rank}") logging.debug("Creating the PyTorch model.") lightning_model = create_lightning_model(config) lightning_model.storing_logger = storing_logger resource_monitor = None # Execute some bookkeeping tasks only once if running distributed: if is_rank_zero(): config.write_args_file() logging.info(str(config)) # Save the dataset files for later use in cross validation analysis config.write_dataset_files() logging.info(f"Model checkpoints are saved at {config.checkpoint_folder}") # set the random seed for all libraries ml_util.set_random_seed(config.get_effective_random_seed(), "Patch visualization") # Visualize how patches are sampled for segmentation models. This changes the random generator, but we don't # want training to depend on how many patients we visualized, and hence set the random seed again right after. with logging_section("Visualizing the effect of sampling random crops for training"): visualize_random_crops_for_dataset(config) # Print out a detailed breakdown of layers, memory consumption and time. generate_and_print_model_summary(config, lightning_model.model) if config.monitoring_interval_seconds > 0: # initialize and start GPU monitoring diagnostics_events = config.logs_folder / "diagnostics" logging.info(f"Starting resource monitor, outputting to {diagnostics_events}") resource_monitor = ResourceMonitor(interval_seconds=config.monitoring_interval_seconds, tensorboard_folder=diagnostics_events) resource_monitor.start() # Training loop logging.info("Starting training") lightning_data = TrainingAndValidationDataLightning(config) # type: ignore # When trying to store the config object in the constructor, it does not appear to get stored at all, later # reference of the object simply fail. Hence, have to set explicitly here. lightning_data.config = config trainer.fit(lightning_model, datamodule=lightning_data) trainer.logger.close() # type: ignore lightning_model.close_all_loggers() world_size = getattr(trainer, "world_size", 0) is_azureml_run = not config.is_offline_run # Per-subject model outputs for regression models are written per rank, and need to be aggregated here. # Each thread per rank will come here, and upload its files to the run outputs. Rank 0 will later download them. if is_azureml_run and world_size > 1 and isinstance(lightning_model, ScalarLightning): upload_output_file_as_temp(lightning_model.train_subject_outputs_logger.csv_path, config.outputs_folder) upload_output_file_as_temp(lightning_model.val_subject_outputs_logger.csv_path, config.outputs_folder) # DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training. # We can now use the global_rank of the Lightining model, rather than environment variables, because DDP has set # all necessary properties. if lightning_model.global_rank != 0: logging.info(f"Terminating training thread with rank {lightning_model.global_rank}.") sys.exit() logging.info("Choosing the best checkpoint and removing redundant files.") cleanup_checkpoint_folder(config.checkpoint_folder) # Lightning modifies a ton of environment variables. If we first run training and then the test suite, # those environment variables will mislead the training runs in the test suite, and make them crash. # Hence, restore the original environment after training. os.environ.clear() os.environ.update(old_environ) if world_size and isinstance(lightning_model, ScalarLightning): if is_azureml_run and world_size > 1: # In a DDP run on the local box, all ranks will write to local disk, hence no download needed. # In a multi-node DDP, each rank would upload to AzureML, and rank 0 will now download all results and # concatenate for rank in range(world_size): for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]: file = mode.value + "/" + get_subject_output_file_per_rank(rank) RUN_CONTEXT.download_file(name=TEMP_PREFIX + file, output_file_path=config.outputs_folder / file) # Concatenate all temporary file per execution mode for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]: temp_files = (config.outputs_folder / mode.value).rglob(SUBJECT_OUTPUT_PER_RANK_PREFIX + "*") result_file = config.outputs_folder / mode.value / SUBJECT_METRICS_FILE_NAME for i, file in enumerate(temp_files): temp_file_contents = file.read_text() if i == 0: # Copy the first file as-is, including the first line with the column headers result_file.write_text(temp_file_contents) else: # For all files but the first one, cut off the header line. result_file.write_text(os.linesep.join(temp_file_contents.splitlines()[1:])) model_training_results = ModelTrainingResults( train_results_per_epoch=list(storing_logger.to_metrics_dicts(prefix_filter=TRAIN_PREFIX).values()), val_results_per_epoch=list(storing_logger.to_metrics_dicts(prefix_filter=VALIDATION_PREFIX).values()), train_diagnostics=lightning_model.train_diagnostics, val_diagnostics=lightning_model.val_diagnostics, optimal_temperature_scale_values_per_checkpoint_epoch=[] ) logging.info("Finished training") # Since we have trained the model further, let the checkpoint_handler object know so it can handle # checkpoints correctly. checkpoint_handler.additional_training_done() # Upload visualization directory to AML run context to be able to see it # in the Azure UI. if config.max_batch_grad_cam > 0 and config.visualization_folder.exists(): RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER, path=str(config.visualization_folder)) if resource_monitor: # stop the resource monitoring process logging.info("Shutting down the resource monitor process. Aggregate resource utilization:") for name, value in resource_monitor.read_aggregate_metrics(): logging.info(f"{name}: {value}") if not config.is_offline_run: RUN_CONTEXT.log(name, value) resource_monitor.kill() return model_training_results