예제 #1
0
    def __init__(
        self,
        application_credentials: Optional[Union[str, os.PathLike]] = None,
        credentials: Optional["Credentials"] = None,
        project: Optional[str] = None,
        storage_client: Optional["StorageClient"] = None,
        local_cache_dir: Optional[Union[str, os.PathLike]] = None,
    ):
        """Class constructor. Sets up a [`Storage
        Client`](https://googleapis.dev/python/storage/latest/client.html).
        Supports the following authentication methods of `Storage Client`.

        - Environment variable `"GOOGLE_APPLICATION_CREDENTIALS"` containing a
          path to a JSON credentials file for a Google service account. See
          [Authenticating as a Service
          Account](https://cloud.google.com/docs/authentication/production).
        - File path to a JSON credentials file for a Google service account.
        - OAuth2 Credentials object and a project name.
        - Instantiated and already authenticated `Storage Client`.

        If multiple methods are used, priority order is reverse of list above
        (later in list takes priority).

        Args:
            application_credentials (Optional[Union[str, os.PathLike]]): Path to Google service
                account credentials file.
            credentials (Optional[Credentials]): The OAuth2 Credentials to use for this client.
                See documentation for [`StorageClient`](
                https://googleapis.dev/python/storage/latest/client.html).
            project (Optional[str]): The project which the client acts on behalf of. See
                documentation for [`StorageClient`](
                https://googleapis.dev/python/storage/latest/client.html).
            storage_client (Optional[StorageClient]): Instantiated [`StorageClient`](
                https://googleapis.dev/python/storage/latest/client.html).
            local_cache_dir (Optional[Union[str, os.PathLike]]): Path to directory to use as cache
                for downloaded files. If None, will use a temporary directory.
        """

        if storage_client is not None:
            self.client = storage_client
        elif credentials is not None:
            self.client = StorageClient(credentials=credentials, project=project)
        elif application_credentials is not None:
            self.client = StorageClient.from_service_account_json(application_credentials)
        else:
            self.client = StorageClient.from_service_account_json(
                os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
            )

        super().__init__(local_cache_dir=local_cache_dir)
예제 #2
0
    def lock(self):
        """
        This is the best we can do. It is impossible to acquire the lock reliably without
        using any additional services. test-and-set is impossible to implement.
        :return:
        """
        log = self._log
        log.info("Locking the bucket...")

        # Client should be imported here because grpc starts threads during import
        # and if you call fork after that, a child process will be hang during exit
        from google.cloud.storage import Client

        if self.credentials:
            client = Client.from_service_account_json(self.credentials)
        else:
            client = Client()
        bucket = client.get_bucket(self.bucket_name)
        self._bucket = bucket
        sentinel = bucket.blob("index.lock")
        try:
            while sentinel.exists():
                log.warning("Failed to acquire the lock, waiting...")
                time.sleep(1)
            sentinel.upload_from_string(b"")
            # Several agents can get here. No test-and-set, sorry!
            yield None
        finally:
            self._bucket = None
            if sentinel is not None:
                try:
                    sentinel.delete()
                except:
                    pass
예제 #3
0
 def open_gcs_url(config, logger, storage, url):
     reader_impl = SourceFile.extract_reader_impl(config)
     use_gcs_service_account = "service_account_json" in config["provider"] and storage == "gs://"
     file_to_close = None
     if reader_impl == "gcsfs":
         if use_gcs_service_account:
             try:
                 token_dict = json.loads(config["provider"]["service_account_json"])
             except json.decoder.JSONDecodeError as err:
                 logger.error(f"Failed to parse gcs service account json: {repr(err)}\n{traceback.format_exc()}")
                 raise err
         else:
             token_dict = "anon"
         fs = gcsfs.GCSFileSystem(token=token_dict)
         file_to_close = fs.open(f"gs://{url}")
         result = file_to_close
     else:
         if use_gcs_service_account:
             try:
                 credentials = json.dumps(json.loads(config["provider"]["service_account_json"]))
                 tmp_service_account = tempfile.NamedTemporaryFile(delete=False)
                 with open(tmp_service_account, "w") as f:
                     f.write(credentials)
                 tmp_service_account.close()
                 client = Client.from_service_account_json(tmp_service_account.name)
                 result = open(f"gs://{url}", transport_params=dict(client=client))
                 os.remove(tmp_service_account.name)
             except json.decoder.JSONDecodeError as err:
                 logger.error(f"Failed to parse gcs service account json: {repr(err)}\n{traceback.format_exc()}")
                 raise err
         else:
             client = Client.create_anonymous_client()
             result = open(f"{storage}{url}", transport_params=dict(client=client))
     return result, file_to_close
예제 #4
0
    def _get_native_gcp_handle() -> typing.Any:
        if Config.BLOBSTORE_CONNECT_TIMEOUT is None and Config.BLOBSTORE_READ_TIMEOUT is None:
            client = Client.from_service_account_json(
                os.environ['GOOGLE_APPLICATION_CREDENTIALS'], )
        else:
            # GCP has no direct interface to configure retries and timeouts. However, it makes use of Python's
            # stdlib `requests` package, which has straightforward timeout usage.
            class SessionWithTimeouts(AuthorizedSession):
                def request(self, *args, **kwargs):
                    kwargs['timeout'] = (Config.BLOBSTORE_CONNECT_TIMEOUT,
                                         Config.BLOBSTORE_READ_TIMEOUT)
                    return super().request(*args, **kwargs)

            credentials = service_account.Credentials.from_service_account_file(
                os.environ['GOOGLE_APPLICATION_CREDENTIALS'],
                scopes=Client.SCOPE)

            # _http is a "private" parameter, and we may need to re-visit GCP timeout retry
            # strategies in the future.
            client = Client(_http=SessionWithTimeouts(credentials),
                            credentials=credentials)

        adapter_kwargs = dict(pool_maxsize=max(DEFAULT_POOLSIZE, 20))
        if Config.BLOBSTORE_RETRIES is not None:
            adapter_kwargs['max_retries'] = Retry(
                total=Config.BLOBSTORE_RETRIES,
                backoff_factor=0.3,
                status_forcelist=(500, 502, 504))
        adapter = HTTPAdapter(**adapter_kwargs)
        # _http is a "private" parameter, and we may need to re-visit GCP timeout retry
        # strategies in the future.
        client._http.mount('https://', adapter)
        client._http.mount('http://', adapter)
        return client
예제 #5
0
    def _get_native_gcp_handle() -> typing.Any:
        if Config.BLOBSTORE_GS_MAX_CUMULATIVE_RETRY is not None:
            google.resumable_media.common.MAX_CUMULATIVE_RETRY = Config.BLOBSTORE_GS_MAX_CUMULATIVE_RETRY

        if Config.BLOBSTORE_CONNECT_TIMEOUT is None and Config.BLOBSTORE_READ_TIMEOUT is None:
            return Client.from_service_account_json(
                os.environ['GOOGLE_APPLICATION_CREDENTIALS'], )
        else:
            # GCP has no direct interface to configure retries and timeouts. However, it makes use of Python's
            # stdlib `requests` package, which has straightforward timeout usage.
            class SessionWithTimeouts(
                    google.auth.transport.requests.AuthorizedSession):
                def request(self, *args, **kwargs):
                    kwargs['timeout'] = (Config.BLOBSTORE_CONNECT_TIMEOUT,
                                         Config.BLOBSTORE_READ_TIMEOUT)
                    return super().request(*args, **kwargs)

            credentials = service_account.Credentials.from_service_account_file(
                os.environ['GOOGLE_APPLICATION_CREDENTIALS'],
                scopes=Client.SCOPE)

            # _http is a "private" parameter, and we may need to re-visit GCP timeout retry
            # strategies in the future.
            return Client(_http=SessionWithTimeouts(credentials),
                          credentials=credentials)
예제 #6
0
 def get_native_handle(replica: "Replica") -> typing.Any:
     if replica == Replica.aws:
         return boto3.client("s3")
     elif replica == Replica.gcp:
         credentials = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
         return Client.from_service_account_json(credentials)
     raise NotImplementedError(
         f"Replica `{replica.name}` is not implemented!")
예제 #7
0
    def gs(self):
        from google.cloud.storage import Client

        return (
            Client.from_service_account_json(self.credentialpath)
            if self.credentialpath
            else Client(self.projectname)
        )
    def _client(self):
        from google.cloud.storage import Client

        if type(self._credentials) == str:
            return Client.from_service_account_json(self._credentials)
        else:
            return Client(credentials=self._credentials,
                          project=self.project_name)
예제 #9
0
 def _create_default_client(self,
                            service_account_credentials_path=settings.
                            GCS_STORAGE_SERVICE_ACCOUNT_KEY_PATH):
     if service_account_credentials_path:
         return Client.from_service_account_json(
             service_account_credentials_path)
     else:
         return Client()
예제 #10
0
 def create_client(self):
     # Client should be imported here because grpc starts threads during import
     # and if you call fork after that, a child process will be hang during exit
     from google.cloud.storage import Client
     if self.credentials:
         client = Client.from_service_account_json(self.credentials)
     else:
         client = Client()
     return client
예제 #11
0
파일: clients.py 프로젝트: luke-zhu/blueno
 def __init__(self):
     client = self.client_cache.get('gcs')
     if client is None:
         self.client = GCSClient.from_service_account_json(
             env.GOOGLE_APPLICATION_CREDENTIALS)
     else:
         self.client: GCSClient = client
         self.client_cache.get('gcs')
     self.client_cache.put('gcs', self.client)
예제 #12
0
파일: auth.py 프로젝트: trisongz/hfsync
 def create_auth(self):
     if self.auth_params['service_account']:
         self.client = Client.from_service_account_json(
             self.auth_params['service_account'])
     elif self.auth_params['token']:
         self.sess = Credentials(token=self.auth_params['token'])
         self.client = Client(credentials=self.sess)
     else:
         self.client = None
예제 #13
0
 def __init__(self, credentials_file_path, bucket_name):
     self._credentials_file_path = credentials_file_path
     self._bucket_name = bucket_name
     try:
         self._client = Client.from_service_account_json(
             self._credentials_file_path,
         )
         self._bucket = self._client.get_bucket(self._bucket_name)
     except (GoogleAuthError, GoogleAPICallError) as e:
         raise StorageException(f'Failed to initialize GCSClient: {e}')
예제 #14
0
    def _prepare_client(self):
        keyfile_filename = "gcp_key_file.json"

        keyfile_path = os.path.join(config.get("tasks_path"), keyfile_filename)

        try:
            client = Client.from_service_account_json(keyfile_path)
        except TypeError as e:
            sys.exit(1)

        return client
예제 #15
0
 def client(self):
     """
     :return: used instance of :class:`google.cloud.storage.Client`.
     """
     if self._client is not None:
         return self._client
     if not self.project:
         self._client = GSClient()
     else:
         self._client = GSClient.from_service_account_json(
             self.keyfile, project=self.project)
     return self._client
예제 #16
0
 def client(self):
     """
     :return: used instance of :class:`google.cloud.storage.Client`.
     """
     try:
         return self._client
     except AttributeError:
         if not self.project:
             self._client = GSClient()
         else:
             self._client = GSClient.from_service_account_json(
                 self.keyfile, project=self.project)
         return self._client
def setUpModule():
    if os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"):
        GS.client = Client.from_service_account_json(
            os.environ['GOOGLE_APPLICATION_CREDENTIALS'])
    elif os.environ.get("GSCIO_TEST_CREDENTIALS"):
        import json
        import base64
        from google.oauth2.service_account import Credentials
        creds_info = json.loads(
            base64.b64decode(os.environ.get("GSCIO_TEST_CREDENTIALS")))
        creds = Credentials.from_service_account_info(creds_info)
        GS.client = Client(credentials=creds)
    else:
        GS.client = Client()
    GS.bucket = GS.client.bucket(os.environ['GSCIO_GOOGLE_TEST_BUCKET'])
예제 #18
0
    def retrieve_from_google(self, bucket, object):
        """Retrieves file from Google Cloud Storage.

        Args:
            bucket: Bucket to retrieve file from.
            object: File object to retrieve.

        Returns:
            A byte string containing the file content.
        """
        global gcs_client
        if gcs_client is None:
            gcs_client = Client.from_service_account_json(
                conf.remote_cfg["google_application_credentials"])
        return gcs_client.get_bucket(bucket).get_blob(
            object).download_as_string()
예제 #19
0
파일: gcs.py 프로젝트: clustree/modelkit
    def __init__(
        self,
        bucket: Optional[str] = None,
        service_account_path: Optional[str] = None,
        client: Optional[Client] = None,
    ):
        self.bucket = bucket or os.environ.get("MODELKIT_STORAGE_BUCKET") or ""
        if not self.bucket:
            raise ValueError("Bucket needs to be set for GCS storage driver")

        if client:
            self.client = client
        elif service_account_path:  # pragma: no cover
            self.client = Client.from_service_account_json(
                service_account_path)
        else:
            self.client = Client()
예제 #20
0
파일: fquery.py 프로젝트: nmatare/putils
    def __init__(self,
                 service_file,
                 dataset,
                 export_as="csv.gz",
                 verbose=True,
                 temp_path=create_temp_directory()[1]):
        self.bq_client = BQclient.from_service_account_json(service_file)
        self.gcs_client = GCSclient.from_service_account_json(service_file)

        self._use_legacy_sql = use_legacy_sql
        self._verbose = verbose
        self._temp_path = temp_path
        self._base_name = self._temp_path.replace("/tmp", "")
        self.temp_table = self.bq_client.dataset(dataset).table(
            self._base_name + "_temp_table")

        assert export_as in ["csv.gz", "csv", "avro"]
        self._export_as = export_as
예제 #21
0
    def __init__(self, config):

        super().__init__()

        self._webdav_enabled = config.getBool('Upload', 'webdav_enable')
        self._gcp_enabled = config.getBool('Upload', 'gcp_enable')

        if self._webdav_enabled:
            self._baseurl = config.get('Upload', 'webdav_url')
            if config.getBool('Upload', 'webdav_use_auth'):
                self._auth = (config.get('Upload', 'webdav_user'),
                              config.get('Upload', 'webdav_password'))
            else:
                self._auth = None
        if self._gcp_enabled:
            print("Initialized GCP!")
            self._bucket_name = config.get('Upload', 'gcp_bucket')
            self._service_account_location = config.get(
                'Upload', 'gcp_service_account_path')
            try:
                self._client = Client.from_service_account_json(
                    self._service_account_location)
            except:
                self._client = None
예제 #22
0
def write_to_gcs(out_data: str, out_path: str, creds_file: Optional[str]):
    from google.cloud.exceptions import Forbidden
    from google.cloud.storage import Client
    from urllib.parse import urlparse

    if creds_file is not None:
        client = Client.from_service_account_json(creds_file)
    else:
        client = Client()
    url = urlparse(out_path)
    bucket = client.bucket(url.netloc)
    # Stripping the leading /
    blob: "Blob"
    blob = bucket.blob(url.path[1:])
    try:
        blob.upload_from_string(out_data, content_type="application/json")
    except Forbidden as e:
        click.secho(
            f"Unable to write to {out_path}, permission denied:\n"
            f"{e.response.json()['error']['message']}",
            err=True,
        )
        sys.exit(1)
    blob.make_public()
예제 #23
0
    def _gc(self):
        from google.cloud.storage import Client

        return Client.from_service_account_json(TEST_GCP_CREDS_FILE)
예제 #24
0
    def load_dataframes(config, logger, skip_data=False) -> List:
        """From an Airbyte Configuration file, load and return the appropriate pandas dataframe.

        :param skip_data: limit reading data
        :param config:
        :param logger:
        :return: a list of dataframe loaded from files described in the configuration
        """
        storage = SourceFile.get_storage_scheme(logger, config["storage"], config["url"])
        url = SourceFile.get_simple_url(config["url"])

        gcs_file = None
        use_gcs_service_account = "service_account_json" in config and storage == "gs://"
        use_aws_account = "aws_access_key_id" in config and "aws_secret_access_key" in config and storage == "s3://"

        # default format reader
        reader_format = "csv"
        if "format" in config:
            reader_format = config["format"]
        reader_options: dict = {}
        if "reader_options" in config:
            try:
                reader_options = json.loads(config["reader_options"])
            except json.decoder.JSONDecodeError as err:
                logger.error(f"Failed to parse reader options {repr(err)}\n{config['reader_options']}\n{traceback.format_exc()}")
        if skip_data and reader_format == "csv":
            reader_options["nrows"] = 0
            reader_options["index_col"] = 0

        # default reader impl
        reader_impl = ""
        if "reader_impl" in config:
            reader_impl = config["reader_impl"]

        if reader_impl == "gcsfs":
            if use_gcs_service_account:
                try:
                    token_dict = json.loads(config["service_account_json"])
                    fs = gcsfs.GCSFileSystem(token=token_dict)
                    gcs_file = fs.open(f"gs://{url}")
                    url = gcs_file
                except json.decoder.JSONDecodeError as err:
                    logger.error(f"Failed to parse gcs service account json: {repr(err)}\n{traceback.format_exc()}")
                    raise err
            else:
                url = open(f"{storage}{url}")
        elif reader_impl == "s3fs":
            if use_aws_account:
                aws_access_key_id = None
                if "aws_access_key_id" in config:
                    aws_access_key_id = config["aws_access_key_id"]
                aws_secret_access_key = None
                if "aws_secret_access_key" in config:
                    aws_secret_access_key = config["aws_secret_access_key"]
                s3 = S3FileSystem(anon=False, key=aws_access_key_id, secret=aws_secret_access_key)
                url = s3.open(f"s3://{url}", mode="r")
            else:
                url = open(f"{storage}{url}")
        else:  # using smart_open
            if use_gcs_service_account:
                try:
                    credentials = json.dumps(json.loads(config["service_account_json"]))
                    tmp_service_account = tempfile.NamedTemporaryFile(delete=False)
                    with open(tmp_service_account, "w") as f:
                        f.write(credentials)
                    tmp_service_account.close()
                    client = Client.from_service_account_json(tmp_service_account.name)
                    url = open(f"gs://{url}", transport_params=dict(client=client))
                    os.remove(tmp_service_account.name)
                except json.decoder.JSONDecodeError as err:
                    logger.error(f"Failed to parse gcs service account json: {repr(err)}\n{traceback.format_exc()}")
                    raise err
            elif use_aws_account:
                aws_access_key_id = ""
                if "aws_access_key_id" in config:
                    aws_access_key_id = config["aws_access_key_id"]
                aws_secret_access_key = ""
                if "aws_secret_access_key" in config:
                    aws_secret_access_key = config["aws_secret_access_key"]
                url = open(f"s3://{aws_access_key_id}:{aws_secret_access_key}@{url}")
            elif storage == "webhdfs://":
                host = config["host"]
                port = config["port"]
                url = open(f"webhdfs://{host}:{port}/{url}")
            elif storage == "ssh://" or storage == "scp://" or storage == "sftp://":
                user = config["user"]
                host = config["host"]
                if "password" in config:
                    password = config["password"]
                    # Explicitly turn off ssh keys stored in ~/.ssh
                    transport_params = {"connect_kwargs": {"look_for_keys": False}}
                    url = open(f"{storage}{user}:{password}@{host}/{url}", transport_params=transport_params)
                else:
                    url = open(f"{storage}{user}@{host}/{url}")
            else:
                url = open(f"{storage}{url}")
        try:
            result = SourceFile.parse_file(logger, reader_format, url, reader_options)
        finally:
            if gcs_file:
                gcs_file.close()
        return result
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-n', '--network-name', type=str, required=True)
    parser.add_argument('-d', '--dataset-name', type=str, required=True)
    parser.add_argument('-p', '--pretrained', type=bool, default=False)
    parser.add_argument('-dp', '--data-parallel', type=bool, default=True)
    parser.add_argument('--train-images-path',
                        type=str,
                        default=default_train_images)
    parser.add_argument('--train-csv-path', type=str, default=default_csv)
    parser.add_argument(
        '-l',
        '--load',
        help=
        'if using load, must be path to .pth file containing serialized model state dict'
    )
    parser.add_argument('--batchSz', type=int, default=32)  # 64
    parser.add_argument('--nEpochs', type=int, default=1)  # 300
    parser.add_argument('--sEpoch', type=int, default=1)
    parser.add_argument('--unfreeze-epoch', type=int, default=-1)
    parser.add_argument('--nSubsample', type=int, default=0)
    parser.add_argument('--use-cuda', type=str, default='no')
    parser.add_argument('--nGPU', type=int, default=0)
    parser.add_argument('--save')
    parser.add_argument('--seed', type=int, default=50)
    parser.add_argument('--opt',
                        type=str,
                        default='sgd',
                        choices=('sgd', 'adam', 'rmsprop'))
    parser.add_argument('--crit',
                        type=str,
                        default='bce',
                        choices=('bce', 'f1', 'crl'))
    parser.add_argument(
        '--distributed',
        type=bool,
        default=False,
        help='If True, use distributed data parallel training (default, False).'
    )
    args = parser.parse_args()

    if args.use_cuda == 'yes' and not torch.cuda.is_available():
        raise ValueError('Use cuda requires cuda devices and ' + \
                         'drivers to be installed. Please make ' + \
                         'sure both are installed.'
                        )
    elif args.use_cuda == 'yes' and torch.cuda.is_available():
        args.cuda = True
    else:
        args.cuda = False

    if args.cuda and args.nGPU == 0:
        nGPU = 1
    else:
        nGPU = args.nGPU

    main_proc = True
    if args.distributed:
        dist.init_process_group(backend='gloo')
        main_proc = dist.get_rank() == 0
        init_print(dist.get_rank(), dist.get_world_size())

    print("using cuda ", args.cuda)

    args.save = args.save or 'work/%s/%s' % \
                                (args.network_name, args.dataset_name)
    setproctitle.setproctitle(args.save)

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    os.makedirs(args.save, exist_ok=True)

    kwargs = {'batch_size': args.batchSz}

    trainLoader, devLoader = get_train_test_split(args, **kwargs)

    # net = get_network(args)
    net = NETWORKS_DICT[args.network_name](args.pretrained)

    if args.load:
        print("Loading network: {}".format(args.load))
        load_model(args, net)

    if args.distributed:
        net = DistributedDataParallel(net)
    elif args.data_parallel:
        net = torch.nn.DataParallel(net)

    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in net.parameters()])))

    if args.cuda:
        net = net.cuda()

    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=1e-3,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), weight_decay=1e-4)
    elif args.opt == 'rmsprop':
        optimizer = torch.optim.RMSprop(net.parameters(), weight_decay=1e-4)
    else:
        raise ModuleNotFoundError('optimiser not found')

    if args.crit == 'crl':
        lf_args = [0.5, 8.537058595265812e-06, args.batchSz, 5, True, True]
    else:
        lf_args = None

    criterion = get_loss_function(args.crit, lf_args)

    sched_args = [10, 1e-4, 1.1, .5, -1]
    scheduler = CosineAnnealingRestartsLR(optimizer, *sched_args)

    trainF = open(os.path.join(args.save, 'train.csv'), 'a')
    testF = open(os.path.join(args.save, 'test.csv'), 'a')

    for epoch in range(args.sEpoch, args.nEpochs + args.sEpoch):
        # adjust_opt(args, epoch, optimizer)
        scheduler.step()
        unfreeze_weights(args, epoch, net)
        train(args, epoch, net, trainLoader, criterion, optimizer, trainF)
        test(args, epoch, net, devLoader, criterion, optimizer, testF)
        if main_proc:
            save_model(args, epoch, net)

    trainF.close()
    testF.close()

    if len(CLOUD_STORAGE_BUCKET) != 0:
        storage_client = Client.from_service_account_json(CREDENTIALS)
        bucket = storage_client.get_bucket(CLOUD_STORAGE_BUCKET)
        all_files = [name for name in os.listdir(args.save) \
                            if os.path.isfile(os.path.join(args.save, name))]
        if len(all_files) > 0:
            print("uploading weights...")
            for file in all_files:
                blob = bucket.blob(os.path.join(args.save, file))
                blob.upload_from_filename(os.path.join(args.save, file))
            print('upload complete')
예제 #26
0
    def __init__(self, json_keyfile: str) -> None:
        super(GSBlobStore, self).__init__()

        self.gcp_client = Client.from_service_account_json(json_keyfile)
        self.bucket_map = dict()  # type: typing.MutableMapping[str, Bucket]
예제 #27
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-n', '--network-name', type=str, required=True)
    parser.add_argument('-d', '--dataset-name', type=str, required=True)
    parser.add_argument('-p', '--pretrained', type=bool, default=False)
    parser.add_argument('-dp', '--data-parallel', type=bool, default=True)
    parser.add_argument('--test-images-path',
                        type=str,
                        default=default_test_images)
    parser.add_argument('-l', '--load')
    parser.add_argument('--batchSz', type=int, default=32)
    parser.add_argument('--save')
    parser.add_argument('--thresholds', type=str, default=None)
    parser.add_argument('--sigmoid', type=bool, default=True)
    args = parser.parse_args()

    args.cuda = torch.cuda.is_available()

    print("using cuda: ", args.cuda)

    args.save = args.save or 'work/%s/%s' % \
                                (args.network_name, args.dataset_name)
    setproctitle.setproctitle('work/%s/%s-test' % \
                                (args.network_name, args.dataset_name))

    if not os.path.exists(args.save):
        raise ValueError('save directory not found')

    kwargs = {'batch_size': args.batchSz}

    testLoader = get_testloader(args, **kwargs)

    net = NETWORKS_DICT[args.network_name](args.pretrained)
    if args.load:
        print("Loading network: {}".format(args.load))
    else:
        load_path = 'work/%s/%s' % (args.network_name, args.dataset_name)
        files = [f for f in os.listdir(load_path) if \
                            os.path.isfile(os.path.join(load_path, f)) \
                            and '.pth' in f]
        current = max([int(i.replace('.pth', '')) for i in files])
        args.load = os.path.join(load_path, str(current) + '.pth')
        print(args.load)
    load_model(args, net)

    if args.cuda:
        net = net.cuda()

    now = datetime.datetime.now(
        tz=pytz.timezone("US/Mountain")).strftime("%Y-%m-%d___%H:%M:%S")
    predict_csv_path = os.path.join(
        args.save, '{}_{}_predict.csv'.format(BRANCH_NAME, now))

    predF = open(predict_csv_path, 'a')

    predict(args, net, testLoader, predF)

    predF.close

    if len(CLOUD_STORAGE_BUCKET) != 0:
        storage_client = Client.from_service_account_json(CREDENTIALS)
        print('client authenticated')
        bucket = storage_client.get_bucket(CLOUD_STORAGE_BUCKET)
        blob = bucket.blob(predict_csv_path)

        blob.upload_from_filename(predict_csv_path)
예제 #28
0
 def from_auth_credentials(cls, json_keyfile_path: str) -> "GSBlobStore":
     return cls(Client.from_service_account_json(json_keyfile_path))
 def __init__(self, local_root: str, bucket_name: str) -> None:
     super(GSUploader, self).__init__(local_root)
     credentials = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
     self.gcp_client = Client.from_service_account_json(credentials)
     self.bucket = self.gcp_client.bucket(bucket_name)