Beispiel #1
0
def renew_message_visibility(receipt_handle: str):
    queue_url = local_cache["job_spec"]["sqs_url"]
    interval = MESSAGE_RENEWAL_PERIOD
    new_timeout = INITIAL_MESSAGE_VISIBILITY
    cur_time = time.time()

    while True:
        time.sleep((cur_time + interval) - time.time())
        cur_time += interval
        new_timeout += interval

        with receipt_handle_mutex:
            if receipt_handle in stop_renewal:
                stop_renewal.remove(receipt_handle)
                break

            try:
                local_cache["sqs_client"].change_message_visibility(
                    QueueUrl=queue_url,
                    ReceiptHandle=receipt_handle,
                    VisibilityTimeout=new_timeout)
            except botocore.exceptions.ClientError as e:
                if e.response["Error"]["Code"] == "InvalidParameterValue":
                    # unexpected; this error is thrown when attempting to renew a message that has been deleted
                    continue
                elif e.response["Error"][
                        "Code"] == "AWS.SimpleQueueService.NonExistentQueue":
                    # there may be a delay between the cron may deleting the queue and this worker stopping
                    logger().info(
                        "failed to renew message visibility because the queue was not found"
                    )
                else:
                    stop_renewal.remove(receipt_handle)
                    raise e
Beispiel #2
0
def handle_on_job_complete(message):
    job_spec = local_cache["job_spec"]
    predictor_impl = local_cache["predictor_impl"]
    sqs_client = local_cache["sqs_client"]
    queue_url = job_spec["sqs_url"]
    receipt_handle = message["ReceiptHandle"]

    should_run_on_job_complete = False
    try:
        while True:
            visible_messages, invisible_messages = get_total_messages_in_queue(
            )
            total_messages = visible_messages + invisible_messages
            if total_messages > 1:
                new_message_id = uuid.uuid4()
                time.sleep(JOB_COMPLETE_MESSAGE_RENEWAL)
                sqs_client.send_message(
                    QueueUrl=queue_url,
                    MessageBody='"job_complete"',
                    MessageAttributes={
                        "job_complete": {
                            "StringValue": "true",
                            "DataType": "String"
                        },
                        "api_name": {
                            "StringValue": job_spec["api_name"],
                            "DataType": "String"
                        },
                        "job_id": {
                            "StringValue": job_spec["job_id"],
                            "DataType": "String"
                        },
                    },
                    MessageDeduplicationId=str(new_message_id),
                    MessageGroupId=str(new_message_id),
                )
                break
            else:
                if should_run_on_job_complete:
                    if getattr(predictor_impl, "on_job_complete", None):
                        logger().info("executing on_job_complete")
                        predictor_impl.on_job_complete()
                    break
                should_run_on_job_complete = True
            time.sleep(10)  # verify that the queue is empty one more time
    except:
        logger().exception("failed to handle on_job_complete")
        raise
    finally:
        with receipt_handle_mutex:
            stop_renewal.add(receipt_handle)
            sqs_client.delete_message(QueueUrl=queue_url,
                                      ReceiptHandle=receipt_handle)
Beispiel #3
0
    def garbage_collect(
            self,
            exclude_disk_model_ids: List[str] = [],
            dry_run: bool = False) -> Tuple[bool, List[str], List[str]]:
        """
        Removes stale in-memory and on-disk models based on LRU policy.
        Also calls the "remove" callback before removing the models from this object. The callback must not raise any exceptions.

        Must be called with a write lock unless dry_run is set to true.

        Args:
            exclude_disk_model_ids: Model IDs to exclude from removing from disk. Necessary for locally-provided models.
            dry_run: Just test if there are any models to remove. If set to true, this method can then be called with a read lock.

        Returns:
            A 3-element tuple. First element tells whether models had to be collected. The 2nd and 3rd elements contain the model IDs that were removed from memory and disk respectively.
        """
        collected = False
        if self._mem_cache_size <= 0 or self._disk_cache_size <= 0:
            return collected

        stale_mem_model_ids = self._lru_model_ids(self._mem_cache_size,
                                                  filter_in_mem=True)
        stale_disk_model_ids = self._lru_model_ids(self._disk_cache_size -
                                                   len(exclude_disk_model_ids),
                                                   filter_in_mem=False)

        if self._remove_callback and not dry_run:
            self._remove_callback(stale_mem_model_ids)

        # don't delete excluded model IDs from disk
        stale_disk_model_ids = list(
            set(stale_disk_model_ids) - set(exclude_disk_model_ids))
        stale_disk_model_ids = stale_disk_model_ids[len(stale_disk_model_ids) -
                                                    self._disk_cache_size:]

        if not dry_run:
            logger().info(
                f"unloading models {stale_mem_model_ids} from memory using the garbage collector"
            )
            logger().info(
                f"unloading models {stale_disk_model_ids} from disk using the garbage collector"
            )
            for model_id in stale_mem_model_ids:
                self.remove_model_by_id(model_id, mem=True, disk=False)
            for model_id in stale_disk_model_ids:
                self.remove_model_by_id(model_id, mem=False, disk=True)

        if len(stale_mem_model_ids) > 0 or len(stale_disk_model_ids) > 0:
            collected = True

        return collected, stale_mem_model_ids, stale_disk_model_ids
Beispiel #4
0
def start():
    while not pathlib.Path("/mnt/workspace/init_script_run.txt").is_file():
        time.sleep(0.2)

    cache_dir = os.environ["CORTEX_CACHE_DIR"]
    provider = os.environ["CORTEX_PROVIDER"]
    api_spec_path = os.environ["CORTEX_API_SPEC"]
    job_spec_path = os.environ["CORTEX_JOB_SPEC"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]

    model_dir = os.getenv("CORTEX_MODEL_DIR")
    tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000")
    tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost")

    region = os.getenv("AWS_REGION")

    has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS")
    if has_multiple_servers:
        with LockedFile("/run/used_ports.json", "r+") as f:
            used_ports = json.load(f)
            for port in used_ports.keys():
                if not used_ports[port]:
                    tf_serving_port = port
                    used_ports[port] = True
                    break
            f.seek(0)
            json.dump(used_ports, f)
            f.truncate()

    api = get_api(provider, api_spec_path, model_dir, cache_dir, region)
    storage, api_spec = get_spec(provider, api_spec_path, cache_dir, region)
    job_spec = get_job_spec(storage, cache_dir, job_spec_path)

    client = api.predictor.initialize_client(tf_serving_host=tf_serving_host,
                                             tf_serving_port=tf_serving_port)
    logger().info("loading the predictor from {}".format(api.predictor.path))
    predictor_impl = api.predictor.initialize_impl(project_dir, client,
                                                   job_spec)

    local_cache["api_spec"] = api
    local_cache["provider"] = provider
    local_cache["job_spec"] = job_spec
    local_cache["predictor_impl"] = predictor_impl
    local_cache["predict_fn_args"] = inspect.getfullargspec(
        predictor_impl.predict).args
    local_cache["sqs_client"] = boto3.client("sqs", region_name=region)

    open("/mnt/workspace/api_readiness.txt", "a").close()

    logger().info("polling for batches...")
    sqs_loop()
Beispiel #5
0
def predict(request: Request):
    tasks = BackgroundTasks()
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]
    dynamic_batcher = local_cache["dynamic_batcher"]
    kwargs = build_predict_kwargs(request)

    if dynamic_batcher:
        prediction = dynamic_batcher.predict(**kwargs)
    else:
        prediction = predictor_impl.predict(**kwargs)

    if isinstance(prediction, bytes):
        response = Response(content=prediction,
                            media_type="application/octet-stream")
    elif isinstance(prediction, str):
        response = Response(content=prediction, media_type="text/plain")
    elif isinstance(prediction, Response):
        response = prediction
    else:
        try:
            json_string = json.dumps(prediction)
        except Exception as e:
            raise UserRuntimeException(
                str(e),
                "please return an object that is JSON serializable (including its nested fields), a bytes object, "
                "a string, or a starlette.response.Response object",
            ) from e
        response = Response(content=json_string, media_type="application/json")

    if local_cache["provider"] != "local" and api.monitoring is not None:
        try:
            predicted_value = api.monitoring.extract_predicted_value(
                prediction)
            api.post_monitoring_metrics(predicted_value)
            if (api.monitoring.model_type == "classification"
                    and predicted_value not in local_cache["class_set"]):
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
        except:
            logger().warn("unable to record prediction metric", exc_info=True)

    if util.has_method(predictor_impl, "post_predict"):
        kwargs = build_post_predict_kwargs(prediction, request)
        request_thread_pool.submit(predictor_impl.post_predict, **kwargs)

    if len(tasks.tasks) > 0:
        response.background = tasks

    return response
Beispiel #6
0
def sqs_loop():
    job_spec = local_cache["job_spec"]
    api_spec = local_cache["api_spec"]
    predictor_impl = local_cache["predictor_impl"]
    sqs_client = local_cache["sqs_client"]

    queue_url = job_spec["sqs_url"]

    no_messages_found_in_previous_iteration = False

    while True:
        response = sqs_client.receive_message(
            QueueUrl=queue_url,
            MaxNumberOfMessages=1,
            WaitTimeSeconds=SQS_POLL_WAIT_TIME,
            VisibilityTimeout=INITIAL_MESSAGE_VISIBILITY,
            MessageAttributeNames=["All"],
        )

        if response.get("Messages") is None or len(response["Messages"]) == 0:
            visible_messages, invisible_messages = get_total_messages_in_queue(
            )
            if visible_messages + invisible_messages == 0:
                if no_messages_found_in_previous_iteration:
                    logger().info("no batches left in queue, exiting...")
                    return
                no_messages_found_in_previous_iteration = True

            time.sleep(MESSAGE_NOT_FOUND_SLEEP)
            continue

        no_messages_found_in_previous_iteration = False
        message = response["Messages"][0]
        receipt_handle = message["ReceiptHandle"]

        renewer = threading.Thread(target=renew_message_visibility,
                                   args=(receipt_handle, ),
                                   daemon=True)
        renewer.start()

        if is_on_job_complete(message):
            handle_on_job_complete(message)
        else:
            handle_batch_message(message)
Beispiel #7
0
def handle_on_complete(message):
    job_spec = local_cache["job_spec"]
    predictor_impl = local_cache["predictor_impl"]
    sqs_client = local_cache["sqs_client"]
    queue_url = job_spec["sqs_url"]
    receipt_handle = message["ReceiptHandle"]

    try:
        if not getattr(predictor_impl, "on_job_complete", None):
            sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
            return True

        should_run_on_job_complete = False

        while True:
            visible_count, not_visible_count = get_total_messages_in_queue()

            # if there are other messages that are visible, release this message and get the other ones (should rarely happen for FIFO)
            if visible_count > 0:
                sqs_client.change_message_visibility(
                    QueueUrl=queue_url, ReceiptHandle=receipt_handle, VisibilityTimeout=0
                )
                return False

            if should_run_on_job_complete:
                # double check that the queue is still empty (except for the job_complete message)
                if not_visible_count <= 1:
                    logger().info("executing on_job_complete")
                    predictor_impl.on_job_complete()
                    sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
                    return True
                else:
                    should_run_on_job_complete = False

            if not_visible_count <= 1:
                should_run_on_job_complete = True

            time.sleep(20)
    except:
        sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
        raise
Beispiel #8
0
    def post_metrics(self, metrics):
        try:
            if self.statsd is None:
                raise CortexException(
                    "statsd client not initialized")  # unexpected

            for metric in metrics:
                tags = [
                    "{}:{}".format(dim["Name"], dim["Value"])
                    for dim in metric["Dimensions"]
                ]
                if metric.get("Unit") == "Count":
                    self.statsd.increment(metric["MetricName"],
                                          value=metric["Value"],
                                          tags=tags)
                else:
                    self.statsd.histogram(metric["MetricName"],
                                          value=metric["Value"],
                                          tags=tags)
        except:
            logger().warn("failure encountered while publishing metrics",
                          exc_info=True)
Beispiel #9
0
    def _remove_models(self, model_ids: List[str]) -> None:
        """
        Remove models from TFS.
        Must only be used when caching enabled.
        """
        logger().info(f"unloading models with model IDs {model_ids} from TFS")

        models = {}
        for model_id in model_ids:
            model_name, model_version = model_id.rsplit("-", maxsplit=1)
            if model_name not in models:
                models[model_name] = [model_version]
            else:
                models[model_name].append(model_version)

        model_names = []
        model_versions = []
        for model_name, versions in models.items():
            model_names.append(model_name)
            model_versions.append(versions)

        self._client.remove_models(model_names, model_versions)
Beispiel #10
0
def handle_batch_message(message):
    job_spec = local_cache["job_spec"]
    predictor_impl = local_cache["predictor_impl"]
    sqs_client = local_cache["sqs_client"]
    queue_url = job_spec["sqs_url"]
    receipt_handle = message["ReceiptHandle"]
    api_spec = local_cache["api_spec"]

    start_time = time.time()

    try:
        logger().info(f"processing batch {message['MessageId']}")
        payload = json.loads(message["Body"])
        batch_id = message["MessageId"]
        predictor_impl.predict(**build_predict_args(payload, batch_id))

        api_spec.post_metrics([
            success_counter_metric(),
            time_per_batch_metric(time.time() - start_time)
        ])
    except:
        api_spec.post_metrics([failed_counter_metric()])
        logger().exception(f"failed processing batch {message['MessageId']}")
        with receipt_handle_mutex:
            stop_renewal.add(receipt_handle)
            if job_spec.get("sqs_dead_letter_queue") is not None:
                sqs_client.change_message_visibility(  # return message
                    QueueUrl=queue_url,
                    ReceiptHandle=receipt_handle,
                    VisibilityTimeout=0)
            else:
                sqs_client.delete_message(QueueUrl=queue_url,
                                          ReceiptHandle=receipt_handle)
    else:
        with receipt_handle_mutex:
            stop_renewal.add(receipt_handle)
            sqs_client.delete_message(QueueUrl=queue_url,
                                      ReceiptHandle=receipt_handle)
Beispiel #11
0
def start(args):
    download_config = json.loads(base64.urlsafe_b64decode(args.download))
    for download_arg in download_config["download_args"]:
        from_path = download_arg["from"]
        to_path = download_arg["to"]
        item_name = download_arg.get("item_name", "")

        if from_path.startswith("s3://"):
            bucket_name, prefix = S3.deconstruct_s3_path(from_path)
            client = S3(bucket_name, client_config={})
        elif from_path.startswith("gs://"):
            bucket_name, prefix = GCS.deconstruct_gcs_path(from_path)
            client = GCS(bucket_name)
        else:
            raise ValueError(
                '"from" download arg can either have the "s3://" or "gs://" prefixes'
            )

        if item_name != "":
            if download_arg.get("hide_from_log", False):
                logger().info("downloading {}".format(item_name))
            else:
                logger().info("downloading {} from {}".format(
                    item_name, from_path))

        if download_arg.get("to_file", False):
            client.download_file(prefix, to_path)
        else:
            client.download(prefix, to_path)

        if download_arg.get("unzip", False):
            if item_name != "" and not download_arg.get(
                    "hide_unzipping_log", False):
                logger().info("unzipping {}".format(item_name))
            if download_arg.get("to_file", False):
                util.extract_zip(to_path, delete_zip_file=True)
            else:
                util.extract_zip(os.path.join(to_path,
                                              os.path.basename(from_path)),
                                 delete_zip_file=True)

    if download_config.get("last_log", "") != "":
        logger().info(download_config["last_log"])
Beispiel #12
0
def model_downloader(
    predictor_type: PredictorType,
    bucket_provider: str,
    bucket_name: str,
    model_name: str,
    model_version: str,
    model_path: str,
    temp_dir: str,
    model_dir: str,
) -> Optional[datetime.datetime]:
    """
    Downloads model to disk. Validates the cloud model path and the downloaded model as well.

    Args:
        predictor_type: The predictor type as implemented by the API.
        bucket_provider: Provider for the bucket. Can be "s3" or "gs".
        bucket_name: Name of the bucket where the model is stored.
        model_name: Name of the model. Is part of the model's local path.
        model_version: Version of the model. Is part of the model's local path.
        model_path: Model prefix of the versioned model.
        temp_dir: Where to temporarily store the model for validation.
        model_dir: The top directory of where all models are stored locally.

    Returns:
        The model's timestamp. None if the model didn't pass the validation, if it doesn't exist or if there are not enough permissions.
    """

    logger().info(
        f"downloading from bucket {bucket_name}/{model_path}, model {model_name} of version {model_version}, temporarily to {temp_dir} and then finally to {model_dir}"
    )

    if bucket_provider == "s3":
        client = S3(bucket_name)
    if bucket_provider == "gs":
        client = GCS(bucket_name)

    # validate upstream cloud model
    sub_paths, ts = client.search(model_path)
    try:
        validate_model_paths(sub_paths, predictor_type, model_path)
    except CortexException:
        logger().info(
            f"failed validating model {model_name} of version {model_version}")
        return None

    # download model to temp dir
    temp_dest = os.path.join(temp_dir, model_name, model_version)
    try:
        client.download_dir_contents(model_path, temp_dest)
    except CortexException:
        logger().info(
            f"failed downloading model {model_name} of version {model_version} to temp dir {temp_dest}"
        )
        shutil.rmtree(temp_dest)
        return None

    # validate model
    model_contents = glob.glob(temp_dest + "*/**", recursive=True)
    model_contents = util.remove_non_empty_directory_paths(model_contents)
    try:
        validate_model_paths(model_contents, predictor_type, temp_dest)
    except CortexException:
        logger().info(
            f"failed validating model {model_name} of version {model_version} from temp dir"
        )
        shutil.rmtree(temp_dest)
        return None

    # move model to dest dir
    model_top_dir = os.path.join(model_dir, model_name)
    ondisk_model_version = os.path.join(model_top_dir, model_version)
    logger().info(
        f"moving model {model_name} of version {model_version} to final dir {ondisk_model_version}"
    )
    if os.path.isdir(ondisk_model_version):
        shutil.rmtree(ondisk_model_version)
    shutil.move(temp_dest, ondisk_model_version)

    return max(ts)
Beispiel #13
0
    def _run_inference(self, model_input: Any, model_name: str,
                       model_version: str) -> dict:
        """
        When processes_per_replica = 1 and caching enabled, check/load model and make prediction.
        When processes_per_replica > 0 and caching disabled, attempt to make prediction regardless.

        Args:
            model_input: Input to the model.
            model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk.
            model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" version tag.

        Returns:
            The prediction.
        """

        model = None
        tag = ""
        if model_version == "latest":
            tag = model_version

        if not self._caching_enabled:

            # determine model version
            if tag == "latest":
                versions = self._client.poll_available_model_versions(
                    model_name)
                if len(versions) == 0:
                    raise UserException(
                        f"model '{model_name}' accessed with tag {tag} couldn't be found"
                    )
                model_version = str(max(map(lambda x: int(x), versions)))
            model_id = model_name + "-" + model_version

            return self._client.predict(model_input, model_name, model_version)

        if not self._multiple_processes and self._caching_enabled:

            # determine model version
            try:
                if tag == "latest":
                    model_version = self._get_latest_model_version_from_tree(
                        model_name, self._models_tree.model_info(model_name))
            except ValueError:
                # if model_name hasn't been found
                raise UserRuntimeException(
                    f"'{model_name}' model of tag {tag} wasn't found in the list of available models"
                )

            models_stats = []
            for model_id in self._models.get_model_ids():
                models_stats = self._models.has_model_id(model_id)

            # grab shared access to model tree
            available_model = True
            logger().info(
                f"grabbing access to model {model_name} of version {model_version}"
            )
            with LockedModelsTree(self._models_tree, "r", model_name,
                                  model_version):

                # check if the versioned model exists
                model_id = model_name + "-" + model_version
                if model_id not in self._models_tree:
                    available_model = False
                    logger().info(
                        f"model {model_name} of version {model_version} is not available"
                    )
                    raise WithBreak

                # retrieve model tree's metadata
                upstream_model = self._models_tree[model_id]
                current_upstream_ts = int(
                    upstream_model["timestamp"].timestamp())
                logger().info(
                    f"model {model_name} of version {model_version} is available"
                )

            if not available_model:
                if tag == "":
                    raise UserException(
                        f"model '{model_name}' of version '{model_version}' couldn't be found"
                    )
                raise UserException(
                    f"model '{model_name}' accessed with tag '{tag}' couldn't be found"
                )

            # grab shared access to models holder and retrieve model
            update_model = False
            prediction = None
            tfs_was_unresponsive = False
            with LockedModel(self._models, "r", model_name, model_version):
                logger().info(
                    f"checking the {model_name} {model_version} status")
                status, local_ts = self._models.has_model(
                    model_name, model_version)
                if status in ["not-available", "on-disk"
                              ] or (status != "not-available"
                                    and local_ts != current_upstream_ts):
                    logger().info(
                        f"model {model_name} of version {model_version} is not loaded (with status {status} or different timestamp)"
                    )
                    update_model = True
                    raise WithBreak

                # run prediction
                logger().info(
                    f"run the prediction on model {model_name} of version {model_version}"
                )
                self._models.get_model(model_name, model_version, tag)
                try:
                    prediction = self._client.predict(model_input, model_name,
                                                      model_version)
                except grpc.RpcError as e:
                    # effectively when it got restarted
                    if len(
                            self._client.poll_available_model_versions(
                                model_name)) > 0:
                        raise
                    tfs_was_unresponsive = True

            # remove model from disk and memory references if TFS gets unresponsive
            if tfs_was_unresponsive:
                with LockedModel(self._models, "w", model_name, model_version):
                    available_versions = self._client.poll_available_model_versions(
                        model_name)
                    status, _ = self._models.has_model(model_name,
                                                       model_version)
                    if not (status == "in-memory"
                            and model_version not in available_versions):
                        raise WithBreak

                    logger().info(
                        f"removing model {model_name} of version {model_version} because TFS got unresponsive"
                    )
                    self._models.remove_model(model_name, model_version)

            # download, load into memory the model and retrieve it
            if update_model:
                # grab exclusive access to models holder
                with LockedModel(self._models, "w", model_name, model_version):

                    # check model status
                    status, local_ts = self._models.has_model(
                        model_name, model_version)

                    # refresh disk model
                    if status == "not-available" or (
                            status in ["on-disk", "in-memory"]
                            and local_ts != current_upstream_ts):
                        # unload model from TFS
                        if status == "in-memory":
                            try:
                                logger().info(
                                    f"unloading model {model_name} of version {model_version} from TFS"
                                )
                                self._models.unload_model(
                                    model_name, model_version)
                            except Exception:
                                logger().info(
                                    f"failed unloading model {model_name} of version {model_version} from TFS"
                                )
                                raise

                        # remove model from disk and references
                        if status in ["on-disk", "in-memory"]:
                            logger().info(
                                f"removing model references from memory and from disk for model {model_name} of version {model_version}"
                            )
                            self._models.remove_model(model_name,
                                                      model_version)

                        # download model
                        if model_name not in self._spec_models.get_local_model_names(
                        ):
                            logger().info(
                                f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream"
                            )
                            date = self._models.download_model(
                                upstream_model["provider"],
                                upstream_model["bucket"],
                                model_name,
                                model_version,
                                upstream_model["path"],
                            )
                            if not date:
                                raise WithBreak
                            current_upstream_ts = int(date.timestamp())

                    # load model
                    try:
                        logger().info(
                            f"loading model {model_name} of version {model_version} into memory"
                        )
                        self._models.load_model(
                            model_name,
                            model_version,
                            current_upstream_ts,
                            [tag],
                            kwargs={
                                "model_name":
                                model_name,
                                "model_version":
                                model_version,
                                "signature_key":
                                self._determine_model_signature_key(
                                    model_name),
                            },
                        )
                    except Exception as e:
                        raise UserRuntimeException(
                            f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})",
                            str(e),
                        )

                    # run prediction
                    self._models.get_model(model_name, model_version, tag)
                    prediction = self._client.predict(model_input, model_name,
                                                      model_version)

            return prediction
Beispiel #14
0
    def _extract_signatures(self, signature_def, signature_key,
                            model_name: str, model_version: str):
        logger().info(
            "signature defs found in model '{}' for version '{}': {}".format(
                model_name, model_version, signature_def))

        available_keys = list(signature_def.keys())
        if len(available_keys) == 0:
            raise UserException(
                "unable to find signature defs in model '{}' of version '{}'".
                format(model_name, model_version))

        if signature_key is None:
            if len(available_keys) == 1:
                logger().info(
                    "signature_key was not configured by user, using signature key '{}' for model '{}' of version '{}' (found in the signature def map)"
                    .format(
                        available_keys[0],
                        model_name,
                        model_version,
                    ))
                signature_key = available_keys[0]
            elif "predict" in signature_def:
                logger().info(
                    "signature_key was not configured by user, using signature key 'predict' for model '{}' of version '{}' (found in the signature def map)"
                    .format(
                        model_name,
                        model_version,
                    ))
                signature_key = "predict"
            else:
                raise UserException(
                    "signature_key was not configured by user, please specify one the following keys '{}' for model '{}' of version '{}' (found in the signature def map)"
                    .format(", ".join(available_keys), model_name,
                            model_version))
        else:
            if signature_def.get(signature_key) is None:
                possibilities_str = "key: '{}'".format(available_keys[0])
                if len(available_keys) > 1:
                    possibilities_str = "keys: '{}'".format(
                        "', '".join(available_keys))

                raise UserException(
                    "signature_key '{}' was not found in signature def map for model '{}' of version '{}', but found the following {}"
                    .format(signature_key, model_name, model_version,
                            possibilities_str))

        signature_def_val = signature_def.get(signature_key)

        if signature_def_val.get("inputs") is None:
            raise UserException(
                "unable to find 'inputs' in signature def '{}' for model '{}'".
                format(signature_key, model_name))

        parsed_signatures = {}
        for input_name, input_metadata in signature_def_val["inputs"].items():
            if input_metadata["tensorShape"] == {}:
                # a scalar with rank 0 and empty shape
                shape = "scalar"
            elif input_metadata["tensorShape"].get("unknownRank", False):
                # unknown rank and shape
                #
                # unknownRank is set to True if the model input has no rank
                # it may lead to an undefined behavior if unknownRank is only checked for its presence
                # so it also gets to be tested against its value
                shape = "unknown"
            elif input_metadata["tensorShape"].get("dim", None):
                # known rank and known/unknown shape
                shape = [
                    int(dim["size"])
                    for dim in input_metadata["tensorShape"]["dim"]
                ]
            else:
                raise UserException(
                    "invalid 'tensorShape' specification for input '{}' in signature key '{}' for model '{}'",
                    input_name,
                    signature_key,
                    model_name,
                )

            parsed_signatures[input_name] = {
                "shape": shape if type(shape) == list else [shape],
                "type": DTYPE_TO_TF_TYPE[input_metadata["dtype"]].name,
            }
        return signature_key, parsed_signatures
Beispiel #15
0
def start_fn():
    provider = os.environ["CORTEX_PROVIDER"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]
    spec_path = os.environ["CORTEX_API_SPEC"]

    model_dir = os.getenv("CORTEX_MODEL_DIR")
    cache_dir = os.getenv("CORTEX_CACHE_DIR")
    region = os.getenv("AWS_REGION")

    tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000")
    tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost")

    has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS")
    if has_multiple_servers:
        with LockedFile("/run/used_ports.json", "r+") as f:
            used_ports = json.load(f)
            for port in used_ports.keys():
                if not used_ports[port]:
                    tf_serving_port = port
                    used_ports[port] = True
                    break
            f.seek(0)
            json.dump(used_ports, f)
            f.truncate()

    try:
        api = get_api(provider, spec_path, model_dir, cache_dir, region)

        client = api.predictor.initialize_client(
            tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port)

        with FileLock("/run/init_stagger.lock"):
            logger().info("loading the predictor from {}".format(
                api.predictor.path))
            predictor_impl = api.predictor.initialize_impl(project_dir, client)

        local_cache["api"] = api
        local_cache["provider"] = provider
        local_cache["client"] = client
        local_cache["predictor_impl"] = predictor_impl
        local_cache["predict_fn_args"] = inspect.getfullargspec(
            predictor_impl.predict).args

        if api.server_side_batching_enabled:
            dynamic_batching_config = api.api_spec["predictor"][
                "server_side_batching"]
            local_cache["dynamic_batcher"] = DynamicBatcher(
                predictor_impl,
                max_batch_size=dynamic_batching_config["max_batch_size"],
                batch_interval=dynamic_batching_config["batch_interval"] /
                NANOSECONDS_IN_SECOND,  # convert nanoseconds to seconds
            )

        if util.has_method(predictor_impl, "post_predict"):
            local_cache["post_predict_fn_args"] = inspect.getfullargspec(
                predictor_impl.post_predict).args

        predict_route = "/"
        if provider != "local":
            predict_route = "/predict"
        local_cache["predict_route"] = predict_route
    except:
        logger().exception("failed to start api")
        sys.exit(1)

    if (provider != "local" and api.monitoring is not None
            and api.monitoring.model_type == "classification"):
        try:
            local_cache["class_set"] = api.get_cached_classes()
        except:
            logger().warn("an error occurred while attempting to load classes",
                          exc_info=True)

    app.add_api_route(local_cache["predict_route"], predict, methods=["POST"])
    app.add_api_route(local_cache["predict_route"],
                      get_summary,
                      methods=["GET"])

    return app
Beispiel #16
0
def sqs_loop():
    job_spec = local_cache["job_spec"]
    api_spec = local_cache["api_spec"]
    predictor_impl = local_cache["predictor_impl"]
    sqs_client = local_cache["sqs_client"]

    queue_url = job_spec["sqs_url"]

    no_messages_found_in_previous_iteration = False

    while True:
        response = sqs_client.receive_message(
            QueueUrl=queue_url,
            MaxNumberOfMessages=1,
            WaitTimeSeconds=10,
            VisibilityTimeout=MAXIMUM_MESSAGE_VISIBILITY,
            MessageAttributeNames=["All"],
        )

        if response.get("Messages") is None or len(response["Messages"]) == 0:
            if no_messages_found_in_previous_iteration:
                logger().info("no batches left in queue, exiting...")
                return
            else:
                no_messages_found_in_previous_iteration = True
                continue
        else:
            no_messages_found_in_previous_iteration = False

        message = response["Messages"][0]

        receipt_handle = message["ReceiptHandle"]

        if "MessageAttributes" in message and "job_complete" in message["MessageAttributes"]:
            handled_on_complete = handle_on_complete(message)
            if handled_on_complete:
                logger().info("no batches left in queue, job has been completed")
                return
            else:
                # sometimes on_job_complete message will be released if there are other messages still to be processed
                continue

        try:
            logger().info(f"processing batch {message['MessageId']}")

            start_time = time.time()

            payload = json.loads(message["Body"])
            batch_id = message["MessageId"]
            predictor_impl.predict(**build_predict_args(payload, batch_id))

            api_spec.post_metrics(
                [success_counter_metric(), time_per_batch_metric(time.time() - start_time)]
            )
        except Exception:
            api_spec.post_metrics(
                [failed_counter_metric(), time_per_batch_metric(time.time() - start_time)]
            )
            logger().exception("failed to process batch")
        finally:
            sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
Beispiel #17
0
    def _get_model(self, model_name: str, model_version: str) -> Any:
        """
        Checks if versioned model is on disk, then checks if model is in memory,
        and if not, it loads it into memory, and returns the model.

        Args:
            model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk.
            model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" version tag.

        Exceptions:
            RuntimeError: if another thread tried to load the model at the very same time.

        Returns:
            The model as returned by self._load_model method.
            None if the model wasn't found or if it didn't pass the validation.
        """

        model = None
        tag = ""
        if model_version == "latest":
            tag = model_version

        if not self._caching_enabled:
            # determine model version
            if tag == "latest":
                model_version = self._get_latest_model_version_from_disk(model_name)
            model_id = model_name + "-" + model_version

            # grab shared access to versioned model
            resource = os.path.join(self._lock_dir, model_id + ".txt")
            with LockedFile(resource, "r", reader_lock=True) as f:

                # check model status
                file_status = f.read()
                if file_status == "" or file_status == "not-available":
                    raise WithBreak

                current_upstream_ts = int(file_status.split(" ")[1])
                update_model = False

                # grab shared access to models holder and retrieve model
                with LockedModel(self._models, "r", model_name, model_version):
                    status, local_ts = self._models.has_model(model_name, model_version)
                    if status == "not-available" or (
                        status == "in-memory" and local_ts != current_upstream_ts
                    ):
                        update_model = True
                        raise WithBreak
                    model, _ = self._models.get_model(model_name, model_version, tag)

                # load model into memory and retrieve it
                if update_model:
                    with LockedModel(self._models, "w", model_name, model_version):
                        status, _ = self._models.has_model(model_name, model_version)
                        if status == "not-available" or (
                            status == "in-memory" and local_ts != current_upstream_ts
                        ):
                            if status == "not-available":
                                logger().info(
                                    f"loading model {model_name} of version {model_version} (thread {td.get_ident()})"
                                )
                            else:
                                logger().info(
                                    f"reloading model {model_name} of version {model_version} (thread {td.get_ident()})"
                                )
                            try:
                                self._models.load_model(
                                    model_name,
                                    model_version,
                                    current_upstream_ts,
                                    [tag],
                                )
                            except Exception as e:
                                raise UserRuntimeException(
                                    f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})",
                                    str(e),
                                )
                        model, _ = self._models.get_model(model_name, model_version, tag)

        if not self._multiple_processes and self._caching_enabled:
            # determine model version
            try:
                if tag == "latest":
                    model_version = self._get_latest_model_version_from_tree(
                        model_name, self._models_tree.model_info(model_name)
                    )
            except ValueError:
                # if model_name hasn't been found
                raise UserRuntimeException(
                    f"'{model_name}' model of tag {tag} wasn't found in the list of available models"
                )

            # grab shared access to model tree
            available_model = True
            with LockedModelsTree(self._models_tree, "r", model_name, model_version):

                # check if the versioned model exists
                model_id = model_name + "-" + model_version
                if model_id not in self._models_tree:
                    available_model = False
                    raise WithBreak

                # retrieve model tree's metadata
                upstream_model = self._models_tree[model_id]
                current_upstream_ts = int(upstream_model["timestamp"].timestamp())

            if not available_model:
                return None

            # grab shared access to models holder and retrieve model
            update_model = False
            with LockedModel(self._models, "r", model_name, model_version):
                status, local_ts = self._models.has_model(model_name, model_version)
                if status in ["not-available", "on-disk"] or (
                    status != "not-available"
                    and local_ts != current_upstream_ts
                    and not (status == "in-memory" and model_name in self._spec_local_model_names)
                ):
                    update_model = True
                    raise WithBreak
                model, _ = self._models.get_model(model_name, model_version, tag)

            # download, load into memory the model and retrieve it
            if update_model:
                # grab exclusive access to models holder
                with LockedModel(self._models, "w", model_name, model_version):

                    # check model status
                    status, local_ts = self._models.has_model(model_name, model_version)

                    # refresh disk model
                    if model_name not in self._spec_local_model_names and (
                        status == "not-available"
                        or (status in ["on-disk", "in-memory"] and local_ts != current_upstream_ts)
                    ):
                        if status == "not-available":
                            logger().info(
                                f"model {model_name} of version {model_version} not found locally; continuing with the download..."
                            )
                        elif status == "on-disk":
                            logger().info(
                                f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one on the disk"
                            )
                        else:
                            logger().info(
                                f"found newer model {model_name} of vesion {model_version} on the {upstream_model['provider']} upstream than the one loaded into memory"
                            )

                        # remove model from disk and memory
                        if status == "on-disk":
                            logger().info(
                                f"removing model from disk for model {model_name} of version {model_version}"
                            )
                            self._models.remove_model(model_name, model_version)
                        if status == "in-memory":
                            logger().info(
                                f"removing model from disk and memory for model {model_name} of version {model_version}"
                            )
                            self._models.remove_model(model_name, model_version)

                        # download model
                        logger().info(
                            f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream"
                        )
                        date = self._models.download_model(
                            upstream_model["provider"],
                            upstream_model["bucket"],
                            model_name,
                            model_version,
                            upstream_model["path"],
                        )
                        if not date:
                            raise WithBreak
                        current_upstream_ts = int(date.timestamp())

                    # give the local model a timestamp initialized at start time
                    if model_name in self._spec_local_model_names:
                        current_upstream_ts = self._local_model_ts

                    # load model
                    try:
                        logger().info(
                            f"loading model {model_name} of version {model_version} into memory"
                        )
                        self._models.load_model(
                            model_name,
                            model_version,
                            current_upstream_ts,
                            [tag],
                        )
                    except Exception as e:
                        raise UserRuntimeException(
                            f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})",
                            str(e),
                        )

                    # retrieve model
                    model, _ = self._models.get_model(model_name, model_version, tag)

        return model