Пример #1
0
    def _handle_message(self, message, callback_fn, failure_callback_fn):
        receipt_handle = message["ReceiptHandle"]

        try:
            callback_fn(message)
        except Exception as err:
            if not isinstance(err, UserRuntimeException):
                capture_exception(err)

            failure_callback_fn(message)

            with self.receipt_handle_mutex:
                self.stop_renewal.add(receipt_handle)
                if self.dead_letter_queue_url is not None:
                    self.sqs_client.change_message_visibility(  # return message
                        QueueUrl=self.queue_url,
                        ReceiptHandle=receipt_handle,
                        VisibilityTimeout=0)
                else:
                    self.sqs_client.delete_message(
                        QueueUrl=self.queue_url, ReceiptHandle=receipt_handle)
        else:
            with self.receipt_handle_mutex:
                self.stop_renewal.add(receipt_handle)
                self.sqs_client.delete_message(QueueUrl=self.queue_url,
                                               ReceiptHandle=receipt_handle)
Пример #2
0
def main():
    address = sys.argv[1]
    threads_per_process = int(os.environ["CORTEX_THREADS_PER_PROCESS"])

    try:
        config = init()
    except Exception as err:
        if not isinstance(err, UserRuntimeException):
            capture_exception(err)
        logger.exception("failed to start api")
        sys.exit(1)

    module_proto_pb2 = config["module_proto_pb2"]
    module_proto_pb2_grpc = config["module_proto_pb2_grpc"]
    PredictorServicer = config["predictor_servicer"]

    api = config["api"]
    predictor_impl = config["predictor_impl"]
    predict_fn_args = config["predict_fn_args"]

    server = grpc.server(
        ThreadPoolExecutorWithRequestMonitor(
            post_latency_metrics_fn=api.post_latency_request_metrics,
            max_workers=threads_per_process,
        ),
        options=[("grpc.max_send_message_length", -1),
                 ("grpc.max_receive_message_length", -1)],
    )

    add_PredictorServicer_to_server = get_servicer_to_server_from_module(
        module_proto_pb2_grpc)
    add_PredictorServicer_to_server(
        PredictorServicer(predict_fn_args, predictor_impl, api), server)

    service_name = get_service_name_from_module(module_proto_pb2_grpc)
    SERVICE_NAMES = (
        module_proto_pb2.DESCRIPTOR.services_by_name[service_name].full_name,
        reflection.SERVICE_NAME,
    )
    reflection.enable_server_reflection(SERVICE_NAMES, server)

    server.add_insecure_port(address)
    server.start()

    time.sleep(5.0)
    open(f"/mnt/workspace/proc-{os.getpid()}-ready.txt", "a").close()
    server.wait_for_termination()
Пример #3
0
def handle_batch_message(message):
    job_spec = local_cache["job_spec"]
    predictor_impl = local_cache["predictor_impl"]
    sqs_client = local_cache["sqs_client"]
    queue_url = job_spec["sqs_url"]
    receipt_handle = message["ReceiptHandle"]
    api = local_cache["api"]

    start_time = time.time()

    try:
        logger.info(f"processing batch {message['MessageId']}")
        payload = json.loads(message["Body"])
        batch_id = message["MessageId"]

        try:
            predictor_impl.predict(**build_predict_args(payload, batch_id))
        except Exception as err:
            raise UserRuntimeException from err

        api.post_metrics([
            success_counter_metric(),
            time_per_batch_metric(time.time() - start_time)
        ])
    except Exception as err:
        if not isinstance(err, UserRuntimeException):
            capture_exception(err)

        api.post_metrics([failed_counter_metric()])
        logger.exception(f"failed processing batch {message['MessageId']}")
        with receipt_handle_mutex:
            stop_renewal.add(receipt_handle)
            if job_spec.get("sqs_dead_letter_queue") is not None:
                sqs_client.change_message_visibility(  # return message
                    QueueUrl=queue_url,
                    ReceiptHandle=receipt_handle,
                    VisibilityTimeout=0)
            else:
                sqs_client.delete_message(QueueUrl=queue_url,
                                          ReceiptHandle=receipt_handle)
    else:
        with receipt_handle_mutex:
            stop_renewal.add(receipt_handle)
            sqs_client.delete_message(QueueUrl=queue_url,
                                      ReceiptHandle=receipt_handle)
Пример #4
0
def start():
    while not pathlib.Path("/mnt/workspace/init_script_run.txt").is_file():
        time.sleep(0.2)

    cache_dir = os.environ["CORTEX_CACHE_DIR"]
    api_spec_path = os.environ["CORTEX_API_SPEC"]
    job_spec_path = os.environ["CORTEX_JOB_SPEC"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]

    model_dir = os.getenv("CORTEX_MODEL_DIR")
    tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000")
    tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost")

    region = os.getenv("AWS_REGION")

    has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS")
    if has_multiple_servers:
        with LockedFile("/run/used_ports.json", "r+") as f:
            used_ports = json.load(f)
            for port in used_ports.keys():
                if not used_ports[port]:
                    tf_serving_port = port
                    used_ports[port] = True
                    break
            f.seek(0)
            json.dump(used_ports, f)
            f.truncate()

    api = get_api(api_spec_path, model_dir, cache_dir)
    with open(job_spec_path) as json_file:
        job_spec = json.load(json_file)

    sqs_client = boto3.client("sqs", region_name=region)

    client = api.predictor.initialize_client(
        tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port
    )

    try:
        log.info("loading the predictor from {}".format(api.predictor.path))
        metrics_client = MetricsClient(api.statsd)
        predictor_impl = api.predictor.initialize_impl(
            project_dir=project_dir,
            client=client,
            metrics_client=metrics_client,
            job_spec=job_spec,
        )
    except UserRuntimeException as err:
        err.wrap(f"failed to start job {job_spec['job_id']}")
        log.error(str(err), exc_info=True)
        sys.exit(1)
    except Exception as err:
        capture_exception(err)
        log.error(f"failed to start job {job_spec['job_id']}", exc_info=True)
        sys.exit(1)

    # crons only stop if an unhandled exception occurs
    def check_if_crons_have_failed():
        while True:
            for cron in api.predictor.crons:
                if not cron.is_alive():
                    os.kill(os.getpid(), signal.SIGQUIT)
            time.sleep(1)

    threading.Thread(target=check_if_crons_have_failed, daemon=True).start()

    local_cache["api"] = api
    local_cache["job_spec"] = job_spec
    local_cache["predictor_impl"] = predictor_impl
    local_cache["predict_fn_args"] = inspect.getfullargspec(predictor_impl.predict).args
    local_cache["sqs_client"] = sqs_client

    open("/mnt/workspace/api_readiness.txt", "a").close()

    log.info("polling for batches...")
    try:
        sqs_handler = SQSHandler(
            sqs_client=sqs_client,
            queue_url=job_spec["sqs_url"],
            renewal_period=MESSAGE_RENEWAL_PERIOD,
            visibility_timeout=INITIAL_MESSAGE_VISIBILITY,
            not_found_sleep_time=MESSAGE_NOT_FOUND_SLEEP,
            message_wait_time=SQS_POLL_WAIT_TIME,
            dead_letter_queue_url=job_spec.get("sqs_dead_letter_queue"),
            stop_if_no_messages=True,
        )
        sqs_handler.start(
            message_fn=handle_batch_message,
            message_failure_fn=handle_batch_failure,
            on_job_complete_fn=handle_on_job_complete,
        )
    except UserRuntimeException as err:
        err.wrap(f"failed to run job {job_spec['job_id']}")
        log.error(str(err), exc_info=True)
        sys.exit(1)
    except Exception as err:
        capture_exception(err)
        log.error(f"failed to run job {job_spec['job_id']}", exc_info=True)
        sys.exit(1)
Пример #5
0
def start():
    while not pathlib.Path("/mnt/workspace/init_script_run.txt").is_file():
        time.sleep(0.2)

    cache_dir = os.environ["CORTEX_CACHE_DIR"]
    provider = os.environ["CORTEX_PROVIDER"]
    api_spec_path = os.environ["CORTEX_API_SPEC"]
    job_spec_path = os.environ["CORTEX_JOB_SPEC"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]

    model_dir = os.getenv("CORTEX_MODEL_DIR")
    tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000")
    tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost")

    region = os.getenv("AWS_REGION")

    has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS")
    if has_multiple_servers:
        with LockedFile("/run/used_ports.json", "r+") as f:
            used_ports = json.load(f)
            for port in used_ports.keys():
                if not used_ports[port]:
                    tf_serving_port = port
                    used_ports[port] = True
                    break
            f.seek(0)
            json.dump(used_ports, f)
            f.truncate()

    api = get_api(provider, api_spec_path, model_dir, cache_dir, region)
    storage, _ = get_spec(provider, api_spec_path, cache_dir, region)
    job_spec = get_job_spec(storage, cache_dir, job_spec_path)

    client = api.predictor.initialize_client(tf_serving_host=tf_serving_host,
                                             tf_serving_port=tf_serving_port)

    try:
        logger.info("loading the predictor from {}".format(api.predictor.path))
        metrics_client = MetricsClient(api.statsd)
        predictor_impl = api.predictor.initialize_impl(
            project_dir=project_dir,
            client=client,
            metrics_client=metrics_client,
            job_spec=job_spec,
        )
    except UserRuntimeException as err:
        err.wrap(f"failed to start job {job_spec['job_id']}")
        logger.error(str(err), exc_info=True)
        sys.exit(1)
    except Exception as err:
        capture_exception(err)
        logger.error(f"failed to start job {job_spec['job_id']}",
                     exc_info=True)
        sys.exit(1)

    # crons only stop if an unhandled exception occurs
    def check_if_crons_have_failed():
        while True:
            for cron in api.predictor.crons:
                if not cron.is_alive():
                    os.kill(os.getpid(), signal.SIGQUIT)
            time.sleep(1)

    threading.Thread(target=check_if_crons_have_failed, daemon=True).start()

    local_cache["api"] = api
    local_cache["provider"] = provider
    local_cache["job_spec"] = job_spec
    local_cache["predictor_impl"] = predictor_impl
    local_cache["predict_fn_args"] = inspect.getfullargspec(
        predictor_impl.predict).args
    local_cache["sqs_client"] = boto3.client("sqs", region_name=region)

    open("/mnt/workspace/api_readiness.txt", "a").close()

    logger.info("polling for batches...")
    try:
        sqs_loop()
    except UserRuntimeException as err:
        err.wrap(f"failed to run job {job_spec['job_id']}")
        logger.error(str(err), exc_info=True)
        sys.exit(1)
    except Exception as err:
        capture_exception(err)
        logger.error(f"failed to run job {job_spec['job_id']}", exc_info=True)
        sys.exit(1)
Пример #6
0
def start_fn():
    provider = os.environ["CORTEX_PROVIDER"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]
    spec_path = os.environ["CORTEX_API_SPEC"]

    model_dir = os.getenv("CORTEX_MODEL_DIR")
    cache_dir = os.getenv("CORTEX_CACHE_DIR")
    region = os.getenv("AWS_REGION")

    tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000")
    tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost")

    try:
        has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS")
        if has_multiple_servers:
            with LockedFile("/run/used_ports.json", "r+") as f:
                used_ports = json.load(f)
                for port in used_ports.keys():
                    if not used_ports[port]:
                        tf_serving_port = port
                        used_ports[port] = True
                        break
                f.seek(0)
                json.dump(used_ports, f)
                f.truncate()

        api = get_api(provider, spec_path, model_dir, cache_dir, region)

        client = api.predictor.initialize_client(
            tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port)

        with FileLock("/run/init_stagger.lock"):
            logger.info("loading the predictor from {}".format(
                api.predictor.path))
            predictor_impl = api.predictor.initialize_impl(project_dir, client)

        # crons only stop if an unhandled exception occurs
        def check_if_crons_have_failed():
            while True:
                for cron in api.predictor.crons:
                    if not cron.is_alive():
                        os.kill(os.getpid(), signal.SIGQUIT)
                time.sleep(1)

        threading.Thread(target=check_if_crons_have_failed,
                         daemon=True).start()

        local_cache["api"] = api
        local_cache["provider"] = provider
        local_cache["client"] = client
        local_cache["predictor_impl"] = predictor_impl
        local_cache["predict_fn_args"] = inspect.getfullargspec(
            predictor_impl.predict).args

        if api.python_server_side_batching_enabled:
            dynamic_batching_config = api.api_spec["predictor"][
                "server_side_batching"]
            local_cache["dynamic_batcher"] = DynamicBatcher(
                predictor_impl,
                max_batch_size=dynamic_batching_config["max_batch_size"],
                batch_interval=dynamic_batching_config["batch_interval"] /
                NANOSECONDS_IN_SECOND,  # convert nanoseconds to seconds
            )

        if util.has_method(predictor_impl, "post_predict"):
            local_cache["post_predict_fn_args"] = inspect.getfullargspec(
                predictor_impl.post_predict).args

        predict_route = "/predict"
        local_cache["predict_route"] = predict_route

    except (UserRuntimeException, Exception) as err:
        if not isinstance(err, UserRuntimeException):
            capture_exception(err)
        logger.exception("failed to start api")
        sys.exit(1)

    app.add_api_route(local_cache["predict_route"], predict, methods=["POST"])
    app.add_api_route(local_cache["predict_route"],
                      get_summary,
                      methods=["GET"])

    return app
Пример #7
0
def main():
    while not pathlib.Path("/mnt/workspace/init_script_run.txt").is_file():
        time.sleep(0.2)

    model_dir = os.getenv("CORTEX_MODEL_DIR")
    cache_dir = os.environ["CORTEX_CACHE_DIR"]
    api_spec_path = os.environ["CORTEX_API_SPEC"]
    workload_path = os.environ["CORTEX_ASYNC_WORKLOAD_PATH"]
    project_dir = os.environ["CORTEX_PROJECT_DIR"]
    readiness_file = os.getenv("CORTEX_READINESS_FILE",
                               "/mnt/workspace/api_readiness.txt")
    region = os.getenv("AWS_REGION")
    queue_url = os.environ["CORTEX_QUEUE_URL"]
    statsd_host = os.getenv("HOST_IP")
    statsd_port = os.getenv("CORTEX_STATSD_PORT", "9125")
    tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST")
    tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT")

    storage, api_spec = get_spec(api_spec_path, cache_dir, region)
    sqs_client = boto3.client("sqs", region_name=region)
    api = AsyncAPI(
        api_spec=api_spec,
        storage=storage,
        storage_path=workload_path,
        statsd_host=statsd_host,
        statsd_port=int(statsd_port),
        model_dir=model_dir,
    )

    try:
        log.info(f"loading the predictor from {api.path}")
        metrics_client = MetricsClient(api.statsd)
        predictor_impl = api.initialize_impl(
            project_dir,
            metrics_client,
            tf_serving_host=tf_serving_host,
            tf_serving_port=tf_serving_port,
        )
    except UserRuntimeException as err:
        err.wrap(f"failed to initialize predictor implementation")
        log.error(str(err), exc_info=True)
        sys.exit(1)
    except Exception as err:
        capture_exception(err)
        log.error(f"failed to initialize predictor implementation",
                  exc_info=True)
        sys.exit(1)

    local_cache["api"] = api
    local_cache["predictor_impl"] = predictor_impl
    local_cache["sqs_client"] = sqs_client
    local_cache["storage_client"] = storage
    local_cache["predict_fn_args"] = inspect.getfullargspec(
        predictor_impl.predict).args

    open(readiness_file, "a").close()

    log.info("polling for workloads...")
    try:
        sqs_handler = SQSHandler(
            sqs_client=sqs_client,
            queue_url=queue_url,
            renewal_period=MESSAGE_RENEWAL_PERIOD,
            visibility_timeout=INITIAL_MESSAGE_VISIBILITY,
            not_found_sleep_time=MESSAGE_NOT_FOUND_SLEEP,
            message_wait_time=SQS_POLL_WAIT_TIME,
        )
        sqs_handler.start(message_fn=handle_workload,
                          message_failure_fn=handle_workload_failure)
    except UserRuntimeException as err:
        log.error(str(err), exc_info=True)
        sys.exit(1)
    except Exception as err:
        capture_exception(err)
        log.error(str(err), exc_info=True)
        sys.exit(1)