Example #1
0
def start():
    while not pathlib.Path("/mnt/workspace/init_script_run.txt").is_file():
        time.sleep(0.2)

    cache_dir = os.environ["CORTEX_CACHE_DIR"]
    provider = os.environ["CORTEX_PROVIDER"]
    api_spec_path = os.environ["CORTEX_API_SPEC"]
    job_spec_path = os.environ["CORTEX_JOB_SPEC"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]

    model_dir = os.getenv("CORTEX_MODEL_DIR")
    tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000")
    tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost")

    bucket = os.getenv("CORTEX_BUCKET")
    region = os.getenv("AWS_REGION")

    has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS")
    if has_multiple_servers:
        with LockedFile("/run/used_ports.json", "r+") as f:
            used_ports = json.load(f)
            for port in used_ports.keys():
                if not used_ports[port]:
                    tf_serving_port = port
                    used_ports[port] = True
                    break
            f.seek(0)
            json.dump(used_ports, f)
            f.truncate()

    api = get_api(provider, api_spec_path, model_dir, cache_dir, bucket,
                  region)
    storage, api_spec = get_spec(provider, api_spec_path, cache_dir, bucket,
                                 region)
    job_spec = get_job_spec(storage, cache_dir, job_spec_path)

    client = api.predictor.initialize_client(tf_serving_host=tf_serving_host,
                                             tf_serving_port=tf_serving_port)
    logger().info("loading the predictor from {}".format(api.predictor.path))
    predictor_impl = api.predictor.initialize_impl(project_dir, client,
                                                   job_spec)

    local_cache["api_spec"] = api
    local_cache["provider"] = provider
    local_cache["job_spec"] = job_spec
    local_cache["predictor_impl"] = predictor_impl
    local_cache["predict_fn_args"] = inspect.getfullargspec(
        predictor_impl.predict).args
    local_cache["sqs_client"] = boto3.client("sqs", region_name=region)

    open("/mnt/workspace/api_readiness.txt", "a").close()

    logger().info("polling for batches...")
    sqs_loop()
Example #2
0
    def garbage_collect(
            self,
            exclude_disk_model_ids: List[str] = [],
            dry_run: bool = False) -> Tuple[bool, List[str], List[str]]:
        """
        Removes stale in-memory and on-disk models based on LRU policy.
        Also calls the "remove" callback before removing the models from this object. The callback must not raise any exceptions.

        Must be called with a write lock unless dry_run is set to true.

        Args:
            exclude_disk_model_ids: Model IDs to exclude from removing from disk. Necessary for locally-provided models.
            dry_run: Just test if there are any models to remove. If set to true, this method can then be called with a read lock.

        Returns:
            A 3-element tuple. First element tells whether models had to be collected. The 2nd and 3rd elements contain the model IDs that were removed from memory and disk respectively.
        """
        collected = False
        if self._mem_cache_size <= 0 or self._disk_cache_size <= 0:
            return collected

        stale_mem_model_ids = self._lru_model_ids(self._mem_cache_size,
                                                  filter_in_mem=True)
        stale_disk_model_ids = self._lru_model_ids(self._disk_cache_size -
                                                   len(exclude_disk_model_ids),
                                                   filter_in_mem=False)

        if self._remove_callback and not dry_run:
            self._remove_callback(stale_mem_model_ids)

        # don't delete excluded model IDs from disk
        stale_disk_model_ids = list(
            set(stale_disk_model_ids) - set(exclude_disk_model_ids))
        stale_disk_model_ids = stale_disk_model_ids[len(stale_disk_model_ids) -
                                                    self._disk_cache_size:]

        if not dry_run:
            logger().info(
                f"unloading models {stale_mem_model_ids} from memory using the garbage collector"
            )
            logger().info(
                f"unloading models {stale_disk_model_ids} from disk using the garbage collector"
            )
            for model_id in stale_mem_model_ids:
                self.remove_model_by_id(model_id, mem=True, disk=False)
            for model_id in stale_disk_model_ids:
                self.remove_model_by_id(model_id, mem=False, disk=True)

        if len(stale_mem_model_ids) > 0 or len(stale_disk_model_ids) > 0:
            collected = True

        return collected, stale_mem_model_ids, stale_disk_model_ids
Example #3
0
    def post_metrics(self, metrics):
        try:
            if self.statsd is None:
                raise CortexException("statsd client not initialized")  # unexpected

            for metric in metrics:
                tags = ["{}:{}".format(dim["Name"], dim["Value"]) for dim in metric["Dimensions"]]
                if metric.get("Unit") == "Count":
                    self.statsd.increment(metric["MetricName"], value=metric["Value"], tags=tags)
                else:
                    self.statsd.histogram(metric["MetricName"], value=metric["Value"], tags=tags)
        except:
            logger().warn("failure encountered while publishing metrics", exc_info=True)
Example #4
0
def start(args):
    download_config = json.loads(base64.urlsafe_b64decode(args.download))
    for download_arg in download_config["download_args"]:
        from_path = download_arg["from"]
        to_path = download_arg["to"]
        item_name = download_arg.get("item_name", "")
        bucket_name, prefix = S3.deconstruct_s3_path(from_path)
        s3_client = S3(bucket_name, client_config={})

        if item_name != "":
            if download_arg.get("hide_from_log", False):
                logger().info("downloading {}".format(item_name))
            else:
                logger().info("downloading {} from {}".format(
                    item_name, from_path))

        if download_arg.get("to_file", False):
            s3_client.download_file(prefix, to_path)
        else:
            s3_client.download(prefix, to_path)

        if download_arg.get("unzip", False):
            if item_name != "" and not download_arg.get(
                    "hide_unzipping_log", False):
                logger().info("unzipping {}".format(item_name))
            if download_arg.get("to_file", False):
                util.extract_zip(to_path, delete_zip_file=True)
            else:
                util.extract_zip(os.path.join(to_path,
                                              os.path.basename(from_path)),
                                 delete_zip_file=True)

    if download_config.get("last_log", "") != "":
        logger().info(download_config["last_log"])
Example #5
0
def handle_on_complete(message):
    job_spec = local_cache["job_spec"]
    predictor_impl = local_cache["predictor_impl"]
    sqs_client = local_cache["sqs_client"]
    queue_url = job_spec["sqs_url"]
    receipt_handle = message["ReceiptHandle"]

    try:
        if not getattr(predictor_impl, "on_job_complete", None):
            sqs_client.delete_message(QueueUrl=queue_url,
                                      ReceiptHandle=receipt_handle)
            return True

        should_run_on_job_complete = False

        while True:
            visible_count, not_visible_count = get_total_messages_in_queue()

            # if there are other messages that are visible, release this message and get the other ones (should rarely happen for FIFO)
            if visible_count > 0:
                sqs_client.change_message_visibility(
                    QueueUrl=queue_url,
                    ReceiptHandle=receipt_handle,
                    VisibilityTimeout=0)
                return False

            if should_run_on_job_complete:
                # double check that the queue is still empty (except for the job_complete message)
                if not_visible_count <= 1:
                    logger().info("executing on_job_complete")
                    predictor_impl.on_job_complete()
                    sqs_client.delete_message(QueueUrl=queue_url,
                                              ReceiptHandle=receipt_handle)
                    return True
                else:
                    should_run_on_job_complete = False

            if not_visible_count <= 1:
                should_run_on_job_complete = True

            time.sleep(20)
    except:
        sqs_client.delete_message(QueueUrl=queue_url,
                                  ReceiptHandle=receipt_handle)
        raise
Example #6
0
def predict(request: Request):
    tasks = BackgroundTasks()
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]
    kwargs = build_predict_kwargs(request)

    prediction = predictor_impl.predict(**kwargs)

    if isinstance(prediction, bytes):
        response = Response(content=prediction,
                            media_type="application/octet-stream")
    elif isinstance(prediction, str):
        response = Response(content=prediction, media_type="text/plain")
    elif isinstance(prediction, Response):
        response = prediction
    else:
        try:
            json_string = json.dumps(prediction)
        except Exception as e:
            raise UserRuntimeException(
                str(e),
                "please return an object that is JSON serializable (including its nested fields), a bytes object, a string, or a starlette.response.Response object",
            ) from e
        response = Response(content=json_string, media_type="application/json")

    if local_cache["provider"] != "local" and api.monitoring is not None:
        try:
            predicted_value = api.monitoring.extract_predicted_value(
                prediction)
            api.post_monitoring_metrics(predicted_value)
            if (api.monitoring.model_type == "classification"
                    and predicted_value not in local_cache["class_set"]):
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
        except:
            logger().warn("unable to record prediction metric", exc_info=True)

    if util.has_method(predictor_impl, "post_predict"):
        kwargs = build_post_predict_kwargs(prediction, request)
        request_thread_pool.submit(predictor_impl.post_predict, **kwargs)

    if len(tasks.tasks) > 0:
        response.background = tasks

    return response
Example #7
0
def validate_models_dir_paths(
        paths: List[str], predictor_type: PredictorType,
        common_prefix: str) -> Tuple[List[str], List[List[int]]]:
    """
    Validates the models paths based on the given predictor type.
    To be used when predictor:models:dir in cortex.yaml is used.

    Args:
        paths: A list of all paths for a given S3/local prefix. Must be underneath the common prefix.
        predictor_type: The predictor type.
        common_prefix: The common prefix of the directory which holds all models. AKA predictor:models:dir.

    Returns:
        List with the prefix of each model that's valid.
        List with the OneOfAllPlaceholder IDs validated for each valid model.
    """
    if len(paths) == 0:
        raise CortexException(
            f"{predictor_type} predictor at '{common_prefix}'",
            "model top path can't be empty")

    rel_paths = [
        os.path.relpath(top_path, common_prefix) for top_path in paths
    ]
    rel_paths = [path for path in rel_paths if not path.startswith("../")]

    model_names = [util.get_leftmost_part_of_path(path) for path in rel_paths]
    model_names = list(set(model_names))

    valid_model_prefixes = []
    ooa_valid_key_ids = []
    for model_name in model_names:
        try:
            ooa_valid_key_ids.append(
                validate_model_paths(rel_paths, predictor_type, model_name))
            valid_model_prefixes.append(os.path.join(common_prefix,
                                                     model_name))
        except CortexException as e:
            logger().debug(f"failed validating model {model_name}: {str(e)}")
            continue

    return valid_model_prefixes, ooa_valid_key_ids
Example #8
0
    def _remove_models(self, model_ids: List[str]) -> None:
        """
        Remove models from TFS.
        Must only be used when caching enabled.
        """
        logger().info(f"unloading models with model IDs {model_ids} from TFS")

        models = {}
        for model_id in model_ids:
            model_name, model_version = model_id.rsplit("-", maxsplit=1)
            if model_name not in models:
                models[model_name] = [model_version]
            else:
                models[model_name].append(model_version)

        model_names = []
        model_versions = []
        for model_name, versions in models.items():
            model_names.append(model_name)
            model_versions.append(versions)

        self._client.remove_models(model_names, model_versions)
Example #9
0
def start(args):
    download_config = json.loads(base64.urlsafe_b64decode(args.download))
    for download_arg in download_config["download_args"]:
        from_path = download_arg["from"]
        to_path = download_arg["to"]
        item_name = download_arg.get("item_name", "")

        if from_path.startswith("s3://"):
            bucket_name, prefix = S3.deconstruct_s3_path(from_path)
            client = S3(bucket_name, client_config={})
        elif from_path.startswith("gs://"):
            bucket_name, prefix = GCS.deconstruct_gcs_path(from_path)
            client = GCS(bucket_name)
        else:
            raise ValueError(
                '"from" download arg can either have the "s3://" or "gs://" prefixes'
            )

        if item_name != "":
            if download_arg.get("hide_from_log", False):
                logger().info("downloading {}".format(item_name))
            else:
                logger().info("downloading {} from {}".format(
                    item_name, from_path))

        if download_arg.get("to_file", False):
            client.download_file(prefix, to_path)
        else:
            client.download(prefix, to_path)

        if download_arg.get("unzip", False):
            if item_name != "" and not download_arg.get(
                    "hide_unzipping_log", False):
                logger().info("unzipping {}".format(item_name))
            if download_arg.get("to_file", False):
                util.extract_zip(to_path, delete_zip_file=True)
            else:
                util.extract_zip(os.path.join(to_path,
                                              os.path.basename(from_path)),
                                 delete_zip_file=True)

    if download_config.get("last_log", "") != "":
        logger().info(download_config["last_log"])
Example #10
0
    def _get_model(self, model_name: str, model_version: str) -> Any:
        """
        Checks if versioned model is on disk, then checks if model is in memory,
        and if not, it loads it into memory, and returns the model.

        Args:
            model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk.
            model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" version tag.

        Exceptions:
            RuntimeError: if another thread tried to load the model at the very same time.

        Returns:
            The model as returned by self._load_model method.
            None if the model wasn't found or if it didn't pass the validation.
        """

        model = None
        tag = ""
        if model_version == "latest":
            tag = model_version

        if not self._caching_enabled:
            # determine model version
            if tag == "latest":
                model_version = self._get_latest_model_version_from_disk(
                    model_name)
            model_id = model_name + "-" + model_version

            # grab shared access to versioned model
            resource = os.path.join(self._lock_dir, model_id + ".txt")
            with LockedFile(resource, "r", reader_lock=True) as f:

                # check model status
                file_status = f.read()
                if file_status == "" or file_status == "not-available":
                    raise WithBreak

                current_upstream_ts = int(file_status.split(" ")[1])
                update_model = False

                # grab shared access to models holder and retrieve model
                with LockedModel(self._models, "r", model_name, model_version):
                    status, local_ts = self._models.has_model(
                        model_name, model_version)
                    if status == "not-available" or (
                            status == "in-memory"
                            and local_ts != current_upstream_ts):
                        update_model = True
                        raise WithBreak
                    model, _ = self._models.get_model(model_name,
                                                      model_version, tag)

                # load model into memory and retrieve it
                if update_model:
                    with LockedModel(self._models, "w", model_name,
                                     model_version):
                        status, _ = self._models.has_model(
                            model_name, model_version)
                        if status == "not-available" or (
                                status == "in-memory"
                                and local_ts != current_upstream_ts):
                            if status == "not-available":
                                logger().info(
                                    f"loading model {model_name} of version {model_version} (thread {td.get_ident()})"
                                )
                            else:
                                logger().info(
                                    f"reloading model {model_name} of version {model_version} (thread {td.get_ident()})"
                                )
                            try:
                                self._models.load_model(
                                    model_name,
                                    model_version,
                                    current_upstream_ts,
                                    [tag],
                                )
                            except Exception as e:
                                raise UserRuntimeException(
                                    f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})",
                                    str(e),
                                )
                        model, _ = self._models.get_model(
                            model_name, model_version, tag)

        if not self._multiple_processes and self._caching_enabled:
            # determine model version
            try:
                if tag == "latest":
                    model_version = self._get_latest_model_version_from_tree(
                        model_name, self._models_tree.model_info(model_name))
            except ValueError:
                # if model_name hasn't been found
                raise UserRuntimeException(
                    f"'{model_name}' model of tag {tag} wasn't found in the list of available models"
                )

            # grab shared access to model tree
            available_model = True
            with LockedModelsTree(self._models_tree, "r", model_name,
                                  model_version):

                # check if the versioned model exists
                model_id = model_name + "-" + model_version
                if model_id not in self._models_tree:
                    available_model = False
                    raise WithBreak

                # retrieve model tree's metadata
                upstream_model = self._models_tree[model_id]
                current_upstream_ts = int(
                    upstream_model["timestamp"].timestamp())

            if not available_model:
                return None

            # grab shared access to models holder and retrieve model
            update_model = False
            with LockedModel(self._models, "r", model_name, model_version):
                status, local_ts = self._models.has_model(
                    model_name, model_version)
                if status in [
                        "not-available", "on-disk"
                ] or (status != "not-available"
                      and local_ts != current_upstream_ts
                      and not (status == "in-memory" and model_name
                               in self._spec_local_model_names)):
                    update_model = True
                    raise WithBreak
                model, _ = self._models.get_model(model_name, model_version,
                                                  tag)

            # download, load into memory the model and retrieve it
            if update_model:
                # grab exclusive access to models holder
                with LockedModel(self._models, "w", model_name, model_version):

                    # check model status
                    status, local_ts = self._models.has_model(
                        model_name, model_version)

                    # refresh disk model
                    if model_name not in self._spec_local_model_names and (
                            status == "not-available" or
                        (status in ["on-disk", "in-memory"]
                         and local_ts != current_upstream_ts)):
                        if status == "not-available":
                            logger().info(
                                f"model {model_name} of version {model_version} not found locally; continuing with the download..."
                            )
                        elif status == "on-disk":
                            logger().info(
                                f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one on the disk"
                            )
                        else:
                            logger().info(
                                f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one loaded into memory"
                            )

                        # remove model from disk and memory
                        if status == "on-disk":
                            logger().info(
                                f"removing model from disk for model {model_name} of version {model_version}"
                            )
                            self._models.remove_model(model_name,
                                                      model_version)
                        if status == "in-memory":
                            logger().info(
                                f"removing model from disk and memory for model {model_name} of version {model_version}"
                            )
                            self._models.remove_model(model_name,
                                                      model_version)

                        # download model
                        logger().info(
                            f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream"
                        )
                        date = self._models.download_model(
                            upstream_model["provider"],
                            upstream_model["bucket"],
                            model_name,
                            model_version,
                            upstream_model["path"],
                        )
                        if not date:
                            raise WithBreak
                        current_upstream_ts = int(date.timestamp())

                    # give the local model a timestamp initialized at start time
                    if model_name in self._spec_local_model_names:
                        current_upstream_ts = self._local_model_ts

                    # load model
                    try:
                        logger().info(
                            f"loading model {model_name} of version {model_version} into memory"
                        )
                        self._models.load_model(
                            model_name,
                            model_version,
                            current_upstream_ts,
                            [tag],
                        )
                    except Exception as e:
                        raise UserRuntimeException(
                            f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})",
                            str(e),
                        )

                    # retrieve model
                    model, _ = self._models.get_model(model_name,
                                                      model_version, tag)

        return model
Example #11
0
def sqs_loop():
    job_spec = local_cache["job_spec"]
    api_spec = local_cache["api_spec"]
    predictor_impl = local_cache["predictor_impl"]
    sqs_client = local_cache["sqs_client"]

    queue_url = job_spec["sqs_url"]

    no_messages_found_in_previous_iteration = False

    while True:
        response = sqs_client.receive_message(
            QueueUrl=queue_url,
            MaxNumberOfMessages=1,
            WaitTimeSeconds=10,
            VisibilityTimeout=MAXIMUM_MESSAGE_VISIBILITY,
            MessageAttributeNames=["All"],
        )

        if response.get("Messages") is None or len(response["Messages"]) == 0:
            if no_messages_found_in_previous_iteration:
                logger().info("no batches left in queue, exiting...")
                return
            else:
                no_messages_found_in_previous_iteration = True
                continue
        else:
            no_messages_found_in_previous_iteration = False

        message = response["Messages"][0]

        receipt_handle = message["ReceiptHandle"]

        if "MessageAttributes" in message and "job_complete" in message[
                "MessageAttributes"]:
            handled_on_complete = handle_on_complete(message)
            if handled_on_complete:
                logger().info(
                    "no batches left in queue, job has been completed")
                return
            else:
                # sometimes on_job_complete message will be released if there are other messages still to be processed
                continue

        try:
            logger().info(f"processing batch {message['MessageId']}")

            start_time = time.time()

            payload = json.loads(message["Body"])
            batch_id = message["MessageId"]
            predictor_impl.predict(**build_predict_args(payload, batch_id))

            api_spec.post_metrics([
                success_counter_metric(),
                time_per_batch_metric(time.time() - start_time)
            ])
        except Exception:
            api_spec.post_metrics([
                failed_counter_metric(),
                time_per_batch_metric(time.time() - start_time)
            ])
            logger().exception("failed to process batch")
        finally:
            sqs_client.delete_message(QueueUrl=queue_url,
                                      ReceiptHandle=receipt_handle)
Example #12
0
def model_downloader(
    predictor_type: PredictorType,
    bucket_provider: str,
    bucket_name: str,
    model_name: str,
    model_version: str,
    model_path: str,
    temp_dir: str,
    model_dir: str,
) -> Optional[datetime.datetime]:
    """
    Downloads model to disk. Validates the cloud model path and the downloaded model as well.

    Args:
        predictor_type: The predictor type as implemented by the API.
        bucket_provider: Provider for the bucket. Can be "s3" or "gs".
        bucket_name: Name of the bucket where the model is stored.
        model_name: Name of the model. Is part of the model's local path.
        model_version: Version of the model. Is part of the model's local path.
        model_path: Model prefix of the versioned model.
        temp_dir: Where to temporarily store the model for validation.
        model_dir: The top directory of where all models are stored locally.

    Returns:
        The model's timestamp. None if the model didn't pass the validation, if it doesn't exist or if there are not enough permissions.
    """

    logger().info(
        f"downloading from bucket {bucket_name}/{model_path}, model {model_name} of version {model_version}, temporarily to {temp_dir} and then finally to {model_dir}"
    )

    if bucket_provider == "s3":
        client = S3(bucket_name)
    if bucket_provider == "gs":
        client = GCS(bucket_name)

    # validate upstream cloud model
    sub_paths, ts = client.search(model_path)
    try:
        validate_model_paths(sub_paths, predictor_type, model_path)
    except CortexException:
        logger().info(f"failed validating model {model_name} of version {model_version}")
        return None

    # download model to temp dir
    temp_dest = os.path.join(temp_dir, model_name, model_version)
    try:
        client.download_dir_contents(model_path, temp_dest)
    except CortexException:
        logger().info(
            f"failed downloading model {model_name} of version {model_version} to temp dir {temp_dest}"
        )
        shutil.rmtree(temp_dest)
        return None

    # validate model
    model_contents = glob.glob(temp_dest + "*/**", recursive=True)
    model_contents = util.remove_non_empty_directory_paths(model_contents)
    try:
        validate_model_paths(model_contents, predictor_type, temp_dest)
    except CortexException:
        logger().info(
            f"failed validating model {model_name} of version {model_version} from temp dir"
        )
        shutil.rmtree(temp_dest)
        return None

    # move model to dest dir
    model_top_dir = os.path.join(model_dir, model_name)
    ondisk_model_version = os.path.join(model_top_dir, model_version)
    logger().info(
        f"moving model {model_name} of version {model_version} to final dir {ondisk_model_version}"
    )
    if os.path.isdir(ondisk_model_version):
        shutil.rmtree(ondisk_model_version)
    shutil.move(temp_dest, ondisk_model_version)

    return max(ts)
Example #13
0
    def _extract_signatures(
        self, signature_def, signature_key, model_name: str, model_version: str
    ):
        logger().info(
            "signature defs found in model '{}' for version '{}': {}".format(
                model_name, model_version, signature_def
            )
        )

        available_keys = list(signature_def.keys())
        if len(available_keys) == 0:
            raise UserException(
                "unable to find signature defs in model '{}' of version '{}'".format(
                    model_name, model_version
                )
            )

        if signature_key is None:
            if len(available_keys) == 1:
                logger().info(
                    "signature_key was not configured by user, using signature key '{}' for model '{}' of version '{}' (found in the signature def map)".format(
                        available_keys[0],
                        model_name,
                        model_version,
                    )
                )
                signature_key = available_keys[0]
            elif "predict" in signature_def:
                logger().info(
                    "signature_key was not configured by user, using signature key 'predict' for model '{}' of version '{}' (found in the signature def map)".format(
                        model_name,
                        model_version,
                    )
                )
                signature_key = "predict"
            else:
                raise UserException(
                    "signature_key was not configured by user, please specify one the following keys '{}' for model '{}' of version '{}' (found in the signature def map)".format(
                        ", ".join(available_keys), model_name, model_version
                    )
                )
        else:
            if signature_def.get(signature_key) is None:
                possibilities_str = "key: '{}'".format(available_keys[0])
                if len(available_keys) > 1:
                    possibilities_str = "keys: '{}'".format("', '".join(available_keys))

                raise UserException(
                    "signature_key '{}' was not found in signature def map for model '{}' of version '{}', but found the following {}".format(
                        signature_key, model_name, model_version, possibilities_str
                    )
                )

        signature_def_val = signature_def.get(signature_key)

        if signature_def_val.get("inputs") is None:
            raise UserException(
                "unable to find 'inputs' in signature def '{}' for model '{}'".format(
                    signature_key, model_name
                )
            )

        parsed_signatures = {}
        for input_name, input_metadata in signature_def_val["inputs"].items():
            if input_metadata["tensorShape"] == {}:
                # a scalar with rank 0 and empty shape
                shape = "scalar"
            elif input_metadata["tensorShape"].get("unknownRank", False):
                # unknown rank and shape
                #
                # unknownRank is set to True if the model input has no rank
                # it may lead to an undefined behavior if unknownRank is only checked for its presence
                # so it also gets to be tested against its value
                shape = "unknown"
            elif input_metadata["tensorShape"].get("dim", None):
                # known rank and known/unknown shape
                shape = [int(dim["size"]) for dim in input_metadata["tensorShape"]["dim"]]
            else:
                raise UserException(
                    "invalid 'tensorShape' specification for input '{}' in signature key '{}' for model '{}'",
                    input_name,
                    signature_key,
                    model_name,
                )

            parsed_signatures[input_name] = {
                "shape": shape if type(shape) == list else [shape],
                "type": DTYPE_TO_TF_TYPE[input_metadata["dtype"]].name,
            }
        return signature_key, parsed_signatures
Example #14
0
    def _run_inference(self, model_input: Any, model_name: str,
                       model_version: str) -> dict:
        """
        When processes_per_replica = 1 and caching enabled, check/load model and make prediction.
        When processes_per_replica > 0 and caching disabled, attempt to make prediction regardless.

        Args:
            model_input: Input to the model.
            model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk.
            model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" version tag.

        Returns:
            The prediction.
        """

        model = None
        tag = ""
        if model_version == "latest":
            tag = model_version

        if not self._caching_enabled:

            # determine model version
            if tag == "latest":
                versions = self._client.poll_available_model_versions(
                    model_name)
                if len(versions) == 0:
                    raise UserException(
                        f"model '{model_name}' accessed with tag {tag} couldn't be found"
                    )
                model_version = str(max(map(lambda x: int(x), versions)))
            model_id = model_name + "-" + model_version

            return self._client.predict(model_input, model_name, model_version)

        if not self._multiple_processes and self._caching_enabled:

            # determine model version
            try:
                if tag == "latest":
                    model_version = self._get_latest_model_version_from_tree(
                        model_name, self._models_tree.model_info(model_name))
            except ValueError:
                # if model_name hasn't been found
                raise UserRuntimeException(
                    f"'{model_name}' model of tag {tag} wasn't found in the list of available models"
                )

            models_stats = []
            for model_id in self._models.get_model_ids():
                models_stats = self._models.has_model_id(model_id)

            # grab shared access to model tree
            available_model = True
            logger().info(
                f"grabbing access to model {model_name} of version {model_version}"
            )
            with LockedModelsTree(self._models_tree, "r", model_name,
                                  model_version):

                # check if the versioned model exists
                model_id = model_name + "-" + model_version
                if model_id not in self._models_tree:
                    available_model = False
                    logger().info(
                        f"model {model_name} of version {model_version} is not available"
                    )
                    raise WithBreak

                # retrieve model tree's metadata
                upstream_model = self._models_tree[model_id]
                current_upstream_ts = int(
                    upstream_model["timestamp"].timestamp())
                logger().info(
                    f"model {model_name} of version {model_version} is available"
                )

            if not available_model:
                if tag == "":
                    raise UserException(
                        f"model '{model_name}' of version '{model_version}' couldn't be found"
                    )
                raise UserException(
                    f"model '{model_name}' accessed with tag '{tag}' couldn't be found"
                )

            # grab shared access to models holder and retrieve model
            update_model = False
            prediction = None
            tfs_was_unresponsive = False
            with LockedModel(self._models, "r", model_name, model_version):
                logger().info(
                    f"checking the {model_name} {model_version} status")
                status, local_ts = self._models.has_model(
                    model_name, model_version)
                if status in ["not-available", "on-disk"
                              ] or (status != "not-available"
                                    and local_ts != current_upstream_ts):
                    logger().info(
                        f"model {model_name} of version {model_version} is not loaded (with status {status} or different timestamp)"
                    )
                    update_model = True
                    raise WithBreak

                # run prediction
                logger().info(
                    f"run the prediction on model {model_name} of version {model_version}"
                )
                self._models.get_model(model_name, model_version, tag)
                try:
                    prediction = self._client.predict(model_input, model_name,
                                                      model_version)
                except grpc.RpcError as e:
                    # effectively when it got restarted
                    if len(
                            self._client.poll_available_model_versions(
                                model_name)) > 0:
                        raise
                    tfs_was_unresponsive = True

            # remove model from disk and memory references if TFS gets unresponsive
            if tfs_was_unresponsive:
                with LockedModel(self._models, "w", model_name, model_version):
                    available_versions = self._client.poll_available_model_versions(
                        model_name)
                    status, _ = self._models.has_model(model_name,
                                                       model_version)
                    if not (status == "in-memory"
                            and model_version not in available_versions):
                        raise WithBreak

                    logger().info(
                        f"removing model {model_name} of version {model_version} because TFS got unresponsive"
                    )
                    self._models.remove_model(model_name, model_version)

            # download, load into memory the model and retrieve it
            if update_model:
                # grab exclusive access to models holder
                with LockedModel(self._models, "w", model_name, model_version):

                    # check model status
                    status, local_ts = self._models.has_model(
                        model_name, model_version)

                    # refresh disk model
                    if status == "not-available" or (
                            status in ["on-disk", "in-memory"]
                            and local_ts != current_upstream_ts):
                        # unload model from TFS
                        if status == "in-memory":
                            try:
                                logger().info(
                                    f"unloading model {model_name} of version {model_version} from TFS"
                                )
                                self._models.unload_model(
                                    model_name, model_version)
                            except Exception:
                                logger().info(
                                    f"failed unloading model {model_name} of version {model_version} from TFS"
                                )
                                raise

                        # remove model from disk and references
                        if status in ["on-disk", "in-memory"]:
                            logger().info(
                                f"removing model references from memory and from disk for model {model_name} of version {model_version}"
                            )
                            self._models.remove_model(model_name,
                                                      model_version)

                        # download model
                        logger().info(
                            f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream"
                        )
                        date = self._models.download_model(
                            upstream_model["provider"],
                            upstream_model["bucket"],
                            model_name,
                            model_version,
                            upstream_model["path"],
                        )
                        if not date:
                            raise WithBreak
                        current_upstream_ts = int(date.timestamp())

                    # load model
                    try:
                        logger().info(
                            f"loading model {model_name} of version {model_version} into memory"
                        )
                        self._models.load_model(
                            model_name,
                            model_version,
                            current_upstream_ts,
                            [tag],
                            kwargs={
                                "model_name":
                                model_name,
                                "model_version":
                                model_version,
                                "signature_key":
                                self._determine_model_signature_key(
                                    model_name),
                            },
                        )
                    except Exception as e:
                        raise UserRuntimeException(
                            f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})",
                            str(e),
                        )

                    # run prediction
                    self._models.get_model(model_name, model_version, tag)
                    prediction = self._client.predict(model_input, model_name,
                                                      model_version)

            return prediction
Example #15
0
def start_fn():
    provider = os.environ["CORTEX_PROVIDER"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]
    spec_path = os.environ["CORTEX_API_SPEC"]

    model_dir = os.getenv("CORTEX_MODEL_DIR")
    cache_dir = os.getenv("CORTEX_CACHE_DIR")
    region = os.getenv("AWS_REGION")

    tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000")
    tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost")

    has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS")
    if has_multiple_servers:
        with LockedFile("/run/used_ports.json", "r+") as f:
            used_ports = json.load(f)
            for port in used_ports.keys():
                if not used_ports[port]:
                    tf_serving_port = port
                    used_ports[port] = True
                    break
            f.seek(0)
            json.dump(used_ports, f)
            f.truncate()

    try:
        api = get_api(provider, spec_path, model_dir, cache_dir, region)

        client = api.predictor.initialize_client(
            tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port)

        with FileLock("/run/init_stagger.lock"):
            logger().info("loading the predictor from {}".format(
                api.predictor.path))
            predictor_impl = api.predictor.initialize_impl(project_dir, client)

        local_cache["api"] = api
        local_cache["provider"] = provider
        local_cache["client"] = client
        local_cache["predictor_impl"] = predictor_impl
        local_cache["predict_fn_args"] = inspect.getfullargspec(
            predictor_impl.predict).args
        if util.has_method(predictor_impl, "post_predict"):
            local_cache["post_predict_fn_args"] = inspect.getfullargspec(
                predictor_impl.post_predict).args

        predict_route = "/"
        if provider != "local":
            predict_route = "/predict"
        local_cache["predict_route"] = predict_route
    except:
        logger().exception("failed to start api")
        sys.exit(1)

    if (provider != "local" and api.monitoring is not None
            and api.monitoring.model_type == "classification"):
        try:
            local_cache["class_set"] = api.get_cached_classes()
        except:
            logger().warn("an error occurred while attempting to load classes",
                          exc_info=True)

    app.add_api_route(local_cache["predict_route"], predict, methods=["POST"])
    app.add_api_route(local_cache["predict_route"],
                      get_summary,
                      methods=["GET"])

    return app