예제 #1
0
def test_flight_do_get_dicts():
    table = simple_dicts_table()

    with ConstantFlightServer() as server:
        client = flight.connect(('localhost', server.port))
        data = client.do_get(flight.Ticket(b'dicts')).read_all()
        assert data.equals(table)
예제 #2
0
def test_flight_do_get_dicts():
    table = simple_dicts_table()

    with flight_server(ConstantFlightServer) as server_location:
        client = flight.FlightClient.connect(server_location)
        data = client.do_get(flight.Ticket(b'dicts')).read_all()
        assert data.equals(table)
예제 #3
0
def client():
    # client = fl.connect("grpc://0.0.0.0:8815")
    client = fl.connect("grpc://35.168.111.94:8815")

    stream = client.do_get(fl.Ticket('molbeam'))
    for r in tqdm(stream, total=191):
        continue
예제 #4
0
def test_cancel_do_get_threaded():
    """Test canceling a DoGet operation from another thread."""
    with flight_server(SlowFlightServer) as server_location:
        client = flight.FlightClient.connect(server_location)
        reader = client.do_get(flight.Ticket(b'ints'))

        read_first_message = threading.Event()
        stream_canceled = threading.Event()
        result_lock = threading.Lock()
        raised_proper_exception = threading.Event()

        def block_read():
            reader.read_chunk()
            read_first_message.set()
            stream_canceled.wait(timeout=5)
            try:
                reader.read_chunk()
            except flight.FlightCancelledError:
                with result_lock:
                    raised_proper_exception.set()

        thread = threading.Thread(target=block_read, daemon=True)
        thread.start()
        read_first_message.wait(timeout=5)
        reader.cancel()
        stream_canceled.set()
        thread.join(timeout=1)

        with result_lock:
            assert raised_proper_exception.is_set()
예제 #5
0
    def get_data(
        self,
        selector: SeriesSelector,
        start_date: datetime = None,
        end_date: datetime = None,
    ) -> pa.Table:
        """Get raw data for the time series selected by the SeriesSelector.

        Args:
            selector: return data for the time series selected by this selector.
            start_date: the start date of the time range of data to return. Defaults to one year ago.
            end_date: the end date of the time range of data to return. Defaults to now.

        Returns:
            A pyarrow Table with two columns: 'ts' and 'value'.
        """
        if start_date is None or end_date is None:
            now = datetime.utcnow().replace(tzinfo=timezone(timedelta(0)))
            if start_date is None:
                start_date = now.replace(year=now.year - 1)
            if end_date is None:
                end_date = now
        query = {
            "query": "get_data",
            "selector": {
                "source": selector.source,
                "name": selector.name,
            },
            "start_date": start_date.isoformat(),
            "end_date": end_date.isoformat(),
        }
        ticket = fl.Ticket(json.dumps(query))
        return self._get_client().do_get(ticket).read_all()
예제 #6
0
def test_roundtrip_types():
    """Make sure serializable types round-trip."""
    ticket = flight.Ticket("foo")
    assert ticket == flight.Ticket.deserialize(ticket.serialize())

    desc = flight.FlightDescriptor.for_command("test")
    assert desc == flight.FlightDescriptor.deserialize(desc.serialize())

    desc = flight.FlightDescriptor.for_path("a", "b", "test.arrow")
    assert desc == flight.FlightDescriptor.deserialize(desc.serialize())

    info = flight.FlightInfo(
        pa.schema([('a', pa.int32())]),
        desc,
        [
            flight.FlightEndpoint(b'', ['grpc://test']),
            flight.FlightEndpoint(
                b'',
                [flight.Location.for_grpc_tcp('localhost', 5005)],
            ),
        ],
        -1,
        -1,
    )
    info2 = flight.FlightInfo.deserialize(info.serialize())
    assert info.schema == info2.schema
    assert info.descriptor == info2.descriptor
    assert info.total_bytes == info2.total_bytes
    assert info.total_records == info2.total_records
    assert info.endpoints == info2.endpoints
예제 #7
0
def test_flight_do_get_ticket():
    """Make sure Tickets get passed to the server."""
    data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
    table = pa.Table.from_arrays(data1, names=['a'])
    with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server:
        client = flight.connect(('localhost', server.port))
        data = client.do_get(flight.Ticket(b'the-ticket')).read_all()
        assert data.equals(table)
예제 #8
0
def test_do_get_ints_pandas():
    """Try a simple do_get call."""
    table = simple_ints_table()

    with flight_server(ConstantFlightServer) as server_location:
        client = flight.FlightClient.connect(server_location)
        data = client.do_get(flight.Ticket(b'ints')).read_pandas()
        assert list(data['some_ints']) == table.column(0).to_pylist()
예제 #9
0
def test_cancel_do_get():
    """Test canceling a DoGet operation on the client side."""
    with ConstantFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        reader = client.do_get(flight.Ticket(b'ints'))
        reader.cancel()
        with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"):
            reader.read_chunk()
예제 #10
0
def test_do_get_ints_pandas():
    """Try a simple do_get call."""
    table = simple_ints_table()

    with ConstantFlightServer() as server:
        client = flight.connect(('localhost', server.port))
        data = client.do_get(flight.Ticket(b'ints')).read_pandas()
        assert list(data['some_ints']) == table.column(0).to_pylist()
예제 #11
0
def test_cancel_do_get():
    """Test canceling a DoGet operation on the client side."""
    with flight_server(ConstantFlightServer) as server_location:
        client = flight.FlightClient.connect(server_location)
        reader = client.do_get(flight.Ticket(b'ints'))
        reader.cancel()
        with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"):
            reader.read_chunk()
 def snapshot_table(self, table: Table):
     try:
         options = paflight.FlightCallOptions(
             headers=self.session.grpc_metadata)
         flight_ticket = paflight.Ticket(table.ticket.ticket)
         reader = self._flight_client.do_get(flight_ticket, options=options)
         return reader.read_all()
     except Exception as e:
         raise DHError("failed to take a snapshot of the table.") from e
예제 #13
0
def test_flight_domain_socket():
    """Try a simple do_get call over a Unix domain socket."""
    with tempfile.NamedTemporaryFile() as sock:
        sock.close()
        location = flight.Location.for_grpc_unix(sock.name)
        with ConstantFlightServer(location=location):
            client = FlightClient(location)

            reader = client.do_get(flight.Ticket(b'ints'))
            table = simple_ints_table()
            assert reader.schema.equals(table.schema)
            data = reader.read_all()
            assert data.equals(table)

            reader = client.do_get(flight.Ticket(b'dicts'))
            table = simple_dicts_table()
            assert reader.schema.equals(table.schema)
            data = reader.read_all()
            assert data.equals(table)
예제 #14
0
def test_tls_override_hostname():
    """Check that incorrectly overriding the hostname fails."""
    certs = example_tls_certs()

    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
        client = flight.connect(('localhost', s.port),
                                tls_root_certs=certs["root_cert"],
                                override_hostname="fakehostname")
        with pytest.raises(flight.FlightUnavailableError):
            client.do_get(flight.Ticket(b'ints'))
예제 #15
0
def test_tls_do_get():
    """Try a simple do_get call over TLS."""
    table = simple_ints_table()
    certs = example_tls_certs()

    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
        client = FlightClient(('localhost', s.port),
                              tls_root_certs=certs["root_cert"])
        data = client.do_get(flight.Ticket(b'ints')).read_all()
        assert data.equals(table)
예제 #16
0
def test_flight_generator_stream():
    """Try downloading a flight of RecordBatches in a GeneratorStream."""
    data = pa.Table.from_arrays([pa.array(range(0, 10 * 1024))], names=['a'])

    with EchoStreamFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
                                  data.schema)
        writer.write_table(data)
        writer.close()
        result = client.do_get(flight.Ticket(b'')).read_all()
        assert result.equals(data)
예제 #17
0
def pclient():
    import ray
    ray.init()

    @ray.remote
    def f(batch):
        return 1

    client = fl.connect("grpc://35.168.111.94:8815")
    stream = client.do_get(fl.Ticket('molbeam'))
    futures = [f.remote(b.data) for b in tqdm(stream, total=191)]
    print(ray.get(futures))
예제 #18
0
def test_tls_do_get():
    """Try a simple do_get call over TLS."""
    table = simple_ints_table()
    certs = example_tls_certs()

    with flight_server(
            ConstantFlightServer, tls_certificates=certs["certificates"],
            connect_args=dict(tls_root_certs=certs["root_cert"]),
    ) as server_location:
        client = flight.FlightClient.connect(
            server_location, tls_root_certs=certs["root_cert"])
        data = client.do_get(flight.Ticket(b'ints')).read_all()
        assert data.equals(table)
예제 #19
0
def test_tls_override_hostname():
    """Check that incorrectly overriding the hostname fails."""
    certs = example_tls_certs()

    with flight_server(
            ConstantFlightServer, tls_certificates=certs["certificates"],
            connect_args=dict(tls_root_certs=certs["root_cert"]),
    ) as server_location:
        client = flight.FlightClient.connect(
            server_location, tls_root_certs=certs["root_cert"],
            override_hostname="fakehostname")
        with pytest.raises(flight.FlightUnavailableError):
            client.do_get(flight.Ticket(b'ints'))
예제 #20
0
def test_tls_fails():
    """Make sure clients cannot connect when cert verification fails."""
    certs = example_tls_certs()

    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
        # Ensure client doesn't connect when certificate verification
        # fails (this is a slow test since gRPC does retry a few times)
        client = FlightClient("grpc+tls://localhost:" + str(s.port))

        # gRPC error messages change based on version, so don't look
        # for a particular error
        with pytest.raises(flight.FlightUnavailableError):
            client.do_get(flight.Ticket(b'ints')).read_all()
예제 #21
0
파일: test_flight.py 프로젝트: vikram/arrow
def test_mtls():
    """Test mutual TLS (mTLS) with gRPC."""
    certs = example_tls_certs()
    table = simple_ints_table()

    with ConstantFlightServer(tls_certificates=[certs["certificates"][0]],
                              verify_client=True,
                              root_certificates=certs["root_cert"]) as s:
        client = FlightClient(('localhost', s.port),
                              tls_root_certs=certs["root_cert"],
                              cert_chain=certs["certificates"][0].cert,
                              private_key=certs["certificates"][0].key)
        data = client.do_get(flight.Ticket(b'ints')).read_all()
        assert data.equals(table)
예제 #22
0
def test_tls_fails():
    """Make sure clients cannot connect when cert verification fails."""
    certs = example_tls_certs()

    with flight_server(
            ConstantFlightServer, tls_certificates=certs["certificates"],
            connect_args=dict(tls_root_certs=certs["root_cert"]),
    ) as server_location:
        # Ensure client doesn't connect when certificate verification
        # fails (this is a slow test since gRPC does retry a few times)
        client = flight.FlightClient.connect(server_location)
        # gRPC error messages change based on version, so don't look
        # for a particular error
        with pytest.raises(flight.FlightUnavailableError):
            client.do_get(flight.Ticket(b'ints'))
예제 #23
0
def test_flight_large_message():
    """Try sending/receiving a large message via Flight.

    See ARROW-4421: by default, gRPC won't allow us to send messages >
    4MiB in size.
    """
    data = pa.Table.from_arrays([pa.array(range(0, 10 * 1024 * 1024))],
                                names=['a'])

    with EchoFlightServer(expected_schema=data.schema) as server:
        client = FlightClient(('localhost', server.port))
        writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
                                  data.schema)
        # Write a single giant chunk
        writer.write_table(data, 10 * 1024 * 1024)
        writer.close()
        result = client.do_get(flight.Ticket(b'')).read_all()
        assert result.equals(data)
예제 #24
0
def test_middleware_mapping():
    """Test that middleware records methods correctly."""
    server_middleware = RecordingServerMiddlewareFactory()
    client_middleware = RecordingClientMiddlewareFactory()
    with FlightServerBase(middleware={"test": server_middleware}) as server:
        client = FlightClient(
            ('localhost', server.port),
            middleware=[client_middleware]
        )

        descriptor = flight.FlightDescriptor.for_command(b"")
        with pytest.raises(NotImplementedError):
            list(client.list_flights())
        with pytest.raises(NotImplementedError):
            client.get_flight_info(descriptor)
        with pytest.raises(NotImplementedError):
            client.get_schema(descriptor)
        with pytest.raises(NotImplementedError):
            client.do_get(flight.Ticket(b""))
        with pytest.raises(NotImplementedError):
            writer, _ = client.do_put(descriptor, pa.schema([]))
            writer.close()
        with pytest.raises(NotImplementedError):
            list(client.do_action(flight.Action(b"", b"")))
        with pytest.raises(NotImplementedError):
            list(client.list_actions())
        with pytest.raises(NotImplementedError):
            writer, _ = client.do_exchange(descriptor)
            writer.close()

        expected = [
            flight.FlightMethod.LIST_FLIGHTS,
            flight.FlightMethod.GET_FLIGHT_INFO,
            flight.FlightMethod.GET_SCHEMA,
            flight.FlightMethod.DO_GET,
            flight.FlightMethod.DO_PUT,
            flight.FlightMethod.DO_ACTION,
            flight.FlightMethod.LIST_ACTIONS,
            flight.FlightMethod.DO_EXCHANGE,
        ]
        assert server_middleware.methods == expected
        assert client_middleware.methods == expected
예제 #25
0
def test_flight_do_get_metadata():
    """Try a simple do_get call with metadata."""
    data = [pa.array([-10, -5, 0, 5, 10])]
    table = pa.Table.from_arrays(data, names=['a'])

    batches = []
    with MetadataFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        reader = client.do_get(flight.Ticket(b''))
        idx = 0
        while True:
            try:
                batch, metadata = reader.read_chunk()
                batches.append(batch)
                server_idx, = struct.unpack('<i', metadata.to_pybytes())
                assert idx == server_idx
                idx += 1
            except StopIteration:
                break
        data = pa.Table.from_batches(batches)
        assert data.equals(table)
예제 #26
0
def test_timeout_passes():
    """Make sure timeouts do not fire on fast requests."""
    with ConstantFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        options = flight.FlightCallOptions(timeout=5.0)
        client.do_get(flight.Ticket(b'ints'), options=options).read_all()
예제 #27
0
def get_by_ticket(args, client):
    ticket_name = args.name
    response = client.do_get(fl.Ticket(ticket_name)).read_all()
    print_response(response)
예제 #28
0
def test_flight_invalid_generator_stream():
    """Try streaming data with mismatched schemas."""
    with flight_server(InvalidStreamFlightServer) as server_location:
        client = flight.FlightClient.connect(server_location)
        with pytest.raises(pa.ArrowException):
            client.do_get(flight.Ticket(b'')).read_all()
예제 #29
0
def test_flight_invalid_generator_stream():
    """Try streaming data with mismatched schemas."""
    with InvalidStreamFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        with pytest.raises(pa.ArrowException):
            client.do_get(flight.Ticket(b'')).read_all()
예제 #30
0
def test_timeout_passes():
    """Make sure timeouts do not fire on fast requests."""
    with flight_server(ConstantFlightServer) as server_location:
        client = flight.FlightClient.connect(server_location)
        options = flight.FlightCallOptions(timeout=5.0)
        client.do_get(flight.Ticket(b'ints'), options=options).read_all()