def delete_model(self, model_name):
        model_server_config = model_server_config_pb2.ModelServerConfig()
        config_list = model_server_config_pb2.ModelConfigList()

        with lock(DEFAULT_LOCK_FILE):
            try:
                config_file = self._read_model_config(MODEL_CONFIG_FILE)
                config_list_text = config_file.strip('\n').strip('}').strip('model_config_list: {')
                config_list = text_format.Parse(text=config_list_text, message=config_list)

                for config in config_list.config:
                    if config.name == model_name:
                        config_list.config.remove(config)
                        model_server_config.model_config_list.CopyFrom(config_list)
                        req = model_management_pb2.ReloadConfigRequest()
                        req.config.CopyFrom(model_server_config)
                        self.stub.HandleReloadConfigRequest(req)
                        self._delete_model_from_config_file(model_server_config)

                # no such model exists
                raise Exception(404, '{} not loaded yet.'.format(model_name))
            except grpc.RpcError as e:
                raise Exception(e.code(), e.details())

        return 'Model {} unloaded.'.format(model_name)
def main(config_path, host):

    models = load_config(config_path)

    channel = grpc.insecure_channel(host)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)

    request = model_management_pb2.ReloadConfigRequest()

    model_server_config = model_server_config_pb2.ModelServerConfig()
    config_list = model_server_config_pb2.ModelConfigList()

    for model in models:
        image_config = config_list.config.add()
        image_config.name = model['config']['name']
        image_config.base_path = model['config']['base_path']
        image_config.model_platform = model['config']['model_platform']

    model_server_config.model_config_list.CopyFrom(config_list)
    request.config.CopyFrom(model_server_config)

    print(request.ListFields())
    print('Sending request')
    response = stub.HandleReloadConfigRequest(request, 30)
    if response.status.error_code == 0:
        print('Reload successful')
    else:
        print('Reload failed!')
        print(response.status.error_code)
        print(response.status.error_message)
예제 #3
0
def read_add_rewrite_config():
    """
    测试读取一个config文件,新加一个配置信息,并且reload过后回写
	test:read a config_file and add a new model, then reload and rewrite to the config_file
    :return:
    """
    with open("model_config.ini", "r") as f:
        config_ini = f.read()
    #print(config_ini)
    model_server_config = model_server_config_pb2.ModelServerConfig()
    model_server_config = text_format.Parse(text=config_ini,
                                            message=model_server_config)
    one_config = model_server_config.model_config_list.config.add(
    )  # add one more config
    one_config.name = "test2"
    one_config.base_path = "/home/sparkingarthur/tools/tf-serving-custom/serving/my_models/test2"
    one_config.model_platform = "tensorflow"

    print(model_server_config)

    channel = grpc.insecure_channel('10.200.24.101:8009')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()
    request.config.CopyFrom(model_server_config)
    responese = stub.HandleReloadConfigRequest(request, 10)
    if responese.status.error_code == 0:
        print("reload sucessfully")
    else:
        print("reload error!")
        print(responese.status.error_code)
        print(responese.status.error_message)
    new_config = text_format.MessageToString(model_server_config)
    with open("model_config.ini", "w+") as f:
        f.write(new_config)
예제 #4
0
def parse_config_file():
    """
    测试读取config_file并回写
    :return:
    """
    with open("model_config.ini", "r") as f:
        config_ini = f.read()
    print(config_ini)

    channel = grpc.insecure_channel('yourip:port')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest(
    )  ##message ReloadConfigRequest
    model_server_config = model_server_config_pb2.ModelServerConfig()
    x = text_format.Parse(text=config_ini,
                          message=model_server_config)  # 非官方认证的方法
    x = text_format.MessageToString(model_server_config)
    with open("x.txt", "w+") as f:
        f.write(x)
    print(x)
    print(model_server_config.IsInitialized())
    request.config.CopyFrom(model_server_config)
    # print(request.ListFields())
    responese = stub.HandleReloadConfigRequest(request, 10)
    if responese.status.error_code == 0:
        print("reload sucessfully")
    else:
        print("reload error!")
        print(responese.status.error_code)
        print(responese.status.error_message)
예제 #5
0
def run():
  channel = grpc.insecure_channel(sys.argv[1])
  stub = model_service_pb2_grpc.ModelServiceStub(channel)
  request = model_management_pb2.ReloadConfigRequest()  ##message ReloadConfigRequest
  model_server_config = model_server_config_pb2.ModelServerConfig()

  config_list = model_server_config_pb2.ModelConfigList()##message ModelConfigList

  one_config = config_list.config.add() #####try to add one model config
  one_config.name= "svm_cls"
  one_config.base_path = "/models/svm_cls"
  one_config.model_platform="tensorflow"
  #one_config.model_version_policy.specific.versions.append(4)


  model_server_config.model_config_list.CopyFrom(config_list)
  request.config.CopyFrom(model_server_config)

  #print(request.IsInitialized())
  #print(request.ListFields())

  responese = stub.HandleReloadConfigRequest(request,10)
  if responese.status.error_code == 0:
      print("Reload sucessfully")
  else:
      print("Reload failed!")
      print(responese.status.error_code)
      print(responese.status.error_message)
예제 #6
0
def update():
    config_file = "models/model.config"
    host_port = "localhost:8500"
    
    channel = grpc.insecure_channel(host_port)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()
    
    # read config file
    config_content = open(config_file, "r").read()
    model_server_config = model_server_config_pb2.ModelServerConfig()
    model_server_config = text_format.Parse(text=config_content, message=model_server_config)
    
    #print model_server_config.model_config_list.config
    #print(request.IsInitialized())
    #print(request.ListFields())

    request.config.CopyFrom(model_server_config)
    request_response = stub.HandleReloadConfigRequest(request, 10)
    
    if request_response.status.error_code == 0:
        print("TF Serving config file updated.")
    else:
        print("Failed to update config file.")
        print(request_response.status.error_code)
        print(request_response.status.error_message)
예제 #7
0
def updateConfigurations():
    channel = grpc.insecure_channel(const.host + ":" + const.port)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()
    model_server_config = model_server_config_pb2.ModelServerConfig()

    # Create a config to add to the list of served models
    configurations = open("models.conf", "r").read()
    config_list = model_server_config_pb2.ModelConfigList()
    model_server_config = text_format.Parse(text=configurations,
                                            message=model_server_config)

    request.config.CopyFrom(model_server_config)

    print(request.IsInitialized())
    print(request.ListFields())

    response = stub.HandleReloadConfigRequest(request, 10)
    if response.status.error_code == 0:
        return {"status": 200, "message": "Reload sucessfully"}
    else:
        return {
            "status": response.status.error_code,
            "message": response.status.error_message
        }
    def add_model(self, model_name, base_path, model_platform='tensorflow'):
        # read model configs from existing model config file
        model_server_config = model_server_config_pb2.ModelServerConfig()
        config_list = model_server_config_pb2.ModelConfigList()

        with lock(DEFAULT_LOCK_FILE):
            try:
                config_file = self._read_model_config(MODEL_CONFIG_FILE)
                model_server_config = text_format.Parse(
                    text=config_file, message=model_server_config)

                new_model_config = config_list.config.add()
                new_model_config.name = model_name
                new_model_config.base_path = base_path
                new_model_config.model_platform = model_platform

                # send HandleReloadConfigRequest to tensorflow model server
                model_server_config.model_config_list.MergeFrom(config_list)
                req = model_management_pb2.ReloadConfigRequest()
                req.config.CopyFrom(model_server_config)

                self.stub.HandleReloadConfigRequest(
                    request=req,
                    timeout=GRPC_REQUEST_TIMEOUT_IN_SECONDS,
                    wait_for_ready=True)
                self._add_model_to_config_file(model_name, base_path,
                                               model_platform)
            except grpc.RpcError as e:
                if e.code() is grpc.StatusCode.INVALID_ARGUMENT:
                    raise MultiModelException(409, e.details())
                elif e.code() is grpc.StatusCode.DEADLINE_EXCEEDED:
                    raise MultiModelException(408, e.details())
                raise MultiModelException(500, e.details())

        return 'Successfully loaded model {}'.format(model_name)
    def delete_model(self, model_name):
        model_server_config = model_server_config_pb2.ModelServerConfig()
        config_list = model_server_config_pb2.ModelConfigList()

        with lock(DEFAULT_LOCK_FILE):
            try:
                config_file = self._read_model_config(MODEL_CONFIG_FILE)
                config_list_text = config_file.strip('\n').strip('}').strip(
                    'model_config_list: {')
                config_list = text_format.Parse(text=config_list_text,
                                                message=config_list)

                for config in config_list.config:
                    if config.name == model_name:
                        config_list.config.remove(config)
                        model_server_config.model_config_list.CopyFrom(
                            config_list)
                        req = model_management_pb2.ReloadConfigRequest()
                        req.config.CopyFrom(model_server_config)
                        self.stub.HandleReloadConfigRequest(
                            request=req,
                            timeout=GRPC_REQUEST_TIMEOUT_IN_SECONDS,
                            wait_for_ready=True)
                        return self._delete_model_from_config_file(
                            model_server_config)

                # no such model exists
                raise FileNotFoundError
            except grpc.RpcError as e:
                if e.code() is grpc.StatusCode.DEADLINE_EXCEEDED:
                    raise MultiModelException(408, e.details())
                raise MultiModelException(500, e.details())

        return 'Model {} unloaded.'.format(model_name)
예제 #10
0
def run_reload_one_config():
    """
  测试gRPC的reloadconfig接口
  :return:
  """
    channel = grpc.insecure_channel('yourip:port')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest(
    )  ##message ReloadConfigRequest
    model_server_config = model_server_config_pb2.ModelServerConfig()

    config_list = model_server_config_pb2.ModelConfigList(
    )  ##message ModelConfigList
    ####try to add one config
    one_config = config_list.config.add()  #
    one_config.name = "test1"
    one_config.base_path = "/home/model/model1"
    one_config.model_platform = "tensorflow"

    model_server_config.model_config_list.CopyFrom(config_list)  #one of

    request.config.CopyFrom(model_server_config)

    print(request.IsInitialized())
    print(request.ListFields())

    responese = stub.HandleReloadConfigRequest(request, 10)
    if responese.status.error_code == 0:
        print("reload sucessfully")
    else:
        print("reload error!")
        print(responese.status.error_code)
        print(responese.status.error_message)
 def initialise_blank_config(self):
     """Initialises the class instance with a blank model configuration"""
     message = model_server_config_pb2.ModelServerConfig(
         model_config_list=model_server_config_pb2.ModelConfigList(
             config=[]))
     self.config: dict = json_format.MessageToDict(message)
     self.config["modelConfigList"]["config"] = []
     self._models: List[dict] = self.config["modelConfigList"]["config"]
    def parse_config_file(self, filepath: str) -> dict:
        """Parses the protobuf message and returns it as a dictionary

        Returns:
            dict: Dict representation of the protobuf message.
        """
        with open(filepath, "r") as file:
            config_file = file.read()
        message_format = model_server_config_pb2.ModelServerConfig()
        message = text_format.Parse(text=config_file, message=message_format)
        self.config: dict = json_format.MessageToDict(message)
        self._models: List[dict] = self.config["modelConfigList"]["config"]
    def to_proto(self) -> ModelServerConfig:
        """Returns the dict representation of the model config file as a
        protobuf message

        Returns:
            ModelServerConfig: Protobuf representation of the config file
        """
        return json_format.Parse(
            json.dumps(self.config),
            message=model_server_config_pb2.ModelServerConfig(),
            ignore_unknown_fields=False,
        )
예제 #14
0
    def __initialize(self):
        setproctitle.setproctitle('tf-serving')
        tf_serving_config_file = './config/tf_serving.json'
        with open(tf_serving_config_file, 'r') as fp:
            serving_config = json.load(fp)
        grpc_port = serving_config['gprc_port']  
        use_batch = serving_config['use_batch'] 

        options = [('grpc.max_send_message_length', 1024 * 1024 * 1024), ('grpc.max_receive_message_length', 1024 * 1024 * 1024)]
        
 
        model_config_file = './config/tf_serving_model.conf'
        cmd = 'tensorflow_model_server --port={} --rest_api_port={} --model_config_file={}'.format(grpc_port-1, grpc_port, model_config_file)
        if use_batch:
            batch_parameter_file = './config/batching.conf'
            cmd = cmd + ' --enable_batching=true --batching_parameters_file={}'.format(batch_parameter_file)
        self.serving_cmds = []
        self.serving_cmds.append(cmd)
        print(self.serving_cmds)
        system_env = os.environ.copy()
        self.process = subprocess.Popen(self.serving_cmds, env=system_env, shell=True)

        ## start grpc 
        self.grpc_channel = grpc.insecure_channel('127.0.0.1:{}'.format(grpc_port-1), options=options)
        self.grpc_stub = model_service_pb2_grpc.ModelServiceStub(self.grpc_channel)

        # wait for start tf serving
        time.sleep(3)

        # reload model
        model_server_config = model_server_config_pb2.ModelServerConfig() 
        config_list = model_server_config_pb2.ModelConfigList()

        model_config_json_file = './config/model.json'
        with open(model_config_json_file, 'r') as fp:
            model_configs = json.load(fp)
        model_dir = model_configs['model_dir']
        model_list = model_configs['models']
        for model_info in model_list:
            print(model_info)
            model_config = model_server_config_pb2.ModelConfig()
            model_config.name = model_info['name']
            model_config.base_path = os.path.abspath(os.path.join(model_dir, model_info['path']))
            model_config.model_platform = 'tensorflow'
            config_list.config.append(model_config)

        model_server_config.model_config_list.CopyFrom(config_list)
        request = model_management_pb2.ReloadConfigRequest()
        request.config.CopyFrom(model_server_config)

        grpc_response = self.grpc_stub.HandleReloadConfigRequest(request, 30)
        print(grpc_port)
        return 
예제 #15
0
    def _load_model_config(
            self, model_config_path: str,
            source: SourceBase) -> model_server_config_pb2.ModelServerConfig:
        """
        returns ModelServerConfig
        """
        # load and parse
        model_config_data = source.load_object(model_config_path)
        models_config = model_server_config_pb2.ModelServerConfig()
        tf.Parse(model_config_data, models_config)

        return models_config
예제 #16
0
def run():
    channel = grpc.insecure_channel('192.168.199.198:8500')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    # message ReloadConfigRequest
    request = model_management_pb2.ReloadConfigRequest()
    model_server_config = model_server_config_pb2.ModelServerConfig()

    # message ModelConfigList
    config_list = model_server_config_pb2.ModelConfigList()

    print(config_list)

    ####try to add
    one_config = config_list.config.add()
    one_config.name = "saved_model_half1"
    one_config.base_path = "/models/saved_model_half_plus_two_cpu"
    one_config.model_platform = "tensorflow"

    one_config = config_list.config.add()
    one_config.name = "saved_model_half2"
    one_config.base_path = "/models/saved_model_half_plus_two_cpu"
    one_config.model_platform = "tensorflow"

    one_config = config_list.config.add()
    one_config.name = "saved_model_half3"
    one_config.base_path = "/models/saved_model_half_plus_two_cpu"
    one_config.model_platform = "tensorflow"

    one_config = config_list.config.add()
    one_config.name = "saved_model_half4"
    one_config.base_path = "/models/saved_model_half_plus_two_cpu"
    one_config.model_platform = "tensorflow"

    # one_config = config_list.config.add()
    # one_config.name = "saved_model_half5"
    # one_config.base_path = "/models/saved_model_half_plus_two_cpu_bak"
    # one_config.model_platform = "tensorflow"

    model_server_config.model_config_list.CopyFrom(config_list)  # one of

    request.config.CopyFrom(model_server_config)

    print(request.IsInitialized())
    print(request.ListFields())

    response = stub.HandleReloadConfigRequest(request, 10)
    if response.status.error_code == 0:
        print("reload sucessfully")
    else:
        print("reload error!")
        print(response.status.error_code)
        print(response.status.error_message)
예제 #17
0
    def remove_models(
        self,
        model_names: List[str],
        model_versions: List[List[str]],
        timeout: Optional[float] = None,
    ) -> None:
        """
        Remove models to TFS.

        Args:
            model_names: List of model names to add.
            model_versions: List of lists - each element is a list of versions for a given model name.
        Raises:
            grpc.RpcError in case something bad happens while communicating.
                StatusCode.DEADLINE_EXCEEDED when timeout is encountered. StatusCode.UNAVAILABLE when the service is unreachable.
            cortex_internal.lib.exceptions.CortexException if a non-0 response code is returned (i.e. model couldn't be unloaded).
        """

        request = model_management_pb2.ReloadConfigRequest()
        model_server_config = model_server_config_pb2.ModelServerConfig()

        for model_name, versions in zip(model_names, model_versions):
            for model_version in versions:
                self._remove_model_from_dict(model_name, model_version)

        config_list = model_server_config_pb2.ModelConfigList()
        remaining_model_names = self._get_model_names()
        for model_name in remaining_model_names:
            versions, model_disk_path = self._get_model_info(model_name)
            versions = [int(version) for version in versions]
            model_config = config_list.config.add()
            model_config.name = model_name
            model_config.base_path = model_disk_path
            model_config.model_version_policy.CopyFrom(
                ServableVersionPolicy(specific=Specific(versions=versions))
            )
            model_config.model_platform = "tensorflow"

        model_server_config.model_config_list.CopyFrom(config_list)
        request.config.CopyFrom(model_server_config)

        response = self._service.HandleReloadConfigRequest(request, timeout)

        if not (response and response.status.error_code == 0):
            if response:
                raise CortexException(
                    "couldn't unload user-requested models {} - failed with error code {}: {}".format(
                        model_names, response.status.error_code, response.status.error_message
                    )
                )
            else:
                raise CortexException("couldn't unload user-requested models")
예제 #18
0
 def ReloadRequest(self,
                   model_server_address,
                   model_config_file):
   """Send AdminService.Reload request."""
   print 'Sending Reload request...'
   # Prepare request
   config = model_server_config_pb2.ModelServerConfig()
   with open(model_config_file, "r") as f:
     text_format.Parse(f.read(), config)
   # Send request
   host, port = model_server_address.split(':')
   channel = implementations.insecure_channel(host, int(port))
   stub = admin_service_pb2.beta_create_AdminService_stub(channel)
   result = stub.Reload(config, RPC_TIMEOUT)
   print "result of reload", result
    def load(path: str) -> model_server_config_pb2.ModelConfig:
        """
        read model config in tf format https://www.tensorflow.org/tfx/serving/serving_config
        """
        logger.info('loading model config: {}'.format(path))

        # read serving config
        config_str = open(path)\
            .read()

        # parse
        model_server_config = model_server_config_pb2.ModelServerConfig()
        tf.Parse(config_str, model_server_config)

        return model_server_config
예제 #20
0
def main(_):
    channel = grpc.insecure_channel(':8500')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()

    model_server_config = model_server_config_pb2.ModelServerConfig()
    _, _, _, _, _, model_configs_text = build_from_config(os.getcwd())
    text_format.Parse(model_configs_text, model_server_config)
    request.config.CopyFrom(model_server_config)
    responese = stub.HandleReloadConfigRequest(request, 10)
    if responese.status.error_code == 0:
        print("successful update model")
    else:
        print("fail")
        print(responese.status.error_code)
        print(responese.status.error_message)
예제 #21
0
class ModelLoader:
    """
    This class provides the interface to load new models into TensorFlow Serving. This is implemented through a
    gRPC call to the TFS api which triggers it to look for directories matching the name of the model specified
    """
    channel = Channel.service_channel()
    request = model_management_pb2.ReloadConfigRequest()
    model_server_config = model_server_config_pb2.ModelServerConfig()
    conf = model_server_config_pb2.ModelConfigList()

    @classmethod
    def set_config(cls, model_name: str) -> None:
        config = cls.conf.config.add()
        config.name = model_name
        config.base_path = '/models/' + model_name
        config.model_platform = 'tensorflow'
        cls.model_server_config.model_config_list.CopyFrom(cls.conf)
        cls.request.config.CopyFrom(cls.model_server_config)

    @classmethod
    def load(cls, model_name: str) -> None:
        """Load model

        This will send the gRPC request. In particular, it will open a gRPC channel and communicate with the
        ReloadConfigRequest api to inform TFS of a change in configuration

        Parameters
        ----------
        model_name : str
            Name of the model, as specified in the instantiated Learner class

        Returns
        -------
        None
        """
        cls.set_config(model_name)
        log.info(cls.request.IsInitialized())
        log.info(cls.request.ListFields())

        response = cls.channel.HandleReloadConfigRequest(
            cls.request, ServerManager.PREDICTION_TIMEOUT)
        if response.status.error_code == 0:
            p.print_success(f'Loaded model {model_name} successfully')
        else:
            p.print_error(
                f'Loading failed, {response.status.error_code}: {response.status.error_message}'
            )
예제 #22
0
 def setUp(cls):
     cls.models = [
         {
             "name": "model_A",
             "base_path": "/path/to/model/A/"
         },
         {
             "name": "B",
             "base_path": "/path/to/model/B/"
         },
     ]
     model_config = model_server_config_pb2.ModelServerConfig(
         model_config_list=model_server_config_pb2.ModelConfigList(config=[
             model_server_config_pb2.ModelConfig(name=m["name"],
                                                 base_path=m["base_path"])
             for m in cls.models
         ]))
     cls.model_config_file = tempfile.NamedTemporaryFile()
     with open(cls.model_config_file.name, "w") as file:
         file.write(str(model_config))
     cls.t = TFServingModelServerConfig()
예제 #23
0
def add_model_config(host, name):
    channel = grpc.insecure_channel(host)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    request = model_management_pb2.ReloadConfigRequest()
    model_server_config = model_server_config_pb2.ModelServerConfig()

    config_list = model_server_config_pb2.ModelConfigList()
    one_config = config_list.config.add()
    one_config.name = name
    one_config.base_path = os.path.join('/models', name)
    one_config.model_platform = 'tensorflow'

    model_server_config.model_config_list.CopyFrom(config_list)

    request.config.CopyFrom(model_server_config)
    response = stub.HandleReloadConfigRequest(request, 10)

    if response.status.error_code == 0:
        print('Reload successful')
    else:
        print('Reload failed: {}: {}'.format(response.status.error_code,
                                             response.status.error_message))
예제 #24
0
 def __init__(self, model_config_list=None, custom_model_config=None):
     super().__init__(model_server_config_pb2.ModelServerConfig(),
                     model_config_list=model_config_list,
                     custom_model_config=custom_model_config)
예제 #25
0
def getConfigurations():
    channel = grpc.insecure_channel(const.host + ":" + const.port)
    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    model_server_config = model_server_config_pb2.ModelServerConfig()
    return model_server_config
예제 #26
0
    def add_models(
        self,
        model_names: List[str],
        model_versions: List[List[str]],
        model_disk_paths: List[str],
        signature_keys: List[Optional[str]],
        skip_if_present: bool = False,
        timeout: Optional[float] = None,
        max_retries: int = 0,
    ) -> None:
        """
        Add models to TFS. If they can't be loaded, use remove_models to remove them from TFS.

        Args:
            model_names: List of model names to add.
            model_versions: List of lists - each element is a list of versions for a given model name.
            model_disk_paths: The common model disk path of multiple versioned models of the same model name (i.e. modelA/ for modelA/1 and modelA/2).
            skip_if_present: If the models are already loaded, don't make a new request to TFS.
            signature_keys: The signature keys as set in cortex_internal.yaml. If an element is set to None, then "predict" key will be assumed.
            max_retries: How many times to call ReloadConfig before giving up.
        Raises:
            grpc.RpcError in case something bad happens while communicating.
                StatusCode.DEADLINE_EXCEEDED when timeout is encountered. StatusCode.UNAVAILABLE when the service is unreachable.
            cortex_internal.lib.exceptions.CortexException if a non-0 response code is returned (i.e. model couldn't be loaded).
            cortex_internal.lib.exceptions.UserException when a model couldn't be validated for the signature def.
        """

        request = model_management_pb2.ReloadConfigRequest()
        model_server_config = model_server_config_pb2.ModelServerConfig()

        num_added_models = 0
        for model_name, versions, model_disk_path in zip(
            model_names, model_versions, model_disk_paths
        ):
            for model_version in versions:
                versioned_model_disk_path = os.path.join(model_disk_path, model_version)
                num_added_models += self._add_model_to_dict(
                    model_name, model_version, versioned_model_disk_path
                )

        if skip_if_present and num_added_models == 0:
            return

        config_list = model_server_config_pb2.ModelConfigList()
        current_model_names = self._get_model_names()
        for model_name in current_model_names:
            versions, model_disk_path = self._get_model_info(model_name)
            versions = [int(version) for version in versions]
            model_config = config_list.config.add()
            model_config.name = model_name
            model_config.base_path = model_disk_path
            model_config.model_version_policy.CopyFrom(
                ServableVersionPolicy(specific=Specific(versions=versions))
            )
            model_config.model_platform = "tensorflow"

        model_server_config.model_config_list.CopyFrom(config_list)
        request.config.CopyFrom(model_server_config)

        while max_retries >= 0:
            max_retries -= 1
            try:
                # to prevent HandleReloadConfigRequest from
                # throwing an exception (TFS has some race-condition bug)
                time.sleep(0.125)
                response = self._service.HandleReloadConfigRequest(request, timeout)
                break
            except grpc.RpcError as err:
                # to prevent HandleReloadConfigRequest from
                # throwing another exception on the next run
                time.sleep(0.125)
                raise

        if not (response and response.status.error_code == 0):
            if response:
                raise CortexException(
                    "couldn't load user-requested models {} - failed with error code {}: {}".format(
                        model_names, response.status.error_code, response.status.error_message
                    )
                )
            else:
                raise CortexException("couldn't load user-requested models")

        # get models metadata
        for model_name, versions, signature_key in zip(model_names, model_versions, signature_keys):
            for model_version in versions:
                self._load_model_signatures(model_name, model_version, signature_key)
예제 #27
0
    def add_models_config(self, names, base_paths, replace_models=False):
        request = model_management_pb2.ReloadConfigRequest()
        model_server_config = model_server_config_pb2.ModelServerConfig()

        # create model(s) configuration
        config_list = model_server_config_pb2.ModelConfigList()
        for i, name in enumerate(names):
            model_config = config_list.config.add()
            model_config.name = name
            model_config.base_path = base_paths[i]
            model_config.model_platform = self.model_platform

        if replace_models:
            model_server_config.model_config_list.CopyFrom(config_list)
            request.config.CopyFrom(model_server_config)
        else:
            model_server_config.model_config_list.MergeFrom(config_list)
            request.config.MergeFrom(model_server_config)

        loaded_models = threading.Event()

        def log_loading_models():
            while not loaded_models.is_set():
                time.sleep(5)
                cx_logger().warn("model(s) still loading ...")

        log_thread = threading.Thread(target=log_loading_models, daemon=True)
        log_thread.start()

        # request TFS to load models
        limit = 3
        response = None
        for i in range(limit):
            try:
                # this request doesn't return until all models have been successfully loaded
                response = self.stub.HandleReloadConfigRequest(
                    request, self.timeout)
                break
            except Exception as e:
                if not (isinstance(e, grpc.RpcError) and e.code() in [
                        grpc.StatusCode.UNAVAILABLE,
                        grpc.StatusCode.DEADLINE_EXCEEDED
                ]):
                    print(e)  # unexpected error
                time.sleep(1.0)

        loaded_models.set()
        log_thread.join()

        # report error or success
        if response and response.status.error_code == 0:
            cx_logger().info(
                "successfully loaded {} models into TF-Serving".format(names))
        else:
            if response:
                raise CortexException(
                    "couldn't load user-requested models - failed with error code {}: {}"
                    .format(response.status.error_code,
                            response.status.error_message))
            else:
                raise CortexException("couldn't load user-requested models")
# datetime:1993/12/01
# filename:grpc_client_hot_deploy.py
# software: PyCharm

from google.protobuf import text_format
from tensorflow_serving.apis import model_management_pb2
from tensorflow_serving.apis import model_service_pb2_grpc
from tensorflow_serving.config import model_server_config_pb2
import grpc

channel = grpc.insecure_channel('ip')

config_file = "model.config"
stub = model_service_pb2_grpc.ModelServiceStub(channel)
request = model_management_pb2.ReloadConfigRequest()

# read config file
config_content = open(config_file, "r").read()
model_server_config = model_server_config_pb2.ModelServerConfig()
model_server_config = text_format.Parse(text=config_content,
                                        message=model_server_config)
request.config.CopyFrom(model_server_config)
request_response = stub.HandleReloadConfigRequest(request, 10)

if request_response.status.error_code == 0:
    open(config_file, "w").write(str(request.config))
    print("TF Serving config file updated.")
else:
    print("Failed to update config file.")
    print(request_response.status.error_code)
    print(request_response.status.error_message)