Exemplo n.º 1
0
def run_predict(sample):
    ctx = local_cache["ctx"]
    request_handler = local_cache.get("request_handler")

    logger.info("sample: " + util.pp_str_flat(sample))

    prepared_sample = sample
    if request_handler is not None and util.has_function(
            request_handler, "pre_inference"):
        prepared_sample = request_handler.pre_inference(
            sample, local_cache["metadata"]["signatureDef"])
        logger.info("pre_inference: " + util.pp_str_flat(prepared_sample))

    validate_sample(sample)

    if util.is_resource_ref(local_cache["api"]["model"]):
        for column in local_cache["required_inputs"]:
            column_type = ctx.get_inferred_column_type(column["name"])
            prepared_sample[column["name"]] = util.upcast(
                prepared_sample[column["name"]], column_type)

        transformed_sample = transform_sample(prepared_sample)
        logger.info("transformed_sample: " +
                    util.pp_str_flat(transformed_sample))

        prediction_request = create_prediction_request(transformed_sample)
        response_proto = local_cache["stub"].Predict(prediction_request,
                                                     timeout=10.0)
        result = parse_response_proto(response_proto)

        result["transformed_sample"] = transformed_sample
        logger.info("inference: " + util.pp_str_flat(result))
    else:
        prediction_request = create_raw_prediction_request(prepared_sample)
        response_proto = local_cache["stub"].Predict(prediction_request,
                                                     timeout=10.0)
        result = parse_response_proto_raw(response_proto)

        logger.info("inference: " + util.pp_str_flat(result))

    if request_handler is not None and util.has_function(
            request_handler, "post_inference"):
        result = request_handler.post_inference(
            result, local_cache["metadata"]["signatureDef"])
        logger.info("post_inference: " + util.pp_str_flat(result))

    return result
Exemplo n.º 2
0
def validate_sample(sample):
    api = local_cache["api"]
    if util.is_resource_ref(api["model"]):
        ctx = local_cache["ctx"]
        for column in local_cache["required_inputs"]:
            if column["name"] not in sample:
                raise UserException('missing key "{}"'.format(column["name"]))
            sample_val = sample[column["name"]]
            column_type = ctx.get_inferred_column_type(column["name"])
            is_valid = util.CORTEX_TYPE_TO_VALIDATOR[column_type](sample_val)

            if not is_valid:
                raise UserException('key "{}"'.format(column["name"]),
                                    "expected type " + column_type)
    else:
        signature = extract_signature()
        for input_name, metadata in signature.items():
            if input_name not in sample:
                raise UserException('missing key "{}"'.format(input_name))
Exemplo n.º 3
0
def run_predict(sample):
    request_handler = local_cache.get("request_handler")

    prepared_sample = sample
    if request_handler is not None and util.has_function(
            request_handler, "pre_inference"):
        prepared_sample = request_handler.pre_inference(
            sample, local_cache["metadata"]["signatureDef"])

    if util.is_resource_ref(local_cache["api"]["model"]):
        transformed_sample = transform_sample(prepared_sample)
        prediction_request = create_prediction_request(transformed_sample)
        response_proto = local_cache["stub"].Predict(prediction_request,
                                                     timeout=10.0)
        result = parse_response_proto(response_proto)

        util.log_indent("Raw sample:", indent=4)
        util.log_pretty_flat(sample, indent=6)
        util.log_indent("Transformed sample:", indent=4)
        util.log_pretty_flat(transformed_sample, indent=6)
        util.log_indent("Prediction:", indent=4)
        util.log_pretty_flat(result, indent=6)

        result["transformed_sample"] = transformed_sample

    else:
        prediction_request = create_raw_prediction_request(prepared_sample)
        response_proto = local_cache["stub"].Predict(prediction_request,
                                                     timeout=10.0)
        result = parse_response_proto_raw(response_proto)
        util.log_indent("Sample:", indent=4)
        util.log_pretty_flat(sample, indent=6)
        util.log_indent("Prediction:", indent=4)
        util.log_pretty_flat(result, indent=6)

    if request_handler is not None and util.has_function(
            request_handler, "post_inference"):
        result = request_handler.post_inference(
            result, local_cache["metadata"]["signatureDef"])

    return result
Exemplo n.º 4
0
def create_transformer_inputs_from_map(input, col_value_map):
    if util.is_str(input):
        if util.is_resource_ref(input):
            res_name = util.get_resource_ref(input)
            return col_value_map[res_name]
        return input

    if util.is_list(input):
        replaced = []
        for item in input:
            replaced.append(create_transformer_inputs_from_map(item, col_value_map))
        return replaced

    if util.is_dict(input):
        replaced = {}
        for key, val in input.items():
            key_replaced = create_transformer_inputs_from_map(key, col_value_map)
            val_replaced = create_transformer_inputs_from_map(val, col_value_map)
            replaced[key_replaced] = val_replaced
        return replaced

    return input
Exemplo n.º 5
0
def start(args):
    ctx = Context(s3_path=args.context,
                  cache_dir=args.cache_dir,
                  workload_id=args.workload_id)

    api = ctx.apis_id_map[args.api]
    local_cache["api"] = api
    local_cache["ctx"] = ctx

    try:
        if api.get("request_handler_impl_key") is not None:
            local_cache["request_handler"] = ctx.get_request_handler_impl(
                api["name"])

        if not util.is_resource_ref(api["model"]):
            if api.get("request_handler") is not None:
                package.install_packages(ctx.python_packages, ctx.storage)
            if not os.path.isdir(args.model_dir):
                ctx.storage.download_and_unzip_external(
                    api["model"], args.model_dir)
        else:
            package.install_packages(ctx.python_packages, ctx.storage)
            model_name = util.get_resource_ref(api["model"])
            model = ctx.models[model_name]
            estimator = ctx.estimators[model["estimator"]]

            local_cache["model"] = model
            local_cache["estimator"] = estimator
            local_cache["target_col"] = ctx.columns[util.get_resource_ref(
                model["target_column"])]
            local_cache["target_col_type"] = ctx.get_inferred_column_type(
                util.get_resource_ref(model["target_column"]))

            log_level = "DEBUG"
            if ctx.environment is not None and ctx.environment.get(
                    "log_level") is not None:
                log_level = ctx.environment["log_level"].get(
                    "tensorflow", "DEBUG")
            tf_lib.set_logging_verbosity(log_level)

            if not os.path.isdir(args.model_dir):
                ctx.storage.download_and_unzip(model["key"], args.model_dir)

            for column_name in ctx.extract_column_names(
                [model["input"], model["target_column"]]):
                if ctx.is_transformed_column(column_name):
                    trans_impl, _ = ctx.get_transformer_impl(column_name)
                    local_cache["trans_impls"][column_name] = trans_impl
                    transformed_column = ctx.transformed_columns[column_name]

                    # cache aggregate values
                    for resource_name in util.extract_resource_refs(
                            transformed_column["input"]):
                        if resource_name in ctx.aggregates:
                            ctx.get_obj(ctx.aggregates[resource_name]["key"])

            local_cache["required_inputs"] = tf_lib.get_base_input_columns(
                model["name"], ctx)

            if util.is_dict(model["input"]) and model["input"].get(
                    "target_vocab") is not None:
                local_cache["target_vocab_populated"] = ctx.populate_values(
                    model["input"]["target_vocab"], None, False)
    except CortexException as e:
        e.wrap("error")
        logger.error(str(e))
        logger.exception(
            "An error occurred, see `cortex logs -v api {}` for more details.".
            format(api["name"]))
        sys.exit(1)
    except Exception as e:
        logger.exception(
            "An error occurred, see `cortex logs -v api {}` for more details.".
            format(api["name"]))
        sys.exit(1)

    try:
        validate_model_dir(args.model_dir)
    except Exception as e:
        logger.exception(e)
        sys.exit(1)

    channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port))
    local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(
        channel)

    # wait a bit for tf serving to start before querying metadata
    limit = 300
    for i in range(limit):
        try:
            local_cache["metadata"] = run_get_model_metadata()
            break
        except Exception as e:
            if i == limit - 1:
                logger.exception(
                    "An error occurred, see `cortex logs -v api {}` for more details."
                    .format(api["name"]))
                sys.exit(1)

        time.sleep(1)

    serve(app, listen="*:{}".format(args.port))
Exemplo n.º 6
0
    def populate_values(self, input, input_schema, preserve_column_refs):
        if input is None:
            if input_schema is None:
                return None
            if input_schema.get("_allow_null") == True:
                return None
            raise UserException("Null value is not allowed")

        if util.is_resource_ref(input):
            res_name = util.get_resource_ref(input)
            if res_name in self.constants:
                if self.constants[res_name].get("value") is not None:
                    const_val = self.constants[res_name]["value"]
                elif self.constants[res_name].get("path") is not None:
                    const_val = self.storage.get_json_external(self.constants[res_name]["path"])
                try:
                    return self.populate_values(const_val, input_schema, preserve_column_refs)
                except CortexException as e:
                    e.wrap("constant " + res_name)
                    raise

            if res_name in self.aggregates:
                agg_val = self.get_obj(self.aggregates[res_name]["key"])
                try:
                    return self.populate_values(agg_val, input_schema, preserve_column_refs)
                except CortexException as e:
                    e.wrap("aggregate " + res_name)
                    raise

            if res_name in self.columns:
                if input_schema is not None:
                    col_type = self.get_inferred_column_type(res_name)
                    if col_type not in input_schema["_type"]:
                        raise UserException(
                            "column {}: unsupported input type (expected type {}, got type {})".format(
                                res_name,
                                util.data_type_str(input_schema["_type"]),
                                util.data_type_str(col_type),
                            )
                        )
                if preserve_column_refs:
                    return input
                else:
                    return res_name

        if util.is_list(input):
            elem_schema = None
            if input_schema is not None:
                if not util.is_list(input_schema["_type"]):
                    raise UserException(
                        "unsupported input type (expected type {}, got {})".format(
                            util.data_type_str(input_schema["_type"]), util.user_obj_str(input)
                        )
                    )
                elem_schema = input_schema["_type"][0]

                min_count = input_schema.get("_min_count")
                if min_count is not None and len(input) < min_count:
                    raise UserException(
                        "list has length {}, but the minimum allowed length is {}".format(
                            len(input), min_count
                        )
                    )

                max_count = input_schema.get("_max_count")
                if max_count is not None and len(input) > max_count:
                    raise UserException(
                        "list has length {}, but the maximum allowed length is {}".format(
                            len(input), max_count
                        )
                    )

            casted = []
            for i, elem in enumerate(input):
                try:
                    casted.append(self.populate_values(elem, elem_schema, preserve_column_refs))
                except CortexException as e:
                    e.wrap("index " + i)
                    raise
            return casted

        if util.is_dict(input):
            if input_schema is None:
                casted = {}
                for key, val in input.items():
                    key_casted = self.populate_values(key, None, preserve_column_refs)
                    try:
                        val_casted = self.populate_values(val, None, preserve_column_refs)
                    except CortexException as e:
                        e.wrap(util.user_obj_str(key))
                        raise
                    casted[key_casted] = val_casted
                return casted

            if not util.is_dict(input_schema["_type"]):
                raise UserException(
                    "unsupported input type (expected type {}, got {})".format(
                        util.data_type_str(input_schema["_type"]), util.user_obj_str(input)
                    )
                )

            min_count = input_schema.get("_min_count")
            if min_count is not None and len(input) < min_count:
                raise UserException(
                    "map has length {}, but the minimum allowed length is {}".format(
                        len(input), min_count
                    )
                )

            max_count = input_schema.get("_max_count")
            if max_count is not None and len(input) > max_count:
                raise UserException(
                    "map has length {}, but the maximum allowed length is {}".format(
                        len(input), max_count
                    )
                )

            is_generic_map = False
            if len(input_schema["_type"]) == 1:
                input_type_key = next(iter(input_schema["_type"].keys()))
                if is_compound_type(input_type_key):
                    is_generic_map = True
                    generic_map_key_schema = input_schema_from_type_schema(input_type_key)
                    generic_map_value = input_schema["_type"][input_type_key]

            if is_generic_map:
                casted = {}
                for key, val in input.items():
                    key_casted = self.populate_values(
                        key, generic_map_key_schema, preserve_column_refs
                    )
                    try:
                        val_casted = self.populate_values(
                            val, generic_map_value, preserve_column_refs
                        )
                    except CortexException as e:
                        e.wrap(util.user_obj_str(key))
                        raise
                    casted[key_casted] = val_casted
                return casted

            # fixed map
            casted = {}
            for key, val_schema in input_schema["_type"].items():
                if key in input:
                    val = input[key]
                else:
                    if val_schema.get("_optional") is not True:
                        raise UserException("missing key: " + util.user_obj_str(key))
                    if val_schema.get("_default") is None:
                        continue
                    val = val_schema["_default"]

                try:
                    val_casted = self.populate_values(val, val_schema, preserve_column_refs)
                except CortexException as e:
                    e.wrap(util.user_obj_str(key))
                    raise
                casted[key] = val_casted
            return casted

        if input_schema is None:
            return input
        if not util.is_str(input_schema["_type"]):
            raise UserException(
                "unsupported input type (expected type {}, got {})".format(
                    util.data_type_str(input_schema["_type"]), util.user_obj_str(input)
                )
            )
        return cast_compound_type(input, input_schema["_type"])
Exemplo n.º 7
0
def predict(deployment_name, api_name):

    try:
        payload = request.get_json()
    except Exception as e:
        return "Malformed JSON", status.HTTP_400_BAD_REQUEST

    ctx = local_cache["ctx"]
    api = local_cache["api"]

    response = {}

    if not util.is_dict(payload) or "samples" not in payload:
        util.log_pretty_flat(payload, logging_func=logger.error)
        return prediction_failed(
            payload, "top level `samples` key not found in request")

    logger.info("Predicting " +
                util.pluralize(len(payload["samples"]), "sample", "samples"))

    predictions = []
    samples = payload["samples"]
    if not util.is_list(samples):
        util.log_pretty_flat(samples, logging_func=logger.error)
        return prediction_failed(
            payload,
            "expected the value of key `samples` to be a list of json objects")

    for i, sample in enumerate(payload["samples"]):
        util.log_indent("sample {}".format(i + 1), 2)

        if util.is_resource_ref(api["model"]):
            is_valid, reason = is_valid_sample(sample)
            if not is_valid:
                return prediction_failed(sample, reason)

            for column in local_cache["required_inputs"]:
                column_type = ctx.get_inferred_column_type(column["name"])
                sample[column["name"]] = util.upcast(sample[column["name"]],
                                                     column_type)

        try:
            result = run_predict(sample)
        except CortexException as e:
            e.wrap("error", "sample {}".format(i + 1))
            logger.error(str(e))
            logger.exception(
                "An error occurred, see `cortex logs -v api {}` for more details."
                .format(api["name"]))
            return prediction_failed(sample, str(e))
        except Exception as e:
            logger.exception(
                "An error occurred, see `cortex logs -v api {}` for more details."
                .format(api["name"]))

            # Show signature def for external models (since we don't validate input)
            schemaStr = ""
            signature_def = local_cache["metadata"]["signatureDef"]
            if (not util.is_resource_ref(api["model"]) and
                    signature_def.get("predict") is not None  # Just to be safe
                    and signature_def["predict"].get("inputs") is
                    not None  # Just to be safe
                ):
                schemaStr = "\n\nExpected shema:\n" + util.pp_str(
                    signature_def["predict"]["inputs"])

            return prediction_failed(sample, str(e) + schemaStr)

        predictions.append(result)

    response["predictions"] = predictions
    response["resource_id"] = api["id"]

    return jsonify(response)