Example #1
0
    def _get_channel(self,
                     config: metadata_store_pb2.MetadataStoreClientConfig):
        """Configures the channel, which could be secure or insecure.

    It returns a channel that can be specified to be secure or insecure,
    depending on whether ssl_config is specified in the config.

    Args:
      config: metadata_store_pb2.MetadataStoreClientConfig.

    Returns:
      an initialized gRPC channel.
    """
        target = ':'.join([config.host, str(config.port)])

        if not config.HasField('ssl_config'):
            return grpc.insecure_channel(target)

        root_certificates = None
        private_key = None
        certificate_chain = None
        if config.ssl_config.HasField('custom_ca'):
            root_certificates = bytes(
                str(config.ssl_config.custom_ca).encode('ascii'))
        if config.ssl_config.HasField('client_key'):
            private_key = bytes(
                str(config.ssl_config.client_key).encode('ascii'))
        if config.ssl_config.HasField('server_cert'):
            certificate_chain = bytes(
                str(config.ssl_config.server_cert).encode('ascii'))
        credentials = grpc.ssl_channel_credentials(root_certificates,
                                                   private_key,
                                                   certificate_chain)
        return grpc.secure_channel(target, credentials)
Example #2
0
    def _get_channel(self,
                     config: metadata_store_pb2.MetadataStoreClientConfig,
                     target: Text):
        """Configures the channel, which could be secure or insecure.

    It returns a channel that can be specified to be secure or insecure,
    depending on whether ssl_config is specified in the config.

    Args:
      config: metadata_store_pb2.MetadataStoreClientConfig.
      target: target host with port.

    Returns:
      an initialized gRPC channel.
    """
        if not config.HasField('ssl_config'):
            return grpc.insecure_channel(target)

        root_certificates = None
        private_key = None
        certificate_chain = None
        if config.ssl_config.HasField('custom_ca'):
            root_certificates = config.ssl_config.custom_ca
        if config.ssl_config.HasField('client_key'):
            private_key = config.ssl_config.client_key
        if config.ssl_config.HasField('server_cert'):
            certificate_chain = config.ssl_config.server_cert
        credentials = grpc.ssl_channel_credentials(root_certificates,
                                                   private_key,
                                                   certificate_chain)
        return grpc.secure_channel(target, credentials)
Example #3
0
    def _get_channel(self,
                     config: metadata_store_pb2.MetadataStoreClientConfig):
        """Configures the channel, which could be secure or insecure.

    It returns a channel that can be specified to be secure or insecure,
    depending on whether ssl_config is specified in the config.

    Args:
      config: metadata_store_pb2.MetadataStoreClientConfig.

    Returns:
      an initialized gRPC channel.
    """
        target = ':'.join([config.host, str(config.port)])

        if config.HasField('client_timeout_sec'):
            self._grpc_timeout_sec = config.client_timeout_sec

        options = None
        if (config.HasField('channel_arguments')
                and config.channel_arguments.HasField(
                    'max_receive_message_length')):
            options = [('grpc.max_receive_message_length',
                        config.channel_arguments.max_receive_message_length)]

        if not config.HasField('ssl_config'):
            return grpc.insecure_channel(target, options=options)

        root_certificates = None
        private_key = None
        certificate_chain = None
        if config.ssl_config.HasField('custom_ca'):
            root_certificates = bytes(
                str(config.ssl_config.custom_ca).encode('ascii'))
        if config.ssl_config.HasField('client_key'):
            private_key = bytes(
                str(config.ssl_config.client_key).encode('ascii'))
        if config.ssl_config.HasField('server_cert'):
            certificate_chain = bytes(
                str(config.ssl_config.server_cert).encode('ascii'))
        credentials = grpc.ssl_channel_credentials(root_certificates,
                                                   private_key,
                                                   certificate_chain)
        return grpc.secure_channel(target, credentials, options=options)
Example #4
0
    def _connect():
        def establish_connection(store):
            """Ensure connection to MLMD store by making a request."""
            try:
                _ = store.get_context_types()
                return True
            except Exception as e:
                log.warning(
                    "Failed to access the Metadata store. Exception:"
                    " '%s'", str(e))
            return False

        metadata_service_host = os.environ.get(
            METADATA_GRPC_SERVICE_SERVICE_HOST_ENV,
            DEFAULT_METADATA_GRPC_SERVICE_SERVICE_HOST)
        metadata_service_port = int(
            os.environ.get(METADATA_GRPC_SERVICE_SERVICE_PORT_ENV,
                           DEFAULT_METADATA_GRPC_SERVICE_SERVICE_PORT))
        metadata_service_max_msg = int(
            os.environ.get(METADATA_GRPC_MAX_RECEIVE_MESSAGE_LENGTH_ENV,
                           DEFAULT_METADATA_GRPC_MAX_RECEIVE_MESSAGE_LENGTH))

        metadata_service_channel_args = GrpcChannelArguments(
            max_receive_message_length=metadata_service_max_msg)

        mlmd_connection_config = MetadataStoreClientConfig(
            host=metadata_service_host,
            port=metadata_service_port,
            channel_arguments=metadata_service_channel_args)
        mlmd_store = metadata_store.MetadataStore(mlmd_connection_config)

        # We ensure that the connection to MLMD is established by retrying a
        # number of times and sleeping for 1 second between the tries.
        # These numbers are taken from the MetadataWriter implementation.
        for _ in range(100):
            if establish_connection(mlmd_store):
                return mlmd_store
            time.sleep(1)

        raise RuntimeError("Could not connect to the Metadata store.")