示例#1
0
async def test_object_life_cycle(bucket_name, creds, uploaded_data,
                                 expected_data, file_extension):
    object_name = f'{uuid.uuid4().hex}/{uuid.uuid4().hex}.{file_extension}'
    copied_object_name = f'copyof_{object_name}'

    async with Session() as session:
        storage = Storage(service_file=creds, session=session)
        await storage.upload(bucket_name, object_name, uploaded_data)

        bucket = storage.get_bucket(bucket_name)
        blob = await bucket.get_blob(object_name)
        constructed_result = await blob.download()
        assert constructed_result == expected_data

        direct_result = await storage.download(bucket_name, object_name)
        assert direct_result == expected_data

        await storage.copy(bucket_name, object_name, bucket_name,
                           new_name=copied_object_name)

        direct_result = await storage.download(bucket_name, copied_object_name)
        assert direct_result == expected_data

        await storage.delete(bucket_name, object_name)
        await storage.delete(bucket_name, copied_object_name)

        with pytest.raises(ResponseError):
            await storage.download(bucket_name, object_name)

        with pytest.raises(ResponseError):
            await storage.download(bucket_name, copied_object_name)
示例#2
0
 async def initialize(self) -> None:
     logging.vlog(1, 'Initializing CardDb.')
     if FLAGS.carddb_local_file:
         logging.info('Initializing CardDb from local file: %s',
                      FLAGS.carddb_local_file)
         with open(FLAGS.carddb_local_file, 'r') as fin:
             db_json = fin.read()
     else:
         logging.info('Initializing CardDb from cloud file: %s/%s',
                      CARDDB_BUCKET, CARDDB_DB_FILE)
         storage = Storage()
         bucket = storage.get_bucket(CARDDB_BUCKET)
         blob = await bucket.get_blob(CARDDB_DB_FILE)
         db_json = await blob.download()
         logging.info('Loaded cloud file.')
     await self._parse_db_json(db_json)
     self._is_initialized.set()
示例#3
0
async def main(loop, balrog_api, bucket_name, limit_to, concurrency):
    # limit the number of connections at any one time
    sem = asyncio.Semaphore(concurrency)
    releases = defaultdict(int)
    uploads = defaultdict(lambda: defaultdict(int))
    tasks = []

    n = 0

    async with aiohttp.ClientSession(loop=loop) as session:
        storage = Storage(session=session)
        bucket = storage.get_bucket(bucket_name)

        to_process = (await (await session.get("{}/releases".format(balrog_api))).json())["releases"]
        for r in to_process:
            release_name = r["name"]

            if limit_to and n >= limit_to:
                break

            n += 1

            if any(pat in release_name for pat in skip_patterns):
                print("Skipping {} because it matches a skip pattern".format(release_name), flush=True)
                continue

            tasks.append(loop.create_task(process_release(release_name, session, balrog_api, bucket, sem, loop)))

        for processed_releases, processed_uploads in await asyncio.gather(*tasks, loop=loop):
            for rel in processed_releases:
                releases[rel] += processed_releases[rel]
            for u in processed_uploads:
                uploads[u]["uploaded"] += processed_uploads[u]["uploaded"]
                uploads[u]["existing"] += processed_uploads[u]["existing"]

    for r in releases:
        revs_in_gcs = uploads[r]["uploaded"] + uploads[r]["existing"]
        print("INFO: {}: Found {} existing revisions, uploaded {} new ones".format(r, uploads[r]["existing"], uploads[r]["uploaded"]))
        if r not in uploads:
            print("WARNING: {} was found in the Balrog API but does not exist in GCS".format(r))
        elif releases[r] != revs_in_gcs:
            print("WARNING: {} has a data version of {} in the Balrog API, but {} revisions exist in GCS".format(r, releases[r], revs_in_gcs))
示例#4
0
class GCSLogFilemanager(LogFileManager):
    def __init__(self, location, creds_path=None, trace_configs=None):
        from gcloud.aio.storage import Storage

        self.bucket_name = URL(location).host
        self.session = ClientSession(trace_configs=trace_configs)
        self.storage = Storage(service_file=creds_path, session=self.session)
        self.bucket = self.storage.get_bucket(self.bucket_name)

    def _get_object_name(self, pkg, run_id, name):
        return "%s/%s/%s.gz" % (pkg, run_id, name)

    async def has_log(self, pkg, run_id, name):
        object_name = self._get_object_name(pkg, run_id, name)
        return await self.bucket.blob_exists(object_name, self.session)

    async def get_log(self, pkg, run_id, name, timeout=30):
        object_name = self._get_object_name(pkg, run_id, name)
        try:
            data = await self.storage.download(
                self.bucket_name, object_name, session=self.session, timeout=timeout
            )
            return BytesIO(gzip.decompress(data))
        except ClientResponseError as e:
            if e.status == 404:
                raise FileNotFoundError(name)
            raise ServiceUnavailable()
        except ServerDisconnectedError:
            raise ServiceUnavailable()

    async def import_log(self, pkg, run_id, orig_path, timeout=360):
        object_name = self._get_object_name(pkg, run_id, os.path.basename(orig_path))
        with open(orig_path, "rb") as f:
            uploaded_data = gzip.compress(f.read())
        try:
            await self.storage.upload(
                self.bucket_name, object_name, uploaded_data, timeout=timeout
            )
        except ClientResponseError as e:
            if e.status == 503:
                raise ServiceUnavailable()
            raise
async def test_object_life_cycle(bucket_name, creds, project, uploaded_data,
                                 expected_data, file_extension):
    object_name = f'{uuid.uuid4().hex}/{uuid.uuid4().hex}.{file_extension}'

    async with aiohttp.ClientSession() as session:
        storage = Storage(project, creds, session=session)
        await storage.upload(bucket_name, object_name, uploaded_data)

        bucket = storage.get_bucket(bucket_name)
        blob = await bucket.get_blob(object_name)
        constructed_result = await blob.download()
        assert constructed_result == expected_data

        direct_result = await storage.download(bucket_name, object_name)
        assert direct_result == expected_data

        await storage.delete(bucket_name, object_name)

        with pytest.raises(aiohttp.client_exceptions.ClientResponseError):
            await storage.download(bucket_name, object_name)
async def doaj_trio(request):
    try:
        encoded_data = request.data
        string_data = encoded_data.decode()
        data = json.loads(string_data)
        if data["t"] == settings.token:
            async with Session() as session:
                storage = Storage(session=session)
                bucket = storage.get_bucket(bucket_name)
                blob = data["f"]
                print(blob)
                blob_object = await bucket.get_blob(blob)
                raw_data = await blob_object.download()

                journal_nlp = nlp(str(raw_data)[:100000])
                user_nlp = nlp(data["d"])
                sim = user_nlp.similarity(journal_nlp)
                return str(sim)
        else:
            return Response("Forbidden", status=403, mimetype="text/plain")
    except:
        return Response("Error", status=500, mimetype="text/plain")
示例#7
0
async def test_object_life_cycle(uploaded_data, expected_data):
    bucket_name = 'talkiq-integration-test'
    object_name = f'{uuid.uuid4().hex}/{uuid.uuid4().hex}.txt'

    async with aiohttp.ClientSession() as session:
        storage = Storage(PROJECT, CREDS, session=session)
        await storage.upload(bucket_name, object_name, uploaded_data)

        bucket = storage.get_bucket(bucket_name)
        blob = await bucket.get_blob(object_name)
        contructed_result = await blob.download_as_string()

        assert contructed_result == expected_data

        direct_result = await storage.download_as_string(
            bucket_name, object_name)

        assert direct_result == expected_data

        await storage.delete(bucket_name, object_name)

        with pytest.raises(aiohttp.client_exceptions.ClientResponseError):
            await storage.download_as_string(bucket_name, object_name)
示例#8
0
async def main(loop, balrog_api, bucket_name, limit_to, concurrency,
               skip_toplevel_keys, whitelist):
    # limit the number of connections at any one time
    sem = asyncio.Semaphore(concurrency)
    releases = defaultdict(int)
    uploads = defaultdict(lambda: defaultdict(int))
    tasks = []

    n = 0

    async with aiohttp.ClientSession(loop=loop) as session:
        storage = Storage(session=session)
        bucket = storage.get_bucket(bucket_name)

        toplevel_keys = []
        if skip_toplevel_keys:
            batch = await storage.list_objects(bucket_name,
                                               params={"delimiter": "/"})
            while batch:
                toplevel_keys.extend(
                    [name.rstrip("/") for name in batch.get("prefixes")])
                if batch.get("nextPageToken"):
                    batch = await storage.list_objects(
                        bucket_name,
                        params={
                            "delimiter": "/",
                            "pageToken": batch["nextPageToken"]
                        })
                else:
                    batch = None

        to_process = (await (await session.get("{}/releases".format(balrog_api)
                                               )).json())["releases"]
        for r in to_process:
            release_name = r["name"]

            if skip_toplevel_keys and release_name in toplevel_keys:
                print("Skipping {} because it is an existing toplevel key".
                      format(release_name),
                      flush=True)
                continue

            if whitelist and release_name not in whitelist:
                print("Skipping {} because it is not in the whitelist".format(
                    release_name),
                      flush=True)
                continue

            if limit_to and n >= limit_to:
                break

            n += 1

            if any(pat in release_name for pat in skip_patterns):
                print("Skipping {} because it matches a skip pattern".format(
                    release_name),
                      flush=True)
                continue

            tasks.append(
                loop.create_task(
                    process_release(release_name, session, balrog_api, bucket,
                                    sem, loop)))

        for processed_releases, processed_uploads in await asyncio.gather(
                *tasks, loop=loop):
            for rel in processed_releases:
                releases[rel] += processed_releases[rel]
            for u in processed_uploads:
                uploads[u]["uploaded"] += processed_uploads[u]["uploaded"]
                uploads[u]["existing"] += processed_uploads[u]["existing"]

    for r in releases:
        revs_in_gcs = uploads[r]["uploaded"] + uploads[r]["existing"]
        print("INFO: {}: Found {} existing revisions, uploaded {} new ones".
              format(r, uploads[r]["existing"], uploads[r]["uploaded"]))
        if r not in uploads:
            print(
                "WARNING: {} was found in the Balrog API but does not exist in GCS"
                .format(r))
        elif releases[r] != revs_in_gcs:
            print(
                "WARNING: {} has a data version of {} in the Balrog API, but {} revisions exist in GCS"
                .format(r, releases[r], revs_in_gcs))
示例#9
0
class GCSArtifactManager(ArtifactManager):
    def __init__(self, location, creds_path=None, trace_configs=None):
        self.bucket_name = URL(location).host
        self.creds_path = creds_path
        self.trace_configs = trace_configs

    def __repr__(self):
        return "%s(%r)" % (type(self).__name__, "gs://%s/" % self.bucket_name)

    async def __aenter__(self):
        from gcloud.aio.storage import Storage

        self.session = ClientSession(trace_configs=self.trace_configs)
        await self.session.__aenter__()
        self.storage = Storage(service_file=self.creds_path,
                               session=self.session)
        self.bucket = self.storage.get_bucket(self.bucket_name)

    async def __aexit__(self, exc_type, exc, tb):
        await self.session.__aexit__(exc_type, exc, tb)
        return False

    async def store_artifacts(self,
                              run_id,
                              local_path,
                              names=None,
                              timeout=None):
        if timeout is None:
            timeout = DEFAULT_GCS_TIMEOUT
        if names is None:
            names = os.listdir(local_path)
        if not names:
            return
        todo = []
        for name in names:
            with open(os.path.join(local_path, name), "rb") as f:
                uploaded_data = f.read()
                todo.append(
                    self.storage.upload(
                        self.bucket_name,
                        "%s/%s" % (run_id, name),
                        uploaded_data,
                        timeout=timeout,
                    ))
        try:
            await asyncio.gather(*todo)
        except ClientResponseError as e:
            if e.status == 503:
                raise ServiceUnavailable()
            raise
        logging.info("Uploaded %r to run %s in bucket %s.", names, run_id,
                     self.bucket_name)

    async def iter_ids(self):
        ids = set()
        for name in await self.bucket.list_blobs():
            log_id = name.split("/")[0]
            if log_id not in ids:
                yield log_id
            ids.add(log_id)

    async def retrieve_artifacts(self,
                                 run_id,
                                 local_path,
                                 filter_fn=None,
                                 timeout=None):
        if timeout is None:
            timeout = DEFAULT_GCS_TIMEOUT
        names = await self.bucket.list_blobs(prefix=run_id + "/")
        if not names:
            raise ArtifactsMissing(run_id)

        async def download_blob(name):
            with open(os.path.join(local_path, os.path.basename(name)),
                      "wb+") as f:
                f.write(await self.storage.download(bucket=self.bucket_name,
                                                    object_name=name,
                                                    timeout=timeout))

        await asyncio.gather(*[
            download_blob(name) for name in names
            if filter_fn is None or filter_fn(os.path.basename(name))
        ])

    async def get_artifact(self,
                           run_id,
                           filename,
                           timeout=DEFAULT_GCS_TIMEOUT):
        try:
            return BytesIO(await self.storage.download(
                bucket=self.bucket_name,
                object_name="%s/%s" % (run_id, filename),
                timeout=timeout,
            ))
        except ClientResponseError as e:
            if e.status == 503:
                raise ServiceUnavailable()
            if e.status == 404:
                raise FileNotFoundError
            raise