예제 #1
0
파일: api.py 프로젝트: wbj0110/cortex
def start(args):
    ctx = Context(s3_path=args.context,
                  cache_dir=args.cache_dir,
                  workload_id=args.workload_id)
    package.install_packages(ctx.python_packages, ctx.bucket)

    api = ctx.apis_id_map[args.api]
    model = ctx.models[api["model_name"]]
    tf_lib.set_logging_verbosity(ctx.environment["log_level"]["tensorflow"])

    local_cache["ctx"] = ctx
    local_cache["api"] = api
    local_cache["model"] = model

    if not os.path.isdir(args.model_dir):
        aws.download_and_extract_zip(model["key"], args.model_dir, ctx.bucket)

    for column_name in model["feature_columns"] + [model["target_column"]]:
        if ctx.is_transformed_column(column_name):
            trans_impl, _ = ctx.get_transformer_impl(column_name)
            local_cache["trans_impls"][column_name] = trans_impl
            transformed_column = ctx.transformed_columns[column_name]
            input_args_schema = transformed_column["inputs"]["args"]
            # cache aggregates and constants in memory
            if input_args_schema is not None:
                local_cache["transform_args_cache"][
                    column_name] = ctx.populate_args(input_args_schema)

    channel = implementations.insecure_channel("localhost", args.tf_serve_port)
    local_cache[
        "stub"] = prediction_service_pb2.beta_create_PredictionService_stub(
            channel)

    local_cache["required_inputs"] = tf_lib.get_base_input_columns(
        model["name"], ctx)

    # wait a bit for tf serving to start before querying metadata
    limit = 600
    for i in range(limit):
        try:
            local_cache["metadata"] = run_get_model_metadata()
            break
        except Exception as e:
            if i == limit - 1:
                logger.exception(
                    "An error occurred, see `cx logs api {}` for more details."
                    .format(api["name"]))
                sys.exit(1)

        time.sleep(1)

    logger.info("Serving model: {}".format(model["name"]))
    serve(app, listen="*:{}".format(args.port))
예제 #2
0
파일: train.py 프로젝트: wbj0110/cortex
def train(args):
    ctx = Context(s3_path=args.context,
                  cache_dir=args.cache_dir,
                  workload_id=args.workload_id)

    package.install_packages(ctx.python_packages, ctx.bucket)

    model = ctx.models_id_map[args.model]

    logger.info("Training")

    with util.Tempdir(ctx.cache_dir) as temp_dir:
        model_dir = os.path.join(temp_dir, "model_dir")
        ctx.upload_resource_status_start(model)

        try:
            model_impl = ctx.get_model_impl(model["name"])
            train_util.train(model["name"], model_impl, ctx, model_dir)
            ctx.upload_resource_status_success(model)

            logger.info("Caching")
            logger.info("Caching model " + model["name"])
            model_export_dir = os.path.join(model_dir, "export", "estimator")
            model_zip_path = os.path.join(temp_dir, "model.zip")
            util.zip_dir(model_export_dir, model_zip_path)

            aws.upload_file_to_s3(local_path=model_zip_path,
                                  key=model["key"],
                                  bucket=ctx.bucket)
            util.log_job_finished(ctx.workload_id)

        except CortexException as e:
            ctx.upload_resource_status_failed(model)
            e.wrap("error")
            logger.error(str(e))
            logger.exception(
                "An error occurred, see `cx logs model {}` for more details.".
                format(model["name"]))
            sys.exit(1)
        except Exception as e:
            ctx.upload_resource_status_failed(model)
            logger.exception(
                "An error occurred, see `cx logs model {}` for more details.".
                format(model["name"]))
            sys.exit(1)
예제 #3
0
파일: api.py 프로젝트: washingtonm/cortex
def start(args):
    ctx = Context(s3_path=args.context,
                  cache_dir=args.cache_dir,
                  workload_id=args.workload_id)
    api = ctx.apis_id_map[args.api]

    local_cache["api"] = api
    local_cache["ctx"] = ctx
    if api.get("request_handler_impl_key") is not None:
        package.install_packages(ctx.python_packages, ctx.storage)
        local_cache["request_handler"] = ctx.get_request_handler_impl(
            api["name"])

    model_cache_path = os.path.join(args.model_dir, args.api)
    if not os.path.exists(model_cache_path):
        ctx.storage.download_file_external(api["model"], model_cache_path)

    sess = rt.InferenceSession(model_cache_path)
    local_cache["sess"] = sess
    local_cache["input_metadata"] = sess.get_inputs()
    local_cache["output_metadata"] = sess.get_outputs()
    logger.info("Serving model: {}".format(
        util.remove_resource_ref(api["model"])))
    serve(app, listen="*:{}".format(args.port))
예제 #4
0
파일: api.py 프로젝트: washingtonm/cortex
def start(args):
    ctx = Context(s3_path=args.context,
                  cache_dir=args.cache_dir,
                  workload_id=args.workload_id)

    api = ctx.apis_id_map[args.api]
    local_cache["api"] = api
    local_cache["ctx"] = ctx

    if api.get("request_handler_impl_key") is not None:
        local_cache["request_handler"] = ctx.get_request_handler_impl(
            api["name"])

    if not util.is_resource_ref(api["model"]):
        if api.get("request_handler") is not None:
            package.install_packages(ctx.python_packages, ctx.storage)
        if not os.path.isdir(args.model_dir):
            ctx.storage.download_and_unzip_external(api["model"],
                                                    args.model_dir)
    else:
        package.install_packages(ctx.python_packages, ctx.storage)
        model_name = util.get_resource_ref(api["model"])
        model = ctx.models[model_name]
        estimator = ctx.estimators[model["estimator"]]

        local_cache["model"] = model
        local_cache["estimator"] = estimator
        local_cache["target_col"] = ctx.columns[util.get_resource_ref(
            model["target_column"])]
        local_cache["target_col_type"] = ctx.get_inferred_column_type(
            util.get_resource_ref(model["target_column"]))

        log_level = "DEBUG"
        if ctx.environment is not None and ctx.environment.get(
                "log_level") is not None:
            log_level = ctx.environment["log_level"].get("tensorflow", "DEBUG")
        tf_lib.set_logging_verbosity(log_level)

        if not os.path.isdir(args.model_dir):
            ctx.storage.download_and_unzip(model["key"], args.model_dir)

        for column_name in ctx.extract_column_names(
            [model["input"], model["target_column"]]):
            if ctx.is_transformed_column(column_name):
                trans_impl, _ = ctx.get_transformer_impl(column_name)
                local_cache["trans_impls"][column_name] = trans_impl
                transformed_column = ctx.transformed_columns[column_name]

                # cache aggregate values
                for resource_name in util.extract_resource_refs(
                        transformed_column["input"]):
                    if resource_name in ctx.aggregates:
                        ctx.get_obj(ctx.aggregates[resource_name]["key"])

        local_cache["required_inputs"] = tf_lib.get_base_input_columns(
            model["name"], ctx)

        if util.is_dict(model["input"]) and model["input"].get(
                "target_vocab") is not None:
            local_cache["target_vocab_populated"] = ctx.populate_values(
                model["input"]["target_vocab"], None, False)

    try:
        validate_model_dir(args.model_dir)
    except Exception as e:
        logger.exception(e)
        sys.exit(1)

    channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port))
    local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(
        channel)

    # wait a bit for tf serving to start before querying metadata
    limit = 300
    for i in range(limit):
        try:
            local_cache["metadata"] = run_get_model_metadata()
            break
        except Exception as e:
            if i == limit - 1:
                logger.exception(
                    "An error occurred, see `cortex logs -v api {}` for more details."
                    .format(api["name"]))
                sys.exit(1)

        time.sleep(1)

    logger.info("Serving model: {}".format(
        util.remove_resource_ref(api["model"])))
    serve(app, listen="*:{}".format(args.port))