Beispiel #1
0
def test_s3_file_given_folder(mock_check_output):
    s3 = boto3.client('s3', region_name='us-east-1')
    s3.create_bucket(Bucket='mybucket')

    s3.put_object(Bucket='mybucket', Key="some_file.txt", Body="CONTENT")

    with tempfile.TemporaryDirectory() as t:
        destination = os.path.join(t, "some_file.txt")
        with downloader.download_manager("s3://mybucket/some_file.txt", destination) as p:
            assert p == destination

    s3.put_object(Bucket='mybucket', Key="some_folder/some_other_file.txt", Body="CONTENT")

    with tempfile.TemporaryDirectory() as t:
        destination = os.path.join(t, "some_other_file.txt")
        with downloader.download_manager("s3://mybucket/some_folder/some_other_file.txt", destination) as p:
            assert p == destination

    with tempfile.TemporaryDirectory() as t:
        destination = os.path.join(t, "some_folder")
        with downloader.download_manager("s3://mybucket/some_folder", destination) as p:
            assert p == destination

    with tempfile.TemporaryDirectory() as t:
        destination = os.path.join(t, "some_folder")
        with downloader.download_manager("s3://mybucket/some_folder/", destination) as p:
            assert p == destination
Beispiel #2
0
def test_edge_cases():
    with pytest.raises(TypeError):
        with downloader.download_manager(None) as _:
            pass

    with pytest.raises(ValueError):
        with downloader.download_manager('') as _:
            pass
Beispiel #3
0
    def process_resources(
            self, resources: Dict[str, Union[str, ClusterResource]],
            folder: str) -> Dict[str, Union[str, ClusterResource]]:
        """Download resources that are not tagged with '!cluster'
        into a given directory.

        Parameters
        ----------
        resources: Dict[str, Union[str, ClusterResource]]
            The resources dict
        folder: str
            The directory where the remote resources
            will be downloaded.

        Returns
        -------
        Dict[str, Union[str, ClusterResource]]
            The resources dict where the remote urls that
            don't contain '!cluster' point now to the local
            path where the resource was downloaded.

        """
        # Keep the resources temporary dict for later cleanup
        ret = {}
        for k, v in resources.items():
            if not isinstance(v, ClusterResource):
                with download_manager(v, os.path.join(folder, k)) as path:
                    ret[k] = path
            else:
                ret[k] = v

        return ret
Beispiel #4
0
def test_s3_inexistent_path():
    path = "s3://inexistent_bucket/file.txt"
    with pytest.raises(ValueError) as excinfo:
        with downloader.download_manager(path) as p:
            _ = p

    assert f"S3 url: '{path}' is not available" in str(excinfo.value)
Beispiel #5
0
def test_invalid_local_file():
    path = "/some/unexistent/path/!@#$%RVMCDOCMSxxxxoemdow"

    if not os.path.exists(path):
        with pytest.raises(ValueError) as excinfo:
            with downloader.download_manager(path) as p:
                assert path == p

        assert 'does not exist locally.' in str(excinfo.value)
Beispiel #6
0
def load_state_from_file(path: str,
                         map_location=None,
                         pickle_module=dill,
                         **pickle_load_args) -> State:
    """Load state from the given path

    Loads a flambe save directory, pickled save object, or a compressed
    version of one of these two formats (using tar + gz). Will
    automatically infer the type of save format and if the directory
    structure is used, the serialization protocol version as well.

    Parameters
    ----------
    path : str
        Path to the save file or directory
    map_location : type
        Location (device) where items will be moved. ONLY used when the
        directory save format is used. See torch.load documentation for
        more details (the default is None).
    pickle_module : type
        Pickle module that has load and dump methods; dump should
        accept a pickle_protocol parameter (the default is dill).
    **pickle_load_args : type
        Additional args that `pickle_module` should use to load; see
        torch.load documentation for more details

    Returns
    -------
    State
        state_dict that can be loaded into a compatible Component

    """
    with download_manager(path) as path:
        state = State()
        state._metadata = OrderedDict({FLAMBE_DIRECTORIES_KEY: set()})
        temp = None
        try:
            if not os.path.isdir(path) and tarfile.is_tarfile(path):
                temp = tempfile.TemporaryDirectory()
                with tarfile.open(path, 'r:gz') as tar_gz:
                    tar_gz.extractall(path=temp.name)
                    expected_name = tar_gz.getnames()[0]
                path = os.path.join(temp.name, expected_name)
            if os.path.isdir(path):
                for current_dir, subdirs, files in os.walk(path):
                    prefix = _extract_prefix(path, current_dir)
                    protocol_version_file = os.path.join(
                        current_dir, PROTOCOL_VERSION_FILE_NAME)
                    with open(protocol_version_file) as f_proto:
                        saved_protocol_version = int(f_proto.read())
                        if saved_protocol_version > HIGHEST_SERIALIZATION_PROTOCOL_VERSION:
                            raise Exception(
                                'This version of Flambe only supports serialization'
                                f'protocol versions <= '
                                f'{HIGHEST_SERIALIZATION_PROTOCOL_VERSION}. '
                                'Found version '
                                f'{saved_protocol_version} at {protocol_version_file}'
                            )
                    component_state = torch.load(
                        os.path.join(current_dir, STATE_FILE_NAME),
                        map_location, pickle_module, **pickle_load_args)
                    with open(os.path.join(current_dir,
                                           VERSION_FILE_NAME)) as f_version:
                        version_info = f_version.read()
                        class_name, version = version_info.split(':')
                    with open(os.path.join(current_dir,
                                           SOURCE_FILE_NAME)) as f_source:
                        source = f_source.read()
                    with open(os.path.join(current_dir,
                                           CONFIG_FILE_NAME)) as f_config:
                        config = f_config.read()
                    with open(os.path.join(current_dir, STASH_FILE_NAME),
                              'rb') as f_stash:
                        stash = torch.load(f_stash, map_location,
                                           pickle_module, **pickle_load_args)
                    local_metadata = {
                        VERSION_KEY: version,
                        FLAMBE_CLASS_KEY: class_name,
                        FLAMBE_SOURCE_KEY: source,
                        FLAMBE_CONFIG_KEY: config
                    }
                    if len(stash) > 0:
                        local_metadata[FLAMBE_STASH_KEY] = stash
                    full_prefix = prefix + STATE_DICT_DELIMETER if prefix != '' else prefix
                    _prefix_keys(component_state, full_prefix)
                    state.update(component_state)
                    if hasattr(component_state, '_metadata'):
                        _prefix_keys(component_state._metadata, full_prefix)
                        # Load torch.nn.Module metadata
                        state._metadata.update(component_state._metadata)
                    # Load flambe.nn.Module metadata
                    state._metadata[prefix] = local_metadata
                    state._metadata[FLAMBE_DIRECTORIES_KEY].add(prefix)
            else:
                with open(path, 'rb') as f_pkl:
                    state = pickle_module.load(f_pkl)
        except Exception as e:
            raise e
        finally:
            if temp is not None:
                temp.cleanup()
        return state
Beispiel #7
0
def test_s3_inexistent_path():
    with pytest.raises(ValueError):
        with downloader.download_manager(
                "s3://inexistent_bucket/file.txt") as p:
            _ = p
Beispiel #8
0
def test_invalid_protocol():
    with pytest.raises(ValueError):
        with downloader.download_manager("sftp://something") as p:
            _ = p
Beispiel #9
0
def test_local_file():
    path = __file__
    with downloader.download_manager(path) as p:
        assert path == p
Beispiel #10
0
def test_invalid_protocol():
    with pytest.raises(ValueError) as excinfo:
        with downloader.download_manager("sftp://something") as p:
            _ = p
    assert 'Only S3 and http/https URLs are supported.' in str(excinfo.value)