Ejemplo n.º 1
0
def predict(request: Request):
    predictor_impl = local_cache["predictor_impl"]
    dynamic_batcher = local_cache["dynamic_batcher"]
    kwargs = build_predict_kwargs(request)

    if dynamic_batcher:
        prediction = dynamic_batcher.predict(**kwargs)
    else:
        prediction = predictor_impl.predict(**kwargs)

    if isinstance(prediction, bytes):
        response = Response(content=prediction, media_type="application/octet-stream")
    elif isinstance(prediction, str):
        response = Response(content=prediction, media_type="text/plain")
    elif isinstance(prediction, Response):
        response = prediction
    else:
        try:
            json_string = json.dumps(prediction)
        except Exception as e:
            raise UserRuntimeException(
                str(e),
                "please return an object that is JSON serializable (including its nested fields), a bytes object, "
                "a string, or a starlette.response.Response object",
            ) from e
        response = Response(content=json_string, media_type="application/json")

    if util.has_method(predictor_impl, "post_predict"):
        kwargs = build_post_predict_kwargs(prediction, request)
        request_thread_pool.submit(predictor_impl.post_predict, **kwargs)

    return response
Ejemplo n.º 2
0
def predict(request: Request):
    tasks = BackgroundTasks()
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]
    dynamic_batcher = local_cache["dynamic_batcher"]
    kwargs = build_predict_kwargs(request)

    if dynamic_batcher:
        prediction = dynamic_batcher.predict(**kwargs)
    else:
        prediction = predictor_impl.predict(**kwargs)

    if isinstance(prediction, bytes):
        response = Response(content=prediction,
                            media_type="application/octet-stream")
    elif isinstance(prediction, str):
        response = Response(content=prediction, media_type="text/plain")
    elif isinstance(prediction, Response):
        response = prediction
    else:
        try:
            json_string = json.dumps(prediction)
        except Exception as e:
            raise UserRuntimeException(
                str(e),
                "please return an object that is JSON serializable (including its nested fields), a bytes object, "
                "a string, or a starlette.response.Response object",
            ) from e
        response = Response(content=json_string, media_type="application/json")

    if local_cache["provider"] != "local" and api.monitoring is not None:
        try:
            predicted_value = api.monitoring.extract_predicted_value(
                prediction)
            api.post_monitoring_metrics(predicted_value)
            if (api.monitoring.model_type == "classification"
                    and predicted_value not in local_cache["class_set"]):
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
        except:
            logger().warn("unable to record prediction metric", exc_info=True)

    if util.has_method(predictor_impl, "post_predict"):
        kwargs = build_post_predict_kwargs(prediction, request)
        request_thread_pool.submit(predictor_impl.post_predict, **kwargs)

    if len(tasks.tasks) > 0:
        response.background = tasks

    return response
Ejemplo n.º 3
0
def start_fn():
    provider = os.environ["CORTEX_PROVIDER"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]
    spec_path = os.environ["CORTEX_API_SPEC"]

    model_dir = os.getenv("CORTEX_MODEL_DIR")
    cache_dir = os.getenv("CORTEX_CACHE_DIR")
    region = os.getenv("AWS_REGION")

    tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000")
    tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost")

    has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS")
    if has_multiple_servers:
        with LockedFile("/run/used_ports.json", "r+") as f:
            used_ports = json.load(f)
            for port in used_ports.keys():
                if not used_ports[port]:
                    tf_serving_port = port
                    used_ports[port] = True
                    break
            f.seek(0)
            json.dump(used_ports, f)
            f.truncate()

    try:
        api = get_api(provider, spec_path, model_dir, cache_dir, region)

        client = api.predictor.initialize_client(
            tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port)

        with FileLock("/run/init_stagger.lock"):
            logger().info("loading the predictor from {}".format(
                api.predictor.path))
            predictor_impl = api.predictor.initialize_impl(project_dir, client)

        local_cache["api"] = api
        local_cache["provider"] = provider
        local_cache["client"] = client
        local_cache["predictor_impl"] = predictor_impl
        local_cache["predict_fn_args"] = inspect.getfullargspec(
            predictor_impl.predict).args

        if api.server_side_batching_enabled:
            dynamic_batching_config = api.api_spec["predictor"][
                "server_side_batching"]
            local_cache["dynamic_batcher"] = DynamicBatcher(
                predictor_impl,
                max_batch_size=dynamic_batching_config["max_batch_size"],
                batch_interval=dynamic_batching_config["batch_interval"] /
                NANOSECONDS_IN_SECOND,  # convert nanoseconds to seconds
            )

        if util.has_method(predictor_impl, "post_predict"):
            local_cache["post_predict_fn_args"] = inspect.getfullargspec(
                predictor_impl.post_predict).args

        predict_route = "/"
        if provider != "local":
            predict_route = "/predict"
        local_cache["predict_route"] = predict_route
    except:
        logger().exception("failed to start api")
        sys.exit(1)

    if (provider != "local" and api.monitoring is not None
            and api.monitoring.model_type == "classification"):
        try:
            local_cache["class_set"] = api.get_cached_classes()
        except:
            logger().warn("an error occurred while attempting to load classes",
                          exc_info=True)

    app.add_api_route(local_cache["predict_route"], predict, methods=["POST"])
    app.add_api_route(local_cache["predict_route"],
                      get_summary,
                      methods=["GET"])

    return app
Ejemplo n.º 4
0
def start_fn():
    provider = os.environ["CORTEX_PROVIDER"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]
    spec_path = os.environ["CORTEX_API_SPEC"]

    model_dir = os.getenv("CORTEX_MODEL_DIR")
    cache_dir = os.getenv("CORTEX_CACHE_DIR")
    region = os.getenv("AWS_REGION")

    tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000")
    tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost")

    try:
        has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS")
        if has_multiple_servers:
            with LockedFile("/run/used_ports.json", "r+") as f:
                used_ports = json.load(f)
                for port in used_ports.keys():
                    if not used_ports[port]:
                        tf_serving_port = port
                        used_ports[port] = True
                        break
                f.seek(0)
                json.dump(used_ports, f)
                f.truncate()

        api = get_api(provider, spec_path, model_dir, cache_dir, region)

        client = api.predictor.initialize_client(
            tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port)

        with FileLock("/run/init_stagger.lock"):
            logger.info("loading the predictor from {}".format(
                api.predictor.path))
            predictor_impl = api.predictor.initialize_impl(project_dir, client)

        # crons only stop if an unhandled exception occurs
        def check_if_crons_have_failed():
            while True:
                for cron in api.predictor.crons:
                    if not cron.is_alive():
                        os.kill(os.getpid(), signal.SIGQUIT)
                time.sleep(1)

        threading.Thread(target=check_if_crons_have_failed,
                         daemon=True).start()

        local_cache["api"] = api
        local_cache["provider"] = provider
        local_cache["client"] = client
        local_cache["predictor_impl"] = predictor_impl
        local_cache["predict_fn_args"] = inspect.getfullargspec(
            predictor_impl.predict).args

        if api.python_server_side_batching_enabled:
            dynamic_batching_config = api.api_spec["predictor"][
                "server_side_batching"]
            local_cache["dynamic_batcher"] = DynamicBatcher(
                predictor_impl,
                max_batch_size=dynamic_batching_config["max_batch_size"],
                batch_interval=dynamic_batching_config["batch_interval"] /
                NANOSECONDS_IN_SECOND,  # convert nanoseconds to seconds
            )

        if util.has_method(predictor_impl, "post_predict"):
            local_cache["post_predict_fn_args"] = inspect.getfullargspec(
                predictor_impl.post_predict).args

        predict_route = "/predict"
        local_cache["predict_route"] = predict_route

    except (UserRuntimeException, Exception) as err:
        if not isinstance(err, UserRuntimeException):
            capture_exception(err)
        logger.exception("failed to start api")
        sys.exit(1)

    app.add_api_route(local_cache["predict_route"], predict, methods=["POST"])
    app.add_api_route(local_cache["predict_route"],
                      get_summary,
                      methods=["GET"])

    return app