def run_reload_one_config(): """ 测试gRPC的reloadconfig接口 :return: """ channel = grpc.insecure_channel('yourip:port') stub = model_service_pb2_grpc.ModelServiceStub(channel) request = model_management_pb2.ReloadConfigRequest( ) ##message ReloadConfigRequest model_server_config = model_server_config_pb2.ModelServerConfig() config_list = model_server_config_pb2.ModelConfigList( ) ##message ModelConfigList ####try to add one config one_config = config_list.config.add() # one_config.name = "test1" one_config.base_path = "/home/model/model1" one_config.model_platform = "tensorflow" model_server_config.model_config_list.CopyFrom(config_list) #one of request.config.CopyFrom(model_server_config) print(request.IsInitialized()) print(request.ListFields()) responese = stub.HandleReloadConfigRequest(request, 10) if responese.status.error_code == 0: print("reload sucessfully") else: print("reload error!") print(responese.status.error_code) print(responese.status.error_message)
def run(): channel = grpc.insecure_channel(sys.argv[1]) stub = model_service_pb2_grpc.ModelServiceStub(channel) request = model_management_pb2.ReloadConfigRequest() ##message ReloadConfigRequest model_server_config = model_server_config_pb2.ModelServerConfig() config_list = model_server_config_pb2.ModelConfigList()##message ModelConfigList one_config = config_list.config.add() #####try to add one model config one_config.name= "svm_cls" one_config.base_path = "/models/svm_cls" one_config.model_platform="tensorflow" #one_config.model_version_policy.specific.versions.append(4) model_server_config.model_config_list.CopyFrom(config_list) request.config.CopyFrom(model_server_config) #print(request.IsInitialized()) #print(request.ListFields()) responese = stub.HandleReloadConfigRequest(request,10) if responese.status.error_code == 0: print("Reload sucessfully") else: print("Reload failed!") print(responese.status.error_code) print(responese.status.error_message)
def parse_config_file(): """ 测试读取config_file并回写 :return: """ with open("model_config.ini", "r") as f: config_ini = f.read() print(config_ini) channel = grpc.insecure_channel('yourip:port') stub = model_service_pb2_grpc.ModelServiceStub(channel) request = model_management_pb2.ReloadConfigRequest( ) ##message ReloadConfigRequest model_server_config = model_server_config_pb2.ModelServerConfig() x = text_format.Parse(text=config_ini, message=model_server_config) # 非官方认证的方法 x = text_format.MessageToString(model_server_config) with open("x.txt", "w+") as f: f.write(x) print(x) print(model_server_config.IsInitialized()) request.config.CopyFrom(model_server_config) # print(request.ListFields()) responese = stub.HandleReloadConfigRequest(request, 10) if responese.status.error_code == 0: print("reload sucessfully") else: print("reload error!") print(responese.status.error_code) print(responese.status.error_message)
def updateConfigurations(): channel = grpc.insecure_channel(const.host + ":" + const.port) stub = model_service_pb2_grpc.ModelServiceStub(channel) request = model_management_pb2.ReloadConfigRequest() model_server_config = model_server_config_pb2.ModelServerConfig() # Create a config to add to the list of served models configurations = open("models.conf", "r").read() config_list = model_server_config_pb2.ModelConfigList() model_server_config = text_format.Parse(text=configurations, message=model_server_config) request.config.CopyFrom(model_server_config) print(request.IsInitialized()) print(request.ListFields()) response = stub.HandleReloadConfigRequest(request, 10) if response.status.error_code == 0: return {"status": 200, "message": "Reload sucessfully"} else: return { "status": response.status.error_code, "message": response.status.error_message }
def __init__(self, endpoint: Text, model_name: Text): # Note that the channel instance is automatically closed (unsubscribed) on # deletion, so we don't have to manually close this on __del__. self._channel = grpc.insecure_channel(endpoint) self._model_name = model_name self._model_service = model_service_pb2_grpc.ModelServiceStub(self._channel) self._prediction_service = prediction_service_pb2_grpc.PredictionServiceStub(self._channel) # pylint: disable=line-too-long
def read_add_rewrite_config(): """ 测试读取一个config文件,新加一个配置信息,并且reload过后回写 test:read a config_file and add a new model, then reload and rewrite to the config_file :return: """ with open("model_config.ini", "r") as f: config_ini = f.read() #print(config_ini) model_server_config = model_server_config_pb2.ModelServerConfig() model_server_config = text_format.Parse(text=config_ini, message=model_server_config) one_config = model_server_config.model_config_list.config.add( ) # add one more config one_config.name = "test2" one_config.base_path = "/home/sparkingarthur/tools/tf-serving-custom/serving/my_models/test2" one_config.model_platform = "tensorflow" print(model_server_config) channel = grpc.insecure_channel('10.200.24.101:8009') stub = model_service_pb2_grpc.ModelServiceStub(channel) request = model_management_pb2.ReloadConfigRequest() request.config.CopyFrom(model_server_config) responese = stub.HandleReloadConfigRequest(request, 10) if responese.status.error_code == 0: print("reload sucessfully") else: print("reload error!") print(responese.status.error_code) print(responese.status.error_message) new_config = text_format.MessageToString(model_server_config) with open("model_config.ini", "w+") as f: f.write(new_config)
def __init__(self, address: str): """ TensorFlow Serving API for loading/unloading/reloading TF models and for running predictions. Extra arguments passed to the tensorflow/serving container: * --max_num_load_retries=0 * --load_retry_interval_micros=30000000 # 30 seconds * --grpc_channel_arguments="grpc.max_concurrent_streams=<processes-per-api-replica>*<threads-per-process>" when inf == 0, otherwise * --grpc_channel_arguments="grpc.max_concurrent_streams=<threads-per-process>" when inf > 0. Args: address: An address with the "host:port" format. """ if not tensorflow_dependencies_installed: raise NameError("tensorflow_serving_api and tensorflow packages not installed") self.address = address self.models = ( {} ) # maps the model ID to the model metadata (signature def, signature key and so on) self.channel = grpc.insecure_channel(self.address) self._service = model_service_pb2_grpc.ModelServiceStub(self.channel) self._pred = prediction_service_pb2_grpc.PredictionServiceStub(self.channel)
def update(): config_file = "models/model.config" host_port = "localhost:8500" channel = grpc.insecure_channel(host_port) stub = model_service_pb2_grpc.ModelServiceStub(channel) request = model_management_pb2.ReloadConfigRequest() # read config file config_content = open(config_file, "r").read() model_server_config = model_server_config_pb2.ModelServerConfig() model_server_config = text_format.Parse(text=config_content, message=model_server_config) #print model_server_config.model_config_list.config #print(request.IsInitialized()) #print(request.ListFields()) request.config.CopyFrom(model_server_config) request_response = stub.HandleReloadConfigRequest(request, 10) if request_response.status.error_code == 0: print("TF Serving config file updated.") else: print("Failed to update config file.") print(request_response.status.error_code) print(request_response.status.error_message)
def main(_): if MODE.STATUS == FLAGS.mode: request = get_model_status_pb2.GetModelStatusRequest() request.model_spec.name = 'detection' request.model_spec.signature_name = 'serving_default' elif MODE.CONFIG == FLAGS.mode: request = model_management_pb2.ReloadConfigRequest() config = request.config.model_config_list.config.add() config.name = 'detection' config.base_path = '/models/detection/detection' config.model_platform = 'tensorflow' config.model_version_policy.specific.versions.append(5) config.model_version_policy.specific.versions.append(7) config2 = request.config.model_config_list.config.add() config2.name = 'pascal' config2.base_path = '/models/detection/pascal' config2.model_platform = 'tensorflow' elif MODE.ZOOKEEPER == FLAGS.mode: zk = KazooClient(hosts="10.10.67.225:2181") zk.start() zk.ensure_path('/serving/cunan') zk.set( '/serving/cunan', get_config('detection', 5, 224, 'serving_default', ','.join(get_classes('model_data/cci.names')), "10.12.102.32:8000")) return for address in FLAGS.addresses: channel = grpc.insecure_channel(address) stub = model_service_pb2_grpc.ModelServiceStub(channel) if MODE.STATUS == FLAGS.mode: result = stub.GetModelStatus(request) elif MODE.CONFIG == FLAGS.mode: result = stub.HandleReloadConfigRequest(request) print(result)
def main(config_path, host): models = load_config(config_path) channel = grpc.insecure_channel(host) stub = model_service_pb2_grpc.ModelServiceStub(channel) request = model_management_pb2.ReloadConfigRequest() model_server_config = model_server_config_pb2.ModelServerConfig() config_list = model_server_config_pb2.ModelConfigList() for model in models: image_config = config_list.config.add() image_config.name = model['config']['name'] image_config.base_path = model['config']['base_path'] image_config.model_platform = model['config']['model_platform'] model_server_config.model_config_list.CopyFrom(config_list) request.config.CopyFrom(model_server_config) print(request.ListFields()) print('Sending request') response = stub.HandleReloadConfigRequest(request, 30) if response.status.error_code == 0: print('Reload successful') else: print('Reload failed!') print(response.status.error_code) print(response.status.error_message)
def _create_channel(address: str, service: int): channel = grpc.insecure_channel(address) if service == MODEL_SERVICE: return model_service_pb2_grpc.ModelServiceStub(channel) elif service == PREDICTION_SERVICE: return prediction_service_pb2_grpc.PredictionServiceStub(channel) return None
def __initialize(self): setproctitle.setproctitle('tf-serving') tf_serving_config_file = './config/tf_serving.json' with open(tf_serving_config_file, 'r') as fp: serving_config = json.load(fp) grpc_port = serving_config['gprc_port'] use_batch = serving_config['use_batch'] options = [('grpc.max_send_message_length', 1024 * 1024 * 1024), ('grpc.max_receive_message_length', 1024 * 1024 * 1024)] model_config_file = './config/tf_serving_model.conf' cmd = 'tensorflow_model_server --port={} --rest_api_port={} --model_config_file={}'.format(grpc_port-1, grpc_port, model_config_file) if use_batch: batch_parameter_file = './config/batching.conf' cmd = cmd + ' --enable_batching=true --batching_parameters_file={}'.format(batch_parameter_file) self.serving_cmds = [] self.serving_cmds.append(cmd) print(self.serving_cmds) system_env = os.environ.copy() self.process = subprocess.Popen(self.serving_cmds, env=system_env, shell=True) ## start grpc self.grpc_channel = grpc.insecure_channel('127.0.0.1:{}'.format(grpc_port-1), options=options) self.grpc_stub = model_service_pb2_grpc.ModelServiceStub(self.grpc_channel) # wait for start tf serving time.sleep(3) # reload model model_server_config = model_server_config_pb2.ModelServerConfig() config_list = model_server_config_pb2.ModelConfigList() model_config_json_file = './config/model.json' with open(model_config_json_file, 'r') as fp: model_configs = json.load(fp) model_dir = model_configs['model_dir'] model_list = model_configs['models'] for model_info in model_list: print(model_info) model_config = model_server_config_pb2.ModelConfig() model_config.name = model_info['name'] model_config.base_path = os.path.abspath(os.path.join(model_dir, model_info['path'])) model_config.model_platform = 'tensorflow' config_list.config.append(model_config) model_server_config.model_config_list.CopyFrom(config_list) request = model_management_pb2.ReloadConfigRequest() request.config.CopyFrom(model_server_config) grpc_response = self.grpc_stub.HandleReloadConfigRequest(request, 30) print(grpc_port) return
def run(): channel = grpc.insecure_channel('192.168.199.198:8500') stub = model_service_pb2_grpc.ModelServiceStub(channel) # message ReloadConfigRequest request = model_management_pb2.ReloadConfigRequest() model_server_config = model_server_config_pb2.ModelServerConfig() # message ModelConfigList config_list = model_server_config_pb2.ModelConfigList() print(config_list) ####try to add one_config = config_list.config.add() one_config.name = "saved_model_half1" one_config.base_path = "/models/saved_model_half_plus_two_cpu" one_config.model_platform = "tensorflow" one_config = config_list.config.add() one_config.name = "saved_model_half2" one_config.base_path = "/models/saved_model_half_plus_two_cpu" one_config.model_platform = "tensorflow" one_config = config_list.config.add() one_config.name = "saved_model_half3" one_config.base_path = "/models/saved_model_half_plus_two_cpu" one_config.model_platform = "tensorflow" one_config = config_list.config.add() one_config.name = "saved_model_half4" one_config.base_path = "/models/saved_model_half_plus_two_cpu" one_config.model_platform = "tensorflow" # one_config = config_list.config.add() # one_config.name = "saved_model_half5" # one_config.base_path = "/models/saved_model_half_plus_two_cpu_bak" # one_config.model_platform = "tensorflow" model_server_config.model_config_list.CopyFrom(config_list) # one of request.config.CopyFrom(model_server_config) print(request.IsInitialized()) print(request.ListFields()) response = stub.HandleReloadConfigRequest(request, 10) if response.status.error_code == 0: print("reload sucessfully") else: print("reload error!") print(response.status.error_code) print(response.status.error_message)
def main(args): # apply logging config logging.basicConfig(format='[%(asctime)s] %(levelname)s %(message)s', level=args.log_level) # create channel and stub channel = grpc.insecure_channel('{}:{}'.format(args.tf_host, args.tf_port)) stub = model_service_pb2_grpc.ModelServiceStub(channel) # register tf_serving exporter REGISTRY.register(TFServingExporter(stub, args.model_name, args.timeout)) start_http_server(args.port) logging.info('Server started on port:{}'.format(args.port)) while True: time.sleep(1)
def main(_): channel = grpc.insecure_channel(':8500') stub = model_service_pb2_grpc.ModelServiceStub(channel) request = model_management_pb2.ReloadConfigRequest() model_server_config = model_server_config_pb2.ModelServerConfig() _, _, _, _, _, model_configs_text = build_from_config(os.getcwd()) text_format.Parse(model_configs_text, model_server_config) request.config.CopyFrom(model_server_config) responese = stub.HandleReloadConfigRequest(request, 10) if responese.status.error_code == 0: print("successful update model") else: print("fail") print(responese.status.error_code) print(responese.status.error_message)
def health_check(self, name, signature_name, version=None): """ """ request = get_model_status_pb2.GetModelStatusRequest() request.model_spec.name = name request.model_spec.signature_name = signature_name if version: request.model_spec.version.value = version stub = model_service_pb2_grpc.ModelServiceStub(self.channel) try: response = stub.GetModelStatus(request, 10) if len(response.model_version_status) > 0: return True except Exception as err: logging.exception(err) return False
def testGetModelStatus(self): """Test ModelService.GetModelStatus implementation.""" model_path = self._GetSavedModelBundlePath() model_server_address = TensorflowModelServerTest.RunServer( 'default', model_path)[1] print('Sending GetModelStatus request...') # Send request request = get_model_status_pb2.GetModelStatusRequest() request.model_spec.name = 'default' channel = grpc.insecure_channel(model_server_address) stub = model_service_pb2_grpc.ModelServiceStub(channel) result = stub.GetModelStatus(request, RPC_TIMEOUT) # 5 secs timeout # Verify response self.assertEqual(1, len(result.model_version_status)) self.assertEqual(123, result.model_version_status[0].version) # OK error code (0) indicates no error occurred self.assertEqual(0, result.model_version_status[0].status.error_code)
def main(_): channel = grpc.insecure_channel(FLAGS.address) stub = model_service_pb2_grpc.ModelServiceStub(channel) if MODE.STATUS == FLAGS.mode: request = get_model_status_pb2.GetModelStatusRequest() request.model_spec.name = 'pascal' request.model_spec.signature_name = 'serving_default' result = stub.GetModelStatus(request) elif MODE.CONFIG == FLAGS.mode: request = model_management_pb2.ReloadConfigRequest() config = request.config.model_config_list.config.add() config.name = 'detection' config.base_path = '/models/detection/detection' config.model_platform = 'tensorflow' config2 = request.config.model_config_list.config.add() config2.name = 'pascal' config2.base_path = '/models/detection/pascal' config2.model_platform = 'tensorflow' result = stub.HandleReloadConfigRequest(request) print(result)
def add_model_config(host, name): channel = grpc.insecure_channel(host) stub = model_service_pb2_grpc.ModelServiceStub(channel) request = model_management_pb2.ReloadConfigRequest() model_server_config = model_server_config_pb2.ModelServerConfig() config_list = model_server_config_pb2.ModelConfigList() one_config = config_list.config.add() one_config.name = name one_config.base_path = os.path.join('/models', name) one_config.model_platform = 'tensorflow' model_server_config.model_config_list.CopyFrom(config_list) request.config.CopyFrom(model_server_config) response = stub.HandleReloadConfigRequest(request, 10) if response.status.error_code == 0: print('Reload successful') else: print('Reload failed: {}: {}'.format(response.status.error_code, response.status.error_message))
def prepare_stub_and_request(address, model_name, model_version=None, creds=None, opts=None, request_type=INFERENCE_REQUEST): if opts is not None: opts = (('grpc.ssl_target_name_override', opts), ) if creds is not None: channel = grpc.secure_channel(address, creds, options=opts) else: channel = grpc.insecure_channel(address, options=opts) request = None stub = None if request_type == MODEL_STATUS_REQUEST: request = get_model_status_pb2.GetModelStatusRequest() stub = model_service_pb2_grpc.ModelServiceStub(channel) elif request_type == INFERENCE_REQUEST: stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) request = predict_pb2.PredictRequest() request.model_spec.name = model_name if model_version is not None: request.model_spec.version.value = model_version return stub, request
def service_channel(cls): return model_service_pb2_grpc.ModelServiceStub(cls.channel)
def __init__(self, port, host='0.0.0.0'): self.channel = grpc.insecure_channel('{}:{}'.format(host, port)) self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
required=False, default=9000, help='Specify port to grpc service. default: 9000') parser.add_argument('--model_name', default='resnet', help='Model name to query. default: resnet', dest='model_name') parser.add_argument( '--model_version', type=int, help='Model version to query. Lists all versions if omitted', dest='model_version') args = vars(parser.parse_args()) channel = grpc.insecure_channel("{}:{}".format(args['grpc_address'], args['grpc_port'])) stub = model_service_pb2_grpc.ModelServiceStub(channel) print('Getting model status for model:', args.get('model_name')) request = get_model_status_pb2.GetModelStatusRequest() request.model_spec.name = args.get('model_name') if args.get('model_version') is not None: request.model_spec.version.value = args.get('model_version') result = stub.GetModelStatus( request, 10.0) # result includes a dictionary with all model outputs print_status_response(response=result)
def __init__(self, address): self.address = address self.model_platform = "tensorflow" self.channel = grpc.insecure_channel(self.address) self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel) self.timeout = 600 # gRPC timeout in seconds
def __init__(self, server): super().__init__(server) self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
def getConfigurations(): channel = grpc.insecure_channel(const.host + ":" + const.port) stub = model_service_pb2_grpc.ModelServiceStub(channel) model_server_config = model_server_config_pb2.ModelServerConfig() return model_server_config
def create_channel_for_model_ver_pol_server_status(): channel = grpc.insecure_channel('localhost:9006') stub = model_service_pb2_grpc.ModelServiceStub(channel) return stub
def create_channel_for_update_flow_specific_status(): channel = grpc.insecure_channel('localhost:9008') stub = model_service_pb2_grpc.ModelServiceStub(channel) return stub