예제 #1
0
class GCSClient:
    def __init__(self,
                 workspace_directory: str = '',
                 task_filters: List[str] = [],
                 tqdm_disable: bool = False):
        """must set $GCS_CREDENTIAL"""
        self.workspace_directory = workspace_directory
        self.task_filters = task_filters
        self.tqdm_disable = tqdm_disable
        self.gcs_client = GCSConfig().get_gcs_client()

    def get_tasks(self) -> Dict[int, Dict[str, Any]]:
        """Load all task_log from GCS"""
        files = self._get_gcs_objects()
        tasks = {}
        for i, x in enumerate(tqdm(files, disable=self.tqdm_disable)):
            n = x.split('/')[-1]
            if self.task_filters and not [
                    f for f in self.task_filters if f in n
            ]:
                continue
            n = n.split('_')
            meta = self._get_gcs_object_info(x)
            tasks[i] = {
                'task_name':
                '_'.join(n[:-1]),
                'task_params':
                pickle.load(
                    self.gcs_client.download(
                        x.replace('task_log', 'task_params'))),
                'task_log':
                pickle.load(self.gcs_client.download(x)),
                'last_modified':
                datetime.strptime(meta['updated'].split('.')[0],
                                  '%Y-%m-%dT%H:%M:%S'),
                'task_hash':
                n[-1].split('.')[0]
            }
        return tasks

    def _get_gcs_objects(self) -> List[str]:
        """get GCS objects"""
        return self.gcs_client.listdir(
            os.path.join(self.workspace_directory, 'log/task_log'))

    def _get_gcs_object_info(self, x: str) -> Dict[str, str]:
        """get GCS object meta data"""
        bucket, obj = self.gcs_client._path_to_bucket_and_key(x)
        return self.gcs_client.client.objects().get(bucket=bucket,
                                                    object=obj).execute()

    def to_absolute_path(self, x: str) -> str:
        """get GCS file path"""
        x = x.lstrip('.').lstrip('/')
        if self.workspace_directory.rstrip('/').split('/')[-1] == x.split(
                '/')[0]:
            x = '/'.join(x.split('/')[1:])
        return x
예제 #2
0
 def __init__(self,
              workspace_directory: str = '',
              task_filters: List[str] = [],
              tqdm_disable: bool = False):
     """must set $GCS_CREDENTIAL"""
     self.workspace_directory = workspace_directory
     self.task_filters = task_filters
     self.tqdm_disable = tqdm_disable
     self.gcs_client = GCSConfig().get_gcs_client()
예제 #3
0
 def get_timestamp(path: str) -> datetime:
     if path.startswith('s3://'):
         return S3Config().get_s3_client().get_key(path).last_modified
     elif path.startswith('gs://'):
         # for gcs object
         # should PR to luigi
         bucket, obj = GCSConfig().get_gcs_client()._path_to_bucket_and_key(path)
         result = GCSConfig().get_gcs_client().client.objects().get(bucket=bucket, object=obj).execute()
         return result['updated']
     else:
         raise
예제 #4
0
 def __init__(self,
              workspace_directory: str = '',
              task_filters: List[str] = [],
              tqdm_disable: bool = False,
              use_cache: bool = True):
     """must set $GCS_CREDENTIAL"""
     self.workspace_directory = workspace_directory
     self.task_filters = task_filters
     self.tqdm_disable = tqdm_disable
     self.gcs_client = GCSConfig().get_gcs_client()
     self.local_cache = LocalCache(workspace_directory, use_cache)
     self.use_cache = use_cache
예제 #5
0
 def exists(path: str) -> bool:
     if path.startswith('s3://'):
         return S3Config().get_s3_client().exists(path)
     elif path.startswith('gs://'):
         return GCSConfig().get_gcs_client().exists(path)
     else:
         raise
예제 #6
0
 def get_object_storage_target(path: str, format: Format) -> luigi.Target:
     if path.startswith('s3://'):
         return luigi.contrib.s3.S3Target(path, client=S3Config().get_s3_client(), format=format)
     elif path.startswith('gs://'):
         return luigi.contrib.gcs.GCSTarget(path, client=GCSConfig().get_gcs_client(), format=format)
     else:
         raise
예제 #7
0
 def test_get_gcs_client_with_json(self):
     mock = MagicMock()
     json_str = '{"test": 1}'
     os.environ['env_name'] = json_str
     with patch('luigi.contrib.gcs.GCSClient'):
         with patch('google.oauth2.service_account.Credentials.from_service_account_info', mock):
             GCSConfig(gcs_credential_name='env_name')._get_gcs_client()
             self.assertEqual(dict(test=1), mock.call_args[0][0])
예제 #8
0
class GCSZipClient(ZipClient):
    def __init__(self, file_path: str, temporary_directory: str) -> None:
        self._file_path = file_path
        self._temporary_directory = temporary_directory
        self._client = GCSConfig().get_gcs_client()

    def exists(self) -> bool:
        return self._client.exists(self._file_path)

    def make_archive(self) -> None:
        extension = os.path.splitext(self._file_path)[1]
        shutil.make_archive(base_name=self._temporary_directory,
                            format=extension[1:],
                            root_dir=self._temporary_directory)
        self._client.put(self._temporary_file_path(), self._file_path)

    def unpack_archive(self) -> None:
        os.makedirs(self._temporary_directory, exist_ok=True)
        self._client.get(self._file_path, self._temporary_file_path())
        _unzip_file(filename=self._temporary_file_path(),
                    extract_dir=self._temporary_directory)

    def remove(self) -> None:
        self._client.remove(self._file_path)

    @property
    def path(self) -> str:
        return self._file_path

    def _temporary_file_path(self):
        extension = os.path.splitext(self._file_path)[1]
        base_name = self._temporary_directory
        if base_name.endswith('/'):
            base_name = base_name[:-1]
        return base_name + extension
예제 #9
0
 def test_get_gcs_client_with_file_path(self):
     mock = MagicMock()
     file_path = 'test.json'
     os.environ['env_name'] = file_path
     with patch('luigi.contrib.gcs.GCSClient'):
         with patch('google.oauth2.service_account.Credentials.from_service_account_file', mock):
             with patch('os.path.isfile', return_value=True):
                 GCSConfig(gcs_credential_name='env_name')._get_gcs_client()
                 self.assertEqual(file_path, mock.call_args[0][0])
예제 #10
0
 def test_get_gcs_client_without_gcs_credential_name(self):
     mock = MagicMock()
     discover_path = 'discover_cache.json'
     os.environ['env_name'] = ''
     os.environ['discover_path'] = discover_path
     with open(f'{discover_path}', 'w') as f:
         f.write('{}')
     with patch('luigi.contrib.gcs.GCSClient', mock):
         with patch('fcntl.flock'):
            GCSConfig(gcs_credential_name='env_name', discover_cache_local_path=discover_path).get_gcs_client()
            self.assertEqual(dict(oauth_credentials=None, descriptor='{}'), mock.call_args[1])
예제 #11
0
 def test_get_gcs_client_with_json(self):
     mock = MagicMock()
     json_str = '{"test": 1}'
     discover_path = 'discover_cache.json'
     os.environ['env_name'] = json_str
     os.environ['discover_path'] = discover_path
     with open(f'{discover_path}', 'w') as f:
         f.write('{}')
     with patch('luigi.contrib.gcs.GCSClient'):
         with patch('google.oauth2.service_account.Credentials.from_service_account_info', mock):
            GCSConfig(gcs_credential_name='env_name', discover_cache_local_path=discover_path).get_gcs_client()
            self.assertEqual(dict(test=1), mock.call_args[0][0])
예제 #12
0
 def test_get_gcs_client_with_file_path(self):
     mock = MagicMock()
     file_path = 'test.json'
     discover_path = 'discover_cache.json'
     os.environ['env_name'] = file_path
     os.environ['discover_path'] = discover_path
     with open(f'{discover_path}', 'w') as f:
         f.write('{}')
     with patch('luigi.contrib.gcs.GCSClient'):
         with patch('google.oauth2.service_account.Credentials.from_service_account_file', mock):
             with patch('os.path.isfile', return_value=True):
                GCSConfig(gcs_credential_name='env_name', discover_cache_local_path=discover_path).get_gcs_client()
                self.assertEqual(file_path, mock.call_args[0][0])
예제 #13
0
 def test_get_gcs_client_without_gcs_credential_name(self):
     mock = MagicMock()
     os.environ['env_name'] = ''
     with patch('luigi.contrib.gcs.GCSClient', mock):
         GCSConfig(gcs_credential_name='env_name')._get_gcs_client()
         self.assertEqual(dict(oauth_credentials=None), mock.call_args[1])
예제 #14
0
 def __init__(self, file_path: str, temporary_directory: str) -> None:
     self._file_path = file_path
     self._temporary_directory = temporary_directory
     self._client = GCSConfig().get_gcs_client()
예제 #15
0
class GCSClient:
    def __init__(self,
                 workspace_directory: str = '',
                 task_filters: List[str] = [],
                 tqdm_disable: bool = False,
                 use_cache: bool = True):
        """must set $GCS_CREDENTIAL"""
        self.workspace_directory = workspace_directory
        self.task_filters = task_filters
        self.tqdm_disable = tqdm_disable
        self.gcs_client = GCSConfig().get_gcs_client()
        self.local_cache = LocalCache(workspace_directory, use_cache)
        self.use_cache = use_cache

    def get_tasks(self) -> List[Dict[str, Any]]:
        """Load all task_log from GCS"""
        files = self._get_gcs_objects()
        tasks_list = list()
        for x in tqdm(files, disable=self.tqdm_disable):
            n = x.split('/')[-1]
            if self.task_filters and not [
                    f for f in self.task_filters if f in n
            ]:
                continue
            n = n.split('_')

            if self.use_cache:
                cache = self.local_cache.get(x)
                if cache:
                    tasks_list.append(cache)
                    continue

            try:
                meta = self._get_gcs_object_info(x)
                params = {
                    'task_name':
                    '_'.join(n[:-1]),
                    'task_params':
                    pickle.load(
                        self.gcs_client.download(
                            x.replace('task_log', 'task_params'))),
                    'task_log':
                    pickle.load(self.gcs_client.download(x)),
                    'last_modified':
                    datetime.strptime(meta['updated'].split('.')[0],
                                      '%Y-%m-%dT%H:%M:%S'),
                    'task_hash':
                    n[-1].split('.')[0]
                }
                tasks_list.append(params)
                if self.use_cache:
                    self.local_cache.dump(x, params)
            except Exception:
                continue

        if len(tasks_list) != len(list(files)):
            warnings.warn(
                f'[NOT FOUND LOGS] target file: {len(list(files))}, found log file: {len(tasks_list)}'
            )

        return tasks_list

    def _get_gcs_objects(self) -> List[str]:
        """get GCS objects"""
        return self.gcs_client.listdir(
            os.path.join(self.workspace_directory, 'log/task_log'))

    def _get_gcs_object_info(self, x: str) -> Dict[str, str]:
        """get GCS object meta data"""
        bucket, obj = self.gcs_client._path_to_bucket_and_key(x)
        return self.gcs_client.client.objects().get(bucket=bucket,
                                                    object=obj).execute()

    def to_absolute_path(self, x: str) -> str:
        """get GCS file path"""
        x = x.lstrip('.').lstrip('/')
        if self.workspace_directory.rstrip('/').split('/')[-1] == x.split(
                '/')[0]:
            x = '/'.join(x.split('/')[1:])
        return x