Пример #1
0
    def get_payload(self, request_id: str) -> Union[Dict, str, bytes]:
        key = f"{self.storage_path}/{request_id}/payload"
        obj = self.storage.get_object(key)
        status_code = obj["ResponseMetadata"]["HTTPStatusCode"]
        if status_code != HTTPStatus.OK:
            raise CortexException(
                f"failed to retrieve async payload (request_id: {request_id}, status_code: {status_code})"
            )

        content_type: str = obj["ResponseMetadata"]["HTTPHeaders"][
            "content-type"]
        payload_bytes: bytes = obj["Body"].read()

        # decode payload
        if content_type.startswith("application/json"):
            try:
                return json.loads(payload_bytes)
            except Exception as err:
                raise UserRuntimeException(
                    f"the uploaded payload, with content-type {content_type}, could not be decoded to JSON"
                ) from err
        elif content_type.startswith("text/plain"):
            try:
                return payload_bytes.decode("utf-8")
            except Exception as err:
                raise UserRuntimeException(
                    f"the uploaded payload, with content-type {content_type}, could not be decoded to a utf-8 string"
                ) from err
        else:
            return payload_bytes
Пример #2
0
    def upload_result(self, request_id: str, result: Dict[str, Any]):
        if not isinstance(result, dict):
            raise UserRuntimeException(
                f"user response must be json serializable dictionary, got {type(result)} instead"
            )

        try:
            result_json = json.dumps(result)
        except Exception:
            raise UserRuntimeException("user response is not json serializable")

        self.storage.put_str(result_json, f"{self.storage_path}/{request_id}/result.json")
Пример #3
0
    def _validate_model_args(
        self, model_name: Optional[str] = None, model_version: str = "latest"
    ) -> Tuple[str, str]:
        """
        Validate the model name and model version.

        Args:
            model_name: Name of the model.
            model_version: Model version to use. Can also be "latest" for picking the highest version.

        Returns:
            The processed model_name, model_version tuple if they had to go through modification.

        Raises:
            UserRuntimeException if the validation fails.
        """

        if model_version != "latest" and not model_version.isnumeric():
            raise UserRuntimeException(
                "model_version must be either a parse-able numeric value or 'latest'"
            )

        # when predictor:models:path or predictor:models:paths is specified
        if not self._models_dir:

            # when when predictor:models:path is provided
            if consts.SINGLE_MODEL_NAME in self._spec_model_names:
                return consts.SINGLE_MODEL_NAME, model_version

            # when predictor:models:paths is specified
            if model_name is None:
                raise UserRuntimeException(
                    f"model_name was not specified, choose one of the following: {self._spec_model_names}"
                )

            if model_name not in self._spec_model_names:
                raise UserRuntimeException(
                    f"'{model_name}' model wasn't found in the list of available models"
                )

        # when predictor:models:dir is specified
        if self._models_dir:
            if model_name is None:
                raise UserRuntimeException("model_name was not specified")
            if not self._caching_enabled:
                available_models = find_ondisk_models_with_lock(self._lock_dir)
                if model_name not in available_models:
                    raise UserRuntimeException(
                        f"'{model_name}' model wasn't found in the list of available models"
                    )

        return model_name, model_version
Пример #4
0
    def predict(self,
                model_input: Any,
                model_name: Optional[str] = None,
                model_version: str = "latest") -> dict:
        """
        Validate model_input, convert it to a Prediction Proto, and make a request to TensorFlow Serving.

        Args:
            model_input: Input to the model.
            model_name (optional): Name of the model to retrieve (when multiple models are deployed in an API).
                When predictor.models.paths is specified, model_name should be the name of one of the models listed in the API config.
                When predictor.models.dir is specified, model_name should be the name of a top-level directory in the models dir.
            model_version (string, optional): Version of the model to retrieve. Can be omitted or set to "latest" to select the highest version.

        Returns:
            dict: TensorFlow Serving response converted to a dictionary.
        """

        if model_version != "latest" and not model_version.isnumeric():
            raise UserRuntimeException(
                "model_version must be either a parse-able numeric value or 'latest'"
            )

        # when ppredictor:models:path or predictor:models:paths is specified
        if not self._models_dir:

            # when predictor:models:path is provided
            if consts.SINGLE_MODEL_NAME in self._spec_model_names:
                return self._run_inference(model_input,
                                           consts.SINGLE_MODEL_NAME,
                                           model_version)

            # when predictor:models:paths is specified
            if model_name is None:
                raise UserRuntimeException(
                    f"model_name was not specified, choose one of the following: {self._spec_model_names}"
                )

            if model_name not in self._spec_model_names:
                raise UserRuntimeException(
                    f"'{model_name}' model wasn't found in the list of available models"
                )

        # when predictor:models:dir is specified
        if self._models_dir and model_name is None:
            raise UserRuntimeException("model_name was not specified")

        return self._run_inference(model_input, model_name, model_version)
Пример #5
0
def predict(request: Request):
    predictor_impl = local_cache["predictor_impl"]
    dynamic_batcher = local_cache["dynamic_batcher"]
    kwargs = build_predict_kwargs(request)

    if dynamic_batcher:
        prediction = dynamic_batcher.predict(**kwargs)
    else:
        prediction = predictor_impl.predict(**kwargs)

    if isinstance(prediction, bytes):
        response = Response(content=prediction, media_type="application/octet-stream")
    elif isinstance(prediction, str):
        response = Response(content=prediction, media_type="text/plain")
    elif isinstance(prediction, Response):
        response = prediction
    else:
        try:
            json_string = json.dumps(prediction)
        except Exception as e:
            raise UserRuntimeException(
                str(e),
                "please return an object that is JSON serializable (including its nested fields), a bytes object, "
                "a string, or a starlette.response.Response object",
            ) from e
        response = Response(content=json_string, media_type="application/json")

    if util.has_method(predictor_impl, "post_predict"):
        kwargs = build_post_predict_kwargs(prediction, request)
        request_thread_pool.submit(predictor_impl.post_predict, **kwargs)

    return response
Пример #6
0
    def initialize_impl(
        self,
        project_dir: str,
        metrics_client: MetricsClient,
        tf_serving_host: str = None,
        tf_serving_port: str = None,
    ):
        predictor_impl = self._get_impl(project_dir)
        constructor_args = inspect.getfullargspec(predictor_impl.__init__).args
        config = deepcopy(self.config)

        args = {}
        if "config" in constructor_args:
            args["config"] = config
        if "metrics_client" in constructor_args:
            args["metrics_client"] = metrics_client

        if self.type in [
                TensorFlowPredictorType, TensorFlowNeuronPredictorType
        ]:
            tf_serving_address = tf_serving_host + ":" + tf_serving_port
            tf_client = TensorFlowClient(
                tf_serving_url=tf_serving_address,
                api_spec=self.api_spec,
            )
            tf_client.sync_models(lock_dir=self.lock_dir)
            args["tensorflow_client"] = tf_client

        try:
            predictor = predictor_impl(**args)
        except Exception as e:
            raise UserRuntimeException(self.path, "__init__", str(e)) from e

        return predictor
Пример #7
0
    def _batch_engine(self):
        while True:
            if len(self.predictions) > 0:
                time.sleep(0.001)
                continue

            try:
                self.barrier.wait(self.batch_interval)
            except td.BrokenBarrierError:
                pass

            self.predictions = {}
            sample_ids = self._get_sample_ids(self.batch_max_size)
            try:
                if self.samples:
                    batch = self._make_batch(sample_ids)

                    predictions = self.predictor_impl.predict(**batch)
                    if not isinstance(predictions, list):
                        raise UserRuntimeException(
                            f"please return a list when using server side batching, got {type(predictions)}"
                        )

                    if self.test_mode:
                        self._test_batch_lengths.append(len(predictions))

                    self.predictions = dict(zip(sample_ids, predictions))
            except Exception as e:
                self.predictions = {sample_id: e for sample_id in sample_ids}
                logger.error(traceback.format_exc())
            finally:
                for sample_id in sample_ids:
                    del self.samples[sample_id]
                self.barrier.reset()
Пример #8
0
    def get_model(self, model_name: Optional[str] = None, model_version: str = "latest") -> Any:
        """
        Retrieve a model for inference.

        Args:
            model_name (optional): Name of the model to retrieve (when multiple models are deployed in an API).
                When predictor.models.paths is specified, model_name should be the name of one of the models listed in the API config.
                When predictor.models.dir is specified, model_name should be the name of a top-level directory in the models dir.
            model_version (string, optional): Version of the model to retrieve. Can be omitted or set to "latest" to select the highest version.

        Returns:
            The value that's returned by your predictor's load_model() method.
        """

        if model_version != "latest" and not model_version.isnumeric():
            raise UserRuntimeException(
                "model_version must be either a parse-able numeric value or 'latest'"
            )

        # when predictor:models:path or predictor:models:paths is specified
        if not self._models_dir:

            # when predictor:models:path is provided
            if consts.SINGLE_MODEL_NAME in self._spec_model_names:
                model_name = consts.SINGLE_MODEL_NAME
                model = self._get_model(model_name, model_version)
                if model is None:
                    raise UserRuntimeException(
                        f"model {model_name} of version {model_version} wasn't found"
                    )
                return model

            # when predictor:models:paths is specified
            if model_name is None:
                raise UserRuntimeException(
                    f"model_name was not specified, choose one of the following: {self._spec_model_names}"
                )

            if model_name not in self._spec_model_names:
                raise UserRuntimeException(
                    f"'{model_name}' model wasn't found in the list of available models"
                )

        # when predictor:models:dir is specified
        if self._models_dir:
            if model_name is None:
                raise UserRuntimeException("model_name was not specified")
            if not self._caching_enabled:
                available_models = find_ondisk_models_with_lock(self._lock_dir)
                if model_name not in available_models:
                    raise UserRuntimeException(
                        f"'{model_name}' model wasn't found in the list of available models"
                    )

        model = self._get_model(model_name, model_version)
        if model is None:
            raise UserRuntimeException(
                f"model {model_name} of version {model_version} wasn't found"
            )
        return model
Пример #9
0
    def _run_inference(self, model_input: Any, model_name: str, model_version: str) -> Any:
        """
        Run the inference on model model_name of version model_version.
        """

        model = self._get_model(model_name, model_version)
        if model is None:
            raise UserRuntimeException(
                f"model {model_name} of version {model_version} wasn't found"
            )

        try:
            input_dict = convert_to_onnx_input(model_input, model["signatures"], model_name)
            return model["session"].run([], input_dict)
        except Exception as e:
            raise UserRuntimeException(
                f"failed inference with model {model_name} of version {model_version}", str(e)
            )
Пример #10
0
 def _get_latest_model_version_from_disk(self, model_name: str) -> str:
     """
     Get the highest version for a specific model name.
     Must only be used when processes_per_replica > 0 and caching disabled.
     """
     versions, timestamps = find_ondisk_model_info(self._lock_dir,
                                                   model_name)
     if len(versions) == 0:
         raise UserRuntimeException(
             "'{}' model's versions have been removed; add at least a version to the model to resume operations"
             .format(model_name))
     return str(max(map(lambda x: int(x), versions)))
Пример #11
0
def predict(request: Request):
    tasks = BackgroundTasks()
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]
    dynamic_batcher = local_cache["dynamic_batcher"]
    kwargs = build_predict_kwargs(request)

    if dynamic_batcher:
        prediction = dynamic_batcher.predict(**kwargs)
    else:
        prediction = predictor_impl.predict(**kwargs)

    if isinstance(prediction, bytes):
        response = Response(content=prediction,
                            media_type="application/octet-stream")
    elif isinstance(prediction, str):
        response = Response(content=prediction, media_type="text/plain")
    elif isinstance(prediction, Response):
        response = prediction
    else:
        try:
            json_string = json.dumps(prediction)
        except Exception as e:
            raise UserRuntimeException(
                str(e),
                "please return an object that is JSON serializable (including its nested fields), a bytes object, "
                "a string, or a starlette.response.Response object",
            ) from e
        response = Response(content=json_string, media_type="application/json")

    if local_cache["provider"] != "local" and api.monitoring is not None:
        try:
            predicted_value = api.monitoring.extract_predicted_value(
                prediction)
            api.post_monitoring_metrics(predicted_value)
            if (api.monitoring.model_type == "classification"
                    and predicted_value not in local_cache["class_set"]):
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
        except:
            logger().warn("unable to record prediction metric", exc_info=True)

    if util.has_method(predictor_impl, "post_predict"):
        kwargs = build_post_predict_kwargs(prediction, request)
        request_thread_pool.submit(predictor_impl.post_predict, **kwargs)

    if len(tasks.tasks) > 0:
        response.background = tasks

    return response
Пример #12
0
    def initialize_impl(self, project_dir: str, metrics_client: MetricsClient):
        predictor_impl = self._get_impl(project_dir)
        constructor_args = inspect.getfullargspec(predictor_impl.__init__).args
        config = deepcopy(self.config)

        args = {}
        if "config" in constructor_args:
            args["config"] = config
        if "metrics_client" in constructor_args:
            args["metrics_client"] = metrics_client

        try:
            predictor = predictor_impl(**args)
        except Exception as e:
            raise UserRuntimeException(self.path, "__init__", str(e)) from e

        return predictor
Пример #13
0
    def get_model(self, model_name: Optional[str] = None, model_version: str = "latest") -> dict:
        """
        Validate input and then return the model loaded into a dictionary.
        The counting of tag calls is recorded with this method (just like with the predict method).

        Args:
            model_name: Model to use when multiple models are deployed in a single API.
            model_version: Model version to use. Can also be "latest" for picking the highest version.

        Returns:
            The model as returned by _load_model method.

        Raises:
            UserRuntimeException if the validation fails.
        """
        model_name, model_version = self._validate_model_args(model_name, model_version)
        model = self._get_model(model_name, model_version)
        if model is None:
            raise UserRuntimeException(
                f"model {model_name} of version {model_version} wasn't found"
            )
        return model
Пример #14
0
    def _run_inference(self, model_input: Any, model_name: str,
                       model_version: str) -> dict:
        """
        When processes_per_replica = 1 and caching enabled, check/load model and make prediction.
        When processes_per_replica > 0 and caching disabled, attempt to make prediction regardless.

        Args:
            model_input: Input to the model.
            model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk.
            model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" version tag.

        Returns:
            The prediction.
        """

        model = None
        tag = ""
        if model_version == "latest":
            tag = model_version

        if not self._caching_enabled:

            # determine model version
            if tag == "latest":
                versions = self._client.poll_available_model_versions(
                    model_name)
                if len(versions) == 0:
                    raise UserException(
                        f"model '{model_name}' accessed with tag {tag} couldn't be found"
                    )
                model_version = str(max(map(lambda x: int(x), versions)))
            model_id = model_name + "-" + model_version

            return self._client.predict(model_input, model_name, model_version)

        if not self._multiple_processes and self._caching_enabled:

            # determine model version
            try:
                if tag == "latest":
                    model_version = self._get_latest_model_version_from_tree(
                        model_name, self._models_tree.model_info(model_name))
            except ValueError:
                # if model_name hasn't been found
                raise UserRuntimeException(
                    f"'{model_name}' model of tag {tag} wasn't found in the list of available models"
                )

            # grab shared access to model tree
            available_model = True
            logger.info(
                f"grabbing access to model {model_name} of version {model_version}"
            )
            with LockedModelsTree(self._models_tree, "r", model_name,
                                  model_version):

                # check if the versioned model exists
                model_id = model_name + "-" + model_version
                if model_id not in self._models_tree:
                    available_model = False
                    logger.info(
                        f"model {model_name} of version {model_version} is not available"
                    )
                    raise WithBreak

                # retrieve model tree's metadata
                upstream_model = self._models_tree[model_id]
                current_upstream_ts = int(
                    upstream_model["timestamp"].timestamp())
                logger.info(
                    f"model {model_name} of version {model_version} is available"
                )

            if not available_model:
                if tag == "":
                    raise UserException(
                        f"model '{model_name}' of version '{model_version}' couldn't be found"
                    )
                raise UserException(
                    f"model '{model_name}' accessed with tag '{tag}' couldn't be found"
                )

            # grab shared access to models holder and retrieve model
            update_model = False
            prediction = None
            tfs_was_unresponsive = False
            with LockedModel(self._models, "r", model_name, model_version):
                logger.info(
                    f"checking the {model_name} {model_version} status")
                status, local_ts = self._models.has_model(
                    model_name, model_version)
                if status in ["not-available", "on-disk"
                              ] or (status != "not-available"
                                    and local_ts != current_upstream_ts):
                    logger.info(
                        f"model {model_name} of version {model_version} is not loaded (with status {status} or different timestamp)"
                    )
                    update_model = True
                    raise WithBreak

                # run prediction
                logger.info(
                    f"run the prediction on model {model_name} of version {model_version}"
                )
                self._models.get_model(model_name, model_version, tag)
                try:
                    prediction = self._client.predict(model_input, model_name,
                                                      model_version)
                except grpc.RpcError as e:
                    # effectively when it got restarted
                    if len(
                            self._client.poll_available_model_versions(
                                model_name)) > 0:
                        raise
                    tfs_was_unresponsive = True

            # remove model from disk and memory references if TFS gets unresponsive
            if tfs_was_unresponsive:
                with LockedModel(self._models, "w", model_name, model_version):
                    available_versions = self._client.poll_available_model_versions(
                        model_name)
                    status, _ = self._models.has_model(model_name,
                                                       model_version)
                    if not (status == "in-memory"
                            and model_version not in available_versions):
                        raise WithBreak

                    logger.info(
                        f"removing model {model_name} of version {model_version} because TFS got unresponsive"
                    )
                    self._models.remove_model(model_name, model_version)

            # download, load into memory the model and retrieve it
            if update_model:
                # grab exclusive access to models holder
                with LockedModel(self._models, "w", model_name, model_version):

                    # check model status
                    status, local_ts = self._models.has_model(
                        model_name, model_version)

                    # refresh disk model
                    if status == "not-available" or (
                            status in ["on-disk", "in-memory"]
                            and local_ts != current_upstream_ts):
                        # unload model from TFS
                        if status == "in-memory":
                            try:
                                logger.info(
                                    f"unloading model {model_name} of version {model_version} from TFS"
                                )
                                self._models.unload_model(
                                    model_name, model_version)
                            except Exception:
                                logger.info(
                                    f"failed unloading model {model_name} of version {model_version} from TFS"
                                )
                                raise

                        # remove model from disk and references
                        if status in ["on-disk", "in-memory"]:
                            logger.info(
                                f"removing model references from memory and from disk for model {model_name} of version {model_version}"
                            )
                            self._models.remove_model(model_name,
                                                      model_version)

                        # download model
                        logger.info(
                            f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream"
                        )
                        date = self._models.download_model(
                            upstream_model["provider"],
                            upstream_model["bucket"],
                            model_name,
                            model_version,
                            upstream_model["path"],
                        )
                        if not date:
                            raise WithBreak
                        current_upstream_ts = int(date.timestamp())

                    # load model
                    try:
                        logger.info(
                            f"loading model {model_name} of version {model_version} into memory"
                        )
                        self._models.load_model(
                            model_name,
                            model_version,
                            current_upstream_ts,
                            [tag],
                            kwargs={
                                "model_name":
                                model_name,
                                "model_version":
                                model_version,
                                "signature_key":
                                self._determine_model_signature_key(
                                    model_name),
                            },
                        )
                    except Exception as e:
                        raise UserRuntimeException(
                            f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})",
                            str(e),
                        )

                    # run prediction
                    self._models.get_model(model_name, model_version, tag)
                    prediction = self._client.predict(model_input, model_name,
                                                      model_version)

            return prediction
Пример #15
0
    def initialize_impl(
        self,
        project_dir: str,
        client: Union[PythonClient, TensorFlowClient, ONNXClient],
        job_spec: Optional[dict] = None,
    ):
        """
        Initialize predictor class as provided by the user.

        job_spec is a dictionary when the "kind" of the API is set to "BatchAPI". Otherwise, it's None.
        """

        # build args
        class_impl = self.class_impl(project_dir)
        constructor_args = inspect.getfullargspec(class_impl.__init__).args
        config = deepcopy(self.config)
        args = {}
        if job_spec is not None and job_spec.get("config") is not None:
            util.merge_dicts_in_place_overwrite(config, job_spec["config"])
        if "config" in constructor_args:
            args["config"] = config
        if "job_spec" in constructor_args:
            args["job_spec"] = job_spec

        # initialize predictor class
        try:
            if self.type == PythonPredictorType:
                if _are_models_specified(self.api_spec):
                    args["python_client"] = client
                    initialized_impl = class_impl(**args)
                    client.set_load_method(initialized_impl.load_model)
                else:
                    initialized_impl = class_impl(**args)
            if self.type in [
                    TensorFlowPredictorType, TensorFlowNeuronPredictorType
            ]:
                args["tensorflow_client"] = client
                initialized_impl = class_impl(**args)
            if self.type == ONNXPredictorType:
                args["onnx_client"] = client
                initialized_impl = class_impl(**args)
        except Exception as e:
            raise UserRuntimeException(self.path, "__init__", str(e)) from e
        finally:
            refresh_logger()

        # initialize the crons if models have been specified and if the API kind is RealtimeAPI
        if _are_models_specified(
                self.api_spec) and self.api_spec["kind"] == "RealtimeAPI":
            if not self.multiple_processes and self.caching_enabled:
                self.crons += [
                    ModelTreeUpdater(
                        interval=10,
                        api_spec=self.api_spec,
                        tree=self.models_tree,
                        ondisk_models_dir=self.model_dir,
                    ),
                    ModelsGC(
                        interval=10,
                        api_spec=self.api_spec,
                        models=self.models,
                        tree=self.models_tree,
                    ),
                ]

            if not self.caching_enabled and self.type in [
                    PythonPredictorType, ONNXPredictorType
            ]:
                self.crons += [
                    FileBasedModelsGC(interval=10,
                                      models=self.models,
                                      download_dir=self.model_dir)
                ]

        for cron in self.crons:
            cron.start()

        return initialized_impl
Пример #16
0
    def _get_model(self, model_name: str, model_version: str) -> Any:
        """
        Checks if versioned model is on disk, then checks if model is in memory,
        and if not, it loads it into memory, and returns the model.

        Args:
            model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk.
            model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" tag.

        Exceptions:
            RuntimeError: if another thread tried to load the model at the very same time.

        Returns:
            The model as returned by self._load_model method.
            None if the model wasn't found or if it didn't pass the validation.
        """

        model = None
        tag = ""
        if model_version == "latest":
            tag = model_version

        if not self._caching_enabled:
            # determine model version
            if tag == "latest":
                model_version = self._get_latest_model_version_from_disk(
                    model_name)
            model_id = model_name + "-" + model_version

            # grab shared access to versioned model
            resource = os.path.join(self._lock_dir, model_id + ".txt")
            with LockedFile(resource, "r", reader_lock=True) as f:

                # check model status
                file_status = f.read()
                if file_status == "" or file_status == "not-available":
                    raise WithBreak

                current_upstream_ts = int(file_status.split(" ")[1])
                update_model = False

                # grab shared access to models holder and retrieve model
                with LockedModel(self._models, "r", model_name, model_version):
                    status, local_ts = self._models.has_model(
                        model_name, model_version)
                    if status == "not-available" or (
                            status == "in-memory"
                            and local_ts != current_upstream_ts):
                        update_model = True
                        raise WithBreak
                    model, _ = self._models.get_model(model_name,
                                                      model_version, tag)

                # load model into memory and retrieve it
                if update_model:
                    with LockedModel(self._models, "w", model_name,
                                     model_version):
                        status, _ = self._models.has_model(
                            model_name, model_version)
                        if status == "not-available" or (
                                status == "in-memory"
                                and local_ts != current_upstream_ts):
                            if status == "not-available":
                                logger.info(
                                    f"loading model {model_name} of version {model_version} (thread {td.get_ident()})"
                                )
                            else:
                                logger.info(
                                    f"reloading model {model_name} of version {model_version} (thread {td.get_ident()})"
                                )
                            try:
                                self._models.load_model(
                                    model_name,
                                    model_version,
                                    current_upstream_ts,
                                    [tag],
                                )
                            except Exception as e:
                                raise UserRuntimeException(
                                    f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})",
                                    str(e),
                                )
                        model, _ = self._models.get_model(
                            model_name, model_version, tag)

        if not self._multiple_processes and self._caching_enabled:
            # determine model version
            try:
                if tag == "latest":
                    model_version = self._get_latest_model_version_from_tree(
                        model_name, self._models_tree.model_info(model_name))
            except ValueError:
                # if model_name hasn't been found
                raise UserRuntimeException(
                    f"'{model_name}' model of tag latest wasn't found in the list of available models"
                )

            # grab shared access to model tree
            available_model = True
            with LockedModelsTree(self._models_tree, "r", model_name,
                                  model_version):

                # check if the versioned model exists
                model_id = model_name + "-" + model_version
                if model_id not in self._models_tree:
                    available_model = False
                    raise WithBreak

                # retrieve model tree's metadata
                upstream_model = self._models_tree[model_id]
                current_upstream_ts = int(
                    upstream_model["timestamp"].timestamp())

            if not available_model:
                return None

            # grab shared access to models holder and retrieve model
            update_model = False
            with LockedModel(self._models, "r", model_name, model_version):
                status, local_ts = self._models.has_model(
                    model_name, model_version)
                if status in ["not-available", "on-disk"
                              ] or (status != "not-available"
                                    and local_ts != current_upstream_ts):
                    update_model = True
                    raise WithBreak
                model, _ = self._models.get_model(model_name, model_version,
                                                  tag)

            # download, load into memory the model and retrieve it
            if update_model:
                # grab exclusive access to models holder
                with LockedModel(self._models, "w", model_name, model_version):

                    # check model status
                    status, local_ts = self._models.has_model(
                        model_name, model_version)

                    # refresh disk model
                    if status == "not-available" or (
                            status in ["on-disk", "in-memory"]
                            and local_ts != current_upstream_ts):
                        if status == "not-available":
                            logger.info(
                                f"model {model_name} of version {model_version} not found locally; continuing with the download..."
                            )
                        elif status == "on-disk":
                            logger.info(
                                f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one on the disk"
                            )
                        else:
                            logger.info(
                                f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one loaded into memory"
                            )

                        # remove model from disk and memory
                        if status == "on-disk":
                            logger.info(
                                f"removing model from disk for model {model_name} of version {model_version}"
                            )
                            self._models.remove_model(model_name,
                                                      model_version)
                        if status == "in-memory":
                            logger.info(
                                f"removing model from disk and memory for model {model_name} of version {model_version}"
                            )
                            self._models.remove_model(model_name,
                                                      model_version)

                        # download model
                        logger.info(
                            f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream"
                        )
                        date = self._models.download_model(
                            upstream_model["provider"],
                            upstream_model["bucket"],
                            model_name,
                            model_version,
                            upstream_model["path"],
                        )
                        if not date:
                            raise WithBreak
                        current_upstream_ts = int(date.timestamp())

                    # load model
                    try:
                        logger.info(
                            f"loading model {model_name} of version {model_version} into memory"
                        )
                        self._models.load_model(
                            model_name,
                            model_version,
                            current_upstream_ts,
                            [tag],
                        )
                    except Exception as e:
                        raise UserRuntimeException(
                            f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})",
                            str(e),
                        )

                    # retrieve model
                    model, _ = self._models.get_model(model_name,
                                                      model_version, tag)

        return model
Пример #17
0
    def initialize_impl(
        self,
        project_dir: str,
        client: Union[PythonClient, TensorFlowClient],
        metrics_client: DogStatsd,
        job_spec: Optional[Dict[str, Any]] = None,
        proto_module_pb2: Optional[Any] = None,
    ):
        """
        Initialize predictor class as provided by the user.

        job_spec is a dictionary when the "kind" of the API is set to "BatchAPI". Otherwise, it's None.
        proto_module_pb2 is a module of the compiled proto when grpc is enabled for the "RealtimeAPI" kind. Otherwise, it's None.

        Can raise UserRuntimeException/UserException/CortexException.
        """

        # build args
        class_impl = self.class_impl(project_dir)
        constructor_args = inspect.getfullargspec(class_impl.__init__).args
        config = deepcopy(self.config)
        args = {}
        if job_spec is not None and job_spec.get("config") is not None:
            util.merge_dicts_in_place_overwrite(config, job_spec["config"])
        if "config" in constructor_args:
            args["config"] = config
        if "job_spec" in constructor_args:
            args["job_spec"] = job_spec
        if "metrics_client" in constructor_args:
            args["metrics_client"] = metrics_client
        if "proto_module_pb2" in constructor_args:
            args["proto_module_pb2"] = proto_module_pb2

        # initialize predictor class
        try:
            if self.type == PythonPredictorType:
                if are_models_specified(self.api_spec):
                    args["python_client"] = client
                    # set load method to enable the use of the client in the constructor
                    # setting/getting from self in load_model won't work because self will be set to None
                    client.set_load_method(lambda model_path: class_impl.
                                           load_model(None, model_path))
                    initialized_impl = class_impl(**args)
                    client.set_load_method(initialized_impl.load_model)
                else:
                    initialized_impl = class_impl(**args)
            if self.type in [
                    TensorFlowPredictorType, TensorFlowNeuronPredictorType
            ]:
                args["tensorflow_client"] = client
                initialized_impl = class_impl(**args)
        except Exception as e:
            raise UserRuntimeException(self.path, "__init__", str(e)) from e

        # initialize the crons if models have been specified and if the API kind is RealtimeAPI
        if are_models_specified(
                self.api_spec) and self.api_spec["kind"] == "RealtimeAPI":
            if not self.multiple_processes and self.caching_enabled:
                self.crons += [
                    ModelTreeUpdater(
                        interval=10,
                        api_spec=self.api_spec,
                        tree=self.models_tree,
                        ondisk_models_dir=self.model_dir,
                    ),
                    ModelsGC(
                        interval=10,
                        api_spec=self.api_spec,
                        models=self.models,
                        tree=self.models_tree,
                    ),
                ]

            if not self.caching_enabled and self.type == PythonPredictorType:
                self.crons += [
                    FileBasedModelsGC(interval=10,
                                      models=self.models,
                                      download_dir=self.model_dir)
                ]

        for cron in self.crons:
            cron.start()

        return initialized_impl