Ejemplo n.º 1
0
    def collect(self):
        metric = GaugeMetricFamily(name='tf_serving_model_state',
                                   documentation='model state on tf_serving',
                                   labels=['model_name', 'model_version'])

        for n in self.model_name:
            # create request
            request = get_model_status_pb2.GetModelStatusRequest()
            request.model_spec.name = n
            try:
                result_future = self.stub.GetModelStatus.future(
                    request, self.timeout)
                model_version_status = result_future.result(
                ).model_version_status
            except AbortionError as e:
                logging.exception(
                    'AbortionError on GetModelStatus of {}: {}'.format(
                        n, e.details))
            except Exception as e:
                logging.exception(
                    'Exeption on GetModelStatus of {}: {}'.format(
                        n, e.message))
            else:
                # success to connect to serving
                for model in model_version_status:
                    metric.add_metric(labels=[n, str(model.version)],
                                      value=int(model_available(model.state)))
                    logging.debug(
                        'Add metric: name:{}, version:{}, state:{}'.format(
                            n, model.version, model.state))
                yield metric
def main(_):
    if MODE.STATUS == FLAGS.mode:
        request = get_model_status_pb2.GetModelStatusRequest()
        request.model_spec.name = 'detection'
        request.model_spec.signature_name = 'serving_default'
    elif MODE.CONFIG == FLAGS.mode:
        request = model_management_pb2.ReloadConfigRequest()
        config = request.config.model_config_list.config.add()
        config.name = 'detection'
        config.base_path = '/models/detection/detection'
        config.model_platform = 'tensorflow'
        config.model_version_policy.specific.versions.append(5)
        config.model_version_policy.specific.versions.append(7)
        config2 = request.config.model_config_list.config.add()
        config2.name = 'pascal'
        config2.base_path = '/models/detection/pascal'
        config2.model_platform = 'tensorflow'
    elif MODE.ZOOKEEPER == FLAGS.mode:
        zk = KazooClient(hosts="10.10.67.225:2181")
        zk.start()
        zk.ensure_path('/serving/cunan')
        zk.set(
            '/serving/cunan',
            get_config('detection', 5, 224, 'serving_default',
                       ','.join(get_classes('model_data/cci.names')),
                       "10.12.102.32:8000"))
        return
    for address in FLAGS.addresses:
        channel = grpc.insecure_channel(address)
        stub = model_service_pb2_grpc.ModelServiceStub(channel)
        if MODE.STATUS == FLAGS.mode:
            result = stub.GetModelStatus(request)
        elif MODE.CONFIG == FLAGS.mode:
            result = stub.HandleReloadConfigRequest(request)
        print(result)
Ejemplo n.º 3
0
    def _GetModelStatus(self) -> get_model_status_pb2.GetModelStatusResponse:
        """Call GetModelStatus() from model service.

    https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/model_service.proto

    Returns:
      GetModelStatusResponse from GetModelStatus().
    """
        request = get_model_status_pb2.GetModelStatusRequest(
            model_spec=model_pb2.ModelSpec(name=self._model_name))
        return self._model_service.GetModelStatus(request)
Ejemplo n.º 4
0
    def is_tfs_accessible(self) -> bool:
        """
        Tests whether TFS is accessible or not.
        """
        request = get_model_status_pb2.GetModelStatusRequest()
        request.model_spec.name = "test-model-name"

        try:
            self._service.GetModelStatus(request, timeout=10.0)
        except grpc.RpcError as error:
            if error.code() in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.DEADLINE_EXCEEDED]:
                return False
        return True
    def health_check(self, name, signature_name, version=None):
        """
        """
        request = get_model_status_pb2.GetModelStatusRequest()
        request.model_spec.name = name
        request.model_spec.signature_name = signature_name
        if version:
            request.model_spec.version.value = version

        stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
        try:
            response = stub.GetModelStatus(request, 10)
            if len(response.model_version_status) > 0:
                return True
        except Exception as err:
            logging.exception(err)
            return False
Ejemplo n.º 6
0
    def testGetModelStatus(self):
        """Test ModelService.GetModelStatus implementation."""
        model_path = self._GetSavedModelBundlePath()
        model_server_address = TensorflowModelServerTest.RunServer(
            'default', model_path)[1]

        print('Sending GetModelStatus request...')
        # Send request
        request = get_model_status_pb2.GetModelStatusRequest()
        request.model_spec.name = 'default'
        channel = grpc.insecure_channel(model_server_address)
        stub = model_service_pb2_grpc.ModelServiceStub(channel)
        result = stub.GetModelStatus(request, RPC_TIMEOUT)  # 5 secs timeout
        # Verify response
        self.assertEqual(1, len(result.model_version_status))
        self.assertEqual(123, result.model_version_status[0].version)
        # OK error code (0) indicates no error occurred
        self.assertEqual(0, result.model_version_status[0].status.error_code)
Ejemplo n.º 7
0
def main(_):
    channel = grpc.insecure_channel(FLAGS.address)

    stub = model_service_pb2_grpc.ModelServiceStub(channel)
    if MODE.STATUS == FLAGS.mode:
        request = get_model_status_pb2.GetModelStatusRequest()
        request.model_spec.name = 'pascal'
        request.model_spec.signature_name = 'serving_default'
        result = stub.GetModelStatus(request)
    elif MODE.CONFIG == FLAGS.mode:
        request = model_management_pb2.ReloadConfigRequest()
        config = request.config.model_config_list.config.add()
        config.name = 'detection'
        config.base_path = '/models/detection/detection'
        config.model_platform = 'tensorflow'
        config2 = request.config.model_config_list.config.add()
        config2.name = 'pascal'
        config2.base_path = '/models/detection/pascal'
        config2.model_platform = 'tensorflow'
        result = stub.HandleReloadConfigRequest(request)

    print(result)
Ejemplo n.º 8
0
    def poll_available_model_versions(self, model_name: str) -> List[str]:
        """
        Gets the available model versions from TFS.

        Args:
            model_name: The model name to check for versions.

        Returns:
            List of the available versions for the given model from TFS.
        """
        request = get_model_status_pb2.GetModelStatusRequest()
        request.model_spec.name = model_name

        versions = []

        try:
            for model in self._service.GetModelStatus(request).model_version_status:
                if model.state == get_model_status_pb2.ModelVersionStatus.AVAILABLE:
                    versions.append(str(model.version))
        except grpc.RpcError as e:
            pass

        return versions
Ejemplo n.º 9
0
def prepare_stub_and_request(address,
                             model_name,
                             model_version=None,
                             creds=None,
                             opts=None,
                             request_type=INFERENCE_REQUEST):
    if opts is not None:
        opts = (('grpc.ssl_target_name_override', opts), )
    if creds is not None:
        channel = grpc.secure_channel(address, creds, options=opts)
    else:
        channel = grpc.insecure_channel(address, options=opts)
    request = None
    stub = None
    if request_type == MODEL_STATUS_REQUEST:
        request = get_model_status_pb2.GetModelStatusRequest()
        stub = model_service_pb2_grpc.ModelServiceStub(channel)
    elif request_type == INFERENCE_REQUEST:
        stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
        request = predict_pb2.PredictRequest()
    request.model_spec.name = model_name
    if model_version is not None:
        request.model_spec.version.value = model_version
    return stub, request
Ejemplo n.º 10
0
def get_fake_model_status_request(model_name, version=None):
    request = get_model_status_pb2.GetModelStatusRequest()
    request.model_spec.name = model_name
    if version is not None:
        request.model_spec.version.value = version
    return request
Ejemplo n.º 11
0
                    required=False,
                    default=9000,
                    help='Specify port to grpc service. default: 9000')
parser.add_argument('--model_name',
                    default='resnet',
                    help='Model name to query. default: resnet',
                    dest='model_name')
parser.add_argument(
    '--model_version',
    type=int,
    help='Model version to query. Lists all versions if omitted',
    dest='model_version')
args = vars(parser.parse_args())

channel = grpc.insecure_channel("{}:{}".format(args['grpc_address'],
                                               args['grpc_port']))

stub = model_service_pb2_grpc.ModelServiceStub(channel)

print('Getting model status for model:', args.get('model_name'))

request = get_model_status_pb2.GetModelStatusRequest()
request.model_spec.name = args.get('model_name')
if args.get('model_version') is not None:
    request.model_spec.version.value = args.get('model_version')

result = stub.GetModelStatus(
    request, 10.0)  # result includes a dictionary with all model outputs

print_status_response(response=result)
Ejemplo n.º 12
0
 def __init__(self, model_spec=None):
     super().__init__(get_model_status_pb2.GetModelStatusRequest(), 
                      model_spec=model_spec)