Example #1
0
def construct_clients_from_provider_config(provider_config):
    """
    Attempt to fetch and parse the JSON GCP credentials from the provider
    config yaml file.

    tpu resource (the last element of the tuple) will be None if
    `_has_tpus` in provider config is not set or False.
    """
    gcp_credentials = provider_config.get("gcp_credentials")
    if gcp_credentials is None:
        logger.debug("gcp_credentials not found in cluster yaml file. "
                     "Falling back to GOOGLE_APPLICATION_CREDENTIALS "
                     "environment variable.")
        tpu_resource = (_create_tpu() if provider_config.get(
            HAS_TPU_PROVIDER_FIELD, False) else None)
        # If gcp_credentials is None, then discovery.build will search for
        # credentials in the local environment.
        return _create_crm(), _create_iam(), _create_compute(), tpu_resource

    assert ("type" in gcp_credentials
            ), "gcp_credentials cluster yaml field missing 'type' field."
    assert (
        "credentials" in gcp_credentials
    ), "gcp_credentials cluster yaml field missing 'credentials' field."

    cred_type = gcp_credentials["type"]
    credentials_field = gcp_credentials["credentials"]

    if cred_type == "service_account":
        # If parsing the gcp_credentials failed, then the user likely made a
        # mistake in copying the credentials into the config yaml.
        try:
            service_account_info = json.loads(credentials_field)
        except json.decoder.JSONDecodeError:
            raise RuntimeError(
                "gcp_credentials found in cluster yaml file but "
                "formatted improperly.")
        credentials = service_account.Credentials.from_service_account_info(
            service_account_info)
    elif cred_type == "credentials_token":
        # Otherwise the credentials type must be credentials_token.
        credentials = OAuthCredentials(credentials_field)

    tpu_resource = (_create_tpu(credentials) if provider_config.get(
        HAS_TPU_PROVIDER_FIELD, False) else None)

    return (
        _create_crm(credentials),
        _create_iam(credentials),
        _create_compute(credentials),
        tpu_resource,
    )
Example #2
0
def grpc_create_channel(settings: HTTPConnectionSettings) -> Channel:
    target = f'{settings.host}:{settings.port}'
    options = [('grpc.max_send_message_length', -1),
               ('grpc.max_receive_message_length', -1)]

    if settings.oauth:
        # noinspection PyPackageRequirements
        from google.auth.transport.grpc import secure_authorized_channel
        # noinspection PyPackageRequirements
        from google.auth.transport.requests import Request as RefreshRequester
        # noinspection PyPackageRequirements
        from google.oauth2.credentials import Credentials as OAuthCredentials

        LOG.debug('Using a secure gRPC connection over OAuth:')

        credentials = OAuthCredentials(
            token=settings.oauth.token,
            refresh_token=settings.oauth.refresh_token,
            id_token=settings.oauth.id_token,
            token_uri=settings.oauth.token_uri,
            client_id=settings.oauth.client_id,
            client_secret=settings.oauth.client_secret)
        return secure_authorized_channel(credentials,
                                         RefreshRequester(),
                                         target,
                                         options=options)

    if settings.ssl_settings:
        cert_chain = read_file_bytes(settings.ssl_settings.cert_file)
        cert = read_file_bytes(settings.ssl_settings.cert_key_file)
        ca_cert = read_file_bytes(settings.ssl_settings.ca_file)

        LOG.debug('Using a secure gRPC connection:')
        LOG.debug('    target: %s', target)
        LOG.debug('    root_certificates: contents of %s',
                  settings.ssl_settings.ca_file)
        LOG.debug('    private_key: contents of %s',
                  settings.ssl_settings.cert_key_file)
        LOG.debug('    certificate_chain: contents of %s',
                  settings.ssl_settings.cert_file)

        credentials = ssl_channel_credentials(root_certificates=ca_cert,
                                              private_key=cert,
                                              certificate_chain=cert_chain)
        return secure_channel(target, credentials, options)
    else:
        LOG.debug('Using an insecure gRPC connection...')
        return insecure_channel(target, options)
Example #3
0
def grpc_create_channel(settings: "HTTPConnectionSettings") -> Channel:
    target = f"{settings.host}:{settings.port}"
    options = [("grpc.max_send_message_length", -1),
               ("grpc.max_receive_message_length", -1)]

    if not settings.enable_http_proxy:
        options.append(("grpc.enable_http_proxy", 0))

    if settings.oauth:
        from google.auth.transport.grpc import secure_authorized_channel
        from google.auth.transport.requests import Request as RefreshRequester
        from google.oauth2.credentials import Credentials as OAuthCredentials

        LOG.debug("Using a secure gRPC connection over OAuth:")

        credentials = OAuthCredentials(
            token=settings.oauth.token,
            refresh_token=settings.oauth.refresh_token,
            id_token=settings.oauth.id_token,
            token_uri=settings.oauth.token_uri,
            client_id=settings.oauth.client_id,
            client_secret=settings.oauth.client_secret,
        )

        ssl_credentials = None
        if settings.ssl_settings:
            cert_chain = read_file_bytes(settings.ssl_settings.cert_file)
            cert = read_file_bytes(settings.ssl_settings.cert_key_file)
            ca_cert = read_file_bytes(settings.ssl_settings.ca_file)

            LOG.debug("Using a secure gRPC connection:")
            LOG.debug("    target: %s", target)
            LOG.debug("    root_certificates: contents of %s",
                      settings.ssl_settings.ca_file)
            LOG.debug("    private_key: contents of %s",
                      settings.ssl_settings.cert_key_file)
            LOG.debug("    certificate_chain: contents of %s",
                      settings.ssl_settings.cert_file)

            ssl_credentials = ssl_channel_credentials(
                root_certificates=ca_cert,
                private_key=cert,
                certificate_chain=cert_chain)
        return secure_authorized_channel(
            credentials,
            RefreshRequester(),
            target,
            ssl_credentials=ssl_credentials,
            options=options,
        )

    if settings.ssl_settings:
        cert_chain = read_file_bytes(settings.ssl_settings.cert_file)
        cert = read_file_bytes(settings.ssl_settings.cert_key_file)
        ca_cert = read_file_bytes(settings.ssl_settings.ca_file)

        LOG.debug("Using a secure gRPC connection:")
        LOG.debug("    target: %s", target)
        LOG.debug("    root_certificates: contents of %s",
                  settings.ssl_settings.ca_file)
        LOG.debug("    private_key: contents of %s",
                  settings.ssl_settings.cert_key_file)
        LOG.debug("    certificate_chain: contents of %s",
                  settings.ssl_settings.cert_file)

        credentials = ssl_channel_credentials(root_certificates=ca_cert,
                                              private_key=cert,
                                              certificate_chain=cert_chain)
        return secure_channel(target, credentials, options)
    else:
        LOG.debug("Using an insecure gRPC connection...")
        return insecure_channel(target, options)
Example #4
0
 async def async_get_creds(self) -> Credentials:
     """Return creds for subscriber API."""
     token = await self.async_get_access_token()
     return OAuthCredentials(token=token)