def init_request():
    print('正在连接Tensorflow Serving...')
    channel = grpc.insecure_channel('192.168.10.100:8500')
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    request = predict_pb2.PredictRequest()
    request.model_spec.name = "catdog"
    # request.model_spec.signature_name = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY  # 'predict_images'
    request.model_spec.signature_name = "catdog_classification"
    print('完成连接Tensorflow Serving...')
    return request, stub
    def make_prediction(self, input_data, input_tensor_name, timeout=10.0, model_name=None):
        request = predict_pb2.PredictRequest()
        request.model_spec.name = model_name or 'model'

        copy_message(tf.make_tensor_proto(input_data.astype(dtype=np.uint8), dtype='uint8'), request.inputs[input_tensor_name])
        response = self.execute(request, timeout=timeout)

        results = {}
        for key in response.outputs:
            tensor_proto = response.outputs[key]
            nd_array = tf.make_ndarray(tensor_proto)
            results[key] = nd_array

        return results
示例#3
0
def main():
  host = FLAGS.host
  port = FLAGS.port
  model_name = FLAGS.model_name
  model_version = FLAGS.model_version
  request_timeout = FLAGS.request_timeout

  filename="example3.bmp"
  fullpath = os.path.join("/ts/", filename)
  src_img = Image.open(fullpath, 'r')
  print('TF Processing source/reference image  %dx%d - %s.' % (src_img.size[0], src_img.size[1], src_img.format))
  src_img.show()
  
  # Create gRPC client and request
  channel = grpc.insecure_channel('%s:%s' % (host, port))
  
  stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
  request = predict_pb2.PredictRequest()
  request.model_spec.name = model_name
  request.model_spec.version.value = 1
  request.model_spec.signature_name = 'serving_default'
 
  raw=cv2.imread(fullpath,0)
  data = numpy.asarray(raw, dtype=numpy.float32) / 255.
  flat = data.reshape(28,28)
  print((flat))
  copy_message(tf.contrib.util.make_tensor_proto(flat), request.inputs['image'])
  print((request))
  

  # Send request
  result = stub.Predict(request, request_timeout)
  print('waiting response....')
  print(' response received \r\n %s ' % (result))
  response = numpy.array(result.outputs['classes'].int64_val)
  print('prediction is %s ' % (response))
# @Author  : RIO
# @desc: 生成调起文件:(需要自行生成/Data/muc/serving/tensorflow_serving/test文件夹)
from grpc.beta import implementations
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import flags
from tensorflow_serving_client.protos import predict_pb2
from tensorflow_serving_client.protos import prediction_service_pb2
from tensorflow.python.framework import dtypes

flags.DEFINE_string('server', '127.0.0.1:9005', 'PredictionService host:port')
FLAGS = flags.FLAGS

n_samples = 100

host, port = FLAGS.server.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

# Generate deprecate data
x_data = np.arange(n_samples, step=1, dtype=np.float32)
x_data = np.reshape(x_data, (n_samples, 1))

# Send request
request = predict_pb2.PredictRequest()
request.model_spec.name = 'deprecate'
request.inputs['x'].ParseFromString(
    tf.contrib.util.make_tensor_proto(x_data,
                                      dtype=dtypes.float32,
                                      shape=[100, 1]).SerializeToString())
result = stub.Predict(request, 10.0)  # 10 secs timeout