Example #1
0
    def __init__(self,
                 model_fn=None,
                 input_fn=None,
                 predict_fn=None,
                 output_fn=None,
                 error_class=_errors.ClientError):
        """Default constructor. Wraps the any non default framework function in an error class to isolate
        framework from user errors.

        Args:
            model_fn (fn): Function responsible to load the model.
            input_fn (fn): Takes request data and de-serializes the data into an object for prediction.
            predict_fn (fn): Function responsible for model predictions.
            output_fn (fn): Function responsible to serialize the prediction for the response.
            error_class (Exception): Error class used to separate framework and user errors.
        """
        self._model = None
        self._model_fn = _functions.error_wrapper(
            model_fn, error_class) if model_fn else default_model_fn
        self._input_fn = _functions.error_wrapper(
            input_fn, error_class) if input_fn else default_input_fn
        self._predict_fn = _functions.error_wrapper(
            predict_fn, error_class) if predict_fn else default_predict_fn
        self._output_fn = _functions.error_wrapper(
            output_fn, error_class) if output_fn else default_output_fn
        self._error_class = error_class
Example #2
0
    def __init__(
        self,
        model_fn=None,
        input_fn=None,
        predict_fn=None,
        output_fn=None,
        transform_fn=None,
        error_class=_errors.ClientError,
    ):
        """Default constructor. Wraps the any non default framework function in an error class to
        isolate framework from user errors.

        Args:
            model_fn (fn): Function responsible to load the model.
            input_fn (fn): Takes request data and de-serializes the data into an object for
                           prediction.
            predict_fn (fn): Function responsible for model predictions.
            output_fn (fn): Function responsible to serialize the prediction for the response.
            transform_fn (fn): Function responsible for taking input data and returning a prediction
                as a serialized response. This function takes the place of ``input_fn``,
                ``predict_fn``, and ``output_fn``.
            error_class (Exception): Error class used to separate framework and user errors.
        """
        self._model = None
        self._model_fn = (_functions.error_wrapper(model_fn, error_class)
                          if model_fn else default_model_fn)

        if transform_fn and (input_fn or predict_fn or output_fn):
            raise ValueError(
                "Cannot use transform_fn implementation with input_fn, predict_fn, and/or output_fn"
            )

        if transform_fn is not None:
            self._transform_fn = _functions.error_wrapper(
                transform_fn, error_class)
        else:
            self._transform_fn = self._default_transform_fn

        self._input_fn = (_functions.error_wrapper(input_fn, error_class)
                          if input_fn else default_input_fn)
        self._predict_fn = (_functions.error_wrapper(predict_fn, error_class)
                            if predict_fn else default_predict_fn)
        self._output_fn = (_functions.error_wrapper(output_fn, error_class)
                           if output_fn else default_output_fn)
        self._error_class = error_class
def test_error_wrapper_exception():
    with pytest.raises(NotImplementedError) as e:
        _functions.error_wrapper(lambda x: x, NotImplementedError)(2, 3)
    assert type(e.value.args[0]) == TypeError
def test_error_wrapper():
    assert _functions.error_wrapper(lambda x: x * 10,
                                    NotImplementedError)(3) == 30