Ejemplo n.º 1
0
    def default_output_fn(self, prediction, accept):  # pylint: disable=no-self-use
        """Function responsible for serializing the prediction result to the desired accept type.

        Args:
            prediction (obj): prediction result returned by the predict_fn.
            accept (str): accept header expected by the client.

        Returns:
            obj: prediction data.

        """
        for content_type in utils.parse_accept(accept):
            if content_type in encoder.SUPPORTED_CONTENT_TYPES:
                return encoder.encode(prediction, content_type), content_type
        raise errors.UnsupportedFormatError(accept)
    def default_output_fn(self, prediction, accept):
        """A default output_fn for PyTorch. Serializes predictions from predict_fn to JSON, CSV or NPY format.

        Args:
            prediction: a prediction result from predict_fn
            accept: type which the output data needs to be serialized

        Returns: output data serialized
        """
        if type(prediction) == torch.Tensor:
            prediction = prediction.detach().cpu().numpy().tolist()

        for content_type in utils.parse_accept(accept):
            if content_type in encoder.SUPPORTED_CONTENT_TYPES:
                encoded_prediction = encoder.encode(prediction, content_type)
                if content_type == content_types.CSV:
                    encoded_prediction = encoded_prediction.encode("utf-8")
                return encoded_prediction

        raise errors.UnsupportedFormatError(accept)
def test_parse_accept(input, expected):
    actual = parse_accept(input)
    assert actual == expected