def _run_inference(self, model_input: Any, model_name: str, model_version: str) -> dict: """ When processes_per_replica = 1 and caching enabled, check/load model and make prediction. When processes_per_replica > 0 and caching disabled, attempt to make prediction regardless. Args: model_input: Input to the model. model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk. model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" version tag. Returns: The prediction. """ model = None tag = "" if model_version == "latest": tag = model_version if not self._caching_enabled: # determine model version if tag == "latest": versions = self._client.poll_available_model_versions( model_name) if len(versions) == 0: raise UserException( f"model '{model_name}' accessed with tag {tag} couldn't be found" ) model_version = str(max(map(lambda x: int(x), versions))) model_id = model_name + "-" + model_version return self._client.predict(model_input, model_name, model_version) if not self._multiple_processes and self._caching_enabled: # determine model version try: if tag == "latest": model_version = self._get_latest_model_version_from_tree( model_name, self._models_tree.model_info(model_name)) except ValueError: # if model_name hasn't been found raise UserRuntimeException( f"'{model_name}' model of tag {tag} wasn't found in the list of available models" ) # grab shared access to model tree available_model = True logger.info( f"grabbing access to model {model_name} of version {model_version}" ) with LockedModelsTree(self._models_tree, "r", model_name, model_version): # check if the versioned model exists model_id = model_name + "-" + model_version if model_id not in self._models_tree: available_model = False logger.info( f"model {model_name} of version {model_version} is not available" ) raise WithBreak # retrieve model tree's metadata upstream_model = self._models_tree[model_id] current_upstream_ts = int( upstream_model["timestamp"].timestamp()) logger.info( f"model {model_name} of version {model_version} is available" ) if not available_model: if tag == "": raise UserException( f"model '{model_name}' of version '{model_version}' couldn't be found" ) raise UserException( f"model '{model_name}' accessed with tag '{tag}' couldn't be found" ) # grab shared access to models holder and retrieve model update_model = False prediction = None tfs_was_unresponsive = False with LockedModel(self._models, "r", model_name, model_version): logger.info( f"checking the {model_name} {model_version} status") status, local_ts = self._models.has_model( model_name, model_version) if status in ["not-available", "on-disk" ] or (status != "not-available" and local_ts != current_upstream_ts): logger.info( f"model {model_name} of version {model_version} is not loaded (with status {status} or different timestamp)" ) update_model = True raise WithBreak # run prediction logger.info( f"run the prediction on model {model_name} of version {model_version}" ) self._models.get_model(model_name, model_version, tag) try: prediction = self._client.predict(model_input, model_name, model_version) except grpc.RpcError as e: # effectively when it got restarted if len( self._client.poll_available_model_versions( model_name)) > 0: raise tfs_was_unresponsive = True # remove model from disk and memory references if TFS gets unresponsive if tfs_was_unresponsive: with LockedModel(self._models, "w", model_name, model_version): available_versions = self._client.poll_available_model_versions( model_name) status, _ = self._models.has_model(model_name, model_version) if not (status == "in-memory" and model_version not in available_versions): raise WithBreak logger.info( f"removing model {model_name} of version {model_version} because TFS got unresponsive" ) self._models.remove_model(model_name, model_version) # download, load into memory the model and retrieve it if update_model: # grab exclusive access to models holder with LockedModel(self._models, "w", model_name, model_version): # check model status status, local_ts = self._models.has_model( model_name, model_version) # refresh disk model if status == "not-available" or ( status in ["on-disk", "in-memory"] and local_ts != current_upstream_ts): # unload model from TFS if status == "in-memory": try: logger.info( f"unloading model {model_name} of version {model_version} from TFS" ) self._models.unload_model( model_name, model_version) except Exception: logger.info( f"failed unloading model {model_name} of version {model_version} from TFS" ) raise # remove model from disk and references if status in ["on-disk", "in-memory"]: logger.info( f"removing model references from memory and from disk for model {model_name} of version {model_version}" ) self._models.remove_model(model_name, model_version) # download model logger.info( f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream" ) date = self._models.download_model( upstream_model["provider"], upstream_model["bucket"], model_name, model_version, upstream_model["path"], ) if not date: raise WithBreak current_upstream_ts = int(date.timestamp()) # load model try: logger.info( f"loading model {model_name} of version {model_version} into memory" ) self._models.load_model( model_name, model_version, current_upstream_ts, [tag], kwargs={ "model_name": model_name, "model_version": model_version, "signature_key": self._determine_model_signature_key( model_name), }, ) except Exception as e: raise UserRuntimeException( f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})", str(e), ) # run prediction self._models.get_model(model_name, model_version, tag) prediction = self._client.predict(model_input, model_name, model_version) return prediction
def _get_model(self, model_name: str, model_version: str) -> Any: """ Checks if versioned model is on disk, then checks if model is in memory, and if not, it loads it into memory, and returns the model. Args: model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk. model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" tag. Exceptions: RuntimeError: if another thread tried to load the model at the very same time. Returns: The model as returned by self._load_model method. None if the model wasn't found or if it didn't pass the validation. """ model = None tag = "" if model_version == "latest": tag = model_version if not self._caching_enabled: # determine model version if tag == "latest": model_version = self._get_latest_model_version_from_disk( model_name) model_id = model_name + "-" + model_version # grab shared access to versioned model resource = os.path.join(self._lock_dir, model_id + ".txt") with LockedFile(resource, "r", reader_lock=True) as f: # check model status file_status = f.read() if file_status == "" or file_status == "not-available": raise WithBreak current_upstream_ts = int(file_status.split(" ")[1]) update_model = False # grab shared access to models holder and retrieve model with LockedModel(self._models, "r", model_name, model_version): status, local_ts = self._models.has_model( model_name, model_version) if status == "not-available" or ( status == "in-memory" and local_ts != current_upstream_ts): update_model = True raise WithBreak model, _ = self._models.get_model(model_name, model_version, tag) # load model into memory and retrieve it if update_model: with LockedModel(self._models, "w", model_name, model_version): status, _ = self._models.has_model( model_name, model_version) if status == "not-available" or ( status == "in-memory" and local_ts != current_upstream_ts): if status == "not-available": logger.info( f"loading model {model_name} of version {model_version} (thread {td.get_ident()})" ) else: logger.info( f"reloading model {model_name} of version {model_version} (thread {td.get_ident()})" ) try: self._models.load_model( model_name, model_version, current_upstream_ts, [tag], ) except Exception as e: raise UserRuntimeException( f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})", str(e), ) model, _ = self._models.get_model( model_name, model_version, tag) if not self._multiple_processes and self._caching_enabled: # determine model version try: if tag == "latest": model_version = self._get_latest_model_version_from_tree( model_name, self._models_tree.model_info(model_name)) except ValueError: # if model_name hasn't been found raise UserRuntimeException( f"'{model_name}' model of tag latest wasn't found in the list of available models" ) # grab shared access to model tree available_model = True with LockedModelsTree(self._models_tree, "r", model_name, model_version): # check if the versioned model exists model_id = model_name + "-" + model_version if model_id not in self._models_tree: available_model = False raise WithBreak # retrieve model tree's metadata upstream_model = self._models_tree[model_id] current_upstream_ts = int( upstream_model["timestamp"].timestamp()) if not available_model: return None # grab shared access to models holder and retrieve model update_model = False with LockedModel(self._models, "r", model_name, model_version): status, local_ts = self._models.has_model( model_name, model_version) if status in ["not-available", "on-disk" ] or (status != "not-available" and local_ts != current_upstream_ts): update_model = True raise WithBreak model, _ = self._models.get_model(model_name, model_version, tag) # download, load into memory the model and retrieve it if update_model: # grab exclusive access to models holder with LockedModel(self._models, "w", model_name, model_version): # check model status status, local_ts = self._models.has_model( model_name, model_version) # refresh disk model if status == "not-available" or ( status in ["on-disk", "in-memory"] and local_ts != current_upstream_ts): if status == "not-available": logger.info( f"model {model_name} of version {model_version} not found locally; continuing with the download..." ) elif status == "on-disk": logger.info( f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one on the disk" ) else: logger.info( f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one loaded into memory" ) # remove model from disk and memory if status == "on-disk": logger.info( f"removing model from disk for model {model_name} of version {model_version}" ) self._models.remove_model(model_name, model_version) if status == "in-memory": logger.info( f"removing model from disk and memory for model {model_name} of version {model_version}" ) self._models.remove_model(model_name, model_version) # download model logger.info( f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream" ) date = self._models.download_model( upstream_model["provider"], upstream_model["bucket"], model_name, model_version, upstream_model["path"], ) if not date: raise WithBreak current_upstream_ts = int(date.timestamp()) # load model try: logger.info( f"loading model {model_name} of version {model_version} into memory" ) self._models.load_model( model_name, model_version, current_upstream_ts, [tag], ) except Exception as e: raise UserRuntimeException( f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})", str(e), ) # retrieve model model, _ = self._models.get_model(model_name, model_version, tag) return model