예제 #1
0
def test_set_model_config_attributes() -> None:
    """Tests setter function for model config attributes"""
    train_output_size = (3, 5, 3)
    test_output_size = (7, 7, 7)
    model = IdentityModel()
    model_config = SegmentationModelBase(crop_size=train_output_size,
                                         test_crop_size=test_output_size,
                                         should_validate=False)

    model_config.set_derived_model_properties(model)
    assert model_config.inference_stride_size == test_output_size
예제 #2
0
def test_get_output_size() -> None:
    """Tests config properties related to output tensor size"""
    train_output_size = (5, 5, 5)
    test_output_size = (7, 7, 7)

    model_config = SegmentationModelBase(crop_size=train_output_size,
                                         test_crop_size=test_output_size,
                                         should_validate=False)
    assert model_config.get_output_size(execution_mode=ModelExecutionMode.TRAIN) is None
    assert model_config.get_output_size(execution_mode=ModelExecutionMode.TEST) is None

    model = IdentityModel()
    model_config.set_derived_model_properties(model)
    assert model_config.get_output_size(execution_mode=ModelExecutionMode.TRAIN) == train_output_size
    assert model_config.get_output_size(execution_mode=ModelExecutionMode.TEST) == test_output_size
예제 #3
0
def test_inference_stride_size_setter() -> None:
    """Tests setter function raises an error when stride size is larger than output patch size"""
    test_output_size = (7, 3, 5)
    test_stride_size = (3, 3, 3)
    test_fail_stride_size = (1, 1, 9)
    model = IdentityModel()
    model_config = SegmentationModelBase(test_crop_size=test_output_size, should_validate=False)

    model_config.inference_stride_size = test_stride_size
    assert model_config.inference_stride_size == test_stride_size

    model_config.set_derived_model_properties(model)
    assert model_config.inference_stride_size == test_stride_size

    model_config.inference_stride_size = None
    model_config.set_derived_model_properties(model)
    assert model_config.inference_stride_size == test_output_size

    with pytest.raises(ValueError):
        model_config.inference_stride_size = test_fail_stride_size