コード例 #1
0
    def Predict(self, request, context):
        """
        Predict -- provides access to loaded TensorFlow model.
        """
        # check if model with was requested
        # is available on server with proper version
        model_name = request.model_spec.name
        requested_version = request.model_spec.version.value
        valid_model_spec, version = check_availability_of_requested_model(
            models=self.models, requested_version=requested_version,
            model_name=model_name)

        if not valid_model_spec:
            context.set_code(StatusCode.NOT_FOUND)
            context.set_details(WRONG_MODEL_METADATA.format(model_name,
                                                            requested_version))
            logger.debug("PREDICT, invalid model spec from request, {} - {}"
                         .format(model_name, requested_version))
            return predict_pb2.PredictResponse()
        start_time = datetime.datetime.now()
        occurred_problem, inference_input, batch_size, code = \
            prepare_input_data(models=self.models, model_name=model_name,
                               version=version, data=request.inputs)
        deserialization_end_time = datetime.datetime.now()
        duration = (deserialization_end_time - start_time)\
            .total_seconds() * 1000
        logger.debug("PREDICT; input deserialization completed; {}; {}; {}ms"
                     .format(model_name, version, duration))
        if occurred_problem:
            context.set_code(code)
            context.set_details(inference_input)
            logger.debug("PREDICT, problem with input data. Exit code {}"
                         .format(code))
            return predict_pb2.PredictResponse()

        inference_start_time = datetime.datetime.now()
        inference_output = self.models[model_name].engines[version] \
            .infer(inference_input, batch_size)
        inference_end_time = datetime.datetime.now()
        duration = (inference_end_time - inference_start_time)\
            .total_seconds() * 1000
        logger.debug("PREDICT; inference execution completed; {}; {}; {}ms"
                     .format(model_name, version, duration))
        response = prepare_output_as_list(inference_output=inference_output,
                                          model_available_outputs=self.models
                                          [model_name].engines[version].
                                          model_keys['outputs'])
        response.model_spec.name = model_name
        response.model_spec.version.value = version
        response.model_spec.signature_name = SIGNATURE_NAME
        serialization_end_time = datetime.datetime.now()
        duration = (serialization_end_time - inference_end_time)\
            .total_seconds() * 1000
        logger.debug("PREDICT; inference results serialization completed;"
                     " {}; {}; {}ms".format(model_name, version, duration))
        return response
コード例 #2
0
    def GetModelMetadata(self, request, context):

        # check if model with was requested
        # is available on server with proper version
        logger.debug("MODEL_METADATA, get request: {}".format(request))
        model_name = request.model_spec.name
        requested_version = request.model_spec.version.value
        valid_model_spec, version = check_availability_of_requested_model(
            models=self.models,
            requested_version=requested_version,
            model_name=model_name)

        if not valid_model_spec:
            context.set_code(StatusCode.NOT_FOUND)
            context.set_details(
                WRONG_MODEL_SPEC.format(model_name, requested_version))
            logger.debug("MODEL_METADATA, invalid model spec from request")
            return get_model_metadata_pb2.GetModelMetadataResponse()
        target_engine = self.models[model_name].engines[version]
        target_engine.in_use.acquire()
        metadata_signature_requested = request.metadata_field[0]
        if 'signature_def' != metadata_signature_requested:
            context.set_code(StatusCode.INVALID_ARGUMENT)
            context.set_details(
                INVALID_METADATA_FIELD.format(metadata_signature_requested))
            logger.debug("MODEL_METADATA, invalid signature def")
            target_engine.in_use.release()
            return get_model_metadata_pb2.GetModelMetadataResponse()

        inputs = target_engine.net.inputs
        outputs = target_engine.net.outputs

        signature_def = prepare_get_metadata_output(
            inputs=inputs,
            outputs=outputs,
            model_keys=target_engine.model_keys)
        response = get_model_metadata_pb2.GetModelMetadataResponse()

        model_data_map = get_model_metadata_pb2.SignatureDefMap()
        model_data_map.signature_def['serving_default'].CopyFrom(signature_def)
        response.metadata['signature_def'].Pack(model_data_map)
        response.model_spec.name = model_name
        response.model_spec.version.value = version
        logger.debug("MODEL_METADATA created a response for {} - {}".format(
            model_name, version))
        target_engine.in_use.release()
        return response
コード例 #3
0
    def on_get(self, req, resp, model_name, requested_version=0):
        logger.debug("MODEL_METADATA, get request: {}, {}".format(
            model_name, requested_version))
        valid_model_spec, version = check_availability_of_requested_model(
            models=self.models,
            requested_version=requested_version,
            model_name=model_name)

        if not valid_model_spec:
            resp.status = falcon.HTTP_NOT_FOUND
            logger.debug("MODEL_METADATA, invalid model spec from request")
            err_out_json = {
                'error': WRONG_MODEL_SPEC.format(model_name, requested_version)
            }
            resp.body = json.dumps(err_out_json)
            return

        target_engine = self.models[model_name].engines[version]
        target_engine.in_use.acquire()

        inputs = target_engine.net.inputs
        outputs = target_engine.net.outputs

        signature_def = prepare_get_metadata_output(
            inputs=inputs,
            outputs=outputs,
            model_keys=target_engine.model_keys)
        response = get_model_metadata_pb2.GetModelMetadataResponse()

        model_data_map = get_model_metadata_pb2.SignatureDefMap()
        model_data_map.signature_def['serving_default'].CopyFrom(signature_def)
        response.metadata['signature_def'].Pack(model_data_map)
        response.model_spec.name = model_name
        response.model_spec.version.value = version
        logger.debug("MODEL_METADATA created a response for {} - {}".format(
            model_name, version))
        target_engine.in_use.release()
        resp.status = falcon.HTTP_200
        resp.body = MessageToJson(response)
コード例 #4
0
def test_check_availability_of_requested_model(mocker, requested_model,
                                               requested_ver, expected_ver,
                                               expected_validation):

    resnet_model_object = {
        'name': 'resnet',
        'versions': [1, 2, 8],
        'default_version': 8
    }
    inception_model_object = {
        'name': 'inception',
        'versions': [3, 4, 5],
        'default_version': 5
    }
    xception_model_object = {
        'name': 'Xception',
        'versions': [1, 6, 8],
        'default_version': 8
    }
    models = [
        resnet_model_object, inception_model_object, xception_model_object
    ]

    available_models = {"resnet": None, 'inception': None, 'Xception': None}
    for x in models:
        model_mocker = mocker.patch('ie_serving.models.model.Model')
        model_mocker.versions = x['versions']
        model_mocker.default_version = x['default_version']
        available_models[x['name']] = model_mocker

    validation, version = service_utils.check_availability_of_requested_model(
        models=available_models,
        model_name=requested_model,
        requested_version=requested_ver)
    assert expected_validation == validation
    assert expected_ver == version
コード例 #5
0
    def on_post(self, req, resp, model_name, requested_version=0):
        valid_model_spec, version = check_availability_of_requested_model(
            models=self.models,
            requested_version=requested_version,
            model_name=model_name)

        if not valid_model_spec:
            resp.status = falcon.HTTP_NOT_FOUND
            logger.debug("PREDICT, invalid model spec from request, "
                         "{} - {}".format(model_name, requested_version))
            err_out_json = {
                'error': WRONG_MODEL_SPEC.format(model_name, requested_version)
            }
            resp.body = json.dumps(err_out_json)
            return
        body = req.media
        if type(body) is not dict:
            resp.status = falcon.HTTP_400
            resp.body = json.dumps({'error': 'Invalid JSON in request body'})
            return
        input_format = get_input_format(
            body, self.models[model_name].engines[version].input_key_names)
        if input_format == INVALID_FORMAT:
            resp.status = falcon.HTTP_400
            resp.body = json.dumps(
                {'error': 'Invalid inputs in request '
                 'body'})
            return

        inputs = preprocess_json_request(
            body, input_format,
            self.models[model_name].engines[version].input_key_names)

        start_time = datetime.datetime.now()
        occurred_problem, inference_input, batch_size, code = \
            prepare_input_data(models=self.models, model_name=model_name,
                               version=version, data=inputs, rest=True)
        deserialization_end_time = datetime.datetime.now()
        duration = \
            (deserialization_end_time - start_time).total_seconds() * 1000
        logger.debug(
            "PREDICT; input deserialization completed; {}; {}; {}ms".format(
                model_name, version, duration))
        if occurred_problem:
            resp.status = code
            err_out_json = {'error': inference_input}
            logger.debug(
                "PREDICT, problem with input data. Exit code {}".format(code))
            resp.body = json.dumps(err_out_json)
            return
        self.models[model_name].engines[version].in_use.acquire()
        inference_start_time = datetime.datetime.now()
        try:
            inference_output = self.models[model_name].engines[version] \
                .infer(inference_input, batch_size)
        except ValueError as error:
            resp.status = falcon.HTTP_400
            err_out_json = {'error': 'Malformed input data'}
            logger.debug("PREDICT, problem with inference. "
                         "Corrupted input: {}".format(error))
            self.models[model_name].engines[version].in_use.release()
            resp.body = json.dumps(err_out_json)
            return
        inference_end_time = datetime.datetime.now()
        self.models[model_name].engines[version].in_use.release()
        duration = \
            (inference_end_time - inference_start_time).total_seconds() * 1000
        logger.debug(
            "PREDICT; inference execution completed; {}; {}; {}ms".format(
                model_name, version, duration))
        for key, value in inference_output.items():
            inference_output[key] = value.tolist()

        response = prepare_json_response(
            OUTPUT_REPRESENTATION[input_format], inference_output,
            self.models[model_name].engines[version].model_keys['outputs'])

        resp.status = falcon.HTTP_200
        resp.body = json.dumps(response)
        serialization_end_time = datetime.datetime.now()
        duration = \
            (serialization_end_time -
             inference_end_time).total_seconds() * 1000
        logger.debug("PREDICT; inference results serialization completed;"
                     " {}; {}; {}ms".format(model_name, version, duration))
        return
コード例 #6
0
    def on_post(self, req, resp, model_name, requested_version=0):
        valid_model_spec, version = check_availability_of_requested_model(
            models=self.models,
            requested_version=requested_version,
            model_name=model_name)

        if not valid_model_spec:
            resp.status = falcon.HTTP_NOT_FOUND
            logger.debug("PREDICT, invalid model spec from request, "
                         "{} - {}".format(model_name, requested_version))
            err_out_json = {
                'error': WRONG_MODEL_SPEC.format(model_name, requested_version)
            }
            resp.body = json.dumps(err_out_json)
            return
        body = req.media
        if type(body) is not dict:
            resp.status = falcon.HTTP_400
            resp.body = json.dumps({'error': 'Invalid JSON in request body'})
            return

        target_engine = self.models[model_name].engines[version]
        input_format = get_input_format(body, target_engine.input_key_names)
        if input_format == INVALID_FORMAT:
            resp.status = falcon.HTTP_400
            resp.body = json.dumps(
                {'error': 'Invalid inputs in request '
                 'body'})
            return

        inputs = preprocess_json_request(body, input_format,
                                         target_engine.input_key_names)

        start_time = datetime.datetime.now()
        inference_input, error_message = \
            prepare_input_data(target_engine=target_engine, data=inputs,
                               service_type=REST)
        deserialization_end_time = datetime.datetime.now()
        duration = \
            (deserialization_end_time - start_time).total_seconds() * 1000
        logger.debug(
            "PREDICT; input deserialization completed; {}; {}; {}ms".format(
                model_name, version, duration))
        if error_message is not None:
            resp.status = code = statusCodes['invalid_arg'][REST]
            err_out_json = {'error': error_message}
            logger.debug(
                "PREDICT, problem with input data. Exit code {}".format(code))
            resp.body = json.dumps(err_out_json)
            return
        target_engine.in_use.acquire()
        ###############################################
        # Reshape network inputs if needed
        reshape_param = target_engine.detect_shapes_incompatibility(
            inference_input)
        if reshape_param is not None:
            error_message = target_engine.reshape(reshape_param)
            if error_message is not None:
                resp.status = falcon.HTTP_400
                err_out_json = {'error': error_message}
                resp.body = json.dumps(err_out_json)
                target_engine.in_use.release()
                return
        ##############################################
        inference_start_time = datetime.datetime.now()
        inference_output, error_message = target_engine.infer(inference_input)
        if error_message is not None:
            resp.status = falcon.HTTP_400
            err_out_json = {'error': error_message}
            resp.body = json.dumps(err_out_json)
            target_engine.in_use.release()
            return
        inference_end_time = datetime.datetime.now()
        target_engine.in_use.release()
        duration = \
            (inference_end_time - inference_start_time).total_seconds() * 1000
        logger.debug(
            "PREDICT; inference execution completed; {}; {}; {}ms".format(
                model_name, version, duration))
        for key, value in inference_output.items():
            inference_output[key] = value.tolist()

        response = prepare_json_response(OUTPUT_REPRESENTATION[input_format],
                                         inference_output,
                                         target_engine.model_keys['outputs'])

        resp.status = falcon.HTTP_200
        resp.body = json.dumps(response)
        serialization_end_time = datetime.datetime.now()
        duration = \
            (serialization_end_time -
             inference_end_time).total_seconds() * 1000
        logger.debug("PREDICT; inference results serialization completed;"
                     " {}; {}; {}ms".format(model_name, version, duration))
        return
コード例 #7
0
    def Predict(self, request, context):
        """
        Predict -- provides access to loaded TensorFlow model.
        """
        # check if requested model
        # is available on server with proper version
        model_name = request.model_spec.name
        requested_version = request.model_spec.version.value
        valid_model_spec, version = check_availability_of_requested_model(
            models=self.models,
            requested_version=requested_version,
            model_name=model_name)

        if not valid_model_spec:
            context.set_code(StatusCode.NOT_FOUND)
            context.set_details(
                WRONG_MODEL_SPEC.format(model_name, requested_version))
            logger.debug(
                "PREDICT, invalid model spec from request, {} - {}".format(
                    model_name, requested_version))
            return predict_pb2.PredictResponse()

        target_engine = self.models[model_name].engines[version]

        deserialization_start_time = datetime.datetime.now()
        inference_input, error_message = \
            prepare_input_data(target_engine=target_engine,
                               data=request.inputs,
                               service_type=GRPC)
        duration = (datetime.datetime.now() -
                    deserialization_start_time).total_seconds() * 1000
        logger.debug(
            "PREDICT; input deserialization completed; {}; {}; {} ms".format(
                model_name, version, duration))
        if error_message is not None:
            code = statusCodes['invalid_arg'][GRPC]
            context.set_code(code)
            context.set_details(error_message)
            logger.debug(
                "PREDICT, problem with input data. Exit code {}".format(code))
            return predict_pb2.PredictResponse()

        target_engine = self.models[model_name].engines[version]
        inference_request = Request(inference_input)
        target_engine.requests_queue.put(inference_request)
        inference_output, used_ireq_index = inference_request.wait_for_result()
        if type(inference_output) is str:
            code = statusCodes['invalid_arg'][GRPC]
            context.set_code(code)
            context.set_details(inference_output)
            logger.debug("PREDICT, problem during inference execution. Exit "
                         "code {}".format(code))
            target_engine.free_ireq_index_queue.put(used_ireq_index)
            return predict_pb2.PredictResponse()
        serialization_start_time = datetime.datetime.now()
        response = prepare_output(
            inference_output=inference_output,
            model_available_outputs=target_engine.model_keys['outputs'])
        response.model_spec.name = model_name
        response.model_spec.version.value = version
        response.model_spec.signature_name = SIGNATURE_NAME
        duration = (datetime.datetime.now() -
                    serialization_start_time).total_seconds() * 1000
        logger.debug("PREDICT; inference results serialization completed;"
                     " {}; {}; {} ms".format(model_name, version, duration))
        target_engine.free_ireq_index_queue.put(used_ireq_index)
        return response
コード例 #8
0
    def Predict(self, request, context):
        """
        Predict -- provides access to loaded TensorFlow model.
        """
        # check if requested model
        # is available on server with proper version
        model_name = request.model_spec.name
        requested_version = request.model_spec.version.value
        valid_model_spec, version = check_availability_of_requested_model(
            models=self.models,
            requested_version=requested_version,
            model_name=model_name)

        if not valid_model_spec:
            context.set_code(StatusCode.NOT_FOUND)
            context.set_details(
                WRONG_MODEL_SPEC.format(model_name, requested_version))
            logger.debug(
                "PREDICT, invalid model spec from request, {} - {}".format(
                    model_name, requested_version))
            return predict_pb2.PredictResponse()

        target_engine = self.models[model_name].engines[version]
        start_time = datetime.datetime.now()
        inference_input, error_message = \
            prepare_input_data(target_engine=target_engine,
                               data=request.inputs,
                               service_type=GRPC)
        deserialization_end_time = datetime.datetime.now()
        duration = \
            (deserialization_end_time - start_time).total_seconds() * 1000
        logger.debug(
            "PREDICT; input deserialization completed; {}; {}; {}ms".format(
                model_name, version, duration))
        if error_message is not None:
            code = statusCodes['invalid_arg'][GRPC]
            context.set_code(code)
            context.set_details(error_message)
            logger.debug(
                "PREDICT, problem with input data. Exit code {}".format(code))
            return predict_pb2.PredictResponse()
        target_engine = self.models[model_name].engines[version]
        target_engine.in_use.acquire()
        ################################################
        # Reshape network inputs if needed
        reshape_param = target_engine.detect_shapes_incompatibility(
            inference_input)
        if reshape_param is not None:
            error_message = target_engine.reshape(reshape_param)
            if error_message is not None:
                code = statusCodes['invalid_arg'][GRPC]
                context.set_code(code)
                context.set_details(error_message)
                target_engine.in_use.release()
                return predict_pb2.PredictResponse()
        ################################################
        inference_start_time = datetime.datetime.now()
        inference_output, error_message = target_engine.infer(inference_input)
        if error_message is not None:
            code = statusCodes['invalid_arg'][GRPC]
            context.set_code(code)
            context.set_details(error_message)
            target_engine.in_use.release()
            return predict_pb2.PredictResponse()
        inference_end_time = datetime.datetime.now()
        target_engine.in_use.release()
        duration = \
            (inference_end_time - inference_start_time).total_seconds() * 1000
        logger.debug(
            "PREDICT; inference execution completed; {}; {}; {}ms".format(
                model_name, version, duration))
        response = prepare_output_as_list(
            inference_output=inference_output,
            model_available_outputs=target_engine.model_keys['outputs'])
        response.model_spec.name = model_name
        response.model_spec.version.value = version
        response.model_spec.signature_name = SIGNATURE_NAME
        serialization_end_time = datetime.datetime.now()
        duration = \
            (serialization_end_time -
             inference_end_time).total_seconds() * 1000
        logger.debug("PREDICT; inference results serialization completed;"
                     " {}; {}; {}ms".format(model_name, version, duration))

        return response