Example #1
0
def test_validate_user_module_and_set_functions(find_spec, import_module):
    default_inference_handler = Mock()
    mock_env = Mock()
    mock_env.module_name = "foo_module"

    default_model_fn = object()
    default_input_fn = object()
    default_predict_fn = object()
    default_output_fn = object()

    default_inference_handler.default_model_fn = default_model_fn
    default_inference_handler.default_input_fn = default_input_fn
    default_inference_handler.default_predict_fn = default_predict_fn
    default_inference_handler.default_output_fn = default_output_fn

    transformer = Transformer(default_inference_handler)
    transformer._environment = mock_env
    transformer._validate_user_module_and_set_functions()

    find_spec.assert_called_once_with(mock_env.module_name)
    import_module.assert_called_once_with(mock_env.module_name)
    assert transformer._default_inference_handler == default_inference_handler
    assert transformer._environment == mock_env
    assert transformer._model_fn == default_model_fn
    assert transformer._input_fn == default_input_fn
    assert transformer._predict_fn == default_predict_fn
    assert transformer._output_fn == default_output_fn
Example #2
0
def _assert_value_error_raised():
    with pytest.raises(ValueError) as e:
        transformer = Transformer()
        transformer._environment = Mock()
        transformer._validate_user_module_and_set_functions()

    assert 'Cannot use transform_fn implementation in conjunction with input_fn, predict_fn, ' \
           'and/or output_fn implementation' in str(e.value)
Example #3
0
def test_validate_user_module_and_set_functions_transform_fn(import_module):
    import_module.transform_fn = Mock()

    transformer = Transformer()
    transformer._environment = Mock()

    transformer._validate_user_module_and_set_functions()

    assert transformer._transform_fn == import_module.return_value.transform_fn
def test_validate_user_module_and_set_functions_transform_fn(find_spec, import_module):
    mock_env = Mock()
    mock_env.module_name = "foo_module"

    import_module.transform_fn = Mock()

    transformer = Transformer()
    transformer._environment = mock_env

    transformer._validate_user_module_and_set_functions()

    find_spec.assert_called_once_with(mock_env.module_name)
    import_module.assert_called_once_with(mock_env.module_name)
    assert transformer._transform_fn == import_module.return_value.transform_fn
def test_transform_tuple(validate, retrieve_content_type_header):
    data = [{"body": INPUT_DATA}]
    context = Mock()
    request_processor = Mock()
    transform_fn = Mock(return_value=(RESULT, ACCEPT))

    context.request_processor = [request_processor]
    request_processor.get_request_properties.return_value = {"accept": ACCEPT}

    transformer = Transformer()
    transformer._model = MODEL
    transformer._transform_fn = transform_fn

    result = transformer.transform(data, context)

    transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT)
    context.set_response_content_type.assert_called_once_with(0, transform_fn()[1])
    assert isinstance(result, list)
    assert result[0] == transform_fn()[0]
    def _user_module_transformer():
        user_module = importlib.import_module(
            environment.Environment().module_name)

        if hasattr(user_module, 'transform_fn'):
            return Transformer(
                default_inference_handler=DefaultMXNetInferenceHandler())

        model_fn = getattr(user_module, 'model_fn',
                           DefaultMXNetInferenceHandler().default_model_fn)

        model = model_fn(environment.model_dir)
        if isinstance(model, mx.module.BaseModule):
            return MXNetModuleTransformer()
        elif isinstance(model, mx.gluon.block.Block):
            return Transformer(
                default_inference_handler=DefaultGluonBlockInferenceHandler())
        else:
            raise ValueError('Unsupported model type: {}'.format(
                model.__class__.__name__))
def test_default_transformer():
    transformer = Transformer()

    assert isinstance(transformer._default_inference_handler, DefaultInferenceHandler)
    assert transformer._initialized is False
    assert transformer._environment is None
    assert transformer._model is None
    assert transformer._model_fn is None
    assert transformer._transform_fn is None
    assert transformer._input_fn is None
    assert transformer._predict_fn is None
    assert transformer._output_fn is None
Example #8
0
def test_handle_validate_and_initialize_user_error(env, validate_user_module):
    test_status_code = http_client.FORBIDDEN
    test_error_message = "Foo"

    class FooUserError(BaseInferenceToolkitError):
        def __init__(self, status_code, message):
            self.status_code = status_code
            self.message = message
            self.phrase = "Foo"

    data = [{"body": INPUT_DATA}]
    context = Mock()
    transform_fn = Mock()
    model_fn = Mock()

    transformer = Transformer()

    transformer._model = MODEL
    transformer._transform_fn = transform_fn
    transformer._model_fn = model_fn

    validate_user_module.side_effect = FooUserError(test_status_code,
                                                    test_error_message)

    assert transformer._initialized is False

    response = transformer.transform(data, context)
    assert test_error_message in str(response)
    context.set_response_status.assert_called_with(code=http_client.FORBIDDEN,
                                                   phrase=test_error_message)
Example #9
0
def test_handle_validate_and_initialize_error(env, validate_user_module):
    data = [{"body": INPUT_DATA}]
    request_processor = Mock()

    context = Mock()
    context.request_processor = [request_processor]

    transform_fn = Mock()
    model_fn = Mock()

    transformer = Transformer()

    transformer._model = MODEL
    transformer._transform_fn = transform_fn
    transformer._model_fn = model_fn

    test_error_message = "Foo"
    validate_user_module.side_effect = ValueError(test_error_message)

    assert transformer._initialized is False

    response = transformer.transform(data, context)
    assert test_error_message in str(response)
    context.set_response_status.assert_called_with(
        code=http_client.INTERNAL_SERVER_ERROR, phrase=test_error_message)
def test_transform(validate, retrieve_content_type_header, accept_key):
    data = [{"body": INPUT_DATA}]
    context = Mock()
    request_processor = Mock()
    transform_fn = Mock(return_value=RESULT)

    context.request_processor = [request_processor]
    request_property = {accept_key: ACCEPT}
    request_processor.get_request_properties.return_value = request_property

    transformer = Transformer()
    transformer._model = MODEL
    transformer._transform_fn = transform_fn

    result = transformer.transform(data, context)

    validate.assert_called_once_with()
    retrieve_content_type_header.assert_called_once_with(request_property)
    transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT)
    context.set_response_content_type.assert_called_once_with(0, ACCEPT)
    assert isinstance(result, list)
    assert result[0] == RESULT
Example #11
0
def test_transformer_with_custom_default_inference_handler():
    default_inference_handler = Mock()

    transformer = Transformer(default_inference_handler)

    assert transformer._default_inference_handler == default_inference_handler
    assert transformer._initialized is False
    assert transformer._environment is None
    assert transformer._model is None
    assert transformer._model_fn is None
    assert transformer._transform_fn is None
    assert transformer._input_fn is None
    assert transformer._predict_fn is None
    assert transformer._output_fn is None
Example #12
0
def test_validate_and_initialize(env, validate_user_module):
    transformer = Transformer()

    model_fn = Mock()
    transformer._model_fn = model_fn

    assert transformer._initialized is False

    transformer.validate_and_initialize()

    assert transformer._initialized is True

    transformer.validate_and_initialize()

    model_fn.assert_called_once_with(environment.model_dir)
    env.assert_called_once_with()
    validate_user_module.assert_called_once_with()
def test_transform_no_accept(validate, retrieve_content_type_header):
    data = [{"body": INPUT_DATA}]
    context = Mock()
    request_processor = Mock()
    transform_fn = Mock()
    environment = Mock()
    environment.default_accept = DEFAULT_ACCEPT

    context.request_processor = [request_processor]
    request_processor.get_request_properties.return_value = dict()

    transformer = Transformer()
    transformer._model = MODEL
    transformer._transform_fn = transform_fn
    transformer._environment = environment

    transformer.transform(data, context)

    validate.assert_called_once_with()
    transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, DEFAULT_ACCEPT)
def test_default_transform_fn():
    transformer = Transformer()

    input_fn = Mock(return_value=PREPROCESSED_DATA)
    predict_fn = Mock(return_value=PREDICT_RESULT)
    output_fn = Mock(return_value=PROCESSED_RESULT)

    transformer._input_fn = input_fn
    transformer._predict_fn = predict_fn
    transformer._output_fn = output_fn

    result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT)

    input_fn.assert_called_once_with(INPUT_DATA, CONTENT_TYPE)
    predict_fn.assert_called_once_with(PREPROCESSED_DATA, MODEL)
    output_fn.assert_called_once_with(PREDICT_RESULT, ACCEPT)
    assert result == PROCESSED_RESULT
def test_transform_decode(validate, retrieve_content_type_header, content_type):
    input_data = Mock()
    context = Mock()
    request_processor = Mock()
    transform_fn = Mock()
    data = [{"body": input_data}]

    input_data.decode.return_value = INPUT_DATA
    context.request_processor = [request_processor]
    request_processor.get_request_properties.return_value = {"accept": ACCEPT}
    retrieve_content_type_header.return_value = content_type

    transformer = Transformer()
    transformer._model = MODEL
    transformer._transform_fn = transform_fn

    transformer.transform(data, context)

    input_data.decode.assert_called_once_with("utf-8")
    transform_fn.assert_called_once_with(MODEL, INPUT_DATA, content_type, ACCEPT)
Example #16
0
 def __init__(self):
     transformer = Transformer(default_inference_handler=DefaultPytorchInferenceHandler())
     super(HandlerService, self).__init__(transformer=transformer)
Example #17
0
 def __init__(self, transformer=None):
     self._service = transformer if transformer else Transformer()
Example #18
0
    def __init__(self):
        self._initialized = False

        transformer = Transformer(
            default_inference_handler=DefaultXGBoostInferenceHandler())
        super(HandlerService, self).__init__(transformer=transformer)
Example #19
0
 def __init__(self):
     transformer = Transformer(
         default_inference_handler=TensorFlowInferenceHandler())
     super(HandlerService, self).__init__(transformer=transformer)
Example #20
0
 def __init__(self):
     transformer = Transformer(default_inference_handler=self.DefaultSKLearnUserModuleInferenceHandler())
     super(HandlerService, self).__init__(transformer=transformer)