Example #1
0
File: api.py Project: xinhen/cortex
    def post_request_metrics(self,
                             status_code,
                             start_time,
                             prediction_value=None):
        metrics_list = []
        metrics_list.append(self.status_code_metric(status_code))

        if prediction_value is not None:
            metrics_list.append(self.prediction_metrics(prediction_value))

        metrics_list.append(self.latency_metric(start_time))

        try:
            if self.statsd is None:
                raise CortexException(
                    "statsd client not initialized")  # unexpected

            for metric in metrics_list:
                tags = [
                    "{}:{}".format(dim["Name"], dim["Value"])
                    for dim in metric["Dimensions"]
                ]
                if metric.get("Unit") == "Count":
                    self.statsd.increment(metric["MetricName"],
                                          value=metric["Value"],
                                          tags=tags)
                else:
                    self.statsd.histogram(metric["MetricName"],
                                          value=metric["Value"],
                                          tags=tags)
        except:
            cx_logger().warn("failure encountered while publishing metrics",
                             exc_info=True)
Example #2
0
def predict(
        request: Any = Body(..., media_type="application/json"), debug=False):
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]

    debug_obj("payload", request, debug)
    prediction = predictor_impl.predict(request)
    debug_obj("prediction", prediction, debug)

    try:
        json_string = json.dumps(prediction)
    except:
        json_string = util.json_tricks_encoder().encode(prediction)

    response = Response(content=json_string, media_type="application/json")

    if api.tracker is not None:
        try:
            predicted_value = api.tracker.extract_predicted_value(prediction)
            api.post_tracker_metrics(predicted_value)
            if (api.tracker.model_type == "classification"
                    and predicted_value not in local_cache["class_set"]):
                tasks = BackgroundTasks()
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
                response.background = tasks
        except:
            cx_logger().warn("unable to record prediction metric",
                             exc_info=True)

    return response
Example #3
0
def after_request(response):
    response.headers["Access-Control-Allow-Origin"] = "*"
    response.headers["Access-Control-Allow-Headers"] = request.headers.get(
        "Access-Control-Request-Headers", "*"
    )

    if not (request.path == "/predict" and request.method == "POST"):
        return response

    api = local_cache["api"]

    cx_logger().info(response.status)

    prediction = None
    if "prediction" in g:
        prediction = g.prediction

    try:
        if api.tracker is not None:
            predicted_value = api.tracker.extract_predicted_value(prediction)
            api.post_request_metrics(response.status_code, g.start_time, predicted_value)
            if predicted_value is not None and predicted_value not in local_cache["class_set"]:
                api.upload_class(predicted_value)
                local_cache["class_set"].add(predicted_value)
    except Exception as e:
        cx_logger().warn("unable to record prediction metric", exc_info=True)

    return response
Example #4
0
def predict():
    debug = request.args.get("debug", "false").lower() == "true"

    try:
        payload = request.get_json()
    except:
        return "malformed json", status.HTTP_400_BAD_REQUEST

    api = local_cache["api"]
    predictor = local_cache["predictor"]

    try:
        try:
            debug_obj("payload", payload, debug)
            output = predictor.predict(payload, api["predictor"]["metadata"])
            debug_obj("prediction", output, debug)
        except Exception as e:
            raise UserRuntimeException(api["predictor"]["path"], "predict",
                                       str(e)) from e
    except Exception as e:
        cx_logger().exception("prediction failed")
        return prediction_failed(str(e))

    g.prediction = output
    return jsonify(output)
Example #5
0
def start(args):
    download = json.loads(base64.urlsafe_b64decode(args.download))
    for download_arg in download:
        from_path = download_arg["from"]
        to_path = download_arg["to"]
        item_name = download_arg.get("item_name", "")
        bucket_name, prefix = S3.deconstruct_s3_path(from_path)
        s3_client = S3(bucket_name, client_config={})

        if item_name != "":
            cx_logger().info("downloading {} from {}".format(item_name, from_path))
        s3_client.download(prefix, to_path)

        if download_arg.get("unzip", False):
            if item_name != "":
                cx_logger().info("unzipping {}".format(item_name))
            util.extract_zip(
                os.path.join(to_path, os.path.basename(from_path)), delete_zip_file=True
            )

        if download_arg.get("tf_model_version_rename", "") != "":
            dest = util.trim_suffix(download_arg["tf_model_version_rename"], "/")
            dir_path = os.path.dirname(dest)
            entries = os.listdir(dir_path)
            if len(entries) == 1:
                src = os.path.join(dir_path, entries[0])
                os.rename(src, dest)
Example #6
0
def post_request_metrics(ctx, api, response, prediction_payload, start_time,
                         class_set):
    api_name = api["name"]
    api_dimensions = api_metric_dimensions(ctx, api_name)
    metrics_list = []
    metrics_list += status_code_metric(api_dimensions, response.status_code)

    if prediction_payload is not None:
        if api.get("tracker") is not None:
            try:
                prediction = extract_prediction(api, prediction_payload)

                if api["tracker"]["model_type"] == "classification":
                    cache_classes(ctx, api, prediction, class_set)

                metrics_list += prediction_metrics(api_dimensions, api,
                                                   prediction)
            except Exception as e:
                cx_logger().warn("unable to record prediction metric",
                                 exc_info=True)

    metrics_list += latency_metric(api_dimensions, start_time)
    try:
        ctx.publish_metrics(metrics_list)
    except Exception as e:
        cx_logger().warn("failure encountered while publishing metrics",
                         exc_info=True)
Example #7
0
    def initialize_client(self, tf_serving_host=None, tf_serving_port=None):
        signature_message = None
        if self.type == "onnx":
            from cortex.lib.client.onnx import ONNXClient

            client = ONNXClient(self.models)
            if self.models[0].name == consts.SINGLE_MODEL_NAME:
                signature_message = "ONNX model signature: {}".format(
                    client.input_signatures[consts.SINGLE_MODEL_NAME]
                )
            else:
                signature_message = "ONNX model signatures: {}".format(client.input_signatures)
            cx_logger().info(signature_message)
            return client
        elif self.type == "tensorflow":
            from cortex.lib.client.tensorflow import TensorFlowClient

            for model in self.models:
                validate_model_dir(model.base_path)

            tf_serving_address = tf_serving_host + ":" + tf_serving_port
            client = TensorFlowClient(tf_serving_address, self.models)
            if self.models[0].name == consts.SINGLE_MODEL_NAME:
                signature_message = "TensorFlow model signature: {}".format(
                    client.input_signatures[consts.SINGLE_MODEL_NAME]
                )
            else:
                signature_message = "TensorFlow model signatures: {}".format(
                    client.input_signatures
                )
            cx_logger().info(signature_message)
            return client

        return None
Example #8
0
def predict(
        request: Any = Body(..., media_type="application/json"), debug=False):
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]

    debug_obj("payload", request, debug)
    prediction = predictor_impl.predict(request)

    try:
        json_string = json.dumps(prediction)
    except Exception as e:
        raise UserRuntimeException(
            f"the return value of predict() or one of its nested values is not JSON serializable",
            str(e),
        ) from e

    debug_obj("prediction", json_string, debug)

    response = Response(content=json_string, media_type="application/json")

    if api.tracker is not None:
        try:
            predicted_value = api.tracker.extract_predicted_value(prediction)
            api.post_tracker_metrics(predicted_value)
            if (api.tracker.model_type == "classification"
                    and predicted_value not in local_cache["class_set"]):
                tasks = BackgroundTasks()
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
                response.background = tasks
        except:
            cx_logger().warn("unable to record prediction metric",
                             exc_info=True)

    return response
Example #9
0
def start():
    cache_dir = os.environ["CORTEX_CACHE_DIR"]
    provider = os.environ["CORTEX_PROVIDER"]
    api_spec_path = os.environ["CORTEX_API_SPEC"]
    job_spec_path = os.environ["CORTEX_JOB_SPEC"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]

    model_dir = os.getenv("CORTEX_MODEL_DIR")
    tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000")
    tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost")

    storage = S3(bucket=os.environ["CORTEX_BUCKET"],
                 region=os.environ["AWS_REGION"])

    has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS")
    if has_multiple_servers:
        with FileLock("/run/used_ports.json.lock"):
            with open("/run/used_ports.json", "r+") as f:
                used_ports = json.load(f)
                for port in used_ports.keys():
                    if not used_ports[port]:
                        tf_serving_port = port
                        used_ports[port] = True
                        break
                f.seek(0)
                json.dump(used_ports, f)
                f.truncate()

    raw_api_spec = get_spec(provider, storage, cache_dir, api_spec_path)
    job_spec = get_job_spec(storage, cache_dir, job_spec_path)

    api = API(provider=provider,
              storage=storage,
              model_dir=model_dir,
              cache_dir=cache_dir,
              **raw_api_spec)

    client = api.predictor.initialize_client(tf_serving_host=tf_serving_host,
                                             tf_serving_port=tf_serving_port)
    cx_logger().info("loading the predictor from {}".format(
        api.predictor.path))
    predictor_impl = api.predictor.initialize_impl(project_dir, client,
                                                   raw_api_spec, job_spec)

    local_cache["api_spec"] = api
    local_cache["provider"] = provider
    local_cache["job_spec"] = job_spec
    local_cache["predictor_impl"] = predictor_impl
    local_cache["predict_fn_args"] = inspect.getfullargspec(
        predictor_impl.predict).args
    local_cache["sqs_client"] = boto3.client(
        "sqs", region_name=os.environ["AWS_REGION"])

    open("/mnt/workspace/api_readiness.txt", "a").close()

    cx_logger().info("polling for batches...")
    sqs_loop()
Example #10
0
def extract_waitress_params(config):
    waitress_kwargs = {}
    if config is not None:
        for key, value in config.items():
            if key.startswith("waitress_"):
                waitress_kwargs[key[len("waitress_") :]] = value

    if len(waitress_kwargs) > 0:
        cx_logger().info("waitress parameters: {}".format(waitress_kwargs))

    return waitress_kwargs
Example #11
0
def predict():
    debug = request.args.get("debug", "false").lower() == "true"

    try:
        payload = request.get_json()
    except Exception as e:
        return "malformed json", status.HTTP_400_BAD_REQUEST

    sess = local_cache["sess"]
    api = local_cache["api"]
    ctx = local_cache["ctx"]
    request_handler = local_cache.get("request_handler")
    input_metadata = local_cache["input_metadata"]
    output_metadata = local_cache["output_metadata"]

    try:
        debug_obj("payload", payload, debug)

        prepared_payload = payload
        if request_handler is not None and util.has_function(
                request_handler, "pre_inference"):
            try:
                prepared_payload = request_handler.pre_inference(
                    payload, input_metadata, api["onnx"]["metadata"])
                debug_obj("pre_inference", prepared_payload, debug)
            except Exception as e:
                raise UserRuntimeException(api["onnx"]["request_handler"],
                                           "pre_inference request handler",
                                           str(e)) from e

        inference_input = convert_to_onnx_input(prepared_payload,
                                                input_metadata)
        model_output = sess.run([], inference_input)

        debug_obj("inference", model_output, debug)
        result = model_output
        if request_handler is not None and util.has_function(
                request_handler, "post_inference"):
            try:
                result = request_handler.post_inference(
                    model_output, output_metadata, api["onnx"]["metadata"])
            except Exception as e:
                raise UserRuntimeException(api["onnx"]["request_handler"],
                                           "post_inference request handler",
                                           str(e)) from e

            debug_obj("post_inference", result, debug)
    except Exception as e:
        cx_logger().exception("prediction failed")
        return prediction_failed(str(e))

    g.prediction = result
    return jsonify(result)
Example #12
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api["predictor"]["type"] != "python":
            raise CortexException(api["name"], "predictor type is not python")

        cx_logger().info("loading the predictor from {}".format(
            api["predictor"]["path"]))
        predictor_class = ctx.get_predictor_class(api["name"],
                                                  args.project_dir)

        try:
            local_cache["predictor"] = predictor_class(
                api["predictor"]["config"])
        except Exception as e:
            raise UserRuntimeException(api["predictor"]["path"], "__init__",
                                       str(e)) from e
        finally:
            refresh_logger()
    except:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            cx_logger().warn(
                "an error occurred while attempting to load classes",
                exc_info=True)

    waitress_kwargs = {}
    if api["predictor"].get("config") is not None:
        for key, value in api["predictor"]["config"].items():
            if key.startswith("waitress_"):
                waitress_kwargs[key[len("waitress_"):]] = value

    if len(waitress_kwargs) > 0:
        cx_logger().info("waitress parameters: {}".format(waitress_kwargs))

    waitress_kwargs["listen"] = "*:{}".format(args.port)

    cx_logger().info("{} api is live".format(api["name"]))
    open("/health_check.txt", "a").close()
    serve(app, **waitress_kwargs)
Example #13
0
    def post_metrics(self, metrics):
        try:
            if self.statsd is None:
                raise CortexException("statsd client not initialized")  # unexpected

            for metric in metrics:
                tags = ["{}:{}".format(dim["Name"], dim["Value"]) for dim in metric["Dimensions"]]
                if metric.get("Unit") == "Count":
                    self.statsd.increment(metric["MetricName"], value=metric["Value"], tags=tags)
                else:
                    self.statsd.histogram(metric["MetricName"], value=metric["Value"], tags=tags)
        except:
            cx_logger().warn("failure encountered while publishing metrics", exc_info=True)
Example #14
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api.get("predictor") is None:
            raise CortexException(api["name"], "predictor key not configured")

        cx_logger().info("loading the predictor from {}".format(
            api["predictor"]["path"]))
        local_cache["predictor"] = ctx.get_predictor_impl(
            api["name"], args.project_dir)

        if util.has_function(local_cache["predictor"], "init"):
            try:
                model_path = None
                if api["predictor"].get("model") is not None:
                    _, prefix = ctx.storage.deconstruct_s3_path(
                        api["predictor"]["model"])
                    model_path = os.path.join(
                        args.model_dir,
                        os.path.basename(os.path.normpath(prefix)))

                cx_logger().info("calling the predictor's init() function")
                local_cache["predictor"].init(model_path,
                                              api["predictor"]["metadata"])
            except Exception as e:
                raise UserRuntimeException(api["predictor"]["path"], "init",
                                           str(e)) from e
            finally:
                refresh_logger()
    except:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            cx_logger().warn(
                "an error occurred while attempting to load classes",
                exc_info=True)

    cx_logger().info("{} api is live".format(api["name"]))
    serve(app, listen="*:{}".format(args.port))
Example #15
0
def start(args):
    assert_api_version()
    storage = S3(bucket=os.environ["CORTEX_BUCKET"], region=os.environ["AWS_REGION"])
    try:
        raw_api_spec = get_spec(args.cache_dir, args.spec)
        api = API(storage=storage, cache_dir=args.cache_dir, **raw_api_spec)
        client = api.predictor.initialize_client(args)
        cx_logger().info("loading the predictor from {}".format(api.predictor.path))
        predictor_impl = api.predictor.initialize_impl(args.project_dir, client)

        local_cache["api"] = api
        local_cache["client"] = client
        local_cache["predictor_impl"] = predictor_impl
    except:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if api.tracker is not None and api.tracker.model_type == "classification":
        try:
            local_cache["class_set"] = api.get_cached_classes()
        except Exception as e:
            cx_logger().warn("an error occurred while attempting to load classes", exc_info=True)

    waitress_kwargs = extract_waitress_params(api.predictor.config)
    waitress_kwargs["listen"] = "*:{}".format(args.port)

    open("/health_check.txt", "a").close()
    cx_logger().info("{} api is live".format(api.name))
    serve(app, **waitress_kwargs)
Example #16
0
def start():
    cache_dir = os.environ["CORTEX_CACHE_DIR"]
    spec = os.environ["CORTEX_API_SPEC"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]
    model_dir = os.getenv("CORTEX_MODEL_DIR", None)
    tf_serving_port = os.getenv("CORTEX_TF_SERVING_PORT", None)
    storage = S3(bucket=os.environ["CORTEX_BUCKET"],
                 region=os.environ["AWS_REGION"])

    try:
        raw_api_spec = get_spec(storage, cache_dir, spec)
        api = API(storage=storage, cache_dir=cache_dir, **raw_api_spec)
        client = api.predictor.initialize_client(model_dir, tf_serving_port)
        cx_logger().info("loading the predictor from {}".format(
            api.predictor.path))
        predictor_impl = api.predictor.initialize_impl(project_dir, client)

        local_cache["api"] = api
        local_cache["client"] = client
        local_cache["predictor_impl"] = predictor_impl
    except:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if api.tracker is not None and api.tracker.model_type == "classification":
        try:
            local_cache["class_set"] = api.get_cached_classes()
        except Exception as e:
            cx_logger().warn(
                "an error occurred while attempting to load classes",
                exc_info=True)

    cx_logger().info("{} api is live".format(api.name))
    return app
Example #17
0
def handle_on_complete(message):
    job_spec = local_cache["job_spec"]
    predictor_impl = local_cache["predictor_impl"]
    sqs_client = local_cache["sqs_client"]
    queue_url = job_spec["sqs_url"]
    receipt_handle = message["ReceiptHandle"]

    try:
        if not getattr(predictor_impl, "on_job_complete", None):
            sqs_client.delete_message(QueueUrl=queue_url,
                                      ReceiptHandle=receipt_handle)
            return True

        should_run_on_job_complete = False

        while True:
            visible_count, not_visible_count = get_total_messages_in_queue()

            # if there are other messages that are visible, release this message and get the other ones (should rarely happen for FIFO)
            if visible_count > 0:
                sqs_client.change_message_visibility(
                    QueueUrl=queue_url,
                    ReceiptHandle=receipt_handle,
                    VisibilityTimeout=0)
                return False

            if should_run_on_job_complete:
                # double check that the queue is still empty (except for the job_complete message)
                if not_visible_count <= 1:
                    cx_logger().info("executing on_job_complete")
                    predictor_impl.on_job_complete()
                    sqs_client.delete_message(QueueUrl=queue_url,
                                              ReceiptHandle=receipt_handle)
                    return True
                else:
                    should_run_on_job_complete = False

            if not_visible_count <= 1:
                should_run_on_job_complete = True

            time.sleep(20)
    except:
        sqs_client.delete_message(QueueUrl=queue_url,
                                  ReceiptHandle=receipt_handle)
        raise
Example #18
0
def predict(request: Request):
    tasks = BackgroundTasks()
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]
    kwargs = build_predict_kwargs(request)

    prediction = predictor_impl.predict(**kwargs)

    if isinstance(prediction, bytes):
        response = Response(content=prediction, media_type="application/octet-stream")
    elif isinstance(prediction, str):
        response = Response(content=prediction, media_type="text/plain")
    elif isinstance(prediction, Response):
        response = prediction
    else:
        try:
            json_string = json.dumps(prediction)
        except Exception as e:
            raise UserRuntimeException(
                str(e),
                "please return an object that is JSON serializable (including its nested fields), a bytes object, a string, or a starlette.response.Response object",
            ) from e
        response = Response(content=json_string, media_type="application/json")

    if local_cache["provider"] != "local" and api.monitoring is not None:
        try:
            predicted_value = api.monitoring.extract_predicted_value(prediction)
            api.post_monitoring_metrics(predicted_value)
            if (
                api.monitoring.model_type == "classification"
                and predicted_value not in local_cache["class_set"]
            ):
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
        except:
            cx_logger().warn("unable to record prediction metric", exc_info=True)

    if util.has_method(predictor_impl, "post_predict"):
        kwargs = build_post_predict_kwargs(prediction, request)
        request_thread_pool.submit(predictor_impl.post_predict, **kwargs)

    if len(tasks.tasks) > 0:
        response.background = tasks

    return response
Example #19
0
    def initialize_client(self, args):
        if self.type == "onnx":
            from cortex.lib.client.onnx import ONNXClient

            _, prefix = self.storage.deconstruct_s3_path(self.model)
            model_path = os.path.join(args.model_dir, os.path.basename(prefix))
            client = ONNXClient(model_path)
            cx_logger().info("ONNX model signature: {}".format(client.input_signature))
            return client
        elif self.type == "tensorflow":
            from cortex.lib.client.tensorflow import TensorFlowClient

            validate_model_dir(args.model_dir)
            client = TensorFlowClient("localhost:" + str(args.tf_serve_port), self.signature_key)
            cx_logger().info("TensorFlow model signature: {}".format(client.input_signature))
            return client

        return None
Example #20
0
def after_request(response):
    response.headers["Access-Control-Allow-Origin"] = "*"
    response.headers["Access-Control-Allow-Headers"] = "*"

    if not (request.path == "/predict" and request.method == "POST"):
        return response

    api = local_cache["api"]
    ctx = local_cache["ctx"]

    cx_logger().info(response.status)

    prediction = None
    if "prediction" in g:
        prediction = g.prediction

    api_utils.post_request_metrics(ctx, api, response, prediction,
                                   g.start_time, local_cache["class_set"])

    return response
Example #21
0
def get_signature_def(stub):
    limit = 60
    for i in range(limit):
        try:
            request = create_get_model_metadata_request()
            resp = stub.GetModelMetadata(request, timeout=10.0)
            sigAny = resp.metadata["signature_def"]
            signature_def_map = get_model_metadata_pb2.SignatureDefMap()
            sigAny.Unpack(signature_def_map)
            sigmap = json_format.MessageToDict(signature_def_map)
            return sigmap["signatureDef"]
        except:
            if i > 6:
                cx_logger().warn(
                    "unable to read model metadata - model is still loading, retrying..."
                )

        time.sleep(5)

    raise CortexException("timeout: unable to read model metadata")
Example #22
0
def predict(request: Any = Body(..., media_type="application/json"), debug=False):
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]

    debug_obj("payload", request, debug)
    prediction = predictor_impl.predict(request)

    if isinstance(prediction, bytes):
        response = Response(content=prediction, media_type="application/octet-stream")
    elif isinstance(prediction, str):
        response = Response(content=prediction, media_type="text/plain")
        debug_obj("prediction", prediction, debug)
    elif isinstance(prediction, Response):
        response = prediction
    else:
        try:
            json_string = json.dumps(prediction)
            debug_obj("prediction", prediction, debug)
        except Exception as e:
            raise UserRuntimeException(
                str(e),
                "please return an object that is JSON serializable (including its nested fields), a bytes object, a string, or a starlette.response.Response object",
            ) from e
        response = Response(content=json_string, media_type="application/json")

    if local_cache["provider"] != "local" and api.tracker is not None:
        try:
            predicted_value = api.tracker.extract_predicted_value(prediction)
            api.post_tracker_metrics(predicted_value)
            if (
                api.tracker.model_type == "classification"
                and predicted_value not in local_cache["class_set"]
            ):
                tasks = BackgroundTasks()
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
                response.background = tasks
        except:
            cx_logger().warn("unable to record prediction metric", exc_info=True)

    return response
Example #23
0
def predict():
    debug = request.args.get("debug", "false").lower() == "true"

    try:
        sample = request.get_json()
    except Exception as e:
        return "Malformed JSON", status.HTTP_400_BAD_REQUEST

    ctx = local_cache["ctx"]
    api = local_cache["api"]

    response = {}

    try:
        result = run_predict(sample, debug)
    except Exception as e:
        cx_logger().exception("prediction failed")
        return prediction_failed(str(e))

    g.prediction = result

    return jsonify(result)
Example #24
0
def get_signature_def(stub, model):
    limit = 2
    for i in range(limit):
        try:
            request = create_get_model_metadata_request(model.name)
            resp = stub.GetModelMetadata(request, timeout=10.0)
            sigAny = resp.metadata["signature_def"]
            signature_def_map = get_model_metadata_pb2.SignatureDefMap()
            sigAny.Unpack(signature_def_map)
            sigmap = json_format.MessageToDict(signature_def_map)
            return sigmap["signatureDef"]
        except Exception as e:
            print(e)
            cx_logger().warn(
                "unable to read model metadata for model '{}' - retrying ...".
                format(model.name))

        time.sleep(5)

    raise CortexException(
        "timeout: unable to read model metadata for model '{}'".format(
            model.name))
Example #25
0
def start():
    cache_dir = os.environ["CORTEX_CACHE_DIR"]
    provider = os.environ["CORTEX_PROVIDER"]
    spec_path = os.environ["CORTEX_API_SPEC"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]
    model_dir = os.getenv("CORTEX_MODEL_DIR", None)
    tf_serving_port = os.getenv("CORTEX_TF_SERVING_PORT", "9000")
    tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost")

    if provider == "local":
        storage = LocalStorage(os.getenv("CORTEX_CACHE_DIR"))
    else:
        storage = S3(bucket=os.environ["CORTEX_BUCKET"],
                     region=os.environ["AWS_REGION"])

    try:
        raw_api_spec = get_spec(provider, storage, cache_dir, spec_path)
        api = API(provider=provider,
                  storage=storage,
                  cache_dir=cache_dir,
                  **raw_api_spec)
        client = api.predictor.initialize_client(
            model_dir,
            tf_serving_host=tf_serving_host,
            tf_serving_port=tf_serving_port)
        cx_logger().info("loading the predictor from {}".format(
            api.predictor.path))
        predictor_impl = api.predictor.initialize_impl(project_dir, client)

        local_cache["api"] = api
        local_cache["provider"] = provider
        local_cache["client"] = client
        local_cache["predictor_impl"] = predictor_impl
        local_cache["predict_fn_args"] = inspect.getfullargspec(
            predictor_impl.predict).args
        predict_route = "/"
        if provider != "local":
            predict_route = "/predict"
        local_cache["predict_route"] = predict_route
    except:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if (provider != "local" and api.monitoring is not None
            and api.monitoring.model_type == "classification"):
        try:
            local_cache["class_set"] = api.get_cached_classes()
        except:
            cx_logger().warn(
                "an error occurred while attempting to load classes",
                exc_info=True)

    app.add_api_route(local_cache["predict_route"], predict, methods=["POST"])
    app.add_api_route(local_cache["predict_route"],
                      get_summary,
                      methods=["GET"])

    return app
Example #26
0
    def initialize_client(self,
                          model_dir=None,
                          tf_serving_host=None,
                          tf_serving_port=None):
        if self.type == "onnx":
            from cortex.lib.client.onnx import ONNXClient

            model_path = os.path.join(model_dir, os.path.basename(self.model))
            client = ONNXClient(model_path)
            cx_logger().info("ONNX model signature: {}".format(
                client.input_signature))
            return client
        elif self.type == "tensorflow":
            from cortex.lib.client.tensorflow import TensorFlowClient

            tf_serving_address = tf_serving_host + ":" + tf_serving_port
            validate_model_dir(model_dir)
            client = TensorFlowClient(tf_serving_address, self.signature_key)
            cx_logger().info("TensorFlow model signature: {}".format(
                client.input_signature))
            return client

        return None
Example #27
0
def get_signature_def(stub):
    limit = 60
    for i in range(limit):
        try:
            request = create_get_model_metadata_request()
            resp = stub.GetModelMetadata(request, timeout=10.0)
            sigAny = resp.metadata["signature_def"]
            signature_def_map = get_model_metadata_pb2.SignatureDefMap()
            sigAny.Unpack(signature_def_map)
            sigmap = json_format.MessageToDict(signature_def_map)
            return sigmap["signatureDef"]
        except Exception as e:
            if isinstance(e, grpc.RpcError) and e.code() == grpc.StatusCode.UNAVAILABLE:
                if i > 6:  # only start logging this after 30 seconds
                    cx_logger().warn(
                        "unable to read model metadata - model is still loading, retrying..."
                    )
            else:
                print(e)  # unexpected error
                cx_logger().warn("unable to read model metadata - retrying...")

        time.sleep(5)

    raise CortexException("timeout: unable to read model metadata")
Example #28
0
def extract_signature(signature_def, signature_key, model_name):
    cx_logger().info("signature defs found in model '{}': {}".format(model_name, signature_def))

    available_keys = list(signature_def.keys())
    if len(available_keys) == 0:
        raise UserException("unable to find signature defs in model '{}'".format(model_name))

    if signature_key is None:
        if len(available_keys) == 1:
            cx_logger().info(
                "signature_key was not configured by user, using signature key '{}' for model '{}' (found in the signature def map)".format(
                    available_keys[0], model_name,
                )
            )
            signature_key = available_keys[0]
        elif "predict" in signature_def:
            cx_logger().info(
                "signature_key was not configured by user, using signature key 'predict' for model '{}' (found in the signature def map)".format(
                    model_name
                )
            )
            signature_key = "predict"
        else:
            raise UserException(
                "signature_key was not configured by user, please specify one the following keys '{}' for model '{}' (found in the signature def map)".format(
                    ", ".join(available_keys), model_name
                )
            )
    else:
        if signature_def.get(signature_key) is None:
            possibilities_str = "key: '{}'".format(available_keys[0])
            if len(available_keys) > 1:
                possibilities_str = "keys: '{}'".format("', '".join(available_keys))

            raise UserException(
                "signature_key '{}' was not found in signature def map for model '{}', but found the following {}".format(
                    signature_key, model_name, possibilities_str
                )
            )

    signature_def_val = signature_def.get(signature_key)

    if signature_def_val.get("inputs") is None:
        raise UserException(
            "unable to find 'inputs' in signature def '{}' for model '{}'".format(
                signature_key, model_name
            )
        )

    parsed_signature = {}
    for input_name, input_metadata in signature_def_val["inputs"].items():
        parsed_signature[input_name] = {
            "shape": [int(dim["size"]) for dim in input_metadata["tensorShape"]["dim"]],
            "type": DTYPE_TO_TF_TYPE[input_metadata["dtype"]].name,
        }
    return signature_key, parsed_signature
Example #29
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api.get("tensorflow") is None:
            raise CortexException(api["name"], "tensorflow key not configured")

        if api["tensorflow"].get("request_handler") is not None:
            cx_logger().info("loading the request handler from {}".format(
                api["tensorflow"]["request_handler"]))
            local_cache["request_handler"] = ctx.get_request_handler_impl(
                api["name"], args.project_dir)
        request_handler = local_cache.get("request_handler")

        if request_handler is not None and util.has_function(
                request_handler, "pre_inference"):
            cx_logger().info(
                "using pre_inference request handler defined in {}".format(
                    api["tensorflow"]["request_handler"]))
        else:
            cx_logger().info("pre_inference request handler not defined")

        if request_handler is not None and util.has_function(
                request_handler, "post_inference"):
            cx_logger().info(
                "using post_inference request handler defined in {}".format(
                    api["tensorflow"]["request_handler"]))
        else:
            cx_logger().info("post_inference request handler not defined")

    except Exception as e:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    try:
        validate_model_dir(args.model_dir)
    except Exception as e:
        cx_logger().exception("failed to validate model")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            cx_logger().warn(
                "an error occurred while attempting to load classes",
                exc_info=True)

    channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port))
    local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(
        channel)

    # wait a bit for tf serving to start before querying metadata
    limit = 60
    for i in range(limit):
        try:
            local_cache["model_metadata"] = run_get_model_metadata()
            break
        except Exception as e:
            if i > 6:
                cx_logger().warn(
                    "unable to read model metadata - model is still loading. Retrying..."
                )
            if i == limit - 1:
                cx_logger().exception("retry limit exceeded")
                sys.exit(1)

        time.sleep(5)

    signature_key, parsed_signature = extract_signature(
        local_cache["model_metadata"]["signatureDef"],
        api["tensorflow"]["signature_key"])

    local_cache["signature_key"] = signature_key
    local_cache["parsed_signature"] = parsed_signature
    cx_logger().info("model_signature: {}".format(
        local_cache["parsed_signature"]))

    cx_logger().info("{} API is live".format(api["name"]))
    serve(app, listen="*:{}".format(args.port))
Example #30
0
def exceptions(e):
    cx_logger().exception(e)
    return jsonify(error=str(e)), 500