def test_copy_batch(self):
        drs_urls = {
            # 1631686 bytes # name property disapeard from DRS response :(
            # "NWD522743.b38.irc.v1.cram.crai": "drs://dg.4503/95cc4ae1-dee7-4266-8b97-77cf46d83d35",  # 1631686 bytes
            "NWD522743.b38.irc.v1.cram.crai":
            "drs://dg.4503/95cc4ae1-dee7-4266-8b97-77cf46d83d35",
            "data_phs001237.v2.p1.c1.avro.gz":
            "drs://dg.4503/26e11149-5deb-4cd7-a475-16997a825655",  # 1115092 bytes
            "RootStudyConsentSet_phs001237.TOPMed_WGS_WHI.v2.p1.c1.HMB-IRB.tar.gz":
            "drs://dg.4503/e9c2caf2-b2a1-446d-92eb-8d5389e99ee3",  # 332237 bytes

            # "NWD961306.freeze5.v1.vcf.gz": "drs://dg.4503/6e73a376-f7fd-47ed-ac99-0567bb5a5993",  # 2679331445 bytes
            # "NWD531899.freeze5.v1.vcf.gz": "drs://dg.4503/651a4ad1-06b5-4534-bb2c-1f8ed51134f6",  # 2679411265 bytes
        }
        pfx = f"test-batch-copy/{uuid4()}"
        bucket = gs.get_client().bucket(WORKSPACE_BUCKET)
        with self.subTest("gs bucket"):
            with mock.patch("terra_notebook_utils.drs.MULTIPART_THRESHOLD",
                            400000):
                drs.copy_batch(
                    list(drs_urls.values()),
                    f"gs://fc-9169fcd1-92ce-4d60-9d2d-d19fd326ff10/{pfx}")
                for name in list(drs_urls.keys()):
                    blob = bucket.get_blob(f"{pfx}/{name}")
                    self.assertGreater(blob.size, 0)
        with self.subTest("local filesystem"):
            with tempfile.TemporaryDirectory() as dirname:
                drs.copy_batch(list(drs_urls.values()), dirname)
                names = [
                    os.path.basename(path) for path in _list_tree(dirname)
                ]
                self.assertEqual(sorted(names), sorted(list(drs_urls.keys())))
Beispiel #2
0
def copy_batch(drs_urls: Iterable[str],
               dst: str,
               workspace_name: Optional[str]=WORKSPACE_NAME,
               workspace_namespace: Optional[str]=WORKSPACE_GOOGLE_PROJECT):
    enable_requester_pays(workspace_name, workspace_namespace)
    with ThreadPoolExecutor(max_workers=IO_CONCURRENCY) as oneshot_executor:
        oneshot_pool = async_collections.AsyncSet(oneshot_executor, concurrency=IO_CONCURRENCY)
        for drs_url in drs_urls:
            assert drs_url.startswith("drs://")
            src_client, src_info = resolve_drs_for_gs_storage(drs_url)
            src_bucket = src_client.bucket(src_info.bucket_name, user_project=workspace_namespace)
            src_blob = src_bucket.get_blob(src_info.key)
            basename = src_info.name or src_info.key.rsplit("/", 1)[-1]
            if dst.startswith("gs://"):
                if dst.endswith("/"):
                    raise ValueError("Bucket destination cannot end with '/'")
                dst_bucket_name, dst_pfx = _bucket_name_and_key(dst)
                dst_bucket = gs.get_client().bucket(dst_bucket_name)
                dst_key = f"{dst_pfx}/{basename}"
                if MULTIPART_THRESHOLD >= src_blob.size:
                    oneshot_pool.put(gs.oneshot_copy, src_bucket, dst_bucket, src_info.key, dst_key)
                else:
                    gs.multipart_copy(src_bucket, dst_bucket, src_info.key, dst_key)
            else:
                oneshot_pool.put(copy_to_local, drs_url, dst, workspace_name, workspace_namespace)
        for _ in oneshot_pool.consume():
            pass
    def test_copy(self):
        with self.subTest("test local"):
            with tempfile.NamedTemporaryFile() as tf:
                self._test_cmd(terra_notebook_utils.cli.commands.drs.drs_copy,
                               drs_url=self.drs_url,
                               dst=tf.name,
                               workspace=WORKSPACE_NAME,
                               workspace_namespace=WORKSPACE_NAMESPACE)
                with open(tf.name, "rb") as fh:
                    data = fh.read()
                self.assertEqual(_crc32c(data), self.expected_crc32c)

        with self.subTest("test gs"):
            key = "test-drs-cli-object"
            self._test_cmd(terra_notebook_utils.cli.commands.drs.drs_copy,
                           drs_url=self.drs_url,
                           dst=f"gs://{WORKSPACE_BUCKET}/{key}",
                           workspace=WORKSPACE_NAME,
                           workspace_namespace=WORKSPACE_NAMESPACE)
            blob = gs.get_client().bucket(WORKSPACE_BUCKET).get_blob(key)
            out = io.BytesIO()
            blob.download_to_file(out)
            blob.reload(
            )  # download_to_file causes the crc32c to change, for some reason. Reload blob to recover.
            self.assertEqual(self.expected_crc32c, blob.crc32c)
            self.assertEqual(_crc32c(out.getvalue()), blob.crc32c)
Beispiel #4
0
def remove_workflow_logs(bucket_name=WORKSPACE_BUCKET,
                         submission_id: str = None) -> List[str]:
    """
    Experimental: do not use
    """
    bucket = gs.get_client().bucket(bucket_name)

    def _is_workflow_log(blob):
        fname = blob.name.rsplit("/", 1)[-1]
        return fname.startswith("workflow") and fname.endswith(".log")

    def _delete_blob(blob):
        blob.delete()
        return blob.name

    prefixes = {
        blob.name.split("/", 1)[0]
        for blob in bucket.list_blobs() if _is_workflow_log(blob)
    }
    if submission_id is not None:
        prefixes = {pfx for pfx in prefixes if pfx == submission_id}
    for pfx in prefixes:
        blobs_to_delete = [blob for blob in bucket.list_blobs(prefix=pfx)]
        print(f"Deleting {len(blobs_to_delete)} objects for {pfx}")
        with ThreadPoolExecutor(max_workers=8) as e:
            futures = [
                e.submit(_delete_blob, blob) for blob in blobs_to_delete
            ]
            deleted_manifest = [f.result() for f in as_completed(futures)]

    return deleted_manifest
 def test_extract(self):
     with self.subTest("Test tarball extraction to local filesystem"):
         with tempfile.TemporaryDirectory() as tempdir:
             with open("tests/fixtures/test_archive.tar.gz", "rb") as fh:
                 tar_gz.extract(fh, root=tempdir)
             for filename in glob.glob("tests/fixtures/test_archive/*"):
                 with open(filename) as a:
                     with open(
                             os.path.join(f"{tempdir}/test_archive",
                                          os.path.basename(filename))) as b:
                         self.assertEqual(a.read(), b.read())
     with self.subTest("Test tarball extraction to GS bucket"):
         start_time = time.time()
         client = gs.get_client()
         bucket = client.bucket(os.environ['WORKSPACE_BUCKET'][5:])
         key_pfx = "untar_test/{uuid4()}"
         with open("tests/fixtures/test_archive.tar.gz", "rb") as fh:
             tar_gz.extract(fh, root=f"gs://{bucket.name}/{key_pfx}")
         for filename in glob.glob("tests/fixtures/test_archive/*"):
             key = f"{key_pfx}/test_archive/{os.path.basename(filename)}"
             blob = bucket.get_blob(key)
             self.assertIsNotNone(blob)
             age = (datetime.now(pytz.utc) -
                    blob.time_created).total_seconds()
             self.assertGreater(time.time() - start_time, age)
def _get_native_bucket(
        bucket: Union[str, GSNativeBucket],
        credentials: Optional[dict] = None,
        billing_project: Optional[str] = None) -> GSNativeBucket:
    if isinstance(bucket, str):
        kwargs = dict()
        if billing_project is not None:
            kwargs['user_project'] = billing_project
        bucket = gcp.get_client(credentials,
                                billing_project).bucket(bucket, **kwargs)
    return bucket
 def remove_workflow_logs(self):
     bucket = gs.get_client().bucket(WORKSPACE_BUCKET)
     submission_ids = [f"{uuid4()}", f"{uuid4()}"]
     manifests = [
         self._upload_workflow_logs(bucket, submission_id)
         for submission_id in submission_ids
     ]
     for submission_id, manifest in zip(submission_ids, manifests):
         deleted_manifest = workspace.remove_workflow_logs(
             submission_id=submission_id)
         self.assertEqual(sorted(manifest), sorted(deleted_manifest))
Beispiel #8
0
def _get_blob(path: str, google_project: str) -> google.cloud.storage.blob:
    if path.startswith("gs://"):
        from terra_notebook_utils import gs
        path = path.split("gs://", 1)[1]
        bucket_name, key = path.split("/", 1)
        blob = gs.get_client(project=google_project).bucket(bucket_name).get_blob(key)
    elif path.startswith("drs://"):
        from terra_notebook_utils import drs
        client, info = drs.resolve_drs_for_gs_storage(path)
        blob = client.bucket(info.bucket_name, user_project=google_project).get_blob(info.key)
    else:
        blob = None
    return blob
    def test_copy_batch_manifest(self):
        drs_uris = {
            "CCDG_13607_B01_GRM_WGS_2019-02-19_chr2.recalibrated_variants.annotated.clinical.txt":
            DRS_URI_003_MB,
            "CCDG_13607_B01_GRM_WGS_2019-02-19_chr9.recalibrated_variants.annotated.clinical.txt":
            DRS_URI_370_KB,
            "CCDG_13607_B01_GRM_WGS_2019-02-19_chr3.recalibrated_variants.annotated.clinical.txt":
            DRS_URI_500_KB,
        }
        named_drs_uris = {
            f"{uuid4()}": DRS_URI_003_MB,
            f"{uuid4()}": DRS_URI_370_KB,
            f"{uuid4()}": DRS_URI_500_KB,
        }
        pfx = f"test-batch-copy/{uuid4()}"
        bucket = gs.get_client().bucket(TNU_TEST_GS_BUCKET)
        with tempfile.TemporaryDirectory() as dirname:
            # create a mixed manifest with local and cloud destinations
            manifest = [
                dict(
                    drs_uri=uri,
                    dst=
                    f"gs://{os.environ['TNU_BLOBSTORE_TEST_GS_BUCKET']}/{pfx}/"
                ) for uri in drs_uris.values()
            ]
            manifest.extend([
                dict(
                    drs_uri=uri,
                    dst=
                    f"gs://{os.environ['TNU_BLOBSTORE_TEST_GS_BUCKET']}/{pfx}/{name}"
                ) for name, uri in named_drs_uris.items()
            ])
            manifest.extend(
                [dict(drs_uri=uri, dst=dirname) for uri in drs_uris.values()])
            manifest.extend([
                dict(drs_uri=uri, dst=os.path.join(dirname, name))
                for name, uri in named_drs_uris.items()
            ])
            drs.copy_batch_manifest(manifest)
            for name in dict(**drs_uris, **named_drs_uris):
                blob = bucket.get_blob(f"{pfx}/{name}")
                self.assertGreater(blob.size, 0)
            names = [os.path.basename(path) for path in _list_tree(dirname)]
            self.assertEqual(
                sorted(names),
                sorted(list(dict(**drs_uris, **named_drs_uris).keys())))

        with self.subTest("malformed manifest"):
            manifest = [dict(a="b"), dict(drs_uri="drs://foo", dst=".")]
            with self.assertRaises(jsonschema.exceptions.ValidationError):
                drs.copy_batch_manifest(manifest)
def _get_fileobj(uri: str):
    if uri.startswith("gs://"):
        bucket_name, key = uri[5:].split("/", 1)
        blob = gs.get_client().bucket(
            bucket_name, user_project=WORKSPACE_GOOGLE_PROJECT).get_blob(key)
        fh = gscio.Reader(blob, chunk_size=1024**2)
    elif uri.startswith("drs://"):
        gs_client, drs_info = drs.resolve_drs_for_gs_storage(uri)
        bucket = gs_client.bucket(drs_info.bucket_name,
                                  user_project=WORKSPACE_GOOGLE_PROJECT)
        fh = gscio.Reader(bucket.get_blob(drs_info.key), chunk_size=1024**2)
    else:
        fh = open(uri, "rb")
    return fh
 def _upload_workflow_logs(self, bucket, submission_id):
     bucket = gs.get_client().bucket(WORKSPACE_BUCKET)
     with open("tests/fixtures/workflow_logs_manifest.txt") as fh:
         manifest = [
             line.strip().format(submission_id=submission_id) for line in fh
         ]
     with ThreadPoolExecutor(max_workers=8) as e:
         futures = [
             e.submit(bucket.blob(key).upload_from_file, io.BytesIO(b""))
             for key in manifest
         ]
         for f in as_completed(futures):
             f.result()
     return manifest
 def __init__(self,
              bucket_name: str,
              key: str,
              credentials: Optional[dict] = None,
              billing_project: Optional[str] = None):
     super().__init__()
     kwargs = dict()
     if billing_project is not None:
         kwargs['user_project'] = billing_project
     bucket = gcp.get_client(credentials,
                             billing_project).bucket(bucket_name, **kwargs)
     self._executor = ThreadPoolExecutor(max_workers=IO_CONCURRENCY)
     async_set = gscio.async_collections.AsyncSet(self._executor,
                                                  IO_CONCURRENCY)
     self._part_uploader = gscio.AsyncPartUploader(key, bucket, async_set)
     self._part_number = 0
 def test_extract_tar_gz(self):
     expected_data = (
         b"\x1f\x8b\x08\x04\x00\x00\x00\x00\x00\xff\x06\x00BC\x02\x00\x90 \xed]kO\\I\x92\xfd\xcc\xfc"
         b"\x8a\xd2\xb4V\xfb2\xd7\xf9~\xac\x97\x910\xb6i$\x1b\xbb\r\xdb=\xd3_\x10\x862\xae\x1d\x0c"
         b"\x0cU\xb8\xa7G\xfe\xf1{\xe2\xc6\xc9\xa2\xc0\xb8\xdb\xdd\xf2,_\xd2R\xc8\x87"
     )
     # This test uses a hack property, `_extract_single_chunk`, to extract a small amount
     # of data from the cohort vcf pointed to by `drs://dg.4503/da8cb525-4532-4d0f-90a3-4d327817ec73`.
     with mock.patch("terra_notebook_utils.tar_gz._extract_single_chunk",
                     True):
         drs_url = "drs://dg.4503/da8cb525-4532-4d0f-90a3-4d327817ec73"  # cohort VCF tarball
         pfx = f"test_cohort_extract_{uuid4()}"
         drs.extract_tar_gz(drs_url, pfx)
         for key in gs.list_bucket(pfx):
             blob = gs.get_client().bucket(WORKSPACE_BUCKET).get_blob(key)
             data = blob.download_as_bytes()[:len(expected_data)]
             self.assertEqual(data, expected_data)
Beispiel #14
0
def extract_tar_gz(drs_url: str,
                   dst_pfx: str=None,
                   dst_bucket_name: str=None,
                   workspace_name: Optional[str]=WORKSPACE_NAME,
                   workspace_namespace: Optional[str]=WORKSPACE_GOOGLE_PROJECT):
    """
    Extract a `.tar.gz` archive resolved by a DRS url into a Google Storage bucket.
    """
    if dst_bucket_name is None:
        dst_bucket_name = WORKSPACE_BUCKET
    enable_requester_pays(workspace_name, workspace_namespace)
    src_client, src_info = resolve_drs_for_gs_storage(drs_url)
    src_bucket = src_client.bucket(src_info.bucket_name, user_project=workspace_namespace)
    dst_bucket = gs.get_client().bucket(dst_bucket_name)
    with ThreadPoolExecutor(max_workers=IO_CONCURRENCY) as e:
        async_queue = async_collections.AsyncQueue(e, IO_CONCURRENCY)
        with gscio.Reader(src_bucket.get_blob(src_info.key), async_queue=async_queue) as fh:
            tar_gz.extract(fh, dst_bucket, root=dst_pfx)
Beispiel #15
0
def vcf_info(
        uri: str,
        workspace_namespace: Optional[str] = WORKSPACE_GOOGLE_PROJECT
) -> VCFInfo:
    if uri.startswith("drs://"):
        client, drs_info = drs.resolve_drs_for_gs_storage(uri)
        blob = client.bucket(drs_info.bucket_name,
                             user_project=workspace_namespace).get_blob(
                                 drs_info.key)
        return VCFInfo.with_blob(blob)
    elif uri.startswith("gs://"):
        bucket, key = uri[5:].split("/", 1)
        blob = gs.get_client().bucket(
            bucket, user_project=workspace_namespace).get_blob(key)
        return VCFInfo.with_blob(blob)
    elif uri.startswith("s3://"):
        raise ValueError("S3 URIs not supported")
    else:
        return VCFInfo.with_file(uri)
Beispiel #16
0
def resolve_drs_for_gs_storage(drs_url: str) -> Tuple[gs.Client, DRSInfo]:
    """
    Attempt to resolve gs:// url and credentials for a DRS object. Instantiate and return the Google Storage client.
    """
    assert drs_url.startswith("drs://")

    try:
        info = resolve_drs_info_for_gs_storage(drs_url)
    except DRSResolutionError:
        raise
    except Exception:
        raise

    if info.credentials is not None:
        project_id = info.credentials['project_id']
    else:
        project_id = None

    client = gs.get_client(info.credentials, project=project_id)
    return client, info
Beispiel #17
0
def copy_to_bucket(drs_url: str,
                   dst_key: str,
                   dst_bucket_name: str=None,
                   workspace_name: Optional[str]=WORKSPACE_NAME,
                   workspace_namespace: Optional[str]=WORKSPACE_GOOGLE_PROJECT):
    """
    Resolve `drs_url` and copy into user-specified bucket `dst_bucket`.
    If `dst_bucket` is None, copy into workspace bucket.
    """
    assert drs_url.startswith("drs://")
    enable_requester_pays(workspace_name, workspace_namespace)
    if dst_bucket_name is None:
        dst_bucket_name = WORKSPACE_BUCKET
    src_client, src_info = resolve_drs_for_gs_storage(drs_url)
    if not dst_key:
        dst_key = src_info.name or src_info.key.rsplit("/", 1)[-1]
    dst_client = gs.get_client()
    src_bucket = src_client.bucket(src_info.bucket_name, user_project=workspace_namespace)
    dst_bucket = dst_client.bucket(dst_bucket_name)
    logger.info(f"Beginning to copy from {src_bucket} to {dst_bucket}. This can take a while for large files...")
    gs.copy(src_bucket, dst_bucket, src_info.key, dst_key)
 def _gs_obj_exists(self, key: str) -> bool:
     return gs.get_client().bucket(TNU_TEST_GS_BUCKET).blob(key).exists()
def list_bucket(prefix="", bucket=WORKSPACE_BUCKET):
    for blob in gs.get_client().bucket(bucket).list_blobs(prefix=prefix):
        yield blob.name
 def _gs_obj_exists(self, key: str) -> bool:
     return gs.get_client().bucket(WORKSPACE_BUCKET).blob(key).exists()
 def test_get_client(self):
     gs.get_client()