Ejemplo n.º 1
0
    def _validate_model_paths(pattern: Any, paths: List[str],
                              common_prefix: str) -> None:
        if common_prefix not in paths_by_prefix_cache:
            paths_by_prefix_cache[common_prefix] = util.get_paths_with_prefix(
                paths, common_prefix)
        paths = paths_by_prefix_cache[common_prefix]

        rel_paths = [os.path.relpath(path, common_prefix) for path in paths]
        rel_paths = [path for path in rel_paths if not path.startswith("../")]

        objects = [util.get_leftmost_part_of_path(path) for path in rel_paths]
        objects = list(set(objects))
        visited_objects = len(objects) * [False]

        ooa_valid_key_ids = []  # OneOfAllPlaceholder IDs that are valid

        if pattern is None:
            if len(objects) == 1 and objects[0] == ".":
                return ooa_valid_key_ids
            raise CortexException(
                f"{predictor_type} predictor at '{common_prefix}'",
                "template doesn't specify a substructure for the given path",
            )
        if not isinstance(pattern, dict):
            pattern = {pattern: None}

        keys = list(pattern.keys())
        keys.sort(key=operator.attrgetter("priority"))

        try:
            if any(isinstance(x, OneOfAllPlaceholder)
                   for x in keys) and not all(
                       isinstance(x, OneOfAllPlaceholder) for x in keys):
                raise CortexException(
                    f"{predictor_type} predictor at '{common_prefix}'",
                    f"{OneOfAllPlaceholder()} is a mutual-exclusive key with all other keys",
                )
            elif all(isinstance(x, OneOfAllPlaceholder) for x in keys):
                num_keys = len(keys)
                num_validation_failures = 0

            for key_id, key in enumerate(keys):
                if key == IntegerPlaceholder:
                    _validate_integer_placeholder(keys, key_id, objects,
                                                  visited_objects)
                elif key == AnyPlaceholder:
                    _validate_any_placeholder(keys, key_id, objects,
                                              visited_objects)
                elif key == SinglePlaceholder:
                    _validate_single_placeholder(keys, key_id, objects,
                                                 visited_objects)
                elif isinstance(key, GenericPlaceholder):
                    _validate_generic_placeholder(keys, key_id, objects,
                                                  visited_objects, key)
                elif isinstance(key, PlaceholderGroup):
                    _validate_group_placeholder(keys, key_id, objects,
                                                visited_objects)
                elif isinstance(key, OneOfAllPlaceholder):
                    try:
                        _validate_model_paths(pattern[key], paths,
                                              common_prefix)
                        ooa_valid_key_ids.append(key.ID)
                    except CortexException:
                        num_validation_failures += 1
                else:
                    raise CortexException(
                        "found a non-placeholder object in model template")

        except CortexException as e:
            raise CortexException(
                f"{predictor_type} predictor at '{common_prefix}'", str(e))

        if (all(isinstance(x, OneOfAllPlaceholder) for x in keys)
                and num_validation_failures == num_keys):
            raise CortexException(
                f"couldn't validate for any of the {OneOfAllPlaceholder()} placeholders"
            )
        if all(isinstance(x, OneOfAllPlaceholder) for x in keys):
            return ooa_valid_key_ids

        unvisited_paths = []
        for idx, visited in enumerate(visited_objects):
            if visited is False:
                untraced_common_prefix = os.path.join(common_prefix,
                                                      objects[idx])
                untraced_paths = [
                    os.path.relpath(path, untraced_common_prefix)
                    for path in paths
                ]
                untraced_paths = [
                    os.path.join(objects[idx], path) for path in untraced_paths
                    if not path.startswith("../")
                ]
                unvisited_paths += untraced_paths
        if len(unvisited_paths) > 0:
            raise CortexException(
                f"{predictor_type} predictor model at '{common_prefix}'",
                "unexpected path(s) for " + str(unvisited_paths),
            )

        new_common_prefixes = []
        sub_patterns = []
        paths_by_prefix = {}
        for obj_id, key_id in enumerate(visited_objects):
            obj = objects[obj_id]
            key = keys[key_id]
            if key != AnyPlaceholder:
                new_common_prefixes.append(os.path.join(common_prefix, obj))
                sub_patterns.append(pattern[key])

        if len(new_common_prefixes) > 0:
            paths_by_prefix = util.get_paths_by_prefixes(
                paths, new_common_prefixes)

        aggregated_ooa_valid_key_ids = []
        for sub_pattern, new_common_prefix in zip(sub_patterns,
                                                  new_common_prefixes):
            aggregated_ooa_valid_key_ids += _validate_model_paths(
                sub_pattern, paths_by_prefix[new_common_prefix],
                new_common_prefix)

        return aggregated_ooa_valid_key_ids
Ejemplo n.º 2
0
    def add_models(
        self,
        model_names: List[str],
        model_versions: List[List[str]],
        model_disk_paths: List[str],
        signature_keys: List[Optional[str]],
        skip_if_present: bool = False,
        timeout: Optional[float] = None,
        max_retries: int = 0,
    ) -> None:
        """
        Add models to TFS. If they can't be loaded, use remove_models to remove them from TFS.

        Args:
            model_names: List of model names to add.
            model_versions: List of lists - each element is a list of versions for a given model name.
            model_disk_paths: The common model disk path of multiple versioned models of the same model name (i.e. modelA/ for modelA/1 and modelA/2).
            skip_if_present: If the models are already loaded, don't make a new request to TFS.
            signature_keys: The signature keys as set in cortex_internal.yaml. If an element is set to None, then "predict" key will be assumed.
            max_retries: How many times to call ReloadConfig before giving up.
        Raises:
            grpc.RpcError in case something bad happens while communicating.
                StatusCode.DEADLINE_EXCEEDED when timeout is encountered. StatusCode.UNAVAILABLE when the service is unreachable.
            cortex_internal.lib.exceptions.CortexException if a non-0 response code is returned (i.e. model couldn't be loaded).
            cortex_internal.lib.exceptions.UserException when a model couldn't be validated for the signature def.
        """

        request = model_management_pb2.ReloadConfigRequest()
        model_server_config = model_server_config_pb2.ModelServerConfig()

        num_added_models = 0
        for model_name, versions, model_disk_path in zip(
                model_names, model_versions, model_disk_paths):
            for model_version in versions:
                versioned_model_disk_path = os.path.join(
                    model_disk_path, model_version)
                num_added_models += self._add_model_to_dict(
                    model_name, model_version, versioned_model_disk_path)

        if skip_if_present and num_added_models == 0:
            return

        config_list = model_server_config_pb2.ModelConfigList()
        current_model_names = self._get_model_names()
        for model_name in current_model_names:
            versions, model_disk_path = self._get_model_info(model_name)
            versions = [int(version) for version in versions]
            model_config = config_list.config.add()
            model_config.name = model_name
            model_config.base_path = model_disk_path
            model_config.model_version_policy.CopyFrom(
                ServableVersionPolicy(specific=Specific(versions=versions)))
            model_config.model_platform = "tensorflow"

        model_server_config.model_config_list.CopyFrom(config_list)
        request.config.CopyFrom(model_server_config)

        while max_retries >= 0:
            max_retries -= 1
            try:
                # to prevent HandleReloadConfigRequest from
                # throwing an exception (TFS has some race-condition bug)
                time.sleep(0.125)
                response = self._service.HandleReloadConfigRequest(
                    request, timeout)
                break
            except grpc.RpcError as err:
                # to prevent HandleReloadConfigRequest from
                # throwing another exception on the next run
                time.sleep(0.125)
                raise

        if not (response and response.status.error_code == 0):
            if response:
                raise CortexException(
                    "couldn't load user-requested models {} - failed with error code {}: {}"
                    .format(model_names, response.status.error_code,
                            response.status.error_message))
            else:
                raise CortexException("couldn't load user-requested models")

        # get models metadata
        for model_name, versions, signature_key in zip(model_names,
                                                       model_versions,
                                                       signature_keys):
            for model_version in versions:
                self._load_model_signatures(model_name, model_version,
                                            signature_key)