def _validate_model_paths(pattern: Any, paths: List[str], common_prefix: str) -> None: if common_prefix not in paths_by_prefix_cache: paths_by_prefix_cache[common_prefix] = util.get_paths_with_prefix( paths, common_prefix) paths = paths_by_prefix_cache[common_prefix] rel_paths = [os.path.relpath(path, common_prefix) for path in paths] rel_paths = [path for path in rel_paths if not path.startswith("../")] objects = [util.get_leftmost_part_of_path(path) for path in rel_paths] objects = list(set(objects)) visited_objects = len(objects) * [False] ooa_valid_key_ids = [] # OneOfAllPlaceholder IDs that are valid if pattern is None: if len(objects) == 1 and objects[0] == ".": return ooa_valid_key_ids raise CortexException( f"{predictor_type} predictor at '{common_prefix}'", "template doesn't specify a substructure for the given path", ) if not isinstance(pattern, dict): pattern = {pattern: None} keys = list(pattern.keys()) keys.sort(key=operator.attrgetter("priority")) try: if any(isinstance(x, OneOfAllPlaceholder) for x in keys) and not all( isinstance(x, OneOfAllPlaceholder) for x in keys): raise CortexException( f"{predictor_type} predictor at '{common_prefix}'", f"{OneOfAllPlaceholder()} is a mutual-exclusive key with all other keys", ) elif all(isinstance(x, OneOfAllPlaceholder) for x in keys): num_keys = len(keys) num_validation_failures = 0 for key_id, key in enumerate(keys): if key == IntegerPlaceholder: _validate_integer_placeholder(keys, key_id, objects, visited_objects) elif key == AnyPlaceholder: _validate_any_placeholder(keys, key_id, objects, visited_objects) elif key == SinglePlaceholder: _validate_single_placeholder(keys, key_id, objects, visited_objects) elif isinstance(key, GenericPlaceholder): _validate_generic_placeholder(keys, key_id, objects, visited_objects, key) elif isinstance(key, PlaceholderGroup): _validate_group_placeholder(keys, key_id, objects, visited_objects) elif isinstance(key, OneOfAllPlaceholder): try: _validate_model_paths(pattern[key], paths, common_prefix) ooa_valid_key_ids.append(key.ID) except CortexException: num_validation_failures += 1 else: raise CortexException( "found a non-placeholder object in model template") except CortexException as e: raise CortexException( f"{predictor_type} predictor at '{common_prefix}'", str(e)) if (all(isinstance(x, OneOfAllPlaceholder) for x in keys) and num_validation_failures == num_keys): raise CortexException( f"couldn't validate for any of the {OneOfAllPlaceholder()} placeholders" ) if all(isinstance(x, OneOfAllPlaceholder) for x in keys): return ooa_valid_key_ids unvisited_paths = [] for idx, visited in enumerate(visited_objects): if visited is False: untraced_common_prefix = os.path.join(common_prefix, objects[idx]) untraced_paths = [ os.path.relpath(path, untraced_common_prefix) for path in paths ] untraced_paths = [ os.path.join(objects[idx], path) for path in untraced_paths if not path.startswith("../") ] unvisited_paths += untraced_paths if len(unvisited_paths) > 0: raise CortexException( f"{predictor_type} predictor model at '{common_prefix}'", "unexpected path(s) for " + str(unvisited_paths), ) new_common_prefixes = [] sub_patterns = [] paths_by_prefix = {} for obj_id, key_id in enumerate(visited_objects): obj = objects[obj_id] key = keys[key_id] if key != AnyPlaceholder: new_common_prefixes.append(os.path.join(common_prefix, obj)) sub_patterns.append(pattern[key]) if len(new_common_prefixes) > 0: paths_by_prefix = util.get_paths_by_prefixes( paths, new_common_prefixes) aggregated_ooa_valid_key_ids = [] for sub_pattern, new_common_prefix in zip(sub_patterns, new_common_prefixes): aggregated_ooa_valid_key_ids += _validate_model_paths( sub_pattern, paths_by_prefix[new_common_prefix], new_common_prefix) return aggregated_ooa_valid_key_ids
def add_models( self, model_names: List[str], model_versions: List[List[str]], model_disk_paths: List[str], signature_keys: List[Optional[str]], skip_if_present: bool = False, timeout: Optional[float] = None, max_retries: int = 0, ) -> None: """ Add models to TFS. If they can't be loaded, use remove_models to remove them from TFS. Args: model_names: List of model names to add. model_versions: List of lists - each element is a list of versions for a given model name. model_disk_paths: The common model disk path of multiple versioned models of the same model name (i.e. modelA/ for modelA/1 and modelA/2). skip_if_present: If the models are already loaded, don't make a new request to TFS. signature_keys: The signature keys as set in cortex_internal.yaml. If an element is set to None, then "predict" key will be assumed. max_retries: How many times to call ReloadConfig before giving up. Raises: grpc.RpcError in case something bad happens while communicating. StatusCode.DEADLINE_EXCEEDED when timeout is encountered. StatusCode.UNAVAILABLE when the service is unreachable. cortex_internal.lib.exceptions.CortexException if a non-0 response code is returned (i.e. model couldn't be loaded). cortex_internal.lib.exceptions.UserException when a model couldn't be validated for the signature def. """ request = model_management_pb2.ReloadConfigRequest() model_server_config = model_server_config_pb2.ModelServerConfig() num_added_models = 0 for model_name, versions, model_disk_path in zip( model_names, model_versions, model_disk_paths): for model_version in versions: versioned_model_disk_path = os.path.join( model_disk_path, model_version) num_added_models += self._add_model_to_dict( model_name, model_version, versioned_model_disk_path) if skip_if_present and num_added_models == 0: return config_list = model_server_config_pb2.ModelConfigList() current_model_names = self._get_model_names() for model_name in current_model_names: versions, model_disk_path = self._get_model_info(model_name) versions = [int(version) for version in versions] model_config = config_list.config.add() model_config.name = model_name model_config.base_path = model_disk_path model_config.model_version_policy.CopyFrom( ServableVersionPolicy(specific=Specific(versions=versions))) model_config.model_platform = "tensorflow" model_server_config.model_config_list.CopyFrom(config_list) request.config.CopyFrom(model_server_config) while max_retries >= 0: max_retries -= 1 try: # to prevent HandleReloadConfigRequest from # throwing an exception (TFS has some race-condition bug) time.sleep(0.125) response = self._service.HandleReloadConfigRequest( request, timeout) break except grpc.RpcError as err: # to prevent HandleReloadConfigRequest from # throwing another exception on the next run time.sleep(0.125) raise if not (response and response.status.error_code == 0): if response: raise CortexException( "couldn't load user-requested models {} - failed with error code {}: {}" .format(model_names, response.status.error_code, response.status.error_message)) else: raise CortexException("couldn't load user-requested models") # get models metadata for model_name, versions, signature_key in zip(model_names, model_versions, signature_keys): for model_version in versions: self._load_model_signatures(model_name, model_version, signature_key)