def default_output_fn(self, prediction, accept): # pylint: disable=no-self-use """Function responsible for serializing the prediction result to the desired accept type. Args: prediction (obj): prediction result returned by the predict_fn. accept (str): accept header expected by the client. Returns: obj: prediction data. """ for content_type in utils.parse_accept(accept): if content_type in encoder.SUPPORTED_CONTENT_TYPES: return encoder.encode(prediction, content_type), content_type raise errors.UnsupportedFormatError(accept)
def default_output_fn(self, prediction, accept): """A default output_fn for PyTorch. Serializes predictions from predict_fn to JSON, CSV or NPY format. Args: prediction: a prediction result from predict_fn accept: type which the output data needs to be serialized Returns: output data serialized """ if type(prediction) == torch.Tensor: prediction = prediction.detach().cpu().numpy().tolist() for content_type in utils.parse_accept(accept): if content_type in encoder.SUPPORTED_CONTENT_TYPES: encoded_prediction = encoder.encode(prediction, content_type) if content_type == content_types.CSV: encoded_prediction = encoded_prediction.encode("utf-8") return encoded_prediction raise errors.UnsupportedFormatError(accept)
def test_parse_accept(input, expected): actual = parse_accept(input) assert actual == expected