def post_request_metrics(self, status_code, start_time, prediction_value=None): metrics_list = [] metrics_list.append(self.status_code_metric(status_code)) if prediction_value is not None: metrics_list.append(self.prediction_metrics(prediction_value)) metrics_list.append(self.latency_metric(start_time)) try: if self.statsd is None: raise CortexException( "statsd client not initialized") # unexpected for metric in metrics_list: tags = [ "{}:{}".format(dim["Name"], dim["Value"]) for dim in metric["Dimensions"] ] if metric.get("Unit") == "Count": self.statsd.increment(metric["MetricName"], value=metric["Value"], tags=tags) else: self.statsd.histogram(metric["MetricName"], value=metric["Value"], tags=tags) except: cx_logger().warn("failure encountered while publishing metrics", exc_info=True)
def predict( request: Any = Body(..., media_type="application/json"), debug=False): api = local_cache["api"] predictor_impl = local_cache["predictor_impl"] debug_obj("payload", request, debug) prediction = predictor_impl.predict(request) debug_obj("prediction", prediction, debug) try: json_string = json.dumps(prediction) except: json_string = util.json_tricks_encoder().encode(prediction) response = Response(content=json_string, media_type="application/json") if api.tracker is not None: try: predicted_value = api.tracker.extract_predicted_value(prediction) api.post_tracker_metrics(predicted_value) if (api.tracker.model_type == "classification" and predicted_value not in local_cache["class_set"]): tasks = BackgroundTasks() tasks.add_task(api.upload_class, class_name=predicted_value) local_cache["class_set"].add(predicted_value) response.background = tasks except: cx_logger().warn("unable to record prediction metric", exc_info=True) return response
def after_request(response): response.headers["Access-Control-Allow-Origin"] = "*" response.headers["Access-Control-Allow-Headers"] = request.headers.get( "Access-Control-Request-Headers", "*" ) if not (request.path == "/predict" and request.method == "POST"): return response api = local_cache["api"] cx_logger().info(response.status) prediction = None if "prediction" in g: prediction = g.prediction try: if api.tracker is not None: predicted_value = api.tracker.extract_predicted_value(prediction) api.post_request_metrics(response.status_code, g.start_time, predicted_value) if predicted_value is not None and predicted_value not in local_cache["class_set"]: api.upload_class(predicted_value) local_cache["class_set"].add(predicted_value) except Exception as e: cx_logger().warn("unable to record prediction metric", exc_info=True) return response
def predict(): debug = request.args.get("debug", "false").lower() == "true" try: payload = request.get_json() except: return "malformed json", status.HTTP_400_BAD_REQUEST api = local_cache["api"] predictor = local_cache["predictor"] try: try: debug_obj("payload", payload, debug) output = predictor.predict(payload, api["predictor"]["metadata"]) debug_obj("prediction", output, debug) except Exception as e: raise UserRuntimeException(api["predictor"]["path"], "predict", str(e)) from e except Exception as e: cx_logger().exception("prediction failed") return prediction_failed(str(e)) g.prediction = output return jsonify(output)
def start(args): download = json.loads(base64.urlsafe_b64decode(args.download)) for download_arg in download: from_path = download_arg["from"] to_path = download_arg["to"] item_name = download_arg.get("item_name", "") bucket_name, prefix = S3.deconstruct_s3_path(from_path) s3_client = S3(bucket_name, client_config={}) if item_name != "": cx_logger().info("downloading {} from {}".format(item_name, from_path)) s3_client.download(prefix, to_path) if download_arg.get("unzip", False): if item_name != "": cx_logger().info("unzipping {}".format(item_name)) util.extract_zip( os.path.join(to_path, os.path.basename(from_path)), delete_zip_file=True ) if download_arg.get("tf_model_version_rename", "") != "": dest = util.trim_suffix(download_arg["tf_model_version_rename"], "/") dir_path = os.path.dirname(dest) entries = os.listdir(dir_path) if len(entries) == 1: src = os.path.join(dir_path, entries[0]) os.rename(src, dest)
def post_request_metrics(ctx, api, response, prediction_payload, start_time, class_set): api_name = api["name"] api_dimensions = api_metric_dimensions(ctx, api_name) metrics_list = [] metrics_list += status_code_metric(api_dimensions, response.status_code) if prediction_payload is not None: if api.get("tracker") is not None: try: prediction = extract_prediction(api, prediction_payload) if api["tracker"]["model_type"] == "classification": cache_classes(ctx, api, prediction, class_set) metrics_list += prediction_metrics(api_dimensions, api, prediction) except Exception as e: cx_logger().warn("unable to record prediction metric", exc_info=True) metrics_list += latency_metric(api_dimensions, start_time) try: ctx.publish_metrics(metrics_list) except Exception as e: cx_logger().warn("failure encountered while publishing metrics", exc_info=True)
def initialize_client(self, tf_serving_host=None, tf_serving_port=None): signature_message = None if self.type == "onnx": from cortex.lib.client.onnx import ONNXClient client = ONNXClient(self.models) if self.models[0].name == consts.SINGLE_MODEL_NAME: signature_message = "ONNX model signature: {}".format( client.input_signatures[consts.SINGLE_MODEL_NAME] ) else: signature_message = "ONNX model signatures: {}".format(client.input_signatures) cx_logger().info(signature_message) return client elif self.type == "tensorflow": from cortex.lib.client.tensorflow import TensorFlowClient for model in self.models: validate_model_dir(model.base_path) tf_serving_address = tf_serving_host + ":" + tf_serving_port client = TensorFlowClient(tf_serving_address, self.models) if self.models[0].name == consts.SINGLE_MODEL_NAME: signature_message = "TensorFlow model signature: {}".format( client.input_signatures[consts.SINGLE_MODEL_NAME] ) else: signature_message = "TensorFlow model signatures: {}".format( client.input_signatures ) cx_logger().info(signature_message) return client return None
def predict( request: Any = Body(..., media_type="application/json"), debug=False): api = local_cache["api"] predictor_impl = local_cache["predictor_impl"] debug_obj("payload", request, debug) prediction = predictor_impl.predict(request) try: json_string = json.dumps(prediction) except Exception as e: raise UserRuntimeException( f"the return value of predict() or one of its nested values is not JSON serializable", str(e), ) from e debug_obj("prediction", json_string, debug) response = Response(content=json_string, media_type="application/json") if api.tracker is not None: try: predicted_value = api.tracker.extract_predicted_value(prediction) api.post_tracker_metrics(predicted_value) if (api.tracker.model_type == "classification" and predicted_value not in local_cache["class_set"]): tasks = BackgroundTasks() tasks.add_task(api.upload_class, class_name=predicted_value) local_cache["class_set"].add(predicted_value) response.background = tasks except: cx_logger().warn("unable to record prediction metric", exc_info=True) return response
def start(): cache_dir = os.environ["CORTEX_CACHE_DIR"] provider = os.environ["CORTEX_PROVIDER"] api_spec_path = os.environ["CORTEX_API_SPEC"] job_spec_path = os.environ["CORTEX_JOB_SPEC"] project_dir = os.environ["CORTEX_PROJECT_DIR"] model_dir = os.getenv("CORTEX_MODEL_DIR") tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000") tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost") storage = S3(bucket=os.environ["CORTEX_BUCKET"], region=os.environ["AWS_REGION"]) has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS") if has_multiple_servers: with FileLock("/run/used_ports.json.lock"): with open("/run/used_ports.json", "r+") as f: used_ports = json.load(f) for port in used_ports.keys(): if not used_ports[port]: tf_serving_port = port used_ports[port] = True break f.seek(0) json.dump(used_ports, f) f.truncate() raw_api_spec = get_spec(provider, storage, cache_dir, api_spec_path) job_spec = get_job_spec(storage, cache_dir, job_spec_path) api = API(provider=provider, storage=storage, model_dir=model_dir, cache_dir=cache_dir, **raw_api_spec) client = api.predictor.initialize_client(tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port) cx_logger().info("loading the predictor from {}".format( api.predictor.path)) predictor_impl = api.predictor.initialize_impl(project_dir, client, raw_api_spec, job_spec) local_cache["api_spec"] = api local_cache["provider"] = provider local_cache["job_spec"] = job_spec local_cache["predictor_impl"] = predictor_impl local_cache["predict_fn_args"] = inspect.getfullargspec( predictor_impl.predict).args local_cache["sqs_client"] = boto3.client( "sqs", region_name=os.environ["AWS_REGION"]) open("/mnt/workspace/api_readiness.txt", "a").close() cx_logger().info("polling for batches...") sqs_loop()
def extract_waitress_params(config): waitress_kwargs = {} if config is not None: for key, value in config.items(): if key.startswith("waitress_"): waitress_kwargs[key[len("waitress_") :]] = value if len(waitress_kwargs) > 0: cx_logger().info("waitress parameters: {}".format(waitress_kwargs)) return waitress_kwargs
def predict(): debug = request.args.get("debug", "false").lower() == "true" try: payload = request.get_json() except Exception as e: return "malformed json", status.HTTP_400_BAD_REQUEST sess = local_cache["sess"] api = local_cache["api"] ctx = local_cache["ctx"] request_handler = local_cache.get("request_handler") input_metadata = local_cache["input_metadata"] output_metadata = local_cache["output_metadata"] try: debug_obj("payload", payload, debug) prepared_payload = payload if request_handler is not None and util.has_function( request_handler, "pre_inference"): try: prepared_payload = request_handler.pre_inference( payload, input_metadata, api["onnx"]["metadata"]) debug_obj("pre_inference", prepared_payload, debug) except Exception as e: raise UserRuntimeException(api["onnx"]["request_handler"], "pre_inference request handler", str(e)) from e inference_input = convert_to_onnx_input(prepared_payload, input_metadata) model_output = sess.run([], inference_input) debug_obj("inference", model_output, debug) result = model_output if request_handler is not None and util.has_function( request_handler, "post_inference"): try: result = request_handler.post_inference( model_output, output_metadata, api["onnx"]["metadata"]) except Exception as e: raise UserRuntimeException(api["onnx"]["request_handler"], "post_inference request handler", str(e)) from e debug_obj("post_inference", result, debug) except Exception as e: cx_logger().exception("prediction failed") return prediction_failed(str(e)) g.prediction = result return jsonify(result)
def start(args): api = None try: ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) api = ctx.apis_id_map[args.api] local_cache["api"] = api local_cache["ctx"] = ctx if api["predictor"]["type"] != "python": raise CortexException(api["name"], "predictor type is not python") cx_logger().info("loading the predictor from {}".format( api["predictor"]["path"])) predictor_class = ctx.get_predictor_class(api["name"], args.project_dir) try: local_cache["predictor"] = predictor_class( api["predictor"]["config"]) except Exception as e: raise UserRuntimeException(api["predictor"]["path"], "__init__", str(e)) from e finally: refresh_logger() except: cx_logger().exception("failed to start api") sys.exit(1) if api.get("tracker") is not None and api["tracker"].get( "model_type") == "classification": try: local_cache["class_set"] = api_utils.get_classes(ctx, api["name"]) except Exception as e: cx_logger().warn( "an error occurred while attempting to load classes", exc_info=True) waitress_kwargs = {} if api["predictor"].get("config") is not None: for key, value in api["predictor"]["config"].items(): if key.startswith("waitress_"): waitress_kwargs[key[len("waitress_"):]] = value if len(waitress_kwargs) > 0: cx_logger().info("waitress parameters: {}".format(waitress_kwargs)) waitress_kwargs["listen"] = "*:{}".format(args.port) cx_logger().info("{} api is live".format(api["name"])) open("/health_check.txt", "a").close() serve(app, **waitress_kwargs)
def post_metrics(self, metrics): try: if self.statsd is None: raise CortexException("statsd client not initialized") # unexpected for metric in metrics: tags = ["{}:{}".format(dim["Name"], dim["Value"]) for dim in metric["Dimensions"]] if metric.get("Unit") == "Count": self.statsd.increment(metric["MetricName"], value=metric["Value"], tags=tags) else: self.statsd.histogram(metric["MetricName"], value=metric["Value"], tags=tags) except: cx_logger().warn("failure encountered while publishing metrics", exc_info=True)
def start(args): api = None try: ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) api = ctx.apis_id_map[args.api] local_cache["api"] = api local_cache["ctx"] = ctx if api.get("predictor") is None: raise CortexException(api["name"], "predictor key not configured") cx_logger().info("loading the predictor from {}".format( api["predictor"]["path"])) local_cache["predictor"] = ctx.get_predictor_impl( api["name"], args.project_dir) if util.has_function(local_cache["predictor"], "init"): try: model_path = None if api["predictor"].get("model") is not None: _, prefix = ctx.storage.deconstruct_s3_path( api["predictor"]["model"]) model_path = os.path.join( args.model_dir, os.path.basename(os.path.normpath(prefix))) cx_logger().info("calling the predictor's init() function") local_cache["predictor"].init(model_path, api["predictor"]["metadata"]) except Exception as e: raise UserRuntimeException(api["predictor"]["path"], "init", str(e)) from e finally: refresh_logger() except: cx_logger().exception("failed to start api") sys.exit(1) if api.get("tracker") is not None and api["tracker"].get( "model_type") == "classification": try: local_cache["class_set"] = api_utils.get_classes(ctx, api["name"]) except Exception as e: cx_logger().warn( "an error occurred while attempting to load classes", exc_info=True) cx_logger().info("{} api is live".format(api["name"])) serve(app, listen="*:{}".format(args.port))
def start(args): assert_api_version() storage = S3(bucket=os.environ["CORTEX_BUCKET"], region=os.environ["AWS_REGION"]) try: raw_api_spec = get_spec(args.cache_dir, args.spec) api = API(storage=storage, cache_dir=args.cache_dir, **raw_api_spec) client = api.predictor.initialize_client(args) cx_logger().info("loading the predictor from {}".format(api.predictor.path)) predictor_impl = api.predictor.initialize_impl(args.project_dir, client) local_cache["api"] = api local_cache["client"] = client local_cache["predictor_impl"] = predictor_impl except: cx_logger().exception("failed to start api") sys.exit(1) if api.tracker is not None and api.tracker.model_type == "classification": try: local_cache["class_set"] = api.get_cached_classes() except Exception as e: cx_logger().warn("an error occurred while attempting to load classes", exc_info=True) waitress_kwargs = extract_waitress_params(api.predictor.config) waitress_kwargs["listen"] = "*:{}".format(args.port) open("/health_check.txt", "a").close() cx_logger().info("{} api is live".format(api.name)) serve(app, **waitress_kwargs)
def start(): cache_dir = os.environ["CORTEX_CACHE_DIR"] spec = os.environ["CORTEX_API_SPEC"] project_dir = os.environ["CORTEX_PROJECT_DIR"] model_dir = os.getenv("CORTEX_MODEL_DIR", None) tf_serving_port = os.getenv("CORTEX_TF_SERVING_PORT", None) storage = S3(bucket=os.environ["CORTEX_BUCKET"], region=os.environ["AWS_REGION"]) try: raw_api_spec = get_spec(storage, cache_dir, spec) api = API(storage=storage, cache_dir=cache_dir, **raw_api_spec) client = api.predictor.initialize_client(model_dir, tf_serving_port) cx_logger().info("loading the predictor from {}".format( api.predictor.path)) predictor_impl = api.predictor.initialize_impl(project_dir, client) local_cache["api"] = api local_cache["client"] = client local_cache["predictor_impl"] = predictor_impl except: cx_logger().exception("failed to start api") sys.exit(1) if api.tracker is not None and api.tracker.model_type == "classification": try: local_cache["class_set"] = api.get_cached_classes() except Exception as e: cx_logger().warn( "an error occurred while attempting to load classes", exc_info=True) cx_logger().info("{} api is live".format(api.name)) return app
def handle_on_complete(message): job_spec = local_cache["job_spec"] predictor_impl = local_cache["predictor_impl"] sqs_client = local_cache["sqs_client"] queue_url = job_spec["sqs_url"] receipt_handle = message["ReceiptHandle"] try: if not getattr(predictor_impl, "on_job_complete", None): sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle) return True should_run_on_job_complete = False while True: visible_count, not_visible_count = get_total_messages_in_queue() # if there are other messages that are visible, release this message and get the other ones (should rarely happen for FIFO) if visible_count > 0: sqs_client.change_message_visibility( QueueUrl=queue_url, ReceiptHandle=receipt_handle, VisibilityTimeout=0) return False if should_run_on_job_complete: # double check that the queue is still empty (except for the job_complete message) if not_visible_count <= 1: cx_logger().info("executing on_job_complete") predictor_impl.on_job_complete() sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle) return True else: should_run_on_job_complete = False if not_visible_count <= 1: should_run_on_job_complete = True time.sleep(20) except: sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle) raise
def predict(request: Request): tasks = BackgroundTasks() api = local_cache["api"] predictor_impl = local_cache["predictor_impl"] kwargs = build_predict_kwargs(request) prediction = predictor_impl.predict(**kwargs) if isinstance(prediction, bytes): response = Response(content=prediction, media_type="application/octet-stream") elif isinstance(prediction, str): response = Response(content=prediction, media_type="text/plain") elif isinstance(prediction, Response): response = prediction else: try: json_string = json.dumps(prediction) except Exception as e: raise UserRuntimeException( str(e), "please return an object that is JSON serializable (including its nested fields), a bytes object, a string, or a starlette.response.Response object", ) from e response = Response(content=json_string, media_type="application/json") if local_cache["provider"] != "local" and api.monitoring is not None: try: predicted_value = api.monitoring.extract_predicted_value(prediction) api.post_monitoring_metrics(predicted_value) if ( api.monitoring.model_type == "classification" and predicted_value not in local_cache["class_set"] ): tasks.add_task(api.upload_class, class_name=predicted_value) local_cache["class_set"].add(predicted_value) except: cx_logger().warn("unable to record prediction metric", exc_info=True) if util.has_method(predictor_impl, "post_predict"): kwargs = build_post_predict_kwargs(prediction, request) request_thread_pool.submit(predictor_impl.post_predict, **kwargs) if len(tasks.tasks) > 0: response.background = tasks return response
def initialize_client(self, args): if self.type == "onnx": from cortex.lib.client.onnx import ONNXClient _, prefix = self.storage.deconstruct_s3_path(self.model) model_path = os.path.join(args.model_dir, os.path.basename(prefix)) client = ONNXClient(model_path) cx_logger().info("ONNX model signature: {}".format(client.input_signature)) return client elif self.type == "tensorflow": from cortex.lib.client.tensorflow import TensorFlowClient validate_model_dir(args.model_dir) client = TensorFlowClient("localhost:" + str(args.tf_serve_port), self.signature_key) cx_logger().info("TensorFlow model signature: {}".format(client.input_signature)) return client return None
def after_request(response): response.headers["Access-Control-Allow-Origin"] = "*" response.headers["Access-Control-Allow-Headers"] = "*" if not (request.path == "/predict" and request.method == "POST"): return response api = local_cache["api"] ctx = local_cache["ctx"] cx_logger().info(response.status) prediction = None if "prediction" in g: prediction = g.prediction api_utils.post_request_metrics(ctx, api, response, prediction, g.start_time, local_cache["class_set"]) return response
def get_signature_def(stub): limit = 60 for i in range(limit): try: request = create_get_model_metadata_request() resp = stub.GetModelMetadata(request, timeout=10.0) sigAny = resp.metadata["signature_def"] signature_def_map = get_model_metadata_pb2.SignatureDefMap() sigAny.Unpack(signature_def_map) sigmap = json_format.MessageToDict(signature_def_map) return sigmap["signatureDef"] except: if i > 6: cx_logger().warn( "unable to read model metadata - model is still loading, retrying..." ) time.sleep(5) raise CortexException("timeout: unable to read model metadata")
def predict(request: Any = Body(..., media_type="application/json"), debug=False): api = local_cache["api"] predictor_impl = local_cache["predictor_impl"] debug_obj("payload", request, debug) prediction = predictor_impl.predict(request) if isinstance(prediction, bytes): response = Response(content=prediction, media_type="application/octet-stream") elif isinstance(prediction, str): response = Response(content=prediction, media_type="text/plain") debug_obj("prediction", prediction, debug) elif isinstance(prediction, Response): response = prediction else: try: json_string = json.dumps(prediction) debug_obj("prediction", prediction, debug) except Exception as e: raise UserRuntimeException( str(e), "please return an object that is JSON serializable (including its nested fields), a bytes object, a string, or a starlette.response.Response object", ) from e response = Response(content=json_string, media_type="application/json") if local_cache["provider"] != "local" and api.tracker is not None: try: predicted_value = api.tracker.extract_predicted_value(prediction) api.post_tracker_metrics(predicted_value) if ( api.tracker.model_type == "classification" and predicted_value not in local_cache["class_set"] ): tasks = BackgroundTasks() tasks.add_task(api.upload_class, class_name=predicted_value) local_cache["class_set"].add(predicted_value) response.background = tasks except: cx_logger().warn("unable to record prediction metric", exc_info=True) return response
def predict(): debug = request.args.get("debug", "false").lower() == "true" try: sample = request.get_json() except Exception as e: return "Malformed JSON", status.HTTP_400_BAD_REQUEST ctx = local_cache["ctx"] api = local_cache["api"] response = {} try: result = run_predict(sample, debug) except Exception as e: cx_logger().exception("prediction failed") return prediction_failed(str(e)) g.prediction = result return jsonify(result)
def get_signature_def(stub, model): limit = 2 for i in range(limit): try: request = create_get_model_metadata_request(model.name) resp = stub.GetModelMetadata(request, timeout=10.0) sigAny = resp.metadata["signature_def"] signature_def_map = get_model_metadata_pb2.SignatureDefMap() sigAny.Unpack(signature_def_map) sigmap = json_format.MessageToDict(signature_def_map) return sigmap["signatureDef"] except Exception as e: print(e) cx_logger().warn( "unable to read model metadata for model '{}' - retrying ...". format(model.name)) time.sleep(5) raise CortexException( "timeout: unable to read model metadata for model '{}'".format( model.name))
def start(): cache_dir = os.environ["CORTEX_CACHE_DIR"] provider = os.environ["CORTEX_PROVIDER"] spec_path = os.environ["CORTEX_API_SPEC"] project_dir = os.environ["CORTEX_PROJECT_DIR"] model_dir = os.getenv("CORTEX_MODEL_DIR", None) tf_serving_port = os.getenv("CORTEX_TF_SERVING_PORT", "9000") tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost") if provider == "local": storage = LocalStorage(os.getenv("CORTEX_CACHE_DIR")) else: storage = S3(bucket=os.environ["CORTEX_BUCKET"], region=os.environ["AWS_REGION"]) try: raw_api_spec = get_spec(provider, storage, cache_dir, spec_path) api = API(provider=provider, storage=storage, cache_dir=cache_dir, **raw_api_spec) client = api.predictor.initialize_client( model_dir, tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port) cx_logger().info("loading the predictor from {}".format( api.predictor.path)) predictor_impl = api.predictor.initialize_impl(project_dir, client) local_cache["api"] = api local_cache["provider"] = provider local_cache["client"] = client local_cache["predictor_impl"] = predictor_impl local_cache["predict_fn_args"] = inspect.getfullargspec( predictor_impl.predict).args predict_route = "/" if provider != "local": predict_route = "/predict" local_cache["predict_route"] = predict_route except: cx_logger().exception("failed to start api") sys.exit(1) if (provider != "local" and api.monitoring is not None and api.monitoring.model_type == "classification"): try: local_cache["class_set"] = api.get_cached_classes() except: cx_logger().warn( "an error occurred while attempting to load classes", exc_info=True) app.add_api_route(local_cache["predict_route"], predict, methods=["POST"]) app.add_api_route(local_cache["predict_route"], get_summary, methods=["GET"]) return app
def initialize_client(self, model_dir=None, tf_serving_host=None, tf_serving_port=None): if self.type == "onnx": from cortex.lib.client.onnx import ONNXClient model_path = os.path.join(model_dir, os.path.basename(self.model)) client = ONNXClient(model_path) cx_logger().info("ONNX model signature: {}".format( client.input_signature)) return client elif self.type == "tensorflow": from cortex.lib.client.tensorflow import TensorFlowClient tf_serving_address = tf_serving_host + ":" + tf_serving_port validate_model_dir(model_dir) client = TensorFlowClient(tf_serving_address, self.signature_key) cx_logger().info("TensorFlow model signature: {}".format( client.input_signature)) return client return None
def get_signature_def(stub): limit = 60 for i in range(limit): try: request = create_get_model_metadata_request() resp = stub.GetModelMetadata(request, timeout=10.0) sigAny = resp.metadata["signature_def"] signature_def_map = get_model_metadata_pb2.SignatureDefMap() sigAny.Unpack(signature_def_map) sigmap = json_format.MessageToDict(signature_def_map) return sigmap["signatureDef"] except Exception as e: if isinstance(e, grpc.RpcError) and e.code() == grpc.StatusCode.UNAVAILABLE: if i > 6: # only start logging this after 30 seconds cx_logger().warn( "unable to read model metadata - model is still loading, retrying..." ) else: print(e) # unexpected error cx_logger().warn("unable to read model metadata - retrying...") time.sleep(5) raise CortexException("timeout: unable to read model metadata")
def extract_signature(signature_def, signature_key, model_name): cx_logger().info("signature defs found in model '{}': {}".format(model_name, signature_def)) available_keys = list(signature_def.keys()) if len(available_keys) == 0: raise UserException("unable to find signature defs in model '{}'".format(model_name)) if signature_key is None: if len(available_keys) == 1: cx_logger().info( "signature_key was not configured by user, using signature key '{}' for model '{}' (found in the signature def map)".format( available_keys[0], model_name, ) ) signature_key = available_keys[0] elif "predict" in signature_def: cx_logger().info( "signature_key was not configured by user, using signature key 'predict' for model '{}' (found in the signature def map)".format( model_name ) ) signature_key = "predict" else: raise UserException( "signature_key was not configured by user, please specify one the following keys '{}' for model '{}' (found in the signature def map)".format( ", ".join(available_keys), model_name ) ) else: if signature_def.get(signature_key) is None: possibilities_str = "key: '{}'".format(available_keys[0]) if len(available_keys) > 1: possibilities_str = "keys: '{}'".format("', '".join(available_keys)) raise UserException( "signature_key '{}' was not found in signature def map for model '{}', but found the following {}".format( signature_key, model_name, possibilities_str ) ) signature_def_val = signature_def.get(signature_key) if signature_def_val.get("inputs") is None: raise UserException( "unable to find 'inputs' in signature def '{}' for model '{}'".format( signature_key, model_name ) ) parsed_signature = {} for input_name, input_metadata in signature_def_val["inputs"].items(): parsed_signature[input_name] = { "shape": [int(dim["size"]) for dim in input_metadata["tensorShape"]["dim"]], "type": DTYPE_TO_TF_TYPE[input_metadata["dtype"]].name, } return signature_key, parsed_signature
def start(args): api = None try: ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) api = ctx.apis_id_map[args.api] local_cache["api"] = api local_cache["ctx"] = ctx if api.get("tensorflow") is None: raise CortexException(api["name"], "tensorflow key not configured") if api["tensorflow"].get("request_handler") is not None: cx_logger().info("loading the request handler from {}".format( api["tensorflow"]["request_handler"])) local_cache["request_handler"] = ctx.get_request_handler_impl( api["name"], args.project_dir) request_handler = local_cache.get("request_handler") if request_handler is not None and util.has_function( request_handler, "pre_inference"): cx_logger().info( "using pre_inference request handler defined in {}".format( api["tensorflow"]["request_handler"])) else: cx_logger().info("pre_inference request handler not defined") if request_handler is not None and util.has_function( request_handler, "post_inference"): cx_logger().info( "using post_inference request handler defined in {}".format( api["tensorflow"]["request_handler"])) else: cx_logger().info("post_inference request handler not defined") except Exception as e: cx_logger().exception("failed to start api") sys.exit(1) try: validate_model_dir(args.model_dir) except Exception as e: cx_logger().exception("failed to validate model") sys.exit(1) if api.get("tracker") is not None and api["tracker"].get( "model_type") == "classification": try: local_cache["class_set"] = api_utils.get_classes(ctx, api["name"]) except Exception as e: cx_logger().warn( "an error occurred while attempting to load classes", exc_info=True) channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port)) local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub( channel) # wait a bit for tf serving to start before querying metadata limit = 60 for i in range(limit): try: local_cache["model_metadata"] = run_get_model_metadata() break except Exception as e: if i > 6: cx_logger().warn( "unable to read model metadata - model is still loading. Retrying..." ) if i == limit - 1: cx_logger().exception("retry limit exceeded") sys.exit(1) time.sleep(5) signature_key, parsed_signature = extract_signature( local_cache["model_metadata"]["signatureDef"], api["tensorflow"]["signature_key"]) local_cache["signature_key"] = signature_key local_cache["parsed_signature"] = parsed_signature cx_logger().info("model_signature: {}".format( local_cache["parsed_signature"])) cx_logger().info("{} API is live".format(api["name"])) serve(app, listen="*:{}".format(args.port))
def exceptions(e): cx_logger().exception(e) return jsonify(error=str(e)), 500