예제 #1
0
    def testLocalPredictionTensorflowModelWithStrings(self,
                                                      mock_create_client):

        signature_def = meta_graph_pb2.SignatureDef()
        signature_def.outputs["x_bytes"].dtype = types_pb2.DT_STRING
        signature_def.inputs["x_bytes"].dtype = types_pb2.DT_STRING

        mock_client = mock.Mock()
        mock_client.signature_map = {
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            signature_def
        }
        mock_client.predict.return_value = {"x_bytes": "to_encode"}
        mock_create_client.return_value = mock_client
        predictions = mlprediction.local_predict(model_dir=None,
                                                 instances=[{
                                                     "x_bytes": [1, 2, 3]
                                                 }])
        # Validate that the output is correctly base64 encoded (and only once)
        self.assertEquals(predictions, {
            "predictions": [{
                "x_bytes": {
                    "b64": base64.b64encode("to_encode")
                }
            }]
        })
예제 #2
0
    def testModelWithBytesBasedOutput(self):
        mock_client = mock.Mock()
        mock_client.predict.return_value = {"x_bytes": "to_encode"}
        signature_def = meta_graph_pb2.SignatureDef()
        signature_def.outputs["x_bytes"].dtype = types_pb2.DT_STRING
        signature_def.inputs["input_key"].dtype = types_pb2.DT_STRING
        mock_client.signature_map = {
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            signature_def
        }

        model = mlprediction.create_model(mock_client, "gs://tmp/foo")
        _, predictions = model.predict({"input_key": "foo"})
        self.assertEqual(predictions, [{
            "x_bytes": {
                "b64": base64.b64encode("to_encode")
            }
        }])
예제 #3
0
 def setUp(self):
     self._instances = [{"a": 1, "b": 2}, {"a": 2, "b": 4}]
     self._model_path = "gs://dummy/model/path"
     signature_def = meta_graph_pb2.SignatureDef()
     signature_def.inputs["a"].dtype = types_pb2.DT_INT32
     signature_def.inputs["b"].dtype = types_pb2.DT_INT32
     signature_def.outputs["c"].dtype = types_pb2.DT_INT32
     self._mock_client = mock.Mock()
     self._mock_client.predict.return_value = {"c": np.array([10, 20])}
     self._mock_client.signature_map = {
         tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
         signature_def
     }
     self._kwargs = {
         "signature_name":
         tf.saved_model.signature_constants.
         DEFAULT_SERVING_SIGNATURE_DEF_KEY
     }
예제 #4
0
    def testCreateTFModelFromModelServerClient(self, mock_model_server):
        env_updates = {"prediction_engine": "MODEL_SERVER"}
        flag_values = {"tensorflow_session_parallelism": 3}
        pseudo_flags = type("Flags", (object, ), flag_values)

        # Create model's SignatureDef and expected response.
        expected_response = get_model_metadata_pb2.GetModelMetadataResponse()
        expected_response.model_spec.name = "default"

        in_bytes = meta_graph_pb2.TensorInfo(name="x",
                                             dtype=tf.string.as_datatype_enum)
        out_bytes = meta_graph_pb2.TensorInfo(name="y",
                                              dtype=tf.string.as_datatype_enum)

        inputs = {"in_bytes": in_bytes}
        outputs = {"out_bytes": out_bytes}
        signatures_def = meta_graph_pb2.SignatureDef(inputs=inputs,
                                                     outputs=outputs)
        signatures_def_map = get_model_metadata_pb2.SignatureDefMap()
        signatures_def_map.signature_def["serving_default"].CopyFrom(
            signatures_def)
        expected_response.metadata["signature_def"].Pack(signatures_def_map)
        mock_model_server.GetModelMetadata.return_value = expected_response

        with mock.patch.dict("os.environ", env_updates):
            with mock.patch.object(tf_prediction_server_lib,
                                   "_start_model_server"):
                tf_prediction_server_lib._start_model_server.return_value = (
                    None, mock_model_server)
                model = tf_prediction_server_lib.create_tf_model(
                    "/dummy/model/path", pseudo_flags)

        # model is a TensorflowModel instance with a ModelServer client.
        expected_predict_response = make_response({"out_bytes": ["to encode"]})
        mock_model_server.Predict.return_value = expected_predict_response

        dummy_instances = []
        _, predictions = model.predict(dummy_instances,
                                       stats=mlprediction.Stats())
        self.assertEqual(list(predictions), [{
            "out_bytes": {
                u"b64": base64.b64encode("to encode")
            }
        }])
예제 #5
0
    def testModelWithAdditionalOptions(self):
        # Setup the app.
        mock_model = mock.Mock()
        mock_model.predict.return_value = [], []
        mock_model.signature_map = {
            "custom_signature": meta_graph_pb2.SignatureDef()
        }

        config = create_app_config(model=mock_model)
        inference_app = webapp2.WSGIApplication(
            [("/", server_lib._InferenceHandler)], debug=True, config=config)
        test_app = webtest.app.TestApp(app=inference_app)

        serialized_examples = [
            example(1, 3).SerializeToString(),
            example(2, -4).SerializeToString(),
            example(0, 0).SerializeToString()
        ]

        # Act.
        instances = [
            {
                "b64": base64.b64encode(serialized_examples[0])
            },
            {
                "b64": base64.b64encode(serialized_examples[1])
            },
            {
                "b64": base64.b64encode(serialized_examples[2])
            },
        ]
        body = {"instances": instances, "signature_name": "custom_signature"}
        test_app.post(url="/",
                      params=json.dumps(body),
                      content_type="application/json")

        # Assert.
        mock_model.predict.assert_has_calls([
            mock.call(serialized_examples,
                      stats=mock.ANY,
                      signature_name="custom_signature")
        ])
예제 #6
0
    def testGetModelSignatureMissingDtype(self, mock_model_server):
        expected_response = get_model_metadata_pb2.GetModelMetadataResponse()
        expected_response.model_spec.name = "default"

        in_bytes = meta_graph_pb2.TensorInfo(name="x")
        out_bytes = meta_graph_pb2.TensorInfo(name="y",
                                              dtype=tf.string.as_datatype_enum)

        inputs = {"in_bytes": in_bytes}
        outputs = {"out_bytes": out_bytes}
        signatures_def = meta_graph_pb2.SignatureDef(inputs=inputs,
                                                     outputs=outputs)
        signatures_def_map = get_model_metadata_pb2.SignatureDefMap()
        signatures_def_map.signature_def["serving_default"].CopyFrom(
            signatures_def)
        expected_response.metadata["signature_def"].Pack(signatures_def_map)
        mock_model_server.GetModelMetadata.return_value = expected_response
        received_signatures = tf_prediction_server_lib._get_model_signature_map(
            mock_model_server)
        self.assertTrue("serving_default" not in received_signatures)
예제 #7
0
    def testAppDecodesBytesIn(self):
        #     mock_model = mock.create_autospec(server_lib.Model)
        mock_model = mock.Mock()
        signature_def = meta_graph_pb2.SignatureDef()
        signature_def.outputs["z_bytes"].dtype = types_pb2.DT_STRING
        mock_model.signature_map = {
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            signature_def
        }

        # Setup the app.
        config = create_app_config(model=mock_model)
        inference_app = webapp2.WSGIApplication(
            [("/", server_lib._InferenceHandler)], debug=True, config=config)
        test_app = webtest.app.TestApp(app=inference_app)

        instances = [{
            u"x_bytes": {
                "b64": unicode(base64.b64encode("first"))
            },
            u"y_bytes": {
                "b64": unicode(base64.b64encode("second"))
            }
        }]
        predictions = [{"z_bytes": "some binary string"}]
        mock_model.predict.return_value = instances, predictions

        body = {"instances": instances}
        response = test_app.post(url="/",
                                 params=json.dumps(body),
                                 content_type="application/json")

        # Assert.
        # Inputs are automatically decoded
        expected_instances = [{u"x_bytes": "first", u"y_bytes": "second"}]
        mock_model.predict.assert_has_calls(
            [mock.call(expected_instances, stats=mock.ANY)])

        expected_predictions = {"predictions": predictions}
        self.assertEqual(response.body, json.dumps(expected_predictions))
예제 #8
0
    def testModelWithOutputCannotJsonEncode(self):
        # Setup the app.
        mock_model = mock.Mock()
        mock_model.predict.return_value = [], [{"x": "\xe1"}]
        signature_def = meta_graph_pb2.SignatureDef()
        signature_def.outputs["x"].dtype = types_pb2.DT_STRING
        mock_model.signature_map = {
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            signature_def
        }
        config = create_app_config(model=mock_model)
        inference_app = webapp2.WSGIApplication(
            [("/", server_lib._InferenceHandler)], debug=True, config=config)
        test_app = webtest.app.TestApp(app=inference_app)

        # Act.
        body = {"instances": []}
        response = test_app.post(url="/",
                                 params=json.dumps(body),
                                 content_type="application/json")
        self.assertIn("Failed to json encode the prediction response",
                      response)
        self.assertIn("suffix the alias of your output tensor with _bytes",
                      response)