def test_s3_file_given_folder(mock_check_output): s3 = boto3.client('s3', region_name='us-east-1') s3.create_bucket(Bucket='mybucket') s3.put_object(Bucket='mybucket', Key="some_file.txt", Body="CONTENT") with tempfile.TemporaryDirectory() as t: destination = os.path.join(t, "some_file.txt") with downloader.download_manager("s3://mybucket/some_file.txt", destination) as p: assert p == destination s3.put_object(Bucket='mybucket', Key="some_folder/some_other_file.txt", Body="CONTENT") with tempfile.TemporaryDirectory() as t: destination = os.path.join(t, "some_other_file.txt") with downloader.download_manager("s3://mybucket/some_folder/some_other_file.txt", destination) as p: assert p == destination with tempfile.TemporaryDirectory() as t: destination = os.path.join(t, "some_folder") with downloader.download_manager("s3://mybucket/some_folder", destination) as p: assert p == destination with tempfile.TemporaryDirectory() as t: destination = os.path.join(t, "some_folder") with downloader.download_manager("s3://mybucket/some_folder/", destination) as p: assert p == destination
def test_edge_cases(): with pytest.raises(TypeError): with downloader.download_manager(None) as _: pass with pytest.raises(ValueError): with downloader.download_manager('') as _: pass
def process_resources( self, resources: Dict[str, Union[str, ClusterResource]], folder: str) -> Dict[str, Union[str, ClusterResource]]: """Download resources that are not tagged with '!cluster' into a given directory. Parameters ---------- resources: Dict[str, Union[str, ClusterResource]] The resources dict folder: str The directory where the remote resources will be downloaded. Returns ------- Dict[str, Union[str, ClusterResource]] The resources dict where the remote urls that don't contain '!cluster' point now to the local path where the resource was downloaded. """ # Keep the resources temporary dict for later cleanup ret = {} for k, v in resources.items(): if not isinstance(v, ClusterResource): with download_manager(v, os.path.join(folder, k)) as path: ret[k] = path else: ret[k] = v return ret
def test_s3_inexistent_path(): path = "s3://inexistent_bucket/file.txt" with pytest.raises(ValueError) as excinfo: with downloader.download_manager(path) as p: _ = p assert f"S3 url: '{path}' is not available" in str(excinfo.value)
def test_invalid_local_file(): path = "/some/unexistent/path/!@#$%RVMCDOCMSxxxxoemdow" if not os.path.exists(path): with pytest.raises(ValueError) as excinfo: with downloader.download_manager(path) as p: assert path == p assert 'does not exist locally.' in str(excinfo.value)
def load_state_from_file(path: str, map_location=None, pickle_module=dill, **pickle_load_args) -> State: """Load state from the given path Loads a flambe save directory, pickled save object, or a compressed version of one of these two formats (using tar + gz). Will automatically infer the type of save format and if the directory structure is used, the serialization protocol version as well. Parameters ---------- path : str Path to the save file or directory map_location : type Location (device) where items will be moved. ONLY used when the directory save format is used. See torch.load documentation for more details (the default is None). pickle_module : type Pickle module that has load and dump methods; dump should accept a pickle_protocol parameter (the default is dill). **pickle_load_args : type Additional args that `pickle_module` should use to load; see torch.load documentation for more details Returns ------- State state_dict that can be loaded into a compatible Component """ with download_manager(path) as path: state = State() state._metadata = OrderedDict({FLAMBE_DIRECTORIES_KEY: set()}) temp = None try: if not os.path.isdir(path) and tarfile.is_tarfile(path): temp = tempfile.TemporaryDirectory() with tarfile.open(path, 'r:gz') as tar_gz: tar_gz.extractall(path=temp.name) expected_name = tar_gz.getnames()[0] path = os.path.join(temp.name, expected_name) if os.path.isdir(path): for current_dir, subdirs, files in os.walk(path): prefix = _extract_prefix(path, current_dir) protocol_version_file = os.path.join( current_dir, PROTOCOL_VERSION_FILE_NAME) with open(protocol_version_file) as f_proto: saved_protocol_version = int(f_proto.read()) if saved_protocol_version > HIGHEST_SERIALIZATION_PROTOCOL_VERSION: raise Exception( 'This version of Flambe only supports serialization' f'protocol versions <= ' f'{HIGHEST_SERIALIZATION_PROTOCOL_VERSION}. ' 'Found version ' f'{saved_protocol_version} at {protocol_version_file}' ) component_state = torch.load( os.path.join(current_dir, STATE_FILE_NAME), map_location, pickle_module, **pickle_load_args) with open(os.path.join(current_dir, VERSION_FILE_NAME)) as f_version: version_info = f_version.read() class_name, version = version_info.split(':') with open(os.path.join(current_dir, SOURCE_FILE_NAME)) as f_source: source = f_source.read() with open(os.path.join(current_dir, CONFIG_FILE_NAME)) as f_config: config = f_config.read() with open(os.path.join(current_dir, STASH_FILE_NAME), 'rb') as f_stash: stash = torch.load(f_stash, map_location, pickle_module, **pickle_load_args) local_metadata = { VERSION_KEY: version, FLAMBE_CLASS_KEY: class_name, FLAMBE_SOURCE_KEY: source, FLAMBE_CONFIG_KEY: config } if len(stash) > 0: local_metadata[FLAMBE_STASH_KEY] = stash full_prefix = prefix + STATE_DICT_DELIMETER if prefix != '' else prefix _prefix_keys(component_state, full_prefix) state.update(component_state) if hasattr(component_state, '_metadata'): _prefix_keys(component_state._metadata, full_prefix) # Load torch.nn.Module metadata state._metadata.update(component_state._metadata) # Load flambe.nn.Module metadata state._metadata[prefix] = local_metadata state._metadata[FLAMBE_DIRECTORIES_KEY].add(prefix) else: with open(path, 'rb') as f_pkl: state = pickle_module.load(f_pkl) except Exception as e: raise e finally: if temp is not None: temp.cleanup() return state
def test_s3_inexistent_path(): with pytest.raises(ValueError): with downloader.download_manager( "s3://inexistent_bucket/file.txt") as p: _ = p
def test_invalid_protocol(): with pytest.raises(ValueError): with downloader.download_manager("sftp://something") as p: _ = p
def test_local_file(): path = __file__ with downloader.download_manager(path) as p: assert path == p
def test_invalid_protocol(): with pytest.raises(ValueError) as excinfo: with downloader.download_manager("sftp://something") as p: _ = p assert 'Only S3 and http/https URLs are supported.' in str(excinfo.value)