Ejemplo n.º 1
0
    def _get_class_impl(self, module_name, impl_path, target_class_name):
        if impl_path.endswith(".pickle"):
            try:
                with open(impl_path, "rb") as pickle_file:
                    return dill.load(pickle_file)
            except Exception as e:
                raise UserException("unable to load pickle", str(e)) from e

        try:
            impl = imp.load_source(module_name, impl_path)
        except Exception as e:
            raise UserException(str(e)) from e

        classes = inspect.getmembers(impl, inspect.isclass)
        predictor_class = None
        for class_df in classes:
            if class_df[0] == target_class_name:
                if predictor_class is not None:
                    raise UserException(
                        f"multiple definitions for {target_class_name} class found; please check your imports and class definitions and ensure that there is only one Predictor class definition"
                    )
                predictor_class = class_df[1]
        if predictor_class is None:
            raise UserException(f"{target_class_name} class is not defined")

        return predictor_class
Ejemplo n.º 2
0
    def _read_impl(module_name: str, impl_path: str, target_class_name: str):
        if impl_path.endswith(".pickle"):
            try:
                with open(impl_path, "rb") as pickle_file:
                    return dill.load(pickle_file)
            except Exception as e:
                raise UserException("unable to load pickle", str(e)) from e

        try:
            impl = imp.load_source(module_name, impl_path)
        except Exception as e:
            raise UserException(str(e)) from e

        classes = inspect.getmembers(impl, inspect.isclass)

        if len(classes) > 0:
            task_class = None
            for class_df in classes:
                if class_df[0] == target_class_name:
                    if task_class is not None:
                        raise UserException(
                            f"multiple definitions for {target_class_name} class found; please check "
                            f"your imports and class definitions and ensure that there is only one "
                            f"task class definition")
                    task_class = class_df[1]
            if task_class is None:
                raise UserException(
                    f"{target_class_name} class is not defined")
            return task_class
        else:
            raise UserException("no callable class was provided")
Ejemplo n.º 3
0
def validate_predictor_with_grpc(impl, api_spec):
    if not is_grpc_enabled(api_spec):
        return

    target_class_name = impl.__name__
    constructor = getattr(impl, "__init__")
    constructor_arg_spec = inspect.getfullargspec(constructor)
    if "proto_module_pb2" not in constructor_arg_spec.args:
        raise UserException(
            f"class {target_class_name}",
            f'invalid signature for method "__init__"',
            f'"proto_module_pb2" is a required argument, but was not provided',
            f"when a protobuf is specified in the api spec, then that means the grpc protocol is enabled, "
            f'which means that adding the "proto_module_pb2" argument is required',
        )

    predictor = getattr(impl, "predict")
    predictor_arg_spec = inspect.getfullargspec(predictor)
    disallowed_params = list(
        set(["query_params", "headers",
             "batch_id"]).intersection(predictor_arg_spec.args))
    if len(disallowed_params) > 0:
        raise UserException(
            f"class {target_class_name}",
            f'invalid signature for method "predict"',
            f'{util.string_plural_with_s("argument", len(disallowed_params))} {util.and_list_with_quotes(disallowed_params)} cannot be used when the grpc protocol is enabled',
        )

    if getattr(impl, "post_predict", None):
        raise UserException(
            f"class {target_class_name}",
            f"post_predict method is not supported when the grpc protocol is enabled",
        )
Ejemplo n.º 4
0
    def _create_prediction_request(
        self,
        signature_def: dict,
        signature_key: str,
        model_name: str,
        model_version: int,
        model_input: Any,
    ) -> predictRequestClass:
        prediction_request = predict_pb2.PredictRequest()
        prediction_request.model_spec.name = model_name
        prediction_request.model_spec.version.value = int(model_version)
        prediction_request.model_spec.signature_name = signature_key

        for column_name, value in model_input.items():
            if signature_def[signature_key]["inputs"][column_name]["tensorShape"] == {}:
                shape = "scalar"
            elif signature_def[signature_key]["inputs"][column_name]["tensorShape"].get(
                "unknownRank", False
            ):
                # unknownRank is set to True if the model input has no rank
                # it may lead to an undefined behavior if unknownRank is only checked for its presence
                # so it also gets to be tested against its value
                shape = "unknown"
            else:
                shape = []
                for dim in signature_def[signature_key]["inputs"][column_name]["tensorShape"][
                    "dim"
                ]:
                    shape.append(int(dim["size"]))

            sig_type = signature_def[signature_key]["inputs"][column_name]["dtype"]

            try:
                tensor_proto = tf.compat.v1.make_tensor_proto(
                    value, dtype=DTYPE_TO_TF_TYPE[sig_type]
                )
                prediction_request.inputs[column_name].CopyFrom(tensor_proto)
            except Exception as e:
                if shape == "scalar":
                    raise UserException(
                        'key "{}"'.format(column_name), "expected to be a scalar", str(e)
                    ) from e
                elif shape == "unknown":
                    raise UserException(
                        'key "{}"'.format(column_name), "can be of any rank and shape", str(e)
                    ) from e
                else:
                    raise UserException(
                        'key "{}"'.format(column_name), "expected shape {}".format(shape), str(e)
                    ) from e

        return prediction_request
Ejemplo n.º 5
0
def transform_to_numpy(input_pyobj, input_metadata, model_name):
    target_dtype = ONNX_TO_NP_TYPE[input_metadata.type]
    target_shape = input_metadata.shape

    try:
        for idx, dim in enumerate(target_shape):
            if type(dim) is str:
                target_shape[idx] = -1
            elif type(dim) is not int:
                target_shape[idx] = 1

        if type(input_pyobj) is np.ndarray:
            np_arr = input_pyobj
            if np.issubdtype(np_arr.dtype, np.number) == np.issubdtype(target_dtype, np.number):
                if str(np_arr.dtype) != target_dtype:
                    np_arr = np_arr.astype(target_dtype)
            else:
                raise ValueError(
                    "expected dtype '{}' but found '{}' for model '{}'".format(
                        target_dtype, np_arr.dtype, model_name
                    )
                )
        else:
            np_arr = np.array(input_pyobj, dtype=target_dtype)

        # can only infer the size for up to 1 unknown dimension
        if target_shape.count(-1) <= 1:
            np_arr = np_arr.reshape(target_shape)

        return np_arr
    except Exception as e:
        raise UserException(
            "failed to convert to numpy array for model '{}'".format(model_name), str(e)
        ) from e
Ejemplo n.º 6
0
 def _validate_impl(impl):
     if inspect.isclass(impl):
         validate_class_impl(impl, TASK_CLASS_VALIDATION)
     else:
         callable_fn = impl
         argspec = inspect.getfullargspec(callable_fn)
         if not (len(argspec.args) == 1 and argspec.args[0] == "config"):
             raise UserException(
                 f'callable function must have the "config" parameter in its signature',
             )
Ejemplo n.º 7
0
    def _load_model_signatures(
        self, model_name: str, model_version: str, signature_key: Optional[str] = None
    ) -> None:
        """
        Queries the signature defs from TFS.

        Args:
            model_name: Name of the model.
            model_version: Version of the model.
            signature_key: Signature key of the model as passed in with predictor:signature_key, predictor:models:paths:signature_key or predictor:models:signature_key.
                When set to None, "predict" is the assumed key.

        Raises:
            cortex_internal.lib.exceptions.UserException when the signature def can't be validated.
        """

        # create model metadata request
        request = get_model_metadata_pb2.GetModelMetadataRequest()
        request.model_spec.name = model_name
        request.model_spec.version.value = int(model_version)
        request.metadata_field.append("signature_def")

        # get signature def
        last_idx = 0
        for times in range(100):
            try:
                resp = self._pred.GetModelMetadata(request)
                break
            except grpc.RpcError as e:
                # it has been observed that it may take a little bit of time
                # until a model gets to be accessible with TFS (even though it's already loaded in)
                time.sleep(0.3)
            last_idx = times
        if last_idx == 99:
            raise UserException(
                "couldn't find model '{}' of version '{}' to extract the signature def".format(
                    model_name, model_version
                )
            )

        sigAny = resp.metadata["signature_def"]
        signature_def_map = get_model_metadata_pb2.SignatureDefMap()
        sigAny.Unpack(signature_def_map)
        sigmap = json_format.MessageToDict(signature_def_map)
        signature_def = sigmap["signatureDef"]

        # extract signature key and input signature
        signature_key, input_signatures = self._extract_signatures(
            signature_def, signature_key, model_name, model_version
        )

        model_id = f"{model_name}-{model_version}"
        self.models[model_id]["signature_def"] = signature_def
        self.models[model_id]["signature_key"] = signature_key
        self.models[model_id]["input_signatures"] = input_signatures
Ejemplo n.º 8
0
def convert_to_onnx_input(model_input, input_metadata_list, model_name):
    input_dict = {}
    if len(input_metadata_list) == 1:
        input_metadata = input_metadata_list[0]
        if util.is_dict(model_input):
            if model_input.get(input_metadata.name) is None:
                raise UserException(
                    "missing key '{}' for model '{}'".format(input_metadata.name, model_name)
                )
            input_dict[input_metadata.name] = transform_to_numpy(
                model_input[input_metadata.name], input_metadata, model_name
            )
        else:
            try:
                input_dict[input_metadata.name] = transform_to_numpy(
                    model_input, input_metadata, model_name
                )
            except CortexException as e:
                e.wrap("key '{}' for model '{}'".format(input_metadata.name, model_name))
                raise
    else:
        for input_metadata in input_metadata_list:
            if not util.is_dict(model_input):
                expected_keys = [metadata.name for metadata in input_metadata_list]
                raise UserException(
                    "expected model_input to be a dictionary with keys '{}' for model '{}'".format(
                        ", ".join('"' + key + '"' for key in expected_keys), model_name
                    )
                )

            if model_input.get(input_metadata.name) is None:
                raise UserException(
                    "missing key '{}' for model '{}'".format(input_metadata.name, model_name)
                )
            try:
                input_dict[input_metadata.name] = transform_to_numpy(
                    model_input[input_metadata.name], input_metadata, model_name
                )
            except CortexException as e:
                e.wrap("key '{}' for model '{}'".format(input_metadata.name, model_name))
                raise
    return input_dict
Ejemplo n.º 9
0
def _validate_python_predictor_with_models(impl, api_spec):
    target_class_name = impl.__name__

    if _are_models_specified(api_spec):
        constructor = getattr(impl, "__init__")
        constructor_arg_spec = inspect.getfullargspec(constructor)
        if "python_client" not in constructor_arg_spec.args:
            raise UserException(
                f"class {target_class_name}",
                f'invalid signature for method "__init__"',
                f'"python_client" is a required argument, but was not provided',
                f'when the python predictor type is used and models are specified in the api spec, adding the "python_client" argument is required',
            )

        if getattr(impl, "load_model", None) is None:
            raise UserException(
                f"class {target_class_name}",
                f'required method "load_model" is not defined',
                f'when the python predictor type is used and models are specified in the api spec, adding the "load_model" method is required',
            )
Ejemplo n.º 10
0
 def _metric_validation(self,
                        metric: str,
                        value: float,
                        tags: Dict[str, str] = None):
     internal_prefixes = ("cortex_", "istio_")
     if metric.startswith(internal_prefixes):
         raise UserException(
             f"Metric name ({metric}) is invalid because it starts with a cortex exclusive prefix.\n"
             f"The following are prefixes are exclusive to cortex: {internal_prefixes}."
         )
     return fn(self, metric=metric, value=value, tags=tags)
Ejemplo n.º 11
0
    def predict(self,
                model_input: Any,
                model_name: str,
                model_version: str,
                timeout: float = 300.0) -> Any:
        """
        Args:
            model_input: The input to run the prediction on - as passed by the user.
            model_name: Name of the model.
            model_version: Version of the model.
            timeout: How many seconds to wait for the prediction to run before timing out.

        Raises:
            UserException when the model input is not valid or when the model's shape doesn't match that of the input's.
            grpc.RpcError in case something bad happens while communicating - should not happen.

        Returns:
            The prediction.
        """

        model_id = f"{model_name}-{model_version}"

        signature_def = self.models[model_id]["signature_def"]
        signature_key = self.models[model_id]["signature_key"]
        input_signatures = self.models[model_id]["input_signatures"]

        # validate model input
        for input_name, _ in input_signatures.items():
            if input_name not in model_input:
                raise UserException(
                    "missing key '{}' for model '{}' of version '{}'".format(
                        input_name, model_name, model_version))

        # create prediction request
        prediction_request = self._create_prediction_request(
            signature_def, signature_key, model_name, model_version,
            model_input)

        # run prediction
        response_proto = self._pred.Predict(prediction_request,
                                            timeout=timeout)

        # interpret response message
        results_dict = json_format.MessageToDict(response_proto)
        outputs = results_dict["outputs"]
        outputs_simplified = {}
        for key in outputs:
            value_key = DTYPE_TO_VALUE_KEY[outputs[key]["dtype"]]
            outputs_simplified[key] = outputs[key][value_key]

        # return parsed response
        return outputs_simplified
Ejemplo n.º 12
0
    def _validate_impl(self, impl):
        if inspect.isclass(impl):
            target_class_name = impl.__name__

            constructor_fn = getattr(impl, "__init__", None)
            if constructor_fn:
                argspec = inspect.getfullargspec(constructor_fn)
                if not (len(argspec.args) == 1 and argspec.args[0] == "self"):
                    raise UserException(
                        f"class {target_class_name}",
                        f'invalid signature for method "__init__"',
                        f'only "self" parameter must be present in method\'s signature',
                    )

            callable_fn = getattr(impl, "__call__", None)
            if callable_fn:
                argspec = inspect.getfullargspec(callable_fn)
                if not (len(argspec.args) == 2 and argspec.args[0] == "self"
                        and argspec.args[1] == "config"):
                    raise UserException(
                        f"class {target_class_name}",
                        f'invalid signature for method "__call__"',
                        f'the following parameters must be present in method\'s signature: "self", "config"',
                    )
            else:
                raise UserException(
                    f"class {target_class_name}",
                    f'"__call__" method not defined',
                )
        else:
            callable_fn = impl
            argspec = inspect.getfullargspec(callable_fn)
            if not (len(argspec.args) == 1 and argspec.args[0] == "config"):
                raise UserException(
                    f'callable function must have the "config" parameter in its signature',
                )
Ejemplo n.º 13
0
def _validate_required_fn_args(impl, func_signature, api_spec):
    target_class_name = impl.__name__

    fn = getattr(impl, func_signature["name"], None)
    if not fn:
        raise UserException(
            f"class {target_class_name}",
            f'required method "{func_signature["name"]}" is not defined',
        )

    if not callable(fn):
        raise UserException(
            f"class {target_class_name}",
            f'"{func_signature["name"]}" is defined, but is not a method',
        )

    required_args = func_signature.get("required_args", [])
    optional_args = func_signature.get("optional_args", [])

    argspec = inspect.getfullargspec(fn)
    fn_str = f'{func_signature["name"]}({", ".join(argspec.args)})'

    for arg_name in required_args:
        if arg_name not in argspec.args:
            raise UserException(
                f"class {target_class_name}",
                f'invalid signature for method "{fn_str}"',
                f'"{arg_name}" is a required argument, but was not provided',
            )

        if arg_name == "self":
            if argspec.args[0] != "self":
                raise UserException(
                    f"class {target_class_name}",
                    f'invalid signature for method "{fn_str}"',
                    f'"self" must be the first argument',
                )

    seen_args = []
    for arg_name in argspec.args:
        if arg_name not in required_args and arg_name not in optional_args:
            raise UserException(
                f"class {target_class_name}",
                f'invalid signature for method "{fn_str}"',
                f'"{arg_name}" is not a supported argument',
            )

        if arg_name in seen_args:
            raise UserException(
                f"class {target_class_name}",
                f'invalid signature for method "{fn_str}"',
                f'"{arg_name}" is duplicated',
            )

        seen_args.append(arg_name)
Ejemplo n.º 14
0
    def _run_inference(self, model_input: Any, model_name: str,
                       model_version: str) -> dict:
        """
        When processes_per_replica = 1 and caching enabled, check/load model and make prediction.
        When processes_per_replica > 0 and caching disabled, attempt to make prediction regardless.

        Args:
            model_input: Input to the model.
            model_name: Name of the model, as it's specified in predictor:models:paths or in the other case as they are named on disk.
            model_version: Version of the model, as it's found on disk. Can also infer the version number from the "latest" version tag.

        Returns:
            The prediction.
        """

        model = None
        tag = ""
        if model_version == "latest":
            tag = model_version

        if not self._caching_enabled:

            # determine model version
            if tag == "latest":
                versions = self._client.poll_available_model_versions(
                    model_name)
                if len(versions) == 0:
                    raise UserException(
                        f"model '{model_name}' accessed with tag {tag} couldn't be found"
                    )
                model_version = str(max(map(lambda x: int(x), versions)))
            model_id = model_name + "-" + model_version

            return self._client.predict(model_input, model_name, model_version)

        if not self._multiple_processes and self._caching_enabled:

            # determine model version
            try:
                if tag == "latest":
                    model_version = self._get_latest_model_version_from_tree(
                        model_name, self._models_tree.model_info(model_name))
            except ValueError:
                # if model_name hasn't been found
                raise UserRuntimeException(
                    f"'{model_name}' model of tag {tag} wasn't found in the list of available models"
                )

            # grab shared access to model tree
            available_model = True
            logger.info(
                f"grabbing access to model {model_name} of version {model_version}"
            )
            with LockedModelsTree(self._models_tree, "r", model_name,
                                  model_version):

                # check if the versioned model exists
                model_id = model_name + "-" + model_version
                if model_id not in self._models_tree:
                    available_model = False
                    logger.info(
                        f"model {model_name} of version {model_version} is not available"
                    )
                    raise WithBreak

                # retrieve model tree's metadata
                upstream_model = self._models_tree[model_id]
                current_upstream_ts = int(
                    upstream_model["timestamp"].timestamp())
                logger.info(
                    f"model {model_name} of version {model_version} is available"
                )

            if not available_model:
                if tag == "":
                    raise UserException(
                        f"model '{model_name}' of version '{model_version}' couldn't be found"
                    )
                raise UserException(
                    f"model '{model_name}' accessed with tag '{tag}' couldn't be found"
                )

            # grab shared access to models holder and retrieve model
            update_model = False
            prediction = None
            tfs_was_unresponsive = False
            with LockedModel(self._models, "r", model_name, model_version):
                logger.info(
                    f"checking the {model_name} {model_version} status")
                status, local_ts = self._models.has_model(
                    model_name, model_version)
                if status in ["not-available", "on-disk"
                              ] or (status != "not-available"
                                    and local_ts != current_upstream_ts):
                    logger.info(
                        f"model {model_name} of version {model_version} is not loaded (with status {status} or different timestamp)"
                    )
                    update_model = True
                    raise WithBreak

                # run prediction
                logger.info(
                    f"run the prediction on model {model_name} of version {model_version}"
                )
                self._models.get_model(model_name, model_version, tag)
                try:
                    prediction = self._client.predict(model_input, model_name,
                                                      model_version)
                except grpc.RpcError as e:
                    # effectively when it got restarted
                    if len(
                            self._client.poll_available_model_versions(
                                model_name)) > 0:
                        raise
                    tfs_was_unresponsive = True

            # remove model from disk and memory references if TFS gets unresponsive
            if tfs_was_unresponsive:
                with LockedModel(self._models, "w", model_name, model_version):
                    available_versions = self._client.poll_available_model_versions(
                        model_name)
                    status, _ = self._models.has_model(model_name,
                                                       model_version)
                    if not (status == "in-memory"
                            and model_version not in available_versions):
                        raise WithBreak

                    logger.info(
                        f"removing model {model_name} of version {model_version} because TFS got unresponsive"
                    )
                    self._models.remove_model(model_name, model_version)

            # download, load into memory the model and retrieve it
            if update_model:
                # grab exclusive access to models holder
                with LockedModel(self._models, "w", model_name, model_version):

                    # check model status
                    status, local_ts = self._models.has_model(
                        model_name, model_version)

                    # refresh disk model
                    if status == "not-available" or (
                            status in ["on-disk", "in-memory"]
                            and local_ts != current_upstream_ts):
                        # unload model from TFS
                        if status == "in-memory":
                            try:
                                logger.info(
                                    f"unloading model {model_name} of version {model_version} from TFS"
                                )
                                self._models.unload_model(
                                    model_name, model_version)
                            except Exception:
                                logger.info(
                                    f"failed unloading model {model_name} of version {model_version} from TFS"
                                )
                                raise

                        # remove model from disk and references
                        if status in ["on-disk", "in-memory"]:
                            logger.info(
                                f"removing model references from memory and from disk for model {model_name} of version {model_version}"
                            )
                            self._models.remove_model(model_name,
                                                      model_version)

                        # download model
                        logger.info(
                            f"downloading model {model_name} of version {model_version} from the {upstream_model['provider']} upstream"
                        )
                        date = self._models.download_model(
                            upstream_model["provider"],
                            upstream_model["bucket"],
                            model_name,
                            model_version,
                            upstream_model["path"],
                        )
                        if not date:
                            raise WithBreak
                        current_upstream_ts = int(date.timestamp())

                    # load model
                    try:
                        logger.info(
                            f"loading model {model_name} of version {model_version} into memory"
                        )
                        self._models.load_model(
                            model_name,
                            model_version,
                            current_upstream_ts,
                            [tag],
                            kwargs={
                                "model_name":
                                model_name,
                                "model_version":
                                model_version,
                                "signature_key":
                                self._determine_model_signature_key(
                                    model_name),
                            },
                        )
                    except Exception as e:
                        raise UserRuntimeException(
                            f"failed (re-)loading model {model_name} of version {model_version} (thread {td.get_ident()})",
                            str(e),
                        )

                    # run prediction
                    self._models.get_model(model_name, model_version, tag)
                    prediction = self._client.predict(model_input, model_name,
                                                      model_version)

            return prediction
Ejemplo n.º 15
0
    def _extract_signatures(
        self, signature_def, signature_key, model_name: str, model_version: str
    ):
        logger.info(
            "signature defs found in model '{}' for version '{}': {}".format(
                model_name, model_version, signature_def
            )
        )

        available_keys = list(signature_def.keys())
        if len(available_keys) == 0:
            raise UserException(
                "unable to find signature defs in model '{}' of version '{}'".format(
                    model_name, model_version
                )
            )

        if signature_key is None:
            if len(available_keys) == 1:
                logger.info(
                    "signature_key was not configured by user, using signature key '{}' for model '{}' of version '{}' (found in the signature def map)".format(
                        available_keys[0],
                        model_name,
                        model_version,
                    )
                )
                signature_key = available_keys[0]
            elif "predict" in signature_def:
                logger.info(
                    "signature_key was not configured by user, using signature key 'predict' for model '{}' of version '{}' (found in the signature def map)".format(
                        model_name,
                        model_version,
                    )
                )
                signature_key = "predict"
            else:
                raise UserException(
                    "signature_key was not configured by user, please specify one the following keys '{}' for model '{}' of version '{}' (found in the signature def map)".format(
                        ", ".join(available_keys), model_name, model_version
                    )
                )
        else:
            if signature_def.get(signature_key) is None:
                possibilities_str = "key: '{}'".format(available_keys[0])
                if len(available_keys) > 1:
                    possibilities_str = "keys: '{}'".format("', '".join(available_keys))

                raise UserException(
                    "signature_key '{}' was not found in signature def map for model '{}' of version '{}', but found the following {}".format(
                        signature_key, model_name, model_version, possibilities_str
                    )
                )

        signature_def_val = signature_def.get(signature_key)

        if signature_def_val.get("inputs") is None:
            raise UserException(
                "unable to find 'inputs' in signature def '{}' for model '{}'".format(
                    signature_key, model_name
                )
            )

        parsed_signatures = {}
        for input_name, input_metadata in signature_def_val["inputs"].items():
            if input_metadata["tensorShape"] == {}:
                # a scalar with rank 0 and empty shape
                shape = "scalar"
            elif input_metadata["tensorShape"].get("unknownRank", False):
                # unknown rank and shape
                #
                # unknownRank is set to True if the model input has no rank
                # it may lead to an undefined behavior if unknownRank is only checked for its presence
                # so it also gets to be tested against its value
                shape = "unknown"
            elif input_metadata["tensorShape"].get("dim", None):
                # known rank and known/unknown shape
                shape = [int(dim["size"]) for dim in input_metadata["tensorShape"]["dim"]]
            else:
                raise UserException(
                    "invalid 'tensorShape' specification for input '{}' in signature key '{}' for model '{}'",
                    input_name,
                    signature_key,
                    model_name,
                )

            parsed_signatures[input_name] = {
                "shape": shape if type(shape) == list else [shape],
                "type": DTYPE_TO_TF_TYPE[input_metadata["dtype"]].name,
            }
        return signature_key, parsed_signatures