Beispiel #1
0
def model_test(
    config: ModelConfigBase,
    data_split: ModelExecutionMode,
    run_recovery: Optional[RunRecovery] = None,
    model_proc: ModelProcessing = ModelProcessing.DEFAULT
) -> Optional[InferenceMetrics]:
    """
    Runs model inference on segmentation or classification models, using a given dataset (that could be training,
    test or validation set). The inference results and metrics will be stored and logged in a way that may
    differ for model categories (classification, segmentation).
    :param config: The configuration of the model
    :param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
    :param run_recovery: Run recovery data if applicable.
    :param model_proc: whether we are testing an ensemble or single model; this affects where results are written.
    :return: The metrics that the model achieved on the given data set, or None if the data set is empty.
    """
    if len(config.get_dataset_splits()[data_split]) == 0:
        logging.info(f"Skipping inference on empty data split {data_split}")
        return None
    if config.avoid_process_spawn_in_data_loaders and is_linux():
        logging.warning(
            "Not performing any inference because avoid_process_spawn_in_data_loaders is set "
            "and additional data loaders are likely to block.")
        return None
    with logging_section(
            f"Running {model_proc.value} model on {data_split.name.lower()} set"
    ):
        if isinstance(config, SegmentationModelBase):
            return segmentation_model_test(config, data_split, run_recovery,
                                           model_proc)
        if isinstance(config, ScalarModelBase):
            return classification_model_test(config, data_split, run_recovery)
    raise ValueError(
        f"There is no testing code for models of type {type(config)}")
def may_initialize_rpdb() -> None:
    """
    On Linux only, import and initialize rpdb, to enable remote debugging if necessary.
    """
    # rpdb signal trapping does not work on Windows, as there is no SIGTRAP:
    if not is_linux():
        return
    import rpdb
    rpdb_port = 4444
    rpdb.handle_trap(port=rpdb_port)
    # For some reason, os.getpid() does not return the ID of what appears to be the currently running process.
    logging.info("rpdb is handling traps. To debug: identify the main runner.py process, then as root: "
                 f"kill -TRAP <process_id>; nc 127.0.0.1 {rpdb_port}")
Beispiel #3
0
def test_dataset_consumption2() -> None:
    """
    Creating datasets, case 2: Azure datasets, local folders and mount points given
    """
    azure_config = get_default_azure_config()
    datasets = create_dataset_configs(azure_config,
                                      all_azure_dataset_ids=["1", "2"],
                                      all_dataset_mountpoints=["mp1", "mp2"],
                                      all_local_datasets=[Path("l1"), Path("l2")])
    assert len(datasets) == 2
    assert datasets[0].name == "1"
    assert datasets[1].name == "2"
    assert datasets[0].local_folder == Path("l1")
    assert datasets[1].local_folder == Path("l2")
    if is_linux():
        # PosixPath cannot be instantiated on Windows
        assert datasets[0].target_folder == PosixPath("mp1")
        assert datasets[1].target_folder == PosixPath("mp2")
Beispiel #4
0
def test_model_inference_train_and_test(
        test_output_dirs: OutputFolderForTests, perform_cross_validation: bool,
        perform_training_set_inference: bool) -> None:
    config = DummyModel()
    config.number_of_cross_validation_splits = 2 if perform_cross_validation else 0
    config.perform_training_set_inference = perform_training_set_inference
    # Plotting crashes with random TCL errors on Windows, disable that for Windows PR builds.
    config.is_plotting_enabled = common_util.is_linux()

    config.set_output_to(test_output_dirs.root_dir)
    config.local_dataset = full_ml_test_data_path()

    # To make it seem like there was a training run before this, copy checkpoints into the checkpoints folder.
    stored_checkpoints = full_ml_test_data_path("checkpoints")
    shutil.copytree(str(stored_checkpoints), str(config.checkpoint_folder))

    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)
    checkpoint_handler.additional_training_done()
    result, _, _ = MLRunner(config).model_inference_train_and_test(
        checkpoint_handler=checkpoint_handler)
    if result is None:
        raise ValueError("Error result cannot be None")
    assert isinstance(result, InferenceMetricsForSegmentation)
    for key, _ in result.epochs.items():
        epoch_folder_name = common_util.epoch_folder_name(key)
        for folder in [
                ModelExecutionMode.TRAIN.value, ModelExecutionMode.VAL.value,
                ModelExecutionMode.TEST.value
        ]:
            results_folder = config.outputs_folder / epoch_folder_name / folder
            folder_exists = results_folder.is_dir()
            if folder in [
                    ModelExecutionMode.TRAIN.value,
                    ModelExecutionMode.VAL.value
            ]:
                if perform_training_set_inference:
                    assert folder_exists
            else:
                assert folder_exists
def test_model_inference_train_and_test(
        test_output_dirs: TestOutputDirectories,
        perform_cross_validation: bool,
        perform_training_set_inference: bool) -> None:
    config = DummyModel()
    config.number_of_cross_validation_splits = 2 if perform_cross_validation else 0
    config.perform_training_set_inference = perform_training_set_inference
    # Plotting crashes with random TCL errors on Windows, disable that for Windows PR builds.
    config.is_plotting_enabled = common_util.is_linux()

    config.set_output_to(test_output_dirs.root_dir)
    config.local_dataset = full_ml_test_data_path()

    # Mimic the behaviour that checkpoints are downloaded from blob storage into the checkpoints folder.
    stored_checkpoints = full_ml_test_data_path("checkpoints")
    shutil.copytree(str(stored_checkpoints), str(config.checkpoint_folder))
    result, _, _ = MLRunner(config).model_inference_train_and_test()
    if result is None:
        raise ValueError("Error result cannot be None")
    assert isinstance(result, InferenceMetricsForSegmentation)
    for key, _ in result.epochs.items():
        epoch_folder_name = common_util.epoch_folder_name(key)
        for folder in [
                ModelExecutionMode.TRAIN.value, ModelExecutionMode.VAL.value,
                ModelExecutionMode.TEST.value
        ]:
            results_folder = config.outputs_folder / epoch_folder_name / folder
            folder_exists = results_folder.is_dir()
            if folder in [
                    ModelExecutionMode.TRAIN.value,
                    ModelExecutionMode.VAL.value
            ]:
                if perform_training_set_inference:
                    assert folder_exists
            else:
                assert folder_exists
Beispiel #6
0
def test_model_inference_train_and_test(
        test_output_dirs: OutputFolderForTests, perform_cross_validation: bool,
        perform_training_set_inference: bool) -> None:
    config = DummyModel()
    config.number_of_cross_validation_splits = 2 if perform_cross_validation else 0
    config.perform_training_set_inference = perform_training_set_inference
    # Plotting crashes with random TCL errors on Windows, disable that for Windows PR builds.
    config.is_plotting_enabled = common_util.is_linux()

    config.set_output_to(test_output_dirs.root_dir)
    config.local_dataset = full_ml_test_data_path()

    checkpoint_path = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
    create_model_and_store_checkpoint(config, checkpoint_path)
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)
    checkpoint_handler.additional_training_done()
    result, _, _ = MLRunner(config).model_inference_train_and_test(
        checkpoint_handler=checkpoint_handler)
    if result is None:
        raise ValueError("Error result cannot be None")
    assert isinstance(result, InferenceMetricsForSegmentation)
    epoch_folder_name = common_util.BEST_EPOCH_FOLDER_NAME
    for folder in [
            ModelExecutionMode.TRAIN.value, ModelExecutionMode.VAL.value,
            ModelExecutionMode.TEST.value
    ]:
        results_folder = config.outputs_folder / epoch_folder_name / folder
        folder_exists = results_folder.is_dir()
        if folder in [
                ModelExecutionMode.TRAIN.value, ModelExecutionMode.VAL.value
        ]:
            if perform_training_set_inference:
                assert folder_exists
        else:
            assert folder_exists
Beispiel #7
0
def num_dataload_workers() -> int:
    """PyTorch support for multiple dataloader workers is flaky on Windows (so return 0)"""
    return 4 if common_util.is_linux() else 0
Beispiel #8
0
def run_model_inference_train_and_test(
        test_output_dirs: OutputFolderForTests,
        perform_cross_validation: bool,
        inference_on_train_set: Optional[bool] = None,
        inference_on_val_set: Optional[bool] = None,
        inference_on_test_set: Optional[bool] = None,
        ensemble_inference_on_train_set: Optional[bool] = None,
        ensemble_inference_on_val_set: Optional[bool] = None,
        ensemble_inference_on_test_set: Optional[bool] = None,
        model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> None:
    """
    Test running inference produces expected output metrics, files, folders and calls to upload_folder.

    :param test_output_dirs: Test output directories.
    :param perform_cross_validation: Whether to test with cross validation.
    :param inference_on_train_set: Override for inference on train data sets.
    :param inference_on_val_set: Override for inference on validation data sets.
    :param inference_on_test_set: Override for inference on test data sets.
    :param ensemble_inference_on_train_set: Override for ensemble inference on train data sets.
    :param ensemble_inference_on_val_set: Override for ensemble inference on validation data sets.
    :param ensemble_inference_on_test_set: Override for ensemble inference on test data sets.
    :param model_proc: Model processing to test.
    :return: None.
    """
    dummy_model = DummyModel()

    config = PassThroughModel()
    # Copy settings from DummyModel
    config.image_channels = dummy_model.image_channels
    config.ground_truth_ids = dummy_model.ground_truth_ids
    config.ground_truth_ids_display_names = dummy_model.ground_truth_ids_display_names
    config.colours = dummy_model.colours
    config.fill_holes = dummy_model.fill_holes
    config.roi_interpreted_types = dummy_model.roi_interpreted_types

    config.test_crop_size = (16, 16, 16)
    config.number_of_cross_validation_splits = 2 if perform_cross_validation else 0
    config.inference_on_train_set = inference_on_train_set
    config.inference_on_val_set = inference_on_val_set
    config.inference_on_test_set = inference_on_test_set
    config.ensemble_inference_on_train_set = ensemble_inference_on_train_set
    config.ensemble_inference_on_val_set = ensemble_inference_on_val_set
    config.ensemble_inference_on_test_set = ensemble_inference_on_test_set
    # Plotting crashes with random TCL errors on Windows, disable that for Windows PR builds.
    config.is_plotting_enabled = common_util.is_linux()

    config.set_output_to(test_output_dirs.root_dir)
    train_and_test_data_small_dir = test_output_dirs.root_dir / "train_and_test_data_small"
    config.local_dataset = create_train_and_test_data_small_dataset(
        config.test_crop_size, full_ml_test_data_path(), "train_and_test_data",
        train_and_test_data_small_dir, "data")

    checkpoint_path = config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
    create_model_and_store_checkpoint(config, checkpoint_path)
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)
    checkpoint_handler.additional_training_done()

    mock_upload_path = test_output_dirs.root_dir / "mock_upload"
    mock_upload_path.mkdir()

    run = create_mock_run(mock_upload_path, config)

    azure_config = Mock(name='mock_azure_config')
    azure_config.fetch_run.return_value = run

    runner = MLRunner(model_config=config, azure_config=azure_config)

    with mock.patch("InnerEye.ML.model_testing.PARENT_RUN_CONTEXT", run):
        metrics = runner.model_inference_train_and_test(
            checkpoint_paths=checkpoint_handler.get_checkpoints_to_test(),
            model_proc=model_proc)

    if model_proc == ModelProcessing.ENSEMBLE_CREATION:
        # Create a fake ensemble dataset.csv
        dataset_df = create_dataset_df()
        dataset_df.to_csv(config.outputs_folder / DATASET_CSV_FILE_NAME)

        with mock.patch.object(PlotCrossValidationConfig,
                               'azure_config',
                               return_value=azure_config):
            with mock.patch("InnerEye.Azure.azure_util.PARENT_RUN_CONTEXT",
                            run):
                with mock.patch("InnerEye.ML.run_ml.PARENT_RUN_CONTEXT", run):
                    runner.plot_cross_validation_and_upload_results()
                    runner.generate_report(ModelProcessing.ENSEMBLE_CREATION)

    if model_proc == ModelProcessing.DEFAULT:
        named_metrics = {
            ModelExecutionMode.TRAIN: inference_on_train_set,
            ModelExecutionMode.TEST: inference_on_test_set,
            ModelExecutionMode.VAL: inference_on_val_set
        }
    else:
        named_metrics = {
            ModelExecutionMode.TRAIN: ensemble_inference_on_train_set,
            ModelExecutionMode.TEST: ensemble_inference_on_test_set,
            ModelExecutionMode.VAL: ensemble_inference_on_val_set
        }

    error = ''
    expected_upload_folder_count = 0
    for mode, flag in named_metrics.items():
        if mode in metrics:
            metric = metrics[mode]
            assert isinstance(metric, InferenceMetricsForSegmentation)

        if flag is None:
            # No override supplied, calculate the expected default:
            if model_proc == ModelProcessing.DEFAULT:
                if not perform_cross_validation:
                    # If a "normal" run then default to val or test.
                    flag = mode == ModelExecutionMode.TEST
                else:
                    # If an ensemble child then default to never.
                    flag = False
            else:
                # If an ensemble then default to test only.
                flag = mode == ModelExecutionMode.TEST

        if mode in metrics and not flag:
            error = error + f"Error: {mode.value} cannot be not None."
        elif mode not in metrics and flag:
            error = error + f"Error: {mode.value} cannot be None."
        results_folder = config.outputs_folder / get_best_epoch_results_path(
            mode, model_proc)
        folder_exists = results_folder.is_dir()
        assert folder_exists == flag
        if flag and model_proc == ModelProcessing.ENSEMBLE_CREATION:
            expected_upload_folder_count = expected_upload_folder_count + 1
            expected_name = get_best_epoch_results_path(
                mode, ModelProcessing.DEFAULT)
            run.upload_folder.assert_any_call(name=str(expected_name),
                                              path=str(results_folder))
    if len(error):
        raise ValueError(error)

    if model_proc == ModelProcessing.ENSEMBLE_CREATION:
        # The report should have been mock uploaded
        expected_upload_folder_count = expected_upload_folder_count + 1

    assert run.upload_folder.call_count == expected_upload_folder_count
#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import numpy as np
import pytest

from InnerEye.Common import common_util
from InnerEye.ML.visualizers.reliability_curve import plot_reliability_curve


@pytest.mark.skipif(not common_util.is_linux(),
                    reason="Test execution time is longer on Windows")
def test_plot_reliability_curve() -> None:
    prediction = [np.random.rand(250, 1), np.random.rand(200, 1)]
    target = [
        np.random.randint(2, size=(250, 1)),
        np.random.randint(2, size=(200, 1))
    ]
    plot_reliability_curve(y_predict=prediction,
                           y_true=target,
                           num_bins=10,
                           normalise=True)
Beispiel #10
0
            plot_file_name=Path("x"))
    assert "Combination of input arguments is not recognized" in str(ex)


def compare_files(actual: List[Path], expected: List[str]) -> None:
    assert len(actual) == len(expected)
    for (f, e) in zip(actual, expected):
        assert f.exists()
        full_expected = full_ml_test_data_path(e)
        assert full_expected.exists()
        assert str(f).endswith(e)
        assert file_as_bytes(f) == file_as_bytes(full_expected)


@pytest.mark.skipif(
    common_util.is_linux(),
    reason="Rendering of the graph is slightly different on Linux")
def test_plot_normalization_result(
        test_output_dirs: TestOutputDirectories) -> None:
    """
    Tests plotting of before/after histograms in photometric normalization.
    :return:
    """
    size = (3, 3, 3)
    image = np.zeros((1, ) + size)
    for i, (z, y, x) in enumerate(
            itertools.product(range(size[0]), range(size[1]), range(size[2]))):
        image[0, z, y, x] = i
    labels = np.zeros((2, ) + size)
    labels[1, 1, 1, 1] = 1
    sample = Sample(image=image,
Beispiel #11
0
def test_is_cross_validation_child_run_ensemble_run() -> None:
    """
    Test that cross validation child runs are identified correctly.
    """
    # check for offline run
    assert not is_cross_validation_child_run(Run.get_context())
    # check for online runs
    run = get_most_recent_run(
        fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
    assert not is_cross_validation_child_run(run)
    assert all(
        [is_cross_validation_child_run(x) for x in fetch_child_runs(run)])


@pytest.mark.skipif(
    is_linux(),
    reason="Spurious file read/write errors on linux build agents.")
def test_merge_conda(test_output_dirs: OutputFolderForTests) -> None:
    """
    Tests the logic for merging Conda environment files.
    """
    env1 = """
channels:
  - defaults
  - pytorch
dependencies:
  - conda1=1.0
  - conda2=2.0
  - conda_both=3.0
  - pip:
      - azureml-sdk==1.7.0