def do_get(self, context, ticket: fl.Ticket): ticket_info: AFMTicket = AFMTicket.fromJSON(ticket.ticket) if ticket_info.columns is None: raise ValueError("Columns must be specified in ticket") logger.info('retrieving dataset', extra={ 'ticket': ticket.ticket, DataSetID: ticket_info.asset_name }) with Config(self.config_path) as config: asset = asset_from_config( config, ticket_info.asset_name, partition_path=ticket_info.partition_path) if asset.connection_type == "flight": schema, batches = asset.flight.do_get(context, ticket) if ticket_info.columns: asset.add_action( actions.FilterColumns(columns=ticket_info.columns, description="filter columns", options=None)) else: schema, batches = self._read_asset(asset, ticket_info.columns) schema = transform_schema(asset.actions, schema) batches = transform(asset.actions, batches) return fl.GeneratorStream(schema, batches)
def authenticate(self, outgoing, incoming): buf = incoming.read() auth = flight.BasicAuth.deserialize(buf) logger.info('basic authentication', extra={'username': auth.username.decode()}) if auth.username.decode() not in self.creds: raise flight.FlightUnauthenticatedError("unknown user") if self.creds[auth.username.decode()] != auth.password.decode(): raise flight.FlightUnauthenticatedError("wrong password") outgoing.write(auth.username)
def __init__(self, auth_config): super().__init__() if not auth_config: logger.info( "no authentication configuration. Using NoopAuthHandler") self.auth_handler = NoopAuthHandler() elif 'basic' in auth_config: logger.info( "basic authentication configuration. Using HttpBasicServerAuthHandler" ) self.auth_handler = HttpBasicServerAuthHandler( auth_config['basic'].get('credentials', None)) else: raise NotImplementedError("Unknown authenticaion type")
def get_flight_info(self, context, descriptor): cmd = AFMCommand(descriptor.command) logger.info('getting flight information', extra={ 'command': descriptor.command, DataSetID: cmd.asset_name }) with Config(self.config_path) as config: asset = asset_from_config(config, cmd.asset_name) workers = workers_from_config(config.workers) if asset.connection_type == 'flight': passthrough_flight_info = asset.flight.get_flight_info() schema = passthrough_flight_info.schema else: # Infer schema schema, data_files = self._infer_schema(asset) if cmd.columns: schema = self._filter_columns(schema, cmd.columns) schema = transform_schema(asset.actions, schema) # Build endpoint to this server locations = self._get_locations(workers) tickets = [] if asset.connection_type == 'flight': for endpoint in passthrough_flight_info.endpoints: tickets.append( AFMTicket(cmd.asset_name, schema.names, endpoint.ticket.ticket.decode())) else: # Build endpoint to this server for f in data_files: tickets.append( AFMTicket(cmd.asset_name, schema.names, partition_path=f)) endpoints = self._get_endpoints(tickets, locations) return fl.FlightInfo(schema, descriptor, endpoints, -1, -1)
def do_put(self, context, descriptor, reader, writer): asset_info = json.loads(descriptor.command) logger.info('writing dataset', extra={DataSetID: asset_info['asset']}) with Config(self.config_path) as config: asset = asset_from_config(config, asset_info['asset']) self._write_asset(asset, reader)
# SPDX-License-Identifier: Apache-2.0 # import argparse from afm.server import AFMFlightServer from afm.logging import logger if __name__ == '__main__': parser = argparse.ArgumentParser(description='AFM Flight Server') parser.add_argument('-p', '--port', type=int, default=8080, help='Listening port') parser.add_argument('-c', '--config', type=str, default='/etc/conf/conf.yaml', help='Path to config file') parser.add_argument( '-l', '--loglevel', type=str, default='warning', help='logging level', choices=['trace', 'info', 'debug', 'warning', 'error', 'critical']) args = parser.parse_args() server = AFMFlightServer(args.config, args.port, args.loglevel.upper()) logger.info('AFMFlightServer started') server.serve()