示例#1
0
def do_inference(hostport, work_dir, concurrency, num_tests):
    """Tests PredictionService with concurrent requests.
  Args:
    hostport: Host:port address of the PredictionService.
    work_dir: The full path of working directory for test data set.
    concurrency: Maximum number of concurrent requests.
    num_tests: Number of test images to use.
  Returns:
    The classification error rate.
  Raises:
    IOError: An error occurred processing test data set.
  """
    test_data_set = mnist_input_data.read_data_sets(work_dir).test
    channel = grpc.insecure_channel(hostport)
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    result_counter = _ResultCounter(num_tests, concurrency)
    for _ in range(num_tests):
        request = predict_pb2.PredictRequest()
        request.model_spec.name = 'mnist'
        request.model_spec.signature_name = 'predict_images'
        image, label = test_data_set.next_batch(1)
        request.inputs['images'].CopyFrom(
            tf.contrib.util.make_tensor_proto(image[0],
                                              shape=[1, image[0].size]))
        result_counter.throttle()
        result_future = stub.Predict.future(request, 5.0)  # 5 seconds
        result_future.add_done_callback(
            _create_rpc_callback(label[0], result_counter))
    return result_counter.get_error_rate()
def main(_):

    # 参数校验
    # if len(sys.argv) < 2 or sys.argv[-1].startswith('-'):
    #     print('Usage: mnist_saved_model.py [--training_iteration=x] '
    #           '[--model_version=y] export_dir')
    #     sys.exit(-1)
    # if FLAGS.training_iteration <= 0:
    #     print('Please specify a positive value for training iteration.')
    #     sys.exit(-1)
    # if FLAGS.model_version <= 0:
    #     print('Please specify a positive value for version number.')
    #     sys.exit(-1)

    # Train model
    print('Training model...')

    mnist = mnist_input_data.read_data_sets(FLAGS.work_dir, one_hot=True)

    sess = tf.InteractiveSession()

    serialized_tf_example = tf.placeholder(tf.string, name='tf_example')
    feature_configs = {'x': tf.FixedLenFeature(shape=[784], dtype=tf.float32), }
    tf_example = tf.parse_example(serialized_tf_example, feature_configs)
    x = tf.identity(tf_example['x'], name='x')  # use tf.identity() to assign name
    y_ = tf.placeholder('float', shape=[None, 10])
    w = tf.Variable(tf.zeros([784, 10]))
    b = tf.Variable(tf.zeros([10]))
    sess.run(tf.global_variables_initializer())
    y = tf.nn.softmax(tf.matmul(x, w) + b, name='y')
    cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    values, indices = tf.nn.top_k(y, 10)
    table = tf.contrib.lookup.index_to_string_table_from_tensor(
        tf.constant([str(i) for i in range(10)]))
    prediction_classes = table.lookup(tf.to_int64(indices))
    for _ in range(FLAGS.training_iteration):
        batch = mnist.train.next_batch(50)
        train_step.run(feed_dict={x: batch[0], y_: batch[1]})
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
    print('training accuracy %g' % sess.run(
        accuracy, feed_dict={
            x: mnist.test.images,
            y_: mnist.test.labels
        }))
    print('Done training!')

    # Export model
    # WARNING(break-tutorial-inline-code): The following code snippet is
    # in-lined in tutorials, please update tutorial documents accordingly
    # whenever code changes.

    # export_path_base = sys.argv[-1]
    export_path_base = "/Users/xingoo/PycharmProjects/ml-in-action/实践-tensorflow/01-官方文档-学习和使用ML/save_model"
    export_path = os.path.join(tf.compat.as_bytes(export_path_base), tf.compat.as_bytes(str(FLAGS.model_version)))
    print('Exporting trained model to', export_path)
    # 配置导出地址,创建SaveModel
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)

    # Build the signature_def_map.

    # 创建TensorInfo,包含type,shape,name
    classification_inputs = tf.saved_model.utils.build_tensor_info(serialized_tf_example)
    classification_outputs_classes = tf.saved_model.utils.build_tensor_info(prediction_classes)
    classification_outputs_scores = tf.saved_model.utils.build_tensor_info(values)

    # 分类签名:算法类型+输入+输出(概率和名字)
    classification_signature = (
        tf.saved_model.signature_def_utils.build_signature_def(
            inputs={
                tf.saved_model.signature_constants.CLASSIFY_INPUTS:
                    classification_inputs
            },
            outputs={
                tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES:
                    classification_outputs_classes,
                tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
                    classification_outputs_scores
            },
            method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME))

    tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
    tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

    # 预测签名:输入的x和输出的y
    prediction_signature = (
        tf.saved_model.signature_def_utils.build_signature_def(
            inputs={'images': tensor_info_x},
            outputs={'scores': tensor_info_y},
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

    # 构建图和变量的信息:
    """
    sess                会话
    tags                标签,默认提供serving、train、eval、gpu、tpu
    signature_def_map   签名
    main_op             初始化?
    strip_default_attrs strp?
    """
    # TODO predict_images和serving_default的区别
    builder.add_meta_graph_and_variables(
        sess, [tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            'predict_images':
                prediction_signature,
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                classification_signature,
        },
        main_op=tf.tables_initializer(),
        strip_default_attrs=True)

    # 保存
    builder.save()

    print('Done exporting!')
import numpy as np
import sys
import threading

from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc

tf.app.flags.DEFINE_integer('concurrency', 1,
                            'maximum number of concurrent inference requests')
tf.app.flags.DEFINE_integer('num_tests', 100, 'Number of test images')
tf.app.flags.DEFINE_string('server', 'localhost:8500',
                           'PredictionService host:port')
tf.app.flags.DEFINE_string('work_dir', './tmp', 'Working directory. ')
FLAGS = tf.app.flags.FLAGS

test_data_set = mnist_input_data.read_data_sets(FLAGS.work_dir).test
channel = grpc.insecure_channel(FLAGS.server)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)


class _ResultCounter(object):
    """Counter for the prediction results."""
    def __init__(self, num_tests, concurrency):
        self._num_tests = num_tests
        self._concurrency = concurrency
        self._error = 0
        self._done = 0
        self._active = 0
        self._condition = threading.Condition()

    def inc_error(self):
示例#4
0
import requests
import tensorflow as tf
import basic.mnist_input_data as mnist_input_data

headers = {"content-type": "application/json"}
# json_response = requests.post('http://localhost:8501/v1/models/half_plus_two:predict',
#                               data='{"instances": [1.0, 2.0, 5.0]}',
#                               headers=headers)
# print(json_response.text)

url = 'http://localhost:8501/v1/models/mnist:predict'
# json_response = requests.post(url,
#                               data='{"instances": [1.0, 2.0, 5.0]}',
#                               headers=headers)
# print(json_response.text)
test_data_set = mnist_input_data.read_data_sets('./tmp').test
print(type(test_data_set))

for _ in range(10):
    image, label = test_data_set.next_batch(1)
    print(tf.contrib.util.make_tensor_proto(image[0], shape=[1,
                                                             image[0].size]))

    json_response = requests.post(url,
                                  data='{"inputs": [image[0]]}',
                                  headers=headers)
    print(json_response.text)

# channel = grpc.insecure_channel(hostport)
# stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
# result_counter = _ResultCounter(num_tests, concurrency)