def test_roundtrip_types(): """Make sure serializable types round-trip.""" ticket = flight.Ticket("foo") assert ticket == flight.Ticket.deserialize(ticket.serialize()) desc = flight.FlightDescriptor.for_command("test") assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) desc = flight.FlightDescriptor.for_path("a", "b", "test.arrow") assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) info = flight.FlightInfo( pa.schema([('a', pa.int32())]), desc, [ flight.FlightEndpoint(b'', ['grpc://test']), flight.FlightEndpoint( b'', [flight.Location.for_grpc_tcp('localhost', 5005)], ), ], -1, -1, ) info2 = flight.FlightInfo.deserialize(info.serialize()) assert info.schema == info2.schema assert info.descriptor == info2.descriptor assert info.total_bytes == info2.total_bytes assert info.total_records == info2.total_records assert info.endpoints == info2.endpoints
def list_flights(self, context, criteria): yield flight.FlightInfo( pa.schema([]), flight.FlightDescriptor.for_path('/foo'), [], -1, -1 ) raise flight.FlightInternalError("foo")
def list_flights(self, context, criteria): if criteria == self.CRITERIA: yield flight.FlightInfo( pa.schema([]), flight.FlightDescriptor.for_path('/foo'), [], -1, -1 )
def get_flight_info(self, context, descriptor): ticket_name = b''.join(descriptor.path) if ticket_name in self.tables: table = self.tables[ticket_name] endpoints = [ fl.FlightEndpoint(ticket_name, ["grpc://0.0.0.0:8815"]) ] return fl.FlightInfo(table.schema, descriptor, endpoints, table.num_rows, 0) raise KeyError("Unknown ticket name: {}".format(ticket_name))
def get_flight_info(self, context, descriptor): return flight.FlightInfo( pa.schema([('a', pa.int32())]), descriptor, [ flight.FlightEndpoint(b'', ['grpc://test']), flight.FlightEndpoint( b'', [flight.Location.for_grpc_tcp('localhost', 5005)], ), ], -1, -1, )
def get_flight_info(self, context, descriptor): cmd = AFMCommand(descriptor.command) logger.info('getting flight information', extra={ 'command': descriptor.command, DataSetID: cmd.asset_name }) with Config(self.config_path) as config: asset = asset_from_config(config, cmd.asset_name) workers = workers_from_config(config.workers) if asset.connection_type == 'flight': passthrough_flight_info = asset.flight.get_flight_info() schema = passthrough_flight_info.schema else: # Infer schema schema, data_files = self._infer_schema(asset) if cmd.columns: schema = self._filter_columns(schema, cmd.columns) schema = transform_schema(asset.actions, schema) # Build endpoint to this server locations = self._get_locations(workers) tickets = [] if asset.connection_type == 'flight': for endpoint in passthrough_flight_info.endpoints: tickets.append( AFMTicket(cmd.asset_name, schema.names, endpoint.ticket.ticket.decode())) else: # Build endpoint to this server for f in data_files: tickets.append( AFMTicket(cmd.asset_name, schema.names, partition_path=f)) endpoints = self._get_endpoints(tickets, locations) return fl.FlightInfo(schema, descriptor, endpoints, -1, -1)
def get_flight_info(self, context, descriptor): cmd = AFMCommand(descriptor.command) with Config(self.config_path) as config: asset = asset_from_config(config, cmd.asset_name) # Infer schema schema = self._infer_schema(asset) if cmd.columns: schema = self._filter_columns(schema, cmd.columns) schema = transform_schema(asset.actions, schema) # Build endpoint to this server endpoints = [] ticket = AFMTicket(cmd.asset_name, schema.names) locations = [] local_address = os.getenv("MY_POD_IP") if local_address: locations += "grpc://{}:{}".format(local_address, self.port) endpoints.append(fl.FlightEndpoint(ticket.toJSON(), locations)) return fl.FlightInfo(schema, descriptor, endpoints, -1, -1)