Exemple #1
0
 def setUp(self):
     super().setUp()
     # Set up MLMD connection.
     pipeline_root = os.path.join(
         os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
         self.id())
     metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
     connection_config = metadata.sqlite_metadata_connection_config(
         metadata_path)
     connection_config.sqlite.SetInParent()
     self._mlmd_connection = metadata.Metadata(
         connection_config=connection_config)
     with self._mlmd_connection as m:
         self._execution = execution_publish_utils.register_execution(
             metadata_handler=m,
             execution_type=metadata_store_pb2.ExecutionType(
                 name='test_execution_type'),
             contexts=[],
             input_artifacts=[])
     # Set up gRPC stub.
     port = portpicker.pick_unused_port()
     self.sidecar = execution_watcher.ExecutionWatcher(
         port,
         mlmd_connection=self._mlmd_connection,
         execution=self._execution,
         creds=grpc.local_server_credentials())
     self.sidecar.start()
     self.stub = execution_watcher.generate_service_stub(
         self.sidecar.address, grpc.local_channel_credentials())
def error_sanitizing_interceptor_dummy_api(TestRpc):
    with futures.ThreadPoolExecutor(1) as executor:
        server = grpc.server(executor,
                             interceptors=[ErrorSanitizationInterceptor()])
        port = server.add_secure_port("localhost:0",
                                      grpc.local_server_credentials())

        # manually add the handler
        rpc_method_handlers = {
            "TestRpc":
            grpc.unary_unary_rpc_method_handler(
                TestRpc,
                request_deserializer=empty_pb2.Empty.FromString,
                response_serializer=empty_pb2.Empty.SerializeToString,
            )
        }
        generic_handler = grpc.method_handlers_generic_handler(
            "testing.Test", rpc_method_handlers)
        server.add_generic_rpc_handlers((generic_handler, ))
        server.start()

        try:
            with grpc.secure_channel(
                    f"localhost:{port}",
                    grpc.local_channel_credentials()) as channel:
                call_rpc = channel.unary_unary(
                    "/testing.Test/TestRpc",
                    request_serializer=empty_pb2.Empty.SerializeToString,
                    response_deserializer=empty_pb2.Empty.FromString,
                )
                yield call_rpc
        finally:
            server.stop(None).wait()
Exemple #3
0
    def _local_composite_credentials(self):
        """
        Creates the credentials for the local emulator channel
        :return: grpc.ChannelCredentials
        """
        credentials = google.auth.credentials.with_scopes_if_required(
            self._credentials, None
        )
        request = google.auth.transport.requests.Request()

        # Create the metadata plugin for inserting the authorization header.
        metadata_plugin = google.auth.transport.grpc.AuthMetadataPlugin(
            credentials, request
        )

        # Create a set of grpc.CallCredentials using the metadata plugin.
        google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin)

        # Using the local_credentials to allow connection to emulator
        local_credentials = grpc.local_channel_credentials()

        # Combine the local credentials and the authorization credentials.
        return grpc.composite_channel_credentials(
            local_credentials, google_auth_credentials
        )
def _run(flags):
    """Runs the main uploader program given parsed flags.

    Args:
      flags: An `argparse.Namespace`.
    """

    logging.set_stderrthreshold(logging.WARNING)
    intent = _get_intent(flags)

    store = auth.CredentialsStore()
    if isinstance(intent, _AuthRevokeIntent):
        store.clear()
        sys.stderr.write("Logged out of uploader.\n")
        sys.stderr.flush()
        return
    # TODO(b/141723268): maybe reconfirm Google Account prior to reuse.
    credentials = store.read_credentials()
    if not credentials:
        _prompt_for_user_ack(intent)
        client_config = json.loads(auth.OAUTH_CLIENT_CONFIG)
        flow = auth.build_installed_app_flow(client_config)
        credentials = flow.run(force_console=flags.auth_force_console)
        sys.stderr.write("\n")  # Extra newline after auth flow messages.
        store.write_credentials(credentials)

    channel_options = None
    if flags.grpc_creds_type == "local":
        channel_creds = grpc.local_channel_credentials()
    elif flags.grpc_creds_type == "ssl":
        channel_creds = grpc.ssl_channel_credentials()
    elif flags.grpc_creds_type == "ssl_dev":
        channel_creds = grpc.ssl_channel_credentials(dev_creds.DEV_SSL_CERT)
        channel_options = [("grpc.ssl_target_name_override", "localhost")]
    else:
        msg = "Invalid --grpc_creds_type %s" % flags.grpc_creds_type
        raise base_plugin.FlagsError(msg)

    try:
        server_info = _get_server_info(flags)
    except server_info_lib.CommunicationError as e:
        _die(str(e))
    _handle_server_info(server_info)

    if not server_info.api_server.endpoint:
        logging.error("Server info response: %s", server_info)
        _die("Internal error: frontend did not specify an API server")
    composite_channel_creds = grpc.composite_channel_credentials(
        channel_creds, auth.id_token_call_credentials(credentials)
    )

    # TODO(@nfelt): In the `_UploadIntent` case, consider waiting until
    # logdir exists to open channel.
    channel = grpc.secure_channel(
        server_info.api_server.endpoint,
        composite_channel_creds,
        options=channel_options,
    )
    with channel:
        intent.execute(server_info, channel)
Exemple #5
0
def test_tracing_interceptor_auth_cookies(db):
    user, token = generate_user()

    account = Account()

    rpc_def = {
        "rpc": account.GetAccountInfo,
        "service_name": "org.couchers.api.account.Account",
        "method_name": "GetAccountInfo",
        "interceptors": [TracingInterceptor(), AuthValidatorInterceptor()],
        "request_type": empty_pb2.Empty,
        "response_type": account_pb2.GetAccountInfoRes,
    }

    # with cookies
    with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
        res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
    assert res1.username == user.username

    with session_scope() as session:
        trace = session.execute(select(APICall)).scalar_one()
        assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
        assert not trace.status_code
        assert trace.user_id == user.id
        assert not trace.is_api_key
        assert len(trace.request) == 0
        assert not trace.traceback
Exemple #6
0
def auth_api_session():
    """
    Create an Auth API for testing

    This needs to use the real server since it plays around with headers
    """
    with futures.ThreadPoolExecutor(1) as executor:
        server = grpc.server(executor)
        port = server.add_secure_port("localhost:0",
                                      grpc.local_server_credentials())
        auth_pb2_grpc.add_AuthServicer_to_server(Auth(), server)
        server.start()

        try:
            with grpc.secure_channel(
                    f"localhost:{port}",
                    grpc.local_channel_credentials()) as channel:

                class _MetadataKeeperInterceptor(
                        grpc.UnaryUnaryClientInterceptor):
                    def __init__(self):
                        self.latest_headers = {}

                    def intercept_unary_unary(self, continuation,
                                              client_call_details, request):
                        call = continuation(client_call_details, request)
                        self.latest_headers = dict(call.initial_metadata())
                        return call

                metadata_interceptor = _MetadataKeeperInterceptor()
                channel = grpc.intercept_channel(channel, metadata_interceptor)
                yield auth_pb2_grpc.AuthStub(channel), metadata_interceptor
        finally:
            server.stop(None).wait()
Exemple #7
0
def real_jail_session(token):
    """
    Create a Jail service for testing, using TCP sockets, uses the token for auth
    """
    auth_interceptor = Auth().get_auth_interceptor(allow_jailed=True)

    with futures.ThreadPoolExecutor(1) as executor:
        server = grpc.server(executor, interceptors=[auth_interceptor])
        port = server.add_secure_port("localhost:0",
                                      grpc.local_server_credentials())
        servicer = Jail()
        jail_pb2_grpc.add_JailServicer_to_server(servicer, server)
        server.start()

        call_creds = grpc.metadata_call_credentials(
            CookieMetadataPlugin(token))
        comp_creds = grpc.composite_channel_credentials(
            grpc.local_channel_credentials(), call_creds)

        try:
            with grpc.secure_channel(f"localhost:{port}",
                                     comp_creds) as channel:
                yield jail_pb2_grpc.JailStub(channel)
        finally:
            server.stop(None).wait()
Exemple #8
0
        async def coro():
            server_target, _ = await start_test_server(secure=True)  # pylint: disable=unused-variable
            credentials = grpc.local_channel_credentials(
                grpc.LocalConnectionType.LOCAL_TCP)
            secure_channel = aio.secure_channel(server_target, credentials)

            self.assertIsInstance(secure_channel, aio.Channel)
Exemple #9
0
def interceptor_dummy_api(
    rpc,
    interceptors,
    service_name="testing.Test",
    method_name="TestRpc",
    request_type=empty_pb2.Empty,
    response_type=empty_pb2.Empty,
    creds=None,
):
    with futures.ThreadPoolExecutor(1) as executor:
        server = grpc.server(executor, interceptors=interceptors)
        port = server.add_secure_port("localhost:0", grpc.local_server_credentials())

        # manually add the handler
        rpc_method_handlers = {
            method_name: grpc.unary_unary_rpc_method_handler(
                rpc,
                request_deserializer=request_type.FromString,
                response_serializer=response_type.SerializeToString,
            )
        }
        generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers)
        server.add_generic_rpc_handlers((generic_handler,))
        server.start()

        try:
            with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel:
                call_rpc = channel.unary_unary(
                    f"/{service_name}/{method_name}",
                    request_serializer=request_type.SerializeToString,
                    response_deserializer=response_type.FromString,
                )
                yield call_rpc
        finally:
            server.stop(None).wait()
Exemple #10
0
 def _service(self):
     if isinstance(self._expansion_service, str):
         channel_options = [("grpc.max_receive_message_length", -1),
                            ("grpc.max_send_message_length", -1)]
         if hasattr(grpc, 'local_channel_credentials'):
             # Some environments may not support insecure channels. Hence use a
             # secure channel with local credentials here.
             # TODO: update this to support secure non-local channels.
             channel_factory_fn = functools.partial(
                 grpc.secure_channel,
                 self._expansion_service,
                 grpc.local_channel_credentials(),
                 options=channel_options)
         else:
             # local_channel_credentials is an experimental API which is unsupported
             # by older versions of grpc which may be pulled in due to other project
             # dependencies.
             channel_factory_fn = functools.partial(grpc.insecure_channel,
                                                    self._expansion_service,
                                                    options=channel_options)
         with channel_factory_fn() as channel:
             yield ExpansionAndArtifactRetrievalStub(channel)
     elif hasattr(self._expansion_service, 'Expand'):
         yield self._expansion_service
     else:
         with self._expansion_service as stub:
             yield stub
Exemple #11
0
def test_tracing_interceptor_auth_api_key(db):
    super_user, super_token = generate_user(is_superuser=True)
    user, token = generate_user()

    with real_admin_session(super_token) as api:
        api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))

    with session_scope() as session:
        api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one()
        api_key = api_session.token

    account = Account()

    rpc_def = {
        "rpc": account.GetAccountInfo,
        "service_name": "org.couchers.api.account.Account",
        "method_name": "GetAccountInfo",
        "interceptors": [TracingInterceptor(), AuthValidatorInterceptor()],
        "request_type": empty_pb2.Empty,
        "response_type": account_pb2.GetAccountInfoRes,
    }

    # with api key
    with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
        res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
    assert res1.username == user.username

    with session_scope() as session:
        trace = session.execute(select(APICall)).scalar_one()
        assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
        assert not trace.status_code
        assert trace.user_id == user.id
        assert trace.is_api_key
        assert len(trace.request) == 0
        assert not trace.traceback
Exemple #12
0
 def __init__(self, test_name: Text):
     self._server = {}
     self._output_file = {}
     self._channel = {}
     self._metadata_stub = {}
     self._keyset_stub = {}
     self._aead_stub = {}
     self._daead_stub = {}
     self._streaming_aead_stub = {}
     self._hybrid_stub = {}
     self._mac_stub = {}
     self._signature_stub = {}
     self._prf_stub = {}
     self._jwt_stub = {}
     for lang in LANGUAGES:
         port = portpicker.pick_unused_port()
         cmd = _server_cmd(lang, port)
         logging.info('cmd = %s', cmd)
         try:
             output_dir = os.environ['TEST_UNDECLARED_OUTPUTS_DIR']
         except KeyError:
             raise RuntimeError(
                 'Could not start %s server, TEST_UNDECLARED_OUTPUTS_DIR environment'
                 'variable must be set')
         output_file = '%s-%s-%s' % (test_name, lang, 'server.log')
         output_path = os.path.join(output_dir, output_file)
         logging.info('writing server output to %s', output_path)
         try:
             self._output_file[lang] = open(output_path, 'w+')
         except IOError:
             logging.info('unable to open server output file %s',
                          output_path)
             raise RuntimeError('Could not start %s server' % lang)
         self._server[lang] = subprocess.Popen(
             cmd, stdout=self._output_file[lang], stderr=subprocess.STDOUT)
         logging.info('%s server started on port %d with pid: %d.', lang,
                      port, self._server[lang].pid)
         self._channel[lang] = grpc.secure_channel(
             '[::]:%d' % port, grpc.local_channel_credentials())
     for lang in LANGUAGES:
         try:
             grpc.channel_ready_future(
                 self._channel[lang]).result(timeout=30)
         except:
             logging.info('Timeout while connecting to server %s', lang)
             self._server[lang].kill()
             out, err = self._server[lang].communicate()
             raise RuntimeError(
                 'Could not start %s server, output=%s, err=%s' %
                 (lang, out, err))
         self._metadata_stub[lang] = testing_api_pb2_grpc.MetadataStub(
             self._channel[lang])
         self._keyset_stub[lang] = testing_api_pb2_grpc.KeysetStub(
             self._channel[lang])
     for primitive in _PRIMITIVES:
         for lang in SUPPORTED_LANGUAGES_BY_PRIMITIVE[primitive]:
             stub_name = '_%s_stub' % primitive
             getattr(self, stub_name)[lang] = _PRIMITIVE_STUBS[primitive](
                 self._channel[lang])
Exemple #13
0
def _make_provider(addr):
    options = [
        ("grpc.max_receive_message_length", 1024 * 1024 * 256),
    ]
    creds = grpc.local_channel_credentials()
    channel = grpc.secure_channel(addr, creds, options=options)
    stub = grpc_provider.make_stub(channel)
    return grpc_provider.GrpcDataProvider(addr, stub)
Exemple #14
0
 def test_unary_unary_secure(self):
     with _server(grpc.local_server_credentials()) as port:
         target = f'localhost:{port}'
         response = grpc.experimental.unary_unary(
             _REQUEST,
             target,
             _UNARY_UNARY,
             channel_credentials=grpc.local_channel_credentials())
         self.assertEqual(_REQUEST, response)
Exemple #15
0
def main(_):
  channel_creds = grpc.local_channel_credentials()
  channel = grpc.secure_channel(FLAGS.server_address, channel_creds)
  grpc.channel_ready_future(channel).result(timeout=10)
  stub = environment_pb2_grpc.EnvironmentServiceStub(channel)

  request = environment_pb2.GetQueryRequest()
  response = stub.GetQuery(request, timeout=10)
  logging.info('\n\nReceived GetQueryResponse:\n%s\n', response)
Exemple #16
0
 def test_unary_stream(self):
     with _server(grpc.local_server_credentials()) as port:
         target = f'localhost:{port}'
         for response in grpc.experimental.unary_stream(
                 _REQUEST,
                 target,
                 _UNARY_STREAM,
                 channel_credentials=grpc.local_channel_credentials()):
             self.assertEqual(_REQUEST, response)
Exemple #17
0
  def start_bundle(self):
    env_descriptor = env.get_descriptor()
    self.environment = env.create_environment(
        task=42,
        training=False,
        stop_after_seeing_new_results=FLAGS.stop_after_seeing_new_results)
    self.mzconfig = agent_lib.muzeroconfig_from_flags(
        env_descriptor=env_descriptor)
    self.mzconfig.max_num_action_expansion = 100

    self.initial_inference_stub = (
        prediction_service_pb2_grpc.PredictionServiceStub(
            grpc.secure_channel(FLAGS.initial_inference_model_server_spec,
                                grpc.local_channel_credentials())))
    self.recurrent_inference_stub = (
        prediction_service_pb2_grpc.PredictionServiceStub(
            grpc.secure_channel(FLAGS.recurrent_inference_model_server_spec,
                                grpc.local_channel_credentials())))
def test_document_understanding_service_grpc_asyncio_transport_channel():
    channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials())

    # Check that channel is used if provided.
    transport = transports.DocumentUnderstandingServiceGrpcAsyncIOTransport(
        host="squid.clam.whelk", channel=channel,
    )
    assert transport.grpc_channel == channel
    assert transport._host == "squid.clam.whelk:443"
    assert transport._ssl_channel_credentials == None
def test_quota_controller_grpc_asyncio_transport_channel():
    channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials())

    # Check that channel is used if provided.
    transport = transports.QuotaControllerGrpcAsyncIOTransport(
        host="squid.clam.whelk", channel=channel,
    )
    assert transport.grpc_channel == channel
    assert transport._host == "squid.clam.whelk:443"
    assert transport._ssl_channel_credentials == None
Exemple #20
0
def test_image_annotator_grpc_transport_channel():
    channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials())

    # Check that channel is used if provided.
    transport = transports.ImageAnnotatorGrpcTransport(
        host="squid.clam.whelk", channel=channel,
    )
    assert transport.grpc_channel == channel
    assert transport._host == "squid.clam.whelk:443"
    assert transport._ssl_channel_credentials == None
def test_video_intelligence_service_grpc_asyncio_transport_channel():
    channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials())

    # Check that channel is used if provided.
    transport = transports.VideoIntelligenceServiceGrpcAsyncIOTransport(
        host="squid.clam.whelk", channel=channel,
    )
    assert transport.grpc_channel == channel
    assert transport._host == "squid.clam.whelk:443"
    assert transport._ssl_channel_credentials == None
Exemple #22
0
    def test_create_secure_channel_and_connect(self):
        server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        port = server.add_secure_port('[::]:0',
                                      grpc.local_server_credentials())
        server.start()

        self.assertIsNotNone(
            dm_env_rpc_connection.create_secure_channel_and_connect(
                f'[::]:{port}', grpc.local_channel_credentials()))
        server.stop(grace=None)
Exemple #23
0
def test_prediction_service_grpc_transport_channel():
    channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials())

    # Check that channel is used if provided.
    transport = transports.PredictionServiceGrpcTransport(
        host="squid.clam.whelk", channel=channel,
    )
    assert transport.grpc_channel == channel
    assert transport._host == "squid.clam.whelk:443"
    assert transport._ssl_channel_credentials == None
Exemple #24
0
 def test_insecure_sugar_mutually_exclusive(self):
     with _server(None) as port:
         target = f'localhost:{port}'
         with self.assertRaises(ValueError):
             response = grpc.experimental.unary_unary(
                 _REQUEST,
                 target,
                 _UNARY_UNARY,
                 insecure=True,
                 channel_credentials=grpc.local_channel_credentials())
def test_policy_tag_manager_serialization_grpc_asyncio_transport_channel():
    channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials())

    # Check that channel is used if provided.
    transport = transports.PolicyTagManagerSerializationGrpcAsyncIOTransport(
        host="squid.clam.whelk", channel=channel,
    )
    assert transport.grpc_channel == channel
    assert transport._host == "squid.clam.whelk:443"
    assert transport._ssl_channel_credentials == None
def test_customer_license_service_grpc_asyncio_transport_channel():
    channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials())

    # Check that channel is used if provided.
    transport = transports.CustomerLicenseServiceGrpcAsyncIOTransport(
        host="squid.clam.whelk",
        channel=channel,
    )
    assert transport.grpc_channel == channel
    assert transport._host == "squid.clam.whelk:443"
    assert transport._ssl_channel_credentials == None
def main(_):
    pygame.init()

    port = portpicker.pick_unused_port()
    server = _start_server(port)

    with grpc.secure_channel('localhost:{}'.format(port),
                             grpc.local_channel_credentials()) as channel:
        grpc.channel_ready_future(channel).result()
        with dm_env_rpc_connection.Connection(channel) as connection:
            response = connection.send(dm_env_rpc_pb2.CreateWorldRequest())
            world_name = response.world_name
            response = connection.send(
                dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name))
            specs = response.specs

            with dm_env_adaptor.DmEnvAdaptor(connection, specs) as dm_env:
                window_surface = pygame.display.set_mode((800, 600), 0, 32)
                pygame.display.set_caption('Catch Human Agent')

                keep_running = True
                while keep_running:
                    requested_action = _ACTION_NOTHING

                    for event in pygame.event.get():
                        if event.type == pygame.QUIT:
                            keep_running = False
                            break
                        elif event.type == pygame.KEYDOWN:
                            if event.key == pygame.K_LEFT:
                                requested_action = _ACTION_LEFT
                            elif event.key == pygame.K_RIGHT:
                                requested_action = _ACTION_RIGHT
                            elif event.key == pygame.K_ESCAPE:
                                keep_running = False
                                break

                    actions = {_ACTION_PADDLE: requested_action}
                    step_result = dm_env.step(actions)

                    board = step_result.observation[_OBSERVATION_BOARD]
                    reward = step_result.observation[_OBSERVATION_REWARD]

                    _render_window(board, window_surface, reward)

                    pygame.display.update()

                    pygame.time.wait(_FRAME_DELAY_MS)

            connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
            connection.send(
                dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))

    server.stop(None)
def _connect_to_environment(port, settings):
  """Helper function for connecting to a running dm_memorytask environment."""
  if settings.level_name not in MEMORY_TASK_LEVEL_NAMES:
    raise ValueError(
        'Level named "{}" is not a valid dm_memorytask level.'.format(
            settings.level_name))
  server_address = 'localhost:{}'.format(port)
  credentials = grpc.local_channel_credentials()
  channel, connection = _create_channel_and_connection(server_address,
                                                       credentials)
  return _make_environment_connection(channel, connection, settings)
def test_text_to_speech_grpc_transport_channel():
    channel = grpc.secure_channel("http://localhost/",
                                  grpc.local_channel_credentials())

    # Check that channel is used if provided.
    transport = transports.TextToSpeechGrpcTransport(
        host="squid.clam.whelk",
        channel=channel,
    )
    assert transport.grpc_channel == channel
    assert transport._host == "squid.clam.whelk:443"
    assert transport._ssl_channel_credentials == None
Exemple #30
0
def test_authorized_domains_grpc_asyncio_transport_channel():
    channel = aio.secure_channel("http://localhost/",
                                 grpc.local_channel_credentials())

    # Check that channel is used if provided.
    transport = transports.AuthorizedDomainsGrpcAsyncIOTransport(
        host="squid.clam.whelk",
        channel=channel,
    )
    assert transport.grpc_channel == channel
    assert transport._host == "squid.clam.whelk:443"
    assert transport._ssl_channel_credentials == None