def handle_query(self, query: str) -> WhoisQueryResponse: """Process a single query. Always returns a WhoisQueryResponse object.""" # These flags are reset with every query. self.database_handler = DatabaseHandler() self.key_fields_only = False self.object_classes = [] if query.startswith('!'): try: return self.handle_irrd_command(query[1:]) except WhoisQueryParserException as exc: logger.info( f'{self.peer}: encountered parsing error while parsing query {query}: {exc}' ) return WhoisQueryResponse( response_type=WhoisQueryResponseType.ERROR, mode=WhoisQueryResponseMode.IRRD, result=str(exc)) finally: self.database_handler.close() try: return self.handle_ripe_command(query) except WhoisQueryParserException as exc: logger.info( f'{self.peer}: encountered parsing error while parsing query {query}: {exc}' ) return WhoisQueryResponse( response_type=WhoisQueryResponseType.ERROR, mode=WhoisQueryResponseMode.RIPE, result=str(exc)) finally: self.database_handler.close()
def run(self, database_handler: DatabaseHandler): database_handler.delete_all_rpsl_objects(self.source) dump_sources = get_setting(f'sources.{self.source}.dump_source').split( ',') dump_serial_source = get_setting( f'sources.{self.source}.dump_serial_source') if not dump_sources or not dump_serial_source: logger.debug( f'Skipping full import for {self.source}, dump_source or dump_serial_source not set.' ) return logger.info( f'Running full import of {self.source} from {dump_sources}, serial from {dump_serial_source}' ) dump_serial = int( self._retrieve_file(dump_serial_source, use_tempfile=False)) dump_filenames = [ self._retrieve_file(dump_source, use_tempfile=True) for dump_source in dump_sources ] database_handler.disable_journaling() for dump_filename in dump_filenames: MirrorFullImportParser(source=self.source, filename=dump_filename, serial=dump_serial, database_handler=database_handler) os.unlink(dump_filename)
def main(self, filename, strict_validation, database, show_info=True): self.show_info = show_info if database: self.database_handler = DatabaseHandler(journaling_enabled=False) if filename == '-': # pragma: no cover f = sys.stdin else: f = open(filename, encoding="utf-8", errors='backslashreplace') for paragraph in split_paragraphs_rpsl(f): self.parse_object(paragraph, strict_validation) print( f"Processed {self.obj_parsed} objects, {self.obj_errors} with errors" ) if self.obj_unknown: unknown_formatted = ', '.join(self.unknown_object_classes) print( f"Ignored {self.obj_unknown} objects due to unknown object classes: {unknown_formatted}" ) if self.database_handler: self.database_handler.commit() self.database_handler.close()
def save(self, database_handler: DatabaseHandler) -> None: """Save the update to the database.""" if self.status != UpdateRequestStatus.PROCESSING or not self.rpsl_obj_new: raise ValueError( "UpdateRequest can only be saved in status PROCESSING") if self.request_type == UpdateRequestType.DELETE and self.rpsl_obj_current is not None: database_handler.delete_rpsl_object(self.rpsl_obj_current) else: database_handler.upsert_rpsl_object(self.rpsl_obj_new) self.status = UpdateRequestStatus.SAVED
def __init__(self, object_texts: str, pgp_fingerprint: str = None, request_meta: Dict[str, Optional[str]] = None) -> None: self.database_handler = DatabaseHandler() self.request_meta = request_meta if request_meta else {} self._pgp_key_id = self._resolve_pgp_key_id( pgp_fingerprint) if pgp_fingerprint else None self._handle_object_texts(object_texts) self.database_handler.commit() self.database_handler.close()
def run(self) -> None: self.database_handler = DatabaseHandler() serial_newest_seen, force_reload = self._status() logger.debug( f'Most recent serial seen for {self.source}: {serial_newest_seen}, force_reload: {force_reload}' ) if not serial_newest_seen or force_reload: self.full_import_runner.run(database_handler=self.database_handler) else: self.update_stream_runner.run( serial_newest_seen, database_handler=self.database_handler) self.database_handler.commit() self.database_handler.close()
class RPSLParse: obj_parsed = 0 obj_errors = 0 obj_unknown = 0 unknown_object_classes: Set[str] = set() database_handler = None def main(self, filename, strict_validation, database, show_info=True): self.show_info = show_info if database: self.database_handler = DatabaseHandler(journaling_enabled=False) if filename == '-': # pragma: no cover f = sys.stdin else: f = open(filename, encoding="utf-8", errors='backslashreplace') for paragraph in split_paragraphs_rpsl(f): self.parse_object(paragraph, strict_validation) print( f"Processed {self.obj_parsed} objects, {self.obj_errors} with errors" ) if self.obj_unknown: unknown_formatted = ', '.join(self.unknown_object_classes) print( f"Ignored {self.obj_unknown} objects due to unknown object classes: {unknown_formatted}" ) if self.database_handler: self.database_handler.commit() self.database_handler.close() def parse_object(self, rpsl_text, strict_validation): try: self.obj_parsed += 1 obj = rpsl_object_from_text(rpsl_text.strip(), strict_validation=strict_validation) if (obj.messages.messages() and self.show_info) or obj.messages.errors(): if obj.messages.errors(): self.obj_errors += 1 print(rpsl_text.strip()) print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") print(obj.messages) print("\n=======================================\n") if self.database_handler and obj and not obj.messages.errors(): self.database_handler.upsert_rpsl_object(obj) except UnknownRPSLObjectClassException as e: self.obj_unknown += 1 self.unknown_object_classes.add(str(e).split(":")[1].strip()) except Exception as e: # pragma: no cover print("=======================================") print(rpsl_text) print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") raise e
class MirrorUpdateRunner: """ This MirrorUpdateRunner is the entry point for updating a single database mirror, depending on current state. If there is no current mirrored data, will call MirrorFullImportRunner to run a new import from full dump files. Otherwise, will call NRTMUpdateStreamRunner to retrieve new updates from NRTM. """ def __init__(self, source: str) -> None: self.source = source self.full_import_runner = MirrorFullImportRunner(source) self.update_stream_runner = NRTMUpdateStreamRunner(source) def run(self) -> None: self.database_handler = DatabaseHandler() serial_newest_seen, force_reload = self._status() logger.debug( f'Most recent serial seen for {self.source}: {serial_newest_seen}, force_reload: {force_reload}' ) if not serial_newest_seen or force_reload: self.full_import_runner.run(database_handler=self.database_handler) else: self.update_stream_runner.run( serial_newest_seen, database_handler=self.database_handler) self.database_handler.commit() self.database_handler.close() def _status(self) -> Tuple[Optional[int], Optional[bool]]: query = RPSLDatabaseStatusQuery().source(self.source) result = self.database_handler.execute_query(query) try: status = next(result) return status['serial_newest_seen'], status['force_reload'] except StopIteration: return None, None
def save(self, database_handler: DatabaseHandler) -> bool: try: obj = rpsl_object_from_text(self.object_text.strip(), strict_validation=False) except UnknownRPSLObjectClassException as exc: logger.warning( f'Ignoring NRTM from {self.source} operation {self.serial}/{self.operation.value}: {exc}' ) return False if self.object_class_filter and obj.rpsl_object_class.lower( ) not in self.object_class_filter: return False if obj.messages.errors(): errors = '; '.join(obj.messages.errors()) logger.critical( f'Parsing errors occurred while processing NRTM from {self.source}, ' f'operation {self.serial}/{self.operation.value}. This operation is ignored, ' f'causing potential data inconsistencies. A new operation for this update, without errors, ' f'will still be processed and cause the inconsistency to be resolved. ' f'Parser error messages: {errors}; original object text follows:\n{self.object_text}' ) database_handler.record_mirror_error( self.source, f'Parsing errors: {obj.messages.errors()}, ' f'original object text follows:\n{self.object_text}') return False if 'source' in obj.parsed_data and obj.parsed_data['source'].upper( ) != self.source: msg = ( f'Incorrect source in NRTM object: stream has source {self.source}, found object with ' f'source {obj.source()} in operation {self.serial}/{self.operation.value}/{obj.pk()}. ' f'This operation is ignored, causing potential data inconsistencies.' ) database_handler.record_mirror_error(self.source, msg) logger.critical(msg) return False if self.operation == DatabaseOperation.add_or_update: database_handler.upsert_rpsl_object(obj, self.serial) elif self.operation == DatabaseOperation.delete: database_handler.delete_rpsl_object(obj, self.serial) logger.info( f'Completed NRTM operation in {self.source}: {self.serial}/{self.operation.value}/{obj.pk()}' ) return True
class WhoisQueryParser: """ Parser for all whois-style queries. This parser distinguishes RIPE-style, e.g. "-K 192.0.2.1" or "-i mnt-by FOO" from IRRD-style, e.g. "!oFOO". Some query flags, particularly -k/!! and -s/!s retain state across queries, so a single instance of this object should be created per session, with handle_query() being called for each individual query. """ lookup_field_names = lookup_field_names() database_handler: DatabaseHandler def __init__(self, peer: str) -> None: self.all_valid_sources = list(get_setting('sources').keys()) self.sources: List[str] = [] self.object_classes: List[str] = [] self.user_agent: Optional[str] = None self.multiple_command_mode = False self.key_fields_only = False self.peer = peer def handle_query(self, query: str) -> WhoisQueryResponse: """Process a single query. Always returns a WhoisQueryResponse object.""" # These flags are reset with every query. self.database_handler = DatabaseHandler() self.key_fields_only = False self.object_classes = [] if query.startswith('!'): try: return self.handle_irrd_command(query[1:]) except WhoisQueryParserException as exc: logger.info( f'{self.peer}: encountered parsing error while parsing query {query}: {exc}' ) return WhoisQueryResponse( response_type=WhoisQueryResponseType.ERROR, mode=WhoisQueryResponseMode.IRRD, result=str(exc)) finally: self.database_handler.close() try: return self.handle_ripe_command(query) except WhoisQueryParserException as exc: logger.info( f'{self.peer}: encountered parsing error while parsing query {query}: {exc}' ) return WhoisQueryResponse( response_type=WhoisQueryResponseType.ERROR, mode=WhoisQueryResponseMode.RIPE, result=str(exc)) finally: self.database_handler.close() def handle_irrd_command(self, full_command: str) -> WhoisQueryResponse: """Handle an IRRD-style query. full_command should not include the first exclamation mark. """ if not full_command: raise WhoisQueryParserException(f'Missing IRRD command') command = full_command[0].upper() parameter = full_command[1:] response_type = WhoisQueryResponseType.SUCCESS result = None if command == '!': self.multiple_command_mode = True result = None elif command == 'G': result = self.handle_irrd_routes_for_origin_v4(parameter) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND elif command == '6': result = self.handle_irrd_routes_for_origin_v6(parameter) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND elif command == 'I': result = self.handle_irrd_set_members(parameter) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND elif command == 'J': result = self.handle_irrd_database_serial_range(parameter) elif command == 'M': result = self.handle_irrd_exact_key(parameter) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND elif command == 'N': self.handle_user_agent(parameter) elif command == 'O': result = self.handle_inverse_attr_search('mnt-by', parameter) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND elif command == 'R': result = self.handle_irrd_route_search(parameter) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND elif command == 'S': result = self.handle_irrd_sources_list(parameter) elif command == 'V': result = self.handle_irrd_version() else: raise WhoisQueryParserException(f'Unrecognised command: {command}') return WhoisQueryResponse( response_type=response_type, mode=WhoisQueryResponseMode.IRRD, result=result, ) def handle_irrd_routes_for_origin_v4(self, origin: str) -> str: """!g query - find all originating IPv4 prefixes from an origin, e.g. !gAS65537""" return self._routes_for_origin('route', origin) def handle_irrd_routes_for_origin_v6(self, origin: str) -> str: """!6 query - find all originating IPv6 prefixes from an origin, e.g. !6as65537""" return self._routes_for_origin('route6', origin) def _routes_for_origin(self, object_class: str, origin: str) -> str: """ Resolve all route(6)s for an origin, returning a space-separated list of all originating prefixes, not including duplicates. """ try: _, asn = parse_as_number(origin) except ValidationError as ve: raise WhoisQueryParserException(str(ve)) query = self._prepare_query().object_classes([object_class]).asn(asn) query_result = self.database_handler.execute_query(query) prefixes = [r['parsed_data'][object_class] for r in query_result] unique_prefixes: List[str] = [] for prefix in prefixes: if prefix not in unique_prefixes: unique_prefixes.append(prefix) return ' '.join(unique_prefixes) def handle_irrd_set_members(self, parameter: str) -> str: """ !i query - find all members of an as-set or route-set, possibly recursively. e.g. !iAS-FOO for non-recursive, !iAS-FOO,1 for recursive """ recursive = False if parameter.endswith(',1'): recursive = True parameter = parameter[:-2] self._current_set_priority_source = None if not recursive: members, leaf_members = self._find_set_members({parameter}) members.update(leaf_members) else: members = self._recursive_set_resolve({parameter}) if parameter in members: members.remove(parameter) return ' '.join(sorted(members)) def _recursive_set_resolve(self, members: Set[str], sets_seen=None) -> Set[str]: """ Resolve all members of a number of sets, recursively. For each set in members, determines whether it has been seen already (to prevent infinite recursion), ignores it if already seen, and then either adds it directly or adds it to a set that requires further resolving. """ if not sets_seen: sets_seen = set() if all([member in sets_seen for member in members]): return set() sets_seen.update(members) set_members = set() sub_members, leaf_members = self._find_set_members(members) set_members.update(leaf_members) for sub_member in sub_members: try: IP(sub_member) set_members.add(sub_member) continue except ValueError: pass try: parse_as_number(sub_member) set_members.add(sub_member) continue except ValueError: pass further_resolving_required = sub_members - set_members - sets_seen new_members = self._recursive_set_resolve(further_resolving_required, sets_seen) set_members.update(new_members) return set_members def _find_set_members(self, set_names: Set[str]) -> Tuple[Set[str], Set[str]]: """ Find all members of a number of route-sets or as-sets. Includes both direct members listed in members attribute, but also members included by mbrs-by-ref/member-of. Returns a tuple of two sets: - members found of the sets included in set_names, both references to other sets and direct AS numbers, etc. - leaf members that were included in set_names, i.e. names for which no further data could be found - for example references to non-existent other sets """ members: Set[str] = set() sets_already_resolved: Set[str] = set() query = self._prepare_query().object_classes(['as-set', 'route-set' ]).rpsl_pks(set_names) if self._current_set_priority_source: query.prioritise_source(self._current_set_priority_source) query_result = list(self.database_handler.execute_query(query)) if not query_result: # No sub-members are found, and apparantly all inputs were leaf members. return set(), set_names # Track the source of the root object set if not self._current_set_priority_source: self._current_set_priority_source = query_result[0]['source'] for result in query_result: rpsl_pk = result['rpsl_pk'] # The same PK may occur in multiple sources, but we are # only interested in the first matching object, prioritised # to look for the same source as the root object. This priority # is part of the query ORDER BY, so basically we only process # an RPSL pk once. if rpsl_pk in sets_already_resolved: continue sets_already_resolved.add(rpsl_pk) object_class = result['object_class'] object_data = result['parsed_data'] mbrs_by_ref = object_data.get('mbrs-by-ref', None) for members_attr in ['members', 'mp-members']: if members_attr in object_data: members.update(set(object_data[members_attr])) if not rpsl_pk or not object_class or not mbrs_by_ref: continue # If mbrs-by-ref is set, find any objects with member-of pointing to the route/as-set # under query, and include a maintainer listed in mbrs-by-ref, unless mbrs-by-ref # is set to ANY. query_object_class = [ 'route', 'route6' ] if object_class == 'route-set' else ['aut-num'] query = self._prepare_query().object_classes(query_object_class) query = query.lookup_attrs_in(['member-of'], [rpsl_pk]) if 'ANY' not in [m.strip().upper() for m in mbrs_by_ref]: query = query.lookup_attrs_in(['mnt-by'], mbrs_by_ref) referring_objects = self.database_handler.execute_query(query) for result in referring_objects: member_object_class = result['object_class'] members.add(result['parsed_data'][member_object_class]) leaf_members = set_names - sets_already_resolved return members, leaf_members def handle_irrd_database_serial_range(self, parameter: str) -> str: """!j query - database serial range""" if parameter == '-*': sources = self.all_valid_sources else: sources = [s.upper() for s in parameter.split(',')] invalid_sources = [ s for s in sources if s not in self.all_valid_sources ] query = RPSLDatabaseStatusQuery().sources(sources) query_results = self.database_handler.execute_query(query) result_txt = '' for query_result in query_results: source = query_result['source'].upper() keep_journal = 'Y' if get_setting( f'sources.{source}.keep_journal') else 'N' serial_oldest = query_result['serial_oldest_journal'] serial_newest = query_result['serial_newest_journal'] fields = [ source, keep_journal, f'{serial_oldest}-{serial_newest}' if serial_oldest and serial_newest else '-', ] if query_result['serial_last_dump']: fields.append(str(query_result['serial_last_dump'])) result_txt += ':'.join(fields) + '\n' for invalid_source in invalid_sources: result_txt += f'{invalid_source.upper()}:X:Database unknown\n' return result_txt.strip() def handle_irrd_exact_key(self, parameter: str): """!m query - exact object key lookup, e.g. !maut-num,AS65537""" try: object_class, rpsl_pk = parameter.split(',', maxsplit=1) except ValueError: raise WhoisQueryParserException( f'Invalid argument for object lookup: {parameter}') query = self._prepare_query().object_classes( [object_class]).rpsl_pk(rpsl_pk).first_only() return self._execute_query_flatten_output(query) def handle_irrd_route_search(self, parameter: str): """ !r query - route search with various options: !r192.0.2.0/24 returns all exact matching objects !r192.0.2.0/24,o returns space-separated origins of all exact matching objects !r192.0.2.0/24,l returns all one-level less specific objects, not including exact !r192.0.2.0/24,L returns all less specific objects, including exact !r192.0.2.0/24,M returns all more specific objects, not including exact """ option: Optional[str] = None if ',' in parameter: address, option = parameter.split(',') else: address = parameter try: address = IP(address) except ValueError: raise WhoisQueryParserException( f'Invalid input for route search: {parameter}') query = self._prepare_query().object_classes(['route', 'route6']) if option is None or option == 'o': query = query.ip_exact(address) elif option == 'l': query = query.ip_less_specific_one_level(address) elif option == 'L': query = query.ip_less_specific(address) elif option == 'M': query = query.ip_more_specific(address) else: raise WhoisQueryParserException( f'Invalid route search option: {option}') if option == 'o': query_result = self.database_handler.execute_query(query) prefixes = [r['parsed_data']['origin'] for r in query_result] return ' '.join(prefixes) return self._execute_query_flatten_output(query) def handle_irrd_sources_list(self, parameter: str) -> Optional[str]: """ !s query - set used sources !s-lc returns all enabled sources, space separated !sripe,nttcom limits sources to ripe and nttcom """ if parameter == '-lc': return ','.join(self.all_valid_sources) if parameter: sources = parameter.upper().split(',') if not all( [source in self.all_valid_sources for source in sources]): raise WhoisQueryParserException( "One or more selected sources are unavailable.") self.sources = sources else: raise WhoisQueryParserException( "One or more selected sources are unavailable.") return None def handle_irrd_version(self): """!v query - return version""" return f'IRRD -- version {__version__}' def handle_ripe_command(self, full_query: str) -> WhoisQueryResponse: """ Process RIPE-style queries. Any query that is not explicitly an IRRD-style query (i.e. starts with exclamation mark) is presumed to be a RIPE query. """ full_query = re.sub(' +', ' ', full_query) components = full_query.strip().split(' ') result = None response_type = WhoisQueryResponseType.SUCCESS while len(components): component = components.pop(0) if component.startswith('-'): command = component[1:] try: if command == 'k': self.multiple_command_mode = True elif command in 'lLMx': result = self.handle_ripe_route_search( command, components.pop(0)) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND break elif command == 'i': result = self.handle_inverse_attr_search( components.pop(0), components.pop(0)) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND break elif command == 's': self.handle_ripe_sources_list(components.pop(0)) elif command == 'a': self.handle_ripe_sources_list(None) elif command == 'T': self.handle_ripe_restrict_object_class( components.pop(0)) elif command == 't': result = self.handle_ripe_request_object_template( components.pop(0)) break elif command == 'K': self.handle_ripe_key_fields_only() elif command in 'V': self.handle_user_agent(components.pop(0)) elif command in 'Fr': continue # These flags disable recursion, but IRRd never performs recursion anyways else: raise WhoisQueryParserException( f'Unrecognised flag/search: {command}') except IndexError: raise WhoisQueryParserException( f'Missing argument for flag/search: {command}') else: # assume query to be a free text search result = self.handle_ripe_text_search(component) return WhoisQueryResponse( response_type=response_type, mode=WhoisQueryResponseMode.RIPE, result=result, ) def handle_ripe_route_search(self, command: str, parameter: str) -> str: """ -l/L/M/x query - route search for: -x 192.0.2.0/2 returns all exact matching objects -l 192.0.2.0/2 returns all one-level less specific objects, not including exact -L 192.0.2.0/2 returns all less specific objects, including exact -M 192.0.2.0/2 returns all more specific objects, not including exact """ try: address = IP(parameter) except ValueError: raise WhoisQueryParserException( f'Invalid input for route search: {parameter}') query = self._prepare_query().object_classes(['route', 'route6']) if command == 'x': query = query.ip_exact(address) elif command == 'l': query = query.ip_less_specific_one_level(address) elif command == 'L': query = query.ip_less_specific(address) elif command == 'M': query = query.ip_more_specific(address) return self._execute_query_flatten_output(query) def handle_ripe_sources_list(self, sources_list: Optional[str]) -> None: """-s/-a parameter - set sources list. Empty list enables all sources. """ if sources_list: sources = sources_list.upper().split(',') if not all( [source in self.all_valid_sources for source in sources]): raise WhoisQueryParserException( "One or more selected sources are unavailable.") self.sources = sources else: self.sources = [] def handle_ripe_restrict_object_class(self, object_classes) -> None: """-T parameter - restrict object classes for this query, comma-seperated""" self.object_classes = object_classes.split(',') def handle_ripe_request_object_template(self, object_class) -> str: """-t query - return the RPSL template for an object class""" try: return OBJECT_CLASS_MAPPING[object_class]().generate_template() except KeyError: raise WhoisQueryParserException( f'Unknown object class: {object_class}') def handle_ripe_key_fields_only(self) -> None: """-K paramater - only return primary key and members fields""" self.key_fields_only = True def handle_ripe_text_search(self, value: str) -> str: query = self._prepare_query().text_search(value) return self._execute_query_flatten_output(query) def handle_user_agent(self, user_agent: str): """-V/!n parameter/query - set a user agent for the client""" self.user_agent = user_agent logger.info(f'{self.peer}: user agent set to: {user_agent}') def handle_inverse_attr_search(self, attribute: str, value: str) -> str: """ -i/!o query - inverse search for attribute values e.g. `-i mnt-by FOO` finds all objects where (one of the) maintainer(s) is FOO, as does `!oFOO`. Restricted to designated lookup fields. """ if attribute not in self.lookup_field_names: readable_lookup_field_names = ", ".join(self.lookup_field_names) msg = ( f'Inverse attribute search not supported for {attribute},' + f'only supported for attributes: {readable_lookup_field_names}' ) raise WhoisQueryParserException(msg) query = self._prepare_query().lookup_attr(attribute, value) return self._execute_query_flatten_output(query) def _prepare_query(self) -> RPSLDatabaseQuery: """Prepare an RPSLDatabaseQuery by applying relevant sources/class filters.""" query = RPSLDatabaseQuery() if self.sources: query.sources(self.sources) if self.object_classes: query.object_classes(self.object_classes) return query def _execute_query_flatten_output(self, query: RPSLDatabaseQuery) -> str: """ Execute an RPSLDatabaseQuery, and flatten the output into a string with object text for easy passing to a WhoisQueryResponse. """ query_response = self.database_handler.execute_query(query) if self.key_fields_only: result = self._filter_key_fields(query_response) else: result = '' for obj in query_response: result += obj['object_text'] + '\n' return result.strip('\n\r') def _filter_key_fields(self, query_response) -> str: result = '' for obj in query_response: rpsl_object_class = OBJECT_CLASS_MAPPING[obj['object_class']] fields_included = rpsl_object_class.pk_fields + [ 'members', 'mp-members', 'member-of' ] for field_name in fields_included: field_data = obj['parsed_data'].get(field_name) if field_data: if isinstance(field_data, list): for item in field_data: result += f'{field_name}: {item}\n' else: result += f'{field_name}: {field_data}\n' result += '\n' return result
class UpdateRequestHandler: """ The UpdateRequestHandler handles the text of one or more RPSL updates (create, modify or delete), parses, validates and eventually saves them. This includes validating references between objects, including those part of the same update message, and checking authentication. """ def __init__(self, object_texts: str, pgp_fingerprint: str = None, request_meta: Dict[str, Optional[str]] = None) -> None: self.database_handler = DatabaseHandler() self.request_meta = request_meta if request_meta else {} self._pgp_key_id = self._resolve_pgp_key_id( pgp_fingerprint) if pgp_fingerprint else None self._handle_object_texts(object_texts) self.database_handler.commit() self.database_handler.close() def _handle_object_texts(self, object_texts: str) -> None: reference_validator = ReferenceValidator(self.database_handler) auth_validator = AuthValidator(self.database_handler, self._pgp_key_id) results = parse_update_requests(object_texts, self.database_handler, auth_validator, reference_validator) # When an object references another object, e.g. tech-c referring a person or mntner, # an add/update is only valid if those referred objects exist. To complicate matters, # the object referred to may be part of this very same update. For this reason, the # reference validator can be provided with all new objects to be added in this update. # However, a possible scenario is that A, B and C are submitted. Object A refers to B, # B refers to C, C refers to D and D does not exist - or C fails authentication. # At a first scan, A is valid because B exists, B is valid because C exists. C # becomes invalid on the first scan, which is why another scan is performed, which # will mark B invalid due to the reference to an invalid C, etc. This continues until # all references are resolved and repeated scans lead to the same conclusions. valid_updates = [r for r in results if r.is_valid()] previous_valid_updates: List[UpdateRequest] = [] loop_count = 0 loop_max = len(results) + 10 while valid_updates != previous_valid_updates: previous_valid_updates = valid_updates reference_validator.preload(valid_updates) auth_validator.pre_approve(valid_updates) for result in valid_updates: result.validate() valid_updates = [r for r in results if r.is_valid()] loop_count += 1 if loop_count > loop_max: # pragma: no cover msg = 'Update validity resolver ran an excessive amount of loops, may be stuck, aborting processing' raise ValueError(msg) for result in results: if result.is_valid(): result.save(self.database_handler) self.results = results def _resolve_pgp_key_id(self, pgp_fingerprint: str) -> Optional[str]: """ Find a PGP key ID for a given fingerprint. This method looks for an actual matching object in the database, and then returns the object's PK. """ clean_fingerprint = pgp_fingerprint.replace(' ', '') key_id = "PGPKEY-" + clean_fingerprint[-8:] query = RPSLDatabaseQuery().object_classes(['key-cert' ]).rpsl_pk(key_id) results = list(self.database_handler.execute_query(query)) for result in results: if result['parsed_data'].get('fingerpr', '').replace(' ', '') == clean_fingerprint: return key_id return None def status(self) -> str: """Provide a simple SUCCESS/FAILED string based - former used if all objects were saved.""" if all([ result.status == UpdateRequestStatus.SAVED for result in self.results ]): return "SUCCESS" return "FAILED" def user_report(self) -> str: """Produce a human-readable report for the user.""" # flake8: noqa: W293 successful = [ r for r in self.results if r.status == UpdateRequestStatus.SAVED ] failed = [ r for r in self.results if r.status != UpdateRequestStatus.SAVED ] number_successful_create = len([ r for r in successful if r.request_type == UpdateRequestType.CREATE ]) number_successful_modify = len([ r for r in successful if r.request_type == UpdateRequestType.MODIFY ]) number_successful_delete = len([ r for r in successful if r.request_type == UpdateRequestType.DELETE ]) number_failed_create = len( [r for r in failed if r.request_type == UpdateRequestType.CREATE]) number_failed_modify = len( [r for r in failed if r.request_type == UpdateRequestType.MODIFY]) number_failed_delete = len( [r for r in failed if r.request_type == UpdateRequestType.DELETE]) request_meta_str = '\n'.join( [f"> {k}: {v}" for k, v in self.request_meta.items() if v]) if request_meta_str: request_meta_str = "\n" + request_meta_str + "\n\n" user_report = request_meta_str + textwrap.dedent(f""" SUMMARY OF UPDATE: Number of objects found: {len(self.results):3} Number of objects processed successfully: {len(successful):3} Create: {number_successful_create:3} Modify: {number_successful_modify:3} Delete: {number_successful_delete:3} Number of objects processed with errors: {len(failed):3} Create: {number_failed_create:3} Modify: {number_failed_modify:3} Delete: {number_failed_delete:3} DETAILED EXPLANATION: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ """) for result in self.results: user_report += "---\n" user_report += result.user_report() user_report += "\n" user_report += '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n' return user_report