示例#1
0
def start(args):
    ctx = Context(s3_path=args.context,
                  cache_dir=args.cache_dir,
                  workload_id=args.workload_id)
    package.install_packages(ctx.python_packages, ctx.storage)

    api = ctx.apis_id_map[args.api]
    model = ctx.models[api["model_name"]]
    tf_lib.set_logging_verbosity(ctx.environment["log_level"]["tensorflow"])

    local_cache["ctx"] = ctx
    local_cache["api"] = api
    local_cache["model"] = model

    if not os.path.isdir(args.model_dir):
        ctx.storage.download_and_unzip(model["key"], args.model_dir)

    for column_name in model["feature_columns"] + [model["target_column"]]:
        if ctx.is_transformed_column(column_name):
            trans_impl, _ = ctx.get_transformer_impl(column_name)
            local_cache["trans_impls"][column_name] = trans_impl
            transformed_column = ctx.transformed_columns[column_name]
            input_args_schema = transformed_column["inputs"]["args"]
            # cache aggregates and constants in memory
            if input_args_schema is not None:
                local_cache["transform_args_cache"][
                    column_name] = ctx.populate_args(input_args_schema)

    channel = implementations.insecure_channel("localhost", args.tf_serve_port)
    local_cache[
        "stub"] = prediction_service_pb2.beta_create_PredictionService_stub(
            channel)

    local_cache["required_inputs"] = tf_lib.get_base_input_columns(
        model["name"], ctx)

    # wait a bit for tf serving to start before querying metadata
    limit = 600
    for i in range(limit):
        try:
            local_cache["metadata"] = run_get_model_metadata()
            break
        except Exception as e:
            if i == limit - 1:
                logger.exception(
                    "An error occurred, see `cx logs api {}` for more details."
                    .format(api["name"]))
                sys.exit(1)

        time.sleep(1)

    logger.info("Serving model: {}".format(model["name"]))
    serve(app, listen="*:{}".format(args.port))
示例#2
0
文件: api.py 项目: washingtonm/cortex
def start(args):
    ctx = Context(s3_path=args.context,
                  cache_dir=args.cache_dir,
                  workload_id=args.workload_id)

    api = ctx.apis_id_map[args.api]
    local_cache["api"] = api
    local_cache["ctx"] = ctx

    if api.get("request_handler_impl_key") is not None:
        local_cache["request_handler"] = ctx.get_request_handler_impl(
            api["name"])

    if not util.is_resource_ref(api["model"]):
        if api.get("request_handler") is not None:
            package.install_packages(ctx.python_packages, ctx.storage)
        if not os.path.isdir(args.model_dir):
            ctx.storage.download_and_unzip_external(api["model"],
                                                    args.model_dir)
    else:
        package.install_packages(ctx.python_packages, ctx.storage)
        model_name = util.get_resource_ref(api["model"])
        model = ctx.models[model_name]
        estimator = ctx.estimators[model["estimator"]]

        local_cache["model"] = model
        local_cache["estimator"] = estimator
        local_cache["target_col"] = ctx.columns[util.get_resource_ref(
            model["target_column"])]
        local_cache["target_col_type"] = ctx.get_inferred_column_type(
            util.get_resource_ref(model["target_column"]))

        log_level = "DEBUG"
        if ctx.environment is not None and ctx.environment.get(
                "log_level") is not None:
            log_level = ctx.environment["log_level"].get("tensorflow", "DEBUG")
        tf_lib.set_logging_verbosity(log_level)

        if not os.path.isdir(args.model_dir):
            ctx.storage.download_and_unzip(model["key"], args.model_dir)

        for column_name in ctx.extract_column_names(
            [model["input"], model["target_column"]]):
            if ctx.is_transformed_column(column_name):
                trans_impl, _ = ctx.get_transformer_impl(column_name)
                local_cache["trans_impls"][column_name] = trans_impl
                transformed_column = ctx.transformed_columns[column_name]

                # cache aggregate values
                for resource_name in util.extract_resource_refs(
                        transformed_column["input"]):
                    if resource_name in ctx.aggregates:
                        ctx.get_obj(ctx.aggregates[resource_name]["key"])

        local_cache["required_inputs"] = tf_lib.get_base_input_columns(
            model["name"], ctx)

        if util.is_dict(model["input"]) and model["input"].get(
                "target_vocab") is not None:
            local_cache["target_vocab_populated"] = ctx.populate_values(
                model["input"]["target_vocab"], None, False)

    try:
        validate_model_dir(args.model_dir)
    except Exception as e:
        logger.exception(e)
        sys.exit(1)

    channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port))
    local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(
        channel)

    # wait a bit for tf serving to start before querying metadata
    limit = 300
    for i in range(limit):
        try:
            local_cache["metadata"] = run_get_model_metadata()
            break
        except Exception as e:
            if i == limit - 1:
                logger.exception(
                    "An error occurred, see `cortex logs -v api {}` for more details."
                    .format(api["name"]))
                sys.exit(1)

        time.sleep(1)

    logger.info("Serving model: {}".format(
        util.remove_resource_ref(api["model"])))
    serve(app, listen="*:{}".format(args.port))