def download_and_install(uri, name=DEFAULT_MODULE_NAME, cache=True): # type: (str, str, bool) -> None """Download, prepare and install a compressed tar file from S3 or local directory as a module. The SageMaker Python SDK saves the user provided scripts as compressed tar files in S3. This function downloads this compressed file and, if provided, transforms it into a module before installing it. This method is the predecessor of :meth:`~sagemaker_containers.beta.framework.files.download_and_extract` and has been kept for backward-compatibility purposes. Args: name (str): name of the script or module. uri (str): the location of the module. cache (bool): defaults to True. It will not download and install the module again if it is already installed. """ should_use_cache = cache and exists(name) if not should_use_cache: with _files.tmpdir() as tmpdir: if uri.startswith('s3://'): dst = os.path.join(tmpdir, 'tar_file') _files.s3_download(uri, dst) module_path = os.path.join(tmpdir, 'module_dir') os.makedirs(module_path) with tarfile.open(name=dst, mode='r:gz') as t: t.extractall(path=module_path) else: module_path = uri prepare(module_path, name) install(module_path)
def test_s3_download(resource, url, bucket_name, key, dst): region = 'us-west-2' os.environ[_params.REGION_NAME_ENV] = region _files.s3_download(url, dst) chain = call('s3', region_name=region).Bucket(bucket_name).download_file(key, dst) assert resource.mock_calls == chain.call_list()
def s3_download(url, dst): # type: (str, str) -> None """Download a file from S3. This method acts as an alias for :meth:`~sagemaker_containers.beta.framework.files.s3_download` for backward-compatibility purposes. Args: url (str): the S3 URL of the file. dst (str): the destination where the file will be saved. """ _files.s3_download(url, dst)
def test_s3_download(resource, url, bucket_name, key, dst, endpoint): region = "us-west-2" os.environ[_params.REGION_NAME_ENV] = region if endpoint is not None: os.environ[_params.S3_ENDPOINT_URL] = endpoint _files.s3_download(url, dst) chain = (call("s3", region_name=region, endpoint_url=endpoint).Bucket(bucket_name).download_file( key, dst)) assert resource.mock_calls == chain.call_list()
def test_s3_download_wrong_scheme(): with pytest.raises( ValueError, message="Expecting 's3' scheme, got: c in c://my-bucket/my-file"): _files.s3_download("c://my-bucket/my-file", "/tmp/file")