Esempio n. 1
0
  def testBuildRequests_KerasModel(self):
    builder = request_builder._TFServingRpcRequestBuilder(
        model_name='foo',
        signatures=self._GetKerasModelSignature())
    builder.ReadExamplesArtifact(self._examples, num_examples=1)

    result = builder.BuildRequests()

    self.assertEqual(len(result), 1)
    self.assertIsInstance(result[0], predict_pb2.PredictRequest)
    self.assertEqual(result[0].model_spec.name, 'foo')
    self.assertEqual(result[0].model_spec.signature_name, 'serving_default')
Esempio n. 2
0
  def testBuildRequests_EstimatorModel_ServingDefault(self):
    builder = request_builder._TFServingRpcRequestBuilder(
        model_name='foo',
        signatures=self._GetEstimatorModelSignature())
    builder.ReadExamplesArtifact(self._examples, num_examples=1)

    result = builder.BuildRequests()

    self.assertEqual(len(result), 1)
    self.assertIsInstance(result[0], classification_pb2.ClassificationRequest)
    self.assertEqual(result[0].model_spec.name, 'foo')
    self.assertEqual(result[0].model_spec.signature_name, 'serving_default')
Esempio n. 3
0
  def testBuildRequests_EstimatorModel_Regression(self):
    builder = request_builder._TFServingRpcRequestBuilder(
        model_name='foo',
        signatures=self._GetEstimatorModelSignature(
            signature_names=['regression']))
    builder.ReadExamplesArtifact(self._examples, num_examples=1)

    result = builder.BuildRequests()

    self.assertEqual(len(result), 1)
    self.assertIsInstance(result[0], regression_pb2.RegressionRequest)
    self.assertEqual(result[0].model_spec.name, 'foo')
    self.assertEqual(result[0].model_spec.signature_name, 'regression')
Esempio n. 4
0
    def testBuildRequests_PredictMethod_FailOnInvalidSignature(self):
        builder = request_builder._TFServingRpcRequestBuilder(
            model_name='foo',
            signatures={
                # Signature argument is not for serialized tf.Example (i.e. dtype !=
                # DT_STRING or shape != (None,)).
                'serving_default':
                _make_signature_def({
                    'method_name': 'tensorflow/serving/predict',
                    'inputs': {
                        'x': {
                            'name': 'serving_default_input:0',
                            'dtype': 'DT_FLOAT',
                            'tensor_shape': {
                                'dim': [
                                    {
                                        'size': -1
                                    },
                                    {
                                        'size': 784
                                    },
                                ]
                            }
                        }
                    },
                    'outputs': {
                        'y': {
                            'name': 'StatefulPartitionedCall:0',
                            'dtype': 'DT_FLOAT',
                            'tensor_shape': {
                                'dim': [
                                    {
                                        'size': -1
                                    },
                                    {
                                        'size': 10
                                    },
                                ]
                            }
                        }
                    },
                })
            })
        builder.ReadExamplesArtifact(self._examples, num_examples=1)

        with self.assertRaisesRegex(
                ValueError,
                'Unable to find valid input key from SignatureDef'):
            builder.BuildRequests()
Esempio n. 5
0
    def testBuildRequests_PredictMethod(self):
        builder = request_builder._TFServingRpcRequestBuilder(
            model_name='foo',
            signatures={
                # Has only one argument with dtype=DT_STRING and shape=(None,).
                # This is the only valid form that InfraValidator accepts today.
                'serving_default':
                _make_signature_def({
                    'method_name': 'tensorflow/serving/predict',
                    'inputs': {
                        'x': {
                            'name': 'serving_default_examples:0',
                            'dtype': 'DT_STRING',
                            'tensor_shape': {
                                'dim': [
                                    {
                                        'size': -1
                                    },
                                ]
                            }
                        }
                    },
                    'outputs': {
                        'y': {
                            'name': 'StatefulPartitionedCall:0',
                            'dtype': 'DT_FLOAT',
                            'tensor_shape': {
                                'dim': [
                                    {
                                        'size': -1
                                    },
                                    {
                                        'size': 10
                                    },
                                ]
                            }
                        }
                    },
                })
            })
        builder.ReadExamplesArtifact(self._examples, num_examples=1)

        result = builder.BuildRequests()

        self.assertEqual(len(result), 1)
        self.assertIsInstance(result[0], predict_pb2.PredictRequest)
        self.assertEqual(result[0].inputs['x'].dtype,
                         tf.dtypes.string.as_datatype_enum)
Esempio n. 6
0
  def testBuildRequests_EstimatorModel_Predict(self):
    builder = request_builder._TFServingRpcRequestBuilder(
        model_name='foo',
        signatures=self._GetEstimatorModelSignature(
            signature_names=['predict']))
    builder.ReadExamplesArtifact(self._examples, num_examples=1)

    result = builder.BuildRequests()

    self.assertEqual(len(result), 1)
    self.assertIsInstance(result[0], predict_pb2.PredictRequest)
    self.assertEqual(result[0].model_spec.name, 'foo')
    self.assertEqual(result[0].model_spec.signature_name, 'predict')
    self.assertEqual(len(result[0].inputs), 1)
    input_key = list(result[0].inputs.keys())[0]
    self.assertEqual(result[0].inputs[input_key].dtype,
                     tf.dtypes.string.as_datatype_enum)