def test_copy_batch(self): drs_urls = { # 1631686 bytes # name property disapeard from DRS response :( # "NWD522743.b38.irc.v1.cram.crai": "drs://dg.4503/95cc4ae1-dee7-4266-8b97-77cf46d83d35", # 1631686 bytes "NWD522743.b38.irc.v1.cram.crai": "drs://dg.4503/95cc4ae1-dee7-4266-8b97-77cf46d83d35", "data_phs001237.v2.p1.c1.avro.gz": "drs://dg.4503/26e11149-5deb-4cd7-a475-16997a825655", # 1115092 bytes "RootStudyConsentSet_phs001237.TOPMed_WGS_WHI.v2.p1.c1.HMB-IRB.tar.gz": "drs://dg.4503/e9c2caf2-b2a1-446d-92eb-8d5389e99ee3", # 332237 bytes # "NWD961306.freeze5.v1.vcf.gz": "drs://dg.4503/6e73a376-f7fd-47ed-ac99-0567bb5a5993", # 2679331445 bytes # "NWD531899.freeze5.v1.vcf.gz": "drs://dg.4503/651a4ad1-06b5-4534-bb2c-1f8ed51134f6", # 2679411265 bytes } pfx = f"test-batch-copy/{uuid4()}" bucket = gs.get_client().bucket(WORKSPACE_BUCKET) with self.subTest("gs bucket"): with mock.patch("terra_notebook_utils.drs.MULTIPART_THRESHOLD", 400000): drs.copy_batch( list(drs_urls.values()), f"gs://fc-9169fcd1-92ce-4d60-9d2d-d19fd326ff10/{pfx}") for name in list(drs_urls.keys()): blob = bucket.get_blob(f"{pfx}/{name}") self.assertGreater(blob.size, 0) with self.subTest("local filesystem"): with tempfile.TemporaryDirectory() as dirname: drs.copy_batch(list(drs_urls.values()), dirname) names = [ os.path.basename(path) for path in _list_tree(dirname) ] self.assertEqual(sorted(names), sorted(list(drs_urls.keys())))
def copy_batch(drs_urls: Iterable[str], dst: str, workspace_name: Optional[str]=WORKSPACE_NAME, workspace_namespace: Optional[str]=WORKSPACE_GOOGLE_PROJECT): enable_requester_pays(workspace_name, workspace_namespace) with ThreadPoolExecutor(max_workers=IO_CONCURRENCY) as oneshot_executor: oneshot_pool = async_collections.AsyncSet(oneshot_executor, concurrency=IO_CONCURRENCY) for drs_url in drs_urls: assert drs_url.startswith("drs://") src_client, src_info = resolve_drs_for_gs_storage(drs_url) src_bucket = src_client.bucket(src_info.bucket_name, user_project=workspace_namespace) src_blob = src_bucket.get_blob(src_info.key) basename = src_info.name or src_info.key.rsplit("/", 1)[-1] if dst.startswith("gs://"): if dst.endswith("/"): raise ValueError("Bucket destination cannot end with '/'") dst_bucket_name, dst_pfx = _bucket_name_and_key(dst) dst_bucket = gs.get_client().bucket(dst_bucket_name) dst_key = f"{dst_pfx}/{basename}" if MULTIPART_THRESHOLD >= src_blob.size: oneshot_pool.put(gs.oneshot_copy, src_bucket, dst_bucket, src_info.key, dst_key) else: gs.multipart_copy(src_bucket, dst_bucket, src_info.key, dst_key) else: oneshot_pool.put(copy_to_local, drs_url, dst, workspace_name, workspace_namespace) for _ in oneshot_pool.consume(): pass
def test_copy(self): with self.subTest("test local"): with tempfile.NamedTemporaryFile() as tf: self._test_cmd(terra_notebook_utils.cli.commands.drs.drs_copy, drs_url=self.drs_url, dst=tf.name, workspace=WORKSPACE_NAME, workspace_namespace=WORKSPACE_NAMESPACE) with open(tf.name, "rb") as fh: data = fh.read() self.assertEqual(_crc32c(data), self.expected_crc32c) with self.subTest("test gs"): key = "test-drs-cli-object" self._test_cmd(terra_notebook_utils.cli.commands.drs.drs_copy, drs_url=self.drs_url, dst=f"gs://{WORKSPACE_BUCKET}/{key}", workspace=WORKSPACE_NAME, workspace_namespace=WORKSPACE_NAMESPACE) blob = gs.get_client().bucket(WORKSPACE_BUCKET).get_blob(key) out = io.BytesIO() blob.download_to_file(out) blob.reload( ) # download_to_file causes the crc32c to change, for some reason. Reload blob to recover. self.assertEqual(self.expected_crc32c, blob.crc32c) self.assertEqual(_crc32c(out.getvalue()), blob.crc32c)
def remove_workflow_logs(bucket_name=WORKSPACE_BUCKET, submission_id: str = None) -> List[str]: """ Experimental: do not use """ bucket = gs.get_client().bucket(bucket_name) def _is_workflow_log(blob): fname = blob.name.rsplit("/", 1)[-1] return fname.startswith("workflow") and fname.endswith(".log") def _delete_blob(blob): blob.delete() return blob.name prefixes = { blob.name.split("/", 1)[0] for blob in bucket.list_blobs() if _is_workflow_log(blob) } if submission_id is not None: prefixes = {pfx for pfx in prefixes if pfx == submission_id} for pfx in prefixes: blobs_to_delete = [blob for blob in bucket.list_blobs(prefix=pfx)] print(f"Deleting {len(blobs_to_delete)} objects for {pfx}") with ThreadPoolExecutor(max_workers=8) as e: futures = [ e.submit(_delete_blob, blob) for blob in blobs_to_delete ] deleted_manifest = [f.result() for f in as_completed(futures)] return deleted_manifest
def test_extract(self): with self.subTest("Test tarball extraction to local filesystem"): with tempfile.TemporaryDirectory() as tempdir: with open("tests/fixtures/test_archive.tar.gz", "rb") as fh: tar_gz.extract(fh, root=tempdir) for filename in glob.glob("tests/fixtures/test_archive/*"): with open(filename) as a: with open( os.path.join(f"{tempdir}/test_archive", os.path.basename(filename))) as b: self.assertEqual(a.read(), b.read()) with self.subTest("Test tarball extraction to GS bucket"): start_time = time.time() client = gs.get_client() bucket = client.bucket(os.environ['WORKSPACE_BUCKET'][5:]) key_pfx = "untar_test/{uuid4()}" with open("tests/fixtures/test_archive.tar.gz", "rb") as fh: tar_gz.extract(fh, root=f"gs://{bucket.name}/{key_pfx}") for filename in glob.glob("tests/fixtures/test_archive/*"): key = f"{key_pfx}/test_archive/{os.path.basename(filename)}" blob = bucket.get_blob(key) self.assertIsNotNone(blob) age = (datetime.now(pytz.utc) - blob.time_created).total_seconds() self.assertGreater(time.time() - start_time, age)
def _get_native_bucket( bucket: Union[str, GSNativeBucket], credentials: Optional[dict] = None, billing_project: Optional[str] = None) -> GSNativeBucket: if isinstance(bucket, str): kwargs = dict() if billing_project is not None: kwargs['user_project'] = billing_project bucket = gcp.get_client(credentials, billing_project).bucket(bucket, **kwargs) return bucket
def remove_workflow_logs(self): bucket = gs.get_client().bucket(WORKSPACE_BUCKET) submission_ids = [f"{uuid4()}", f"{uuid4()}"] manifests = [ self._upload_workflow_logs(bucket, submission_id) for submission_id in submission_ids ] for submission_id, manifest in zip(submission_ids, manifests): deleted_manifest = workspace.remove_workflow_logs( submission_id=submission_id) self.assertEqual(sorted(manifest), sorted(deleted_manifest))
def _get_blob(path: str, google_project: str) -> google.cloud.storage.blob: if path.startswith("gs://"): from terra_notebook_utils import gs path = path.split("gs://", 1)[1] bucket_name, key = path.split("/", 1) blob = gs.get_client(project=google_project).bucket(bucket_name).get_blob(key) elif path.startswith("drs://"): from terra_notebook_utils import drs client, info = drs.resolve_drs_for_gs_storage(path) blob = client.bucket(info.bucket_name, user_project=google_project).get_blob(info.key) else: blob = None return blob
def test_copy_batch_manifest(self): drs_uris = { "CCDG_13607_B01_GRM_WGS_2019-02-19_chr2.recalibrated_variants.annotated.clinical.txt": DRS_URI_003_MB, "CCDG_13607_B01_GRM_WGS_2019-02-19_chr9.recalibrated_variants.annotated.clinical.txt": DRS_URI_370_KB, "CCDG_13607_B01_GRM_WGS_2019-02-19_chr3.recalibrated_variants.annotated.clinical.txt": DRS_URI_500_KB, } named_drs_uris = { f"{uuid4()}": DRS_URI_003_MB, f"{uuid4()}": DRS_URI_370_KB, f"{uuid4()}": DRS_URI_500_KB, } pfx = f"test-batch-copy/{uuid4()}" bucket = gs.get_client().bucket(TNU_TEST_GS_BUCKET) with tempfile.TemporaryDirectory() as dirname: # create a mixed manifest with local and cloud destinations manifest = [ dict( drs_uri=uri, dst= f"gs://{os.environ['TNU_BLOBSTORE_TEST_GS_BUCKET']}/{pfx}/" ) for uri in drs_uris.values() ] manifest.extend([ dict( drs_uri=uri, dst= f"gs://{os.environ['TNU_BLOBSTORE_TEST_GS_BUCKET']}/{pfx}/{name}" ) for name, uri in named_drs_uris.items() ]) manifest.extend( [dict(drs_uri=uri, dst=dirname) for uri in drs_uris.values()]) manifest.extend([ dict(drs_uri=uri, dst=os.path.join(dirname, name)) for name, uri in named_drs_uris.items() ]) drs.copy_batch_manifest(manifest) for name in dict(**drs_uris, **named_drs_uris): blob = bucket.get_blob(f"{pfx}/{name}") self.assertGreater(blob.size, 0) names = [os.path.basename(path) for path in _list_tree(dirname)] self.assertEqual( sorted(names), sorted(list(dict(**drs_uris, **named_drs_uris).keys()))) with self.subTest("malformed manifest"): manifest = [dict(a="b"), dict(drs_uri="drs://foo", dst=".")] with self.assertRaises(jsonschema.exceptions.ValidationError): drs.copy_batch_manifest(manifest)
def _get_fileobj(uri: str): if uri.startswith("gs://"): bucket_name, key = uri[5:].split("/", 1) blob = gs.get_client().bucket( bucket_name, user_project=WORKSPACE_GOOGLE_PROJECT).get_blob(key) fh = gscio.Reader(blob, chunk_size=1024**2) elif uri.startswith("drs://"): gs_client, drs_info = drs.resolve_drs_for_gs_storage(uri) bucket = gs_client.bucket(drs_info.bucket_name, user_project=WORKSPACE_GOOGLE_PROJECT) fh = gscio.Reader(bucket.get_blob(drs_info.key), chunk_size=1024**2) else: fh = open(uri, "rb") return fh
def _upload_workflow_logs(self, bucket, submission_id): bucket = gs.get_client().bucket(WORKSPACE_BUCKET) with open("tests/fixtures/workflow_logs_manifest.txt") as fh: manifest = [ line.strip().format(submission_id=submission_id) for line in fh ] with ThreadPoolExecutor(max_workers=8) as e: futures = [ e.submit(bucket.blob(key).upload_from_file, io.BytesIO(b"")) for key in manifest ] for f in as_completed(futures): f.result() return manifest
def __init__(self, bucket_name: str, key: str, credentials: Optional[dict] = None, billing_project: Optional[str] = None): super().__init__() kwargs = dict() if billing_project is not None: kwargs['user_project'] = billing_project bucket = gcp.get_client(credentials, billing_project).bucket(bucket_name, **kwargs) self._executor = ThreadPoolExecutor(max_workers=IO_CONCURRENCY) async_set = gscio.async_collections.AsyncSet(self._executor, IO_CONCURRENCY) self._part_uploader = gscio.AsyncPartUploader(key, bucket, async_set) self._part_number = 0
def test_extract_tar_gz(self): expected_data = ( b"\x1f\x8b\x08\x04\x00\x00\x00\x00\x00\xff\x06\x00BC\x02\x00\x90 \xed]kO\\I\x92\xfd\xcc\xfc" b"\x8a\xd2\xb4V\xfb2\xd7\xf9~\xac\x97\x910\xb6i$\x1b\xbb\r\xdb=\xd3_\x10\x862\xae\x1d\x0c" b"\x0cU\xb8\xa7G\xfe\xf1{\xe2\xc6\xc9\xa2\xc0\xb8\xdb\xdd\xf2,_\xd2R\xc8\x87" ) # This test uses a hack property, `_extract_single_chunk`, to extract a small amount # of data from the cohort vcf pointed to by `drs://dg.4503/da8cb525-4532-4d0f-90a3-4d327817ec73`. with mock.patch("terra_notebook_utils.tar_gz._extract_single_chunk", True): drs_url = "drs://dg.4503/da8cb525-4532-4d0f-90a3-4d327817ec73" # cohort VCF tarball pfx = f"test_cohort_extract_{uuid4()}" drs.extract_tar_gz(drs_url, pfx) for key in gs.list_bucket(pfx): blob = gs.get_client().bucket(WORKSPACE_BUCKET).get_blob(key) data = blob.download_as_bytes()[:len(expected_data)] self.assertEqual(data, expected_data)
def extract_tar_gz(drs_url: str, dst_pfx: str=None, dst_bucket_name: str=None, workspace_name: Optional[str]=WORKSPACE_NAME, workspace_namespace: Optional[str]=WORKSPACE_GOOGLE_PROJECT): """ Extract a `.tar.gz` archive resolved by a DRS url into a Google Storage bucket. """ if dst_bucket_name is None: dst_bucket_name = WORKSPACE_BUCKET enable_requester_pays(workspace_name, workspace_namespace) src_client, src_info = resolve_drs_for_gs_storage(drs_url) src_bucket = src_client.bucket(src_info.bucket_name, user_project=workspace_namespace) dst_bucket = gs.get_client().bucket(dst_bucket_name) with ThreadPoolExecutor(max_workers=IO_CONCURRENCY) as e: async_queue = async_collections.AsyncQueue(e, IO_CONCURRENCY) with gscio.Reader(src_bucket.get_blob(src_info.key), async_queue=async_queue) as fh: tar_gz.extract(fh, dst_bucket, root=dst_pfx)
def vcf_info( uri: str, workspace_namespace: Optional[str] = WORKSPACE_GOOGLE_PROJECT ) -> VCFInfo: if uri.startswith("drs://"): client, drs_info = drs.resolve_drs_for_gs_storage(uri) blob = client.bucket(drs_info.bucket_name, user_project=workspace_namespace).get_blob( drs_info.key) return VCFInfo.with_blob(blob) elif uri.startswith("gs://"): bucket, key = uri[5:].split("/", 1) blob = gs.get_client().bucket( bucket, user_project=workspace_namespace).get_blob(key) return VCFInfo.with_blob(blob) elif uri.startswith("s3://"): raise ValueError("S3 URIs not supported") else: return VCFInfo.with_file(uri)
def resolve_drs_for_gs_storage(drs_url: str) -> Tuple[gs.Client, DRSInfo]: """ Attempt to resolve gs:// url and credentials for a DRS object. Instantiate and return the Google Storage client. """ assert drs_url.startswith("drs://") try: info = resolve_drs_info_for_gs_storage(drs_url) except DRSResolutionError: raise except Exception: raise if info.credentials is not None: project_id = info.credentials['project_id'] else: project_id = None client = gs.get_client(info.credentials, project=project_id) return client, info
def copy_to_bucket(drs_url: str, dst_key: str, dst_bucket_name: str=None, workspace_name: Optional[str]=WORKSPACE_NAME, workspace_namespace: Optional[str]=WORKSPACE_GOOGLE_PROJECT): """ Resolve `drs_url` and copy into user-specified bucket `dst_bucket`. If `dst_bucket` is None, copy into workspace bucket. """ assert drs_url.startswith("drs://") enable_requester_pays(workspace_name, workspace_namespace) if dst_bucket_name is None: dst_bucket_name = WORKSPACE_BUCKET src_client, src_info = resolve_drs_for_gs_storage(drs_url) if not dst_key: dst_key = src_info.name or src_info.key.rsplit("/", 1)[-1] dst_client = gs.get_client() src_bucket = src_client.bucket(src_info.bucket_name, user_project=workspace_namespace) dst_bucket = dst_client.bucket(dst_bucket_name) logger.info(f"Beginning to copy from {src_bucket} to {dst_bucket}. This can take a while for large files...") gs.copy(src_bucket, dst_bucket, src_info.key, dst_key)
def _gs_obj_exists(self, key: str) -> bool: return gs.get_client().bucket(TNU_TEST_GS_BUCKET).blob(key).exists()
def list_bucket(prefix="", bucket=WORKSPACE_BUCKET): for blob in gs.get_client().bucket(bucket).list_blobs(prefix=prefix): yield blob.name
def _gs_obj_exists(self, key: str) -> bool: return gs.get_client().bucket(WORKSPACE_BUCKET).blob(key).exists()
def test_get_client(self): gs.get_client()