예제 #1
0
def test_roundtrip_types():
    """Make sure serializable types round-trip."""
    ticket = flight.Ticket("foo")
    assert ticket == flight.Ticket.deserialize(ticket.serialize())

    desc = flight.FlightDescriptor.for_command("test")
    assert desc == flight.FlightDescriptor.deserialize(desc.serialize())

    desc = flight.FlightDescriptor.for_path("a", "b", "test.arrow")
    assert desc == flight.FlightDescriptor.deserialize(desc.serialize())

    info = flight.FlightInfo(
        pa.schema([('a', pa.int32())]),
        desc,
        [
            flight.FlightEndpoint(b'', ['grpc://test']),
            flight.FlightEndpoint(
                b'',
                [flight.Location.for_grpc_tcp('localhost', 5005)],
            ),
        ],
        -1,
        -1,
    )
    info2 = flight.FlightInfo.deserialize(info.serialize())
    assert info.schema == info2.schema
    assert info.descriptor == info2.descriptor
    assert info.total_bytes == info2.total_bytes
    assert info.total_records == info2.total_records
    assert info.endpoints == info2.endpoints
예제 #2
0
 def _get_endpoints(self, tickets, locations):
     endpoints = []
     i = 0
     for ticket in tickets:
         if locations:
             endpoints.append(
                 fl.FlightEndpoint(ticket.toJSON(), [locations[i]]))
             i = (i + 1) % len(locations)
         else:
             endpoints.append(fl.FlightEndpoint(ticket.toJSON(), []))
     return endpoints
예제 #3
0
 def get_flight_info(self, context, descriptor):
     return flight.FlightInfo(
         pa.schema([('a', pa.int32())]),
         descriptor,
         [
             flight.FlightEndpoint(b'', ['grpc://test']),
             flight.FlightEndpoint(
                 b'',
                 [flight.Location.for_grpc_tcp('localhost', 5005)],
             ),
         ],
         -1,
         -1,
     )
    def get_flight_info(self, context, descriptor):
        ticket_name = b''.join(descriptor.path)
        if ticket_name in self.tables:
            table = self.tables[ticket_name]
            endpoints = [
                fl.FlightEndpoint(ticket_name, ["grpc://0.0.0.0:8815"])
            ]
            return fl.FlightInfo(table.schema, descriptor, endpoints,
                                 table.num_rows, 0)

        raise KeyError("Unknown ticket name: {}".format(ticket_name))
예제 #5
0
    def get_flight_info(self, context, descriptor):
        cmd = AFMCommand(descriptor.command)

        with Config(self.config_path) as config:
            asset = asset_from_config(config, cmd.asset_name)

        # Infer schema
        schema = self._infer_schema(asset)
        if cmd.columns:
            schema = self._filter_columns(schema, cmd.columns)
        schema = transform_schema(asset.actions, schema)

        # Build endpoint to this server
        endpoints = []
        ticket = AFMTicket(cmd.asset_name, schema.names)
        locations = []
        local_address = os.getenv("MY_POD_IP")
        if local_address:
            locations += "grpc://{}:{}".format(local_address, self.port)
        endpoints.append(fl.FlightEndpoint(ticket.toJSON(), locations))

        return fl.FlightInfo(schema, descriptor, endpoints, -1, -1)