def main(_):
    if MODE.STATUS == FLAGS.mode:
        request = get_model_status_pb2.GetModelStatusRequest()
        request.model_spec.name = 'detection'
        request.model_spec.signature_name = 'serving_default'
    elif MODE.CONFIG == FLAGS.mode:
        request = model_management_pb2.ReloadConfigRequest()
        config = request.config.model_config_list.config.add()
        config.name = 'detection'
        config.base_path = '/models/detection/detection'
        config.model_platform = 'tensorflow'
        config.model_version_policy.specific.versions.append(5)
        config.model_version_policy.specific.versions.append(7)
        config2 = request.config.model_config_list.config.add()
        config2.name = 'pascal'
        config2.base_path = '/models/detection/pascal'
        config2.model_platform = 'tensorflow'
    elif MODE.ZOOKEEPER == FLAGS.mode:
        zk = KazooClient(hosts="10.10.67.225:2181")
        zk.start()
        zk.ensure_path('/serving/cunan')
        zk.set(
            '/serving/cunan',
            get_config('detection', 5, 224, 'serving_default',
                       ','.join(get_classes('model_data/cci.names')),
                       "10.12.102.32:8000"))
        return
    for address in FLAGS.addresses:
        channel = grpc.insecure_channel(address)
        stub = model_service_pb2_grpc.ModelServiceStub(channel)
        if MODE.STATUS == FLAGS.mode:
            result = stub.GetModelStatus(request)
        elif MODE.CONFIG == FLAGS.mode:
            result = stub.HandleReloadConfigRequest(request)
        print(result)
    def delete_model(self, model_name):
        model_server_config = model_server_config_pb2.ModelServerConfig()
        config_list = model_server_config_pb2.ModelConfigList()

        with lock(DEFAULT_LOCK_FILE):
            try:
                config_file = self._read_model_config(MODEL_CONFIG_FILE)
                config_list_text = config_file.strip('\n').strip('}').strip('model_config_list: {')
                config_list = text_format.Parse(text=config_list_text, message=config_list)

                for config in config_list.config:
                    if config.name == model_name:
                        config_list.config.remove(config)
                        model_server_config.model_config_list.CopyFrom(config_list)
                        req = model_management_pb2.ReloadConfigRequest()
                        req.config.CopyFrom(model_server_config)
                        self.stub.HandleReloadConfigRequest(req)
                        self._delete_model_from_config_file(model_server_config)

                # no such model exists
                raise Exception(404, '{} not loaded yet.'.format(model_name))
            except grpc.RpcError as e:
                raise Exception(e.code(), e.details())

        return 'Model {} unloaded.'.format(model_name)
    def delete_model(self, model_name):
        model_server_config = model_server_config_pb2.ModelServerConfig()
        config_list = model_server_config_pb2.ModelConfigList()

        with lock(DEFAULT_LOCK_FILE):
            try:
                config_file = self._read_model_config(MODEL_CONFIG_FILE)
                config_list_text = config_file.strip('\n').strip('}').strip(
                    'model_config_list: {')
                config_list = text_format.Parse(text=config_list_text,
                                                message=config_list)

                for config in config_list.config:
                    if config.name == model_name:
                        config_list.config.remove(config)
                        model_server_config.model_config_list.CopyFrom(
                            config_list)
                        req = model_management_pb2.ReloadConfigRequest()
                        req.config.CopyFrom(model_server_config)
                        self.stub.HandleReloadConfigRequest(
                            request=req,
                            timeout=GRPC_REQUEST_TIMEOUT_IN_SECONDS,
                            wait_for_ready=True)
                        return self._delete_model_from_config_file(
                            model_server_config)

                # no such model exists
                raise FileNotFoundError
            except grpc.RpcError as e:
                if e.code() is grpc.StatusCode.DEADLINE_EXCEEDED:
                    raise MultiModelException(408, e.details())
                raise MultiModelException(500, e.details())

        return 'Model {} unloaded.'.format(model_name)
Beispiel #4
0
def read_add_rewrite_config():
    """
    测试读取一个config文件,新加一个配置信息,并且reload过后回写
	test:read a config_file and add a new model, then reload and rewrite to the config_file
    :return:
    """
    with open("model_config.ini", "r") as f:
        config_ini = f.read()
    #print(config_ini)
    model_server_config = model_server_config_pb2.ModelServerConfig()
    model_server_config = text_format.Parse(text=config_ini,
                                            message=model_server_config)
    one_config = model_server_config.model_config_list.config.add(
    )  # add one more config
    one_config.name = "test2"
    one_config.base_path = "/home/sparkingarthur/tools/tf-serving-custom/serving/my_models/test2"
    one_config.model_platform = "tensorflow"

    print(model_server_config)

    channel = grpc.insecure_channel('10.200.24.101:8009')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()
    request.config.CopyFrom(model_server_config)
    responese = stub.HandleReloadConfigRequest(request, 10)
    if responese.status.error_code == 0:
        print("reload sucessfully")
    else:
        print("reload error!")
        print(responese.status.error_code)
        print(responese.status.error_message)
    new_config = text_format.MessageToString(model_server_config)
    with open("model_config.ini", "w+") as f:
        f.write(new_config)
def main(config_path, host):

    models = load_config(config_path)

    channel = grpc.insecure_channel(host)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)

    request = model_management_pb2.ReloadConfigRequest()

    model_server_config = model_server_config_pb2.ModelServerConfig()
    config_list = model_server_config_pb2.ModelConfigList()

    for model in models:
        image_config = config_list.config.add()
        image_config.name = model['config']['name']
        image_config.base_path = model['config']['base_path']
        image_config.model_platform = model['config']['model_platform']

    model_server_config.model_config_list.CopyFrom(config_list)
    request.config.CopyFrom(model_server_config)

    print(request.ListFields())
    print('Sending request')
    response = stub.HandleReloadConfigRequest(request, 30)
    if response.status.error_code == 0:
        print('Reload successful')
    else:
        print('Reload failed!')
        print(response.status.error_code)
        print(response.status.error_message)
Beispiel #6
0
def run():
  channel = grpc.insecure_channel(sys.argv[1])
  stub = model_service_pb2_grpc.ModelServiceStub(channel)
  request = model_management_pb2.ReloadConfigRequest()  ##message ReloadConfigRequest
  model_server_config = model_server_config_pb2.ModelServerConfig()

  config_list = model_server_config_pb2.ModelConfigList()##message ModelConfigList

  one_config = config_list.config.add() #####try to add one model config
  one_config.name= "svm_cls"
  one_config.base_path = "/models/svm_cls"
  one_config.model_platform="tensorflow"
  #one_config.model_version_policy.specific.versions.append(4)


  model_server_config.model_config_list.CopyFrom(config_list)
  request.config.CopyFrom(model_server_config)

  #print(request.IsInitialized())
  #print(request.ListFields())

  responese = stub.HandleReloadConfigRequest(request,10)
  if responese.status.error_code == 0:
      print("Reload sucessfully")
  else:
      print("Reload failed!")
      print(responese.status.error_code)
      print(responese.status.error_message)
def updateConfigurations():
    channel = grpc.insecure_channel(const.host + ":" + const.port)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()
    model_server_config = model_server_config_pb2.ModelServerConfig()

    # Create a config to add to the list of served models
    configurations = open("models.conf", "r").read()
    config_list = model_server_config_pb2.ModelConfigList()
    model_server_config = text_format.Parse(text=configurations,
                                            message=model_server_config)

    request.config.CopyFrom(model_server_config)

    print(request.IsInitialized())
    print(request.ListFields())

    response = stub.HandleReloadConfigRequest(request, 10)
    if response.status.error_code == 0:
        return {"status": 200, "message": "Reload sucessfully"}
    else:
        return {
            "status": response.status.error_code,
            "message": response.status.error_message
        }
Beispiel #8
0
def update():
    config_file = "models/model.config"
    host_port = "localhost:8500"
    
    channel = grpc.insecure_channel(host_port)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()
    
    # read config file
    config_content = open(config_file, "r").read()
    model_server_config = model_server_config_pb2.ModelServerConfig()
    model_server_config = text_format.Parse(text=config_content, message=model_server_config)
    
    #print model_server_config.model_config_list.config
    #print(request.IsInitialized())
    #print(request.ListFields())

    request.config.CopyFrom(model_server_config)
    request_response = stub.HandleReloadConfigRequest(request, 10)
    
    if request_response.status.error_code == 0:
        print("TF Serving config file updated.")
    else:
        print("Failed to update config file.")
        print(request_response.status.error_code)
        print(request_response.status.error_message)
    def add_model(self, model_name, base_path, model_platform='tensorflow'):
        # read model configs from existing model config file
        model_server_config = model_server_config_pb2.ModelServerConfig()
        config_list = model_server_config_pb2.ModelConfigList()

        with lock(DEFAULT_LOCK_FILE):
            try:
                config_file = self._read_model_config(MODEL_CONFIG_FILE)
                model_server_config = text_format.Parse(
                    text=config_file, message=model_server_config)

                new_model_config = config_list.config.add()
                new_model_config.name = model_name
                new_model_config.base_path = base_path
                new_model_config.model_platform = model_platform

                # send HandleReloadConfigRequest to tensorflow model server
                model_server_config.model_config_list.MergeFrom(config_list)
                req = model_management_pb2.ReloadConfigRequest()
                req.config.CopyFrom(model_server_config)

                self.stub.HandleReloadConfigRequest(
                    request=req,
                    timeout=GRPC_REQUEST_TIMEOUT_IN_SECONDS,
                    wait_for_ready=True)
                self._add_model_to_config_file(model_name, base_path,
                                               model_platform)
            except grpc.RpcError as e:
                if e.code() is grpc.StatusCode.INVALID_ARGUMENT:
                    raise MultiModelException(409, e.details())
                elif e.code() is grpc.StatusCode.DEADLINE_EXCEEDED:
                    raise MultiModelException(408, e.details())
                raise MultiModelException(500, e.details())

        return 'Successfully loaded model {}'.format(model_name)
Beispiel #10
0
def parse_config_file():
    """
    测试读取config_file并回写
    :return:
    """
    with open("model_config.ini", "r") as f:
        config_ini = f.read()
    print(config_ini)

    channel = grpc.insecure_channel('yourip:port')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest(
    )  ##message ReloadConfigRequest
    model_server_config = model_server_config_pb2.ModelServerConfig()
    x = text_format.Parse(text=config_ini,
                          message=model_server_config)  # 非官方认证的方法
    x = text_format.MessageToString(model_server_config)
    with open("x.txt", "w+") as f:
        f.write(x)
    print(x)
    print(model_server_config.IsInitialized())
    request.config.CopyFrom(model_server_config)
    # print(request.ListFields())
    responese = stub.HandleReloadConfigRequest(request, 10)
    if responese.status.error_code == 0:
        print("reload sucessfully")
    else:
        print("reload error!")
        print(responese.status.error_code)
        print(responese.status.error_message)
Beispiel #11
0
def run_reload_one_config():
    """
  测试gRPC的reloadconfig接口
  :return:
  """
    channel = grpc.insecure_channel('yourip:port')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest(
    )  ##message ReloadConfigRequest
    model_server_config = model_server_config_pb2.ModelServerConfig()

    config_list = model_server_config_pb2.ModelConfigList(
    )  ##message ModelConfigList
    ####try to add one config
    one_config = config_list.config.add()  #
    one_config.name = "test1"
    one_config.base_path = "/home/model/model1"
    one_config.model_platform = "tensorflow"

    model_server_config.model_config_list.CopyFrom(config_list)  #one of

    request.config.CopyFrom(model_server_config)

    print(request.IsInitialized())
    print(request.ListFields())

    responese = stub.HandleReloadConfigRequest(request, 10)
    if responese.status.error_code == 0:
        print("reload sucessfully")
    else:
        print("reload error!")
        print(responese.status.error_code)
        print(responese.status.error_message)
Beispiel #12
0
    def __initialize(self):
        setproctitle.setproctitle('tf-serving')
        tf_serving_config_file = './config/tf_serving.json'
        with open(tf_serving_config_file, 'r') as fp:
            serving_config = json.load(fp)
        grpc_port = serving_config['gprc_port']  
        use_batch = serving_config['use_batch'] 

        options = [('grpc.max_send_message_length', 1024 * 1024 * 1024), ('grpc.max_receive_message_length', 1024 * 1024 * 1024)]
        
 
        model_config_file = './config/tf_serving_model.conf'
        cmd = 'tensorflow_model_server --port={} --rest_api_port={} --model_config_file={}'.format(grpc_port-1, grpc_port, model_config_file)
        if use_batch:
            batch_parameter_file = './config/batching.conf'
            cmd = cmd + ' --enable_batching=true --batching_parameters_file={}'.format(batch_parameter_file)
        self.serving_cmds = []
        self.serving_cmds.append(cmd)
        print(self.serving_cmds)
        system_env = os.environ.copy()
        self.process = subprocess.Popen(self.serving_cmds, env=system_env, shell=True)

        ## start grpc 
        self.grpc_channel = grpc.insecure_channel('127.0.0.1:{}'.format(grpc_port-1), options=options)
        self.grpc_stub = model_service_pb2_grpc.ModelServiceStub(self.grpc_channel)

        # wait for start tf serving
        time.sleep(3)

        # reload model
        model_server_config = model_server_config_pb2.ModelServerConfig() 
        config_list = model_server_config_pb2.ModelConfigList()

        model_config_json_file = './config/model.json'
        with open(model_config_json_file, 'r') as fp:
            model_configs = json.load(fp)
        model_dir = model_configs['model_dir']
        model_list = model_configs['models']
        for model_info in model_list:
            print(model_info)
            model_config = model_server_config_pb2.ModelConfig()
            model_config.name = model_info['name']
            model_config.base_path = os.path.abspath(os.path.join(model_dir, model_info['path']))
            model_config.model_platform = 'tensorflow'
            config_list.config.append(model_config)

        model_server_config.model_config_list.CopyFrom(config_list)
        request = model_management_pb2.ReloadConfigRequest()
        request.config.CopyFrom(model_server_config)

        grpc_response = self.grpc_stub.HandleReloadConfigRequest(request, 30)
        print(grpc_port)
        return 
Beispiel #13
0
def run():
    channel = grpc.insecure_channel('192.168.199.198:8500')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    # message ReloadConfigRequest
    request = model_management_pb2.ReloadConfigRequest()
    model_server_config = model_server_config_pb2.ModelServerConfig()

    # message ModelConfigList
    config_list = model_server_config_pb2.ModelConfigList()

    print(config_list)

    ####try to add
    one_config = config_list.config.add()
    one_config.name = "saved_model_half1"
    one_config.base_path = "/models/saved_model_half_plus_two_cpu"
    one_config.model_platform = "tensorflow"

    one_config = config_list.config.add()
    one_config.name = "saved_model_half2"
    one_config.base_path = "/models/saved_model_half_plus_two_cpu"
    one_config.model_platform = "tensorflow"

    one_config = config_list.config.add()
    one_config.name = "saved_model_half3"
    one_config.base_path = "/models/saved_model_half_plus_two_cpu"
    one_config.model_platform = "tensorflow"

    one_config = config_list.config.add()
    one_config.name = "saved_model_half4"
    one_config.base_path = "/models/saved_model_half_plus_two_cpu"
    one_config.model_platform = "tensorflow"

    # one_config = config_list.config.add()
    # one_config.name = "saved_model_half5"
    # one_config.base_path = "/models/saved_model_half_plus_two_cpu_bak"
    # one_config.model_platform = "tensorflow"

    model_server_config.model_config_list.CopyFrom(config_list)  # one of

    request.config.CopyFrom(model_server_config)

    print(request.IsInitialized())
    print(request.ListFields())

    response = stub.HandleReloadConfigRequest(request, 10)
    if response.status.error_code == 0:
        print("reload sucessfully")
    else:
        print("reload error!")
        print(response.status.error_code)
        print(response.status.error_message)
Beispiel #14
0
    def remove_models(
        self,
        model_names: List[str],
        model_versions: List[List[str]],
        timeout: Optional[float] = None,
    ) -> None:
        """
        Remove models to TFS.

        Args:
            model_names: List of model names to add.
            model_versions: List of lists - each element is a list of versions for a given model name.
        Raises:
            grpc.RpcError in case something bad happens while communicating.
                StatusCode.DEADLINE_EXCEEDED when timeout is encountered. StatusCode.UNAVAILABLE when the service is unreachable.
            cortex_internal.lib.exceptions.CortexException if a non-0 response code is returned (i.e. model couldn't be unloaded).
        """

        request = model_management_pb2.ReloadConfigRequest()
        model_server_config = model_server_config_pb2.ModelServerConfig()

        for model_name, versions in zip(model_names, model_versions):
            for model_version in versions:
                self._remove_model_from_dict(model_name, model_version)

        config_list = model_server_config_pb2.ModelConfigList()
        remaining_model_names = self._get_model_names()
        for model_name in remaining_model_names:
            versions, model_disk_path = self._get_model_info(model_name)
            versions = [int(version) for version in versions]
            model_config = config_list.config.add()
            model_config.name = model_name
            model_config.base_path = model_disk_path
            model_config.model_version_policy.CopyFrom(
                ServableVersionPolicy(specific=Specific(versions=versions))
            )
            model_config.model_platform = "tensorflow"

        model_server_config.model_config_list.CopyFrom(config_list)
        request.config.CopyFrom(model_server_config)

        response = self._service.HandleReloadConfigRequest(request, timeout)

        if not (response and response.status.error_code == 0):
            if response:
                raise CortexException(
                    "couldn't unload user-requested models {} - failed with error code {}: {}".format(
                        model_names, response.status.error_code, response.status.error_message
                    )
                )
            else:
                raise CortexException("couldn't unload user-requested models")
def main(_):
    channel = grpc.insecure_channel(':8500')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()

    model_server_config = model_server_config_pb2.ModelServerConfig()
    _, _, _, _, _, model_configs_text = build_from_config(os.getcwd())
    text_format.Parse(model_configs_text, model_server_config)
    request.config.CopyFrom(model_server_config)
    responese = stub.HandleReloadConfigRequest(request, 10)
    if responese.status.error_code == 0:
        print("successful update model")
    else:
        print("fail")
        print(responese.status.error_code)
        print(responese.status.error_message)
Beispiel #16
0
class ModelLoader:
    """
    This class provides the interface to load new models into TensorFlow Serving. This is implemented through a
    gRPC call to the TFS api which triggers it to look for directories matching the name of the model specified
    """
    channel = Channel.service_channel()
    request = model_management_pb2.ReloadConfigRequest()
    model_server_config = model_server_config_pb2.ModelServerConfig()
    conf = model_server_config_pb2.ModelConfigList()

    @classmethod
    def set_config(cls, model_name: str) -> None:
        config = cls.conf.config.add()
        config.name = model_name
        config.base_path = '/models/' + model_name
        config.model_platform = 'tensorflow'
        cls.model_server_config.model_config_list.CopyFrom(cls.conf)
        cls.request.config.CopyFrom(cls.model_server_config)

    @classmethod
    def load(cls, model_name: str) -> None:
        """Load model

        This will send the gRPC request. In particular, it will open a gRPC channel and communicate with the
        ReloadConfigRequest api to inform TFS of a change in configuration

        Parameters
        ----------
        model_name : str
            Name of the model, as specified in the instantiated Learner class

        Returns
        -------
        None
        """
        cls.set_config(model_name)
        log.info(cls.request.IsInitialized())
        log.info(cls.request.ListFields())

        response = cls.channel.HandleReloadConfigRequest(
            cls.request, ServerManager.PREDICTION_TIMEOUT)
        if response.status.error_code == 0:
            p.print_success(f'Loaded model {model_name} successfully')
        else:
            p.print_error(
                f'Loading failed, {response.status.error_code}: {response.status.error_message}'
            )
Beispiel #17
0
def main(_):
    channel = grpc.insecure_channel(FLAGS.address)

    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    if MODE.STATUS == FLAGS.mode:
        request = get_model_status_pb2.GetModelStatusRequest()
        request.model_spec.name = 'pascal'
        request.model_spec.signature_name = 'serving_default'
        result = stub.GetModelStatus(request)
    elif MODE.CONFIG == FLAGS.mode:
        request = model_management_pb2.ReloadConfigRequest()
        config = request.config.model_config_list.config.add()
        config.name = 'detection'
        config.base_path = '/models/detection/detection'
        config.model_platform = 'tensorflow'
        config2 = request.config.model_config_list.config.add()
        config2.name = 'pascal'
        config2.base_path = '/models/detection/pascal'
        config2.model_platform = 'tensorflow'
        result = stub.HandleReloadConfigRequest(request)

    print(result)
Beispiel #18
0
def add_model_config(host, name):
    channel = grpc.insecure_channel(host)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()
    model_server_config = model_server_config_pb2.ModelServerConfig()

    config_list = model_server_config_pb2.ModelConfigList()
    one_config = config_list.config.add()
    one_config.name = name
    one_config.base_path = os.path.join('/models', name)
    one_config.model_platform = 'tensorflow'

    model_server_config.model_config_list.CopyFrom(config_list)

    request.config.CopyFrom(model_server_config)
    response = stub.HandleReloadConfigRequest(request, 10)

    if response.status.error_code == 0:
        print('Reload successful')
    else:
        print('Reload failed: {}: {}'.format(response.status.error_code,
                                             response.status.error_message))
Beispiel #19
0
 def __init__(self, config=None, **kwargs):
     super().__init__(model_management_pb2.ReloadConfigRequest(), 
                      config=config,
                      **kwargs)
Beispiel #20
0
    def add_models(
        self,
        model_names: List[str],
        model_versions: List[List[str]],
        model_disk_paths: List[str],
        signature_keys: List[Optional[str]],
        skip_if_present: bool = False,
        timeout: Optional[float] = None,
        max_retries: int = 0,
    ) -> None:
        """
        Add models to TFS. If they can't be loaded, use remove_models to remove them from TFS.

        Args:
            model_names: List of model names to add.
            model_versions: List of lists - each element is a list of versions for a given model name.
            model_disk_paths: The common model disk path of multiple versioned models of the same model name (i.e. modelA/ for modelA/1 and modelA/2).
            skip_if_present: If the models are already loaded, don't make a new request to TFS.
            signature_keys: The signature keys as set in cortex_internal.yaml. If an element is set to None, then "predict" key will be assumed.
            max_retries: How many times to call ReloadConfig before giving up.
        Raises:
            grpc.RpcError in case something bad happens while communicating.
                StatusCode.DEADLINE_EXCEEDED when timeout is encountered. StatusCode.UNAVAILABLE when the service is unreachable.
            cortex_internal.lib.exceptions.CortexException if a non-0 response code is returned (i.e. model couldn't be loaded).
            cortex_internal.lib.exceptions.UserException when a model couldn't be validated for the signature def.
        """

        request = model_management_pb2.ReloadConfigRequest()
        model_server_config = model_server_config_pb2.ModelServerConfig()

        num_added_models = 0
        for model_name, versions, model_disk_path in zip(
            model_names, model_versions, model_disk_paths
        ):
            for model_version in versions:
                versioned_model_disk_path = os.path.join(model_disk_path, model_version)
                num_added_models += self._add_model_to_dict(
                    model_name, model_version, versioned_model_disk_path
                )

        if skip_if_present and num_added_models == 0:
            return

        config_list = model_server_config_pb2.ModelConfigList()
        current_model_names = self._get_model_names()
        for model_name in current_model_names:
            versions, model_disk_path = self._get_model_info(model_name)
            versions = [int(version) for version in versions]
            model_config = config_list.config.add()
            model_config.name = model_name
            model_config.base_path = model_disk_path
            model_config.model_version_policy.CopyFrom(
                ServableVersionPolicy(specific=Specific(versions=versions))
            )
            model_config.model_platform = "tensorflow"

        model_server_config.model_config_list.CopyFrom(config_list)
        request.config.CopyFrom(model_server_config)

        while max_retries >= 0:
            max_retries -= 1
            try:
                # to prevent HandleReloadConfigRequest from
                # throwing an exception (TFS has some race-condition bug)
                time.sleep(0.125)
                response = self._service.HandleReloadConfigRequest(request, timeout)
                break
            except grpc.RpcError as err:
                # to prevent HandleReloadConfigRequest from
                # throwing another exception on the next run
                time.sleep(0.125)
                raise

        if not (response and response.status.error_code == 0):
            if response:
                raise CortexException(
                    "couldn't load user-requested models {} - failed with error code {}: {}".format(
                        model_names, response.status.error_code, response.status.error_message
                    )
                )
            else:
                raise CortexException("couldn't load user-requested models")

        # get models metadata
        for model_name, versions, signature_key in zip(model_names, model_versions, signature_keys):
            for model_version in versions:
                self._load_model_signatures(model_name, model_version, signature_key)
Beispiel #21
0
    def add_models_config(self, names, base_paths, replace_models=False):
        request = model_management_pb2.ReloadConfigRequest()
        model_server_config = model_server_config_pb2.ModelServerConfig()

        # create model(s) configuration
        config_list = model_server_config_pb2.ModelConfigList()
        for i, name in enumerate(names):
            model_config = config_list.config.add()
            model_config.name = name
            model_config.base_path = base_paths[i]
            model_config.model_platform = self.model_platform

        if replace_models:
            model_server_config.model_config_list.CopyFrom(config_list)
            request.config.CopyFrom(model_server_config)
        else:
            model_server_config.model_config_list.MergeFrom(config_list)
            request.config.MergeFrom(model_server_config)

        loaded_models = threading.Event()

        def log_loading_models():
            while not loaded_models.is_set():
                time.sleep(5)
                cx_logger().warn("model(s) still loading ...")

        log_thread = threading.Thread(target=log_loading_models, daemon=True)
        log_thread.start()

        # request TFS to load models
        limit = 3
        response = None
        for i in range(limit):
            try:
                # this request doesn't return until all models have been successfully loaded
                response = self.stub.HandleReloadConfigRequest(
                    request, self.timeout)
                break
            except Exception as e:
                if not (isinstance(e, grpc.RpcError) and e.code() in [
                        grpc.StatusCode.UNAVAILABLE,
                        grpc.StatusCode.DEADLINE_EXCEEDED
                ]):
                    print(e)  # unexpected error
                time.sleep(1.0)

        loaded_models.set()
        log_thread.join()

        # report error or success
        if response and response.status.error_code == 0:
            cx_logger().info(
                "successfully loaded {} models into TF-Serving".format(names))
        else:
            if response:
                raise CortexException(
                    "couldn't load user-requested models - failed with error code {}: {}"
                    .format(response.status.error_code,
                            response.status.error_message))
            else:
                raise CortexException("couldn't load user-requested models")
# e-mail:[email protected]
# datetime:1993/12/01
# filename:grpc_client_hot_deploy.py
# software: PyCharm

from google.protobuf import text_format
from tensorflow_serving.apis import model_management_pb2
from tensorflow_serving.apis import model_service_pb2_grpc
from tensorflow_serving.config import model_server_config_pb2
import grpc

channel = grpc.insecure_channel('ip')

config_file = "model.config"
stub = model_service_pb2_grpc.ModelServiceStub(channel)
request = model_management_pb2.ReloadConfigRequest()

# read config file
config_content = open(config_file, "r").read()
model_server_config = model_server_config_pb2.ModelServerConfig()
model_server_config = text_format.Parse(text=config_content,
                                        message=model_server_config)
request.config.CopyFrom(model_server_config)
request_response = stub.HandleReloadConfigRequest(request, 10)

if request_response.status.error_code == 0:
    open(config_file, "w").write(str(request.config))
    print("TF Serving config file updated.")
else:
    print("Failed to update config file.")
    print(request_response.status.error_code)