def test_validate_user_module_and_set_functions(find_spec, import_module): default_inference_handler = Mock() mock_env = Mock() mock_env.module_name = "foo_module" default_model_fn = object() default_input_fn = object() default_predict_fn = object() default_output_fn = object() default_inference_handler.default_model_fn = default_model_fn default_inference_handler.default_input_fn = default_input_fn default_inference_handler.default_predict_fn = default_predict_fn default_inference_handler.default_output_fn = default_output_fn transformer = Transformer(default_inference_handler) transformer._environment = mock_env transformer._validate_user_module_and_set_functions() find_spec.assert_called_once_with(mock_env.module_name) import_module.assert_called_once_with(mock_env.module_name) assert transformer._default_inference_handler == default_inference_handler assert transformer._environment == mock_env assert transformer._model_fn == default_model_fn assert transformer._input_fn == default_input_fn assert transformer._predict_fn == default_predict_fn assert transformer._output_fn == default_output_fn
def _assert_value_error_raised(): with pytest.raises(ValueError) as e: transformer = Transformer() transformer._environment = Mock() transformer._validate_user_module_and_set_functions() assert 'Cannot use transform_fn implementation in conjunction with input_fn, predict_fn, ' \ 'and/or output_fn implementation' in str(e.value)
def test_validate_user_module_and_set_functions_transform_fn(import_module): import_module.transform_fn = Mock() transformer = Transformer() transformer._environment = Mock() transformer._validate_user_module_and_set_functions() assert transformer._transform_fn == import_module.return_value.transform_fn
def test_validate_user_module_and_set_functions_transform_fn(find_spec, import_module): mock_env = Mock() mock_env.module_name = "foo_module" import_module.transform_fn = Mock() transformer = Transformer() transformer._environment = mock_env transformer._validate_user_module_and_set_functions() find_spec.assert_called_once_with(mock_env.module_name) import_module.assert_called_once_with(mock_env.module_name) assert transformer._transform_fn == import_module.return_value.transform_fn
def test_transform_no_accept(validate, retrieve_content_type_header): data = [{"body": INPUT_DATA}] context = Mock() request_processor = Mock() transform_fn = Mock() environment = Mock() environment.default_accept = DEFAULT_ACCEPT context.request_processor = [request_processor] request_processor.get_request_properties.return_value = dict() transformer = Transformer() transformer._model = MODEL transformer._transform_fn = transform_fn transformer._environment = environment transformer.transform(data, context) validate.assert_called_once_with() transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, DEFAULT_ACCEPT)