示例#1
0
文件: dataset.py 项目: vipadm/airflow
    def get_dataset_service_client(self, region: Optional[str] = None) -> DatasetServiceClient:
        """Returns DatasetServiceClient."""
        if region and region != 'global':
            client_options = ClientOptions(api_endpoint=f'{region}-aiplatform.googleapis.com:443')
        else:
            client_options = ClientOptions()

        return DatasetServiceClient(
            credentials=self._get_credentials(), client_info=CLIENT_INFO, client_options=client_options
        )
    def get_job_service_client(self, region: Optional[str] = None) -> JobServiceClient:
        """Returns JobServiceClient."""
        if region and region != 'global':
            client_options = ClientOptions(api_endpoint=f'{region}-aiplatform.googleapis.com:443')
        else:
            client_options = ClientOptions()

        return JobServiceClient(
            credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options
        )
示例#3
0
 def spanner_api(self):
     """Helper for session-related API calls."""
     if self._spanner_api is None:
         client_info = self._instance._client._client_info
         client_options = self._instance._client._client_options
         if self._instance.emulator_host is not None:
             transport = spanner_grpc_transport.SpannerGrpcTransport(
                 channel=grpc.insecure_channel(
                     self._instance.emulator_host))
             self._spanner_api = SpannerClient(
                 client_info=client_info,
                 client_options=client_options,
                 transport=transport,
             )
             return self._spanner_api
         credentials = self._instance._client.credentials
         if isinstance(credentials, google.auth.credentials.Scoped):
             credentials = credentials.with_scopes((SPANNER_DATA_SCOPE, ))
         if (os.getenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING")
                 == "true"):
             endpoint_cache = self._instance._client._endpoint_cache
             if self._instance.name in endpoint_cache:
                 client_options = ClientOptions(
                     api_endpoint=endpoint_cache[self._instance.name])
             else:
                 try:
                     api = self._instance._client.instance_admin_api
                     resp = api.get_instance(
                         self._instance.name,
                         field_mask={"paths": ["endpoint_uris"]},
                         metadata=_metadata_with_prefix(self.name),
                     )
                     endpoints = resp.endpoint_uris
                     if endpoints:
                         endpoint_cache[self._instance.name] = list(
                             endpoints)[0]
                         client_options = ClientOptions(
                             api_endpoint=endpoint_cache[
                                 self._instance.name])
                     # If there are no endpoints, use default endpoint.
                 except PermissionDenied:
                     warnings.warn(
                         _RESOURCE_ROUTING_PERMISSIONS_WARNING,
                         ResourceRoutingPermissionsWarning,
                         stacklevel=2,
                     )
         self._spanner_api = SpannerClient(
             credentials=credentials,
             client_info=client_info,
             client_options=client_options,
         )
     return self._spanner_api
示例#4
0
def create_dataset(project_id):
    """Creates a dataset for the given Google Cloud project."""
    from google.cloud import datalabeling_v1beta1 as datalabeling

    client = datalabeling.DataLabelingServiceClient()
    # [END datalabeling_create_dataset_beta]
    # If provided, use a provided test endpoint - this will prevent tests on
    # this snippet from triggering any action by a real human
    if "DATALABELING_ENDPOINT" in os.environ:
        opts = ClientOptions(api_endpoint=os.getenv("DATALABELING_ENDPOINT"))
        client = datalabeling.DataLabelingServiceClient(client_options=opts)
    # [START datalabeling_create_dataset_beta]

    formatted_project_name = f"projects/{project_id}"

    dataset = datalabeling.Dataset(
        display_name="YOUR_DATASET_SET_DISPLAY_NAME", description="YOUR_DESCRIPTION"
    )

    response = client.create_dataset(
        request={"parent": formatted_project_name, "dataset": dataset}
    )

    # The format of resource name:
    # project_id/{project_id}/datasets/{dataset_id}
    print("The dataset resource name: {}".format(response.name))
    print("Display name: {}".format(response.display_name))
    print("Description: {}".format(response.description))
    print("Create time:")
    print("\tseconds: {}".format(response.create_time.timestamp_pb().seconds))
    print("\tnanos: {}\n".format(response.create_time.timestamp_pb().nanos))

    return response
 def __init__(self, project, region, model_name, version):
   api_endpoint = f'https://{region}-ml.googleapis.com'
   client_options = ClientOptions(api_endpoint=api_endpoint)
   self.service = googleapiclient.discovery.build(
     serviceName='ml', version='v1', client_options=client_options)
   self.name = f'projects/{project}/models/{model_name}/versions/{version}'
   print(f'Embedding lookup service {self.name} is initialized.')
示例#6
0
    def __init__(
        self,
        region: CloudRegion,
        credentials: Optional[Credentials] = None,
        transport: Optional[str] = None,
        client_options: Optional[ClientOptions] = None,
    ):
        """
        Create a new AdminClient.

        Args:
            region: The cloud region to connect to.
            credentials: The credentials to use when connecting.
            transport: The transport to use.
            client_options: The client options to use when connecting. If used, must explicitly set `api_endpoint`.
        """
        if client_options is None:
            client_options = ClientOptions(api_endpoint=regional_endpoint(region))
        self._impl = AdminClientImpl(
            AdminServiceClient(
                client_options=client_options,
                transport=transport,
                credentials=credentials,
            ),
            region,
        )
示例#7
0
def list_datasets(project_id):
    """Lists datasets for the given Google Cloud project."""
    from google.cloud import datalabeling_v1beta1 as datalabeling

    client = datalabeling.DataLabelingServiceClient()
    # [END datalabeling_list_datasets_beta]
    # If provided, use a provided test endpoint - this will prevent tests on
    # this snippet from triggering any action by a real human
    if "DATALABELING_ENDPOINT" in os.environ:
        opts = ClientOptions(api_endpoint=os.getenv("DATALABELING_ENDPOINT"))
        client = datalabeling.DataLabelingServiceClient(client_options=opts)
    # [START datalabeling_list_datasets_beta]

    formatted_project_name = f"projects/{project_id}"

    response = client.list_datasets(request={"parent": formatted_project_name})
    for element in response:
        # The format of resource name:
        # project_id/{project_id}/datasets/{dataset_id}
        print("The dataset resource name: {}\n".format(element.name))
        print("Display name: {}".format(element.display_name))
        print("Description: {}".format(element.description))
        print("Create time:")
        print("\tseconds: {}".format(element.create_time.timestamp_pb().seconds))
        print("\tnanos: {}".format(element.create_time.timestamp_pb().nanos))
示例#8
0
def get_cloud_tasks_client():
    """
        Get an instance of a Google CloudTasksClient

        Note. Nested imports are to allow for things not to
        force the google cloud tasks dependency if you're not
        using it
    """
    from google.cloud.tasks import CloudTasksClient

    is_app_engine = os.environ.get("GAE_ENV") == "standard"

    if is_app_engine:
        return CloudTasksClient()
    else:
        # Running locally, try to connect to the emulator

        try:
            # google-cloud-tasks < 2.0.0 has this here
            from google.cloud.tasks_v2.gapic.transports.cloud_tasks_grpc_transport import CloudTasksGrpcTransport
        except ImportError:
            from google.cloud.tasks_v2.services.cloud_tasks.transports.grpc import CloudTasksGrpcTransport

        from google.api_core.client_options import ClientOptions

        host = os.environ.get("TASKS_EMULATOR_HOST", "127.0.0.1:9022")

        client = CloudTasksClient(
            transport=CloudTasksGrpcTransport(
                channel=grpc.insecure_channel(host)),
            client_options=ClientOptions(api_endpoint=host))
        return client
示例#9
0
    def get_prediction(self, sent):
        '''
        Obtains the prediction from the input sentence and returns the
        normalized sentence

        Args: sent (string) - input sentence

        Return: request (PredictObject) - predictiton output
        ''' 
        
        params = {}
        
        # Setup API 
        options = ClientOptions(api_endpoint='automl.googleapis.com')
        
        # Create prediction object
        predictor = automl_v1.PredictionServiceClient(client_options=options)

        # Format input sentence
        payload = self.inline_text_payload(sent)
        
        # Make prediction API call
        request = predictor.predict(self.model_name, payload, params)

        # Return the output of the API call
        return request
示例#10
0
    def test_instance_admin_api_emulator_env(self, mock_em):
        from google.api_core.client_options import ClientOptions

        mock_em.return_value = "emulator.host"
        credentials = _make_credentials()
        client_info = mock.Mock()
        client_options = ClientOptions(api_endpoint="endpoint")
        client = self._make_one(
            project=self.PROJECT,
            credentials=credentials,
            client_info=client_info,
            client_options=client_options,
        )

        inst_module = "google.cloud.spanner_v1.client.InstanceAdminClient"
        with mock.patch(inst_module) as instance_admin_client:
            api = client.instance_admin_api

        self.assertIs(api, instance_admin_client.return_value)

        # API instance is cached
        again = client.instance_admin_api
        self.assertIs(again, api)

        self.assertEqual(len(instance_admin_client.call_args_list), 1)
        called_args, called_kw = instance_admin_client.call_args
        self.assertEqual(called_args, ())
        self.assertEqual(called_kw["client_info"], client_info)
        self.assertEqual(called_kw["client_options"], client_options)
        self.assertIn("transport", called_kw)
        self.assertNotIn("credentials", called_kw)
示例#11
0
def create_dataset(project_id):
    """Creates a dataset for the given Google Cloud project."""
    from google.cloud import datalabeling_v1beta1 as datalabeling
    client = datalabeling.DataLabelingServiceClient()
    # [END datalabeling_create_dataset_beta]
    # If provided, use a provided test endpoint - this will prevent tests on
    # this snippet from triggering any action by a real human
    if 'DATALABELING_ENDPOINT' in os.environ:
        opts = ClientOptions(api_endpoint=os.getenv('DATALABELING_ENDPOINT'))
        client = datalabeling.DataLabelingServiceClient(client_options=opts)
    # [START datalabeling_create_dataset_beta]

    formatted_project_name = client.project_path(project_id)

    dataset = datalabeling.types.Dataset(
        display_name='YOUR_DATASET_SET_DISPLAY_NAME',
        description='YOUR_DESCRIPTION')

    response = client.create_dataset(formatted_project_name, dataset)

    # The format of resource name:
    # project_id/{project_id}/datasets/{dataset_id}
    print('The dataset resource name: {}'.format(response.name))
    print('Display name: {}'.format(response.display_name))
    print('Description: {}'.format(response.description))
    print('Create time:')
    print('\tseconds: {}'.format(response.create_time.seconds))
    print('\tnanos: {}\n'.format(response.create_time.nanos))

    return response
示例#12
0
def _create_dummy_storage_client():
    fake_host = os.getenv('STORAGE_PORT_4443_TCP_ADDR')
    external_url = 'https://{}:4443'.format(fake_host)
    storage.blob._API_ACCESS_ENDPOINT = 'https://storage.gcs.{}.nip.io:4443'.format(fake_host)
    storage.blob._DOWNLOAD_URL_TEMPLATE = (
        "%s/download/storage/v1{path}?alt=media" % external_url
    )
    storage.blob._BASE_UPLOAD_TEMPLATE = (
        "%s/upload/storage/v1{bucket_path}/o?uploadType=" % external_url
    )
    storage.blob._MULTIPART_URL_TEMPLATE = storage.blob._BASE_UPLOAD_TEMPLATE + "multipart"
    storage.blob._RESUMABLE_URL_TEMPLATE = storage.blob._BASE_UPLOAD_TEMPLATE + "resumable"
    my_http = requests.Session()
    my_http.verify = False  # disable SSL validation
    urllib3.disable_warnings(
        urllib3.exceptions.InsecureRequestWarning
    )  # disable https warnings for https insecure certs

    storage_client = storage.Client(
        credentials=AnonymousCredentials(),
        project='test',
        _http=my_http,
        client_options=ClientOptions(api_endpoint=external_url))

    if len(list(storage_client.list_buckets())) == 0:
        bucket = storage_client.create_bucket(_get_bucket_name())

    return storage_client
示例#13
0
    def test_database_admin_api_emulator_code(self):
        from google.auth.credentials import AnonymousCredentials
        from google.api_core.client_options import ClientOptions

        credentials = AnonymousCredentials()
        client_info = mock.Mock()
        client_options = ClientOptions(api_endpoint="emulator.host")
        client = self._make_one(
            project=self.PROJECT,
            credentials=credentials,
            client_info=client_info,
            client_options=client_options,
        )

        db_module = "google.cloud.spanner_v1.client.DatabaseAdminClient"
        with mock.patch(db_module) as database_admin_client:
            api = client.database_admin_api

        self.assertIs(api, database_admin_client.return_value)

        # API instance is cached
        again = client.database_admin_api
        self.assertIs(again, api)

        self.assertEqual(len(database_admin_client.call_args_list), 1)
        called_args, called_kw = database_admin_client.call_args
        self.assertEqual(called_args, ())
        self.assertEqual(called_kw["client_info"], client_info)
        self.assertEqual(called_kw["client_options"], client_options)
        self.assertIn("transport", called_kw)
        self.assertNotIn("credentials", called_kw)
示例#14
0
def label_video(dataset_resource_name, instruction_resource_name,
                annotation_spec_set_resource_name):
    """Labels a video dataset."""
    from google.cloud import datalabeling_v1beta1 as datalabeling
    client = datalabeling.DataLabelingServiceClient()
    # [END datalabeling_label_video_beta]
    # If provided, use a provided test endpoint - this will prevent tests on
    # this snippet from triggering any action by a real human
    if 'DATALABELING_ENDPOINT' in os.environ:
        opts = ClientOptions(api_endpoint=os.getenv('DATALABELING_ENDPOINT'))
        client = datalabeling.DataLabelingServiceClient(client_options=opts)
    # [START datalabeling_label_video_beta]

    basic_config = datalabeling.types.HumanAnnotationConfig(
        instruction=instruction_resource_name,
        annotated_dataset_display_name='YOUR_ANNOTATED_DATASET_DISPLAY_NAME',
        label_group='YOUR_LABEL_GROUP',
        replica_count=1
    )

    feature = datalabeling.enums.LabelVideoRequest.Feature.OBJECT_TRACKING

    config = datalabeling.types.ObjectTrackingConfig(
        annotation_spec_set=annotation_spec_set_resource_name
    )

    response = client.label_video(
        dataset_resource_name,
        basic_config,
        feature,
        object_tracking_config=config
    )

    print('Label_video operation name: {}'.format(response.operation.name))
    return response
示例#15
0
def export_data(dataset_resource_name, annotated_dataset_resource_name,
                export_gcs_uri):
    """Exports a dataset from the given Google Cloud project."""
    from google.cloud import datalabeling_v1beta1 as datalabeling

    client = datalabeling.DataLabelingServiceClient()
    # [END datalabeling_export_data_beta]
    # If provided, use a provided test endpoint - this will prevent tests on
    # this snippet from triggering any action by a real human
    if "DATALABELING_ENDPOINT" in os.environ:
        opts = ClientOptions(api_endpoint=os.getenv("DATALABELING_ENDPOINT"))
        client = datalabeling.DataLabelingServiceClient(client_options=opts)
    # [START datalabeling_export_data_beta]

    gcs_destination = datalabeling.GcsDestination(output_uri=export_gcs_uri,
                                                  mime_type="text/csv")

    output_config = datalabeling.OutputConfig(gcs_destination=gcs_destination)

    response = client.export_data(
        request={
            "name": dataset_resource_name,
            "annotated_dataset": annotated_dataset_resource_name,
            "output_config": output_config,
        })

    print("Dataset ID: {}\n".format(response.result().dataset))
    print("Output config:")
    print("\tGcs destination:")
    print("\t\tOutput URI: {}\n".format(
        response.result().output_config.gcs_destination.output_uri))
示例#16
0
def create_storage_client(test):
    if test:
        EXTERNAL_URL = "https://127.0.0.1:4443"
        PUBLIC_HOST = "storage.gcs.127.0.0.1.nip.io:4443"

        storage.blob._API_ACCESS_ENDPOINT = "https://" + PUBLIC_HOST
        storage.blob._DOWNLOAD_URL_TEMPLATE = (
            u"%s/download/storage/v1{path}?alt=media" % EXTERNAL_URL)
        storage.blob._BASE_UPLOAD_TEMPLATE = (
            u"%s/upload/storage/v1{bucket_path}/o?uploadType=" % EXTERNAL_URL)
        storage.blob._MULTIPART_URL_TEMPLATE = storage.blob._BASE_UPLOAD_TEMPLATE + u"multipart"
        storage.blob._RESUMABLE_URL_TEMPLATE = storage.blob._BASE_UPLOAD_TEMPLATE + u"resumable"

        my_http = requests.Session()
        my_http.verify = False  # disable SSL validation
        urllib3.disable_warnings(
            urllib3.exceptions.InsecureRequestWarning
        )  # disable https warnings for https insecure certs

        storage_client = storage.Client(
            credentials=AnonymousCredentials(),
            project="test",
            _http=my_http,
            client_options=ClientOptions(api_endpoint=EXTERNAL_URL),
        )
    else:
        storage_client = storage.Client()

    return storage_client
示例#17
0
def stock_tweet_classifier(tweet_string):
    """Passes tweet into trained AutoML model, outputs classification on whether it is stock-related"""

    options = ClientOptions(api_endpoint='automl.googleapis.com')
    model_name = 'projects/313817029040/locations/us-central1/models/TCN8645127876691099648'
    credentials = service_account.Credentials.from_service_account_file(
        'AutoMLAuth.json')
    prediction_client = automl_v1.PredictionServiceClient(
        client_options=options, credentials=credentials)

    text_snip = {
        'text_snippet': {
            'content': tweet_string,
            'mime_type': 'text/plain'
        }
    }
    payload = automl_v1.ExamplePayload(text_snip)
    # print(payload)
    request = prediction_client.predict(name=model_name, payload=payload)

    classification = request.payload[0].display_name

    if classification == 'stock':
        return True
    else:
        return False
def classify_doc(bucket, filename):
    options = ClientOptions(api_endpoint='automl.googleapis.com')
    prediction_client = automl_v1.PredictionServiceClient(
        client_options=options)

    _, ext = os.path.splitext(filename)
    if ext in [".pdf", "txt", "html"]:
        payload = _gcs_payload(bucket, filename)
    elif ext in ['.tif', '.tiff', '.png', '.jpeg', '.jpg']:
        payload = _img_payload(bucket, filename)
    else:
        print(
            f"Could not sort document gs://{bucket}/{filename}, unsupported file type {ext}")
        return None
    if not payload:
        print(
            f"Missing document gs://{bucket}/{filename} payload, cannot sort")
        return None
    request = prediction_client.predict(
        os.environ["SORT_MODEL_NAME"], payload, {})
    label = max(request.payload, key=lambda x: x.classification.score)
    threshold = float(os.environ.get('SORT_MODEL_THRESHOLD')) or 0.7
    displayName = label.display_name if label.classification.score > threshold else None
    print(f"Labeled document gs://{bucket}/{filename} as {displayName}")
    return displayName
示例#19
0
def import_data(dataset_resource_name, data_type, input_gcs_uri):
    """Imports data to the given Google Cloud project and dataset."""
    from google.cloud import datalabeling_v1beta1 as datalabeling
    client = datalabeling.DataLabelingServiceClient()
    # [END datalabeling_import_data_beta]
    # If provided, use a provided test endpoint - this will prevent tests on
    # this snippet from triggering any action by a real human
    if 'DATALABELING_ENDPOINT' in os.environ:
        opts = ClientOptions(api_endpoint=os.getenv('DATALABELING_ENDPOINT'))
        client = datalabeling.DataLabelingServiceClient(client_options=opts)
    # [START datalabeling_import_data_beta]

    gcs_source = datalabeling.types.GcsSource(input_uri=input_gcs_uri,
                                              mime_type='text/csv')

    csv_input_config = datalabeling.types.InputConfig(data_type=data_type,
                                                      gcs_source=gcs_source)

    response = client.import_data(dataset_resource_name, csv_input_config)

    result = response.result()

    # The format of resource name:
    # project_id/{project_id}/datasets/{dataset_id}
    print('Dataset resource name: {}\n'.format(result.dataset))

    return result
示例#20
0
def label_image(dataset_resource_name, instruction_resource_name,
                annotation_spec_set_resource_name):
    """Labels an image dataset."""
    from google.cloud import datalabeling_v1beta1 as datalabeling
    client = datalabeling.DataLabelingServiceClient()
    # [END datalabeling_label_image_beta]
    # If provided, use a provided test endpoint - this will prevent tests on
    # this snippet from triggering any action by a real human
    if 'DATALABELING_ENDPOINT' in os.environ:
        opts = ClientOptions(api_endpoint=os.getenv('DATALABELING_ENDPOINT'))
        client = datalabeling.DataLabelingServiceClient(client_options=opts)
    # [START datalabeling_label_image_beta]

    basic_config = datalabeling.types.HumanAnnotationConfig(
        instruction=instruction_resource_name,
        annotated_dataset_display_name='YOUR_ANNOTATED_DATASET_DISPLAY_NAME',
        label_group='YOUR_LABEL_GROUP',
        replica_count=1)

    feature = datalabeling.enums.LabelImageRequest.Feature.CLASSIFICATION

    config = datalabeling.types.ImageClassificationConfig(
        annotation_spec_set=annotation_spec_set_resource_name,
        allow_multi_label=False,
        answer_aggregation_type=datalabeling.enums.StringAggregationType.
        MAJORITY_VOTE)

    response = client.label_image(dataset_resource_name,
                                  basic_config,
                                  feature,
                                  image_classification_config=config)

    print('Label_image operation name: {}'.format(response.operation.name))
    return response
示例#21
0
    def test_database_admin_api(self, mock_em):
        from google.cloud.spanner_v1.client import SPANNER_ADMIN_SCOPE
        from google.api_core.client_options import ClientOptions

        mock_em.return_value = None
        credentials = _make_credentials()
        client_info = mock.Mock()
        client_options = ClientOptions(quota_project_id="QUOTA-PROJECT")
        client = self._make_one(
            project=self.PROJECT,
            credentials=credentials,
            client_info=client_info,
            client_options=client_options,
        )
        expected_scopes = (SPANNER_ADMIN_SCOPE,)

        db_module = "google.cloud.spanner_v1.client.DatabaseAdminClient"
        with mock.patch(db_module) as database_admin_client:
            api = client.database_admin_api

        self.assertIs(api, database_admin_client.return_value)

        # API instance is cached
        again = client.database_admin_api
        self.assertIs(again, api)

        database_admin_client.assert_called_once_with(
            credentials=mock.ANY, client_info=client_info, client_options=client_options
        )

        credentials.with_scopes.assert_called_once_with(expected_scopes)
示例#22
0
    def test_constructor_w_explicit_inputs(self):
        from google.api_core.client_options import ClientOptions

        other = "other"
        namespace = "namespace"
        creds = _make_credentials()
        client_info = mock.Mock()
        client_options = ClientOptions("endpoint")
        http = object()
        client = self._make_one(
            project=other,
            namespace=namespace,
            credentials=creds,
            client_info=client_info,
            client_options=client_options,
            _http=http,
        )
        self.assertEqual(client.project, other)
        self.assertEqual(client.namespace, namespace)
        self.assertIs(client._credentials, creds)
        self.assertIs(client._client_info, client_info)
        self.assertIs(client._http_internal, http)
        self.assertIsNone(client.current_batch)
        self.assertIs(client._base_url, "endpoint")
        self.assertEqual(list(client._batch_stack), [])
示例#23
0
    def get_dataplex_client(self) -> DataplexServiceClient:
        """Returns DataplexServiceClient."""
        client_options = ClientOptions(
            api_endpoint='dataplex.googleapis.com:443')

        return DataplexServiceClient(credentials=self._get_credentials(),
                                     client_info=self.client_info,
                                     client_options=client_options)
示例#24
0
    def get_dataproc_metastore_client(self) -> DataprocMetastoreClient:
        """Returns DataprocMetastoreClient."""
        client_options = ClientOptions(
            api_endpoint='metastore.googleapis.com:443')

        return DataprocMetastoreClient(credentials=self.get_credentials(),
                                       client_info=CLIENT_INFO,
                                       client_options=client_options)
    def test_ctor_w_empty_client_options(self):
        from google.api_core.client_options import ClientOptions

        http = object()
        client_options = ClientOptions()
        client = self._make_one(_http=http, client_options=client_options)
        self.assertEqual(client._connection.API_BASE_URL,
                         client._connection.DEFAULT_API_ENDPOINT)
def automl_create_dataset_for_nlp(
  gcp_project_id: str,
  gcp_region: str,
  dataset_display_name: str,
  api_endpoint: str = None,
) -> NamedTuple('Outputs', [('dataset_path', str), ('dataset_status', str), ('dataset_id', str)]):

  import sys
  import subprocess
  subprocess.run([sys.executable, '-m', 'pip', 'install', 'googleapis-common-protos==1.6.0',
      '--no-warn-script-location'],
      env={'PIP_DISABLE_PIP_VERSION_CHECK': '1'}, check=True)
  subprocess.run([sys.executable, '-m', 'pip', 'install', 'google-cloud-automl==0.9.0',
      '--quiet', '--no-warn-script-location'],
      env={'PIP_DISABLE_PIP_VERSION_CHECK': '1'}, check=True)

  import google
  import logging
  from google.api_core.client_options import ClientOptions
  from google.cloud import automl

  logging.getLogger().setLevel(logging.INFO)  # TODO: make level configurable

  if api_endpoint:
    client_options = ClientOptions(api_endpoint=api_endpoint)
    client = automl.AutoMlClient(client_options=client_options)
  else:
    client = automl.AutoMlClient()
  
  status = 'created'
  project_location = client.location_path(gcp_project_id, gcp_region)
  # Check if dataset is existed.
  for element in client.list_datasets(project_location):
    if element.display_name == dataset_display_name:
      status = 'created but existed'
      if element.example_count == 0:
        status = 'existed but empty'
        return (element.name, status, element.name.rsplit('/', 1)[-1])
  try:
    metadata = automl.types.TextClassificationDatasetMetadata(classification_type=automl.enums.ClassificationType.MULTICLASS)
    dataset = automl.types.Dataset(display_name=dataset_display_name, text_classification_dataset_metadata=metadata,)
    # Create a dataset with the given display name
    response = client.create_dataset(project_location, dataset)
    created_dataset = response.result()
    # Log info about the created dataset
    logging.info("Dataset name: {}".format(created_dataset.name))
    logging.info("Dataset id: {}".format(created_dataset.name.split("/")[-1]))
    logging.info("Dataset display name: {}".format(dataset.display_name))
    logging.info("Dataset example count: {}".format(dataset.example_count))
    logging.info("Dataset create time:")
    logging.info("\tseconds: {}".format(dataset.create_time.seconds))
    logging.info("\tnanos: {}".format(dataset.create_time.nanos))
    
    dataset_id = created_dataset.name.rsplit('/', 1)[-1]
    return (created_dataset.name, status, dataset_id)
  except google.api_core.exceptions.GoogleAPICallError as e:
    logging.warning(e)
    raise e
def make_async_subscriber(
    subscription: SubscriptionPath,
    transport: str,
    per_partition_flow_control_settings: FlowControlSettings,
    nack_handler: Optional[NackHandler] = None,
    message_transformer: Optional[MessageTransformer] = None,
    fixed_partitions: Optional[Set[Partition]] = None,
    credentials: Optional[Credentials] = None,
    client_options: Optional[ClientOptions] = None,
    metadata: Optional[Mapping[str, str]] = None,
) -> AsyncSingleSubscriber:
    """
    Make a Pub/Sub Lite AsyncSubscriber.

    Args:
      subscription: The subscription to subscribe to.
      transport: The transport type to use.
      per_partition_flow_control_settings: The flow control settings for each partition subscribed to. Note that these
        settings apply to each partition individually, not in aggregate.
      nack_handler: An optional handler for when nack() is called on a Message. The default will fail the client.
      message_transformer: An optional transformer from Pub/Sub Lite messages to Cloud Pub/Sub messages.
      fixed_partitions: A fixed set of partitions to subscribe to. If not present, will instead use auto-assignment.
      credentials: The credentials to use to connect. GOOGLE_DEFAULT_CREDENTIALS is used if None.
      client_options: Other options to pass to the client. Note that if you pass any you must set api_endpoint.
      metadata: Additional metadata to send with the RPC.

    Returns:
      A new AsyncSubscriber.
    """
    metadata = merge_metadata(pubsub_context(framework="CLOUD_PUBSUB_SHIM"), metadata)
    if client_options is None:
        client_options = ClientOptions(
            api_endpoint=regional_endpoint(subscription.location.region)
        )
    assigner_factory: Callable[[], Assigner]
    if fixed_partitions:
        assigner_factory = lambda: FixedSetAssigner(fixed_partitions)  # noqa: E731
    else:
        assigner_factory = lambda: _make_dynamic_assigner(  # noqa: E731
            subscription, transport, client_options, credentials, metadata,
        )

    if nack_handler is None:
        nack_handler = DefaultNackHandler()
    if message_transformer is None:
        message_transformer = MessageTransformer.of_callable(to_cps_subscribe_message)
    partition_subscriber_factory = _make_partition_subscriber_factory(
        subscription,
        transport,
        client_options,
        credentials,
        metadata,
        per_partition_flow_control_settings,
        nack_handler,
        message_transformer,
    )
    return AssigningSingleSubscriber(assigner_factory, partition_subscriber_factory)
示例#28
0
    def get_batch_client(self, region: Optional[str] = None) -> BatchControllerClient:
        """Returns BatchControllerClient"""
        client_options = None
        if region and region != 'global':
            client_options = ClientOptions(api_endpoint=f'{region}-dataproc.googleapis.com:443')

        return BatchControllerClient(
            credentials=self._get_credentials(), client_info=CLIENT_INFO, client_options=client_options
        )
示例#29
0
    def get_template_client(self, region: Optional[str] = None) -> WorkflowTemplateServiceClient:
        """Returns WorkflowTemplateServiceClient."""
        client_options = None
        if region and region != 'global':
            client_options = ClientOptions(api_endpoint=f'{region}-dataproc.googleapis.com:443')

        return WorkflowTemplateServiceClient(
            credentials=self._get_credentials(), client_info=CLIENT_INFO, client_options=client_options
        )
示例#30
0
def predict(input, model_name):
  options = ClientOptions(api_endpoint='eu-automl.googleapis.com')
  prediction_client = automl_v1.PredictionServiceClient(client_options=options)

  payload = {'text_snippet': {'content': input, 'mime_type': 'text/plain'} }
  params = {}
  automl_request = automl_v1.PredictRequest(name=model_name, payload=payload, params=params)
  automl_response = prediction_client.predict(automl_request)
  return automl_response