Exemplo n.º 1
0
    def spawn_offline_cross_val_classification_child_runs(self) -> None:
        """
        Trains and Tests k models based on their respective data splits sequentially.
        Stores the results on the Validation set to the outputs directory of the parent run.
        """
        _config = self.model_config
        assert isinstance(_config, ScalarModelBase)
        parent_run_file_system = _config.file_system_config

        def _spawn_run(cross_val_split_index: int) -> None:
            split_model_config = copy.deepcopy(_config)
            assert isinstance(split_model_config, ScalarModelBase)
            split_model_config.cross_validation_split_index = cross_val_split_index

            _local_split_folder_name = str(cross_val_split_index)
            split_model_config.file_system_config = parent_run_file_system.add_subfolder(_local_split_folder_name)

            logging.info(f"Running model train and test on cross validation split: {cross_val_split_index}")
            split_ml_runner = MLRunner(model_config=split_model_config,
                                       azure_config=self.azure_config,
                                       project_root=self.project_root,
                                       post_cross_validation_hook=self.post_cross_validation_hook,
                                       model_deployment_hook=self.model_deployment_hook)
            split_ml_runner.run()

        for i in range(_config.number_of_cross_validation_splits):
            _spawn_run(i)

        config_and_files = get_config_and_results_for_offline_runs(self.model_config)
        plot_cross_validation_from_files(config_and_files, Path(config_and_files.config.outputs_directory))
Exemplo n.º 2
0
def _test_result_aggregation_for_classification(files: List[RunResultFiles],
                                                plotting_config: PlotCrossValidationConfig,
                                                expected_aggregate_metrics: List[str],
                                                expected_epochs: Set[int]) -> None:
    """
    Test how metrics are aggregated for cross validation runs on classification models.
    """
    print(f"Writing aggregated metrics to {plotting_config.outputs_directory}")
    root_folder = Path(plotting_config.outputs_directory)
    plot_cross_validation_from_files(OfflineCrossvalConfigAndFiles(config=plotting_config, files=files),
                                     root_folder=root_folder)
    aggregates_file = root_folder / METRICS_AGGREGATES_FILE
    actual_aggregates = aggregates_file.read_text().splitlines()
    header_line = "prediction_target,area_under_roc_curve,area_under_pr_curve,accuracy_at_optimal_threshold," \
                  "false_positive_rate_at_optimal_threshold,false_negative_rate_at_optimal_threshold," \
                  "optimal_threshold,cross_entropy,accuracy_at_threshold_05,subject_count,data_split,epoch"
    expected_aggregate_metrics = [header_line] + expected_aggregate_metrics
    assert len(actual_aggregates) == len(expected_aggregate_metrics), "Number of lines in aggregated metrics file"
    for i, (actual, expected) in enumerate(zip(actual_aggregates, expected_aggregate_metrics)):
        assert actual == expected, f"Mismatch in aggregate metrics at index {i}"
    per_subject_metrics = pd.read_csv(root_folder / FULL_METRICS_DATAFRAME_FILE)
    assert LoggingColumns.Label.value in per_subject_metrics
    assert set(per_subject_metrics[LoggingColumns.Label.value].unique()) == {0.0, 1.0}
    assert LoggingColumns.ModelOutput.value in per_subject_metrics
    assert LoggingColumns.Patient.value in per_subject_metrics
    assert len(per_subject_metrics[LoggingColumns.Patient.value].unique()) == 356
    assert LoggingColumns.Epoch.value in per_subject_metrics
    assert set(per_subject_metrics[LoggingColumns.Epoch.value].unique()) == expected_epochs
    assert LoggingColumns.CrossValidationSplitIndex.value in per_subject_metrics
    assert set(per_subject_metrics[LoggingColumns.CrossValidationSplitIndex.value].unique()) == {0, 1}
    assert LoggingColumns.DataSplit.value in per_subject_metrics
    assert per_subject_metrics[LoggingColumns.DataSplit.value].unique() == ["Val"]
Exemplo n.º 3
0
def test_invalid_number_of_cv_files() -> None:
    """
    Test that an error is raised if the expected number of cross validation fold
    is not equal to the number of results files provided.
    """
    files, plotting_config = load_result_files_for_classification()
    plotting_config.number_of_cross_validation_splits = 4
    print(f"Writing aggregated metrics to {plotting_config.outputs_directory}")
    with pytest.raises(ValueError):
        plot_cross_validation_from_files(OfflineCrossvalConfigAndFiles(config=plotting_config, files=files),
                                         root_folder=plotting_config.outputs_directory)
Exemplo n.º 4
0
    def spawn_offline_cross_val_classification_child_runs(self) -> None:
        """
        Trains and Tests k models based on their respective data splits sequentially.
        Stores the results on the Validation set to the outputs directory of the parent run.
        """
        _config = self.model_config
        assert isinstance(_config, ScalarModelBase)
        parent_run_file_system = _config.file_system_config

        def _spawn_run(cross_val_split_index: int,
                       cross_val_sub_fold_split_index: int) -> None:
            split_model_config = copy.deepcopy(_config)
            assert isinstance(split_model_config, ScalarModelBase)
            split_model_config.cross_validation_split_index = cross_val_split_index
            split_model_config.cross_validation_sub_fold_split_index = cross_val_sub_fold_split_index

            if cross_val_sub_fold_split_index == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX:
                _local_split_folder_name = str(cross_val_split_index)
            else:
                _local_split_folder_name = \
                    str((cross_val_split_index * split_model_config.number_of_cross_validation_splits_per_fold)
                        + cross_val_sub_fold_split_index)

            split_model_config.file_system_config = parent_run_file_system.add_subfolder(
                _local_split_folder_name)

            logging.info(
                f"Running model train and test on cross validation split: {x}")
            split_ml_runner = MLRunner(split_model_config, self.azure_config,
                                       self.project_root,
                                       self.model_deployment_hook,
                                       self.innereye_submodule_name)
            split_ml_runner.run()

        cv_fold_indices = [
            list(range(_config.number_of_cross_validation_splits_per_fold))
            if _config.perform_sub_fold_cross_validation else
            [DEFAULT_CROSS_VALIDATION_SPLIT_INDEX]
        ]
        cv_fold_indices *= _config.number_of_cross_validation_splits

        for i, x in enumerate(cv_fold_indices):
            for y in x:
                _spawn_run(i, int(y))

        config_and_files = get_config_and_results_for_offline_runs(
            self.model_config)
        plot_cross_validation_from_files(
            config_and_files, Path(config_and_files.config.outputs_directory))
Exemplo n.º 5
0
def test_aggregate_files_with_prediction_target(test_output_dirs: TestOutputDirectories) -> None:
    """
    For multi-week RNNs that predict at multiple sequence points: Test that the dataframes
    including the prediction_target column can be aggregated.
    """
    plotting_config = PlotCrossValidationConfig(
        run_recovery_id="foo",
        epoch=1,
        model_category=ModelCategory.Classification
    )
    files = create_run_result_file_list(plotting_config, "multi_label_sequence_in_crossval")

    root_folder = Path(test_output_dirs.root_dir)
    print(f"Writing result files to {root_folder}")
    plot_cross_validation_from_files(OfflineCrossvalConfigAndFiles(config=plotting_config, files=files),
                                     root_folder=root_folder)