def get_payload(self, request_id: str) -> Union[Dict, str, bytes]: key = f"{self.storage_path}/{request_id}/payload" obj = self.storage.get_object(key) status_code = obj["ResponseMetadata"]["HTTPStatusCode"] if status_code != HTTPStatus.OK: raise CortexException( f"failed to retrieve async payload (request_id: {request_id}, status_code: {status_code})" ) content_type: str = obj["ResponseMetadata"]["HTTPHeaders"][ "content-type"] payload_bytes: bytes = obj["Body"].read() # decode payload if content_type.startswith("application/json"): try: return json.loads(payload_bytes) except Exception as err: raise UserRuntimeException( f"the uploaded payload, with content-type {content_type}, could not be decoded to JSON" ) from err elif content_type.startswith("text/plain"): try: return payload_bytes.decode("utf-8") except Exception as err: raise UserRuntimeException( f"the uploaded payload, with content-type {content_type}, could not be decoded to a utf-8 string" ) from err else: return payload_bytes
def upload_result(self, request_id: str, result: Dict[str, Any]): if not isinstance(result, dict): raise UserRuntimeException( f"user response must be json serializable dictionary, got {type(result)} instead" ) try: result_json = json.dumps(result) except Exception: raise UserRuntimeException("user response is not json serializable") self.storage.put_str(result_json, f"{self.storage_path}/{request_id}/result.json")
def _validate_model_args( self, model_name: Optional[str] = None, model_version: str = "latest" ) -> Tuple[str, str]: """ Validate the model name and model version. Args: model_name: Name of the model. model_version: Model version to use. Can also be "latest" for picking the highest version. Returns: The processed model_name, model_version tuple if they had to go through modification. Raises: UserRuntimeException if the validation fails. """ if model_version != "latest" and not model_version.isnumeric(): raise UserRuntimeException( "model_version must be either a parse-able numeric value or 'latest'" ) # when predictor:models:path or predictor:models:paths is specified if not self._models_dir: # when when predictor:models:path is provided if consts.SINGLE_MODEL_NAME in self._spec_model_names: return consts.SINGLE_MODEL_NAME, model_version # when predictor:models:paths is specified if model_name is None: raise UserRuntimeException( f"model_name was not specified, choose one of the following: {self._spec_model_names}" ) if model_name not in self._spec_model_names: raise UserRuntimeException( f"'{model_name}' model wasn't found in the list of available models" ) # when predictor:models:dir is specified if self._models_dir: if model_name is None: raise UserRuntimeException("model_name was not specified") if not self._caching_enabled: available_models = find_ondisk_models_with_lock(self._lock_dir) if model_name not in available_models: raise UserRuntimeException( f"'{model_name}' model wasn't found in the list of available models" ) return model_name, model_version
def predict(self, model_input: Any, model_name: Optional[str] = None, model_version: str = "latest") -> dict: """ Validate model_input, convert it to a Prediction Proto, and make a request to TensorFlow Serving. Args: model_input: Input to the model. model_name (optional): Name of the model to retrieve (when multiple models are deployed in an API). When predictor.models.paths is specified, model_name should be the name of one of the models listed in the API config. When predictor.models.dir is specified, model_name should be the name of a top-level directory in the models dir. model_version (string, optional): Version of the model to retrieve. Can be omitted or set to "latest" to select the highest version. Returns: dict: TensorFlow Serving response converted to a dictionary. """ if model_version != "latest" and not model_version.isnumeric(): raise UserRuntimeException( "model_version must be either a parse-able numeric value or 'latest'" ) # when ppredictor:models:path or predictor:models:paths is specified if not self._models_dir: # when predictor:models:path is provided if consts.SINGLE_MODEL_NAME in self._spec_model_names: return self._run_inference(model_input, consts.SINGLE_MODEL_NAME, model_version) # when predictor:models:paths is specified if model_name is None: raise UserRuntimeException( f"model_name was not specified, choose one of the following: {self._spec_model_names}" ) if model_name not in self._spec_model_names: raise UserRuntimeException( f"'{model_name}' model wasn't found in the list of available models" ) # when predictor:models:dir is specified if self._models_dir and model_name is None: raise UserRuntimeException("model_name was not specified") return self._run_inference(model_input, model_name, model_version)
def predict(request: Request): predictor_impl = local_cache["predictor_impl"] dynamic_batcher = local_cache["dynamic_batcher"] kwargs = build_predict_kwargs(request) if dynamic_batcher: prediction = dynamic_batcher.predict(**kwargs) else: prediction = predictor_impl.predict(**kwargs) if isinstance(prediction, bytes): response = Response(content=prediction, media_type="application/octet-stream") elif isinstance(prediction, str): response = Response(content=prediction, media_type="text/plain") elif isinstance(prediction, Response): response = prediction else: try: json_string = json.dumps(prediction) except Exception as e: raise UserRuntimeException( str(e), "please return an object that is JSON serializable (including its nested fields), a bytes object, " "a string, or a starlette.response.Response object", ) from e response = Response(content=json_string, media_type="application/json") if util.has_method(predictor_impl, "post_predict"): kwargs = build_post_predict_kwargs(prediction, request) request_thread_pool.submit(predictor_impl.post_predict, **kwargs) return response
def initialize_impl( self, project_dir: str, metrics_client: MetricsClient, tf_serving_host: str = None, tf_serving_port: str = None, ): predictor_impl = self._get_impl(project_dir) constructor_args = inspect.getfullargspec(predictor_impl.__init__).args config = deepcopy(self.config) args = {} if "config" in constructor_args: args["config"] = config if "metrics_client" in constructor_args: args["metrics_client"] = metrics_client if self.type in [ TensorFlowPredictorType, TensorFlowNeuronPredictorType ]: tf_serving_address = tf_serving_host + ":" + tf_serving_port tf_client = TensorFlowClient( tf_serving_url=tf_serving_address, api_spec=self.api_spec, ) tf_client.sync_models(lock_dir=self.lock_dir) args["tensorflow_client"] = tf_client try: predictor = predictor_impl(**args) except Exception as e: raise UserRuntimeException(self.path, "__init__", str(e)) from e return predictor
def _batch_engine(self): while True: if len(self.predictions) > 0: time.sleep(0.001) continue try: self.barrier.wait(self.batch_interval) except td.BrokenBarrierError: pass self.predictions = {} sample_ids = self._get_sample_ids(self.batch_max_size) try: if self.samples: batch = self._make_batch(sample_ids) predictions = self.predictor_impl.predict(**batch) if not isinstance(predictions, list): raise UserRuntimeException( f"please return a list when using server side batching, got {type(predictions)}" ) if self.test_mode: self._test_batch_lengths.append(len(predictions)) self.predictions = dict(zip(sample_ids, predictions)) except Exception as e: self.predictions = {sample_id: e for sample_id in sample_ids} logger.error(traceback.format_exc()) finally: for sample_id in sample_ids: del self.samples[sample_id] self.barrier.reset()
def get_model(self, model_name: Optional[str] = None, model_version: str = "latest") -> Any: """ Retrieve a model for inference. Args: model_name (optional): Name of the model to retrieve (when multiple models are deployed in an API). When predictor.models.paths is specified, model_name should be the name of one of the models listed in the API config. When predictor.models.dir is specified, model_name should be the name of a top-level directory in the models dir. model_version (string, optional): Version of the model to retrieve. Can be omitted or set to "latest" to select the highest version. Returns: The value that's returned by your predictor's load_model() method. """ if model_version != "latest" and not model_version.isnumeric(): raise UserRuntimeException( "model_version must be either a parse-able numeric value or 'latest'" ) # when predictor:models:path or predictor:models:paths is specified if not self._models_dir: # when predictor:models:path is provided if consts.SINGLE_MODEL_NAME in self._spec_model_names: model_name = consts.SINGLE_MODEL_NAME model = self._get_model(model_name, model_version) if model is None: raise UserRuntimeException( f"model {model_name} of version {model_version} wasn't found" ) return model # when predictor:models:paths is specified if model_name is None: raise UserRuntimeException( f"model_name was not specified, choose one of the following: {self._spec_model_names}" ) if model_name not in self._spec_model_names: raise UserRuntimeException( f"'{model_name}' model wasn't found in the list of available models" ) # when predictor:models:dir is specified if self._models_dir: if model_name is None: raise UserRuntimeException("model_name was not specified") if not self._caching_enabled: available_models = find_ondisk_models_with_lock(self._lock_dir) if model_name not in available_models: raise UserRuntimeException( f"'{model_name}' model wasn't found in the list of available models" ) model = self._get_model(model_name, model_version) if model is None: raise UserRuntimeException( f"model {model_name} of version {model_version} wasn't found" ) return model
def _run_inference(self, model_input: Any, model_name: str, model_version: str) -> Any: """ Run the inference on model model_name of version model_version. """ model = self._get_model(model_name, model_version) if model is None: raise UserRuntimeException( f"model {model_name} of version {model_version} wasn't found" ) try: input_dict = convert_to_onnx_input(model_input, model["signatures"], model_name) return model["session"].run([], input_dict) except Exception as e: raise UserRuntimeException( f"failed inference with model {model_name} of version {model_version}", str(e) )
def _get_latest_model_version_from_disk(self, model_name: str) -> str: """ Get the highest version for a specific model name. Must only be used when processes_per_replica > 0 and caching disabled. """ versions, timestamps = find_ondisk_model_info(self._lock_dir, model_name) if len(versions) == 0: raise UserRuntimeException( "'{}' model's versions have been removed; add at least a version to the model to resume operations" .format(model_name)) return str(max(map(lambda x: int(x), versions)))
def predict(request: Request): tasks = BackgroundTasks() api = local_cache["api"] predictor_impl = local_cache["predictor_impl"] dynamic_batcher = local_cache["dynamic_batcher"] kwargs = build_predict_kwargs(request) if dynamic_batcher: prediction = dynamic_batcher.predict(**kwargs) else: prediction = predictor_impl.predict(**kwargs) if isinstance(prediction, bytes): response = Response(content=prediction, media_type="application/octet-stream") elif isinstance(prediction, str): response = Response(content=prediction, media_type="text/plain") elif isinstance(prediction, Response): response = prediction else: try: json_string = json.dumps(prediction) except Exception as e: raise UserRuntimeException( str(e), "please return an object that is JSON serializable (including its nested fields), a bytes object, " "a string, or a starlette.response.Response object", ) from e response = Response(content=json_string, media_type="application/json") if local_cache["provider"] != "local" and api.monitoring is not None: try: predicted_value = api.monitoring.extract_predicted_value( prediction) api.post_monitoring_metrics(predicted_value) if (api.monitoring.model_type == "classification" and predicted_value not in local_cache["class_set"]): tasks.add_task(api.upload_class, class_name=predicted_value) local_cache["class_set"].add(predicted_value) except: logger().warn("unable to record prediction metric", exc_info=True) if util.has_method(predictor_impl, "post_predict"): kwargs = build_post_predict_kwargs(prediction, request) request_thread_pool.submit(predictor_impl.post_predict, **kwargs) if len(tasks.tasks) > 0: response.background = tasks return response
def initialize_impl(self, project_dir: str, metrics_client: MetricsClient): predictor_impl = self._get_impl(project_dir) constructor_args = inspect.getfullargspec(predictor_impl.__init__).args config = deepcopy(self.config) args = {} if "config" in constructor_args: args["config"] = config if "metrics_client" in constructor_args: args["metrics_client"] = metrics_client try: predictor = predictor_impl(**args) except Exception as e: raise UserRuntimeException(self.path, "__init__", str(e)) from e return predictor
def get_model(self, model_name: Optional[str] = None, model_version: str = "latest") -> dict: """ Validate input and then return the model loaded into a dictionary. The counting of tag calls is recorded with this method (just like with the predict method). Args: model_name: Model to use when multiple models are deployed in a single API. model_version: Model version to use. Can also be "latest" for picking the highest version. Returns: The model as returned by _load_model method. Raises: UserRuntimeException if the validation fails. """ model_name, model_version = self._validate_model_args(model_name, model_version) model = self._get_model(model_name, model_version) if model is None: raise UserRuntimeException( f"model {model_name} of version {model_version} wasn't found" ) return model
def _run_inference(self, model_input: Any, model_name: str, model_version: str) -> dict: """ When processes_per_replica = 1 and caching enabled, check/load model and make prediction. When processes_per_replica > 0 and caching disabled, attempt to make prediction regardless. Args: model_input: Input to the model. model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk. model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" version tag. Returns: The prediction. """ model = None tag = "" if model_version == "latest": tag = model_version if not self._caching_enabled: # determine model version if tag == "latest": versions = self._client.poll_available_model_versions( model_name) if len(versions) == 0: raise UserException( f"model '{model_name}' accessed with tag {tag} couldn't be found" ) model_version = str(max(map(lambda x: int(x), versions))) model_id = model_name + "-" + model_version return self._client.predict(model_input, model_name, model_version) if not self._multiple_processes and self._caching_enabled: # determine model version try: if tag == "latest": model_version = self._get_latest_model_version_from_tree( model_name, self._models_tree.model_info(model_name)) except ValueError: # if model_name hasn't been found raise UserRuntimeException( f"'{model_name}' model of tag {tag} wasn't found in the list of available models" ) # grab shared access to model tree available_model = True logger.info( f"grabbing access to model {model_name} of version {model_version}" ) with LockedModelsTree(self._models_tree, "r", model_name, model_version): # check if the versioned model exists model_id = model_name + "-" + model_version if model_id not in self._models_tree: available_model = False logger.info( f"model {model_name} of version {model_version} is not available" ) raise WithBreak # retrieve model tree's metadata upstream_model = self._models_tree[model_id] current_upstream_ts = int( upstream_model["timestamp"].timestamp()) logger.info( f"model {model_name} of version {model_version} is available" ) if not available_model: if tag == "": raise UserException( f"model '{model_name}' of version '{model_version}' couldn't be found" ) raise UserException( f"model '{model_name}' accessed with tag '{tag}' couldn't be found" ) # grab shared access to models holder and retrieve model update_model = False prediction = None tfs_was_unresponsive = False with LockedModel(self._models, "r", model_name, model_version): logger.info( f"checking the {model_name} {model_version} status") status, local_ts = self._models.has_model( model_name, model_version) if status in ["not-available", "on-disk" ] or (status != "not-available" and local_ts != current_upstream_ts): logger.info( f"model {model_name} of version {model_version} is not loaded (with status {status} or different timestamp)" ) update_model = True raise WithBreak # run prediction logger.info( f"run the prediction on model {model_name} of version {model_version}" ) self._models.get_model(model_name, model_version, tag) try: prediction = self._client.predict(model_input, model_name, model_version) except grpc.RpcError as e: # effectively when it got restarted if len( self._client.poll_available_model_versions( model_name)) > 0: raise tfs_was_unresponsive = True # remove model from disk and memory references if TFS gets unresponsive if tfs_was_unresponsive: with LockedModel(self._models, "w", model_name, model_version): available_versions = self._client.poll_available_model_versions( model_name) status, _ = self._models.has_model(model_name, model_version) if not (status == "in-memory" and model_version not in available_versions): raise WithBreak logger.info( f"removing model {model_name} of version {model_version} because TFS got unresponsive" ) self._models.remove_model(model_name, model_version) # download, load into memory the model and retrieve it if update_model: # grab exclusive access to models holder with LockedModel(self._models, "w", model_name, model_version): # check model status status, local_ts = self._models.has_model( model_name, model_version) # refresh disk model if status == "not-available" or ( status in ["on-disk", "in-memory"] and local_ts != current_upstream_ts): # unload model from TFS if status == "in-memory": try: logger.info( f"unloading model {model_name} of version {model_version} from TFS" ) self._models.unload_model( model_name, model_version) except Exception: logger.info( f"failed unloading model {model_name} of version {model_version} from TFS" ) raise # remove model from disk and references if status in ["on-disk", "in-memory"]: logger.info( f"removing model references from memory and from disk for model {model_name} of version {model_version}" ) self._models.remove_model(model_name, model_version) # download model logger.info( f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream" ) date = self._models.download_model( upstream_model["provider"], upstream_model["bucket"], model_name, model_version, upstream_model["path"], ) if not date: raise WithBreak current_upstream_ts = int(date.timestamp()) # load model try: logger.info( f"loading model {model_name} of version {model_version} into memory" ) self._models.load_model( model_name, model_version, current_upstream_ts, [tag], kwargs={ "model_name": model_name, "model_version": model_version, "signature_key": self._determine_model_signature_key( model_name), }, ) except Exception as e: raise UserRuntimeException( f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})", str(e), ) # run prediction self._models.get_model(model_name, model_version, tag) prediction = self._client.predict(model_input, model_name, model_version) return prediction
def initialize_impl( self, project_dir: str, client: Union[PythonClient, TensorFlowClient, ONNXClient], job_spec: Optional[dict] = None, ): """ Initialize predictor class as provided by the user. job_spec is a dictionary when the "kind" of the API is set to "BatchAPI". Otherwise, it's None. """ # build args class_impl = self.class_impl(project_dir) constructor_args = inspect.getfullargspec(class_impl.__init__).args config = deepcopy(self.config) args = {} if job_spec is not None and job_spec.get("config") is not None: util.merge_dicts_in_place_overwrite(config, job_spec["config"]) if "config" in constructor_args: args["config"] = config if "job_spec" in constructor_args: args["job_spec"] = job_spec # initialize predictor class try: if self.type == PythonPredictorType: if _are_models_specified(self.api_spec): args["python_client"] = client initialized_impl = class_impl(**args) client.set_load_method(initialized_impl.load_model) else: initialized_impl = class_impl(**args) if self.type in [ TensorFlowPredictorType, TensorFlowNeuronPredictorType ]: args["tensorflow_client"] = client initialized_impl = class_impl(**args) if self.type == ONNXPredictorType: args["onnx_client"] = client initialized_impl = class_impl(**args) except Exception as e: raise UserRuntimeException(self.path, "__init__", str(e)) from e finally: refresh_logger() # initialize the crons if models have been specified and if the API kind is RealtimeAPI if _are_models_specified( self.api_spec) and self.api_spec["kind"] == "RealtimeAPI": if not self.multiple_processes and self.caching_enabled: self.crons += [ ModelTreeUpdater( interval=10, api_spec=self.api_spec, tree=self.models_tree, ondisk_models_dir=self.model_dir, ), ModelsGC( interval=10, api_spec=self.api_spec, models=self.models, tree=self.models_tree, ), ] if not self.caching_enabled and self.type in [ PythonPredictorType, ONNXPredictorType ]: self.crons += [ FileBasedModelsGC(interval=10, models=self.models, download_dir=self.model_dir) ] for cron in self.crons: cron.start() return initialized_impl
def _get_model(self, model_name: str, model_version: str) -> Any: """ Checks if versioned model is on disk, then checks if model is in memory, and if not, it loads it into memory, and returns the model. Args: model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk. model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" tag. Exceptions: RuntimeError: if another thread tried to load the model at the very same time. Returns: The model as returned by self._load_model method. None if the model wasn't found or if it didn't pass the validation. """ model = None tag = "" if model_version == "latest": tag = model_version if not self._caching_enabled: # determine model version if tag == "latest": model_version = self._get_latest_model_version_from_disk( model_name) model_id = model_name + "-" + model_version # grab shared access to versioned model resource = os.path.join(self._lock_dir, model_id + ".txt") with LockedFile(resource, "r", reader_lock=True) as f: # check model status file_status = f.read() if file_status == "" or file_status == "not-available": raise WithBreak current_upstream_ts = int(file_status.split(" ")[1]) update_model = False # grab shared access to models holder and retrieve model with LockedModel(self._models, "r", model_name, model_version): status, local_ts = self._models.has_model( model_name, model_version) if status == "not-available" or ( status == "in-memory" and local_ts != current_upstream_ts): update_model = True raise WithBreak model, _ = self._models.get_model(model_name, model_version, tag) # load model into memory and retrieve it if update_model: with LockedModel(self._models, "w", model_name, model_version): status, _ = self._models.has_model( model_name, model_version) if status == "not-available" or ( status == "in-memory" and local_ts != current_upstream_ts): if status == "not-available": logger.info( f"loading model {model_name} of version {model_version} (thread {td.get_ident()})" ) else: logger.info( f"reloading model {model_name} of version {model_version} (thread {td.get_ident()})" ) try: self._models.load_model( model_name, model_version, current_upstream_ts, [tag], ) except Exception as e: raise UserRuntimeException( f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})", str(e), ) model, _ = self._models.get_model( model_name, model_version, tag) if not self._multiple_processes and self._caching_enabled: # determine model version try: if tag == "latest": model_version = self._get_latest_model_version_from_tree( model_name, self._models_tree.model_info(model_name)) except ValueError: # if model_name hasn't been found raise UserRuntimeException( f"'{model_name}' model of tag latest wasn't found in the list of available models" ) # grab shared access to model tree available_model = True with LockedModelsTree(self._models_tree, "r", model_name, model_version): # check if the versioned model exists model_id = model_name + "-" + model_version if model_id not in self._models_tree: available_model = False raise WithBreak # retrieve model tree's metadata upstream_model = self._models_tree[model_id] current_upstream_ts = int( upstream_model["timestamp"].timestamp()) if not available_model: return None # grab shared access to models holder and retrieve model update_model = False with LockedModel(self._models, "r", model_name, model_version): status, local_ts = self._models.has_model( model_name, model_version) if status in ["not-available", "on-disk" ] or (status != "not-available" and local_ts != current_upstream_ts): update_model = True raise WithBreak model, _ = self._models.get_model(model_name, model_version, tag) # download, load into memory the model and retrieve it if update_model: # grab exclusive access to models holder with LockedModel(self._models, "w", model_name, model_version): # check model status status, local_ts = self._models.has_model( model_name, model_version) # refresh disk model if status == "not-available" or ( status in ["on-disk", "in-memory"] and local_ts != current_upstream_ts): if status == "not-available": logger.info( f"model {model_name} of version {model_version} not found locally; continuing with the download..." ) elif status == "on-disk": logger.info( f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one on the disk" ) else: logger.info( f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one loaded into memory" ) # remove model from disk and memory if status == "on-disk": logger.info( f"removing model from disk for model {model_name} of version {model_version}" ) self._models.remove_model(model_name, model_version) if status == "in-memory": logger.info( f"removing model from disk and memory for model {model_name} of version {model_version}" ) self._models.remove_model(model_name, model_version) # download model logger.info( f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream" ) date = self._models.download_model( upstream_model["provider"], upstream_model["bucket"], model_name, model_version, upstream_model["path"], ) if not date: raise WithBreak current_upstream_ts = int(date.timestamp()) # load model try: logger.info( f"loading model {model_name} of version {model_version} into memory" ) self._models.load_model( model_name, model_version, current_upstream_ts, [tag], ) except Exception as e: raise UserRuntimeException( f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})", str(e), ) # retrieve model model, _ = self._models.get_model(model_name, model_version, tag) return model
def initialize_impl( self, project_dir: str, client: Union[PythonClient, TensorFlowClient], metrics_client: DogStatsd, job_spec: Optional[Dict[str, Any]] = None, proto_module_pb2: Optional[Any] = None, ): """ Initialize predictor class as provided by the user. job_spec is a dictionary when the "kind" of the API is set to "BatchAPI". Otherwise, it's None. proto_module_pb2 is a module of the compiled proto when grpc is enabled for the "RealtimeAPI" kind. Otherwise, it's None. Can raise UserRuntimeException/UserException/CortexException. """ # build args class_impl = self.class_impl(project_dir) constructor_args = inspect.getfullargspec(class_impl.__init__).args config = deepcopy(self.config) args = {} if job_spec is not None and job_spec.get("config") is not None: util.merge_dicts_in_place_overwrite(config, job_spec["config"]) if "config" in constructor_args: args["config"] = config if "job_spec" in constructor_args: args["job_spec"] = job_spec if "metrics_client" in constructor_args: args["metrics_client"] = metrics_client if "proto_module_pb2" in constructor_args: args["proto_module_pb2"] = proto_module_pb2 # initialize predictor class try: if self.type == PythonPredictorType: if are_models_specified(self.api_spec): args["python_client"] = client # set load method to enable the use of the client in the constructor # setting/getting from self in load_model won't work because self will be set to None client.set_load_method(lambda model_path: class_impl. load_model(None, model_path)) initialized_impl = class_impl(**args) client.set_load_method(initialized_impl.load_model) else: initialized_impl = class_impl(**args) if self.type in [ TensorFlowPredictorType, TensorFlowNeuronPredictorType ]: args["tensorflow_client"] = client initialized_impl = class_impl(**args) except Exception as e: raise UserRuntimeException(self.path, "__init__", str(e)) from e # initialize the crons if models have been specified and if the API kind is RealtimeAPI if are_models_specified( self.api_spec) and self.api_spec["kind"] == "RealtimeAPI": if not self.multiple_processes and self.caching_enabled: self.crons += [ ModelTreeUpdater( interval=10, api_spec=self.api_spec, tree=self.models_tree, ondisk_models_dir=self.model_dir, ), ModelsGC( interval=10, api_spec=self.api_spec, models=self.models, tree=self.models_tree, ), ] if not self.caching_enabled and self.type == PythonPredictorType: self.crons += [ FileBasedModelsGC(interval=10, models=self.models, download_dir=self.model_dir) ] for cron in self.crons: cron.start() return initialized_impl