Esempio n. 1
0
def run_reload_one_config():
    """
  测试gRPC的reloadconfig接口
  :return:
  """
    channel = grpc.insecure_channel('yourip:port')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest(
    )  ##message ReloadConfigRequest
    model_server_config = model_server_config_pb2.ModelServerConfig()

    config_list = model_server_config_pb2.ModelConfigList(
    )  ##message ModelConfigList
    ####try to add one config
    one_config = config_list.config.add()  #
    one_config.name = "test1"
    one_config.base_path = "/home/model/model1"
    one_config.model_platform = "tensorflow"

    model_server_config.model_config_list.CopyFrom(config_list)  #one of

    request.config.CopyFrom(model_server_config)

    print(request.IsInitialized())
    print(request.ListFields())

    responese = stub.HandleReloadConfigRequest(request, 10)
    if responese.status.error_code == 0:
        print("reload sucessfully")
    else:
        print("reload error!")
        print(responese.status.error_code)
        print(responese.status.error_message)
Esempio n. 2
0
def run():
  channel = grpc.insecure_channel(sys.argv[1])
  stub = model_service_pb2_grpc.ModelServiceStub(channel)
  request = model_management_pb2.ReloadConfigRequest()  ##message ReloadConfigRequest
  model_server_config = model_server_config_pb2.ModelServerConfig()

  config_list = model_server_config_pb2.ModelConfigList()##message ModelConfigList

  one_config = config_list.config.add() #####try to add one model config
  one_config.name= "svm_cls"
  one_config.base_path = "/models/svm_cls"
  one_config.model_platform="tensorflow"
  #one_config.model_version_policy.specific.versions.append(4)


  model_server_config.model_config_list.CopyFrom(config_list)
  request.config.CopyFrom(model_server_config)

  #print(request.IsInitialized())
  #print(request.ListFields())

  responese = stub.HandleReloadConfigRequest(request,10)
  if responese.status.error_code == 0:
      print("Reload sucessfully")
  else:
      print("Reload failed!")
      print(responese.status.error_code)
      print(responese.status.error_message)
Esempio n. 3
0
def parse_config_file():
    """
    测试读取config_file并回写
    :return:
    """
    with open("model_config.ini", "r") as f:
        config_ini = f.read()
    print(config_ini)

    channel = grpc.insecure_channel('yourip:port')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest(
    )  ##message ReloadConfigRequest
    model_server_config = model_server_config_pb2.ModelServerConfig()
    x = text_format.Parse(text=config_ini,
                          message=model_server_config)  # 非官方认证的方法
    x = text_format.MessageToString(model_server_config)
    with open("x.txt", "w+") as f:
        f.write(x)
    print(x)
    print(model_server_config.IsInitialized())
    request.config.CopyFrom(model_server_config)
    # print(request.ListFields())
    responese = stub.HandleReloadConfigRequest(request, 10)
    if responese.status.error_code == 0:
        print("reload sucessfully")
    else:
        print("reload error!")
        print(responese.status.error_code)
        print(responese.status.error_message)
Esempio n. 4
0
def updateConfigurations():
    channel = grpc.insecure_channel(const.host + ":" + const.port)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()
    model_server_config = model_server_config_pb2.ModelServerConfig()

    # Create a config to add to the list of served models
    configurations = open("models.conf", "r").read()
    config_list = model_server_config_pb2.ModelConfigList()
    model_server_config = text_format.Parse(text=configurations,
                                            message=model_server_config)

    request.config.CopyFrom(model_server_config)

    print(request.IsInitialized())
    print(request.ListFields())

    response = stub.HandleReloadConfigRequest(request, 10)
    if response.status.error_code == 0:
        return {"status": 200, "message": "Reload sucessfully"}
    else:
        return {
            "status": response.status.error_code,
            "message": response.status.error_message
        }
Esempio n. 5
0
 def __init__(self, endpoint: Text, model_name: Text):
   # Note that the channel instance is automatically closed (unsubscribed) on
   # deletion, so we don't have to manually close this on __del__.
   self._channel = grpc.insecure_channel(endpoint)
   self._model_name = model_name
   self._model_service = model_service_pb2_grpc.ModelServiceStub(self._channel)
   self._prediction_service = prediction_service_pb2_grpc.PredictionServiceStub(self._channel)  # pylint: disable=line-too-long
Esempio n. 6
0
def read_add_rewrite_config():
    """
    测试读取一个config文件,新加一个配置信息,并且reload过后回写
	test:read a config_file and add a new model, then reload and rewrite to the config_file
    :return:
    """
    with open("model_config.ini", "r") as f:
        config_ini = f.read()
    #print(config_ini)
    model_server_config = model_server_config_pb2.ModelServerConfig()
    model_server_config = text_format.Parse(text=config_ini,
                                            message=model_server_config)
    one_config = model_server_config.model_config_list.config.add(
    )  # add one more config
    one_config.name = "test2"
    one_config.base_path = "/home/sparkingarthur/tools/tf-serving-custom/serving/my_models/test2"
    one_config.model_platform = "tensorflow"

    print(model_server_config)

    channel = grpc.insecure_channel('10.200.24.101:8009')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()
    request.config.CopyFrom(model_server_config)
    responese = stub.HandleReloadConfigRequest(request, 10)
    if responese.status.error_code == 0:
        print("reload sucessfully")
    else:
        print("reload error!")
        print(responese.status.error_code)
        print(responese.status.error_message)
    new_config = text_format.MessageToString(model_server_config)
    with open("model_config.ini", "w+") as f:
        f.write(new_config)
Esempio n. 7
0
    def __init__(self, address: str):
        """
        TensorFlow Serving API for loading/unloading/reloading TF models and for running predictions.

        Extra arguments passed to the tensorflow/serving container:
            * --max_num_load_retries=0
            * --load_retry_interval_micros=30000000 # 30 seconds
            * --grpc_channel_arguments="grpc.max_concurrent_streams=<processes-per-api-replica>*<threads-per-process>" when inf == 0, otherwise
            * --grpc_channel_arguments="grpc.max_concurrent_streams=<threads-per-process>" when inf > 0.

        Args:
            address: An address with the "host:port" format.
        """

        if not tensorflow_dependencies_installed:
            raise NameError("tensorflow_serving_api and tensorflow packages not installed")

        self.address = address
        self.models = (
            {}
        )  # maps the model ID to the model metadata (signature def, signature key and so on)

        self.channel = grpc.insecure_channel(self.address)
        self._service = model_service_pb2_grpc.ModelServiceStub(self.channel)
        self._pred = prediction_service_pb2_grpc.PredictionServiceStub(self.channel)
Esempio n. 8
0
def update():
    config_file = "models/model.config"
    host_port = "localhost:8500"
    
    channel = grpc.insecure_channel(host_port)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()
    
    # read config file
    config_content = open(config_file, "r").read()
    model_server_config = model_server_config_pb2.ModelServerConfig()
    model_server_config = text_format.Parse(text=config_content, message=model_server_config)
    
    #print model_server_config.model_config_list.config
    #print(request.IsInitialized())
    #print(request.ListFields())

    request.config.CopyFrom(model_server_config)
    request_response = stub.HandleReloadConfigRequest(request, 10)
    
    if request_response.status.error_code == 0:
        print("TF Serving config file updated.")
    else:
        print("Failed to update config file.")
        print(request_response.status.error_code)
        print(request_response.status.error_message)
def main(_):
    if MODE.STATUS == FLAGS.mode:
        request = get_model_status_pb2.GetModelStatusRequest()
        request.model_spec.name = 'detection'
        request.model_spec.signature_name = 'serving_default'
    elif MODE.CONFIG == FLAGS.mode:
        request = model_management_pb2.ReloadConfigRequest()
        config = request.config.model_config_list.config.add()
        config.name = 'detection'
        config.base_path = '/models/detection/detection'
        config.model_platform = 'tensorflow'
        config.model_version_policy.specific.versions.append(5)
        config.model_version_policy.specific.versions.append(7)
        config2 = request.config.model_config_list.config.add()
        config2.name = 'pascal'
        config2.base_path = '/models/detection/pascal'
        config2.model_platform = 'tensorflow'
    elif MODE.ZOOKEEPER == FLAGS.mode:
        zk = KazooClient(hosts="10.10.67.225:2181")
        zk.start()
        zk.ensure_path('/serving/cunan')
        zk.set(
            '/serving/cunan',
            get_config('detection', 5, 224, 'serving_default',
                       ','.join(get_classes('model_data/cci.names')),
                       "10.12.102.32:8000"))
        return
    for address in FLAGS.addresses:
        channel = grpc.insecure_channel(address)
        stub = model_service_pb2_grpc.ModelServiceStub(channel)
        if MODE.STATUS == FLAGS.mode:
            result = stub.GetModelStatus(request)
        elif MODE.CONFIG == FLAGS.mode:
            result = stub.HandleReloadConfigRequest(request)
        print(result)
def main(config_path, host):

    models = load_config(config_path)

    channel = grpc.insecure_channel(host)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)

    request = model_management_pb2.ReloadConfigRequest()

    model_server_config = model_server_config_pb2.ModelServerConfig()
    config_list = model_server_config_pb2.ModelConfigList()

    for model in models:
        image_config = config_list.config.add()
        image_config.name = model['config']['name']
        image_config.base_path = model['config']['base_path']
        image_config.model_platform = model['config']['model_platform']

    model_server_config.model_config_list.CopyFrom(config_list)
    request.config.CopyFrom(model_server_config)

    print(request.ListFields())
    print('Sending request')
    response = stub.HandleReloadConfigRequest(request, 30)
    if response.status.error_code == 0:
        print('Reload successful')
    else:
        print('Reload failed!')
        print(response.status.error_code)
        print(response.status.error_message)
Esempio n. 11
0
 def _create_channel(address: str, service: int):
     channel = grpc.insecure_channel(address)
     if service == MODEL_SERVICE:
         return model_service_pb2_grpc.ModelServiceStub(channel)
     elif service == PREDICTION_SERVICE:
         return prediction_service_pb2_grpc.PredictionServiceStub(channel)
     return None
Esempio n. 12
0
    def __initialize(self):
        setproctitle.setproctitle('tf-serving')
        tf_serving_config_file = './config/tf_serving.json'
        with open(tf_serving_config_file, 'r') as fp:
            serving_config = json.load(fp)
        grpc_port = serving_config['gprc_port']  
        use_batch = serving_config['use_batch'] 

        options = [('grpc.max_send_message_length', 1024 * 1024 * 1024), ('grpc.max_receive_message_length', 1024 * 1024 * 1024)]
        
 
        model_config_file = './config/tf_serving_model.conf'
        cmd = 'tensorflow_model_server --port={} --rest_api_port={} --model_config_file={}'.format(grpc_port-1, grpc_port, model_config_file)
        if use_batch:
            batch_parameter_file = './config/batching.conf'
            cmd = cmd + ' --enable_batching=true --batching_parameters_file={}'.format(batch_parameter_file)
        self.serving_cmds = []
        self.serving_cmds.append(cmd)
        print(self.serving_cmds)
        system_env = os.environ.copy()
        self.process = subprocess.Popen(self.serving_cmds, env=system_env, shell=True)

        ## start grpc 
        self.grpc_channel = grpc.insecure_channel('127.0.0.1:{}'.format(grpc_port-1), options=options)
        self.grpc_stub = model_service_pb2_grpc.ModelServiceStub(self.grpc_channel)

        # wait for start tf serving
        time.sleep(3)

        # reload model
        model_server_config = model_server_config_pb2.ModelServerConfig() 
        config_list = model_server_config_pb2.ModelConfigList()

        model_config_json_file = './config/model.json'
        with open(model_config_json_file, 'r') as fp:
            model_configs = json.load(fp)
        model_dir = model_configs['model_dir']
        model_list = model_configs['models']
        for model_info in model_list:
            print(model_info)
            model_config = model_server_config_pb2.ModelConfig()
            model_config.name = model_info['name']
            model_config.base_path = os.path.abspath(os.path.join(model_dir, model_info['path']))
            model_config.model_platform = 'tensorflow'
            config_list.config.append(model_config)

        model_server_config.model_config_list.CopyFrom(config_list)
        request = model_management_pb2.ReloadConfigRequest()
        request.config.CopyFrom(model_server_config)

        grpc_response = self.grpc_stub.HandleReloadConfigRequest(request, 30)
        print(grpc_port)
        return 
Esempio n. 13
0
def run():
    channel = grpc.insecure_channel('192.168.199.198:8500')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    # message ReloadConfigRequest
    request = model_management_pb2.ReloadConfigRequest()
    model_server_config = model_server_config_pb2.ModelServerConfig()

    # message ModelConfigList
    config_list = model_server_config_pb2.ModelConfigList()

    print(config_list)

    ####try to add
    one_config = config_list.config.add()
    one_config.name = "saved_model_half1"
    one_config.base_path = "/models/saved_model_half_plus_two_cpu"
    one_config.model_platform = "tensorflow"

    one_config = config_list.config.add()
    one_config.name = "saved_model_half2"
    one_config.base_path = "/models/saved_model_half_plus_two_cpu"
    one_config.model_platform = "tensorflow"

    one_config = config_list.config.add()
    one_config.name = "saved_model_half3"
    one_config.base_path = "/models/saved_model_half_plus_two_cpu"
    one_config.model_platform = "tensorflow"

    one_config = config_list.config.add()
    one_config.name = "saved_model_half4"
    one_config.base_path = "/models/saved_model_half_plus_two_cpu"
    one_config.model_platform = "tensorflow"

    # one_config = config_list.config.add()
    # one_config.name = "saved_model_half5"
    # one_config.base_path = "/models/saved_model_half_plus_two_cpu_bak"
    # one_config.model_platform = "tensorflow"

    model_server_config.model_config_list.CopyFrom(config_list)  # one of

    request.config.CopyFrom(model_server_config)

    print(request.IsInitialized())
    print(request.ListFields())

    response = stub.HandleReloadConfigRequest(request, 10)
    if response.status.error_code == 0:
        print("reload sucessfully")
    else:
        print("reload error!")
        print(response.status.error_code)
        print(response.status.error_message)
Esempio n. 14
0
def main(args):
    # apply logging config
    logging.basicConfig(format='[%(asctime)s] %(levelname)s %(message)s',
                        level=args.log_level)

    # create channel and stub
    channel = grpc.insecure_channel('{}:{}'.format(args.tf_host, args.tf_port))
    stub = model_service_pb2_grpc.ModelServiceStub(channel)

    # register tf_serving exporter
    REGISTRY.register(TFServingExporter(stub, args.model_name, args.timeout))

    start_http_server(args.port)
    logging.info('Server started on port:{}'.format(args.port))
    while True:
        time.sleep(1)
Esempio n. 15
0
def main(_):
    channel = grpc.insecure_channel(':8500')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()

    model_server_config = model_server_config_pb2.ModelServerConfig()
    _, _, _, _, _, model_configs_text = build_from_config(os.getcwd())
    text_format.Parse(model_configs_text, model_server_config)
    request.config.CopyFrom(model_server_config)
    responese = stub.HandleReloadConfigRequest(request, 10)
    if responese.status.error_code == 0:
        print("successful update model")
    else:
        print("fail")
        print(responese.status.error_code)
        print(responese.status.error_message)
    def health_check(self, name, signature_name, version=None):
        """
        """
        request = get_model_status_pb2.GetModelStatusRequest()
        request.model_spec.name = name
        request.model_spec.signature_name = signature_name
        if version:
            request.model_spec.version.value = version

        stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
        try:
            response = stub.GetModelStatus(request, 10)
            if len(response.model_version_status) > 0:
                return True
        except Exception as err:
            logging.exception(err)
            return False
Esempio n. 17
0
    def testGetModelStatus(self):
        """Test ModelService.GetModelStatus implementation."""
        model_path = self._GetSavedModelBundlePath()
        model_server_address = TensorflowModelServerTest.RunServer(
            'default', model_path)[1]

        print('Sending GetModelStatus request...')
        # Send request
        request = get_model_status_pb2.GetModelStatusRequest()
        request.model_spec.name = 'default'
        channel = grpc.insecure_channel(model_server_address)
        stub = model_service_pb2_grpc.ModelServiceStub(channel)
        result = stub.GetModelStatus(request, RPC_TIMEOUT)  # 5 secs timeout
        # Verify response
        self.assertEqual(1, len(result.model_version_status))
        self.assertEqual(123, result.model_version_status[0].version)
        # OK error code (0) indicates no error occurred
        self.assertEqual(0, result.model_version_status[0].status.error_code)
Esempio n. 18
0
def main(_):
    channel = grpc.insecure_channel(FLAGS.address)

    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    if MODE.STATUS == FLAGS.mode:
        request = get_model_status_pb2.GetModelStatusRequest()
        request.model_spec.name = 'pascal'
        request.model_spec.signature_name = 'serving_default'
        result = stub.GetModelStatus(request)
    elif MODE.CONFIG == FLAGS.mode:
        request = model_management_pb2.ReloadConfigRequest()
        config = request.config.model_config_list.config.add()
        config.name = 'detection'
        config.base_path = '/models/detection/detection'
        config.model_platform = 'tensorflow'
        config2 = request.config.model_config_list.config.add()
        config2.name = 'pascal'
        config2.base_path = '/models/detection/pascal'
        config2.model_platform = 'tensorflow'
        result = stub.HandleReloadConfigRequest(request)

    print(result)
Esempio n. 19
0
def add_model_config(host, name):
    channel = grpc.insecure_channel(host)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()
    model_server_config = model_server_config_pb2.ModelServerConfig()

    config_list = model_server_config_pb2.ModelConfigList()
    one_config = config_list.config.add()
    one_config.name = name
    one_config.base_path = os.path.join('/models', name)
    one_config.model_platform = 'tensorflow'

    model_server_config.model_config_list.CopyFrom(config_list)

    request.config.CopyFrom(model_server_config)
    response = stub.HandleReloadConfigRequest(request, 10)

    if response.status.error_code == 0:
        print('Reload successful')
    else:
        print('Reload failed: {}: {}'.format(response.status.error_code,
                                             response.status.error_message))
Esempio n. 20
0
def prepare_stub_and_request(address,
                             model_name,
                             model_version=None,
                             creds=None,
                             opts=None,
                             request_type=INFERENCE_REQUEST):
    if opts is not None:
        opts = (('grpc.ssl_target_name_override', opts), )
    if creds is not None:
        channel = grpc.secure_channel(address, creds, options=opts)
    else:
        channel = grpc.insecure_channel(address, options=opts)
    request = None
    stub = None
    if request_type == MODEL_STATUS_REQUEST:
        request = get_model_status_pb2.GetModelStatusRequest()
        stub = model_service_pb2_grpc.ModelServiceStub(channel)
    elif request_type == INFERENCE_REQUEST:
        stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
        request = predict_pb2.PredictRequest()
    request.model_spec.name = model_name
    if model_version is not None:
        request.model_spec.version.value = model_version
    return stub, request
Esempio n. 21
0
 def service_channel(cls):
     return model_service_pb2_grpc.ModelServiceStub(cls.channel)
 def __init__(self, port, host='0.0.0.0'):
     self.channel = grpc.insecure_channel('{}:{}'.format(host, port))
     self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
Esempio n. 23
0
                    required=False,
                    default=9000,
                    help='Specify port to grpc service. default: 9000')
parser.add_argument('--model_name',
                    default='resnet',
                    help='Model name to query. default: resnet',
                    dest='model_name')
parser.add_argument(
    '--model_version',
    type=int,
    help='Model version to query. Lists all versions if omitted',
    dest='model_version')
args = vars(parser.parse_args())

channel = grpc.insecure_channel("{}:{}".format(args['grpc_address'],
                                               args['grpc_port']))

stub = model_service_pb2_grpc.ModelServiceStub(channel)

print('Getting model status for model:', args.get('model_name'))

request = get_model_status_pb2.GetModelStatusRequest()
request.model_spec.name = args.get('model_name')
if args.get('model_version') is not None:
    request.model_spec.version.value = args.get('model_version')

result = stub.GetModelStatus(
    request, 10.0)  # result includes a dictionary with all model outputs

print_status_response(response=result)
Esempio n. 24
0
 def __init__(self, address):
     self.address = address
     self.model_platform = "tensorflow"
     self.channel = grpc.insecure_channel(self.address)
     self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
     self.timeout = 600  # gRPC timeout in seconds
Esempio n. 25
0
 def __init__(self, server):
     super().__init__(server)
     self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
Esempio n. 26
0
def getConfigurations():
    channel = grpc.insecure_channel(const.host + ":" + const.port)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    model_server_config = model_server_config_pb2.ModelServerConfig()
    return model_server_config
Esempio n. 27
0
def create_channel_for_model_ver_pol_server_status():
    channel = grpc.insecure_channel('localhost:9006')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    return stub
Esempio n. 28
0
def create_channel_for_update_flow_specific_status():
    channel = grpc.insecure_channel('localhost:9008')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    return stub