Пример #1
0
 def __init__(self,
              workspace_directory: str = '',
              task_filters: List[str] = [],
              tqdm_disable: bool = False,
              use_cache: bool = True):
     self.workspace_directory = os.path.abspath(workspace_directory)
     self.task_filters = task_filters
     self.tqdm_disable = tqdm_disable
     self.local_cache = LocalCache(workspace_directory, use_cache)
     self.use_cache = use_cache
Пример #2
0
 def __init__(self, workspace_directory: str = '', task_filters: List[str] = [], tqdm_disable: bool = False, use_cache: bool = True):
     self.workspace_directory = workspace_directory
     self.task_filters = task_filters
     self.tqdm_disable = tqdm_disable
     self.bucket_name = workspace_directory.replace('s3://', '').split('/')[0]
     self.prefix = '/'.join(workspace_directory.replace('s3://', '').split('/')[1:])
     self.resource = boto3.resource('s3')
     self.s3client = Session().client('s3')
     self.local_cache = LocalCache(workspace_directory, use_cache)
     self.use_cache = use_cache
Пример #3
0
 def __init__(self,
              workspace_directory: str = '',
              task_filters: List[str] = [],
              tqdm_disable: bool = False,
              use_cache: bool = True):
     """must set $GCS_CREDENTIAL"""
     self.workspace_directory = workspace_directory
     self.task_filters = task_filters
     self.tqdm_disable = tqdm_disable
     self.gcs_client = GCSConfig().get_gcs_client()
     self.local_cache = LocalCache(workspace_directory, use_cache)
     self.use_cache = use_cache
Пример #4
0
class TestLocalCache(unittest.TestCase):
    def setUp(self):
        self.base_path = './resources'
        self.local_cache = LocalCache(self.base_path, True)

    def test_init(self):
        self.assertTrue(os.path.exists('./thunderbolt'))

    def test_dump_and_get(self):
        target = {'foo': 'bar'}
        self.local_cache.dump('test.pkl', target)
        output = self.local_cache.get('test.pkl')
        self.assertDictEqual(target, output)

    def test_convert_file_path(self):
        output = self.local_cache._convert_file_path('test.pkl')
        target = Path(os.path.join(os.getcwd(), '.thunderbolt', self.base_path.split('/')[-1], 'test.pkl'))
        self.assertEqual(target, output)

    def tearDown(self):
        self.local_cache.clear()
Пример #5
0
class LocalDirectoryClient:
    def __init__(self,
                 workspace_directory: str = '',
                 task_filters: List[str] = [],
                 tqdm_disable: bool = False,
                 use_cache: bool = True):
        self.workspace_directory = os.path.abspath(workspace_directory)
        self.task_filters = task_filters
        self.tqdm_disable = tqdm_disable
        self.local_cache = LocalCache(workspace_directory, use_cache)
        self.use_cache = use_cache

    def get_tasks(self) -> List[Dict[str, Any]]:
        """Load all task_log from workspace_directory."""
        files = {
            str(path)
            for path in Path(
                os.path.join(self.workspace_directory, 'log/task_log')).rglob(
                    '*')
        }
        tasks_list = list()
        for x in tqdm(files, disable=self.tqdm_disable):
            n = x.split('/')[-1]
            if self.task_filters and not [
                    x for x in self.task_filters if x in n
            ]:
                continue
            n = n.split('_')

            if self.use_cache:
                cache = self.local_cache.get(x)
                if cache:
                    tasks_list.append(cache)
                    continue

            try:
                modified = datetime.fromtimestamp(os.stat(x).st_mtime)
                with open(x, 'rb') as f:
                    task_log = pickle.load(f)
                with open(x.replace('task_log', 'task_params'), 'rb') as f:
                    task_params = pickle.load(f)

                params = {
                    'task_name': '_'.join(n[:-1]),
                    'task_params': task_params,
                    'task_log': task_log,
                    'last_modified': modified,
                    'task_hash': n[-1].split('.')[0],
                }
                tasks_list.append(params)
                if self.use_cache:
                    self.local_cache.dump(x, params)
            except Exception:
                continue

        if len(tasks_list) != len(files):
            warnings.warn(
                f'[NOT FOUND LOGS] target file: {len(files)}, found log file: {len(tasks_list)}'
            )

        return tasks_list

    def to_absolute_path(self, x: str) -> str:
        """get file path"""
        x = x.lstrip('.').lstrip('/')
        if self.workspace_directory.rstrip('/').split('/')[-1] == x.split(
                '/')[0]:
            x = '/'.join(x.split('/')[1:])
        x = os.path.join(self.workspace_directory, x)
        return os.path.abspath(x)
Пример #6
0
 def setUp(self):
     self.base_path = './resources'
     self.local_cache = LocalCache(self.base_path, True)
Пример #7
0
class S3Client:
    def __init__(self, workspace_directory: str = '', task_filters: List[str] = [], tqdm_disable: bool = False, use_cache: bool = True):
        self.workspace_directory = workspace_directory
        self.task_filters = task_filters
        self.tqdm_disable = tqdm_disable
        self.bucket_name = workspace_directory.replace('s3://', '').split('/')[0]
        self.prefix = '/'.join(workspace_directory.replace('s3://', '').split('/')[1:])
        self.resource = boto3.resource('s3')
        self.s3client = Session().client('s3')
        self.local_cache = LocalCache(workspace_directory, use_cache)
        self.use_cache = use_cache

    def get_tasks(self) -> List[Dict[str, Any]]:
        """Load all task_log from S3"""
        files = self._get_s3_keys([], '')
        tasks_list = list()
        for x in tqdm(files, disable=self.tqdm_disable):
            n = x['Key'].split('/')[-1]
            if self.task_filters and not [x for x in self.task_filters if x in n]:
                continue
            n = n.split('_')

            if self.use_cache:
                cache = self.local_cache.get(x['key'])
                if cache:
                    tasks_list.append(cache)
                    continue

            try:
                params = {
                    'task_name': '_'.join(n[:-1]),
                    'task_params': pickle.loads(self.resource.Object(self.bucket_name, x['Key'].replace('task_log', 'task_params')).get()['Body'].read()),
                    'task_log': pickle.loads(self.resource.Object(self.bucket_name, x['Key']).get()['Body'].read()),
                    'last_modified': x['LastModified'],
                    'task_hash': n[-1].split('.')[0]
                }
                tasks_list.append(params)
                if self.use_cache:
                    self.local_cache.dump(x['key'], params)
            except Exception:
                continue

        if len(tasks_list) != len(files):
            warnings.warn(f'[NOT FOUND LOGS] target file: {len(files)}, found log file: {len(tasks_list)}')

        return tasks_list

    def _get_s3_keys(self, keys: List[Dict[str, Any]] = [], marker: str = '') -> List[Dict[str, Any]]:
        """Recursively get Key from S3.

        Using s3client api by boto module.
        Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html

        Args:
            keys: The object key to get. Increases with recursion.
            marker: S3 marker. The recursion ends when this is gone.

        Returns:
            Object keys from S3. For example: ['hoge', 'piyo', ...]
        """
        response = self.s3client.list_objects(Bucket=self.bucket_name, Prefix=os.path.join(self.prefix, 'log/task_log'), Marker=marker)
        if 'Contents' in response:
            keys.extend([{'Key': content['Key'], 'LastModified': content['LastModified']} for content in response['Contents']])
            if 'Contents' in response and 'IsTruncated' in response:
                return self._get_s3_keys(keys=keys, marker=keys[-1]['Key'])
        return keys

    def to_absolute_path(self, x: str) -> str:
        """get S3 file path"""
        x = x.lstrip('.').lstrip('/')
        if self.workspace_directory.rstrip('/').split('/')[-1] == x.split('/')[0]:
            x = '/'.join(x.split('/')[1:])
        return x
Пример #8
0
class GCSClient:
    def __init__(self,
                 workspace_directory: str = '',
                 task_filters: List[str] = [],
                 tqdm_disable: bool = False,
                 use_cache: bool = True):
        """must set $GCS_CREDENTIAL"""
        self.workspace_directory = workspace_directory
        self.task_filters = task_filters
        self.tqdm_disable = tqdm_disable
        self.gcs_client = GCSConfig().get_gcs_client()
        self.local_cache = LocalCache(workspace_directory, use_cache)
        self.use_cache = use_cache

    def get_tasks(self) -> List[Dict[str, Any]]:
        """Load all task_log from GCS"""
        files = self._get_gcs_objects()
        tasks_list = list()
        for x in tqdm(files, disable=self.tqdm_disable):
            n = x.split('/')[-1]
            if self.task_filters and not [
                    f for f in self.task_filters if f in n
            ]:
                continue
            n = n.split('_')

            if self.use_cache:
                cache = self.local_cache.get(x)
                if cache:
                    tasks_list.append(cache)
                    continue

            try:
                meta = self._get_gcs_object_info(x)
                params = {
                    'task_name':
                    '_'.join(n[:-1]),
                    'task_params':
                    pickle.load(
                        self.gcs_client.download(
                            x.replace('task_log', 'task_params'))),
                    'task_log':
                    pickle.load(self.gcs_client.download(x)),
                    'last_modified':
                    datetime.strptime(meta['updated'].split('.')[0],
                                      '%Y-%m-%dT%H:%M:%S'),
                    'task_hash':
                    n[-1].split('.')[0]
                }
                tasks_list.append(params)
                if self.use_cache:
                    self.local_cache.dump(x, params)
            except Exception:
                continue

        if len(tasks_list) != len(list(files)):
            warnings.warn(
                f'[NOT FOUND LOGS] target file: {len(list(files))}, found log file: {len(tasks_list)}'
            )

        return tasks_list

    def _get_gcs_objects(self) -> List[str]:
        """get GCS objects"""
        return self.gcs_client.listdir(
            os.path.join(self.workspace_directory, 'log/task_log'))

    def _get_gcs_object_info(self, x: str) -> Dict[str, str]:
        """get GCS object meta data"""
        bucket, obj = self.gcs_client._path_to_bucket_and_key(x)
        return self.gcs_client.client.objects().get(bucket=bucket,
                                                    object=obj).execute()

    def to_absolute_path(self, x: str) -> str:
        """get GCS file path"""
        x = x.lstrip('.').lstrip('/')
        if self.workspace_directory.rstrip('/').split('/')[-1] == x.split(
                '/')[0]:
            x = '/'.join(x.split('/')[1:])
        return x