Exemplo n.º 1
0
def predict(request: Request):
    tasks = BackgroundTasks()
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]
    kwargs = build_predict_kwargs(request)

    prediction = predictor_impl.predict(**kwargs)

    if isinstance(prediction, bytes):
        response = Response(content=prediction, media_type="application/octet-stream")
    elif isinstance(prediction, str):
        response = Response(content=prediction, media_type="text/plain")
    elif isinstance(prediction, Response):
        response = prediction
    else:
        try:
            json_string = json.dumps(prediction)
        except Exception as e:
            raise UserRuntimeException(
                str(e),
                "please return an object that is JSON serializable (including its nested fields), a bytes object, a string, or a starlette.response.Response object",
            ) from e
        response = Response(content=json_string, media_type="application/json")

    if local_cache["provider"] != "local" and api.monitoring is not None:
        try:
            predicted_value = api.monitoring.extract_predicted_value(prediction)
            api.post_monitoring_metrics(predicted_value)
            if (
                api.monitoring.model_type == "classification"
                and predicted_value not in local_cache["class_set"]
            ):
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
        except:
            cx_logger().warn("unable to record prediction metric", exc_info=True)

    if util.has_method(predictor_impl, "post_predict"):
        kwargs = build_post_predict_kwargs(prediction, request)
        request_thread_pool.submit(predictor_impl.post_predict, **kwargs)

    if len(tasks.tasks) > 0:
        response.background = tasks

    return response
Exemplo n.º 2
0
def start_fn():
    cache_dir = os.environ["CORTEX_CACHE_DIR"]
    provider = os.environ["CORTEX_PROVIDER"]
    spec_path = os.environ["CORTEX_API_SPEC"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]

    model_dir = os.getenv("CORTEX_MODEL_DIR")
    tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000")
    tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost")

    if provider == "local":
        storage = LocalStorage(os.getenv("CORTEX_CACHE_DIR"))
    else:
        storage = S3(bucket=os.environ["CORTEX_BUCKET"], region=os.environ["AWS_REGION"])

    has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS")
    if has_multiple_servers:
        with FileLock("/run/used_ports.json.lock"):
            with open("/run/used_ports.json", "r+") as f:
                used_ports = json.load(f)
                for port in used_ports.keys():
                    if not used_ports[port]:
                        tf_serving_port = port
                        used_ports[port] = True
                        break
                f.seek(0)
                json.dump(used_ports, f)
                f.truncate()

    try:
        raw_api_spec = get_spec(provider, storage, cache_dir, spec_path)
        api = API(
            provider=provider,
            storage=storage,
            model_dir=model_dir,
            cache_dir=cache_dir,
            **raw_api_spec,
        )
        client = api.predictor.initialize_client(
            tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port
        )
        cx_logger().info("loading the predictor from {}".format(api.predictor.path))
        predictor_impl = api.predictor.initialize_impl(project_dir, client, raw_api_spec, None)

        local_cache["api"] = api
        local_cache["provider"] = provider
        local_cache["client"] = client
        local_cache["predictor_impl"] = predictor_impl
        local_cache["predict_fn_args"] = inspect.getfullargspec(predictor_impl.predict).args
        if util.has_method(predictor_impl, "post_predict"):
            local_cache["post_predict_fn_args"] = inspect.getfullargspec(
                predictor_impl.post_predict
            ).args

        predict_route = "/"
        if provider != "local":
            predict_route = "/predict"
        local_cache["predict_route"] = predict_route
    except:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if (
        provider != "local"
        and api.monitoring is not None
        and api.monitoring.model_type == "classification"
    ):
        try:
            local_cache["class_set"] = api.get_cached_classes()
        except:
            cx_logger().warn("an error occurred while attempting to load classes", exc_info=True)

    app.add_api_route(local_cache["predict_route"], predict, methods=["POST"])
    app.add_api_route(local_cache["predict_route"], get_summary, methods=["GET"])

    return app