Exemplo n.º 1
0
def multi_inference(body):
    prediction_service_stub = _get_prediction_service_stub()
    request = ParseDict(body, inference.MultiInferenceRequest())
    result = prediction_service_stub.MultiInference(request)
    return MessageToDict(result,
                         preserving_proto_field_name=True,
                         including_default_value_fields=True)
Exemplo n.º 2
0
  def testMultiInference(self):
    """Test PredictionService.MultiInference implementation."""
    model_path = self._GetSavedModelBundlePath()
    model_server_address = TensorflowModelServerTest.RunServer(
        'default', model_path)[1]

    print 'Sending MultiInference request...'
    # Prepare request
    request = inference_pb2.MultiInferenceRequest()
    request.tasks.add().model_spec.name = 'default'
    request.tasks[0].model_spec.signature_name = 'regress_x_to_y'
    request.tasks[0].method_name = 'tensorflow/serving/regress'
    request.tasks.add().model_spec.name = 'default'
    request.tasks[1].model_spec.signature_name = 'classify_x_to_y'
    request.tasks[1].method_name = 'tensorflow/serving/classify'

    example = request.input.example_list.examples.add()
    example.features.feature['x'].float_list.value.extend([2.0])

    # Send request
    result = self._MakeStub(model_server_address).MultiInference(
        request, RPC_TIMEOUT)

    # Verify response
    self.assertEquals(2, len(result.results))
    expected_output = 3.0
    self.assertEquals(expected_output,
                      result.results[0].regression_result.regressions[0].value)
    self.assertEquals(expected_output, result.results[
        1].classification_result.classifications[0].classes[0].score)
    for i in xrange(2):
      self._VerifyModelSpec(result.results[i].model_spec,
                            request.tasks[i].model_spec.name,
                            request.tasks[i].model_spec.signature_name,
                            self._GetModelVersion(model_path))
    def parse_request(self, serialized_data):
        request_fn_map = {
            PREDICT: lambda: predict_pb2.PredictRequest(),
            INFERENCE: lambda: inference_pb2.MultiInferenceRequest(),
            CLASSIFY: lambda: classification_pb2.ClassificationRequest(),
            REGRESSION: lambda: regression_pb2.RegressionRequest()
        }

        request = request_fn_map[self.prediction_type]()
        request.ParseFromString(serialized_data)

        return request
Exemplo n.º 4
0
    def testMultiInference(self):
        """Test PredictionService.MultiInference implementation."""
        model_path = self._GetSavedModelBundlePath()
        enable_batching = False

        atexit.register(self.TerminateProcs)
        model_server_address = self.RunServer(PickUnusedPort(), 'default',
                                              model_path, enable_batching)

        print 'Sending MultiInference request...'
        # Prepare request
        request = inference_pb2.MultiInferenceRequest()
        request.tasks.add().model_spec.name = 'default'
        request.tasks[0].model_spec.signature_name = 'regress_x_to_y'
        request.tasks[0].method_name = 'tensorflow/serving/regress'
        request.tasks.add().model_spec.name = 'default'
        request.tasks[1].model_spec.signature_name = 'classify_x_to_y'
        request.tasks[1].method_name = 'tensorflow/serving/classify'

        example = request.input.example_list.examples.add()
        example.features.feature['x'].float_list.value.extend([2.0])

        # Send request
        host, port = model_server_address.split(':')
        channel = implementations.insecure_channel(host, int(port))
        stub = prediction_service_pb2.beta_create_PredictionService_stub(
            channel)
        result = stub.MultiInference(request, RPC_TIMEOUT)  # 5 secs timeout

        # Verify response
        self.assertEquals(2, len(result.results))
        expected_output = 3.0
        self.assertEquals(
            expected_output,
            result.results[0].regression_result.regressions[0].value)
        self.assertEquals(
            expected_output, result.results[1].classification_result.
            classifications[0].classes[0].score)
        for i in xrange(2):
            self._VerifyModelSpec(result.results[i].model_spec,
                                  request.tasks[i].model_spec.name,
                                  request.tasks[i].model_spec.signature_name,
                                  self._GetModelVersion(model_path))
Exemplo n.º 5
0
    def _TestMultiInference(self, model_path):
        """Test PredictionService.MultiInference implementation."""
        model_server_address = TensorflowModelServerTest.RunServer(
            'default', model_path)[1]

        print('Sending MultiInference request...')
        # Prepare request
        request = inference_pb2.MultiInferenceRequest()
        request.tasks.add().model_spec.name = 'default'
        request.tasks[0].model_spec.signature_name = 'regress_x_to_y'
        request.tasks[0].method_name = 'tensorflow/serving/regress'
        request.tasks.add().model_spec.name = 'default'
        request.tasks[1].model_spec.signature_name = 'classify_x_to_y'
        request.tasks[1].method_name = 'tensorflow/serving/classify'

        example = request.input.example_list.examples.add()
        example.features.feature['x'].float_list.value.extend([2.0])

        # Send request
        channel = grpc.insecure_channel(model_server_address)
        stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
        result = stub.MultiInference(request, RPC_TIMEOUT)  # 5 secs timeout

        # Verify response
        self.assertEqual(2, len(result.results))
        expected_output = 3.0
        self.assertEqual(
            expected_output,
            result.results[0].regression_result.regressions[0].value)
        self.assertEqual(
            expected_output, result.results[1].classification_result.
            classifications[0].classes[0].score)
        for i in range(2):
            self._VerifyModelSpec(result.results[i].model_spec,
                                  request.tasks[i].model_spec.name,
                                  request.tasks[i].model_spec.signature_name,
                                  self._GetModelVersion(model_path))