Ejemplo n.º 1
0
def model_metadata_response(response):
    signature_def = response.metadata['signature_def']
    signature_map = get_model_metadata_pb2.SignatureDefMap()
    signature_map.ParseFromString(signature_def.value)
    serving_default = signature_map.ListFields()[0][1]['serving_default']
    serving_inputs = serving_default.inputs
    input_blobs_keys = {key: {} for key in serving_inputs.keys()}
    tensor_shape = {
        key: serving_inputs[key].tensor_shape
        for key in serving_inputs.keys()
    }
    for input_blob in input_blobs_keys:
        inputs_shape = [d.size for d in tensor_shape[input_blob].dim]
        tensor_dtype = serving_inputs[input_blob].dtype
        input_blobs_keys[input_blob].update({'shape': inputs_shape})
        input_blobs_keys[input_blob].update({'dtype': tensor_dtype})

    serving_outputs = serving_default.outputs
    output_blobs_keys = {key: {} for key in serving_outputs.keys()}
    tensor_shape = {
        key: serving_outputs[key].tensor_shape
        for key in serving_outputs.keys()
    }
    for output_blob in output_blobs_keys:
        outputs_shape = [d.size for d in tensor_shape[output_blob].dim]
        tensor_dtype = serving_outputs[output_blob].dtype
        output_blobs_keys[output_blob].update({'shape': outputs_shape})
        output_blobs_keys[output_blob].update({'dtype': tensor_dtype})

    return input_blobs_keys, output_blobs_keys
Ejemplo n.º 2
0
    def cache_prediction_metadata(self):
        channel = implementations.insecure_channel(self.host,
                                                   self.tf_serving_port)
        stub = prediction_service_pb2.beta_create_PredictionService_stub(
            channel)
        request = get_model_metadata_pb2.GetModelMetadataRequest()

        request.model_spec.name = self.model_name
        request.metadata_field.append('signature_def')
        result = stub.GetModelMetadata(request, self.request_timeout)

        _logger.info(
            '---------------------------Model Spec---------------------------')
        _logger.info(json_format.MessageToJson(result))
        _logger.info(
            '----------------------------------------------------------------')

        signature_def = result.metadata['signature_def']
        signature_map = get_model_metadata_pb2.SignatureDefMap()
        signature_map.ParseFromString(signature_def.value)

        serving_default = signature_map.ListFields()[0][1]['serving_default']
        serving_inputs = serving_default.inputs

        self.input_type_map = {
            key: serving_inputs[key].dtype
            for key in serving_inputs.keys()
        }
        self.prediction_type = serving_default.method_name
    def cache_prediction_metadata(self):
        channel = grpc.insecure_channel('{}:{}'.format(self.host,
                                                       self.tf_serving_port),
                                        options=[
                                            ('grpc.max_send_message_length',
                                             MAX_GRPC_MESSAGE_SIZE),
                                            ('grpc.max_receive_message_length',
                                             MAX_GRPC_MESSAGE_SIZE)
                                        ])
        stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
        request = get_model_metadata_pb2.GetModelMetadataRequest()

        request.model_spec.name = self.model_name
        request.metadata_field.append('signature_def')
        result = stub.GetModelMetadata(request, self.request_timeout)

        _logger.info(
            '---------------------------Model Spec---------------------------')
        _logger.info(json_format.MessageToJson(result))
        _logger.info(
            '----------------------------------------------------------------')

        signature_def = result.metadata['signature_def']
        signature_map = get_model_metadata_pb2.SignatureDefMap()
        signature_map.ParseFromString(signature_def.value)

        serving_default = signature_map.ListFields()[0][1]['serving_default']
        serving_inputs = serving_default.inputs

        self.input_type_map = {
            key: serving_inputs[key].dtype
            for key in serving_inputs.keys()
        }
        self.prediction_type = serving_default.method_name
        self.prediction_service_stub = stub
Ejemplo n.º 4
0
def run_get_model_metadata():
    request = create_get_model_metadata_request()
    resp = local_cache["stub"].GetModelMetadata(request, timeout=10.0)
    sigAny = resp.metadata["signature_def"]
    signature_def_map = get_model_metadata_pb2.SignatureDefMap()
    sigAny.Unpack(signature_def_map)
    sigmap = json_format.MessageToDict(signature_def_map)
    return sigmap
Ejemplo n.º 5
0
    def _load_model_signatures(
        self, model_name: str, model_version: str, signature_key: Optional[str] = None
    ) -> None:
        """
        Queries the signature defs from TFS.

        Args:
            model_name: Name of the model.
            model_version: Version of the model.
            signature_key: Signature key of the model as passed in with predictor:signature_key, predictor:models:paths:signature_key or predictor:models:signature_key.
                When set to None, "predict" is the assumed key.

        Raises:
            cortex_internal.lib.exceptions.UserException when the signature def can't be validated.
        """

        # create model metadata request
        request = get_model_metadata_pb2.GetModelMetadataRequest()
        request.model_spec.name = model_name
        request.model_spec.version.value = int(model_version)
        request.metadata_field.append("signature_def")

        # get signature def
        last_idx = 0
        for times in range(100):
            try:
                resp = self._pred.GetModelMetadata(request)
                break
            except grpc.RpcError as e:
                # it has been observed that it may take a little bit of time
                # until a model gets to be accessible with TFS (even though it's already loaded in)
                time.sleep(0.3)
            last_idx = times
        if last_idx == 99:
            raise UserException(
                "couldn't find model '{}' of version '{}' to extract the signature def".format(
                    model_name, model_version
                )
            )

        sigAny = resp.metadata["signature_def"]
        signature_def_map = get_model_metadata_pb2.SignatureDefMap()
        sigAny.Unpack(signature_def_map)
        sigmap = json_format.MessageToDict(signature_def_map)
        signature_def = sigmap["signatureDef"]

        # extract signature key and input signature
        signature_key, input_signatures = self._extract_signatures(
            signature_def, signature_key, model_name, model_version
        )

        model_id = f"{model_name}-{model_version}"
        self.models[model_id]["signature_def"] = signature_def
        self.models[model_id]["signature_key"] = signature_key
        self.models[model_id]["input_signatures"] = input_signatures
Ejemplo n.º 6
0
 def get_metadata(self, model_name, signature_name, timeout):
     field = 'signature_def'
     request = get_model_metadata_pb2.GetModelMetadataRequest()
     request.model_spec.name = model_name
     request.metadata_field.append(field)
     response = self.stub.GetModelMetadata(request, timeout)
     print(response.model_spec)
     raw_value = response.metadata[field].value
     signature_map = get_model_metadata_pb2.SignatureDefMap()
     signature_map.MergeFromString(raw_value)
     print(signature_map.signature_def[signature_name])
Ejemplo n.º 7
0
    def get_io(self, sub_network):
        metadata_request = get_model_metadata_pb2.GetModelMetadataRequest()
        metadata_request.model_spec.name = sub_network
        metadata_request.metadata_field.append("signature_def")
        result = self.prediction_service.GetModelMetadata(metadata_request, self.timeout)

        signature_def_map = get_model_metadata_pb2.SignatureDefMap()
        result.metadata['signature_def'].Unpack(signature_def_map)
        default_signature_def = signature_def_map.signature_def['serving_default']
        return  list(default_signature_def.inputs),\
                [(output_name, [dim.size for dim in metadata.tensor_shape.dim]) for output_name, metadata in sorted(default_signature_def.outputs.items(), key=lambda output: output[1].name)]
Ejemplo n.º 8
0
    def GetModelMetadata(self, request, context):

        # check if model with was requested
        # is available on server with proper version
        logger.debug("MODEL_METADATA, get request: {}".format(request))
        model_name = request.model_spec.name
        requested_version = request.model_spec.version.value
        valid_model_spec, version = check_availability_of_requested_model(
            models=self.models,
            requested_version=requested_version,
            model_name=model_name)

        if not valid_model_spec:
            context.set_code(StatusCode.NOT_FOUND)
            context.set_details(
                WRONG_MODEL_SPEC.format(model_name, requested_version))
            logger.debug("MODEL_METADATA, invalid model spec from request")
            return get_model_metadata_pb2.GetModelMetadataResponse()
        target_engine = self.models[model_name].engines[version]
        target_engine.in_use.acquire()
        metadata_signature_requested = request.metadata_field[0]
        if 'signature_def' != metadata_signature_requested:
            context.set_code(StatusCode.INVALID_ARGUMENT)
            context.set_details(
                INVALID_METADATA_FIELD.format(metadata_signature_requested))
            logger.debug("MODEL_METADATA, invalid signature def")
            target_engine.in_use.release()
            return get_model_metadata_pb2.GetModelMetadataResponse()

        inputs = target_engine.net.inputs
        outputs = target_engine.net.outputs

        signature_def = prepare_get_metadata_output(
            inputs=inputs,
            outputs=outputs,
            model_keys=target_engine.model_keys)
        response = get_model_metadata_pb2.GetModelMetadataResponse()

        model_data_map = get_model_metadata_pb2.SignatureDefMap()
        model_data_map.signature_def['serving_default'].CopyFrom(signature_def)
        response.metadata['signature_def'].Pack(model_data_map)
        response.model_spec.name = model_name
        response.model_spec.version.value = version
        logger.debug("MODEL_METADATA created a response for {} - {}".format(
            model_name, version))
        target_engine.in_use.release()
        return response
Ejemplo n.º 9
0
def get_signature_map(model_server_stub, model_name):
    """
  Gets tensorflow signature map from the model server stub.

  Args:
    model_server_stub: The grpc stub to call GetModelMetadata.
    model_name: The model name.

  Returns:
    The signature map of the model.
  """
    request = get_model_metadata_pb2.GetModelMetadataRequest()
    request.model_spec.name = model_name
    request.metadata_field.append("signature_def")
    try:
        response = model_server_stub.GetModelMetadata(
            request, MODEL_SERVER_METADATA_TIMEOUT_SEC)
    except grpc.RpcError as rpc_error:
        logging.exception(
            "GetModelMetadata call to model server failed with code "
            "%s and message %s", rpc_error.code(), rpc_error.details())
        return None

    signature_def_map_proto = get_model_metadata_pb2.SignatureDefMap()
    response.metadata["signature_def"].Unpack(signature_def_map_proto)
    signature_def_map = signature_def_map_proto.signature_def
    if not signature_def_map:
        logging.error("Graph has no signatures.")

    # Delete incomplete signatures without input dtypes.
    invalid_signatures = []
    for signature_name in signature_def_map:
        for tensor in signature_def_map[signature_name].inputs.itervalues():
            if not tensor.dtype:
                logging.warn(
                    "Signature %s has incomplete dtypes, removing from "
                    "usable signatures", signature_name)
                invalid_signatures.append(signature_name)
                break
    for signature_name in invalid_signatures:
        del signature_def_map[signature_name]

    return signature_def_map
Ejemplo n.º 10
0
def get_signature_def(stub):
    limit = 60
    for i in range(limit):
        try:
            request = create_get_model_metadata_request()
            resp = stub.GetModelMetadata(request, timeout=10.0)
            sigAny = resp.metadata["signature_def"]
            signature_def_map = get_model_metadata_pb2.SignatureDefMap()
            sigAny.Unpack(signature_def_map)
            sigmap = json_format.MessageToDict(signature_def_map)
            return sigmap["signatureDef"]
        except:
            if i > 6:
                cx_logger().warn(
                    "unable to read model metadata - model is still loading, retrying..."
                )

        time.sleep(5)

    raise CortexException("timeout: unable to read model metadata")
Ejemplo n.º 11
0
def get_signature_def(stub, model):
    limit = 2
    for i in range(limit):
        try:
            request = create_get_model_metadata_request(model.name)
            resp = stub.GetModelMetadata(request, timeout=10.0)
            sigAny = resp.metadata["signature_def"]
            signature_def_map = get_model_metadata_pb2.SignatureDefMap()
            sigAny.Unpack(signature_def_map)
            sigmap = json_format.MessageToDict(signature_def_map)
            return sigmap["signatureDef"]
        except Exception as e:
            print(e)
            cx_logger().warn(
                "unable to read model metadata for model '{}' - retrying ...".
                format(model.name))

        time.sleep(5)

    raise CortexException(
        "timeout: unable to read model metadata for model '{}'".format(
            model.name))
Ejemplo n.º 12
0
    def on_get(self, req, resp, model_name, requested_version=0):
        logger.debug("MODEL_METADATA, get request: {}, {}".format(
            model_name, requested_version))
        valid_model_spec, version = check_availability_of_requested_model(
            models=self.models,
            requested_version=requested_version,
            model_name=model_name)

        if not valid_model_spec:
            resp.status = falcon.HTTP_NOT_FOUND
            logger.debug("MODEL_METADATA, invalid model spec from request")
            err_out_json = {
                'error': WRONG_MODEL_SPEC.format(model_name, requested_version)
            }
            resp.body = json.dumps(err_out_json)
            return

        target_engine = self.models[model_name].engines[version]
        target_engine.in_use.acquire()

        inputs = target_engine.net.inputs
        outputs = target_engine.net.outputs

        signature_def = prepare_get_metadata_output(
            inputs=inputs,
            outputs=outputs,
            model_keys=target_engine.model_keys)
        response = get_model_metadata_pb2.GetModelMetadataResponse()

        model_data_map = get_model_metadata_pb2.SignatureDefMap()
        model_data_map.signature_def['serving_default'].CopyFrom(signature_def)
        response.metadata['signature_def'].Pack(model_data_map)
        response.model_spec.name = model_name
        response.model_spec.version.value = version
        logger.debug("MODEL_METADATA created a response for {} - {}".format(
            model_name, version))
        target_engine.in_use.release()
        resp.status = falcon.HTTP_200
        resp.body = MessageToJson(response)
Ejemplo n.º 13
0
def get_signature_def(stub):
    limit = 60
    for i in range(limit):
        try:
            request = create_get_model_metadata_request()
            resp = stub.GetModelMetadata(request, timeout=10.0)
            sigAny = resp.metadata["signature_def"]
            signature_def_map = get_model_metadata_pb2.SignatureDefMap()
            sigAny.Unpack(signature_def_map)
            sigmap = json_format.MessageToDict(signature_def_map)
            return sigmap["signatureDef"]
        except Exception as e:
            if isinstance(e, grpc.RpcError) and e.code() == grpc.StatusCode.UNAVAILABLE:
                if i > 6:  # only start logging this after 30 seconds
                    cx_logger().warn(
                        "unable to read model metadata - model is still loading, retrying..."
                    )
            else:
                print(e)  # unexpected error
                cx_logger().warn("unable to read model metadata - retrying...")

        time.sleep(5)

    raise CortexException("timeout: unable to read model metadata")