Beispiel #1
0
    def search(self,
               prefix="",
               suffix="") -> Tuple[List[str], List[datetime.datetime]]:
        paths = []
        timestamps = []

        timestamp_map = {}
        for key, ts in self._get_matching_s3_keys_generator(prefix, suffix):
            timestamp_map[key] = ts

        filtered_keys = util.remove_non_empty_directory_paths(
            list(timestamp_map.keys()))
        for key in filtered_keys:
            paths.append(key)
            timestamps.append(timestamp_map[key])

        return paths, timestamps
Beispiel #2
0
def model_downloader(
    predictor_type: PredictorType,
    bucket_provider: str,
    bucket_name: str,
    model_name: str,
    model_version: str,
    model_path: str,
    temp_dir: str,
    model_dir: str,
) -> Optional[datetime.datetime]:
    """
    Downloads model to disk. Validates the cloud model path and the downloaded model as well.

    Args:
        predictor_type: The predictor type as implemented by the API.
        bucket_provider: Provider for the bucket. Can be "s3" or "gs".
        bucket_name: Name of the bucket where the model is stored.
        model_name: Name of the model. Is part of the model's local path.
        model_version: Version of the model. Is part of the model's local path.
        model_path: Model prefix of the versioned model.
        temp_dir: Where to temporarily store the model for validation.
        model_dir: The top directory of where all models are stored locally.

    Returns:
        The model's timestamp. None if the model didn't pass the validation, if it doesn't exist or if there are not enough permissions.
    """

    logger().info(
        f"downloading from bucket {bucket_name}/{model_path}, model {model_name} of version {model_version}, temporarily to {temp_dir} and then finally to {model_dir}"
    )

    if bucket_provider == "s3":
        client = S3(bucket_name)
    if bucket_provider == "gs":
        client = GCS(bucket_name)

    # validate upstream cloud model
    sub_paths, ts = client.search(model_path)
    try:
        validate_model_paths(sub_paths, predictor_type, model_path)
    except CortexException:
        logger().info(
            f"failed validating model {model_name} of version {model_version}")
        return None

    # download model to temp dir
    temp_dest = os.path.join(temp_dir, model_name, model_version)
    try:
        client.download_dir_contents(model_path, temp_dest)
    except CortexException:
        logger().info(
            f"failed downloading model {model_name} of version {model_version} to temp dir {temp_dest}"
        )
        shutil.rmtree(temp_dest)
        return None

    # validate model
    model_contents = glob.glob(temp_dest + "*/**", recursive=True)
    model_contents = util.remove_non_empty_directory_paths(model_contents)
    try:
        validate_model_paths(model_contents, predictor_type, temp_dest)
    except CortexException:
        logger().info(
            f"failed validating model {model_name} of version {model_version} from temp dir"
        )
        shutil.rmtree(temp_dest)
        return None

    # move model to dest dir
    model_top_dir = os.path.join(model_dir, model_name)
    ondisk_model_version = os.path.join(model_top_dir, model_version)
    logger().info(
        f"moving model {model_name} of version {model_version} to final dir {ondisk_model_version}"
    )
    if os.path.isdir(ondisk_model_version):
        shutil.rmtree(ondisk_model_version)
    shutil.move(temp_dest, ondisk_model_version)

    return max(ts)