def test_set_model_config_attributes() -> None: """Tests setter function for model config attributes""" train_output_size = (3, 5, 3) test_output_size = (7, 7, 7) model = IdentityModel() model_config = SegmentationModelBase(crop_size=train_output_size, test_crop_size=test_output_size, should_validate=False) model_config.set_derived_model_properties(model) assert model_config.inference_stride_size == test_output_size
def test_get_output_size() -> None: """Tests config properties related to output tensor size""" train_output_size = (5, 5, 5) test_output_size = (7, 7, 7) model_config = SegmentationModelBase(crop_size=train_output_size, test_crop_size=test_output_size, should_validate=False) assert model_config.get_output_size(execution_mode=ModelExecutionMode.TRAIN) is None assert model_config.get_output_size(execution_mode=ModelExecutionMode.TEST) is None model = IdentityModel() model_config.set_derived_model_properties(model) assert model_config.get_output_size(execution_mode=ModelExecutionMode.TRAIN) == train_output_size assert model_config.get_output_size(execution_mode=ModelExecutionMode.TEST) == test_output_size
def test_inference_stride_size_setter() -> None: """Tests setter function raises an error when stride size is larger than output patch size""" test_output_size = (7, 3, 5) test_stride_size = (3, 3, 3) test_fail_stride_size = (1, 1, 9) model = IdentityModel() model_config = SegmentationModelBase(test_crop_size=test_output_size, should_validate=False) model_config.inference_stride_size = test_stride_size assert model_config.inference_stride_size == test_stride_size model_config.set_derived_model_properties(model) assert model_config.inference_stride_size == test_stride_size model_config.inference_stride_size = None model_config.set_derived_model_properties(model) assert model_config.inference_stride_size == test_output_size with pytest.raises(ValueError): model_config.inference_stride_size = test_fail_stride_size