Esempio n. 1
0
    def __init__(
        self,
        region: str = JUMPSTART_DEFAULT_REGION_NAME,
        max_s3_cache_items: int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
        s3_cache_expiration_horizon: datetime.timedelta =
        JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
        max_semantic_version_cache_items: int =
        JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
        semantic_version_cache_expiration_horizon: datetime.timedelta =
        JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
        manifest_file_s3_key: str =
        JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
        s3_bucket_name: Optional[str] = None,
        s3_client_config: Optional[botocore.config.Config] = None,
    ) -> None:  # fmt: on
        """Initialize a ``JumpStartModelsCache`` instance.

        Args:
            region (str): AWS region to associate with cache. Default: region associated
                with boto3 session.
            max_s3_cache_items (int): Maximum number of items to store in s3 cache.
                Default: 20.
            s3_cache_expiration_horizon (datetime.timedelta): Maximum time to hold
                items in s3 cache before invalidation. Default: 6 hours.
            max_semantic_version_cache_items (int): Maximum number of items to store in
                semantic version cache. Default: 20.
            semantic_version_cache_expiration_horizon (datetime.timedelta):
                Maximum time to hold items in semantic version cache before invalidation.
                Default: 6 hours.
            manifest_file_s3_key (str): The key in S3 corresponding to the sdk metadata manifest.
            s3_bucket_name (Optional[str]): S3 bucket to associate with cache.
                Default: JumpStart-hosted content bucket for region.
            s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
                Default: None (no config).
        """

        self._region = region
        self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
            max_cache_items=max_s3_cache_items,
            expiration_horizon=s3_cache_expiration_horizon,
            retrieval_function=self._get_file_from_s3,
        )
        self._model_id_semantic_version_manifest_key_cache = LRUCache[
            JumpStartVersionedModelId, JumpStartVersionedModelId
        ](
            max_cache_items=max_semantic_version_cache_items,
            expiration_horizon=semantic_version_cache_expiration_horizon,
            retrieval_function=self._get_manifest_key_from_model_id_semantic_version,
        )
        self._manifest_file_s3_key = manifest_file_s3_key
        self.s3_bucket_name = (
            utils.get_jumpstart_content_bucket(self._region)
            if s3_bucket_name is None
            else s3_bucket_name
        )
        self._s3_client = (
            boto3.client("s3", region_name=self._region, config=s3_client_config)
            if s3_client_config
            else boto3.client("s3", region_name=self._region)
        )
Esempio n. 2
0
def _retrieve_script_uri(
    model_id: str,
    model_version: str,
    script_scope: Optional[str],
    region: Optional[str],
    tolerate_vulnerable_model: bool,
    tolerate_deprecated_model: bool,
):
    """Retrieves the script S3 URI associated with the model matching the given arguments.

    Args:
        model_id (str): JumpStart model ID of the JumpStart model for which to
            retrieve the script S3 URI.
        model_version (str): Version of the JumpStart model for which to
            retrieve the model script S3 URI.
        script_scope (str): The script type, i.e. what it is used for.
            Valid values: "training" and "inference".
        region (str): Region for which to retrieve model script S3 URI.
        tolerate_vulnerable_model (bool): True if vulnerable versions of model
            specifications should be tolerated (exception not raised). If False, raises an
            exception if the script used by this version of the model has dependencies with known
            security vulnerabilities.
        tolerate_deprecated_model (bool): True if deprecated versions of model
            specifications should be tolerated (exception not raised). If False, raises
            an exception if the version of the model is deprecated.
    Returns:
        str: the model script URI for the corresponding model.

    Raises:
        ValueError: If the combination of arguments specified is not supported.
        VulnerableJumpStartModelError: If any of the dependencies required by the script have
            known security vulnerabilities.
        DeprecatedJumpStartModelError: If the version of the model is deprecated.
    """
    if region is None:
        region = JUMPSTART_DEFAULT_REGION_NAME

    model_specs = verify_model_region_and_return_specs(
        model_id=model_id,
        version=model_version,
        scope=script_scope,
        region=region,
        tolerate_vulnerable_model=tolerate_vulnerable_model,
        tolerate_deprecated_model=tolerate_deprecated_model,
    )

    if script_scope == JumpStartScriptScope.INFERENCE:
        model_script_key = model_specs.hosting_script_key
    elif script_scope == JumpStartScriptScope.TRAINING:
        model_script_key = model_specs.training_script_key

    bucket = get_jumpstart_content_bucket(region)

    script_s3_uri = f"s3://{bucket}/{model_script_key}"

    return script_s3_uri
Esempio n. 3
0
def test_get_jumpstart_content_bucket_override():
    with patch.dict(
            os.environ,
        {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}):
        with patch("logging.Logger.info") as mocked_info_log:
            random_region = "random_region"
            assert "some-val" == utils.get_jumpstart_content_bucket(
                random_region)
            mocked_info_log.assert_called_once_with(
                "Using JumpStart bucket override: '%s'",
                "some-val",
            )
Esempio n. 4
0
def download_inference_assets():

    if not os.path.exists(TMP_DIRECTORY_PATH):
        os.makedirs(TMP_DIRECTORY_PATH)

    for asset, s3_key in TEST_ASSETS_SPECS.items():
        file_path = os.path.join(TMP_DIRECTORY_PATH, str(asset.value))
        if not os.path.exists(file_path):
            download_file(
                file_path,
                get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME),
                s3_key,
                boto3.client("s3"),
            )
Esempio n. 5
0
def test_get_jumpstart_content_bucket():
    bad_region = "bad_region"
    assert bad_region not in JUMPSTART_REGION_NAME_SET
    with pytest.raises(ValueError):
        utils.get_jumpstart_content_bucket(bad_region)