def test_invalid_number_of_cv_files() -> None: """ Test that an error is raised if the expected number of cross validation fold is not equal to the number of results files provided. """ files, plotting_config = load_result_files_for_classification() plotting_config.number_of_cross_validation_splits = 4 print(f"Writing aggregated metrics to {plotting_config.outputs_directory}") root_folder = Path(plotting_config.outputs_directory) with pytest.raises(ValueError): plot_cross_validation_from_files(OfflineCrossvalConfigAndFiles( config=plotting_config, files=files), root_folder=root_folder)
def test_check_result_file_counts() -> None: """ More tests on the function that checks the number of files of each ModeExecutionMode. """ val_files, plotting_config = load_result_files_for_classification() # This test assumes that the loaded val_files all have mode Val assert all(file.execution_mode == ModelExecutionMode.VAL for file in val_files) plotting_config.number_of_cross_validation_splits = len(val_files) # Check that when just the Val files are present, the check does not throw config_and_files1 = OfflineCrossvalConfigAndFiles(config=plotting_config, files=val_files) check_result_file_counts(config_and_files1) # Check that when we add the same number of Test files, the check does not throw test_files = [RunResultFiles(execution_mode=ModelExecutionMode.TEST, metrics_file=file.metrics_file, dataset_csv_file=file.dataset_csv_file, run_recovery_id=file.run_recovery_id, split_index=file.split_index) for file in val_files] config_and_files2 = OfflineCrossvalConfigAndFiles(config=plotting_config, files=val_files + test_files) check_result_file_counts(config_and_files2) # Check that when we have the same number of files as the number of splits, but they are from a mixture # of modes, the check does throw config_and_files3 = OfflineCrossvalConfigAndFiles(config=plotting_config, files=val_files[:1] + test_files[1:]) with pytest.raises(ValueError): check_result_file_counts(config_and_files3)
def test_aggregate_files_with_prediction_target(test_output_dirs: TestOutputDirectories) -> None: """ For multi-week RNNs that predict at multiple sequence points: Test that the dataframes including the prediction_target column can be aggregated. """ plotting_config = PlotCrossValidationConfig( run_recovery_id="foo", epoch=1, model_category=ModelCategory.Classification ) files = create_run_result_file_list(plotting_config, "multi_label_sequence_in_crossval") root_folder = Path(test_output_dirs.root_dir) print(f"Writing result files to {root_folder}") plot_cross_validation_from_files(OfflineCrossvalConfigAndFiles(config=plotting_config, files=files), root_folder=root_folder)
def _test_result_aggregation_for_classification( files: List[RunResultFiles], plotting_config: PlotCrossValidationConfig, expected_aggregate_metrics: List[str], expected_epochs: Set[int]) -> None: """ Test how metrics are aggregated for cross validation runs on classification models. """ print(f"Writing aggregated metrics to {plotting_config.outputs_directory}") root_folder = plotting_config.outputs_directory plot_cross_validation_from_files(OfflineCrossvalConfigAndFiles( config=plotting_config, files=files), root_folder=root_folder) aggregates_file = root_folder / METRICS_AGGREGATES_FILE actual_aggregates = aggregates_file.read_text().splitlines() header_line = "prediction_target,area_under_roc_curve,area_under_pr_curve,accuracy_at_optimal_threshold," \ "false_positive_rate_at_optimal_threshold,false_negative_rate_at_optimal_threshold," \ "optimal_threshold,cross_entropy,accuracy_at_threshold_05,subject_count,data_split,epoch" expected_aggregate_metrics = [header_line] + expected_aggregate_metrics assert len(actual_aggregates) == len( expected_aggregate_metrics ), "Number of lines in aggregated metrics file" for i, (actual, expected) in enumerate( zip(actual_aggregates, expected_aggregate_metrics)): assert actual == expected, f"Mismatch in aggregate metrics at index {i}" per_subject_metrics = pd.read_csv(root_folder / FULL_METRICS_DATAFRAME_FILE) assert LoggingColumns.Label.value in per_subject_metrics assert set(per_subject_metrics[LoggingColumns.Label.value].unique()) == { 0.0, 1.0 } assert LoggingColumns.ModelOutput.value in per_subject_metrics assert LoggingColumns.Patient.value in per_subject_metrics assert len( per_subject_metrics[LoggingColumns.Patient.value].unique()) == 356 assert LoggingColumns.Epoch.value in per_subject_metrics assert set(per_subject_metrics[ LoggingColumns.Epoch.value].unique()) == expected_epochs assert LoggingColumns.CrossValidationSplitIndex.value in per_subject_metrics assert set(per_subject_metrics[ LoggingColumns.CrossValidationSplitIndex.value].unique()) == {0, 1} assert LoggingColumns.DataSplit.value in per_subject_metrics assert per_subject_metrics[LoggingColumns.DataSplit.value].unique() == [ "Val" ]