예제 #1
0
    def predict(self, model_input, model_name=None):
        """Validate input, convert it to a dictionary of input_name to numpy.ndarray, and make a prediction.

        Args:
            model_input: Input to the model.
            model_name: Model to use when multiple models are deployed in a single API.

        Returns:
            numpy.ndarray: The prediction returned from the model.
        """
        if consts.SINGLE_MODEL_NAME in self._model_names:
            return self._run_inference(model_input, consts.SINGLE_MODEL_NAME)

        if model_name is None:
            raise UserRuntimeException(
                "model_name was not specified, choose one of the following: {}".format(
                    self._model_names
                )
            )

        if model_name not in self._model_names:
            raise UserRuntimeException(
                "'{}' model wasn't found in the list of available models: {}".format(
                    model_name, self._model_names
                )
            )

        return self._run_inference(model_input, model_name)
예제 #2
0
    def predict(self, model_input, model_name=None):
        """Validate model_input, convert it to a Prediction Proto, and make a request to TensorFlow Serving.

        Args:
            model_input: Input to the model.
            model_name: Model to use when multiple models are deployed in a single API.

        Returns:
            dict: TensorFlow Serving response converted to a dictionary.
        """
        if consts.SINGLE_MODEL_NAME in self._model_names:
            return self._run_inference(model_input, consts.SINGLE_MODEL_NAME)

        if model_name is None:
            raise UserRuntimeException(
                "model_name was not specified, choose one of the following: {}".format(
                    self._model_names
                )
            )

        if model_name not in self._model_names:
            raise UserRuntimeException(
                "'{}' model wasn't found in the list of available models: {}".format(
                    model_name, self._model_names
                )
            )

        return self._run_inference(model_input, model_name)
예제 #3
0
파일: api.py 프로젝트: databill86/cortex-1
def predict(app_name, api_name):
    debug = request.args.get("debug", "false").lower() == "true"

    try:
        sample = request.get_json()
    except Exception as e:
        return "malformed json", status.HTTP_400_BAD_REQUEST

    sess = local_cache["sess"]
    api = local_cache["api"]
    ctx = local_cache["ctx"]
    request_handler = local_cache.get("request_handler")
    input_metadata = local_cache["input_metadata"]
    output_metadata = local_cache["output_metadata"]

    try:
        debug_obj("sample", sample, debug)

        prepared_sample = sample
        if request_handler is not None and util.has_function(
                request_handler, "pre_inference"):
            try:
                prepared_sample = request_handler.pre_inference(
                    sample, input_metadata)
                debug_obj("pre_inference", prepared_sample, debug)
            except Exception as e:
                raise UserRuntimeException(api["request_handler"],
                                           "pre_inference request handler",
                                           str(e)) from e

        inference_input = convert_to_onnx_input(prepared_sample,
                                                input_metadata)
        model_outputs = sess.run([], inference_input)
        result = []
        for model_output in model_outputs:
            if type(model_output) is np.ndarray:
                result.append(model_output.tolist())
            else:
                result.append(model_output)

        debug_obj("inference", result, debug)

        if request_handler is not None and util.has_function(
                request_handler, "post_inference"):
            try:
                result = request_handler.post_inference(
                    result, output_metadata)
            except Exception as e:
                raise UserRuntimeException(api["request_handler"],
                                           "post_inference request handler",
                                           str(e)) from e

            debug_obj("post_inference", result, debug)
    except Exception as e:
        logger.exception("prediction failed")
        return prediction_failed(str(e))

    g.prediction = result
    return jsonify(result)
예제 #4
0
파일: api.py 프로젝트: rogervaas/cortex
def predict():
    debug = request.args.get("debug", "false").lower() == "true"

    try:
        payload = request.get_json()
    except Exception as e:
        return "malformed json", status.HTTP_400_BAD_REQUEST

    sess = local_cache["sess"]
    api = local_cache["api"]
    ctx = local_cache["ctx"]
    request_handler = local_cache.get("request_handler")
    input_metadata = local_cache["input_metadata"]
    output_metadata = local_cache["output_metadata"]

    try:
        debug_obj("payload", payload, debug)

        prepared_payload = payload
        if request_handler is not None and util.has_function(
                request_handler, "pre_inference"):
            try:
                prepared_payload = request_handler.pre_inference(
                    payload, input_metadata, api["onnx"]["metadata"])
                debug_obj("pre_inference", prepared_payload, debug)
            except Exception as e:
                raise UserRuntimeException(api["onnx"]["request_handler"],
                                           "pre_inference request handler",
                                           str(e)) from e

        inference_input = convert_to_onnx_input(prepared_payload,
                                                input_metadata)
        model_output = sess.run([], inference_input)

        debug_obj("inference", model_output, debug)
        result = model_output
        if request_handler is not None and util.has_function(
                request_handler, "post_inference"):
            try:
                result = request_handler.post_inference(
                    model_output, output_metadata, api["onnx"]["metadata"])
            except Exception as e:
                raise UserRuntimeException(api["onnx"]["request_handler"],
                                           "post_inference request handler",
                                           str(e)) from e

            debug_obj("post_inference", result, debug)
    except Exception as e:
        cx_logger().exception("prediction failed")
        return prediction_failed(str(e))

    g.prediction = result
    return jsonify(result)
예제 #5
0
    def _validate_model_args(self,
                             model_name: Optional[str] = None,
                             model_version: str = "latest") -> Tuple[str, str]:
        """
        Validate the model name and model version.

        Args:
            model_name: Name of the model.
            model_version: Model version to use. Can also be "latest" for picking the highest version.

        Returns:
            The processed model_name, model_version tuple if they had to go through modification.

        Raises:
            UserRuntimeException if the validation fails.
        """

        if model_version != "latest" and not model_version.isnumeric():
            raise UserRuntimeException(
                "model_version must be either a parse-able numeric value or 'latest'"
            )

        # when predictor:model_path or predictor:models:paths is specified
        if not self._models_dir:

            # when predictor:model_path is provided
            if consts.SINGLE_MODEL_NAME in self._spec_model_names:
                return consts.SINGLE_MODEL_NAME, model_version

            # when predictor:models:paths is specified
            if model_name is None:
                raise UserRuntimeException(
                    f"model_name was not specified, choose one of the following: {self._spec_model_names}"
                )

            if model_name not in self._spec_model_names:
                raise UserRuntimeException(
                    f"'{model_name}' model wasn't found in the list of available models"
                )

        # when predictor:models:dir is specified
        if self._models_dir:
            if model_name is None:
                raise UserRuntimeException("model_name was not specified")
            if not self._caching_enabled:
                available_models = find_ondisk_models_with_lock(self._lock_dir)
                if model_name not in available_models:
                    raise UserRuntimeException(
                        f"'{model_name}' model wasn't found in the list of available models"
                    )

        return model_name, model_version
예제 #6
0
    def predict(self,
                model_input: Any,
                model_name: Optional[str] = None,
                model_version: str = "latest") -> dict:
        """
        Validate model_input, convert it to a Prediction Proto, and make a request to TensorFlow Serving.

        Args:
            model_input: Input to the model.
            model_name (optional): Name of the model to retrieve (when multiple models are deployed in an API).
                When predictor.models.paths is specified, model_name should be the name of one of the models listed in the API config.
                When predictor.models.dir is specified, model_name should be the name of a top-level directory in the models dir.
            model_version (string, optional): Version of the model to retrieve. Can be omitted or set to "latest" to select the highest version.

        Returns:
            dict: TensorFlow Serving response converted to a dictionary.
        """

        if model_version != "latest" and not model_version.isnumeric():
            raise UserRuntimeException(
                "model_version must be either a parse-able numeric value or 'latest'"
            )

        # when predictor:model_path or predictor:models:paths is specified
        if not self._models_dir:

            # when predictor:model_path is provided
            if consts.SINGLE_MODEL_NAME in self._spec_model_names:
                return self._run_inference(model_input,
                                           consts.SINGLE_MODEL_NAME,
                                           model_version)

            # when predictor:models:paths is specified
            if model_name is None:
                raise UserRuntimeException(
                    f"model_name was not specified, choose one of the following: {self._spec_model_names}"
                )

            if model_name not in self._spec_model_names:
                raise UserRuntimeException(
                    f"'{model_name}' model wasn't found in the list of available models"
                )

        # when predictor:models:dir is specified
        if self._models_dir and model_name is None:
            raise UserRuntimeException("model_name was not specified")

        return self._run_inference(model_input, model_name, model_version)
예제 #7
0
파일: api.py 프로젝트: rogervaas/cortex
def predict():
    debug = request.args.get("debug", "false").lower() == "true"

    try:
        payload = request.get_json()
    except:
        return "malformed json", status.HTTP_400_BAD_REQUEST

    api = local_cache["api"]
    predictor = local_cache["predictor"]

    try:
        try:
            debug_obj("payload", payload, debug)
            output = predictor.predict(payload, api["predictor"]["metadata"])
            debug_obj("prediction", output, debug)
        except Exception as e:
            raise UserRuntimeException(api["predictor"]["path"], "predict",
                                       str(e)) from e
    except Exception as e:
        cx_logger().exception("prediction failed")
        return prediction_failed(str(e))

    g.prediction = output
    return jsonify(output)
예제 #8
0
    def initialize_impl(self, project_dir, client=None, api_spec=None, job_spec=None):
        class_impl = self.class_impl(project_dir)
        constructor_args = inspect.getfullargspec(class_impl.__init__).args

        args = {}

        config = deepcopy(api_spec["predictor"]["config"])
        if job_spec is not None and job_spec.get("config") is not None:
            util.merge_dicts_in_place_overwrite(config, job_spec["config"])

        if "config" in constructor_args:
            args["config"] = config
        if "job_spec" in constructor_args:
            args["job_spec"] = job_spec

        try:
            if self.type == "onnx":
                args["onnx_client"] = client
                return class_impl(**args)
            elif self.type == "tensorflow":
                args["tensorflow_client"] = client
                return class_impl(**args)
            else:
                return class_impl(**args)
        except Exception as e:
            raise UserRuntimeException(self.path, "__init__", str(e)) from e
        finally:
            refresh_logger()
예제 #9
0
파일: spark_util.py 프로젝트: kwecht/cortex
def run_custom_aggregator(aggregate, df, ctx, spark):
    aggregator = ctx.aggregators[aggregate["aggregator"]]
    aggregator_impl, _ = ctx.get_aggregator_impl(aggregate["name"])

    try:
        input = ctx.populate_values(aggregate["input"],
                                    aggregator["input"],
                                    preserve_column_refs=False)
    except CortexException as e:
        e.wrap("aggregate " + aggregate["name"], "input")
        raise

    try:
        result = aggregator_impl.aggregate_spark(df, input)
    except Exception as e:
        raise UserRuntimeException(
            "aggregate " + aggregate["name"],
            "aggregator " + aggregator["name"],
            "function aggregate_spark",
        ) from e

    if aggregator.get(
            "output_type") is not None and not util.validate_output_type(
                result, aggregator["output_type"]):
        raise UserException(
            "aggregate " + aggregate["name"],
            "aggregator " + aggregator["name"],
            "unsupported return type (expected type {}, got {})".format(
                util.data_type_str(aggregator["output_type"]),
                util.user_obj_str(result)),
        )

    result = util.cast_output_type(result, aggregator["output_type"])
    ctx.store_aggregate_result(result, aggregate)
    return result
예제 #10
0
def predict(
        request: Any = Body(..., media_type="application/json"), debug=False):
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]

    debug_obj("payload", request, debug)
    prediction = predictor_impl.predict(request)

    try:
        json_string = json.dumps(prediction)
    except Exception as e:
        raise UserRuntimeException(
            f"the return value of predict() or one of its nested values is not JSON serializable",
            str(e),
        ) from e

    debug_obj("prediction", json_string, debug)

    response = Response(content=json_string, media_type="application/json")

    if api.tracker is not None:
        try:
            predicted_value = api.tracker.extract_predicted_value(prediction)
            api.post_tracker_metrics(predicted_value)
            if (api.tracker.model_type == "classification"
                    and predicted_value not in local_cache["class_set"]):
                tasks = BackgroundTasks()
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
                response.background = tasks
        except:
            cx_logger().warn("unable to record prediction metric",
                             exc_info=True)

    return response
예제 #11
0
    def get_model(self, model_name: Optional[str] = None, model_version: str = "latest") -> Any:
        """
        Retrieve a model for inference.

        Args:
            model_name (optional): Name of the model to retrieve (when multiple models are deployed in an API).
                When predictor.models.paths is specified, model_name should be the name of one of the models listed in the API config.
                When predictor.models.dir is specified, model_name should be the name of a top-level directory in the models dir.
            model_version (string, optional): Version of the model to retrieve. Can be omitted or set to "latest" to select the highest version.

        Returns:
            The model as loaded by the load_model() method.
        """

        if model_version != "latest" and not model_version.isnumeric():
            raise UserRuntimeException(
                "model_version must be either a parse-able numeric value or 'latest'"
            )

        # when predictor:model_path or predictor:models:paths is specified
        if not self._models_dir:

            # when predictor:model_path is provided
            if consts.SINGLE_MODEL_NAME in self._spec_model_names:
                model_name = consts.SINGLE_MODEL_NAME
                model = self._get_model(model_name, model_version)
                if model is None:
                    raise UserRuntimeException(
                        f"model {model_name} of version {model_version} wasn't found"
                    )
                return model

            # when predictor:models:paths is specified
            if model_name is None:
                raise UserRuntimeException(
                    f"model_name was not specified, choose one of the following: {self._spec_model_names}"
                )

            if model_name not in self._spec_model_names:
                raise UserRuntimeException(
                    f"'{model_name}' model wasn't found in the list of available models"
                )

        # when predictor:models:dir is specified
        if self._models_dir:
            if model_name is None:
                raise UserRuntimeException("model_name was not specified")
            if not self._caching_enabled:
                available_models = find_ondisk_models_with_lock(self._lock_dir)
                if model_name not in available_models:
                    raise UserRuntimeException(
                        f"'{model_name}' model wasn't found in the list of available models"
                    )

        model = self._get_model(model_name, model_version)
        if model is None:
            raise UserRuntimeException(
                f"model {model_name} of version {model_version} wasn't found"
            )
        return model
예제 #12
0
    def _run_inference(self, model_input: Any, model_name: str,
                       model_version: str) -> Any:
        """
        Run the inference on model model_name of version model_version.
        """

        model = self._get_model(model_name, model_version)
        if model is None:
            raise UserRuntimeException(
                f"model {model_name} of version {model_version} wasn't found")

        try:
            input_dict = convert_to_onnx_input(model_input,
                                               model["signatures"], model_name)
            return model["session"].run([], input_dict)
        except Exception as e:
            raise UserRuntimeException(
                f"failed inference with model {model_name} of version {model_version}",
                str(e))
예제 #13
0
파일: api.py 프로젝트: rogervaas/cortex
def run_predict(payload, debug=False):
    ctx = local_cache["ctx"]
    api = local_cache["api"]
    request_handler = local_cache.get("request_handler")

    prepared_payload = payload

    debug_obj("payload", payload, debug)
    if request_handler is not None and util.has_function(
            request_handler, "pre_inference"):
        try:
            prepared_payload = request_handler.pre_inference(
                payload,
                local_cache["model_metadata"]["signatureDef"],
                api["tensorflow"]["metadata"],
            )
            debug_obj("pre_inference", prepared_payload, debug)
        except Exception as e:
            raise UserRuntimeException(api["tensorflow"]["request_handler"],
                                       "pre_inference request handler",
                                       str(e)) from e

    validate_payload(prepared_payload)

    prediction_request = create_prediction_request(prepared_payload)
    response_proto = local_cache["stub"].Predict(prediction_request,
                                                 timeout=300.0)
    result = parse_response_proto(response_proto)
    debug_obj("inference", result, debug)

    if request_handler is not None and util.has_function(
            request_handler, "post_inference"):
        try:
            result = request_handler.post_inference(
                result, local_cache["model_metadata"]["signatureDef"],
                api["tensorflow"]["metadata"])
            debug_obj("post_inference", result, debug)
        except Exception as e:
            raise UserRuntimeException(api["tensorflow"]["request_handler"],
                                       "post_inference request handler",
                                       str(e)) from e

    return result
예제 #14
0
 def initialize_impl(self, project_dir, client=None):
     class_impl = self.class_impl(project_dir)
     try:
         if self.type == "python":
             return class_impl(self.config)
         else:
             return class_impl(client, self.config)
     except Exception as e:
         raise UserRuntimeException(self.path, "__init__", str(e)) from e
     finally:
         refresh_logger()
예제 #15
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api["predictor"]["type"] != "onnx":
            raise CortexException(api["name"], "predictor type is not onnx")

        cx_logger().info("loading the predictor from {}".format(api["predictor"]["path"]))

        _, prefix = ctx.storage.deconstruct_s3_path(api["predictor"]["model"])
        model_path = os.path.join(args.model_dir, os.path.basename(prefix))
        local_cache["client"] = ONNXClient(model_path)

        predictor_class = ctx.get_predictor_class(api["name"], args.project_dir)

        try:
            local_cache["predictor"] = predictor_class(
                local_cache["client"], api["predictor"]["config"]
            )
        except Exception as e:
            raise UserRuntimeException(api["predictor"]["path"], "__init__", str(e)) from e
        finally:
            refresh_logger()
    except Exception as e:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get("model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            cx_logger().warn("an error occurred while attempting to load classes", exc_info=True)

    cx_logger().info("ONNX model signature: {}".format(local_cache["client"].input_signature))

    waitress_kwargs = {}
    if api["predictor"].get("config") is not None:
        for key, value in api["predictor"]["config"].items():
            if key.startswith("waitress_"):
                waitress_kwargs[key[len("waitress_") :]] = value

    if len(waitress_kwargs) > 0:
        cx_logger().info("waitress parameters: {}".format(waitress_kwargs))

    waitress_kwargs["listen"] = "*:{}".format(args.port)

    cx_logger().info("{} api is live".format(api["name"]))
    open("/health_check.txt", "a").close()
    serve(app, **waitress_kwargs)
예제 #16
0
 def _get_latest_model_version_from_disk(self, model_name: str) -> str:
     """
     Get the highest version of a specific model name.
     Must only be used when caching disabled and processes_per_replica > 0.
     """
     versions, timestamps = find_ondisk_model_info(self._lock_dir,
                                                   model_name)
     if len(versions) == 0:
         raise UserRuntimeException(
             "'{}' model's versions have been removed; add at least a version to the model to resume operations"
             .format(model_name))
     return str(max(map(lambda x: int(x), versions)))
예제 #17
0
 def initialize_impl(self, project_dir, client=None):
     class_impl = self.class_impl(project_dir)
     try:
         if self.type == "onnx":
             return class_impl(onnx_client=client, config=self.config)
         elif self.type == "tensorflow":
             return class_impl(tensorflow_client=client, config=self.config)
         else:
             return class_impl(config=self.config)
     except Exception as e:
         raise UserRuntimeException(self.path, "__init__", str(e)) from e
     finally:
         refresh_logger()
예제 #18
0
파일: api.py 프로젝트: rogervaas/cortex
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api.get("predictor") is None:
            raise CortexException(api["name"], "predictor key not configured")

        cx_logger().info("loading the predictor from {}".format(
            api["predictor"]["path"]))
        local_cache["predictor"] = ctx.get_predictor_impl(
            api["name"], args.project_dir)

        if util.has_function(local_cache["predictor"], "init"):
            try:
                model_path = None
                if api["predictor"].get("model") is not None:
                    _, prefix = ctx.storage.deconstruct_s3_path(
                        api["predictor"]["model"])
                    model_path = os.path.join(
                        args.model_dir,
                        os.path.basename(os.path.normpath(prefix)))

                cx_logger().info("calling the predictor's init() function")
                local_cache["predictor"].init(model_path,
                                              api["predictor"]["metadata"])
            except Exception as e:
                raise UserRuntimeException(api["predictor"]["path"], "init",
                                           str(e)) from e
            finally:
                refresh_logger()
    except:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            cx_logger().warn(
                "an error occurred while attempting to load classes",
                exc_info=True)

    cx_logger().info("{} api is live".format(api["name"]))
    serve(app, listen="*:{}".format(args.port))
예제 #19
0
파일: spark_util.py 프로젝트: kwecht/cortex
def transform_column(column_name, df, ctx, spark):
    if not ctx.is_transformed_column(column_name):
        return df
    if column_name in df.columns:
        return df

    transformed_column = ctx.transformed_columns[column_name]
    trans_impl, _ = ctx.get_transformer_impl(column_name)

    if hasattr(trans_impl, "transform_spark"):
        try:
            df = execute_transform_spark(column_name, df, ctx, spark)
            return df.withColumn(
                column_name,
                F.col(column_name).cast(CORTEX_TYPE_TO_SPARK_TYPE[
                    ctx.get_inferred_column_type(column_name)]),
            )
        except CortexException as e:
            raise UserRuntimeException(
                "transformed column " + column_name,
                transformed_column["transformer"] + ".transform_spark",
            ) from e
    elif hasattr(trans_impl, "transform_python"):
        try:
            return execute_transform_python(column_name, df, ctx, spark)
        except Exception as e:
            raise UserRuntimeException(
                "transformed column " + column_name,
                transformed_column["transformer"] + ".transform_python",
            ) from e
    else:
        raise UserException(
            "transformed column " + column_name,
            "transformer " + transformed_column["transformer"],
            "transform_spark(), transform_python(), or both must be defined",
        )
예제 #20
0
def predict(request: Request):
    tasks = BackgroundTasks()
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]
    kwargs = build_predict_kwargs(request)

    prediction = predictor_impl.predict(**kwargs)

    if isinstance(prediction, bytes):
        response = Response(content=prediction, media_type="application/octet-stream")
    elif isinstance(prediction, str):
        response = Response(content=prediction, media_type="text/plain")
    elif isinstance(prediction, Response):
        response = prediction
    else:
        try:
            json_string = json.dumps(prediction)
        except Exception as e:
            raise UserRuntimeException(
                str(e),
                "please return an object that is JSON serializable (including its nested fields), a bytes object, a string, or a starlette.response.Response object",
            ) from e
        response = Response(content=json_string, media_type="application/json")

    if local_cache["provider"] != "local" and api.monitoring is not None:
        try:
            predicted_value = api.monitoring.extract_predicted_value(prediction)
            api.post_monitoring_metrics(predicted_value)
            if (
                api.monitoring.model_type == "classification"
                and predicted_value not in local_cache["class_set"]
            ):
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
        except:
            cx_logger().warn("unable to record prediction metric", exc_info=True)

    if util.has_method(predictor_impl, "post_predict"):
        kwargs = build_post_predict_kwargs(prediction, request)
        request_thread_pool.submit(predictor_impl.post_predict, **kwargs)

    if len(tasks.tasks) > 0:
        response.background = tasks

    return response
예제 #21
0
파일: api.py 프로젝트: jaytoday/cortex
def reverse_transform(value):
    ctx = local_cache["ctx"]
    model = local_cache["model"]
    target_col = local_cache["target_col"]

    trans_impl = local_cache["trans_impls"].get(target_col["name"])
    if not (trans_impl and hasattr(trans_impl, "reverse_transform_python")):
        return None

    input = ctx.populate_values(target_col["input"],
                                None,
                                preserve_column_refs=False)
    try:
        result = trans_impl.reverse_transform_python(value, input)
    except Exception as e:
        raise UserRuntimeException("transformer " + target_col["transformer"],
                                   "function reverse_transform_python") from e

    return result
예제 #22
0
파일: serve.py 프로젝트: zongzhenh/cortex
def predict(request: Any = Body(..., media_type="application/json"), debug=False):
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]

    debug_obj("payload", request, debug)
    prediction = predictor_impl.predict(request)

    if isinstance(prediction, bytes):
        response = Response(content=prediction, media_type="application/octet-stream")
    elif isinstance(prediction, str):
        response = Response(content=prediction, media_type="text/plain")
        debug_obj("prediction", prediction, debug)
    elif isinstance(prediction, Response):
        response = prediction
    else:
        try:
            json_string = json.dumps(prediction)
            debug_obj("prediction", prediction, debug)
        except Exception as e:
            raise UserRuntimeException(
                str(e),
                "please return an object that is JSON serializable (including its nested fields), a bytes object, a string, or a starlette.response.Response object",
            ) from e
        response = Response(content=json_string, media_type="application/json")

    if local_cache["provider"] != "local" and api.tracker is not None:
        try:
            predicted_value = api.tracker.extract_predicted_value(prediction)
            api.post_tracker_metrics(predicted_value)
            if (
                api.tracker.model_type == "classification"
                and predicted_value not in local_cache["class_set"]
            ):
                tasks = BackgroundTasks()
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
                response.background = tasks
        except:
            cx_logger().warn("unable to record prediction metric", exc_info=True)

    return response
예제 #23
0
파일: spark_util.py 프로젝트: kwecht/cortex
def execute_transform_spark(column_name, df, ctx, spark):
    trans_impl, trans_impl_path = ctx.get_transformer_impl(column_name)
    transformed_column = ctx.transformed_columns[column_name]
    transformer = ctx.transformers[transformed_column["transformer"]]

    if trans_impl_path not in ctx.spark_uploaded_impls:
        spark.sparkContext.addPyFile(
            trans_impl_path)  # Executor pods need this because of the UDF
        ctx.spark_uploaded_impls[trans_impl_path] = True

    try:
        input = ctx.populate_values(transformed_column["input"],
                                    transformer["input"],
                                    preserve_column_refs=False)
    except CortexException as e:
        e.wrap("input")
        raise

    try:
        return trans_impl.transform_spark(df, input, column_name)
    except Exception as e:
        raise UserRuntimeException("function transform_spark") from e
예제 #24
0
    def get_model(self,
                  model_name: Optional[str] = None,
                  model_version: str = "latest") -> dict:
        """
        Validate input and then return the model loaded into a dictionary.
        The counting of tag calls is recorded with this method (just like with the predict method).

        Args:
            model_name: Model to use when multiple models are deployed in a single API.
            model_version: Model version to use. Can also be "latest" for picking the highest version.

        Returns:
            The model as returned by _load_model method.

        Raises:
            UserRuntimeException if the validation fails.
        """
        model_name, model_version = self._validate_model_args(
            model_name, model_version)
        model = self._get_model(model_name, model_version)
        if model is None:
            raise UserRuntimeException(
                f"model {model_name} of version {model_version} wasn't found")
        return model
예제 #25
0
파일: spark_util.py 프로젝트: kwecht/cortex
def validate_transformer(column_name, test_df, ctx, spark):
    transformed_column = ctx.transformed_columns[column_name]
    transformer = ctx.transformers[transformed_column["transformer"]]
    trans_impl, _ = ctx.get_transformer_impl(column_name)

    inferred_python_type = None
    inferred_spark_type = None

    if hasattr(trans_impl, "transform_python"):
        try:
            if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED:
                sample_df = test_df.collect()
                sample = sample_df[0]
                try:
                    input = ctx.populate_values(transformed_column["input"],
                                                transformer["input"],
                                                preserve_column_refs=True)
                except CortexException as e:
                    e.wrap("input")
                    raise
                transformer_input = create_transformer_inputs_from_map(
                    input, sample)
                initial_transformed_value = trans_impl.transform_python(
                    transformer_input)
                inferred_python_type = infer_python_type(
                    initial_transformed_value)

                for row in sample_df:
                    transformer_input = create_transformer_inputs_from_map(
                        input, row)
                    transformed_value = trans_impl.transform_python(
                        transformer_input)
                    if inferred_python_type != infer_python_type(
                            transformed_value):
                        raise UserException(
                            "transformed column " + column_name,
                            "type inference failed, mixed data types in dataframe.",
                            'expected type of "' + transformed_sample +
                            '" to be ' + inferred_python_type,
                        )

                ctx.write_metadata(transformed_column["id"],
                                   {"type": inferred_python_type})

            transform_python_collect = execute_transform_python(
                column_name, test_df, ctx, spark, validate=True).collect()
        except Exception as e:
            raise UserRuntimeException(
                "transformed column " + column_name,
                transformed_column["transformer"] + ".transform_python",
            ) from e

    if hasattr(trans_impl, "transform_spark"):
        try:
            transform_spark_df = execute_transform_spark(
                column_name, test_df, ctx, spark)

            # check that the return object is a dataframe
            if type(transform_spark_df) is not DataFrame:
                raise UserException(
                    "expected pyspark.sql.dataframe.DataFrame but got type {}".
                    format(type(transform_spark_df)))

            # check that a column is added with the expected name
            if column_name not in transform_spark_df.columns:
                logger.error("schema of output dataframe:")
                log_df_schema(transform_spark_df, logger.error)

                raise UserException(
                    "output dataframe after running transformer does not have column {}"
                    .format(column_name))

            if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED:
                inferred_spark_type = SPARK_TYPE_TO_CORTEX_TYPE[
                    transform_spark_df.select(column_name).schema[0].dataType]
                ctx.write_metadata(transformed_column["id"],
                                   {"type": inferred_spark_type})

            # check that transformer run on data
            try:
                transform_spark_df.select(column_name).collect()
            except Exception as e:
                raise UserRuntimeException("function transform_spark") from e

            # check that expected output column has the correct data type
            if transformer["output_type"] != consts.COLUMN_TYPE_INFERRED:
                actual_structfield = transform_spark_df.select(
                    column_name).schema.fields[0]
                if (actual_structfield.dataType
                        not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[
                            transformer["output_type"]]):
                    raise UserException(
                        "incorrect column type: expected {}, got {}.".format(
                            " or ".join(
                                str(t)
                                for t in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[
                                    transformer["output_type"]]),
                            actual_structfield.dataType,
                        ))

            # perform the necessary casting for the column
            transform_spark_df = transform_spark_df.withColumn(
                column_name,
                F.col(column_name).cast(CORTEX_TYPE_TO_SPARK_TYPE[
                    ctx.get_inferred_column_type(column_name)]),
            )

            # check that the function doesn't modify the schema of the other columns in the input dataframe
            if set(transform_spark_df.columns) - set([column_name]) != set(
                    test_df.columns):
                logger.error("expected schema:")

                log_df_schema(test_df, logger.error)

                logger.error(
                    "found schema (with {} dropped):".format(column_name))
                log_df_schema(transform_spark_df.drop(column_name),
                              logger.error)

                raise UserException(
                    "a column besides {} was modifed in the output dataframe".
                    format(column_name))
        except CortexException as e:
            raise UserRuntimeException(
                "transformed column " + column_name,
                transformed_column["transformer"] + ".transform_spark",
            ) from e

    if hasattr(trans_impl, "transform_spark") and hasattr(
            trans_impl, "transform_python"):
        if (transformer["output_type"] == consts.COLUMN_TYPE_INFERRED
                and inferred_spark_type != inferred_python_type):
            raise UserException(
                "transformed column " + column_name,
                "type inference failed, transform_spark and transform_python had differing types.",
                "transform_python: " + inferred_python_type,
                "transform_spark: " + inferred_spark_type,
            )

        name_type_map = [(s.name, s.dataType)
                         for s in transform_spark_df.schema]
        transform_spark_collect = transform_spark_df.collect()

        for tp_row, ts_row in zip(transform_python_collect,
                                  transform_spark_collect):
            tp_dict = tp_row.asDict()
            ts_dict = ts_row.asDict()

            for name, dataType in name_type_map:
                if tp_dict[name] == ts_dict[name]:
                    continue
                elif dataType == FloatType() and util.isclose(
                        tp_dict[name], ts_dict[name], FLOAT_PRECISION):
                    continue
                raise UserException(
                    column_name,
                    "{0}.transform_spark and {0}.transform_python had differing values"
                    .format(transformed_column["transformer"]),
                    "{} != {}".format(ts_row, tp_row),
                )
예제 #26
0
    def _get_model(self, model_name: str, model_version: str) -> Any:
        """
        Checks if versioned model is on disk, then checks if model is in memory,
        and if not, it loads it into memory, and returns the model.

        Args:
            model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk.
            model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" version tag.

        Exceptions:
            RuntimeError: if another thread tried to load the model at the very same time.

        Returns:
            The model as returned by self._load_model method.
            None if the model wasn't found or if it didn't pass the validation.
        """

        model = None
        tag = ""
        if model_version == "latest":
            tag = model_version

        if not self._caching_enabled:
            # determine model version
            if tag == "latest":
                model_version = self._get_latest_model_version_from_disk(
                    model_name)
            model_id = model_name + "-" + model_version

            # grab shared access to versioned model
            resource = os.path.join(self._lock_dir, model_id + ".txt")
            with LockedFile(resource, "r", reader_lock=True) as f:

                # check model status
                file_status = f.read()
                if file_status == "" or file_status == "not-available":
                    raise WithBreak

                current_upstream_ts = int(file_status.split(" ")[1])
                update_model = False

                # grab shared access to models holder and retrieve model
                with LockedModel(self._models, "r", model_name, model_version):
                    status, local_ts = self._models.has_model(
                        model_name, model_version)
                    if status == "not-available" or (
                            status == "in-memory"
                            and local_ts != current_upstream_ts):
                        update_model = True
                        raise WithBreak
                    model, _ = self._models.get_model(model_name,
                                                      model_version, tag)

                # load model into memory and retrieve it
                if update_model:
                    with LockedModel(self._models, "w", model_name,
                                     model_version):
                        status, _ = self._models.has_model(
                            model_name, model_version)
                        if status == "not-available" or (
                                status == "in-memory"
                                and local_ts != current_upstream_ts):
                            if status == "not-available":
                                logger().info(
                                    f"loading model {model_name} of version {model_version} (thread {td.get_ident()})"
                                )
                            else:
                                logger().info(
                                    f"reloading model {model_name} of version {model_version} (thread {td.get_ident()})"
                                )
                            try:
                                self._models.load_model(
                                    model_name,
                                    model_version,
                                    current_upstream_ts,
                                    [tag],
                                )
                            except Exception as e:
                                raise UserRuntimeException(
                                    f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})",
                                    str(e),
                                )
                        model, _ = self._models.get_model(
                            model_name, model_version, tag)

        if not self._multiple_processes and self._caching_enabled:
            # determine model version
            try:
                if tag == "latest":
                    model_version = self._get_latest_model_version_from_tree(
                        model_name, self._models_tree.model_info(model_name))
            except ValueError:
                # if model_name hasn't been found
                raise UserRuntimeException(
                    f"'{model_name}' model of tag {tag} wasn't found in the list of available models"
                )

            # grab shared access to model tree
            available_model = True
            with LockedModelsTree(self._models_tree, "r", model_name,
                                  model_version):

                # check if the versioned model exists
                model_id = model_name + "-" + model_version
                if model_id not in self._models_tree:
                    available_model = False
                    raise WithBreak

                # retrieve model tree's metadata
                upstream_model = self._models_tree[model_id]
                current_upstream_ts = int(
                    upstream_model["timestamp"].timestamp())

            if not available_model:
                return None

            # grab shared access to models holder and retrieve model
            update_model = False
            with LockedModel(self._models, "r", model_name, model_version):
                status, local_ts = self._models.has_model(
                    model_name, model_version)
                if status in [
                        "not-available", "on-disk"
                ] or (status != "not-available"
                      and local_ts != current_upstream_ts
                      and not (status == "in-memory" and model_name
                               in self._spec_local_model_names)):
                    update_model = True
                    raise WithBreak
                model, _ = self._models.get_model(model_name, model_version,
                                                  tag)

            # download, load into memory the model and retrieve it
            if update_model:
                # grab exclusive access to models holder
                with LockedModel(self._models, "w", model_name, model_version):

                    # check model status
                    status, local_ts = self._models.has_model(
                        model_name, model_version)

                    # refresh disk model
                    if model_name not in self._spec_local_model_names and (
                            status == "not-available" or
                        (status in ["on-disk", "in-memory"]
                         and local_ts != current_upstream_ts)):
                        if status == "not-available":
                            logger().info(
                                f"model {model_name} of version {model_version} not found locally; continuing with the download..."
                            )
                        elif status == "on-disk":
                            logger().info(
                                f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one on the disk"
                            )
                        else:
                            logger().info(
                                f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one loaded into memory"
                            )

                        # remove model from disk and memory
                        if status == "on-disk":
                            logger().info(
                                f"removing model from disk for model {model_name} of version {model_version}"
                            )
                            self._models.remove_model(model_name,
                                                      model_version)
                        if status == "in-memory":
                            logger().info(
                                f"removing model from disk and memory for model {model_name} of version {model_version}"
                            )
                            self._models.remove_model(model_name,
                                                      model_version)

                        # download model
                        logger().info(
                            f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream"
                        )
                        date = self._models.download_model(
                            upstream_model["provider"],
                            upstream_model["bucket"],
                            model_name,
                            model_version,
                            upstream_model["path"],
                        )
                        if not date:
                            raise WithBreak
                        current_upstream_ts = int(date.timestamp())

                    # give the local model a timestamp initialized at start time
                    if model_name in self._spec_local_model_names:
                        current_upstream_ts = self._local_model_ts

                    # load model
                    try:
                        logger().info(
                            f"loading model {model_name} of version {model_version} into memory"
                        )
                        self._models.load_model(
                            model_name,
                            model_version,
                            current_upstream_ts,
                            [tag],
                        )
                    except Exception as e:
                        raise UserRuntimeException(
                            f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})",
                            str(e),
                        )

                    # retrieve model
                    model, _ = self._models.get_model(model_name,
                                                      model_version, tag)

        return model
예제 #27
0
    def _run_inference(self, model_input: Any, model_name: str,
                       model_version: str) -> dict:
        """
        When processes_per_replica = 1 and caching enabled, check/load model and make prediction.
        When processes_per_replica > 0 and caching disabled, attempt to make prediction regardless.

        Args:
            model_input: Input to the model.
            model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk.
            model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" version tag.

        Returns:
            The prediction.
        """

        model = None
        tag = ""
        if model_version == "latest":
            tag = model_version

        if not self._caching_enabled:

            # determine model version
            if tag == "latest":
                versions = self._client.poll_available_model_versions(
                    model_name)
                if len(versions) == 0:
                    raise UserException(
                        f"model '{model_name}' accessed with tag {tag} couldn't be found"
                    )
                model_version = str(max(map(lambda x: int(x), versions)))
            model_id = model_name + "-" + model_version

            return self._client.predict(model_input, model_name, model_version)

        if not self._multiple_processes and self._caching_enabled:

            # determine model version
            try:
                if tag == "latest":
                    model_version = self._get_latest_model_version_from_tree(
                        model_name, self._models_tree.model_info(model_name))
            except ValueError:
                # if model_name hasn't been found
                raise UserRuntimeException(
                    f"'{model_name}' model of tag {tag} wasn't found in the list of available models"
                )

            models_stats = []
            for model_id in self._models.get_model_ids():
                models_stats = self._models.has_model_id(model_id)

            # grab shared access to model tree
            available_model = True
            logger().info(
                f"grabbing access to model {model_name} of version {model_version}"
            )
            with LockedModelsTree(self._models_tree, "r", model_name,
                                  model_version):

                # check if the versioned model exists
                model_id = model_name + "-" + model_version
                if model_id not in self._models_tree:
                    available_model = False
                    logger().info(
                        f"model {model_name} of version {model_version} is not available"
                    )
                    raise WithBreak

                # retrieve model tree's metadata
                upstream_model = self._models_tree[model_id]
                current_upstream_ts = int(
                    upstream_model["timestamp"].timestamp())
                logger().info(
                    f"model {model_name} of version {model_version} is available"
                )

            if not available_model:
                if tag == "":
                    raise UserException(
                        f"model '{model_name}' of version '{model_version}' couldn't be found"
                    )
                raise UserException(
                    f"model '{model_name}' accessed with tag '{tag}' couldn't be found"
                )

            # grab shared access to models holder and retrieve model
            update_model = False
            prediction = None
            tfs_was_unresponsive = False
            with LockedModel(self._models, "r", model_name, model_version):
                logger().info(
                    f"checking the {model_name} {model_version} status")
                status, local_ts = self._models.has_model(
                    model_name, model_version)
                if status in ["not-available", "on-disk"
                              ] or (status != "not-available"
                                    and local_ts != current_upstream_ts):
                    logger().info(
                        f"model {model_name} of version {model_version} is not loaded (with status {status} or different timestamp)"
                    )
                    update_model = True
                    raise WithBreak

                # run prediction
                logger().info(
                    f"run the prediction on model {model_name} of version {model_version}"
                )
                self._models.get_model(model_name, model_version, tag)
                try:
                    prediction = self._client.predict(model_input, model_name,
                                                      model_version)
                except grpc.RpcError as e:
                    # effectively when it got restarted
                    if len(
                            self._client.poll_available_model_versions(
                                model_name)) > 0:
                        raise
                    tfs_was_unresponsive = True

            # remove model from disk and memory references if TFS gets unresponsive
            if tfs_was_unresponsive:
                with LockedModel(self._models, "w", model_name, model_version):
                    available_versions = self._client.poll_available_model_versions(
                        model_name)
                    status, _ = self._models.has_model(model_name,
                                                       model_version)
                    if not (status == "in-memory"
                            and model_version not in available_versions):
                        raise WithBreak

                    logger().info(
                        f"removing model {model_name} of version {model_version} because TFS got unresponsive"
                    )
                    self._models.remove_model(model_name, model_version)

            # download, load into memory the model and retrieve it
            if update_model:
                # grab exclusive access to models holder
                with LockedModel(self._models, "w", model_name, model_version):

                    # check model status
                    status, local_ts = self._models.has_model(
                        model_name, model_version)

                    # refresh disk model
                    if status == "not-available" or (
                            status in ["on-disk", "in-memory"]
                            and local_ts != current_upstream_ts):
                        # unload model from TFS
                        if status == "in-memory":
                            try:
                                logger().info(
                                    f"unloading model {model_name} of version {model_version} from TFS"
                                )
                                self._models.unload_model(
                                    model_name, model_version)
                            except Exception:
                                logger().info(
                                    f"failed unloading model {model_name} of version {model_version} from TFS"
                                )
                                raise

                        # remove model from disk and references
                        if status in ["on-disk", "in-memory"]:
                            logger().info(
                                f"removing model references from memory and from disk for model {model_name} of version {model_version}"
                            )
                            self._models.remove_model(model_name,
                                                      model_version)

                        # download model
                        logger().info(
                            f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream"
                        )
                        date = self._models.download_model(
                            upstream_model["provider"],
                            upstream_model["bucket"],
                            model_name,
                            model_version,
                            upstream_model["path"],
                        )
                        if not date:
                            raise WithBreak
                        current_upstream_ts = int(date.timestamp())

                    # load model
                    try:
                        logger().info(
                            f"loading model {model_name} of version {model_version} into memory"
                        )
                        self._models.load_model(
                            model_name,
                            model_version,
                            current_upstream_ts,
                            [tag],
                            kwargs={
                                "model_name":
                                model_name,
                                "model_version":
                                model_version,
                                "signature_key":
                                self._determine_model_signature_key(
                                    model_name),
                            },
                        )
                    except Exception as e:
                        raise UserRuntimeException(
                            f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})",
                            str(e),
                        )

                    # run prediction
                    self._models.get_model(model_name, model_version, tag)
                    prediction = self._client.predict(model_input, model_name,
                                                      model_version)

            return prediction
예제 #28
0
def train(model_name, estimator_impl, ctx, model_dir):
    model = ctx.models[model_name]

    util.mkdir_p(model_dir)
    util.rm_dir(model_dir)

    tf_lib.set_logging_verbosity(ctx.environment["log_level"]["tensorflow"])

    run_config = tf.estimator.RunConfig(
        tf_random_seed=model["training"]["tf_random_seed"],
        save_summary_steps=model["training"]["save_summary_steps"],
        save_checkpoints_secs=model["training"]["save_checkpoints_secs"],
        save_checkpoints_steps=model["training"]["save_checkpoints_steps"],
        log_step_count_steps=model["training"]["log_step_count_steps"],
        keep_checkpoint_max=model["training"]["keep_checkpoint_max"],
        keep_checkpoint_every_n_hours=model["training"]
        ["keep_checkpoint_every_n_hours"],
        model_dir=model_dir,
    )

    train_input_fn = generate_input_fn(model_name, ctx, "training",
                                       estimator_impl)
    eval_input_fn = generate_input_fn(model_name, ctx, "evaluation",
                                      estimator_impl)
    serving_input_fn = generate_json_serving_input_fn(model_name, ctx,
                                                      estimator_impl)
    exporter = tf.estimator.FinalExporter("estimator",
                                          serving_input_fn,
                                          as_text=False)

    train_num_steps = model["training"]["num_steps"]
    dataset_metadata = ctx.get_metadata(model["dataset"]["id"])
    if model["training"]["num_epochs"]:
        train_num_steps = (math.ceil(dataset_metadata["training_size"] /
                                     float(model["training"]["batch_size"])) *
                           model["training"]["num_epochs"])

    train_spec = tf.estimator.TrainSpec(train_input_fn,
                                        max_steps=train_num_steps)

    eval_num_steps = model["evaluation"]["num_steps"]
    if model["evaluation"]["num_epochs"]:
        eval_num_steps = (math.ceil(dataset_metadata["eval_size"] /
                                    float(model["evaluation"]["batch_size"])) *
                          model["evaluation"]["num_epochs"])

    eval_spec = tf.estimator.EvalSpec(
        eval_input_fn,
        steps=eval_num_steps,
        exporters=[exporter],
        name="estimator-eval",
        start_delay_secs=model["evaluation"]["start_delay_secs"],
        throttle_secs=model["evaluation"]["throttle_secs"],
    )

    model_config = ctx.model_config(model_name)

    try:
        tf_estimator = estimator_impl.create_estimator(run_config,
                                                       model_config)
    except Exception as e:
        raise UserRuntimeException("model " + model_name) from e

    target_col_name = util.get_resource_ref(model["target_column"])
    if ctx.get_inferred_column_type(
            target_col_name) == consts.COLUMN_TYPE_FLOAT:
        tf_estimator = tf.contrib.estimator.add_metrics(
            tf_estimator, get_regression_eval_metrics)

    tf.estimator.train_and_evaluate(tf_estimator, train_spec, eval_spec)

    return model_dir
예제 #29
0
    def initialize_impl(
        self,
        project_dir: str,
        client: Union[PythonClient, TensorFlowClient, ONNXClient],
        job_spec: Optional[dict] = None,
    ):
        """
        Initialize predictor class as provided by the user.

        job_spec is a dictionary when the "kind" of the API is set to "BatchAPI". Otherwise, it's None.
        """

        # build args
        class_impl = self.class_impl(project_dir)
        constructor_args = inspect.getfullargspec(class_impl.__init__).args
        config = deepcopy(self.config)
        args = {}
        if job_spec is not None and job_spec.get("config") is not None:
            util.merge_dicts_in_place_overwrite(config, job_spec["config"])
        if "config" in constructor_args:
            args["config"] = config
        if "job_spec" in constructor_args:
            args["job_spec"] = job_spec

        # initialize predictor class
        try:
            if self.type == PythonPredictorType:
                if _are_models_specified(self.api_spec):
                    args["python_client"] = client
                    initialized_impl = class_impl(**args)
                    client.set_load_method(initialized_impl.load_model)
                else:
                    initialized_impl = class_impl(**args)
            if self.type in [TensorFlowPredictorType, TensorFlowNeuronPredictorType]:
                args["tensorflow_client"] = client
                initialized_impl = class_impl(**args)
            if self.type == ONNXPredictorType:
                args["onnx_client"] = client
                initialized_impl = class_impl(**args)
        except Exception as e:
            raise UserRuntimeException(self.path, "__init__", str(e)) from e
        finally:
            refresh_logger()

        # initialize the crons if models have been specified and if the API kind is RealtimeAPI
        if _are_models_specified(self.api_spec) and self.api_spec["kind"] == "RealtimeAPI":
            if not self.multiple_processes and self.caching_enabled:
                self.crons += [
                    ModelTreeUpdater(
                        interval=10,
                        api_spec=self.api_spec,
                        tree=self.models_tree,
                        ondisk_models_dir=self.model_dir,
                    ),
                    ModelsGC(
                        interval=10,
                        api_spec=self.api_spec,
                        models=self.models,
                        tree=self.models_tree,
                    ),
                ]

            if not self.caching_enabled and self.type in [PythonPredictorType, ONNXPredictorType]:
                self.crons += [
                    FileBasedModelsGC(interval=10, models=self.models, download_dir=self.model_dir)
                ]

        for cron in self.crons:
            cron.start()

        return initialized_impl