def setUp(self): super().setUp() # Set up MLMD connection. pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) with self._mlmd_connection as m: self._execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=metadata_store_pb2.ExecutionType( name='test_execution_type'), contexts=[], input_artifacts=[]) # Set up gRPC stub. port = portpicker.pick_unused_port() self.sidecar = execution_watcher.ExecutionWatcher( port, mlmd_connection=self._mlmd_connection, execution=self._execution, creds=grpc.local_server_credentials()) self.sidecar.start() self.stub = execution_watcher.generate_service_stub( self.sidecar.address, grpc.local_channel_credentials())
def error_sanitizing_interceptor_dummy_api(TestRpc): with futures.ThreadPoolExecutor(1) as executor: server = grpc.server(executor, interceptors=[ErrorSanitizationInterceptor()]) port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) # manually add the handler rpc_method_handlers = { "TestRpc": grpc.unary_unary_rpc_method_handler( TestRpc, request_deserializer=empty_pb2.Empty.FromString, response_serializer=empty_pb2.Empty.SerializeToString, ) } generic_handler = grpc.method_handlers_generic_handler( "testing.Test", rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler, )) server.start() try: with grpc.secure_channel( f"localhost:{port}", grpc.local_channel_credentials()) as channel: call_rpc = channel.unary_unary( "/testing.Test/TestRpc", request_serializer=empty_pb2.Empty.SerializeToString, response_deserializer=empty_pb2.Empty.FromString, ) yield call_rpc finally: server.stop(None).wait()
def _local_composite_credentials(self): """ Creates the credentials for the local emulator channel :return: grpc.ChannelCredentials """ credentials = google.auth.credentials.with_scopes_if_required( self._credentials, None ) request = google.auth.transport.requests.Request() # Create the metadata plugin for inserting the authorization header. metadata_plugin = google.auth.transport.grpc.AuthMetadataPlugin( credentials, request ) # Create a set of grpc.CallCredentials using the metadata plugin. google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin) # Using the local_credentials to allow connection to emulator local_credentials = grpc.local_channel_credentials() # Combine the local credentials and the authorization credentials. return grpc.composite_channel_credentials( local_credentials, google_auth_credentials )
def _run(flags): """Runs the main uploader program given parsed flags. Args: flags: An `argparse.Namespace`. """ logging.set_stderrthreshold(logging.WARNING) intent = _get_intent(flags) store = auth.CredentialsStore() if isinstance(intent, _AuthRevokeIntent): store.clear() sys.stderr.write("Logged out of uploader.\n") sys.stderr.flush() return # TODO(b/141723268): maybe reconfirm Google Account prior to reuse. credentials = store.read_credentials() if not credentials: _prompt_for_user_ack(intent) client_config = json.loads(auth.OAUTH_CLIENT_CONFIG) flow = auth.build_installed_app_flow(client_config) credentials = flow.run(force_console=flags.auth_force_console) sys.stderr.write("\n") # Extra newline after auth flow messages. store.write_credentials(credentials) channel_options = None if flags.grpc_creds_type == "local": channel_creds = grpc.local_channel_credentials() elif flags.grpc_creds_type == "ssl": channel_creds = grpc.ssl_channel_credentials() elif flags.grpc_creds_type == "ssl_dev": channel_creds = grpc.ssl_channel_credentials(dev_creds.DEV_SSL_CERT) channel_options = [("grpc.ssl_target_name_override", "localhost")] else: msg = "Invalid --grpc_creds_type %s" % flags.grpc_creds_type raise base_plugin.FlagsError(msg) try: server_info = _get_server_info(flags) except server_info_lib.CommunicationError as e: _die(str(e)) _handle_server_info(server_info) if not server_info.api_server.endpoint: logging.error("Server info response: %s", server_info) _die("Internal error: frontend did not specify an API server") composite_channel_creds = grpc.composite_channel_credentials( channel_creds, auth.id_token_call_credentials(credentials) ) # TODO(@nfelt): In the `_UploadIntent` case, consider waiting until # logdir exists to open channel. channel = grpc.secure_channel( server_info.api_server.endpoint, composite_channel_creds, options=channel_options, ) with channel: intent.execute(server_info, channel)
def test_tracing_interceptor_auth_cookies(db): user, token = generate_user() account = Account() rpc_def = { "rpc": account.GetAccountInfo, "service_name": "org.couchers.api.account.Account", "method_name": "GetAccountInfo", "interceptors": [TracingInterceptor(), AuthValidatorInterceptor()], "request_type": empty_pb2.Empty, "response_type": account_pb2.GetAccountInfoRes, } # with cookies with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),)) assert res1.username == user.username with session_scope() as session: trace = session.execute(select(APICall)).scalar_one() assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo" assert not trace.status_code assert trace.user_id == user.id assert not trace.is_api_key assert len(trace.request) == 0 assert not trace.traceback
def auth_api_session(): """ Create an Auth API for testing This needs to use the real server since it plays around with headers """ with futures.ThreadPoolExecutor(1) as executor: server = grpc.server(executor) port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) auth_pb2_grpc.add_AuthServicer_to_server(Auth(), server) server.start() try: with grpc.secure_channel( f"localhost:{port}", grpc.local_channel_credentials()) as channel: class _MetadataKeeperInterceptor( grpc.UnaryUnaryClientInterceptor): def __init__(self): self.latest_headers = {} def intercept_unary_unary(self, continuation, client_call_details, request): call = continuation(client_call_details, request) self.latest_headers = dict(call.initial_metadata()) return call metadata_interceptor = _MetadataKeeperInterceptor() channel = grpc.intercept_channel(channel, metadata_interceptor) yield auth_pb2_grpc.AuthStub(channel), metadata_interceptor finally: server.stop(None).wait()
def real_jail_session(token): """ Create a Jail service for testing, using TCP sockets, uses the token for auth """ auth_interceptor = Auth().get_auth_interceptor(allow_jailed=True) with futures.ThreadPoolExecutor(1) as executor: server = grpc.server(executor, interceptors=[auth_interceptor]) port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) servicer = Jail() jail_pb2_grpc.add_JailServicer_to_server(servicer, server) server.start() call_creds = grpc.metadata_call_credentials( CookieMetadataPlugin(token)) comp_creds = grpc.composite_channel_credentials( grpc.local_channel_credentials(), call_creds) try: with grpc.secure_channel(f"localhost:{port}", comp_creds) as channel: yield jail_pb2_grpc.JailStub(channel) finally: server.stop(None).wait()
async def coro(): server_target, _ = await start_test_server(secure=True) # pylint: disable=unused-variable credentials = grpc.local_channel_credentials( grpc.LocalConnectionType.LOCAL_TCP) secure_channel = aio.secure_channel(server_target, credentials) self.assertIsInstance(secure_channel, aio.Channel)
def interceptor_dummy_api( rpc, interceptors, service_name="testing.Test", method_name="TestRpc", request_type=empty_pb2.Empty, response_type=empty_pb2.Empty, creds=None, ): with futures.ThreadPoolExecutor(1) as executor: server = grpc.server(executor, interceptors=interceptors) port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) # manually add the handler rpc_method_handlers = { method_name: grpc.unary_unary_rpc_method_handler( rpc, request_deserializer=request_type.FromString, response_serializer=response_type.SerializeToString, ) } generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) server.start() try: with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel: call_rpc = channel.unary_unary( f"/{service_name}/{method_name}", request_serializer=request_type.SerializeToString, response_deserializer=response_type.FromString, ) yield call_rpc finally: server.stop(None).wait()
def _service(self): if isinstance(self._expansion_service, str): channel_options = [("grpc.max_receive_message_length", -1), ("grpc.max_send_message_length", -1)] if hasattr(grpc, 'local_channel_credentials'): # Some environments may not support insecure channels. Hence use a # secure channel with local credentials here. # TODO: update this to support secure non-local channels. channel_factory_fn = functools.partial( grpc.secure_channel, self._expansion_service, grpc.local_channel_credentials(), options=channel_options) else: # local_channel_credentials is an experimental API which is unsupported # by older versions of grpc which may be pulled in due to other project # dependencies. channel_factory_fn = functools.partial(grpc.insecure_channel, self._expansion_service, options=channel_options) with channel_factory_fn() as channel: yield ExpansionAndArtifactRetrievalStub(channel) elif hasattr(self._expansion_service, 'Expand'): yield self._expansion_service else: with self._expansion_service as stub: yield stub
def test_tracing_interceptor_auth_api_key(db): super_user, super_token = generate_user(is_superuser=True) user, token = generate_user() with real_admin_session(super_token) as api: api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username)) with session_scope() as session: api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one() api_key = api_session.token account = Account() rpc_def = { "rpc": account.GetAccountInfo, "service_name": "org.couchers.api.account.Account", "method_name": "GetAccountInfo", "interceptors": [TracingInterceptor(), AuthValidatorInterceptor()], "request_type": empty_pb2.Empty, "response_type": account_pb2.GetAccountInfoRes, } # with api key with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),)) assert res1.username == user.username with session_scope() as session: trace = session.execute(select(APICall)).scalar_one() assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo" assert not trace.status_code assert trace.user_id == user.id assert trace.is_api_key assert len(trace.request) == 0 assert not trace.traceback
def __init__(self, test_name: Text): self._server = {} self._output_file = {} self._channel = {} self._metadata_stub = {} self._keyset_stub = {} self._aead_stub = {} self._daead_stub = {} self._streaming_aead_stub = {} self._hybrid_stub = {} self._mac_stub = {} self._signature_stub = {} self._prf_stub = {} self._jwt_stub = {} for lang in LANGUAGES: port = portpicker.pick_unused_port() cmd = _server_cmd(lang, port) logging.info('cmd = %s', cmd) try: output_dir = os.environ['TEST_UNDECLARED_OUTPUTS_DIR'] except KeyError: raise RuntimeError( 'Could not start %s server, TEST_UNDECLARED_OUTPUTS_DIR environment' 'variable must be set') output_file = '%s-%s-%s' % (test_name, lang, 'server.log') output_path = os.path.join(output_dir, output_file) logging.info('writing server output to %s', output_path) try: self._output_file[lang] = open(output_path, 'w+') except IOError: logging.info('unable to open server output file %s', output_path) raise RuntimeError('Could not start %s server' % lang) self._server[lang] = subprocess.Popen( cmd, stdout=self._output_file[lang], stderr=subprocess.STDOUT) logging.info('%s server started on port %d with pid: %d.', lang, port, self._server[lang].pid) self._channel[lang] = grpc.secure_channel( '[::]:%d' % port, grpc.local_channel_credentials()) for lang in LANGUAGES: try: grpc.channel_ready_future( self._channel[lang]).result(timeout=30) except: logging.info('Timeout while connecting to server %s', lang) self._server[lang].kill() out, err = self._server[lang].communicate() raise RuntimeError( 'Could not start %s server, output=%s, err=%s' % (lang, out, err)) self._metadata_stub[lang] = testing_api_pb2_grpc.MetadataStub( self._channel[lang]) self._keyset_stub[lang] = testing_api_pb2_grpc.KeysetStub( self._channel[lang]) for primitive in _PRIMITIVES: for lang in SUPPORTED_LANGUAGES_BY_PRIMITIVE[primitive]: stub_name = '_%s_stub' % primitive getattr(self, stub_name)[lang] = _PRIMITIVE_STUBS[primitive]( self._channel[lang])
def _make_provider(addr): options = [ ("grpc.max_receive_message_length", 1024 * 1024 * 256), ] creds = grpc.local_channel_credentials() channel = grpc.secure_channel(addr, creds, options=options) stub = grpc_provider.make_stub(channel) return grpc_provider.GrpcDataProvider(addr, stub)
def test_unary_unary_secure(self): with _server(grpc.local_server_credentials()) as port: target = f'localhost:{port}' response = grpc.experimental.unary_unary( _REQUEST, target, _UNARY_UNARY, channel_credentials=grpc.local_channel_credentials()) self.assertEqual(_REQUEST, response)
def main(_): channel_creds = grpc.local_channel_credentials() channel = grpc.secure_channel(FLAGS.server_address, channel_creds) grpc.channel_ready_future(channel).result(timeout=10) stub = environment_pb2_grpc.EnvironmentServiceStub(channel) request = environment_pb2.GetQueryRequest() response = stub.GetQuery(request, timeout=10) logging.info('\n\nReceived GetQueryResponse:\n%s\n', response)
def test_unary_stream(self): with _server(grpc.local_server_credentials()) as port: target = f'localhost:{port}' for response in grpc.experimental.unary_stream( _REQUEST, target, _UNARY_STREAM, channel_credentials=grpc.local_channel_credentials()): self.assertEqual(_REQUEST, response)
def start_bundle(self): env_descriptor = env.get_descriptor() self.environment = env.create_environment( task=42, training=False, stop_after_seeing_new_results=FLAGS.stop_after_seeing_new_results) self.mzconfig = agent_lib.muzeroconfig_from_flags( env_descriptor=env_descriptor) self.mzconfig.max_num_action_expansion = 100 self.initial_inference_stub = ( prediction_service_pb2_grpc.PredictionServiceStub( grpc.secure_channel(FLAGS.initial_inference_model_server_spec, grpc.local_channel_credentials()))) self.recurrent_inference_stub = ( prediction_service_pb2_grpc.PredictionServiceStub( grpc.secure_channel(FLAGS.recurrent_inference_model_server_spec, grpc.local_channel_credentials())))
def test_document_understanding_service_grpc_asyncio_transport_channel(): channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DocumentUnderstandingServiceGrpcAsyncIOTransport( host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None
def test_quota_controller_grpc_asyncio_transport_channel(): channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.QuotaControllerGrpcAsyncIOTransport( host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None
def test_image_annotator_grpc_transport_channel(): channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.ImageAnnotatorGrpcTransport( host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None
def test_video_intelligence_service_grpc_asyncio_transport_channel(): channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.VideoIntelligenceServiceGrpcAsyncIOTransport( host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None
def test_create_secure_channel_and_connect(self): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) port = server.add_secure_port('[::]:0', grpc.local_server_credentials()) server.start() self.assertIsNotNone( dm_env_rpc_connection.create_secure_channel_and_connect( f'[::]:{port}', grpc.local_channel_credentials())) server.stop(grace=None)
def test_prediction_service_grpc_transport_channel(): channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PredictionServiceGrpcTransport( host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None
def test_insecure_sugar_mutually_exclusive(self): with _server(None) as port: target = f'localhost:{port}' with self.assertRaises(ValueError): response = grpc.experimental.unary_unary( _REQUEST, target, _UNARY_UNARY, insecure=True, channel_credentials=grpc.local_channel_credentials())
def test_policy_tag_manager_serialization_grpc_asyncio_transport_channel(): channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PolicyTagManagerSerializationGrpcAsyncIOTransport( host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None
def test_customer_license_service_grpc_asyncio_transport_channel(): channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.CustomerLicenseServiceGrpcAsyncIOTransport( host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None
def main(_): pygame.init() port = portpicker.pick_unused_port() server = _start_server(port) with grpc.secure_channel('localhost:{}'.format(port), grpc.local_channel_credentials()) as channel: grpc.channel_ready_future(channel).result() with dm_env_rpc_connection.Connection(channel) as connection: response = connection.send(dm_env_rpc_pb2.CreateWorldRequest()) world_name = response.world_name response = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name)) specs = response.specs with dm_env_adaptor.DmEnvAdaptor(connection, specs) as dm_env: window_surface = pygame.display.set_mode((800, 600), 0, 32) pygame.display.set_caption('Catch Human Agent') keep_running = True while keep_running: requested_action = _ACTION_NOTHING for event in pygame.event.get(): if event.type == pygame.QUIT: keep_running = False break elif event.type == pygame.KEYDOWN: if event.key == pygame.K_LEFT: requested_action = _ACTION_LEFT elif event.key == pygame.K_RIGHT: requested_action = _ACTION_RIGHT elif event.key == pygame.K_ESCAPE: keep_running = False break actions = {_ACTION_PADDLE: requested_action} step_result = dm_env.step(actions) board = step_result.observation[_OBSERVATION_BOARD] reward = step_result.observation[_OBSERVATION_REWARD] _render_window(board, window_surface, reward) pygame.display.update() pygame.time.wait(_FRAME_DELAY_MS) connection.send(dm_env_rpc_pb2.LeaveWorldRequest()) connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name)) server.stop(None)
def _connect_to_environment(port, settings): """Helper function for connecting to a running dm_memorytask environment.""" if settings.level_name not in MEMORY_TASK_LEVEL_NAMES: raise ValueError( 'Level named "{}" is not a valid dm_memorytask level.'.format( settings.level_name)) server_address = 'localhost:{}'.format(port) credentials = grpc.local_channel_credentials() channel, connection = _create_channel_and_connection(server_address, credentials) return _make_environment_connection(channel, connection, settings)
def test_text_to_speech_grpc_transport_channel(): channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.TextToSpeechGrpcTransport( host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None
def test_authorized_domains_grpc_asyncio_transport_channel(): channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.AuthorizedDomainsGrpcAsyncIOTransport( host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None