def spawn_offline_cross_val_classification_child_runs(self) -> None: """ Trains and Tests k models based on their respective data splits sequentially. Stores the results on the Validation set to the outputs directory of the parent run. """ _config = self.model_config assert isinstance(_config, ScalarModelBase) parent_run_file_system = _config.file_system_config def _spawn_run(cross_val_split_index: int) -> None: split_model_config = copy.deepcopy(_config) assert isinstance(split_model_config, ScalarModelBase) split_model_config.cross_validation_split_index = cross_val_split_index _local_split_folder_name = str(cross_val_split_index) split_model_config.file_system_config = parent_run_file_system.add_subfolder(_local_split_folder_name) logging.info(f"Running model train and test on cross validation split: {cross_val_split_index}") split_ml_runner = MLRunner(model_config=split_model_config, azure_config=self.azure_config, project_root=self.project_root, post_cross_validation_hook=self.post_cross_validation_hook, model_deployment_hook=self.model_deployment_hook) split_ml_runner.run() for i in range(_config.number_of_cross_validation_splits): _spawn_run(i) config_and_files = get_config_and_results_for_offline_runs(self.model_config) plot_cross_validation_from_files(config_and_files, Path(config_and_files.config.outputs_directory))
def _test_result_aggregation_for_classification(files: List[RunResultFiles], plotting_config: PlotCrossValidationConfig, expected_aggregate_metrics: List[str], expected_epochs: Set[int]) -> None: """ Test how metrics are aggregated for cross validation runs on classification models. """ print(f"Writing aggregated metrics to {plotting_config.outputs_directory}") root_folder = Path(plotting_config.outputs_directory) plot_cross_validation_from_files(OfflineCrossvalConfigAndFiles(config=plotting_config, files=files), root_folder=root_folder) aggregates_file = root_folder / METRICS_AGGREGATES_FILE actual_aggregates = aggregates_file.read_text().splitlines() header_line = "prediction_target,area_under_roc_curve,area_under_pr_curve,accuracy_at_optimal_threshold," \ "false_positive_rate_at_optimal_threshold,false_negative_rate_at_optimal_threshold," \ "optimal_threshold,cross_entropy,accuracy_at_threshold_05,subject_count,data_split,epoch" expected_aggregate_metrics = [header_line] + expected_aggregate_metrics assert len(actual_aggregates) == len(expected_aggregate_metrics), "Number of lines in aggregated metrics file" for i, (actual, expected) in enumerate(zip(actual_aggregates, expected_aggregate_metrics)): assert actual == expected, f"Mismatch in aggregate metrics at index {i}" per_subject_metrics = pd.read_csv(root_folder / FULL_METRICS_DATAFRAME_FILE) assert LoggingColumns.Label.value in per_subject_metrics assert set(per_subject_metrics[LoggingColumns.Label.value].unique()) == {0.0, 1.0} assert LoggingColumns.ModelOutput.value in per_subject_metrics assert LoggingColumns.Patient.value in per_subject_metrics assert len(per_subject_metrics[LoggingColumns.Patient.value].unique()) == 356 assert LoggingColumns.Epoch.value in per_subject_metrics assert set(per_subject_metrics[LoggingColumns.Epoch.value].unique()) == expected_epochs assert LoggingColumns.CrossValidationSplitIndex.value in per_subject_metrics assert set(per_subject_metrics[LoggingColumns.CrossValidationSplitIndex.value].unique()) == {0, 1} assert LoggingColumns.DataSplit.value in per_subject_metrics assert per_subject_metrics[LoggingColumns.DataSplit.value].unique() == ["Val"]
def test_invalid_number_of_cv_files() -> None: """ Test that an error is raised if the expected number of cross validation fold is not equal to the number of results files provided. """ files, plotting_config = load_result_files_for_classification() plotting_config.number_of_cross_validation_splits = 4 print(f"Writing aggregated metrics to {plotting_config.outputs_directory}") with pytest.raises(ValueError): plot_cross_validation_from_files(OfflineCrossvalConfigAndFiles(config=plotting_config, files=files), root_folder=plotting_config.outputs_directory)
def spawn_offline_cross_val_classification_child_runs(self) -> None: """ Trains and Tests k models based on their respective data splits sequentially. Stores the results on the Validation set to the outputs directory of the parent run. """ _config = self.model_config assert isinstance(_config, ScalarModelBase) parent_run_file_system = _config.file_system_config def _spawn_run(cross_val_split_index: int, cross_val_sub_fold_split_index: int) -> None: split_model_config = copy.deepcopy(_config) assert isinstance(split_model_config, ScalarModelBase) split_model_config.cross_validation_split_index = cross_val_split_index split_model_config.cross_validation_sub_fold_split_index = cross_val_sub_fold_split_index if cross_val_sub_fold_split_index == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX: _local_split_folder_name = str(cross_val_split_index) else: _local_split_folder_name = \ str((cross_val_split_index * split_model_config.number_of_cross_validation_splits_per_fold) + cross_val_sub_fold_split_index) split_model_config.file_system_config = parent_run_file_system.add_subfolder( _local_split_folder_name) logging.info( f"Running model train and test on cross validation split: {x}") split_ml_runner = MLRunner(split_model_config, self.azure_config, self.project_root, self.model_deployment_hook, self.innereye_submodule_name) split_ml_runner.run() cv_fold_indices = [ list(range(_config.number_of_cross_validation_splits_per_fold)) if _config.perform_sub_fold_cross_validation else [DEFAULT_CROSS_VALIDATION_SPLIT_INDEX] ] cv_fold_indices *= _config.number_of_cross_validation_splits for i, x in enumerate(cv_fold_indices): for y in x: _spawn_run(i, int(y)) config_and_files = get_config_and_results_for_offline_runs( self.model_config) plot_cross_validation_from_files( config_and_files, Path(config_and_files.config.outputs_directory))
def test_aggregate_files_with_prediction_target(test_output_dirs: TestOutputDirectories) -> None: """ For multi-week RNNs that predict at multiple sequence points: Test that the dataframes including the prediction_target column can be aggregated. """ plotting_config = PlotCrossValidationConfig( run_recovery_id="foo", epoch=1, model_category=ModelCategory.Classification ) files = create_run_result_file_list(plotting_config, "multi_label_sequence_in_crossval") root_folder = Path(test_output_dirs.root_dir) print(f"Writing result files to {root_folder}") plot_cross_validation_from_files(OfflineCrossvalConfigAndFiles(config=plotting_config, files=files), root_folder=root_folder)