def predict(request: Request): predictor_impl = local_cache["predictor_impl"] dynamic_batcher = local_cache["dynamic_batcher"] kwargs = build_predict_kwargs(request) if dynamic_batcher: prediction = dynamic_batcher.predict(**kwargs) else: prediction = predictor_impl.predict(**kwargs) if isinstance(prediction, bytes): response = Response(content=prediction, media_type="application/octet-stream") elif isinstance(prediction, str): response = Response(content=prediction, media_type="text/plain") elif isinstance(prediction, Response): response = prediction else: try: json_string = json.dumps(prediction) except Exception as e: raise UserRuntimeException( str(e), "please return an object that is JSON serializable (including its nested fields), a bytes object, " "a string, or a starlette.response.Response object", ) from e response = Response(content=json_string, media_type="application/json") if util.has_method(predictor_impl, "post_predict"): kwargs = build_post_predict_kwargs(prediction, request) request_thread_pool.submit(predictor_impl.post_predict, **kwargs) return response
def predict(request: Request): tasks = BackgroundTasks() api = local_cache["api"] predictor_impl = local_cache["predictor_impl"] dynamic_batcher = local_cache["dynamic_batcher"] kwargs = build_predict_kwargs(request) if dynamic_batcher: prediction = dynamic_batcher.predict(**kwargs) else: prediction = predictor_impl.predict(**kwargs) if isinstance(prediction, bytes): response = Response(content=prediction, media_type="application/octet-stream") elif isinstance(prediction, str): response = Response(content=prediction, media_type="text/plain") elif isinstance(prediction, Response): response = prediction else: try: json_string = json.dumps(prediction) except Exception as e: raise UserRuntimeException( str(e), "please return an object that is JSON serializable (including its nested fields), a bytes object, " "a string, or a starlette.response.Response object", ) from e response = Response(content=json_string, media_type="application/json") if local_cache["provider"] != "local" and api.monitoring is not None: try: predicted_value = api.monitoring.extract_predicted_value( prediction) api.post_monitoring_metrics(predicted_value) if (api.monitoring.model_type == "classification" and predicted_value not in local_cache["class_set"]): tasks.add_task(api.upload_class, class_name=predicted_value) local_cache["class_set"].add(predicted_value) except: logger().warn("unable to record prediction metric", exc_info=True) if util.has_method(predictor_impl, "post_predict"): kwargs = build_post_predict_kwargs(prediction, request) request_thread_pool.submit(predictor_impl.post_predict, **kwargs) if len(tasks.tasks) > 0: response.background = tasks return response
def start_fn(): provider = os.environ["CORTEX_PROVIDER"] project_dir = os.environ["CORTEX_PROJECT_DIR"] spec_path = os.environ["CORTEX_API_SPEC"] model_dir = os.getenv("CORTEX_MODEL_DIR") cache_dir = os.getenv("CORTEX_CACHE_DIR") region = os.getenv("AWS_REGION") tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000") tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost") has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS") if has_multiple_servers: with LockedFile("/run/used_ports.json", "r+") as f: used_ports = json.load(f) for port in used_ports.keys(): if not used_ports[port]: tf_serving_port = port used_ports[port] = True break f.seek(0) json.dump(used_ports, f) f.truncate() try: api = get_api(provider, spec_path, model_dir, cache_dir, region) client = api.predictor.initialize_client( tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port) with FileLock("/run/init_stagger.lock"): logger().info("loading the predictor from {}".format( api.predictor.path)) predictor_impl = api.predictor.initialize_impl(project_dir, client) local_cache["api"] = api local_cache["provider"] = provider local_cache["client"] = client local_cache["predictor_impl"] = predictor_impl local_cache["predict_fn_args"] = inspect.getfullargspec( predictor_impl.predict).args if api.server_side_batching_enabled: dynamic_batching_config = api.api_spec["predictor"][ "server_side_batching"] local_cache["dynamic_batcher"] = DynamicBatcher( predictor_impl, max_batch_size=dynamic_batching_config["max_batch_size"], batch_interval=dynamic_batching_config["batch_interval"] / NANOSECONDS_IN_SECOND, # convert nanoseconds to seconds ) if util.has_method(predictor_impl, "post_predict"): local_cache["post_predict_fn_args"] = inspect.getfullargspec( predictor_impl.post_predict).args predict_route = "/" if provider != "local": predict_route = "/predict" local_cache["predict_route"] = predict_route except: logger().exception("failed to start api") sys.exit(1) if (provider != "local" and api.monitoring is not None and api.monitoring.model_type == "classification"): try: local_cache["class_set"] = api.get_cached_classes() except: logger().warn("an error occurred while attempting to load classes", exc_info=True) app.add_api_route(local_cache["predict_route"], predict, methods=["POST"]) app.add_api_route(local_cache["predict_route"], get_summary, methods=["GET"]) return app
def start_fn(): provider = os.environ["CORTEX_PROVIDER"] project_dir = os.environ["CORTEX_PROJECT_DIR"] spec_path = os.environ["CORTEX_API_SPEC"] model_dir = os.getenv("CORTEX_MODEL_DIR") cache_dir = os.getenv("CORTEX_CACHE_DIR") region = os.getenv("AWS_REGION") tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000") tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost") try: has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS") if has_multiple_servers: with LockedFile("/run/used_ports.json", "r+") as f: used_ports = json.load(f) for port in used_ports.keys(): if not used_ports[port]: tf_serving_port = port used_ports[port] = True break f.seek(0) json.dump(used_ports, f) f.truncate() api = get_api(provider, spec_path, model_dir, cache_dir, region) client = api.predictor.initialize_client( tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port) with FileLock("/run/init_stagger.lock"): logger.info("loading the predictor from {}".format( api.predictor.path)) predictor_impl = api.predictor.initialize_impl(project_dir, client) # crons only stop if an unhandled exception occurs def check_if_crons_have_failed(): while True: for cron in api.predictor.crons: if not cron.is_alive(): os.kill(os.getpid(), signal.SIGQUIT) time.sleep(1) threading.Thread(target=check_if_crons_have_failed, daemon=True).start() local_cache["api"] = api local_cache["provider"] = provider local_cache["client"] = client local_cache["predictor_impl"] = predictor_impl local_cache["predict_fn_args"] = inspect.getfullargspec( predictor_impl.predict).args if api.python_server_side_batching_enabled: dynamic_batching_config = api.api_spec["predictor"][ "server_side_batching"] local_cache["dynamic_batcher"] = DynamicBatcher( predictor_impl, max_batch_size=dynamic_batching_config["max_batch_size"], batch_interval=dynamic_batching_config["batch_interval"] / NANOSECONDS_IN_SECOND, # convert nanoseconds to seconds ) if util.has_method(predictor_impl, "post_predict"): local_cache["post_predict_fn_args"] = inspect.getfullargspec( predictor_impl.post_predict).args predict_route = "/predict" local_cache["predict_route"] = predict_route except (UserRuntimeException, Exception) as err: if not isinstance(err, UserRuntimeException): capture_exception(err) logger.exception("failed to start api") sys.exit(1) app.add_api_route(local_cache["predict_route"], predict, methods=["POST"]) app.add_api_route(local_cache["predict_route"], get_summary, methods=["GET"]) return app