def __init__(self, model_fn=None, input_fn=None, predict_fn=None, output_fn=None, error_class=_errors.ClientError): """Default constructor. Wraps the any non default framework function in an error class to isolate framework from user errors. Args: model_fn (fn): Function responsible to load the model. input_fn (fn): Takes request data and de-serializes the data into an object for prediction. predict_fn (fn): Function responsible for model predictions. output_fn (fn): Function responsible to serialize the prediction for the response. error_class (Exception): Error class used to separate framework and user errors. """ self._model = None self._model_fn = _functions.error_wrapper( model_fn, error_class) if model_fn else default_model_fn self._input_fn = _functions.error_wrapper( input_fn, error_class) if input_fn else default_input_fn self._predict_fn = _functions.error_wrapper( predict_fn, error_class) if predict_fn else default_predict_fn self._output_fn = _functions.error_wrapper( output_fn, error_class) if output_fn else default_output_fn self._error_class = error_class
def __init__( self, model_fn=None, input_fn=None, predict_fn=None, output_fn=None, transform_fn=None, error_class=_errors.ClientError, ): """Default constructor. Wraps the any non default framework function in an error class to isolate framework from user errors. Args: model_fn (fn): Function responsible to load the model. input_fn (fn): Takes request data and de-serializes the data into an object for prediction. predict_fn (fn): Function responsible for model predictions. output_fn (fn): Function responsible to serialize the prediction for the response. transform_fn (fn): Function responsible for taking input data and returning a prediction as a serialized response. This function takes the place of ``input_fn``, ``predict_fn``, and ``output_fn``. error_class (Exception): Error class used to separate framework and user errors. """ self._model = None self._model_fn = (_functions.error_wrapper(model_fn, error_class) if model_fn else default_model_fn) if transform_fn and (input_fn or predict_fn or output_fn): raise ValueError( "Cannot use transform_fn implementation with input_fn, predict_fn, and/or output_fn" ) if transform_fn is not None: self._transform_fn = _functions.error_wrapper( transform_fn, error_class) else: self._transform_fn = self._default_transform_fn self._input_fn = (_functions.error_wrapper(input_fn, error_class) if input_fn else default_input_fn) self._predict_fn = (_functions.error_wrapper(predict_fn, error_class) if predict_fn else default_predict_fn) self._output_fn = (_functions.error_wrapper(output_fn, error_class) if output_fn else default_output_fn) self._error_class = error_class
def test_error_wrapper_exception(): with pytest.raises(NotImplementedError) as e: _functions.error_wrapper(lambda x: x, NotImplementedError)(2, 3) assert type(e.value.args[0]) == TypeError
def test_error_wrapper(): assert _functions.error_wrapper(lambda x: x * 10, NotImplementedError)(3) == 30