def diff_gcs_directories( base_directory_url: str, target_directory_url: str) -> Tuple[List[str], List[str], List[str]]: """ Compare objects under different GCS prefixes. :param base_directory_url: URL for base directory :param target_directory_url: URL for target directory :returns: Tuple with 3 elements: List of objects in base directory that are not present in target directory List of objects in target directory that are not present in base directory List of objects with different content in base and target directory """ base = urlparse(base_directory_url) target = urlparse(target_directory_url) if base.scheme != "gs": raise ValueError("base_directory_url must be a gs:// URL") if target.scheme != "gs": raise ValueError("target_directory_url must be a gs:// URL") client = Client(project=None) base_blobs = client.list_blobs(base.hostname, prefix=base.path.strip("/") + "/") base_blobs = { _remove_prefix(blob.name, base.path.strip("/")): blob for blob in base_blobs } missing_objects = set(base_blobs.keys()) extra_objects = [] changed_objects = [] target_blobs = client.list_blobs(target.hostname, prefix=target.path.strip("/") + "/") for blob in target_blobs: key = _remove_prefix(blob.name, target.path.strip("/")) missing_objects.discard(key) try: if blob.md5_hash != base_blobs[key].md5_hash: changed_objects.append(key) except KeyError: extra_objects.append(key) return GCSDiffResult(list(missing_objects), extra_objects, changed_objects)
class GS(Base): _creds: service_account.Credentials = None _project: str = None _bucket: Bucket = None def __init__(self, bucket: str, creds_path: Optional[str] = None): super().__init__() if creds_path is not None: self._creds = service_account.Credentials.from_service_account_file( creds_path) with open(creds_path, 'rt') as f: self._project = json.loads(f.read())['project_id'] self._bucket = Client(self._project, self._creds).bucket(bucket) else: self._bucket = Client().bucket(bucket) def get(self, path: str) -> bytes: return self._bucket.get_blob(path).download_as_string() def put(self, path: str, content: bytes): self._bucket.blob(path).upload_from_string(content) def exists(self, path: str) -> bool: return self._bucket.get_blob(path) is not None def delete(self, path: str): blobs = self._bucket.list_blobs(prefix=path) for blob in blobs: blob.delete()
def remove_oldest_backlog_item( gcs_client: storage.Client, bkt: storage.Bucket, table_prefix: str, ) -> bool: """ Remove the oldest pointer in the backlog if the backlog is not empty. Args: gcs_client: storage.Client bkt: storage.Bucket that this cloud functions is ingesting data for. table_prefix: the prefix for the table whose backlog should be checked. Returns: bool: True if we removed the oldest blob. False if the backlog was empty. """ backlog_blobs = gcs_client.list_blobs(bkt, prefix=f"{table_prefix}/_backlog/") # Backlog items will be lexciographically sorted # https://cloud.google.com/storage/docs/json_api/v1/objects/list blob: storage.Blob for blob in backlog_blobs: blob.delete(client=gcs_client) return True # Return after deleteing first blob in the iterator return False
def bucket_lister(config: ConfigParser, gcs: Client, bucket: Bucket, prefix: str, bucket_number: int, total_buckets: int, stats: dict) -> bool: """List a bucket, sending each page of the listing into an executor pool for processing. Arguments: config {ConfigParser} -- The program config. gcs {Client} -- A GCS client object. bucket {Bucket} -- A GCS Bucket object to list. bucket_number {int} -- The number of this bucket (out of the total). total_buckets {int} -- The total number of buckets that will be listed. stats {dict} -- A dictionary of bucket_name (str): blob_count (int) """ LOG.info("Listing %s. %s of %s total buckets", bucket.name, bucket_number, total_buckets) stats[bucket] = 0 # Use remaining configured workers, or at least 2, for this part workers = max(config.getint('RUNTIME', 'WORKERS') - 2, 2) size = int(config.getint('RUNTIME', 'WORK_QUEUE_SIZE') * .75) with BoundedThreadPoolExecutor(max_workers=workers, queue_size=size) as sub_executor: blobs = gcs.list_blobs(bucket, prefix=prefix) for page in blobs.pages: sub_executor.submit(page_outputter, config, bucket, page, stats) sleep(0.02) # small offset to avoid thundering herd
def get_locked_integrations(integrations: list, storage_client: storage.Client) -> dict: """ Getting all locked integrations files Args: integrations: Integrations that we want to get lock files for storage_client: The GCP storage client Returns: A dict of the form {<integration-name>:<integration-blob-object>} for all integrations that has a blob object. """ # Listing all files in lock folder # Wrapping in 'list' operator because list_blobs return a generator which can only be iterated once lock_files_ls = list( storage_client.list_blobs(BUCKET_NAME, prefix=f'{LOCKS_PATH}')) current_integrations_lock_files = {} # Getting all existing files details for integrations that we want to lock for integration in integrations: current_integrations_lock_files.update({ integration: [ lock_file_blob for lock_file_blob in lock_files_ls if lock_file_blob.name == f'{LOCKS_PATH}/{integration}' ] }) # Filtering 'current_integrations_lock_files' from integrations with no files current_integrations_lock_files = { integration: blob_files[0] for integration, blob_files in current_integrations_lock_files.items() if blob_files } return current_integrations_lock_files
def load_stage(dst_dataset: Dataset, bq_client: Client, bucket_name: str, gcs_client: storage.Client) -> List[LoadJob]: """ Stage files from a bucket to a dataset :param dst_dataset: reference to destination dataset object :param bq_client: a BigQuery client object :param bucket_name: the location in GCS containing the vocabulary files :param gcs_client: a Cloud Storage client object :return: list of completed load jobs """ blobs = list(gcs_client.list_blobs(bucket_name)) table_blobs = [_filename_to_table_name(blob.name) for blob in blobs] missing_blobs = [ table for table in VOCABULARY_TABLES if table not in table_blobs ] if missing_blobs: raise RuntimeError( f'Bucket {bucket_name} is missing files for tables {missing_blobs}' ) load_jobs = [] for blob in blobs: table_name = _filename_to_table_name(blob.name) # ignore any non-vocabulary files if table_name not in VOCABULARY_TABLES: continue destination = dst_dataset.table(table_name) safe_schema = safe_schema_for(table_name) job_config = LoadJobConfig() job_config.schema = safe_schema job_config.skip_leading_rows = 1 job_config.field_delimiter = FIELD_DELIMITER job_config.max_bad_records = MAX_BAD_RECORDS job_config.source_format = 'CSV' job_config.quote_character = '' source_uri = f'gs://{bucket_name}/{blob.name}' load_job = bq_client.load_table_from_uri(source_uri, destination, job_config=job_config) LOGGER.info(f'table:{destination} job_id:{load_job.job_id}') load_jobs.append(load_job) load_job.result() return load_jobs
def load_folder(dst_dataset: str, bq_client: BQClient, bucket_name: str, prefix: str, gcs_client: GCSClient, hpo_id: str) -> List[LoadJob]: """ Stage files from a bucket to a dataset :param dst_dataset: Identifies the destination dataset :param bq_client: a BigQuery client object :param bucket_name: the bucket in GCS containing the archive files :param prefix: prefix of the filepath URI :param gcs_client: a Cloud Storage client object :param hpo_id: Identifies the HPO site :return: list of completed load jobs """ blobs = list(gcs_client.list_blobs(bucket_name, prefix=prefix)) load_jobs = [] for blob in blobs: table_name = _filename_to_table_name(blob.name) if table_name not in AOU_REQUIRED: LOGGER.debug(f'Skipping file for {table_name}') continue schema = get_table_schema(table_name) hpo_table_name = f'{hpo_id}_{table_name}' fq_hpo_table = f'{bq_client.project}.{dst_dataset}.{hpo_table_name}' destination = Table(fq_hpo_table, schema=schema) destination = bq_client.create_table(destination) job_config = LoadJobConfig() job_config.schema = schema job_config.skip_leading_rows = 1 job_config.source_format = 'CSV' source_uri = f'gs://{bucket_name}/{blob.name}' load_job = bq_client.load_table_from_uri( source_uri, destination, job_config=job_config, job_id_prefix=f"{__file__.split('/')[-1].split('.')[0]}_") LOGGER.info(f'table:{destination} job_id:{load_job.job_id}') load_jobs.append(load_job) load_job.result() return load_jobs
class GCStorage(RemoteStorageABC): def __init__(self, project: str = "mathieu-tricicl", bucket_name: str = "tricicl-public"): self.client = Client(project=project) self.bucket = self.client.bucket(bucket_name) def upload_file(self, sync_path: SyncPath): blob = self.bucket.blob(str(sync_path.remote), chunk_size=10 * 1024 * 1024) blob.upload_from_filename(str(sync_path.local), timeout=60 * 5) def download_file(self, sync_path: SyncPath): sync_path.local.parent.mkdir(exist_ok=True, parents=True) blob = self.bucket.get_blob(str(sync_path.remote)) if blob is None: raise FileNotFoundError(f"{sync_path.remote} is not on gcloud bucket") blob.download_to_filename(str(sync_path.local), timeout=60 * 5) def list_files(self, remote_path: PurePath, suffix: str = "") -> List[PurePath]: return [ PurePath(b.name) for b in self.client.list_blobs(self.bucket, prefix=str(remote_path)) if not b.name.endswith("/") and b.name.endswith(suffix) ]
def list_bucket_files( client: Client, bucket_name: str, prefix=None, file_extension=None, is_dir=False, ) -> Iterable[str]: blob_list = client.list_blobs(bucket_name, prefix=prefix) for blob in blob_list: # filter specific extensions if file_extension is not None: blob_extension_name = blob.name.split(".")[-1] if file_extension.strip(".") != blob_extension_name: continue # Excludes top level directory that matches prefix if is_dir and blob.name[-1] == "/": continue yield f"gs://{blob.bucket.name}/{blob.name}"
def trigger_processing(storage_client: storage.Client, parent_directory: str): bucket = storage_client.get_bucket(BUCKET_NAME) blob = bucket.blob(f'{parent_directory}/{FILING_MANIFEST_FILENAME}') manifest_file = blob.download_as_string().decode('utf8') expected_count = len(manifest_file.split('\n')) actual_count = 0 blobs = storage_client.list_blobs( BUCKET_NAME, prefix=f'{parent_directory}/{XML_DIRECTORY_NAME}/') for _ in blobs: actual_count += 1 if actual_count == expected_count: logger.info('All %d files downloaded. Starting processing.', actual_count) publisher = pubsub_v1.PublisherClient() topic_path = publisher.topic_path(PROJECT_ID, 'process-netfile-filings') # pylint: disable=no-member publisher.publish(topic_path, b'', directory=parent_directory) else: logger.debug( 'Only %d of %d filings downloaded. Waiting to start processing.', actual_count, expected_count)
def get_next_backlog_item( gcs_client: storage.Client, bkt: storage.Bucket, table_prefix: str, ) -> Optional[storage.Blob]: """ Get next blob in the backlog if the backlog is not empty. Args: gcs_client: storage.Client bkt: storage.Bucket that this cloud functions is ingesting data for. table_prefix: the prefix for the table whose backlog should be checked. Retruns: storage.Blob: pointer to a SUCCESS file in the backlog """ backlog_blobs = gcs_client.list_blobs(bkt, prefix=f"{table_prefix}/_backlog/") # Backlog items will be lexciographically sorted # https://cloud.google.com/storage/docs/json_api/v1/objects/list for blob in backlog_blobs: return blob # Return first item in iterator return None
def get_batches_for_gsurl(gcs_client: storage.Client, gsurl: str, recursive=True) -> List[List[str]]: """ This function creates batches of GCS uris for a given gsurl. By default, it will recursively search for blobs in all sub-folders of the given gsurl. The function will ignore uris of objects which match the following: - filenames which are present in constants.ACTION_FILENAMES - filenames that start with a dot (.) - _bqlock file created for ordered loads - filename contains any constant.SPECIAL_GCS_DIRECTORY_NAMES in their path returns an Array of their batches (one batch has an array of multiple GCS uris) """ batches: List[List[str]] = [] parsed_url = urlparse(gsurl) bucket_name: str = parsed_url.netloc prefix_path: str = parsed_url.path.lstrip('/') bucket: storage.Bucket = cached_get_bucket(gcs_client, bucket_name) folders: Set[str] = get_folders_in_gcs_path_prefix(gcs_client, bucket, prefix_path, recursive=recursive) folders.add(prefix_path) print( json.dumps( dict(message="Searching for blobs to load in" " prefix path and sub-folders", search_folders=list(folders), severity="INFO"))) blobs: List[storage.Blob] = [] for folder in folders: blobs += (list( gcs_client.list_blobs(bucket, prefix=folder, delimiter="/"))) cumulative_bytes = 0 max_batch_size = int( os.getenv("MAX_BATCH_BYTES", constants.DEFAULT_MAX_BATCH_BYTES)) batch: List[str] = [] for blob in blobs: # The following blobs will be ignored: # - filenames which are present in constants.ACTION_FILENAMES # - filenames that start with a dot (.) # - _bqlock file created for ordered loads # - filenames with constants.SPECIAL_GCS_DIRECTORY_NAMES in their path if (os.path.basename(blob.name) not in constants.ACTION_FILENAMES and not os.path.basename(blob.name).startswith(".") and os.path.basename(blob.name) != "_bqlock" and not any(blob_dir_name in constants.SPECIAL_GCS_DIRECTORY_NAMES for blob_dir_name in blob.name.split('/'))): if blob.size == 0: # ignore empty files print(f"ignoring empty file: gs://{bucket.name}/{blob.name}") continue cumulative_bytes += blob.size # keep adding until we reach threshold if cumulative_bytes <= max_batch_size or len( batch) > constants.MAX_SOURCE_URIS_PER_LOAD: batch.append(f"gs://{bucket_name}/{blob.name}") else: batches.append(batch.copy()) batch.clear() batch.append(f"gs://{bucket_name}/{blob.name}") cumulative_bytes = blob.size # pick up remaining files in the final batch if len(batch) > 0: batches.append(batch.copy()) batch.clear() print( json.dumps( dict(message="Logged batches of blobs to load in jsonPayload.", batches=batches))) if len(batches) > 1: print(f"split into {len(batches)} batches.") elif len(batches) < 1: raise google.api_core.exceptions.NotFound( f"No files to load at {gsurl}!") return batches
def process_post_request(request_dict, user_id): # First, let's check if the user requesting the logs has access to the DAGs run account # user_accounts = None try: user_accounts = get_user_accounts(user_id=user_id) except Exception as ex: print("Exception occured while getting user accounts.\n{}".format(ex)) if user_accounts is None: message = "No user accounts found." print(message) return ({"data": message}, 403) # Check that the DAG run requested belong to the user's account # check_dag_run = False try: check_dag_run, dag_run_data = check_dag_run_vs_accounts( accounts=user_accounts, dag_type=request_dict["data"]["dagType"].strip(), dag_id=request_dict["data"]["dagId"].strip(), dag_run_id=request_dict["data"]["dagRunId"].strip()) except Exception as ex: print("Exception occurred while checking DAG run VS User account.\n{}". format(ex)) if check_dag_run is False: message = "User does not have access to DAG account." print(message) return ({"data": message}, 403) # Check for Cloud Function logs first # if dag_run_data is not None: try: if ("gcp_cf_execution_id" in dag_run_data.keys()) and ("gcp_cf_name" in dag_run_data.keys()): print("Processing new logs ...") ret, message = process_cf_logs(dag_run_data=dag_run_data) # Specific processing for TTT in direct execution mode # elif ("configuration_context" in dag_run_data.keys()) \ and ("configuration" in dag_run_data["configuration_context"].keys()) \ and ("direct_execution" in dag_run_data["configuration_context"]["configuration"].keys()) \ and (dag_run_data["configuration_context"]["configuration"]["direct_execution"] is True): ret, message = process_direct_execution_ttt_logs( execution_id=request_dict["data"]["dagRunId"].strip(), task_id=request_dict["data"]["taskId"].strip()) # Processing response # if ret is False: message = "Error while processing CF logs : {}".format(message) print(message) return ({"data": message}, 500) else: data = {} data["data"] = {} try: data["data"]["logs"] = str( base64.b64encode(bytes(message, "utf-8")), "utf-8") except Exception as ex: print("Error while encoding logs: \n{}".format(ex)) data["data"]["error"] = str( base64.b64encode(b"Error while retrieving logs."), "utf-8") return data, 200 except Exception as ex: print("Check for CF failed : \n{}".format(ex)) try: # Read the data from Google Cloud Storage read_storage_client = Client() # Set buckets and filenames bucket_name = "europe-west1-fd-io-composer-e5bff15e-bucket" # get bucket with name # bucket = read_storage_client.get_bucket(bucket_name) blob_prefix = "logs/{}/{}/{}/".format( request_dict["data"]["dagId"].strip(), request_dict["data"]["taskId"].strip(), request_dict["data"]["dagExecutionDate"].strip()) print("Getting logs : {}".format(blob_prefix)) blobs = read_storage_client.list_blobs(bucket, prefix=blob_prefix) data = {} data["data"] = {} for blob in blobs: try: blob_short_name = ((blob.name).strip()).rpartition("/")[2] if not blob_short_name: print("Not a file name : {}".format(blob.name)) continue print("Log filename : {}".format(blob_short_name)) except Exception: # Not a filename. # print("Not a file name : {}".format(blob.name)) continue try: data["data"][blob_short_name] = str( base64.b64encode(blob.download_as_string()), "utf-8") except Exception as ex: print("Error while processing : {}\n{}".format(blob.name), ex) data["data"][blob_short_name] = str( base64.b64encode(b"Error while retrieving logs."), "utf-8") if not data["data"]: # Empty => not found return data, 404 else: return data, 200 except Exception as ex: print("Error while processing : {}".format(ex)) return {"message": "Error while fetching log files."}, 500
def _get_gcs_blobs(storage_client: storage.Client, bucket: str, target_path: str) -> Dict[str, datetime]: """Return all blobs in the GCS location with their last modified timestamp.""" blobs = storage_client.list_blobs(bucket, prefix=target_path + "/") return {blob.name.replace(".json", ""): blob.updated for blob in blobs}
class BucketClientGCS(BucketClient): client: Optional[GCSNativeClient] def __init__(self, client: Optional[GCSNativeClient] = None): try: self.client = GCSNativeClient() if GCSNativeClient else None except (BaseException, DefaultCredentialsError): self.client = None def make_uri(self, path: PurePathy) -> str: return str(path) def create_bucket(self, path: PurePathy) -> Bucket: assert self.client is not None, _MISSING_DEPS return self.client.create_bucket(path.root) def delete_bucket(self, path: PurePathy) -> None: assert self.client is not None, _MISSING_DEPS bucket = self.client.get_bucket(path.root) bucket.delete() def exists(self, path: PurePathy) -> bool: # Because we want all the parents of a valid blob (e.g. "directory" in # "directory/foo.file") to return True, we enumerate the blobs with a prefix # and compare the object names to see if they match a substring of the path key_name = str(path.key) try: for obj in self.list_blobs(path): if obj.name == key_name: return True if obj.name.startswith(key_name + path._flavour.sep): return True except gcs_errors.ClientError: return False return False def lookup_bucket(self, path: PurePathy) -> Optional[BucketGCS]: assert self.client is not None, _MISSING_DEPS try: native_bucket = self.client.bucket(path.root) if native_bucket is not None: return BucketGCS(str(path.root), bucket=native_bucket) except gcs_errors.ClientError as err: print(err) return None def get_bucket(self, path: PurePathy) -> BucketGCS: assert self.client is not None, _MISSING_DEPS try: native_bucket = self.client.bucket(path.root) if native_bucket is not None: return BucketGCS(str(path.root), bucket=native_bucket) raise FileNotFoundError(f"Bucket {path.root} does not exist!") except gcs_errors.ClientError as e: raise ClientError(message=e.message, code=e.code) def list_buckets( self, **kwargs: Dict[str, Any]) -> Generator[GCSNativeBucket, None, None]: assert self.client is not None, _MISSING_DEPS return self.client.list_buckets(**kwargs) # type:ignore def scandir( # type:ignore[override] self, path: Optional[PurePathy] = None, prefix: Optional[str] = None, delimiter: Optional[str] = None, ) -> Generator[BucketEntryGCS, None, None]: # type:ignore[override] assert self.client is not None, _MISSING_DEPS continuation_token = None if path is None or not path.root: gcs_bucket: GCSNativeBucket for gcs_bucket in self.list_buckets(): yield BucketEntryGCS(gcs_bucket.name, is_dir=True, raw=None) return sep = path._flavour.sep bucket = self.lookup_bucket(path) if bucket is None: return while True: if continuation_token: response = self.client.list_blobs( bucket.name, prefix=prefix, delimiter=sep, page_token=continuation_token, ) else: response = self.client.list_blobs(bucket.name, prefix=prefix, delimiter=sep) for page in response.pages: for folder in list(page.prefixes): full_name = folder[:-1] if folder.endswith(sep) else folder name = full_name.split(sep)[-1] if name: yield BucketEntryGCS(name, is_dir=True, raw=None) for item in page: name = item.name.split(sep)[-1] if name: yield BucketEntryGCS( name=name, is_dir=False, size=item.size, last_modified=item.updated.timestamp(), raw=item, ) if response.next_page_token is None: break continuation_token = response.next_page_token def list_blobs( self, path: PurePathy, prefix: Optional[str] = None, delimiter: Optional[str] = None, include_dirs: bool = False, ) -> Generator[BlobGCS, None, None]: assert self.client is not None, _MISSING_DEPS continuation_token = None bucket = self.lookup_bucket(path) if bucket is None: return while True: if continuation_token: response = self.client.list_blobs( path.root, prefix=prefix, delimiter=delimiter, page_token=continuation_token, ) else: response = self.client.list_blobs(path.root, prefix=prefix, delimiter=delimiter) for page in response.pages: for item in page: yield BlobGCS( bucket=bucket, owner=item.owner, name=item.name, raw=item, size=item.size, updated=item.updated.timestamp(), ) if response.next_page_token is None: break continuation_token = response.next_page_token
def load(self, continue_from_blob: Optional[str] = None) -> bool: current_date = datetime.now().strftime('%Y.%m.%d %H.%M.%S') file_name = f'loader {current_date}.log' set_log_handler(logger=self.logger, file_name=file_name) client = Client(project=self.project_name) self.logger.info(f'Created client for project {self.project_name}') bucket = self._get_bucket(client) if not bucket: return False self.logger.info(f'Found bucket {self.bucket_name}') for last_blob in client.list_blobs(bucket, versions=False): last_blob: Blob if last_blob.name.endswith('/') or ( continue_from_blob and last_blob.name < continue_from_blob): continue self.logger.info( f'Downloading versions of object {last_blob.name}') try: last_version = self._parse_version(last_blob) object_path = self._get_version_path(last_version) if object_path.exists(): self.logger.info('No new versions found') continue versions = [] for blob in client.list_blobs(bucket, prefix=last_blob.name, versions=True): self.logger.info(blob.name) file_version = self._parse_version(blob) object_path = self._get_version_path(file_version) self.logger.info( f'fileId={file_version.fileId}, timestamp={file_version.timestamp}, object_path={object_path}' ) versions.append(file_version) new_versions = 0 versions_to_load = [] for version in sorted(versions, key=lambda v: -v.timestamp): object_path = self._get_version_path(version) new_versions += not object_path.exists() versions_to_load.append(version) if object_path.exists(): break """ We also load the last version among those that are already loaded as it could have been damaged. For example, due to an unexpected interrupt of a loader script. """ for version in reversed(versions_to_load): object_path = self._get_version_path(version) patch_list = PatchList(patches=version.patches) self._store_data(object_path, patch_list.SerializeToString()) if versions_to_load: self._store_data( self._get_content_path(versions_to_load[0]), versions_to_load[0].content) self.logger.info( f'{new_versions} new versions of object {last_blob.name} found' ) except Exception as error: self.logger.error( f'An unexpected error occurred while downloading objects ' f'from bucket {bucket.name}: {error}') return True
def _list_remote_keys(self): client = Client(project=self.project) for blob in client.list_blobs(self.bucket, prefix=self.prefix): yield blob.name[len(self.prefix) + 1:]
class BucketClientGCS(BucketClient): client: GCSNativeClient @property def client_params(self) -> Any: return dict(client=self.client) def __init__(self, **kwargs: Any) -> None: self.recreate(**kwargs) def recreate(self, **kwargs: Any) -> None: creds = kwargs["credentials"] if "credentials" in kwargs else None if creds is not None: kwargs["project"] = creds.project_id self.client = GCSNativeClient(**kwargs) def make_uri(self, path: PurePathy) -> str: return str(path) def create_bucket( # type:ignore[override] self, path: PurePathy) -> GCSNativeBucket: return self.client.create_bucket(path.root) # type:ignore def delete_bucket(self, path: PurePathy) -> None: bucket = self.client.get_bucket(path.root) # type:ignore bucket.delete() # type:ignore def exists(self, path: PurePathy) -> bool: # Because we want all the parents of a valid blob (e.g. "directory" in # "directory/foo.file") to return True, we enumerate the blobs with a prefix # and compare the object names to see if they match a substring of the path key_name = str(path.key) for obj in self.list_blobs(path): if obj.name.startswith(key_name + path._flavour.sep): # type:ignore return True return False def lookup_bucket(self, path: PurePathy) -> Optional[BucketGCS]: try: return self.get_bucket(path) except FileNotFoundError: return None def get_bucket(self, path: PurePathy) -> BucketGCS: native_bucket: Any = self.client.bucket(path.root) # type:ignore try: if native_bucket.exists(): return BucketGCS(str(path.root), bucket=native_bucket) except BadRequest: pass raise FileNotFoundError(f"Bucket {path.root} does not exist!") def list_buckets( # type:ignore[override] self, **kwargs: Dict[str, Any]) -> Generator[GCSNativeBucket, None, None]: return self.client.list_buckets(**kwargs) # type:ignore def scandir( # type:ignore[override] self, path: Optional[PurePathy] = None, prefix: Optional[str] = None, delimiter: Optional[str] = None, ) -> PathyScanDir: return ScanDirGCS(client=self, path=path, prefix=prefix, delimiter=delimiter) def list_blobs( self, path: PurePathy, prefix: Optional[str] = None, delimiter: Optional[str] = None, ) -> Generator[BlobGCS, None, None]: bucket = self.lookup_bucket(path) if bucket is None: return response: Any = self.client.list_blobs( # type:ignore path.root, prefix=prefix, delimiter=delimiter) for page in response.pages: # type:ignore for item in page: yield BlobGCS( bucket=bucket, owner=item.owner, name=item.name, raw=item, size=item.size, updated=item.updated.timestamp(), )
class GCSClient: """ This class is used to download data from GCS location and perform function such as downloading the dataset and checksum validation. """ GCS_PREFIX = "^gs://" KEY_SEPARATOR = "/" def __init__(self, **kwargs): """ Initialize a client to google cloud storage (GCS). """ self.client = Client(**kwargs) def download(self, *, url=None, local_path=None, bucket=None, key=None): """ This method is used to download the dataset from GCS. Args: url (str): This is the downloader-uri that indicates where the dataset should be downloaded from. local_path (str): This is the path to the directory where the download will store the dataset. bucket (str): gcs bucket name key (str): object key path Examples: >>> url = "gs://bucket/folder or gs://bucket/folder/data.zip" >>> local_path = "/tmp/folder" >>> bucket ="bucket" >>> key ="folder/data.zip" or "folder" """ if not (bucket and key) and url: bucket, key = self._parse(url) bucket_obj = self.client.get_bucket(bucket) if self._is_file(bucket_obj, key): self._download_file(bucket_obj, key, local_path) else: self._download_folder(bucket_obj, key, local_path) def _download_folder(self, bucket, key, local_path): """ download all files from directory """ blobs = bucket.list_blobs(prefix=key) for blob in blobs: local_file_path = blob.name.replace(key, local_path) self._download_validate(blob, local_file_path) def _download_file(self, bucket, key, local_path): """ download single file """ blob = bucket.get_blob(key) key_suffix = key.replace("/" + basename(key), "") local_file_path = blob.name.replace(key_suffix, local_path) self._download_validate(blob, local_file_path) def _download_validate(self, blob, local_file_path): """ download file and validate checksum """ self._download_blob(blob, local_file_path) self._checksum(blob, local_file_path) def _download_blob(self, blob, local_file_path): """ download blob from gcs Raises: NotFound: This will raise when object not found """ dst_dir = local_file_path.replace("/" + basename(local_file_path), "") key = blob.name if not isdir(dst_dir): makedirs(dst_dir) logger.info(f"Downloading from {key} to {local_file_path}.") blob.download_to_filename(local_file_path) def _checksum(self, blob, filename): """validate checksum and delete file if checksum does not match Raises: ChecksumError: This will raise this error if checksum doesn't matches """ expected_checksum = blob.md5_hash if expected_checksum: expected_checksum_hex = self._md5_hex(expected_checksum) try: validate_checksum( filename, expected_checksum_hex, algorithm="MD5" ) except ChecksumError as e: logger.exception( "Checksum mismatch. Delete the downloaded files." ) os.remove(filename) raise e def _is_file(self, bucket, key): """Check if the key is a file or directory""" blob = bucket.get_blob(key) return blob and blob.name == key def _md5_hex(self, checksum): """fix the missing padding if requires and converts into hex""" missing_padding = len(checksum) % 4 if missing_padding != 0: checksum += "=" * (4 - missing_padding) return base64.b64decode(checksum).hex() def _parse(self, url): """Split an GCS-prefixed URL into bucket and path.""" match = re.search(self.GCS_PREFIX, url) if not match: raise ValueError( f"Specified destination prefix: {url} does not start " f"with {self.GCS_PREFIX}" ) url = url[len(self.GCS_PREFIX) - 1 :] if self.KEY_SEPARATOR not in url: raise ValueError( f"Specified destination prefix: {self.GCS_PREFIX + url} does " f"not have object key " ) idx = url.index(self.KEY_SEPARATOR) bucket = url[:idx] path = url[(idx + 1) :] return bucket, path def upload( self, *, local_path=None, bucket=None, key=None, url=None, pattern="*" ): """ Upload a file or list of files from directory to GCS Args: url (str): This is the gcs location that indicates where the dataset should be uploaded. local_path (str): This is the path to the directory or file where the data is stored. bucket (str): gcs bucket name key (str): object key path pattern: Unix glob patterns. Use **/* for recursive glob. Examples: For file upload: >>> url = "gs://bucket/folder/data.zip" >>> local_path = "/tmp/folder/data.zip" >>> bucket ="bucket" >>> key ="folder/data.zip" For directory upload: >>> url = "gs://bucket/folder" >>> local_path = "/tmp/folder" >>> bucket ="bucket" >>> key ="folder" >>> key ="**/*" """ if not (bucket and key) and url: bucket, key = self._parse(url) bucket_obj = self.client.get_bucket(bucket) if isdir(local_path): self._upload_folder( local_path=local_path, bucket=bucket_obj, key=key, pattern=pattern, ) else: self._upload_file(local_path=local_path, bucket=bucket_obj, key=key) def _upload_file(self, local_path=None, bucket=None, key=None): """ Upload a single object to GCS """ blob = bucket.blob(key) logger.info(f"Uploading from {local_path} to {key}.") blob.upload_from_filename(local_path) def _upload_folder( self, local_path=None, bucket=None, key=None, pattern="*" ): """Upload all files from a folder to GCS based on pattern """ for path in Path(local_path).glob(pattern): if path.is_dir(): continue full_path = str(path) relative_path = str(path.relative_to(local_path)) object_key = os.path.join(key, relative_path) self._upload_file( local_path=full_path, bucket=bucket, key=object_key ) def get_most_recent_blob(self, url=None, bucket_name=None, key=None): """ Get the last updated blob in a given bucket under given prefix Args: bucket_name (str): gcs bucket name key (str): object key path """ if not (bucket_name and key) and url: bucket_name, key = self._parse(url) bucket = self.client.get_bucket(bucket_name) if self._is_file(bucket, key): # Called on file, return file return bucket.get_blob(key) else: logger.debug( f"Cloud path not a file. Checking for most recent file in {url}" ) # Return the blob with the max update time (most recent) blobs = self._list_blobs(bucket, prefix=key) return max( blobs, key=lambda blob: bucket.get_blob(blob.name).updated ) def _list_blobs(self, bucket_name=None, prefix=None): """List all blobs with given prefix """ blobs = self.client.list_blobs(bucket_name, prefix=prefix) blob_list = list(blobs) logger.debug(f"Blobs in {bucket_name} under prefix {prefix}:") logger.debug(blob_list) return blob_list
class GSPath: path: PurePath bucket: Bucket blob: Blob @classmethod def from_blob(cls, blob: Blob) -> GSPath: return cls(blob.bucket, blob.name) @classmethod def from_url(cls, url_str: str) -> GSPath: url = urlparse(url_str) if url.scheme != 'gs': raise ValueError('Wrong url scheme') return cls(url.netloc, url.path[1:]) def __init__(self, bucket: Union[str, Bucket], path: Union[PurePath, str]) -> None: self.path = PurePath(str(path)) if isinstance(bucket, str): bucket = Client().bucket(bucket) self.bucket = bucket self.blob = self.bucket.blob(str(self.path)) def __getstate__(self) -> Dict[str, Any]: return {'path': self.path, 'bucket': self.bucket.name} def __setstate__(self, data: Dict[str, Any]) -> None: self.path = data['path'] self.bucket = Client().bucket(data['bucket']) self.blob = self.bucket.blob(str(self.path)) # Otherwise pylint thinks GSPath is undefined # pylint: disable=undefined-variable def __truediv__(self, other: Union[str, PurePath]) -> GSPath: return GSPath(self.bucket, self.path / str(other)) def __repr__(self) -> str: url = f'gs://{self.bucket.name}/{self.path}' return f'GSPath.from_url({url!r})' @property def parent(self) -> GSPath: return GSPath(self.bucket, self.path.parent) def mkdir(self, mode: int = 0, parents: bool = True, exist_ok: bool = True) -> None: # no notion of 'directories' in GS pass def rmtree(self) -> None: for path in self.iterdir(): path.unlink() def exists(self) -> bool: # print(f'{self.blob.name} exists? {self.blob.exists()}') return self.blob.exists() def unlink(self) -> None: self.blob.delete() def iterdir(self) -> Iterable[GSPath]: for blob in self.bucket.list_blobs(prefix=f'{self.path!s}/'): yield GSPath.from_blob(blob) def open(self, mode: str = 'r', encoding: str = 'UTF-8', newline: str = '\n') -> IO[Any]: if 'w' in mode: if 'b' in mode: return WGSFile(self, mode) else: return io.TextIOWrapper(WGSFile(self, mode), encoding=encoding, newline=newline) elif 'r' in mode: if 'b' in mode: return io.BytesIO(self.blob.download_as_string()) else: return io.StringIO( self.blob.download_as_string().decode(encoding=encoding), newline) else: raise RuntimeError(f'Flag {mode} not supported') def public_path(self) -> str: return f'https://storage.googleapis.com/{self.bucket.name}/{self.path}'
class BucketClientGCS(BucketClient): client: Optional[GCSNativeClient] @property def client_params(self) -> Any: return dict(client=self.client) def __init__(self, **kwargs: Any) -> None: self.recreate(**kwargs) def recreate(self, **kwargs: Any) -> None: creds = kwargs["credentials"] if "credentials" in kwargs else None if creds is not None: kwargs["project"] = creds.project_id try: self.client = GCSNativeClient(**kwargs) except TypeError: # TypeError is raised if the imports for GCSNativeClient fail and are # assigned to Any, which is not callable. self.client = None def make_uri(self, path: PurePathy) -> str: return str(path) def create_bucket(self, path: PurePathy) -> Bucket: assert self.client is not None, _MISSING_DEPS return self.client.create_bucket(path.root) def delete_bucket(self, path: PurePathy) -> None: assert self.client is not None, _MISSING_DEPS bucket = self.client.get_bucket(path.root) bucket.delete() def exists(self, path: PurePathy) -> bool: # Because we want all the parents of a valid blob (e.g. "directory" in # "directory/foo.file") to return True, we enumerate the blobs with a prefix # and compare the object names to see if they match a substring of the path key_name = str(path.key) try: for obj in self.list_blobs(path): if obj.name == key_name: return True if obj.name.startswith(key_name + path._flavour.sep): return True except gcs_errors.ClientError: return False return False def lookup_bucket(self, path: PurePathy) -> Optional[BucketGCS]: assert self.client is not None, _MISSING_DEPS try: native_bucket = self.client.bucket(path.root) if native_bucket is not None: return BucketGCS(str(path.root), bucket=native_bucket) except gcs_errors.ClientError as err: print(err) return None def get_bucket(self, path: PurePathy) -> BucketGCS: assert self.client is not None, _MISSING_DEPS try: native_bucket = self.client.bucket(path.root) if native_bucket is not None: return BucketGCS(str(path.root), bucket=native_bucket) raise FileNotFoundError(f"Bucket {path.root} does not exist!") except gcs_errors.ClientError as e: raise ClientError(message=e.message, code=e.code) def list_buckets( self, **kwargs: Dict[str, Any] ) -> Generator[GCSNativeBucket, None, None]: assert self.client is not None, _MISSING_DEPS return self.client.list_buckets(**kwargs) # type:ignore def scandir( # type:ignore[override] self, path: Optional[PurePathy] = None, prefix: Optional[str] = None, delimiter: Optional[str] = None, ) -> PathyScanDir: return _GCSScanDir(client=self, path=path, prefix=prefix, delimiter=delimiter) def list_blobs( self, path: PurePathy, prefix: Optional[str] = None, delimiter: Optional[str] = None, include_dirs: bool = False, ) -> Generator[BlobGCS, None, None]: assert self.client is not None, _MISSING_DEPS continuation_token = None bucket = self.lookup_bucket(path) if bucket is None: return while True: if continuation_token: response = self.client.list_blobs( path.root, prefix=prefix, delimiter=delimiter, page_token=continuation_token, ) else: response = self.client.list_blobs( path.root, prefix=prefix, delimiter=delimiter ) for page in response.pages: for item in page: yield BlobGCS( bucket=bucket, owner=item.owner, name=item.name, raw=item, size=item.size, updated=item.updated.timestamp(), ) if response.next_page_token is None: break continuation_token = response.next_page_token
def test_read_gcs(gcs: storage.Client, gcs_src: Dict[str, Any]): """Test reading a GCS blob in source configs.""" blobs = gcs.list_blobs(gcs_src["bucket"], prefix=gcs_src["prefix"]) assert sum(1 for _ in blobs) > 0