Example #1
0
def test_non_image_encoder(
        test_output_dirs: OutputFolderForTests,
        hidden_layer_num_feature_channels: Optional[int]) -> None:
    """
    Test if we can build a simple MLP model that only feeds off non-image features.
    """
    dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
    dataset_contents = _get_fake_dataset_contents()
    (dataset_folder / DATASET_CSV_FILE_NAME).write_text(dataset_contents)
    config = NonImageEncoder(
        should_validate=False,
        hidden_layer_num_feature_channels=hidden_layer_num_feature_channels)
    config.local_dataset = dataset_folder
    config.set_output_to(test_output_dirs.root_dir)
    config.max_batch_grad_cam = 1
    config.validate()
    # run model training
    _, checkpoint_handler = model_train_unittest(
        config, output_folder=test_output_dirs)
    # run model inference
    runner = MLRunner(config)
    runner.setup()
    runner.model_inference_train_and_test(
        checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
    assert config.get_total_number_of_non_imaging_features() == 18
Example #2
0
def run_model_inference_train_and_test(
        test_output_dirs: OutputFolderForTests,
        perform_cross_validation: bool,
        inference_on_train_set: Optional[bool] = None,
        inference_on_val_set: Optional[bool] = None,
        inference_on_test_set: Optional[bool] = None,
        ensemble_inference_on_train_set: Optional[bool] = None,
        ensemble_inference_on_val_set: Optional[bool] = None,
        ensemble_inference_on_test_set: Optional[bool] = None,
        model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> None:
    """
    Test running inference produces expected output metrics, files, folders and calls to upload_folder.

    :param test_output_dirs: Test output directories.
    :param perform_cross_validation: Whether to test with cross validation.
    :param inference_on_train_set: Override for inference on train data sets.
    :param inference_on_val_set: Override for inference on validation data sets.
    :param inference_on_test_set: Override for inference on test data sets.
    :param ensemble_inference_on_train_set: Override for ensemble inference on train data sets.
    :param ensemble_inference_on_val_set: Override for ensemble inference on validation data sets.
    :param ensemble_inference_on_test_set: Override for ensemble inference on test data sets.
    :param model_proc: Model processing to test.
    :return: None.
    """
    dummy_model = DummyModel()

    config = PassThroughModel()
    # Copy settings from DummyModel
    config.image_channels = dummy_model.image_channels
    config.ground_truth_ids = dummy_model.ground_truth_ids
    config.ground_truth_ids_display_names = dummy_model.ground_truth_ids_display_names
    config.colours = dummy_model.colours
    config.fill_holes = dummy_model.fill_holes
    config.roi_interpreted_types = dummy_model.roi_interpreted_types

    config.test_crop_size = (16, 16, 16)
    config.number_of_cross_validation_splits = 2 if perform_cross_validation else 0
    config.inference_on_train_set = inference_on_train_set
    config.inference_on_val_set = inference_on_val_set
    config.inference_on_test_set = inference_on_test_set
    config.ensemble_inference_on_train_set = ensemble_inference_on_train_set
    config.ensemble_inference_on_val_set = ensemble_inference_on_val_set
    config.ensemble_inference_on_test_set = ensemble_inference_on_test_set
    # Plotting crashes with random TCL errors on Windows, disable that for Windows PR builds.
    config.is_plotting_enabled = common_util.is_linux()

    config.set_output_to(test_output_dirs.root_dir)
    train_and_test_data_small_dir = test_output_dirs.root_dir / "train_and_test_data_small"
    config.local_dataset = create_train_and_test_data_small_dataset(
        config.test_crop_size, full_ml_test_data_path(), "train_and_test_data",
        train_and_test_data_small_dir, "data")

    checkpoint_path = config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
    create_model_and_store_checkpoint(config, checkpoint_path)
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)
    checkpoint_handler.additional_training_done()

    mock_upload_path = test_output_dirs.root_dir / "mock_upload"
    mock_upload_path.mkdir()

    run = create_mock_run(mock_upload_path, config)

    azure_config = Mock(name='mock_azure_config')
    azure_config.fetch_run.return_value = run

    runner = MLRunner(model_config=config, azure_config=azure_config)

    with mock.patch("InnerEye.ML.model_testing.PARENT_RUN_CONTEXT", run):
        metrics = runner.model_inference_train_and_test(
            checkpoint_paths=checkpoint_handler.get_checkpoints_to_test(),
            model_proc=model_proc)

    if model_proc == ModelProcessing.ENSEMBLE_CREATION:
        # Create a fake ensemble dataset.csv
        dataset_df = create_dataset_df()
        dataset_df.to_csv(config.outputs_folder / DATASET_CSV_FILE_NAME)

        with mock.patch.object(PlotCrossValidationConfig,
                               'azure_config',
                               return_value=azure_config):
            with mock.patch("InnerEye.Azure.azure_util.PARENT_RUN_CONTEXT",
                            run):
                with mock.patch("InnerEye.ML.run_ml.PARENT_RUN_CONTEXT", run):
                    runner.plot_cross_validation_and_upload_results()
                    runner.generate_report(ModelProcessing.ENSEMBLE_CREATION)

    if model_proc == ModelProcessing.DEFAULT:
        named_metrics = {
            ModelExecutionMode.TRAIN: inference_on_train_set,
            ModelExecutionMode.TEST: inference_on_test_set,
            ModelExecutionMode.VAL: inference_on_val_set
        }
    else:
        named_metrics = {
            ModelExecutionMode.TRAIN: ensemble_inference_on_train_set,
            ModelExecutionMode.TEST: ensemble_inference_on_test_set,
            ModelExecutionMode.VAL: ensemble_inference_on_val_set
        }

    error = ''
    expected_upload_folder_count = 0
    for mode, flag in named_metrics.items():
        if mode in metrics:
            metric = metrics[mode]
            assert isinstance(metric, InferenceMetricsForSegmentation)

        if flag is None:
            # No override supplied, calculate the expected default:
            if model_proc == ModelProcessing.DEFAULT:
                if not perform_cross_validation:
                    # If a "normal" run then default to val or test.
                    flag = mode == ModelExecutionMode.TEST
                else:
                    # If an ensemble child then default to never.
                    flag = False
            else:
                # If an ensemble then default to test only.
                flag = mode == ModelExecutionMode.TEST

        if mode in metrics and not flag:
            error = error + f"Error: {mode.value} cannot be not None."
        elif mode not in metrics and flag:
            error = error + f"Error: {mode.value} cannot be None."
        results_folder = config.outputs_folder / get_best_epoch_results_path(
            mode, model_proc)
        folder_exists = results_folder.is_dir()
        assert folder_exists == flag
        if flag and model_proc == ModelProcessing.ENSEMBLE_CREATION:
            expected_upload_folder_count = expected_upload_folder_count + 1
            expected_name = get_best_epoch_results_path(
                mode, ModelProcessing.DEFAULT)
            run.upload_folder.assert_any_call(name=str(expected_name),
                                              path=str(results_folder))
    if len(error):
        raise ValueError(error)

    if model_proc == ModelProcessing.ENSEMBLE_CREATION:
        # The report should have been mock uploaded
        expected_upload_folder_count = expected_upload_folder_count + 1

    assert run.upload_folder.call_count == expected_upload_folder_count