Ejemplo n.º 1
0
def validate_models_dir_paths(
        paths: List[str], predictor_type: PredictorType,
        common_prefix: str) -> Tuple[List[str], List[List[int]]]:
    """
    Validates the models paths based on the given predictor type.
    To be used when predictor:models:dir in cortex.yaml is used.

    Args:
        paths: A list of all paths for a given s3/local prefix. Must be underneath the common prefix.
        predictor_type: The predictor type.
        common_prefix: The common prefix of the directory which holds all models. AKA predictor:models:dir.

    Returns:
        List with the prefix of each model that's valid.
        List with the OneOfAllPlaceholder IDs validated for each valid model.
    """
    if len(paths) == 0:
        raise CortexException(
            f"{predictor_type} predictor at '{common_prefix}'",
            "model top path can't be empty")

    rel_paths = [
        os.path.relpath(top_path, common_prefix) for top_path in paths
    ]
    rel_paths = [path for path in rel_paths if not path.startswith("../")]

    model_names = [util.get_leftmost_part_of_path(path) for path in rel_paths]
    model_names = list(set(model_names))

    valid_model_prefixes = []
    ooa_valid_key_ids = []
    for model_name in model_names:
        try:
            ooa_valid_key_ids.append(
                validate_model_paths(rel_paths, predictor_type, model_name))
            valid_model_prefixes.append(os.path.join(common_prefix,
                                                     model_name))
        except CortexException as e:
            logger.debug(f"failed validating model {model_name}: {str(e)}")
            continue

    return valid_model_prefixes, ooa_valid_key_ids
Ejemplo n.º 2
0
def find_all_cloud_models(
    is_dir_used: bool,
    models_dir: str,
    predictor_type: PredictorType,
    cloud_paths: List[str],
    cloud_model_names: List[str],
) -> Tuple[List[str], Dict[str, List[str]], List[str], List[List[str]],
           List[List[datetime.datetime]], List[str], List[str], ]:
    """
    Get updated information on all models that are currently present on the cloud upstreams.
    Information on the available models, versions, last edit times, the subpaths of each model, and so on.

    Args:
        is_dir_used: Whether predictor:models:dir is used or not.
        models_dir: The value of predictor:models:dir in case it's present. Ignored when not required.
        predictor_type: The predictor type.
        cloud_paths: The cloud model paths as they are specified in predictor:models:path/predictor:models:paths/predictor:models:dir is used. Ignored when not required.
        cloud_model_names: The cloud model names as they are specified in predictor:models:paths:name when predictor:models:paths is used or the default name of the model when predictor:models:path is used. Ignored when not required.

    Returns: The tuple with the following elements:
        model_names - a list with the names of the models (i.e. bert, gpt-2, etc) and they are unique
        versions - a dictionary with the keys representing the model names and the values being lists of versions that each model has.
          For non-versioned model paths ModelVersion.NOT_PROVIDED, the list will be empty.
        model_paths - a list with the prefix of each model.
        sub_paths - a list of filepaths lists for each file of each model.
        timestamps - a list of timestamps lists representing the last edit time of each versioned model.
        bucket_providers - a list of the bucket providers for each model. Can be "s3" or "gs".
        bucket_names - a list of the bucket names of each model.
    """

    # validate models stored in cloud (S3 or GS) that were specified with predictor:models:dir field
    if is_dir_used:
        if S3.is_valid_s3_path(models_dir):
            bucket_name, models_path = S3.deconstruct_s3_path(models_dir)
            client = S3(bucket_name)
        if GCS.is_valid_gcs_path(models_dir):
            bucket_name, models_path = GCS.deconstruct_gcs_path(models_dir)
            client = GCS(bucket_name)

        sub_paths, timestamps = client.search(models_path)

        model_paths, ooa_ids = validate_models_dir_paths(
            sub_paths, predictor_type, models_path)
        model_names = [
            os.path.basename(model_path) for model_path in model_paths
        ]

        model_paths = [
            model_path for model_path in model_paths
            if os.path.basename(model_path) in model_names
        ]
        model_paths = [
            model_path + "/" * (not model_path.endswith("/"))
            for model_path in model_paths
        ]

        if S3.is_valid_s3_path(models_dir):
            bucket_providers = len(model_paths) * ["s3"]
        if GCS.is_valid_gcs_path(models_dir):
            bucket_providers = len(model_paths) * ["gs"]

        bucket_names = len(model_paths) * [bucket_name]
        sub_paths = len(model_paths) * [sub_paths]
        timestamps = len(model_paths) * [timestamps]

    # validate models stored in cloud (S3 or GS) that were specified with predictor:models:paths field
    if not is_dir_used:
        sub_paths = []
        ooa_ids = []
        model_paths = []
        model_names = []
        timestamps = []
        bucket_providers = []
        bucket_names = []
        for idx, path in enumerate(cloud_paths):
            if S3.is_valid_s3_path(path):
                bucket_name, model_path = S3.deconstruct_s3_path(path)
                client = S3(bucket_name)
            elif GCS.is_valid_gcs_path(path):
                bucket_name, model_path = GCS.deconstruct_gcs_path(path)
                client = GCS(bucket_name)
            else:
                continue

            sb, model_path_ts = client.search(model_path)
            try:
                ooa_ids.append(
                    validate_model_paths(sb, predictor_type, model_path))
            except CortexException:
                continue
            model_paths.append(model_path)
            model_names.append(cloud_model_names[idx])
            bucket_names.append(bucket_name)
            sub_paths += [sb]
            timestamps += [model_path_ts]

            if S3.is_valid_s3_path(path):
                bucket_providers.append("s3")
            if GCS.is_valid_gcs_path(path):
                bucket_providers.append("gs")

    # determine the detected versions for each cloud model
    # if the model was not versioned, then leave the version list empty
    versions = {}
    for model_path, model_name, model_ooa_ids, bucket_sub_paths in zip(
            model_paths, model_names, ooa_ids, sub_paths):
        if ModelVersion.PROVIDED not in model_ooa_ids:
            versions[model_name] = []
            continue

        model_sub_paths = [
            os.path.relpath(sub_path, model_path)
            for sub_path in bucket_sub_paths
        ]
        model_versions_paths = [
            path for path in model_sub_paths if not path.startswith("../")
        ]
        model_versions = [
            util.get_leftmost_part_of_path(model_version_path)
            for model_version_path in model_versions_paths
        ]
        model_versions = list(set(model_versions))
        versions[model_name] = model_versions

    # pick up the max timestamp for each versioned model
    aux_timestamps = []
    for model_path, model_name, bucket_sub_paths, sub_path_timestamps in zip(
            model_paths, model_names, sub_paths, timestamps):
        model_ts = []
        if len(versions[model_name]) == 0:
            masks = list(
                map(
                    lambda x: x.startswith(model_path + "/" *
                                           (model_path[-1] != "/")),
                    bucket_sub_paths,
                ))
            model_ts = [max(itertools.compress(sub_path_timestamps, masks))]

        for version in versions[model_name]:
            masks = list(
                map(
                    lambda x: x.startswith(
                        os.path.join(model_path, version) + "/"),
                    bucket_sub_paths,
                ))
            model_ts.append(max(itertools.compress(sub_path_timestamps,
                                                   masks)))

        aux_timestamps.append(model_ts)

    timestamps = aux_timestamps  # type: List[List[datetime.datetime]]

    # model_names - a list with the names of the models (i.e. bert, gpt-2, etc) and they are unique
    # versions - a dictionary with the keys representing the model names and the values being lists of versions that each model has.
    #   For non-versioned model paths ModelVersion.NOT_PROVIDED, the list will be empty
    # model_paths - a list with the prefix of each model
    # sub_paths - a list of filepaths lists for each file of each model
    # timestamps - a list of timestamps lists representing the last edit time of each versioned model
    # bucket_providers - bucket providers
    # bucket_names - names of the buckets

    return model_names, versions, model_paths, sub_paths, timestamps, bucket_providers, bucket_names
Ejemplo n.º 3
0
    def _validate_model_paths(pattern: Any, paths: List[str], common_prefix: str) -> None:
        if common_prefix not in paths_by_prefix_cache:
            paths_by_prefix_cache[common_prefix] = util.get_paths_with_prefix(paths, common_prefix)
        paths = paths_by_prefix_cache[common_prefix]

        rel_paths = [os.path.relpath(path, common_prefix) for path in paths]
        rel_paths = [path for path in rel_paths if not path.startswith("../")]

        objects = [util.get_leftmost_part_of_path(path) for path in rel_paths]
        objects = list(set(objects))
        visited_objects = len(objects) * [False]

        ooa_valid_key_ids = []  # OneOfAllPlaceholder IDs that are valid

        if pattern is None:
            if len(objects) == 1 and objects[0] == ".":
                return ooa_valid_key_ids
            raise CortexException(
                f"{predictor_type} predictor at '{common_prefix}'",
                "template doesn't specify a substructure for the given path",
            )
        if not isinstance(pattern, dict):
            pattern = {pattern: None}

        keys = list(pattern.keys())
        keys.sort(key=operator.attrgetter("priority"))

        try:
            if any(isinstance(x, OneOfAllPlaceholder) for x in keys) and not all(
                isinstance(x, OneOfAllPlaceholder) for x in keys
            ):
                raise CortexException(
                    f"{predictor_type} predictor at '{common_prefix}'",
                    f"{OneOfAllPlaceholder()} is a mutual-exclusive key with all other keys",
                )
            elif all(isinstance(x, OneOfAllPlaceholder) for x in keys):
                num_keys = len(keys)
                num_validation_failures = 0

            for key_id, key in enumerate(keys):
                if key == IntegerPlaceholder:
                    _validate_integer_placeholder(keys, key_id, objects, visited_objects)
                elif key == AnyPlaceholder:
                    _validate_any_placeholder(keys, key_id, objects, visited_objects)
                elif key == SinglePlaceholder:
                    _validate_single_placeholder(keys, key_id, objects, visited_objects)
                elif isinstance(key, GenericPlaceholder):
                    _validate_generic_placeholder(keys, key_id, objects, visited_objects, key)
                elif isinstance(key, PlaceholderGroup):
                    _validate_group_placeholder(keys, key_id, objects, visited_objects)
                elif isinstance(key, OneOfAllPlaceholder):
                    try:
                        _validate_model_paths(pattern[key], paths, common_prefix)
                        ooa_valid_key_ids.append(key.ID)
                    except CortexException:
                        num_validation_failures += 1
                else:
                    raise CortexException("found a non-placeholder object in model template")

        except CortexException as e:
            raise CortexException(f"{predictor_type} predictor at '{common_prefix}'", str(e))

        if (
            all(isinstance(x, OneOfAllPlaceholder) for x in keys)
            and num_validation_failures == num_keys
        ):
            raise CortexException(
                f"couldn't validate for any of the {OneOfAllPlaceholder()} placeholders"
            )
        if all(isinstance(x, OneOfAllPlaceholder) for x in keys):
            return ooa_valid_key_ids

        unvisited_paths = []
        for idx, visited in enumerate(visited_objects):
            if visited is False:
                untraced_common_prefix = os.path.join(common_prefix, objects[idx])
                untraced_paths = [os.path.relpath(path, untraced_common_prefix) for path in paths]
                untraced_paths = [
                    os.path.join(objects[idx], path)
                    for path in untraced_paths
                    if not path.startswith("../")
                ]
                unvisited_paths += untraced_paths
        if len(unvisited_paths) > 0:
            raise CortexException(
                f"{predictor_type} predictor model at '{common_prefix}'",
                "unexpected path(s) for " + str(unvisited_paths),
            )

        new_common_prefixes = []
        sub_patterns = []
        paths_by_prefix = {}
        for obj_id, key_id in enumerate(visited_objects):
            obj = objects[obj_id]
            key = keys[key_id]
            if key != AnyPlaceholder:
                new_common_prefixes.append(os.path.join(common_prefix, obj))
                sub_patterns.append(pattern[key])

        if len(new_common_prefixes) > 0:
            paths_by_prefix = util.get_paths_by_prefixes(paths, new_common_prefixes)

        aggregated_ooa_valid_key_ids = []
        for sub_pattern, new_common_prefix in zip(sub_patterns, new_common_prefixes):
            aggregated_ooa_valid_key_ids += _validate_model_paths(
                sub_pattern, paths_by_prefix[new_common_prefix], new_common_prefix
            )

        return aggregated_ooa_valid_key_ids