def test_predict_ensemble(batch_size: int) -> None:
    config_returns_0 = ConstantScalarConfig(0.)
    model_and_info_returns_0 = ModelAndInfo(config=config_returns_0, model_execution_mode=ModelExecutionMode.TEST,
                                            is_mean_teacher=False, checkpoint_path=None)
    model_loaded = model_and_info_returns_0.try_create_model_load_from_checkpoint_and_adjust()
    assert model_loaded
    model_returns_0 = model_and_info_returns_0.model

    config_returns_1 = ConstantScalarConfig(1.)
    model_and_info_returns_1 = ModelAndInfo(config=config_returns_1, model_execution_mode=ModelExecutionMode.TEST,
                                            is_mean_teacher=False, checkpoint_path=None)
    model_loaded = model_and_info_returns_1.try_create_model_load_from_checkpoint_and_adjust()
    assert model_loaded
    model_returns_1 = model_and_info_returns_1.model

    pipeline_0 = ScalarInferencePipeline(model_returns_0, config_returns_0, 0, 0)
    pipeline_1 = ScalarInferencePipeline(model_returns_0, config_returns_0, 0, 1)
    pipeline_2 = ScalarInferencePipeline(model_returns_0, config_returns_0, 0, 2)
    pipeline_3 = ScalarInferencePipeline(model_returns_1, config_returns_1, 0, 3)
    pipeline_4 = ScalarInferencePipeline(model_returns_1, config_returns_1, 0, 4)
    ensemble_pipeline = ScalarEnsemblePipeline([pipeline_0, pipeline_1, pipeline_2, pipeline_3, pipeline_4],
                                               config_returns_0, EnsembleAggregationType.Average)
    data = {"metadata": [GeneralSampleMetadata(id='2')] * batch_size,
            "label": torch.zeros((batch_size, 1)),
            "images": torch.zeros(((batch_size, 1) + config_returns_0.expected_image_size_zyx)),
            "numerical_non_image_features": torch.tensor([]),
            "categorical_non_image_features": torch.tensor([]),
            "segmentations": torch.tensor([])}

    results = ensemble_pipeline.predict(data)
    ids, labels, predicted = results.subject_ids, results.labels, results.model_outputs
    assert ids == ['2'] * batch_size
    assert torch.equal(labels, torch.zeros((batch_size, 1)))
    # 3 models return 0, 2 return 1, so predicted should be ((sigmoid(0)*3)+(sigmoid(1)*2))/5
    assert torch.allclose(predicted, torch.full((batch_size, 1), 0.592423431))
Пример #2
0
def test_predict_ensemble(batch_size: int) -> None:
    config_returns_0 = ConstantScalarConfig(0.)
    model_returns_0 = create_lightning_model(config_returns_0,
                                             set_optimizer_and_scheduler=False)
    assert isinstance(model_returns_0, ScalarLightning)

    config_returns_1 = ConstantScalarConfig(1.)
    model_returns_1 = create_lightning_model(config_returns_1,
                                             set_optimizer_and_scheduler=False)
    assert isinstance(model_returns_1, ScalarLightning)

    pipeline_0 = ScalarInferencePipeline(model_returns_0, config_returns_0, 0)
    pipeline_1 = ScalarInferencePipeline(model_returns_0, config_returns_0, 1)
    pipeline_2 = ScalarInferencePipeline(model_returns_0, config_returns_0, 2)
    pipeline_3 = ScalarInferencePipeline(model_returns_1, config_returns_1, 3)
    pipeline_4 = ScalarInferencePipeline(model_returns_1, config_returns_1, 4)
    ensemble_pipeline = ScalarEnsemblePipeline(
        [pipeline_0, pipeline_1, pipeline_2, pipeline_3, pipeline_4],
        config_returns_0, EnsembleAggregationType.Average)
    data = {
        "metadata": [GeneralSampleMetadata(id='2')] * batch_size,
        "label":
        torch.zeros((batch_size, 1)),
        "images":
        torch.zeros(
            ((batch_size, 1) + config_returns_0.expected_image_size_zyx)),
        "numerical_non_image_features":
        torch.tensor([]),
        "categorical_non_image_features":
        torch.tensor([]),
        "segmentations":
        torch.tensor([])
    }

    results = ensemble_pipeline.predict(data)
    ids, labels, predicted = results.subject_ids, results.labels, results.posteriors
    assert ids == ['2'] * batch_size
    assert torch.equal(labels, torch.zeros((batch_size, 1)))
    # 3 models return 0, 2 return 1, so predicted should be ((sigmoid(0)*3)+(sigmoid(1)*2))/5
    assert torch.allclose(predicted, torch.full((batch_size, 1), 0.592423431))
def test_predict_ensemble(batch_size: int) -> None:
    config = ClassificationModelForTesting()
    model_returns_0: Any = ScalarOnesModel(config.expected_image_size_zyx, 0.)
    model_returns_1: Any = ScalarOnesModel(config.expected_image_size_zyx, 1.)
    model_and_opt_0 = update_model_for_multiple_gpus(
        ModelAndInfo(model_returns_0),
        args=config,
        execution_mode=ModelExecutionMode.TEST)
    model_returns_0 = model_and_opt_0.model
    model_and_opt_1 = update_model_for_multiple_gpus(
        ModelAndInfo(model_returns_1),
        args=config,
        execution_mode=ModelExecutionMode.TEST)
    model_returns_1 = model_and_opt_1.model
    pipeline_0 = ScalarInferencePipeline(model_returns_0, config, 0, 0)
    pipeline_1 = ScalarInferencePipeline(model_returns_0, config, 0, 1)
    pipeline_2 = ScalarInferencePipeline(model_returns_0, config, 0, 2)
    pipeline_3 = ScalarInferencePipeline(model_returns_1, config, 0, 3)
    pipeline_4 = ScalarInferencePipeline(model_returns_1, config, 0, 4)
    ensemble_pipeline = ScalarEnsemblePipeline(
        [pipeline_0, pipeline_1, pipeline_2, pipeline_3, pipeline_4], config,
        EnsembleAggregationType.Average)
    data = {
        "metadata": [GeneralSampleMetadata(id='2')] * batch_size,
        "label": torch.zeros((batch_size, 1)),
        "images": torch.zeros(
            ((batch_size, 1) + config.expected_image_size_zyx)),
        "numerical_non_image_features": torch.tensor([]),
        "categorical_non_image_features": torch.tensor([]),
        "segmentations": torch.tensor([])
    }

    results = ensemble_pipeline.predict(data)
    ids, labels, predicted = results.subject_ids, results.labels, results.model_outputs
    assert ids == ['2'] * batch_size
    assert torch.equal(labels, torch.zeros((batch_size, 1)))
    # 3 models return 0, 2 return 1, so predicted should be ((sigmoid(0)*3)+(sigmoid(1)*2))/5
    assert torch.allclose(predicted, torch.full((batch_size, 1), 0.592423431))