示例#1
0
def test_do_put_independent_read_write():
    """Ensure that separate threads can read/write on a DoPut."""
    # ARROW-6063: previously this would cause gRPC to abort when the
    # writer was closed (due to simultaneous reads), or would hang
    # forever.
    data = [pa.array([-10, -5, 0, 5, 10])]
    table = pa.Table.from_arrays(data, names=['a'])

    with MetadataFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        writer, metadata_reader = client.do_put(
            flight.FlightDescriptor.for_path(''), table.schema)

        count = [0]

        def _reader_thread():
            while metadata_reader.read() is not None:
                count[0] += 1

        thread = threading.Thread(target=_reader_thread)
        thread.start()

        batches = table.to_batches(max_chunksize=1)
        with writer:
            for idx, batch in enumerate(batches):
                metadata = struct.pack('<i', idx)
                writer.write_with_metadata(batch, metadata)
            # Causes the server to stop writing and end the call
            writer.done_writing()
            # Thus reader thread will break out of loop
            thread.join()
        # writer.close() won't segfault since reader thread has
        # stopped
        assert count[0] == len(batches)
示例#2
0
def test_cancel_do_get_threaded():
    """Test canceling a DoGet operation from another thread."""
    with SlowFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        reader = client.do_get(flight.Ticket(b'ints'))

        read_first_message = threading.Event()
        stream_canceled = threading.Event()
        result_lock = threading.Lock()
        raised_proper_exception = threading.Event()

        def block_read():
            reader.read_chunk()
            read_first_message.set()
            stream_canceled.wait(timeout=5)
            try:
                reader.read_chunk()
            except flight.FlightCancelledError:
                with result_lock:
                    raised_proper_exception.set()

        thread = threading.Thread(target=block_read, daemon=True)
        thread.start()
        read_first_message.wait(timeout=5)
        reader.cancel()
        stream_canceled.set()
        thread.join(timeout=1)

        with result_lock:
            assert raised_proper_exception.is_set()
示例#3
0
def test_doexchange_echo():
    """Try a DoExchange echo server."""
    data = pa.Table.from_arrays([pa.array(range(0, 10 * 1024))], names=["a"])
    batches = data.to_batches(max_chunksize=512)

    with ExchangeFlightServer() as server:
        client = FlightClient(("localhost", server.port))
        descriptor = flight.FlightDescriptor.for_command(b"echo")
        writer, reader = client.do_exchange(descriptor)
        with writer:
            # Read/write metadata before starting data.
            for i in range(10):
                buf = str(i).encode("utf-8")
                writer.write_metadata(buf)
                chunk = reader.read_chunk()
                assert chunk.data is None
                assert chunk.app_metadata == buf

            # Now write data without metadata.
            writer.begin(data.schema)
            for batch in batches:
                writer.write_batch(batch)
                assert reader.schema == data.schema
                chunk = reader.read_chunk()
                assert chunk.data == batch
                assert chunk.app_metadata is None

            # And write data with metadata.
            for i, batch in enumerate(batches):
                buf = str(i).encode("utf-8")
                writer.write_with_metadata(batch, buf)
                chunk = reader.read_chunk()
                assert chunk.data == batch
                assert chunk.app_metadata == buf
示例#4
0
def test_cancel_do_get():
    """Test canceling a DoGet operation on the client side."""
    with ConstantFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        reader = client.do_get(flight.Ticket(b'ints'))
        reader.cancel()
        with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"):
            reader.read_chunk()
示例#5
0
def test_http_basic_unauth():
    """Test that auth fails when not authenticated."""
    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("who-am-i", b"")
        with pytest.raises(flight.FlightUnauthenticatedError,
                           match=".*unauthenticated.*"):
            list(client.do_action(action))
示例#6
0
def test_http_basic_auth():
    """Test a Python implementation of HTTP basic authentication."""
    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("who-am-i", b"")
        client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd'))
        identity = next(client.do_action(action))
        assert identity.body.to_pybytes() == b'test'
示例#7
0
def test_token_auth():
    """Test an auth mechanism that uses a handshake."""
    with EchoStreamFlightServer(auth_handler=token_auth_handler) as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("who-am-i", b"")
        client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd'))
        identity = next(client.do_action(action))
        assert identity.body.to_pybytes() == b'test'
示例#8
0
def test_http_basic_auth_invalid_password():
    """Test that auth fails with the wrong password."""
    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("who-am-i", b"")
        with pytest.raises(flight.FlightUnauthenticatedError,
                           match=".*wrong password.*"):
            client.authenticate(HttpBasicClientAuthHandler('test', 'wrong'))
            next(client.do_action(action))
示例#9
0
def test_tls_do_get():
    """Try a simple do_get call over TLS."""
    table = simple_ints_table()
    certs = example_tls_certs()

    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
        client = FlightClient(('localhost', s.port),
                              tls_root_certs=certs["root_cert"])
        data = client.do_get(flight.Ticket(b'ints')).read_all()
        assert data.equals(table)
示例#10
0
def test_server_middleware_same_thread():
    """Ensure that server middleware run on the same thread as the RPC."""
    with HeaderFlightServer(middleware={
            "test": HeaderServerMiddlewareFactory(),
    }) as server:
        client = FlightClient(('localhost', server.port))
        results = list(client.do_action(flight.Action(b"test", b"")))
        assert len(results) == 1
        value = results[0].body.to_pybytes()
        assert b"right value" == value
示例#11
0
def test_extra_info():
    with ErrorFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        try:
            list(client.do_action(flight.Action("protobuf", b"")))
            assert False
        except flight.FlightUnauthorizedError as e:
            assert e.extra_info is not None
            ei = e.extra_info
            assert ei == b'this is an error message'
示例#12
0
def test_do_action_result_convenience():
    with ConvenienceServer() as server:
        client = FlightClient(('localhost', server.port))

        # do_action as action type without body
        results = [x.body for x in client.do_action('simple-action')]
        assert results == server.simple_action_results

        # do_action with tuple of type and body
        body = b'the-body'
        results = [x.body for x in client.do_action(('echo', body))]
        assert results == [body]
示例#13
0
def test_flight_generator_stream():
    """Try downloading a flight of RecordBatches in a GeneratorStream."""
    data = pa.Table.from_arrays([pa.array(range(0, 10 * 1024))], names=['a'])

    with EchoStreamFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
                                  data.schema)
        writer.write_table(data)
        writer.close()
        result = client.do_get(flight.Ticket(b'')).read_all()
        assert result.equals(data)
示例#14
0
def test_nicer_server_exceptions():
    with ConvenienceServer() as server:
        client = FlightClient(('localhost', server.port))
        with pytest.raises(flight.FlightServerError,
                           match="a bytes-like object is required"):
            list(client.do_action('bad-action'))
        # While Flight/C++ sends across the original status code, it
        # doesn't get mapped to the equivalent code here, since we
        # want to be able to distinguish between client- and server-
        # side errors.
        with pytest.raises(flight.FlightServerError, match="ArrowMemoryError"):
            list(client.do_action('arrow-exception'))
示例#15
0
def test_doexchange_get():
    """Emulate DoGet with DoExchange."""
    expected = pa.Table.from_arrays([pa.array(range(0, 10 * 1024))],
                                    names=["a"])

    with ExchangeFlightServer() as server:
        client = FlightClient(("localhost", server.port))
        descriptor = flight.FlightDescriptor.for_command(b"get")
        writer, reader = client.do_exchange(descriptor)
        with writer:
            table = reader.read_all()
        assert expected == table
示例#16
0
def test_timeout_fires():
    """Make sure timeouts fire on slow requests."""
    # Do this in a separate thread so that if it fails, we don't hang
    # the entire test process
    with SlowFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        action = flight.Action("", b"")
        options = flight.FlightCallOptions(timeout=0.2)
        # gRPC error messages change based on version, so don't look
        # for a particular error
        with pytest.raises(flight.FlightTimedOutError):
            list(client.do_action(action, options=options))
示例#17
0
def test_tls_fails():
    """Make sure clients cannot connect when cert verification fails."""
    certs = example_tls_certs()

    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
        # Ensure client doesn't connect when certificate verification
        # fails (this is a slow test since gRPC does retry a few times)
        client = FlightClient("grpc+tls://localhost:" + str(s.port))

        # gRPC error messages change based on version, so don't look
        # for a particular error
        with pytest.raises(flight.FlightUnavailableError):
            client.do_get(flight.Ticket(b'ints')).read_all()
示例#18
0
def test_flight_get_info():
    """Make sure FlightEndpoint accepts string and object URIs."""
    with GetInfoFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        info = client.get_flight_info(flight.FlightDescriptor.for_command(b''))
        assert info.total_records == -1
        assert info.total_bytes == -1
        assert info.schema == pa.schema([('a', pa.int32())])
        assert len(info.endpoints) == 2
        assert len(info.endpoints[0].locations) == 1
        assert info.endpoints[0].locations[0] == flight.Location('grpc://test')
        assert info.endpoints[1].locations[0] == \
            flight.Location.for_grpc_tcp('localhost', 5005)
示例#19
0
def test_mtls():
    """Test mutual TLS (mTLS) with gRPC."""
    certs = example_tls_certs()
    table = simple_ints_table()

    with ConstantFlightServer(tls_certificates=[certs["certificates"][0]],
                              verify_client=True,
                              root_certificates=certs["root_cert"]) as s:
        client = FlightClient(('localhost', s.port),
                              tls_root_certs=certs["root_cert"],
                              cert_chain=certs["certificates"][0].cert,
                              private_key=certs["certificates"][0].key)
        data = client.do_get(flight.Ticket(b'ints')).read_all()
        assert data.equals(table)
示例#20
0
def test_middleware_multi_header():
    """Test sending/receiving multiple (binary-valued) headers."""
    with MultiHeaderFlightServer(middleware={
            "test": MultiHeaderServerMiddlewareFactory(),
    }) as server:
        headers = MultiHeaderClientMiddlewareFactory()
        client = FlightClient(('localhost', server.port), middleware=[headers])
        response = next(client.do_action(flight.Action(b"", b"")))
        # The server echoes the headers it got back to us.
        raw_headers = response.body.to_pybytes().decode("utf-8")
        client_headers = ast.literal_eval(raw_headers)
        # Don't directly compare; gRPC may add headers like User-Agent.
        for header, values in MultiHeaderClientMiddleware.EXPECTED.items():
            assert client_headers.get(header) == values
            assert headers.last_headers.get(header) == values
示例#21
0
def test_do_action_result_convenience():
    with ConvenienceServer() as server:
        client = FlightClient(('localhost', server.port))

        # do_action as action type without body
        results = [x.body for x in client.do_action('simple-action')]
        assert results == server.simple_action_results

        # do_action with tuple of type and body
        body = b'the-body'
        results = [x.body for x in client.do_action(('echo', body))]
        assert results == [body]

        # ARROW-6884 raise a more specific and helpful exception
        with pytest.raises(Exception):
            list(client.do_action('bad-action'))
示例#22
0
def test_flight_do_put_metadata():
    """Try a simple do_put call with metadata."""
    data = [pa.array([-10, -5, 0, 5, 10])]
    table = pa.Table.from_arrays(data, names=['a'])

    with MetadataFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        writer, metadata_reader = client.do_put(
            flight.FlightDescriptor.for_path(''), table.schema)
        with writer:
            for idx, batch in enumerate(table.to_batches(max_chunksize=1)):
                metadata = struct.pack('<i', idx)
                writer.write_with_metadata(batch, metadata)
                buf = metadata_reader.read()
                assert buf is not None
                server_idx, = struct.unpack('<i', buf.to_pybytes())
                assert idx == server_idx
示例#23
0
def test_doexchange_put():
    """Emulate DoPut with DoExchange."""
    data = pa.Table.from_arrays([pa.array(range(0, 10 * 1024))], names=["a"])
    batches = data.to_batches(max_chunksize=512)

    with ExchangeFlightServer() as server:
        client = FlightClient(("localhost", server.port))
        descriptor = flight.FlightDescriptor.for_command(b"put")
        writer, reader = client.do_exchange(descriptor)
        with writer:
            writer.begin(data.schema)
            for batch in batches:
                writer.write_batch(batch)
            writer.done_writing()
            chunk = reader.read_chunk()
            assert chunk.data is None
            expected_buf = str(len(batches)).encode("utf-8")
            assert chunk.app_metadata == expected_buf
示例#24
0
def test_flight_large_message():
    """Try sending/receiving a large message via Flight.

    See ARROW-4421: by default, gRPC won't allow us to send messages >
    4MiB in size.
    """
    data = pa.Table.from_arrays([pa.array(range(0, 10 * 1024 * 1024))],
                                names=['a'])

    with EchoFlightServer(expected_schema=data.schema) as server:
        client = FlightClient(('localhost', server.port))
        writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
                                  data.schema)
        # Write a single giant chunk
        writer.write_table(data, 10 * 1024 * 1024)
        writer.close()
        result = client.do_get(flight.Ticket(b'')).read_all()
        assert result.equals(data)
示例#25
0
def test_client_wait_for_available():
    location = ('localhost', find_free_port())
    server = None

    def serve():
        global server
        time.sleep(0.5)
        server = FlightServerBase(location)
        server.serve()

    client = FlightClient(location)
    thread = threading.Thread(target=serve, daemon=True)
    thread.start()

    started = time.time()
    client.wait_for_available(timeout=5)
    elapsed = time.time() - started
    assert elapsed >= 0.5
示例#26
0
def test_flight_domain_socket():
    """Try a simple do_get call over a Unix domain socket."""
    with tempfile.NamedTemporaryFile() as sock:
        sock.close()
        location = flight.Location.for_grpc_unix(sock.name)
        with ConstantFlightServer(location=location):
            client = FlightClient(location)

            reader = client.do_get(flight.Ticket(b'ints'))
            table = simple_ints_table()
            assert reader.schema.equals(table.schema)
            data = reader.read_all()
            assert data.equals(table)

            reader = client.do_get(flight.Ticket(b'dicts'))
            table = simple_dicts_table()
            assert reader.schema.equals(table.schema)
            data = reader.read_all()
            assert data.equals(table)
示例#27
0
def test_flight_do_get_metadata():
    """Try a simple do_get call with metadata."""
    data = [pa.array([-10, -5, 0, 5, 10])]
    table = pa.Table.from_arrays(data, names=['a'])

    batches = []
    with MetadataFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        reader = client.do_get(flight.Ticket(b''))
        idx = 0
        while True:
            try:
                batch, metadata = reader.read_chunk()
                batches.append(batch)
                server_idx, = struct.unpack('<i', metadata.to_pybytes())
                assert idx == server_idx
                idx += 1
            except StopIteration:
                break
        data = pa.Table.from_batches(batches)
        assert data.equals(table)
示例#28
0
def test_doexchange_transform():
    """Transform a table with a service."""
    data = pa.Table.from_arrays([
        pa.array(range(0, 1024)),
        pa.array(range(1, 1025)),
        pa.array(range(2, 1026)),
    ], names=["a", "b", "c"])
    expected = pa.Table.from_arrays([
        pa.array(range(3, 1024 * 3 + 3, 3)),
    ], names=["sum"])

    with ExchangeFlightServer() as server:
        client = FlightClient(("localhost", server.port))
        descriptor = flight.FlightDescriptor.for_command(b"transform")
        writer, reader = client.do_exchange(descriptor)
        with writer:
            writer.begin(data.schema)
            writer.write_table(data)
            writer.done_writing()
            table = reader.read_all()
        assert expected == table
示例#29
0
def test_roundtrip_errors():
    """Ensure that Flight errors propagate from server to client."""
    with ErrorFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
            list(client.do_action(flight.Action("internal", b"")))
        with pytest.raises(flight.FlightTimedOutError, match=".*foo.*"):
            list(client.do_action(flight.Action("timedout", b"")))
        with pytest.raises(flight.FlightCancelledError, match=".*foo.*"):
            list(client.do_action(flight.Action("cancel", b"")))
        with pytest.raises(flight.FlightUnauthenticatedError, match=".*foo.*"):
            list(client.do_action(flight.Action("unauthenticated", b"")))
        with pytest.raises(flight.FlightUnauthorizedError, match=".*foo.*"):
            list(client.do_action(flight.Action("unauthorized", b"")))
        with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
            list(client.list_flights())
示例#30
0
def test_list_actions():
    """Make sure the return type of ListActions is validated."""
    # ARROW-6392
    with ListActionsErrorFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        with pytest.raises(pa.ArrowException, match=".*unknown error.*"):
            list(client.list_actions())

    with ListActionsFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        assert list(client.list_actions()) == \
            ListActionsFlightServer.expected_actions()