def start(): while not pathlib.Path("/mnt/workspace/init_script_run.txt").is_file(): time.sleep(0.2) cache_dir = os.environ["CORTEX_CACHE_DIR"] provider = os.environ["CORTEX_PROVIDER"] api_spec_path = os.environ["CORTEX_API_SPEC"] job_spec_path = os.environ["CORTEX_JOB_SPEC"] project_dir = os.environ["CORTEX_PROJECT_DIR"] model_dir = os.getenv("CORTEX_MODEL_DIR") tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000") tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost") bucket = os.getenv("CORTEX_BUCKET") region = os.getenv("AWS_REGION") has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS") if has_multiple_servers: with LockedFile("/run/used_ports.json", "r+") as f: used_ports = json.load(f) for port in used_ports.keys(): if not used_ports[port]: tf_serving_port = port used_ports[port] = True break f.seek(0) json.dump(used_ports, f) f.truncate() api = get_api(provider, api_spec_path, model_dir, cache_dir, bucket, region) storage, api_spec = get_spec(provider, api_spec_path, cache_dir, bucket, region) job_spec = get_job_spec(storage, cache_dir, job_spec_path) client = api.predictor.initialize_client(tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port) logger().info("loading the predictor from {}".format(api.predictor.path)) predictor_impl = api.predictor.initialize_impl(project_dir, client, job_spec) local_cache["api_spec"] = api local_cache["provider"] = provider local_cache["job_spec"] = job_spec local_cache["predictor_impl"] = predictor_impl local_cache["predict_fn_args"] = inspect.getfullargspec( predictor_impl.predict).args local_cache["sqs_client"] = boto3.client("sqs", region_name=region) open("/mnt/workspace/api_readiness.txt", "a").close() logger().info("polling for batches...") sqs_loop()
def garbage_collect( self, exclude_disk_model_ids: List[str] = [], dry_run: bool = False) -> Tuple[bool, List[str], List[str]]: """ Removes stale in-memory and on-disk models based on LRU policy. Also calls the "remove" callback before removing the models from this object. The callback must not raise any exceptions. Must be called with a write lock unless dry_run is set to true. Args: exclude_disk_model_ids: Model IDs to exclude from removing from disk. Necessary for locally-provided models. dry_run: Just test if there are any models to remove. If set to true, this method can then be called with a read lock. Returns: A 3-element tuple. First element tells whether models had to be collected. The 2nd and 3rd elements contain the model IDs that were removed from memory and disk respectively. """ collected = False if self._mem_cache_size <= 0 or self._disk_cache_size <= 0: return collected stale_mem_model_ids = self._lru_model_ids(self._mem_cache_size, filter_in_mem=True) stale_disk_model_ids = self._lru_model_ids(self._disk_cache_size - len(exclude_disk_model_ids), filter_in_mem=False) if self._remove_callback and not dry_run: self._remove_callback(stale_mem_model_ids) # don't delete excluded model IDs from disk stale_disk_model_ids = list( set(stale_disk_model_ids) - set(exclude_disk_model_ids)) stale_disk_model_ids = stale_disk_model_ids[len(stale_disk_model_ids) - self._disk_cache_size:] if not dry_run: logger().info( f"unloading models {stale_mem_model_ids} from memory using the garbage collector" ) logger().info( f"unloading models {stale_disk_model_ids} from disk using the garbage collector" ) for model_id in stale_mem_model_ids: self.remove_model_by_id(model_id, mem=True, disk=False) for model_id in stale_disk_model_ids: self.remove_model_by_id(model_id, mem=False, disk=True) if len(stale_mem_model_ids) > 0 or len(stale_disk_model_ids) > 0: collected = True return collected, stale_mem_model_ids, stale_disk_model_ids
def post_metrics(self, metrics): try: if self.statsd is None: raise CortexException("statsd client not initialized") # unexpected for metric in metrics: tags = ["{}:{}".format(dim["Name"], dim["Value"]) for dim in metric["Dimensions"]] if metric.get("Unit") == "Count": self.statsd.increment(metric["MetricName"], value=metric["Value"], tags=tags) else: self.statsd.histogram(metric["MetricName"], value=metric["Value"], tags=tags) except: logger().warn("failure encountered while publishing metrics", exc_info=True)
def start(args): download_config = json.loads(base64.urlsafe_b64decode(args.download)) for download_arg in download_config["download_args"]: from_path = download_arg["from"] to_path = download_arg["to"] item_name = download_arg.get("item_name", "") bucket_name, prefix = S3.deconstruct_s3_path(from_path) s3_client = S3(bucket_name, client_config={}) if item_name != "": if download_arg.get("hide_from_log", False): logger().info("downloading {}".format(item_name)) else: logger().info("downloading {} from {}".format( item_name, from_path)) if download_arg.get("to_file", False): s3_client.download_file(prefix, to_path) else: s3_client.download(prefix, to_path) if download_arg.get("unzip", False): if item_name != "" and not download_arg.get( "hide_unzipping_log", False): logger().info("unzipping {}".format(item_name)) if download_arg.get("to_file", False): util.extract_zip(to_path, delete_zip_file=True) else: util.extract_zip(os.path.join(to_path, os.path.basename(from_path)), delete_zip_file=True) if download_config.get("last_log", "") != "": logger().info(download_config["last_log"])
def handle_on_complete(message): job_spec = local_cache["job_spec"] predictor_impl = local_cache["predictor_impl"] sqs_client = local_cache["sqs_client"] queue_url = job_spec["sqs_url"] receipt_handle = message["ReceiptHandle"] try: if not getattr(predictor_impl, "on_job_complete", None): sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle) return True should_run_on_job_complete = False while True: visible_count, not_visible_count = get_total_messages_in_queue() # if there are other messages that are visible, release this message and get the other ones (should rarely happen for FIFO) if visible_count > 0: sqs_client.change_message_visibility( QueueUrl=queue_url, ReceiptHandle=receipt_handle, VisibilityTimeout=0) return False if should_run_on_job_complete: # double check that the queue is still empty (except for the job_complete message) if not_visible_count <= 1: logger().info("executing on_job_complete") predictor_impl.on_job_complete() sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle) return True else: should_run_on_job_complete = False if not_visible_count <= 1: should_run_on_job_complete = True time.sleep(20) except: sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle) raise
def predict(request: Request): tasks = BackgroundTasks() api = local_cache["api"] predictor_impl = local_cache["predictor_impl"] kwargs = build_predict_kwargs(request) prediction = predictor_impl.predict(**kwargs) if isinstance(prediction, bytes): response = Response(content=prediction, media_type="application/octet-stream") elif isinstance(prediction, str): response = Response(content=prediction, media_type="text/plain") elif isinstance(prediction, Response): response = prediction else: try: json_string = json.dumps(prediction) except Exception as e: raise UserRuntimeException( str(e), "please return an object that is JSON serializable (including its nested fields), a bytes object, a string, or a starlette.response.Response object", ) from e response = Response(content=json_string, media_type="application/json") if local_cache["provider"] != "local" and api.monitoring is not None: try: predicted_value = api.monitoring.extract_predicted_value( prediction) api.post_monitoring_metrics(predicted_value) if (api.monitoring.model_type == "classification" and predicted_value not in local_cache["class_set"]): tasks.add_task(api.upload_class, class_name=predicted_value) local_cache["class_set"].add(predicted_value) except: logger().warn("unable to record prediction metric", exc_info=True) if util.has_method(predictor_impl, "post_predict"): kwargs = build_post_predict_kwargs(prediction, request) request_thread_pool.submit(predictor_impl.post_predict, **kwargs) if len(tasks.tasks) > 0: response.background = tasks return response
def validate_models_dir_paths( paths: List[str], predictor_type: PredictorType, common_prefix: str) -> Tuple[List[str], List[List[int]]]: """ Validates the models paths based on the given predictor type. To be used when predictor:models:dir in cortex.yaml is used. Args: paths: A list of all paths for a given S3/local prefix. Must be underneath the common prefix. predictor_type: The predictor type. common_prefix: The common prefix of the directory which holds all models. AKA predictor:models:dir. Returns: List with the prefix of each model that's valid. List with the OneOfAllPlaceholder IDs validated for each valid model. """ if len(paths) == 0: raise CortexException( f"{predictor_type} predictor at '{common_prefix}'", "model top path can't be empty") rel_paths = [ os.path.relpath(top_path, common_prefix) for top_path in paths ] rel_paths = [path for path in rel_paths if not path.startswith("../")] model_names = [util.get_leftmost_part_of_path(path) for path in rel_paths] model_names = list(set(model_names)) valid_model_prefixes = [] ooa_valid_key_ids = [] for model_name in model_names: try: ooa_valid_key_ids.append( validate_model_paths(rel_paths, predictor_type, model_name)) valid_model_prefixes.append(os.path.join(common_prefix, model_name)) except CortexException as e: logger().debug(f"failed validating model {model_name}: {str(e)}") continue return valid_model_prefixes, ooa_valid_key_ids
def _remove_models(self, model_ids: List[str]) -> None: """ Remove models from TFS. Must only be used when caching enabled. """ logger().info(f"unloading models with model IDs {model_ids} from TFS") models = {} for model_id in model_ids: model_name, model_version = model_id.rsplit("-", maxsplit=1) if model_name not in models: models[model_name] = [model_version] else: models[model_name].append(model_version) model_names = [] model_versions = [] for model_name, versions in models.items(): model_names.append(model_name) model_versions.append(versions) self._client.remove_models(model_names, model_versions)
def start(args): download_config = json.loads(base64.urlsafe_b64decode(args.download)) for download_arg in download_config["download_args"]: from_path = download_arg["from"] to_path = download_arg["to"] item_name = download_arg.get("item_name", "") if from_path.startswith("s3://"): bucket_name, prefix = S3.deconstruct_s3_path(from_path) client = S3(bucket_name, client_config={}) elif from_path.startswith("gs://"): bucket_name, prefix = GCS.deconstruct_gcs_path(from_path) client = GCS(bucket_name) else: raise ValueError( '"from" download arg can either have the "s3://" or "gs://" prefixes' ) if item_name != "": if download_arg.get("hide_from_log", False): logger().info("downloading {}".format(item_name)) else: logger().info("downloading {} from {}".format( item_name, from_path)) if download_arg.get("to_file", False): client.download_file(prefix, to_path) else: client.download(prefix, to_path) if download_arg.get("unzip", False): if item_name != "" and not download_arg.get( "hide_unzipping_log", False): logger().info("unzipping {}".format(item_name)) if download_arg.get("to_file", False): util.extract_zip(to_path, delete_zip_file=True) else: util.extract_zip(os.path.join(to_path, os.path.basename(from_path)), delete_zip_file=True) if download_config.get("last_log", "") != "": logger().info(download_config["last_log"])
def _get_model(self, model_name: str, model_version: str) -> Any: """ Checks if versioned model is on disk, then checks if model is in memory, and if not, it loads it into memory, and returns the model. Args: model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk. model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" version tag. Exceptions: RuntimeError: if another thread tried to load the model at the very same time. Returns: The model as returned by self._load_model method. None if the model wasn't found or if it didn't pass the validation. """ model = None tag = "" if model_version == "latest": tag = model_version if not self._caching_enabled: # determine model version if tag == "latest": model_version = self._get_latest_model_version_from_disk( model_name) model_id = model_name + "-" + model_version # grab shared access to versioned model resource = os.path.join(self._lock_dir, model_id + ".txt") with LockedFile(resource, "r", reader_lock=True) as f: # check model status file_status = f.read() if file_status == "" or file_status == "not-available": raise WithBreak current_upstream_ts = int(file_status.split(" ")[1]) update_model = False # grab shared access to models holder and retrieve model with LockedModel(self._models, "r", model_name, model_version): status, local_ts = self._models.has_model( model_name, model_version) if status == "not-available" or ( status == "in-memory" and local_ts != current_upstream_ts): update_model = True raise WithBreak model, _ = self._models.get_model(model_name, model_version, tag) # load model into memory and retrieve it if update_model: with LockedModel(self._models, "w", model_name, model_version): status, _ = self._models.has_model( model_name, model_version) if status == "not-available" or ( status == "in-memory" and local_ts != current_upstream_ts): if status == "not-available": logger().info( f"loading model {model_name} of version {model_version} (thread {td.get_ident()})" ) else: logger().info( f"reloading model {model_name} of version {model_version} (thread {td.get_ident()})" ) try: self._models.load_model( model_name, model_version, current_upstream_ts, [tag], ) except Exception as e: raise UserRuntimeException( f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})", str(e), ) model, _ = self._models.get_model( model_name, model_version, tag) if not self._multiple_processes and self._caching_enabled: # determine model version try: if tag == "latest": model_version = self._get_latest_model_version_from_tree( model_name, self._models_tree.model_info(model_name)) except ValueError: # if model_name hasn't been found raise UserRuntimeException( f"'{model_name}' model of tag {tag} wasn't found in the list of available models" ) # grab shared access to model tree available_model = True with LockedModelsTree(self._models_tree, "r", model_name, model_version): # check if the versioned model exists model_id = model_name + "-" + model_version if model_id not in self._models_tree: available_model = False raise WithBreak # retrieve model tree's metadata upstream_model = self._models_tree[model_id] current_upstream_ts = int( upstream_model["timestamp"].timestamp()) if not available_model: return None # grab shared access to models holder and retrieve model update_model = False with LockedModel(self._models, "r", model_name, model_version): status, local_ts = self._models.has_model( model_name, model_version) if status in [ "not-available", "on-disk" ] or (status != "not-available" and local_ts != current_upstream_ts and not (status == "in-memory" and model_name in self._spec_local_model_names)): update_model = True raise WithBreak model, _ = self._models.get_model(model_name, model_version, tag) # download, load into memory the model and retrieve it if update_model: # grab exclusive access to models holder with LockedModel(self._models, "w", model_name, model_version): # check model status status, local_ts = self._models.has_model( model_name, model_version) # refresh disk model if model_name not in self._spec_local_model_names and ( status == "not-available" or (status in ["on-disk", "in-memory"] and local_ts != current_upstream_ts)): if status == "not-available": logger().info( f"model {model_name} of version {model_version} not found locally; continuing with the download..." ) elif status == "on-disk": logger().info( f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one on the disk" ) else: logger().info( f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one loaded into memory" ) # remove model from disk and memory if status == "on-disk": logger().info( f"removing model from disk for model {model_name} of version {model_version}" ) self._models.remove_model(model_name, model_version) if status == "in-memory": logger().info( f"removing model from disk and memory for model {model_name} of version {model_version}" ) self._models.remove_model(model_name, model_version) # download model logger().info( f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream" ) date = self._models.download_model( upstream_model["provider"], upstream_model["bucket"], model_name, model_version, upstream_model["path"], ) if not date: raise WithBreak current_upstream_ts = int(date.timestamp()) # give the local model a timestamp initialized at start time if model_name in self._spec_local_model_names: current_upstream_ts = self._local_model_ts # load model try: logger().info( f"loading model {model_name} of version {model_version} into memory" ) self._models.load_model( model_name, model_version, current_upstream_ts, [tag], ) except Exception as e: raise UserRuntimeException( f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})", str(e), ) # retrieve model model, _ = self._models.get_model(model_name, model_version, tag) return model
def sqs_loop(): job_spec = local_cache["job_spec"] api_spec = local_cache["api_spec"] predictor_impl = local_cache["predictor_impl"] sqs_client = local_cache["sqs_client"] queue_url = job_spec["sqs_url"] no_messages_found_in_previous_iteration = False while True: response = sqs_client.receive_message( QueueUrl=queue_url, MaxNumberOfMessages=1, WaitTimeSeconds=10, VisibilityTimeout=MAXIMUM_MESSAGE_VISIBILITY, MessageAttributeNames=["All"], ) if response.get("Messages") is None or len(response["Messages"]) == 0: if no_messages_found_in_previous_iteration: logger().info("no batches left in queue, exiting...") return else: no_messages_found_in_previous_iteration = True continue else: no_messages_found_in_previous_iteration = False message = response["Messages"][0] receipt_handle = message["ReceiptHandle"] if "MessageAttributes" in message and "job_complete" in message[ "MessageAttributes"]: handled_on_complete = handle_on_complete(message) if handled_on_complete: logger().info( "no batches left in queue, job has been completed") return else: # sometimes on_job_complete message will be released if there are other messages still to be processed continue try: logger().info(f"processing batch {message['MessageId']}") start_time = time.time() payload = json.loads(message["Body"]) batch_id = message["MessageId"] predictor_impl.predict(**build_predict_args(payload, batch_id)) api_spec.post_metrics([ success_counter_metric(), time_per_batch_metric(time.time() - start_time) ]) except Exception: api_spec.post_metrics([ failed_counter_metric(), time_per_batch_metric(time.time() - start_time) ]) logger().exception("failed to process batch") finally: sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
def model_downloader( predictor_type: PredictorType, bucket_provider: str, bucket_name: str, model_name: str, model_version: str, model_path: str, temp_dir: str, model_dir: str, ) -> Optional[datetime.datetime]: """ Downloads model to disk. Validates the cloud model path and the downloaded model as well. Args: predictor_type: The predictor type as implemented by the API. bucket_provider: Provider for the bucket. Can be "s3" or "gs". bucket_name: Name of the bucket where the model is stored. model_name: Name of the model. Is part of the model's local path. model_version: Version of the model. Is part of the model's local path. model_path: Model prefix of the versioned model. temp_dir: Where to temporarily store the model for validation. model_dir: The top directory of where all models are stored locally. Returns: The model's timestamp. None if the model didn't pass the validation, if it doesn't exist or if there are not enough permissions. """ logger().info( f"downloading from bucket {bucket_name}/{model_path}, model {model_name} of version {model_version}, temporarily to {temp_dir} and then finally to {model_dir}" ) if bucket_provider == "s3": client = S3(bucket_name) if bucket_provider == "gs": client = GCS(bucket_name) # validate upstream cloud model sub_paths, ts = client.search(model_path) try: validate_model_paths(sub_paths, predictor_type, model_path) except CortexException: logger().info(f"failed validating model {model_name} of version {model_version}") return None # download model to temp dir temp_dest = os.path.join(temp_dir, model_name, model_version) try: client.download_dir_contents(model_path, temp_dest) except CortexException: logger().info( f"failed downloading model {model_name} of version {model_version} to temp dir {temp_dest}" ) shutil.rmtree(temp_dest) return None # validate model model_contents = glob.glob(temp_dest + "*/**", recursive=True) model_contents = util.remove_non_empty_directory_paths(model_contents) try: validate_model_paths(model_contents, predictor_type, temp_dest) except CortexException: logger().info( f"failed validating model {model_name} of version {model_version} from temp dir" ) shutil.rmtree(temp_dest) return None # move model to dest dir model_top_dir = os.path.join(model_dir, model_name) ondisk_model_version = os.path.join(model_top_dir, model_version) logger().info( f"moving model {model_name} of version {model_version} to final dir {ondisk_model_version}" ) if os.path.isdir(ondisk_model_version): shutil.rmtree(ondisk_model_version) shutil.move(temp_dest, ondisk_model_version) return max(ts)
def _extract_signatures( self, signature_def, signature_key, model_name: str, model_version: str ): logger().info( "signature defs found in model '{}' for version '{}': {}".format( model_name, model_version, signature_def ) ) available_keys = list(signature_def.keys()) if len(available_keys) == 0: raise UserException( "unable to find signature defs in model '{}' of version '{}'".format( model_name, model_version ) ) if signature_key is None: if len(available_keys) == 1: logger().info( "signature_key was not configured by user, using signature key '{}' for model '{}' of version '{}' (found in the signature def map)".format( available_keys[0], model_name, model_version, ) ) signature_key = available_keys[0] elif "predict" in signature_def: logger().info( "signature_key was not configured by user, using signature key 'predict' for model '{}' of version '{}' (found in the signature def map)".format( model_name, model_version, ) ) signature_key = "predict" else: raise UserException( "signature_key was not configured by user, please specify one the following keys '{}' for model '{}' of version '{}' (found in the signature def map)".format( ", ".join(available_keys), model_name, model_version ) ) else: if signature_def.get(signature_key) is None: possibilities_str = "key: '{}'".format(available_keys[0]) if len(available_keys) > 1: possibilities_str = "keys: '{}'".format("', '".join(available_keys)) raise UserException( "signature_key '{}' was not found in signature def map for model '{}' of version '{}', but found the following {}".format( signature_key, model_name, model_version, possibilities_str ) ) signature_def_val = signature_def.get(signature_key) if signature_def_val.get("inputs") is None: raise UserException( "unable to find 'inputs' in signature def '{}' for model '{}'".format( signature_key, model_name ) ) parsed_signatures = {} for input_name, input_metadata in signature_def_val["inputs"].items(): if input_metadata["tensorShape"] == {}: # a scalar with rank 0 and empty shape shape = "scalar" elif input_metadata["tensorShape"].get("unknownRank", False): # unknown rank and shape # # unknownRank is set to True if the model input has no rank # it may lead to an undefined behavior if unknownRank is only checked for its presence # so it also gets to be tested against its value shape = "unknown" elif input_metadata["tensorShape"].get("dim", None): # known rank and known/unknown shape shape = [int(dim["size"]) for dim in input_metadata["tensorShape"]["dim"]] else: raise UserException( "invalid 'tensorShape' specification for input '{}' in signature key '{}' for model '{}'", input_name, signature_key, model_name, ) parsed_signatures[input_name] = { "shape": shape if type(shape) == list else [shape], "type": DTYPE_TO_TF_TYPE[input_metadata["dtype"]].name, } return signature_key, parsed_signatures
def _run_inference(self, model_input: Any, model_name: str, model_version: str) -> dict: """ When processes_per_replica = 1 and caching enabled, check/load model and make prediction. When processes_per_replica > 0 and caching disabled, attempt to make prediction regardless. Args: model_input: Input to the model. model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk. model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" version tag. Returns: The prediction. """ model = None tag = "" if model_version == "latest": tag = model_version if not self._caching_enabled: # determine model version if tag == "latest": versions = self._client.poll_available_model_versions( model_name) if len(versions) == 0: raise UserException( f"model '{model_name}' accessed with tag {tag} couldn't be found" ) model_version = str(max(map(lambda x: int(x), versions))) model_id = model_name + "-" + model_version return self._client.predict(model_input, model_name, model_version) if not self._multiple_processes and self._caching_enabled: # determine model version try: if tag == "latest": model_version = self._get_latest_model_version_from_tree( model_name, self._models_tree.model_info(model_name)) except ValueError: # if model_name hasn't been found raise UserRuntimeException( f"'{model_name}' model of tag {tag} wasn't found in the list of available models" ) models_stats = [] for model_id in self._models.get_model_ids(): models_stats = self._models.has_model_id(model_id) # grab shared access to model tree available_model = True logger().info( f"grabbing access to model {model_name} of version {model_version}" ) with LockedModelsTree(self._models_tree, "r", model_name, model_version): # check if the versioned model exists model_id = model_name + "-" + model_version if model_id not in self._models_tree: available_model = False logger().info( f"model {model_name} of version {model_version} is not available" ) raise WithBreak # retrieve model tree's metadata upstream_model = self._models_tree[model_id] current_upstream_ts = int( upstream_model["timestamp"].timestamp()) logger().info( f"model {model_name} of version {model_version} is available" ) if not available_model: if tag == "": raise UserException( f"model '{model_name}' of version '{model_version}' couldn't be found" ) raise UserException( f"model '{model_name}' accessed with tag '{tag}' couldn't be found" ) # grab shared access to models holder and retrieve model update_model = False prediction = None tfs_was_unresponsive = False with LockedModel(self._models, "r", model_name, model_version): logger().info( f"checking the {model_name} {model_version} status") status, local_ts = self._models.has_model( model_name, model_version) if status in ["not-available", "on-disk" ] or (status != "not-available" and local_ts != current_upstream_ts): logger().info( f"model {model_name} of version {model_version} is not loaded (with status {status} or different timestamp)" ) update_model = True raise WithBreak # run prediction logger().info( f"run the prediction on model {model_name} of version {model_version}" ) self._models.get_model(model_name, model_version, tag) try: prediction = self._client.predict(model_input, model_name, model_version) except grpc.RpcError as e: # effectively when it got restarted if len( self._client.poll_available_model_versions( model_name)) > 0: raise tfs_was_unresponsive = True # remove model from disk and memory references if TFS gets unresponsive if tfs_was_unresponsive: with LockedModel(self._models, "w", model_name, model_version): available_versions = self._client.poll_available_model_versions( model_name) status, _ = self._models.has_model(model_name, model_version) if not (status == "in-memory" and model_version not in available_versions): raise WithBreak logger().info( f"removing model {model_name} of version {model_version} because TFS got unresponsive" ) self._models.remove_model(model_name, model_version) # download, load into memory the model and retrieve it if update_model: # grab exclusive access to models holder with LockedModel(self._models, "w", model_name, model_version): # check model status status, local_ts = self._models.has_model( model_name, model_version) # refresh disk model if status == "not-available" or ( status in ["on-disk", "in-memory"] and local_ts != current_upstream_ts): # unload model from TFS if status == "in-memory": try: logger().info( f"unloading model {model_name} of version {model_version} from TFS" ) self._models.unload_model( model_name, model_version) except Exception: logger().info( f"failed unloading model {model_name} of version {model_version} from TFS" ) raise # remove model from disk and references if status in ["on-disk", "in-memory"]: logger().info( f"removing model references from memory and from disk for model {model_name} of version {model_version}" ) self._models.remove_model(model_name, model_version) # download model logger().info( f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream" ) date = self._models.download_model( upstream_model["provider"], upstream_model["bucket"], model_name, model_version, upstream_model["path"], ) if not date: raise WithBreak current_upstream_ts = int(date.timestamp()) # load model try: logger().info( f"loading model {model_name} of version {model_version} into memory" ) self._models.load_model( model_name, model_version, current_upstream_ts, [tag], kwargs={ "model_name": model_name, "model_version": model_version, "signature_key": self._determine_model_signature_key( model_name), }, ) except Exception as e: raise UserRuntimeException( f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})", str(e), ) # run prediction self._models.get_model(model_name, model_version, tag) prediction = self._client.predict(model_input, model_name, model_version) return prediction
def start_fn(): provider = os.environ["CORTEX_PROVIDER"] project_dir = os.environ["CORTEX_PROJECT_DIR"] spec_path = os.environ["CORTEX_API_SPEC"] model_dir = os.getenv("CORTEX_MODEL_DIR") cache_dir = os.getenv("CORTEX_CACHE_DIR") region = os.getenv("AWS_REGION") tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000") tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost") has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS") if has_multiple_servers: with LockedFile("/run/used_ports.json", "r+") as f: used_ports = json.load(f) for port in used_ports.keys(): if not used_ports[port]: tf_serving_port = port used_ports[port] = True break f.seek(0) json.dump(used_ports, f) f.truncate() try: api = get_api(provider, spec_path, model_dir, cache_dir, region) client = api.predictor.initialize_client( tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port) with FileLock("/run/init_stagger.lock"): logger().info("loading the predictor from {}".format( api.predictor.path)) predictor_impl = api.predictor.initialize_impl(project_dir, client) local_cache["api"] = api local_cache["provider"] = provider local_cache["client"] = client local_cache["predictor_impl"] = predictor_impl local_cache["predict_fn_args"] = inspect.getfullargspec( predictor_impl.predict).args if util.has_method(predictor_impl, "post_predict"): local_cache["post_predict_fn_args"] = inspect.getfullargspec( predictor_impl.post_predict).args predict_route = "/" if provider != "local": predict_route = "/predict" local_cache["predict_route"] = predict_route except: logger().exception("failed to start api") sys.exit(1) if (provider != "local" and api.monitoring is not None and api.monitoring.model_type == "classification"): try: local_cache["class_set"] = api.get_cached_classes() except: logger().warn("an error occurred while attempting to load classes", exc_info=True) app.add_api_route(local_cache["predict_route"], predict, methods=["POST"]) app.add_api_route(local_cache["predict_route"], get_summary, methods=["GET"]) return app