def Predict(self, request, context): """ Predict -- provides access to loaded TensorFlow model. """ # check if model with was requested # is available on server with proper version model_name = request.model_spec.name requested_version = request.model_spec.version.value valid_model_spec, version = check_availability_of_requested_model( models=self.models, requested_version=requested_version, model_name=model_name) if not valid_model_spec: context.set_code(StatusCode.NOT_FOUND) context.set_details(WRONG_MODEL_METADATA.format(model_name, requested_version)) logger.debug("PREDICT, invalid model spec from request, {} - {}" .format(model_name, requested_version)) return predict_pb2.PredictResponse() start_time = datetime.datetime.now() occurred_problem, inference_input, batch_size, code = \ prepare_input_data(models=self.models, model_name=model_name, version=version, data=request.inputs) deserialization_end_time = datetime.datetime.now() duration = (deserialization_end_time - start_time)\ .total_seconds() * 1000 logger.debug("PREDICT; input deserialization completed; {}; {}; {}ms" .format(model_name, version, duration)) if occurred_problem: context.set_code(code) context.set_details(inference_input) logger.debug("PREDICT, problem with input data. Exit code {}" .format(code)) return predict_pb2.PredictResponse() inference_start_time = datetime.datetime.now() inference_output = self.models[model_name].engines[version] \ .infer(inference_input, batch_size) inference_end_time = datetime.datetime.now() duration = (inference_end_time - inference_start_time)\ .total_seconds() * 1000 logger.debug("PREDICT; inference execution completed; {}; {}; {}ms" .format(model_name, version, duration)) response = prepare_output_as_list(inference_output=inference_output, model_available_outputs=self.models [model_name].engines[version]. model_keys['outputs']) response.model_spec.name = model_name response.model_spec.version.value = version response.model_spec.signature_name = SIGNATURE_NAME serialization_end_time = datetime.datetime.now() duration = (serialization_end_time - inference_end_time)\ .total_seconds() * 1000 logger.debug("PREDICT; inference results serialization completed;" " {}; {}; {}ms".format(model_name, version, duration)) return response
def on_post(self, req, resp, model_name, requested_version=0): valid_model_spec, version = check_availability_of_requested_model( models=self.models, requested_version=requested_version, model_name=model_name) if not valid_model_spec: resp.status = falcon.HTTP_NOT_FOUND logger.debug("PREDICT, invalid model spec from request, " "{} - {}".format(model_name, requested_version)) err_out_json = { 'error': WRONG_MODEL_SPEC.format(model_name, requested_version) } resp.body = json.dumps(err_out_json) return body = req.media if type(body) is not dict: resp.status = falcon.HTTP_400 resp.body = json.dumps({'error': 'Invalid JSON in request body'}) return input_format = get_input_format( body, self.models[model_name].engines[version].input_key_names) if input_format == INVALID_FORMAT: resp.status = falcon.HTTP_400 resp.body = json.dumps( {'error': 'Invalid inputs in request ' 'body'}) return inputs = preprocess_json_request( body, input_format, self.models[model_name].engines[version].input_key_names) start_time = datetime.datetime.now() occurred_problem, inference_input, batch_size, code = \ prepare_input_data(models=self.models, model_name=model_name, version=version, data=inputs, rest=True) deserialization_end_time = datetime.datetime.now() duration = \ (deserialization_end_time - start_time).total_seconds() * 1000 logger.debug( "PREDICT; input deserialization completed; {}; {}; {}ms".format( model_name, version, duration)) if occurred_problem: resp.status = code err_out_json = {'error': inference_input} logger.debug( "PREDICT, problem with input data. Exit code {}".format(code)) resp.body = json.dumps(err_out_json) return self.models[model_name].engines[version].in_use.acquire() inference_start_time = datetime.datetime.now() try: inference_output = self.models[model_name].engines[version] \ .infer(inference_input, batch_size) except ValueError as error: resp.status = falcon.HTTP_400 err_out_json = {'error': 'Malformed input data'} logger.debug("PREDICT, problem with inference. " "Corrupted input: {}".format(error)) self.models[model_name].engines[version].in_use.release() resp.body = json.dumps(err_out_json) return inference_end_time = datetime.datetime.now() self.models[model_name].engines[version].in_use.release() duration = \ (inference_end_time - inference_start_time).total_seconds() * 1000 logger.debug( "PREDICT; inference execution completed; {}; {}; {}ms".format( model_name, version, duration)) for key, value in inference_output.items(): inference_output[key] = value.tolist() response = prepare_json_response( OUTPUT_REPRESENTATION[input_format], inference_output, self.models[model_name].engines[version].model_keys['outputs']) resp.status = falcon.HTTP_200 resp.body = json.dumps(response) serialization_end_time = datetime.datetime.now() duration = \ (serialization_end_time - inference_end_time).total_seconds() * 1000 logger.debug("PREDICT; inference results serialization completed;" " {}; {}; {}ms".format(model_name, version, duration)) return
def Predict(self, request, context): """ Predict -- provides access to loaded TensorFlow model. """ # check if requested model # is available on server with proper version model_name = request.model_spec.name requested_version = request.model_spec.version.value valid_model_spec, version = check_availability_of_requested_model( models=self.models, requested_version=requested_version, model_name=model_name) if not valid_model_spec: context.set_code(StatusCode.NOT_FOUND) context.set_details( WRONG_MODEL_SPEC.format(model_name, requested_version)) logger.debug( "PREDICT, invalid model spec from request, {} - {}".format( model_name, requested_version)) return predict_pb2.PredictResponse() target_engine = self.models[model_name].engines[version] deserialization_start_time = datetime.datetime.now() inference_input, error_message = \ prepare_input_data(target_engine=target_engine, data=request.inputs, service_type=GRPC) duration = (datetime.datetime.now() - deserialization_start_time).total_seconds() * 1000 logger.debug( "PREDICT; input deserialization completed; {}; {}; {} ms".format( model_name, version, duration)) if error_message is not None: code = statusCodes['invalid_arg'][GRPC] context.set_code(code) context.set_details(error_message) logger.debug( "PREDICT, problem with input data. Exit code {}".format(code)) return predict_pb2.PredictResponse() target_engine = self.models[model_name].engines[version] inference_request = Request(inference_input) target_engine.requests_queue.put(inference_request) inference_output, used_ireq_index = inference_request.wait_for_result() if type(inference_output) is str: code = statusCodes['invalid_arg'][GRPC] context.set_code(code) context.set_details(inference_output) logger.debug("PREDICT, problem during inference execution. Exit " "code {}".format(code)) target_engine.free_ireq_index_queue.put(used_ireq_index) return predict_pb2.PredictResponse() serialization_start_time = datetime.datetime.now() response = prepare_output( inference_output=inference_output, model_available_outputs=target_engine.model_keys['outputs']) response.model_spec.name = model_name response.model_spec.version.value = version response.model_spec.signature_name = SIGNATURE_NAME duration = (datetime.datetime.now() - serialization_start_time).total_seconds() * 1000 logger.debug("PREDICT; inference results serialization completed;" " {}; {}; {} ms".format(model_name, version, duration)) target_engine.free_ireq_index_queue.put(used_ireq_index) return response
def on_post(self, req, resp, model_name, requested_version=0): valid_model_spec, version = check_availability_of_requested_model( models=self.models, requested_version=requested_version, model_name=model_name) if not valid_model_spec: resp.status = falcon.HTTP_NOT_FOUND logger.debug("PREDICT, invalid model spec from request, " "{} - {}".format(model_name, requested_version)) err_out_json = { 'error': WRONG_MODEL_SPEC.format(model_name, requested_version) } resp.body = json.dumps(err_out_json) return body = req.media if type(body) is not dict: resp.status = falcon.HTTP_400 resp.body = json.dumps({'error': 'Invalid JSON in request body'}) return target_engine = self.models[model_name].engines[version] input_format = get_input_format(body, target_engine.input_key_names) if input_format == INVALID_FORMAT: resp.status = falcon.HTTP_400 resp.body = json.dumps( {'error': 'Invalid inputs in request ' 'body'}) return inputs = preprocess_json_request(body, input_format, target_engine.input_key_names) start_time = datetime.datetime.now() inference_input, error_message = \ prepare_input_data(target_engine=target_engine, data=inputs, service_type=REST) deserialization_end_time = datetime.datetime.now() duration = \ (deserialization_end_time - start_time).total_seconds() * 1000 logger.debug( "PREDICT; input deserialization completed; {}; {}; {}ms".format( model_name, version, duration)) if error_message is not None: resp.status = code = statusCodes['invalid_arg'][REST] err_out_json = {'error': error_message} logger.debug( "PREDICT, problem with input data. Exit code {}".format(code)) resp.body = json.dumps(err_out_json) return target_engine.in_use.acquire() ############################################### # Reshape network inputs if needed reshape_param = target_engine.detect_shapes_incompatibility( inference_input) if reshape_param is not None: error_message = target_engine.reshape(reshape_param) if error_message is not None: resp.status = falcon.HTTP_400 err_out_json = {'error': error_message} resp.body = json.dumps(err_out_json) target_engine.in_use.release() return ############################################## inference_start_time = datetime.datetime.now() inference_output, error_message = target_engine.infer(inference_input) if error_message is not None: resp.status = falcon.HTTP_400 err_out_json = {'error': error_message} resp.body = json.dumps(err_out_json) target_engine.in_use.release() return inference_end_time = datetime.datetime.now() target_engine.in_use.release() duration = \ (inference_end_time - inference_start_time).total_seconds() * 1000 logger.debug( "PREDICT; inference execution completed; {}; {}; {}ms".format( model_name, version, duration)) for key, value in inference_output.items(): inference_output[key] = value.tolist() response = prepare_json_response(OUTPUT_REPRESENTATION[input_format], inference_output, target_engine.model_keys['outputs']) resp.status = falcon.HTTP_200 resp.body = json.dumps(response) serialization_end_time = datetime.datetime.now() duration = \ (serialization_end_time - inference_end_time).total_seconds() * 1000 logger.debug("PREDICT; inference results serialization completed;" " {}; {}; {}ms".format(model_name, version, duration)) return
def Predict(self, request, context): """ Predict -- provides access to loaded TensorFlow model. """ # check if requested model # is available on server with proper version model_name = request.model_spec.name requested_version = request.model_spec.version.value valid_model_spec, version = check_availability_of_requested_model( models=self.models, requested_version=requested_version, model_name=model_name) if not valid_model_spec: context.set_code(StatusCode.NOT_FOUND) context.set_details( WRONG_MODEL_SPEC.format(model_name, requested_version)) logger.debug( "PREDICT, invalid model spec from request, {} - {}".format( model_name, requested_version)) return predict_pb2.PredictResponse() target_engine = self.models[model_name].engines[version] start_time = datetime.datetime.now() inference_input, error_message = \ prepare_input_data(target_engine=target_engine, data=request.inputs, service_type=GRPC) deserialization_end_time = datetime.datetime.now() duration = \ (deserialization_end_time - start_time).total_seconds() * 1000 logger.debug( "PREDICT; input deserialization completed; {}; {}; {}ms".format( model_name, version, duration)) if error_message is not None: code = statusCodes['invalid_arg'][GRPC] context.set_code(code) context.set_details(error_message) logger.debug( "PREDICT, problem with input data. Exit code {}".format(code)) return predict_pb2.PredictResponse() target_engine = self.models[model_name].engines[version] target_engine.in_use.acquire() ################################################ # Reshape network inputs if needed reshape_param = target_engine.detect_shapes_incompatibility( inference_input) if reshape_param is not None: error_message = target_engine.reshape(reshape_param) if error_message is not None: code = statusCodes['invalid_arg'][GRPC] context.set_code(code) context.set_details(error_message) target_engine.in_use.release() return predict_pb2.PredictResponse() ################################################ inference_start_time = datetime.datetime.now() inference_output, error_message = target_engine.infer(inference_input) if error_message is not None: code = statusCodes['invalid_arg'][GRPC] context.set_code(code) context.set_details(error_message) target_engine.in_use.release() return predict_pb2.PredictResponse() inference_end_time = datetime.datetime.now() target_engine.in_use.release() duration = \ (inference_end_time - inference_start_time).total_seconds() * 1000 logger.debug( "PREDICT; inference execution completed; {}; {}; {}ms".format( model_name, version, duration)) response = prepare_output_as_list( inference_output=inference_output, model_available_outputs=target_engine.model_keys['outputs']) response.model_spec.name = model_name response.model_spec.version.value = version response.model_spec.signature_name = SIGNATURE_NAME serialization_end_time = datetime.datetime.now() duration = \ (serialization_end_time - inference_end_time).total_seconds() * 1000 logger.debug("PREDICT; inference results serialization completed;" " {}; {}; {}ms".format(model_name, version, duration)) return response