def test_tls_fails(): """Make sure clients cannot connect when cert verification fails.""" certs = example_tls_certs() with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: # Ensure client doesn't connect when certificate verification # fails (this is a slow test since gRPC does retry a few times) client = FlightClient("grpc+tls://localhost:" + str(s.port)) # gRPC error messages change based on version, so don't look # for a particular error with pytest.raises(flight.FlightUnavailableError): client.do_get(flight.Ticket(b'ints')).read_all()
def test_cancel_do_get_threaded(): """Test canceling a DoGet operation from another thread.""" with SlowFlightServer() as server: client = FlightClient(('localhost', server.port)) reader = client.do_get(flight.Ticket(b'ints')) read_first_message = threading.Event() stream_canceled = threading.Event() result_lock = threading.Lock() raised_proper_exception = threading.Event() def block_read(): reader.read_chunk() read_first_message.set() stream_canceled.wait(timeout=5) try: reader.read_chunk() except flight.FlightCancelledError: with result_lock: raised_proper_exception.set() thread = threading.Thread(target=block_read, daemon=True) thread.start() read_first_message.wait(timeout=5) reader.cancel() stream_canceled.set() thread.join(timeout=1) with result_lock: assert raised_proper_exception.is_set()
def test_cancel_do_get(): """Test canceling a DoGet operation on the client side.""" with ConstantFlightServer() as server: client = FlightClient(('localhost', server.port)) reader = client.do_get(flight.Ticket(b'ints')) reader.cancel() with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"): reader.read_chunk()
def test_flight_domain_socket(): """Try a simple do_get call over a Unix domain socket.""" with tempfile.NamedTemporaryFile() as sock: sock.close() location = flight.Location.for_grpc_unix(sock.name) with ConstantFlightServer(location=location): client = FlightClient(location) reader = client.do_get(flight.Ticket(b'ints')) table = simple_ints_table() assert reader.schema.equals(table.schema) data = reader.read_all() assert data.equals(table) reader = client.do_get(flight.Ticket(b'dicts')) table = simple_dicts_table() assert reader.schema.equals(table.schema) data = reader.read_all() assert data.equals(table)
def test_tls_do_get(): """Try a simple do_get call over TLS.""" table = simple_ints_table() certs = example_tls_certs() with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: client = FlightClient(('localhost', s.port), tls_root_certs=certs["root_cert"]) data = client.do_get(flight.Ticket(b'ints')).read_all() assert data.equals(table)
def test_middleware_mapping(): """Test that middleware records methods correctly.""" server_middleware = RecordingServerMiddlewareFactory() client_middleware = RecordingClientMiddlewareFactory() with FlightServerBase(middleware={"test": server_middleware}) as server: client = FlightClient( ('localhost', server.port), middleware=[client_middleware] ) descriptor = flight.FlightDescriptor.for_command(b"") with pytest.raises(NotImplementedError): list(client.list_flights()) with pytest.raises(NotImplementedError): client.get_flight_info(descriptor) with pytest.raises(NotImplementedError): client.get_schema(descriptor) with pytest.raises(NotImplementedError): client.do_get(flight.Ticket(b"")) with pytest.raises(NotImplementedError): writer, _ = client.do_put(descriptor, pa.schema([])) writer.close() with pytest.raises(NotImplementedError): list(client.do_action(flight.Action(b"", b""))) with pytest.raises(NotImplementedError): list(client.list_actions()) with pytest.raises(NotImplementedError): writer, _ = client.do_exchange(descriptor) writer.close() expected = [ flight.FlightMethod.LIST_FLIGHTS, flight.FlightMethod.GET_FLIGHT_INFO, flight.FlightMethod.GET_SCHEMA, flight.FlightMethod.DO_GET, flight.FlightMethod.DO_PUT, flight.FlightMethod.DO_ACTION, flight.FlightMethod.LIST_ACTIONS, flight.FlightMethod.DO_EXCHANGE, ] assert server_middleware.methods == expected assert client_middleware.methods == expected
def test_flight_generator_stream(): """Try downloading a flight of RecordBatches in a GeneratorStream.""" data = pa.Table.from_arrays([pa.array(range(0, 10 * 1024))], names=['a']) with EchoStreamFlightServer() as server: client = FlightClient(('localhost', server.port)) writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), data.schema) writer.write_table(data) writer.close() result = client.do_get(flight.Ticket(b'')).read_all() assert result.equals(data)
def test_mtls(): """Test mutual TLS (mTLS) with gRPC.""" certs = example_tls_certs() table = simple_ints_table() with ConstantFlightServer(tls_certificates=[certs["certificates"][0]], verify_client=True, root_certificates=certs["root_cert"]) as s: client = FlightClient(('localhost', s.port), tls_root_certs=certs["root_cert"], cert_chain=certs["certificates"][0].cert, private_key=certs["certificates"][0].key) data = client.do_get(flight.Ticket(b'ints')).read_all() assert data.equals(table)
def test_flight_large_message(): """Try sending/receiving a large message via Flight. See ARROW-4421: by default, gRPC won't allow us to send messages > 4MiB in size. """ data = pa.Table.from_arrays([pa.array(range(0, 10 * 1024 * 1024))], names=['a']) with EchoFlightServer(expected_schema=data.schema) as server: client = FlightClient(('localhost', server.port)) writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), data.schema) # Write a single giant chunk writer.write_table(data, 10 * 1024 * 1024) writer.close() result = client.do_get(flight.Ticket(b'')).read_all() assert result.equals(data)
def test_flight_do_get_metadata(): """Try a simple do_get call with metadata.""" data = [pa.array([-10, -5, 0, 5, 10])] table = pa.Table.from_arrays(data, names=['a']) batches = [] with MetadataFlightServer() as server: client = FlightClient(('localhost', server.port)) reader = client.do_get(flight.Ticket(b'')) idx = 0 while True: try: batch, metadata = reader.read_chunk() batches.append(batch) server_idx, = struct.unpack('<i', metadata.to_pybytes()) assert idx == server_idx idx += 1 except StopIteration: break data = pa.Table.from_batches(batches) assert data.equals(table)
def test_timeout_passes(): """Make sure timeouts do not fire on fast requests.""" with ConstantFlightServer() as server: client = FlightClient(('localhost', server.port)) options = flight.FlightCallOptions(timeout=5.0) client.do_get(flight.Ticket(b'ints'), options=options).read_all()
def test_flight_invalid_generator_stream(): """Try streaming data with mismatched schemas.""" with InvalidStreamFlightServer() as server: client = FlightClient(('localhost', server.port)) with pytest.raises(pa.ArrowException): client.do_get(flight.Ticket(b'')).read_all()