def init_request(): print('正在连接Tensorflow Serving...') channel = grpc.insecure_channel('192.168.10.100:8500') stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) request = predict_pb2.PredictRequest() request.model_spec.name = "catdog" # request.model_spec.signature_name = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # 'predict_images' request.model_spec.signature_name = "catdog_classification" print('完成连接Tensorflow Serving...') return request, stub
def make_prediction(self, input_data, input_tensor_name, timeout=10.0, model_name=None): request = predict_pb2.PredictRequest() request.model_spec.name = model_name or 'model' copy_message(tf.make_tensor_proto(input_data.astype(dtype=np.uint8), dtype='uint8'), request.inputs[input_tensor_name]) response = self.execute(request, timeout=timeout) results = {} for key in response.outputs: tensor_proto = response.outputs[key] nd_array = tf.make_ndarray(tensor_proto) results[key] = nd_array return results
def main(): host = FLAGS.host port = FLAGS.port model_name = FLAGS.model_name model_version = FLAGS.model_version request_timeout = FLAGS.request_timeout filename="example3.bmp" fullpath = os.path.join("/ts/", filename) src_img = Image.open(fullpath, 'r') print('TF Processing source/reference image %dx%d - %s.' % (src_img.size[0], src_img.size[1], src_img.format)) src_img.show() # Create gRPC client and request channel = grpc.insecure_channel('%s:%s' % (host, port)) stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) request = predict_pb2.PredictRequest() request.model_spec.name = model_name request.model_spec.version.value = 1 request.model_spec.signature_name = 'serving_default' raw=cv2.imread(fullpath,0) data = numpy.asarray(raw, dtype=numpy.float32) / 255. flat = data.reshape(28,28) print((flat)) copy_message(tf.contrib.util.make_tensor_proto(flat), request.inputs['image']) print((request)) # Send request result = stub.Predict(request, request_timeout) print('waiting response....') print(' response received \r\n %s ' % (result)) response = numpy.array(result.outputs['classes'].int64_val) print('prediction is %s ' % (response))
# @Author : RIO # @desc: 生成调起文件:(需要自行生成/Data/muc/serving/tensorflow_serving/test文件夹) from grpc.beta import implementations import numpy as np import tensorflow as tf from tensorflow.python.platform import flags from tensorflow_serving_client.protos import predict_pb2 from tensorflow_serving_client.protos import prediction_service_pb2 from tensorflow.python.framework import dtypes flags.DEFINE_string('server', '127.0.0.1:9005', 'PredictionService host:port') FLAGS = flags.FLAGS n_samples = 100 host, port = FLAGS.server.split(':') channel = implementations.insecure_channel(host, int(port)) stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) # Generate deprecate data x_data = np.arange(n_samples, step=1, dtype=np.float32) x_data = np.reshape(x_data, (n_samples, 1)) # Send request request = predict_pb2.PredictRequest() request.model_spec.name = 'deprecate' request.inputs['x'].ParseFromString( tf.contrib.util.make_tensor_proto(x_data, dtype=dtypes.float32, shape=[100, 1]).SerializeToString()) result = stub.Predict(request, 10.0) # 10 secs timeout