コード例 #1
0
class GCSFileManager(FileManager):
    def __init__(self, client, gcs_bucket, gcs_base_key):
        self._client = check.inst_param(client, "client", storage.client.Client)
        self._gcs_bucket = check.str_param(gcs_bucket, "gcs_bucket")
        self._gcs_base_key = check.str_param(gcs_base_key, "gcs_base_key")
        self._local_handle_cache = {}
        self._temp_file_manager = TempfileManager()

    def copy_handle_to_local_temp(self, file_handle):
        self._download_if_not_cached(file_handle)
        return self._get_local_path(file_handle)

    def _download_if_not_cached(self, file_handle):
        if not self._file_handle_cached(file_handle):
            # instigate download
            temp_file_obj = self._temp_file_manager.tempfile()
            temp_name = temp_file_obj.name
            bucket_obj = self._client.bucket(file_handle.gcs_bucket)
            bucket_obj.blob(file_handle.gcs_key).download_to_file(temp_file_obj)
            self._local_handle_cache[file_handle.gcs_path] = temp_name

        return file_handle

    @contextmanager
    def read(self, file_handle, mode="rb"):
        check.inst_param(file_handle, "file_handle", GCSFileHandle)
        check.str_param(mode, "mode")
        check.param_invariant(mode in {"r", "rb"}, "mode")

        self._download_if_not_cached(file_handle)

        with open(self._get_local_path(file_handle), mode) as file_obj:
            yield file_obj

    def _file_handle_cached(self, file_handle):
        return file_handle.gcs_path in self._local_handle_cache

    def _get_local_path(self, file_handle):
        return self._local_handle_cache[file_handle.gcs_path]

    def read_data(self, file_handle):
        with self.read(file_handle, mode="rb") as file_obj:
            return file_obj.read()

    def write_data(self, data, ext=None):
        check.inst_param(data, "data", bytes)
        return self.write(io.BytesIO(data), mode="wb", ext=ext)

    def write(self, file_obj, mode="wb", ext=None):
        check_file_like_obj(file_obj)
        gcs_key = self.get_full_key(str(uuid.uuid4()) + (("." + ext) if ext is not None else ""))
        bucket_obj = self._client.bucket(self._gcs_bucket)
        bucket_obj.blob(gcs_key).upload_from_file(file_obj)
        return GCSFileHandle(self._gcs_bucket, gcs_key)

    def get_full_key(self, file_key):
        return "{base_key}/{file_key}".format(base_key=self._gcs_base_key, file_key=file_key)

    def delete_local_temp(self):
        self._temp_file_manager.close()
コード例 #2
0
class ADLS2FileManager(FileManager):
    def __init__(self, adls2_client, file_system, prefix):
        self._client = adls2_client
        self._file_system = check.str_param(file_system, 'file_system')
        self._prefix = check.str_param(prefix, 'prefix')
        self._local_handle_cache = {}
        self._temp_file_manager = TempfileManager()

    def copy_handle_to_local_temp(self, file_handle):
        self._download_if_not_cached(file_handle)
        return self._get_local_path(file_handle)

    def _download_if_not_cached(self, file_handle):
        if not self._file_handle_cached(file_handle):
            # instigate download
            temp_file_obj = self._temp_file_manager.tempfile()
            temp_name = temp_file_obj.name
            file = self._client.get_file_client(
                file_system=file_handle.file_system, file_path=file_handle.key,
            )
            download = file.download_file()
            with open(temp_name, 'wb') as file_obj:
                download.readinto(file_obj)
            self._local_handle_cache[file_handle.adls2_path] = temp_name

        return file_handle

    @contextmanager
    def read(self, file_handle, mode='rb'):
        check.inst_param(file_handle, 'file_handle', ADLS2FileHandle)
        check.str_param(mode, 'mode')
        check.param_invariant(mode in {'r', 'rb'}, 'mode')

        self._download_if_not_cached(file_handle)

        with open(self._get_local_path(file_handle), mode) as file_obj:
            yield file_obj

    def _file_handle_cached(self, file_handle):
        return file_handle.adls2_path in self._local_handle_cache

    def _get_local_path(self, file_handle):
        return self._local_handle_cache[file_handle.adls2_path]

    def read_data(self, file_handle):
        with self.read(file_handle, mode='rb') as file_obj:
            return file_obj.read()

    def write_data(self, data, ext=None):
        check.inst_param(data, 'data', bytes)
        return self.write(io.BytesIO(data), mode='wb', ext=ext)

    def write(self, file_obj, mode='wb', ext=None):  # pylint: disable=unused-argument
        check_file_like_obj(file_obj)
        adls2_key = self.get_full_key(str(uuid.uuid4()) + (('.' + ext) if ext is not None else ''))
        adls2_file = self._client.get_file_client(
            file_system=self._file_system, file_path=adls2_key
        )
        adls2_file.upload_data(file_obj, overwrite=True)
        return ADLS2FileHandle(self._client.account_name, self._file_system, adls2_key)

    def get_full_key(self, file_key):
        return '{base_key}/{file_key}'.format(base_key=self._prefix, file_key=file_key)

    def delete_local_temp(self):
        self._temp_file_manager.close()
コード例 #3
0
class GCSFileManager(FileManager):
    def __init__(self, client, gcs_bucket, gcs_base_key):
        self._client = check.inst_param(client, 'client',
                                        storage.client.Client)
        self._gcs_bucket = check.str_param(gcs_bucket, 'gcs_bucket')
        self._gcs_base_key = check.str_param(gcs_base_key, 'gcs_base_key')
        self._local_handle_cache = {}
        self._temp_file_manager = TempfileManager()

    def copy_handle_to_local_temp(self, file_handle):
        self._download_if_not_cached(file_handle)
        return self._get_local_path(file_handle)

    def _download_if_not_cached(self, file_handle):
        if not self._file_handle_cached(file_handle):
            # instigate download
            temp_file_obj = self._temp_file_manager.tempfile()
            temp_name = temp_file_obj.name
            bucket_obj = self._client.get_bucket(file_handle.gcs_bucket)
            bucket_obj.blob(
                file_handle.gcs_key).download_to_file(temp_file_obj)
            self._local_handle_cache[file_handle.gcs_path] = temp_name

        return file_handle

    @contextmanager
    def read(self, file_handle, mode='rb'):
        check.inst_param(file_handle, 'file_handle', GCSFileHandle)
        check.str_param(mode, 'mode')
        check.param_invariant(mode in {'r', 'rb'}, 'mode')

        self._download_if_not_cached(file_handle)

        with open(self._get_local_path(file_handle), mode) as file_obj:
            yield file_obj

    def _file_handle_cached(self, file_handle):
        return file_handle.gcs_path in self._local_handle_cache

    def _get_local_path(self, file_handle):
        return self._local_handle_cache[file_handle.gcs_path]

    def read_data(self, file_handle):
        with self.read(file_handle, mode='rb') as file_obj:
            return file_obj.read()

    def write_data(self, data):
        check.inst_param(data, 'data', bytes)
        return self.write(io.BytesIO(data), mode='wb')

    def write(self, file_obj, mode='wb'):
        check_file_like_obj(file_obj)
        gcs_key = self.get_full_key(str(uuid.uuid4()))
        bucket_obj = self._client.get_bucket(self._gcs_bucket)
        bucket_obj.blob(gcs_key).upload_from_file(file_obj)
        return GCSFileHandle(self._gcs_bucket, gcs_key)

    def get_full_key(self, file_key):
        return '{base_key}/{file_key}'.format(base_key=self._gcs_base_key,
                                              file_key=file_key)

    def delete_local_temp(self):
        self._temp_file_manager.close()
コード例 #4
0
ファイル: file_manager.py プロジェクト: zuik/dagster
class S3FileManager(FileManager):
    def __init__(self, s3_session, s3_bucket, s3_base_key):
        self._s3_session = s3_session
        self._s3_bucket = check.str_param(s3_bucket, "s3_bucket")
        self._s3_base_key = check.str_param(s3_base_key, "s3_base_key")
        self._local_handle_cache = {}
        self._temp_file_manager = TempfileManager()

    def copy_handle_to_local_temp(self, file_handle):
        self._download_if_not_cached(file_handle)
        return self._get_local_path(file_handle)

    def _download_if_not_cached(self, file_handle):
        if not self._file_handle_cached(file_handle):
            # instigate download
            temp_file_obj = self._temp_file_manager.tempfile()
            temp_name = temp_file_obj.name
            self._s3_session.download_file(Bucket=file_handle.s3_bucket,
                                           Key=file_handle.s3_key,
                                           Filename=temp_name)
            self._local_handle_cache[file_handle.s3_path] = temp_name

        return file_handle

    @contextmanager
    def read(self, file_handle, mode="rb"):
        check.inst_param(file_handle, "file_handle", S3FileHandle)
        check.str_param(mode, "mode")
        check.param_invariant(mode in {"r", "rb"}, "mode")

        self._download_if_not_cached(file_handle)

        with open(self._get_local_path(file_handle), mode) as file_obj:
            yield file_obj

    def _file_handle_cached(self, file_handle):
        return file_handle.s3_path in self._local_handle_cache

    def _get_local_path(self, file_handle):
        return self._local_handle_cache[file_handle.s3_path]

    def read_data(self, file_handle):
        with self.read(file_handle, mode="rb") as file_obj:
            return file_obj.read()

    def write_data(self, data, ext=None):
        check.inst_param(data, "data", bytes)
        return self.write(io.BytesIO(data), mode="wb", ext=ext)

    def write(self, file_obj, mode="wb", ext=None):
        check_file_like_obj(file_obj)
        s3_key = self.get_full_key(
            str(uuid.uuid4()) + (("." + ext) if ext is not None else ""))
        self._s3_session.put_object(Body=file_obj,
                                    Bucket=self._s3_bucket,
                                    Key=s3_key)
        return S3FileHandle(self._s3_bucket, s3_key)

    def get_full_key(self, file_key):
        return "{base_key}/{file_key}".format(base_key=self._s3_base_key,
                                              file_key=file_key)

    def delete_local_temp(self):
        self._temp_file_manager.close()
コード例 #5
0
class S3FileManager(FileManager):
    def __init__(self, s3_session, s3_bucket, s3_base_key):
        self._s3_session = s3_session
        self._s3_bucket = check.str_param(s3_bucket, 's3_bucket')
        self._s3_base_key = check.str_param(s3_base_key, 's3_base_key')
        self._local_handle_cache = {}
        self._temp_file_manager = TempfileManager()

    def copy_handle_to_local_temp(self, file_handle):
        self._download_if_not_cached(file_handle)
        return self._get_local_path(file_handle)

    def _download_if_not_cached(self, file_handle):
        if not self._file_handle_cached(file_handle):
            # instigate download
            temp_file_obj = self._temp_file_manager.tempfile()
            temp_name = temp_file_obj.name
            self._s3_session.download_file(Bucket=file_handle.s3_bucket,
                                           Key=file_handle.s3_key,
                                           Filename=temp_name)
            self._local_handle_cache[file_handle.s3_path] = temp_name

        return file_handle

    @contextmanager
    def read(self, file_handle, mode='rb'):
        check.inst_param(file_handle, 'file_handle', S3FileHandle)
        check.str_param(mode, 'mode')
        check.param_invariant(mode in {'r', 'rb'}, 'mode')

        self._download_if_not_cached(file_handle)

        with open(self._get_local_path(file_handle), mode) as file_obj:
            yield file_obj

    def _file_handle_cached(self, file_handle):
        return file_handle.s3_path in self._local_handle_cache

    def _get_local_path(self, file_handle):
        return self._local_handle_cache[file_handle.s3_path]

    def read_data(self, file_handle):
        with self.read(file_handle, mode='rb') as file_obj:
            return file_obj.read()

    def write_data(self, data):
        check.inst_param(data, 'data', bytes)
        return self.write(io.BytesIO(data), mode='wb')

    def write(self, file_obj, mode='wb'):
        check_file_like_obj(file_obj)
        s3_key = self.get_full_key(str(uuid.uuid4()))
        self._s3_session.put_object(Body=file_obj,
                                    Bucket=self._s3_bucket,
                                    Key=s3_key)
        return S3FileHandle(self._s3_bucket, s3_key)

    def get_full_key(self, file_key):
        return '{base_key}/{file_key}'.format(base_key=self._s3_base_key,
                                              file_key=file_key)

    def delete_local_temp(self):
        self._temp_file_manager.close()