Beispiel #1
0
def test_do_action_result_convenience():
    with ConvenienceServer() as server:
        client = FlightClient(('localhost', server.port))

        # do_action as action type without body
        results = [x.body for x in client.do_action('simple-action')]
        assert results == server.simple_action_results

        # do_action with tuple of type and body
        body = b'the-body'
        results = [x.body for x in client.do_action(('echo', body))]
        assert results == [body]
Beispiel #2
0
def test_nicer_server_exceptions():
    with ConvenienceServer() as server:
        client = FlightClient(('localhost', server.port))
        with pytest.raises(flight.FlightServerError,
                           match="a bytes-like object is required"):
            list(client.do_action('bad-action'))
        # While Flight/C++ sends across the original status code, it
        # doesn't get mapped to the equivalent code here, since we
        # want to be able to distinguish between client- and server-
        # side errors.
        with pytest.raises(flight.FlightServerError, match="ArrowMemoryError"):
            list(client.do_action('arrow-exception'))
Beispiel #3
0
def test_roundtrip_errors():
    """Ensure that Flight errors propagate from server to client."""
    with ErrorFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
            list(client.do_action(flight.Action("internal", b"")))
        with pytest.raises(flight.FlightTimedOutError, match=".*foo.*"):
            list(client.do_action(flight.Action("timedout", b"")))
        with pytest.raises(flight.FlightCancelledError, match=".*foo.*"):
            list(client.do_action(flight.Action("cancel", b"")))
        with pytest.raises(flight.FlightUnauthenticatedError, match=".*foo.*"):
            list(client.do_action(flight.Action("unauthenticated", b"")))
        with pytest.raises(flight.FlightUnauthorizedError, match=".*foo.*"):
            list(client.do_action(flight.Action("unauthorized", b"")))
        with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
            list(client.list_flights())
Beispiel #4
0
def test_do_action_result_convenience():
    with ConvenienceServer() as server:
        client = FlightClient(('localhost', server.port))

        # do_action as action type without body
        results = [x.body for x in client.do_action('simple-action')]
        assert results == server.simple_action_results

        # do_action with tuple of type and body
        body = b'the-body'
        results = [x.body for x in client.do_action(('echo', body))]
        assert results == [body]

        # ARROW-6884 raise a more specific and helpful exception
        with pytest.raises(Exception):
            list(client.do_action('bad-action'))
Beispiel #5
0
def test_token_auth():
    """Test an auth mechanism that uses a handshake."""
    with EchoStreamFlightServer(auth_handler=token_auth_handler) as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("who-am-i", b"")
        client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd'))
        identity = next(client.do_action(action))
        assert identity.body.to_pybytes() == b'test'
Beispiel #6
0
def test_http_basic_auth():
    """Test a Python implementation of HTTP basic authentication."""
    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("who-am-i", b"")
        client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd'))
        identity = next(client.do_action(action))
        assert identity.body.to_pybytes() == b'test'
Beispiel #7
0
def test_http_basic_unauth():
    """Test that auth fails when not authenticated."""
    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("who-am-i", b"")
        with pytest.raises(flight.FlightUnauthenticatedError,
                           match=".*unauthenticated.*"):
            list(client.do_action(action))
Beispiel #8
0
def test_http_basic_auth_invalid_password():
    """Test that auth fails with the wrong password."""
    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("who-am-i", b"")
        with pytest.raises(flight.FlightUnauthenticatedError,
                           match=".*wrong password.*"):
            client.authenticate(HttpBasicClientAuthHandler('test', 'wrong'))
            next(client.do_action(action))
Beispiel #9
0
def test_extra_info():
    with ErrorFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        try:
            list(client.do_action(flight.Action("protobuf", b"")))
            assert False
        except flight.FlightUnauthorizedError as e:
            assert e.extra_info is not None
            ei = e.extra_info
            assert ei == b'this is an error message'
Beispiel #10
0
def test_server_middleware_same_thread():
    """Ensure that server middleware run on the same thread as the RPC."""
    with HeaderFlightServer(middleware={
            "test": HeaderServerMiddlewareFactory(),
    }) as server:
        client = FlightClient(('localhost', server.port))
        results = list(client.do_action(flight.Action(b"test", b"")))
        assert len(results) == 1
        value = results[0].body.to_pybytes()
        assert b"right value" == value
Beispiel #11
0
def test_middleware_reject():
    """Test rejecting an RPC with server middleware."""
    with HeaderFlightServer(middleware={
            "test": SelectiveAuthServerMiddlewareFactory(),
    }) as server:
        client = FlightClient(('localhost', server.port))
        # The middleware allows this through without auth.
        with pytest.raises(pa.ArrowNotImplementedError):
            list(client.list_actions())

        # But not anything else.
        with pytest.raises(flight.FlightUnauthenticatedError):
            list(client.do_action(flight.Action(b"", b"")))

        client = FlightClient(
            ('localhost', server.port),
            middleware=[SelectiveAuthClientMiddlewareFactory()])
        response = next(client.do_action(flight.Action(b"", b"")))
        assert b"password" == response.body.to_pybytes()
Beispiel #12
0
def test_timeout_fires():
    """Make sure timeouts fire on slow requests."""
    # Do this in a separate thread so that if it fails, we don't hang
    # the entire test process
    with SlowFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("", b"")
        options = flight.FlightCallOptions(timeout=0.2)
        # gRPC error messages change based on version, so don't look
        # for a particular error
        with pytest.raises(flight.FlightTimedOutError):
            list(client.do_action(action, options=options))
Beispiel #13
0
def test_middleware_multi_header():
    """Test sending/receiving multiple (binary-valued) headers."""
    with MultiHeaderFlightServer(middleware={
            "test": MultiHeaderServerMiddlewareFactory(),
    }) as server:
        headers = MultiHeaderClientMiddlewareFactory()
        client = FlightClient(('localhost', server.port), middleware=[headers])
        response = next(client.do_action(flight.Action(b"", b"")))
        # The server echoes the headers it got back to us.
        raw_headers = response.body.to_pybytes().decode("utf-8")
        client_headers = ast.literal_eval(raw_headers)
        # Don't directly compare; gRPC may add headers like User-Agent.
        for header, values in MultiHeaderClientMiddleware.EXPECTED.items():
            assert client_headers.get(header) == values
            assert headers.last_headers.get(header) == values
Beispiel #14
0
def test_middleware_mapping():
    """Test that middleware records methods correctly."""
    server_middleware = RecordingServerMiddlewareFactory()
    client_middleware = RecordingClientMiddlewareFactory()
    with FlightServerBase(middleware={"test": server_middleware}) as server:
        client = FlightClient(
            ('localhost', server.port),
            middleware=[client_middleware]
        )

        descriptor = flight.FlightDescriptor.for_command(b"")
        with pytest.raises(NotImplementedError):
            list(client.list_flights())
        with pytest.raises(NotImplementedError):
            client.get_flight_info(descriptor)
        with pytest.raises(NotImplementedError):
            client.get_schema(descriptor)
        with pytest.raises(NotImplementedError):
            client.do_get(flight.Ticket(b""))
        with pytest.raises(NotImplementedError):
            writer, _ = client.do_put(descriptor, pa.schema([]))
            writer.close()
        with pytest.raises(NotImplementedError):
            list(client.do_action(flight.Action(b"", b"")))
        with pytest.raises(NotImplementedError):
            list(client.list_actions())
        with pytest.raises(NotImplementedError):
            writer, _ = client.do_exchange(descriptor)
            writer.close()

        expected = [
            flight.FlightMethod.LIST_FLIGHTS,
            flight.FlightMethod.GET_FLIGHT_INFO,
            flight.FlightMethod.GET_SCHEMA,
            flight.FlightMethod.DO_GET,
            flight.FlightMethod.DO_PUT,
            flight.FlightMethod.DO_ACTION,
            flight.FlightMethod.LIST_ACTIONS,
            flight.FlightMethod.DO_EXCHANGE,
        ]
        assert server_middleware.methods == expected
        assert client_middleware.methods == expected
Beispiel #15
0
def test_nicer_server_exceptions():
    with ConvenienceServer() as server:
        client = FlightClient(('localhost', server.port))
        with pytest.raises(flight.FlightServerError,
                           match="TypeError: a bytes-like object is required"):
            list(client.do_action('bad-action'))