def test_do_put_independent_read_write(): """Ensure that separate threads can read/write on a DoPut.""" # ARROW-6063: previously this would cause gRPC to abort when the # writer was closed (due to simultaneous reads), or would hang # forever. data = [pa.array([-10, -5, 0, 5, 10])] table = pa.Table.from_arrays(data, names=['a']) with MetadataFlightServer() as server: client = FlightClient(('localhost', server.port)) writer, metadata_reader = client.do_put( flight.FlightDescriptor.for_path(''), table.schema) count = [0] def _reader_thread(): while metadata_reader.read() is not None: count[0] += 1 thread = threading.Thread(target=_reader_thread) thread.start() batches = table.to_batches(max_chunksize=1) with writer: for idx, batch in enumerate(batches): metadata = struct.pack('<i', idx) writer.write_with_metadata(batch, metadata) # Causes the server to stop writing and end the call writer.done_writing() # Thus reader thread will break out of loop thread.join() # writer.close() won't segfault since reader thread has # stopped assert count[0] == len(batches)
def test_cancel_do_get_threaded(): """Test canceling a DoGet operation from another thread.""" with SlowFlightServer() as server: client = FlightClient(('localhost', server.port)) reader = client.do_get(flight.Ticket(b'ints')) read_first_message = threading.Event() stream_canceled = threading.Event() result_lock = threading.Lock() raised_proper_exception = threading.Event() def block_read(): reader.read_chunk() read_first_message.set() stream_canceled.wait(timeout=5) try: reader.read_chunk() except flight.FlightCancelledError: with result_lock: raised_proper_exception.set() thread = threading.Thread(target=block_read, daemon=True) thread.start() read_first_message.wait(timeout=5) reader.cancel() stream_canceled.set() thread.join(timeout=1) with result_lock: assert raised_proper_exception.is_set()
def test_doexchange_echo(): """Try a DoExchange echo server.""" data = pa.Table.from_arrays([pa.array(range(0, 10 * 1024))], names=["a"]) batches = data.to_batches(max_chunksize=512) with ExchangeFlightServer() as server: client = FlightClient(("localhost", server.port)) descriptor = flight.FlightDescriptor.for_command(b"echo") writer, reader = client.do_exchange(descriptor) with writer: # Read/write metadata before starting data. for i in range(10): buf = str(i).encode("utf-8") writer.write_metadata(buf) chunk = reader.read_chunk() assert chunk.data is None assert chunk.app_metadata == buf # Now write data without metadata. writer.begin(data.schema) for batch in batches: writer.write_batch(batch) assert reader.schema == data.schema chunk = reader.read_chunk() assert chunk.data == batch assert chunk.app_metadata is None # And write data with metadata. for i, batch in enumerate(batches): buf = str(i).encode("utf-8") writer.write_with_metadata(batch, buf) chunk = reader.read_chunk() assert chunk.data == batch assert chunk.app_metadata == buf
def test_cancel_do_get(): """Test canceling a DoGet operation on the client side.""" with ConstantFlightServer() as server: client = FlightClient(('localhost', server.port)) reader = client.do_get(flight.Ticket(b'ints')) reader.cancel() with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"): reader.read_chunk()
def test_http_basic_unauth(): """Test that auth fails when not authenticated.""" with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: client = FlightClient(('localhost', server.port)) action = flight.Action("who-am-i", b"") with pytest.raises(flight.FlightUnauthenticatedError, match=".*unauthenticated.*"): list(client.do_action(action))
def test_http_basic_auth(): """Test a Python implementation of HTTP basic authentication.""" with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: client = FlightClient(('localhost', server.port)) action = flight.Action("who-am-i", b"") client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd')) identity = next(client.do_action(action)) assert identity.body.to_pybytes() == b'test'
def test_token_auth(): """Test an auth mechanism that uses a handshake.""" with EchoStreamFlightServer(auth_handler=token_auth_handler) as server: client = FlightClient(('localhost', server.port)) action = flight.Action("who-am-i", b"") client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd')) identity = next(client.do_action(action)) assert identity.body.to_pybytes() == b'test'
def test_http_basic_auth_invalid_password(): """Test that auth fails with the wrong password.""" with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: client = FlightClient(('localhost', server.port)) action = flight.Action("who-am-i", b"") with pytest.raises(flight.FlightUnauthenticatedError, match=".*wrong password.*"): client.authenticate(HttpBasicClientAuthHandler('test', 'wrong')) next(client.do_action(action))
def test_tls_do_get(): """Try a simple do_get call over TLS.""" table = simple_ints_table() certs = example_tls_certs() with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: client = FlightClient(('localhost', s.port), tls_root_certs=certs["root_cert"]) data = client.do_get(flight.Ticket(b'ints')).read_all() assert data.equals(table)
def test_server_middleware_same_thread(): """Ensure that server middleware run on the same thread as the RPC.""" with HeaderFlightServer(middleware={ "test": HeaderServerMiddlewareFactory(), }) as server: client = FlightClient(('localhost', server.port)) results = list(client.do_action(flight.Action(b"test", b""))) assert len(results) == 1 value = results[0].body.to_pybytes() assert b"right value" == value
def test_extra_info(): with ErrorFlightServer() as server: client = FlightClient(('localhost', server.port)) try: list(client.do_action(flight.Action("protobuf", b""))) assert False except flight.FlightUnauthorizedError as e: assert e.extra_info is not None ei = e.extra_info assert ei == b'this is an error message'
def test_do_action_result_convenience(): with ConvenienceServer() as server: client = FlightClient(('localhost', server.port)) # do_action as action type without body results = [x.body for x in client.do_action('simple-action')] assert results == server.simple_action_results # do_action with tuple of type and body body = b'the-body' results = [x.body for x in client.do_action(('echo', body))] assert results == [body]
def test_flight_generator_stream(): """Try downloading a flight of RecordBatches in a GeneratorStream.""" data = pa.Table.from_arrays([pa.array(range(0, 10 * 1024))], names=['a']) with EchoStreamFlightServer() as server: client = FlightClient(('localhost', server.port)) writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), data.schema) writer.write_table(data) writer.close() result = client.do_get(flight.Ticket(b'')).read_all() assert result.equals(data)
def test_nicer_server_exceptions(): with ConvenienceServer() as server: client = FlightClient(('localhost', server.port)) with pytest.raises(flight.FlightServerError, match="a bytes-like object is required"): list(client.do_action('bad-action')) # While Flight/C++ sends across the original status code, it # doesn't get mapped to the equivalent code here, since we # want to be able to distinguish between client- and server- # side errors. with pytest.raises(flight.FlightServerError, match="ArrowMemoryError"): list(client.do_action('arrow-exception'))
def test_doexchange_get(): """Emulate DoGet with DoExchange.""" expected = pa.Table.from_arrays([pa.array(range(0, 10 * 1024))], names=["a"]) with ExchangeFlightServer() as server: client = FlightClient(("localhost", server.port)) descriptor = flight.FlightDescriptor.for_command(b"get") writer, reader = client.do_exchange(descriptor) with writer: table = reader.read_all() assert expected == table
def test_timeout_fires(): """Make sure timeouts fire on slow requests.""" # Do this in a separate thread so that if it fails, we don't hang # the entire test process with SlowFlightServer() as server: client = FlightClient(('localhost', server.port)) action = flight.Action("", b"") options = flight.FlightCallOptions(timeout=0.2) # gRPC error messages change based on version, so don't look # for a particular error with pytest.raises(flight.FlightTimedOutError): list(client.do_action(action, options=options))
def test_tls_fails(): """Make sure clients cannot connect when cert verification fails.""" certs = example_tls_certs() with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: # Ensure client doesn't connect when certificate verification # fails (this is a slow test since gRPC does retry a few times) client = FlightClient("grpc+tls://localhost:" + str(s.port)) # gRPC error messages change based on version, so don't look # for a particular error with pytest.raises(flight.FlightUnavailableError): client.do_get(flight.Ticket(b'ints')).read_all()
def test_flight_get_info(): """Make sure FlightEndpoint accepts string and object URIs.""" with GetInfoFlightServer() as server: client = FlightClient(('localhost', server.port)) info = client.get_flight_info(flight.FlightDescriptor.for_command(b'')) assert info.total_records == -1 assert info.total_bytes == -1 assert info.schema == pa.schema([('a', pa.int32())]) assert len(info.endpoints) == 2 assert len(info.endpoints[0].locations) == 1 assert info.endpoints[0].locations[0] == flight.Location('grpc://test') assert info.endpoints[1].locations[0] == \ flight.Location.for_grpc_tcp('localhost', 5005)
def test_mtls(): """Test mutual TLS (mTLS) with gRPC.""" certs = example_tls_certs() table = simple_ints_table() with ConstantFlightServer(tls_certificates=[certs["certificates"][0]], verify_client=True, root_certificates=certs["root_cert"]) as s: client = FlightClient(('localhost', s.port), tls_root_certs=certs["root_cert"], cert_chain=certs["certificates"][0].cert, private_key=certs["certificates"][0].key) data = client.do_get(flight.Ticket(b'ints')).read_all() assert data.equals(table)
def test_middleware_multi_header(): """Test sending/receiving multiple (binary-valued) headers.""" with MultiHeaderFlightServer(middleware={ "test": MultiHeaderServerMiddlewareFactory(), }) as server: headers = MultiHeaderClientMiddlewareFactory() client = FlightClient(('localhost', server.port), middleware=[headers]) response = next(client.do_action(flight.Action(b"", b""))) # The server echoes the headers it got back to us. raw_headers = response.body.to_pybytes().decode("utf-8") client_headers = ast.literal_eval(raw_headers) # Don't directly compare; gRPC may add headers like User-Agent. for header, values in MultiHeaderClientMiddleware.EXPECTED.items(): assert client_headers.get(header) == values assert headers.last_headers.get(header) == values
def test_do_action_result_convenience(): with ConvenienceServer() as server: client = FlightClient(('localhost', server.port)) # do_action as action type without body results = [x.body for x in client.do_action('simple-action')] assert results == server.simple_action_results # do_action with tuple of type and body body = b'the-body' results = [x.body for x in client.do_action(('echo', body))] assert results == [body] # ARROW-6884 raise a more specific and helpful exception with pytest.raises(Exception): list(client.do_action('bad-action'))
def test_flight_do_put_metadata(): """Try a simple do_put call with metadata.""" data = [pa.array([-10, -5, 0, 5, 10])] table = pa.Table.from_arrays(data, names=['a']) with MetadataFlightServer() as server: client = FlightClient(('localhost', server.port)) writer, metadata_reader = client.do_put( flight.FlightDescriptor.for_path(''), table.schema) with writer: for idx, batch in enumerate(table.to_batches(max_chunksize=1)): metadata = struct.pack('<i', idx) writer.write_with_metadata(batch, metadata) buf = metadata_reader.read() assert buf is not None server_idx, = struct.unpack('<i', buf.to_pybytes()) assert idx == server_idx
def test_doexchange_put(): """Emulate DoPut with DoExchange.""" data = pa.Table.from_arrays([pa.array(range(0, 10 * 1024))], names=["a"]) batches = data.to_batches(max_chunksize=512) with ExchangeFlightServer() as server: client = FlightClient(("localhost", server.port)) descriptor = flight.FlightDescriptor.for_command(b"put") writer, reader = client.do_exchange(descriptor) with writer: writer.begin(data.schema) for batch in batches: writer.write_batch(batch) writer.done_writing() chunk = reader.read_chunk() assert chunk.data is None expected_buf = str(len(batches)).encode("utf-8") assert chunk.app_metadata == expected_buf
def test_flight_large_message(): """Try sending/receiving a large message via Flight. See ARROW-4421: by default, gRPC won't allow us to send messages > 4MiB in size. """ data = pa.Table.from_arrays([pa.array(range(0, 10 * 1024 * 1024))], names=['a']) with EchoFlightServer(expected_schema=data.schema) as server: client = FlightClient(('localhost', server.port)) writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), data.schema) # Write a single giant chunk writer.write_table(data, 10 * 1024 * 1024) writer.close() result = client.do_get(flight.Ticket(b'')).read_all() assert result.equals(data)
def test_client_wait_for_available(): location = ('localhost', find_free_port()) server = None def serve(): global server time.sleep(0.5) server = FlightServerBase(location) server.serve() client = FlightClient(location) thread = threading.Thread(target=serve, daemon=True) thread.start() started = time.time() client.wait_for_available(timeout=5) elapsed = time.time() - started assert elapsed >= 0.5
def test_flight_domain_socket(): """Try a simple do_get call over a Unix domain socket.""" with tempfile.NamedTemporaryFile() as sock: sock.close() location = flight.Location.for_grpc_unix(sock.name) with ConstantFlightServer(location=location): client = FlightClient(location) reader = client.do_get(flight.Ticket(b'ints')) table = simple_ints_table() assert reader.schema.equals(table.schema) data = reader.read_all() assert data.equals(table) reader = client.do_get(flight.Ticket(b'dicts')) table = simple_dicts_table() assert reader.schema.equals(table.schema) data = reader.read_all() assert data.equals(table)
def test_flight_do_get_metadata(): """Try a simple do_get call with metadata.""" data = [pa.array([-10, -5, 0, 5, 10])] table = pa.Table.from_arrays(data, names=['a']) batches = [] with MetadataFlightServer() as server: client = FlightClient(('localhost', server.port)) reader = client.do_get(flight.Ticket(b'')) idx = 0 while True: try: batch, metadata = reader.read_chunk() batches.append(batch) server_idx, = struct.unpack('<i', metadata.to_pybytes()) assert idx == server_idx idx += 1 except StopIteration: break data = pa.Table.from_batches(batches) assert data.equals(table)
def test_doexchange_transform(): """Transform a table with a service.""" data = pa.Table.from_arrays([ pa.array(range(0, 1024)), pa.array(range(1, 1025)), pa.array(range(2, 1026)), ], names=["a", "b", "c"]) expected = pa.Table.from_arrays([ pa.array(range(3, 1024 * 3 + 3, 3)), ], names=["sum"]) with ExchangeFlightServer() as server: client = FlightClient(("localhost", server.port)) descriptor = flight.FlightDescriptor.for_command(b"transform") writer, reader = client.do_exchange(descriptor) with writer: writer.begin(data.schema) writer.write_table(data) writer.done_writing() table = reader.read_all() assert expected == table
def test_roundtrip_errors(): """Ensure that Flight errors propagate from server to client.""" with ErrorFlightServer() as server: client = FlightClient(('localhost', server.port)) with pytest.raises(flight.FlightInternalError, match=".*foo.*"): list(client.do_action(flight.Action("internal", b""))) with pytest.raises(flight.FlightTimedOutError, match=".*foo.*"): list(client.do_action(flight.Action("timedout", b""))) with pytest.raises(flight.FlightCancelledError, match=".*foo.*"): list(client.do_action(flight.Action("cancel", b""))) with pytest.raises(flight.FlightUnauthenticatedError, match=".*foo.*"): list(client.do_action(flight.Action("unauthenticated", b""))) with pytest.raises(flight.FlightUnauthorizedError, match=".*foo.*"): list(client.do_action(flight.Action("unauthorized", b""))) with pytest.raises(flight.FlightInternalError, match=".*foo.*"): list(client.list_flights())
def test_list_actions(): """Make sure the return type of ListActions is validated.""" # ARROW-6392 with ListActionsErrorFlightServer() as server: client = FlightClient(('localhost', server.port)) with pytest.raises(pa.ArrowException, match=".*unknown error.*"): list(client.list_actions()) with ListActionsFlightServer() as server: client = FlightClient(('localhost', server.port)) assert list(client.list_actions()) == \ ListActionsFlightServer.expected_actions()