Esempio n. 1
0
def test_roundtrip_errors():
    """Ensure that Flight errors propagate from server to client."""
    with flight_server(ErrorFlightServer) as server_location:
        client = flight.FlightClient.connect(server_location)
        with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
            list(client.do_action(flight.Action("internal", b"")))
        with pytest.raises(flight.FlightTimedOutError, match=".*foo.*"):
            list(client.do_action(flight.Action("timedout", b"")))
        with pytest.raises(flight.FlightCancelledError, match=".*foo.*"):
            list(client.do_action(flight.Action("cancel", b"")))
        with pytest.raises(flight.FlightUnauthenticatedError, match=".*foo.*"):
            list(client.do_action(flight.Action("unauthenticated", b"")))
        with pytest.raises(flight.FlightUnauthorizedError, match=".*foo.*"):
            list(client.do_action(flight.Action("unauthorized", b"")))
        with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
            list(client.list_flights())
Esempio n. 2
0
def test_http_basic_unauth():
    """Test that auth fails when not authenticated."""
    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("who-am-i", b"")
        with pytest.raises(flight.FlightUnauthenticatedError,
                           match=".*unauthenticated.*"):
            list(client.do_action(action))
Esempio n. 3
0
def test_token_auth():
    """Test an auth mechanism that uses a handshake."""
    with EchoStreamFlightServer(auth_handler=token_auth_handler) as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("who-am-i", b"")
        client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd'))
        identity = next(client.do_action(action))
        assert identity.body.to_pybytes() == b'test'
Esempio n. 4
0
def test_http_basic_auth():
    """Test a Python implementation of HTTP basic authentication."""
    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("who-am-i", b"")
        client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd'))
        identity = next(client.do_action(action))
        assert identity.body.to_pybytes() == b'test'
Esempio n. 5
0
def test_http_basic_unauth():
    """Test that auth fails when not authenticated."""
    with flight_server(EchoStreamFlightServer,
                       auth_handler=basic_auth_handler) as server_location:
        client = flight.FlightClient.connect(server_location)
        action = flight.Action("who-am-i", b"")
        with pytest.raises(pa.ArrowException, match=".*unauthenticated.*"):
            list(client.do_action(action))
Esempio n. 6
0
def test_http_basic_auth_invalid_password():
    """Test that auth fails with the wrong password."""
    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("who-am-i", b"")
        with pytest.raises(flight.FlightUnauthenticatedError,
                           match=".*wrong password.*"):
            client.authenticate(HttpBasicClientAuthHandler('test', 'wrong'))
            next(client.do_action(action))
Esempio n. 7
0
def test_http_basic_auth_invalid_password():
    """Test that auth fails with the wrong password."""
    with flight_server(EchoStreamFlightServer,
                       auth_handler=basic_auth_handler) as server_location:
        client = flight.FlightClient.connect(server_location)
        action = flight.Action("who-am-i", b"")
        client.authenticate(HttpBasicClientAuthHandler('test', 'wrong'))
        with pytest.raises(pa.ArrowException, match=".*wrong password.*"):
            next(client.do_action(action))
Esempio n. 8
0
def test_extra_info():
    with ErrorFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        try:
            list(client.do_action(flight.Action("protobuf", b"")))
            assert False
        except flight.FlightUnauthorizedError as e:
            assert e.extra_info is not None
            ei = e.extra_info
            assert ei == b'this is an error message'
Esempio n. 9
0
def test_server_middleware_same_thread():
    """Ensure that server middleware run on the same thread as the RPC."""
    with HeaderFlightServer(middleware={
            "test": HeaderServerMiddlewareFactory(),
    }) as server:
        client = FlightClient(('localhost', server.port))
        results = list(client.do_action(flight.Action(b"test", b"")))
        assert len(results) == 1
        value = results[0].body.to_pybytes()
        assert b"right value" == value
Esempio n. 10
0
def test_middleware_reject():
    """Test rejecting an RPC with server middleware."""
    with HeaderFlightServer(middleware={
            "test": SelectiveAuthServerMiddlewareFactory(),
    }) as server:
        client = FlightClient(('localhost', server.port))
        # The middleware allows this through without auth.
        with pytest.raises(pa.ArrowNotImplementedError):
            list(client.list_actions())

        # But not anything else.
        with pytest.raises(flight.FlightUnauthenticatedError):
            list(client.do_action(flight.Action(b"", b"")))

        client = FlightClient(
            ('localhost', server.port),
            middleware=[SelectiveAuthClientMiddlewareFactory()])
        response = next(client.do_action(flight.Action(b"", b"")))
        assert b"password" == response.body.to_pybytes()
Esempio n. 11
0
def test_timeout_fires():
    """Make sure timeouts fire on slow requests."""
    # Do this in a separate thread so that if it fails, we don't hang
    # the entire test process
    with flight_server(SlowFlightServer) as server_location:
        client = flight.FlightClient.connect(server_location)
        action = flight.Action("", b"")
        options = flight.FlightCallOptions(timeout=0.2)
        # gRPC error messages change based on version, so don't look
        # for a particular error
        with pytest.raises(flight.FlightTimedOutError):
            list(client.do_action(action, options=options))
Esempio n. 12
0
def test_middleware_multi_header():
    """Test sending/receiving multiple (binary-valued) headers."""
    with MultiHeaderFlightServer(middleware={
            "test": MultiHeaderServerMiddlewareFactory(),
    }) as server:
        headers = MultiHeaderClientMiddlewareFactory()
        client = FlightClient(('localhost', server.port), middleware=[headers])
        response = next(client.do_action(flight.Action(b"", b"")))
        # The server echoes the headers it got back to us.
        raw_headers = response.body.to_pybytes().decode("utf-8")
        client_headers = ast.literal_eval(raw_headers)
        # Don't directly compare; gRPC may add headers like User-Agent.
        for header, values in MultiHeaderClientMiddleware.EXPECTED.items():
            assert client_headers.get(header) == values
            assert headers.last_headers.get(header) == values
Esempio n. 13
0
def main():
    parser = argparse.ArgumentParser()

    subparsers = parser.add_subparsers(dest="command")
    client = subparsers.add_parser("client", help="Run the client.")
    client.add_argument("server")
    client.add_argument("--request-id", default=None)

    server = subparsers.add_parser("server", help="Run the server.")
    server.add_argument(
        "--listen",
        required=True,
        help="The location to listen on (example: grpc://localhost:5050)",
    )
    server.add_argument(
        "--delegate",
        required=False,
        default=None,
        help=("A location to delegate to. That is, this server will "
              "simply call the given server for the response. Demonstrates "
              "propagation of the trace ID between servers."),
    )

    args = parser.parse_args()
    if not getattr(args, "command"):
        parser.print_help()
        return 1

    if args.command == "server":
        server = FlightServer(
            args.delegate,
            location=args.listen,
            middleware={"trace": TracingServerMiddlewareFactory()})
        server.serve()
    elif args.command == "client":
        client = flight.connect(
            args.server,
            middleware=(TracingClientMiddlewareFactory(),))
        if args.request_id:
            TraceContext.set_trace_id(args.request_id)
        else:
            TraceContext.set_trace_id("client-chosen-id")

        for result in client.do_action(flight.Action("get-trace-id", b"")):
            print(result.body.to_pybytes())
Esempio n. 14
0
def test_middleware_mapping():
    """Test that middleware records methods correctly."""
    server_middleware = RecordingServerMiddlewareFactory()
    client_middleware = RecordingClientMiddlewareFactory()
    with FlightServerBase(middleware={"test": server_middleware}) as server:
        client = FlightClient(
            ('localhost', server.port),
            middleware=[client_middleware]
        )

        descriptor = flight.FlightDescriptor.for_command(b"")
        with pytest.raises(NotImplementedError):
            list(client.list_flights())
        with pytest.raises(NotImplementedError):
            client.get_flight_info(descriptor)
        with pytest.raises(NotImplementedError):
            client.get_schema(descriptor)
        with pytest.raises(NotImplementedError):
            client.do_get(flight.Ticket(b""))
        with pytest.raises(NotImplementedError):
            writer, _ = client.do_put(descriptor, pa.schema([]))
            writer.close()
        with pytest.raises(NotImplementedError):
            list(client.do_action(flight.Action(b"", b"")))
        with pytest.raises(NotImplementedError):
            list(client.list_actions())
        with pytest.raises(NotImplementedError):
            writer, _ = client.do_exchange(descriptor)
            writer.close()

        expected = [
            flight.FlightMethod.LIST_FLIGHTS,
            flight.FlightMethod.GET_FLIGHT_INFO,
            flight.FlightMethod.GET_SCHEMA,
            flight.FlightMethod.DO_GET,
            flight.FlightMethod.DO_PUT,
            flight.FlightMethod.DO_ACTION,
            flight.FlightMethod.LIST_ACTIONS,
            flight.FlightMethod.DO_EXCHANGE,
        ]
        assert server_middleware.methods == expected
        assert client_middleware.methods == expected
Esempio n. 15
0
 def list_tables(self):
     action = flight.Action('list-tables', b'')
     return [x.body.to_pybytes().decode('utf8') for x in self.con.do_action(action)]