def get_engine(): global engine if not engine: database_pool_size = (get_setting('server.whois.max_connections') + len(get_setting('sources', []))) * 3 engine = sa.create_engine( get_setting('database_url'), pool_size=database_pool_size, max_overflow=30, json_deserializer=ujson.loads, ) return engine
def send_email(recipient, subject, body) -> None: body += get_setting('email.footer') hostname = socket.gethostname() body += f'\n\nGenerated by IRRD version {__version__} on {hostname}' msg = MIMEText(body) msg['Subject'] = subject msg['From'] = get_setting('email.from') msg['To'] = recipient s = SMTP(get_setting('email.smtp')) s.send_message(msg) s.quit()
def run_http_server(config_path: str): setproctitle('irrd-http-server-manager') configuration = get_configuration() assert configuration os.environ[ENV_UVICORN_WORKER_CONFIG_PATH] = config_path uvicorn.run( app="irrd.server.http.app:app", host=get_setting('server.http.interface'), port=get_setting('server.http.port'), workers=get_setting('server.http.workers'), forwarded_allow_ips=get_setting('server.http.forwarded_allowed_ips'), headers=[['Server', f'IRRd {__version__}']], log_config=configuration.logging_config, )
def __init__(self): object_class_filter = get_setting( f'sources.{self.source}.object_class_filter') if object_class_filter: if isinstance(object_class_filter, str): object_class_filter = [object_class_filter] self.object_class_filter = [ c.strip().lower() for c in object_class_filter ] else: self.object_class_filter = None self.strict_validation_key_cert = get_setting( f'sources.{self.source}.strict_import_keycert_objects', False)
def run(self) -> None: if get_setting('rpki.roa_source'): import_timer = int(get_setting('rpki.roa_import_timer')) self.run_if_relevant(RPKI_IRR_PSEUDO_SOURCE, ROAImportRunner, import_timer) if self._check_scopefilter_change(): self.run_if_relevant('scopefilter', ScopeFilterUpdateRunner, 0) sources_started = 0 for source in get_setting('sources', {}).keys(): if sources_started >= MAX_SIMULTANEOUS_RUNS: break started_import = False started_export = False is_mirror = get_setting(f'sources.{source}.import_source') or get_setting(f'sources.{source}.nrtm_host') import_timer = int(get_setting(f'sources.{source}.import_timer', DEFAULT_SOURCE_IMPORT_TIMER)) if is_mirror: started_import = self.run_if_relevant(source, RPSLMirrorImportUpdateRunner, import_timer) runs_export = get_setting(f'sources.{source}.export_destination') export_timer = int(get_setting(f'sources.{source}.export_timer', DEFAULT_SOURCE_EXPORT_TIMER)) if runs_export: started_export = self.run_if_relevant(source, SourceExportRunner, export_timer) if started_import or started_export: sources_started += 1
def run(self) -> None: for source in get_setting('sources').keys(): is_mirror = get_setting(f'sources.{source}.import_source') or get_setting(f'sources.{source}.nrtm_host') import_timer = int(get_setting(f'sources.{source}.import_timer', 300)) current_time = time.time() has_expired = (self.last_started_time[source] + import_timer) < current_time if is_mirror and has_expired and not self._is_thread_running(source): logger.debug(f'Started new thread for mirror update for {source}') initiator = MirrorUpdateRunner(source=source) thread = threading.Thread(target=initiator.run, name=f'Thread-MirrorUpdateRunner-{source}') self.threads[source] = thread thread.start() self.last_started_time[source] = int(current_time)
def test_send_email(self, monkeypatch): mock_smtp = Mock() monkeypatch.setattr('irrd.utils.email.SMTP', lambda server: mock_smtp) send_email('Sasha <*****@*****.**>', 'subject', 'body') assert mock_smtp.mock_calls[0][0] == 'send_message' assert mock_smtp.mock_calls[0][1][0]['From'] == get_setting( 'email.from') assert mock_smtp.mock_calls[0][1][0][ 'To'] == 'Sasha <*****@*****.**>' assert mock_smtp.mock_calls[0][1][0]['Subject'] == 'subject' payload = mock_smtp.mock_calls[0][1][0].get_payload() assert 'body' in payload assert 'IRRd version' in payload assert get_setting('email.footer') in payload assert mock_smtp.mock_calls[1][0] == 'quit'
def run(self, database_handler: DatabaseHandler, serial_newest_seen: Optional[int] = None, force_reload=False): import_sources = get_setting(f'sources.{self.source}.import_source') if isinstance(import_sources, str): import_sources = [import_sources] import_serial_source = get_setting( f'sources.{self.source}.import_serial_source') if not import_sources: logger.info( f'Skipping full import for {self.source}, import_source not set.' ) return logger.info( f'Running full import of {self.source} from {import_sources}, serial from {import_serial_source}' ) import_serial = None if import_serial_source: import_serial = int( self._retrieve_file(import_serial_source, return_contents=True)[0]) if not force_reload and serial_newest_seen is not None and import_serial <= serial_newest_seen: logger.info( f'Current newest serial seen for {self.source} is ' f'{serial_newest_seen}, import_serial is {import_serial}, cancelling import.' ) return database_handler.delete_all_rpsl_objects_with_journal(self.source) import_data = [ self._retrieve_file(import_source, return_contents=False) for import_source in import_sources ] database_handler.disable_journaling() for import_filename, to_delete in import_data: p = MirrorFileImportParser(source=self.source, filename=import_filename, serial=import_serial, database_handler=database_handler) p.run_import() if to_delete: os.unlink(import_filename)
def get_engine(): global engine if engine: return engine engine = sa.create_engine( translate_url(get_setting('database_url')), pool_size=2, json_deserializer=ujson.loads, ) # https://docs.sqlalchemy.org/en/13/core/pooling.html#using-connection-pools-with-multiprocessing @sa.event.listens_for(engine, "connect") def connect(dbapi_connection, connection_record): connection_record.info['pid'] = os.getpid() @sa.event.listens_for(engine, "checkout") def checkout(dbapi_connection, connection_record, connection_proxy): pid = os.getpid() if connection_record.info['pid'] != pid: # pragma: no cover connection_record.connection = connection_proxy.connection = None raise sa.exc.DisconnectionError( "Connection record belongs to pid %s, " "attempting to check out in pid %s" % (connection_record.info['pid'], pid)) return engine
def set_last_modified(): dh = DatabaseHandler() auth_sources = [ k for k, v in get_setting('sources').items() if v.get('authoritative') ] q = RPSLDatabaseQuery(column_names=['pk', 'object_text', 'updated'], enable_ordering=False) q = q.sources(auth_sources) results = list(dh.execute_query(q)) print(f'Updating {len(results)} objects in sources {auth_sources}') for result in results: rpsl_obj = rpsl_object_from_text(result['object_text'], strict_validation=False) if rpsl_obj.messages.errors(): # pragma: no cover print( f'Failed to process {rpsl_obj}: {rpsl_obj.messages.errors()}') continue new_text = rpsl_obj.render_rpsl_text(result['updated']) stmt = RPSLDatabaseObject.__table__.update().where( RPSLDatabaseObject.__table__.c.pk == result['pk']).values( object_text=new_text, ) dh.execute_statement(stmt) dh.commit() dh.close()
def record_operation(self, operation: DatabaseOperation, rpsl_pk: str, source: str, object_class: str, object_text: str, forced_serial: Optional[int]) -> None: """ Make a record in the journal of a change to an object. Will only record changes when self.journaling_enabled is set, and the database.SOURCE.keep_journal is set. Note that this method locks the journal table for writing to ensure a gapless set of NRTM serials. """ if self.journaling_enabled and get_setting(f'sources.{source}.keep_journal'): journal_tablename = RPSLDatabaseJournal.__tablename__ if forced_serial is None: self.database_handler.execute_statement(f'LOCK TABLE {journal_tablename} IN EXCLUSIVE MODE') serial_nrtm = sa.select([sa.text(f'COALESCE(MAX(serial_nrtm), 0) + 1')]) serial_nrtm = serial_nrtm.where(RPSLDatabaseJournal.__table__.c.source == source) serial_nrtm = serial_nrtm.as_scalar() else: serial_nrtm = forced_serial stmt = RPSLDatabaseJournal.__table__.insert().values( rpsl_pk=rpsl_pk, source=source, operation=operation, object_class=object_class, object_text=object_text, serial_nrtm=serial_nrtm, ).returning(self.c_journal.serial_nrtm) insert_result = self.database_handler.execute_statement(stmt) inserted_serial = insert_result.fetchone()['serial_nrtm'] self._new_serials_per_source[source].add(inserted_serial) elif forced_serial is not None: self._new_serials_per_source[source].add(forced_serial)
def main(): # pragma: no cover description = """Update a database based on a RPSL file.""" parser = argparse.ArgumentParser(description=description) parser.add_argument( '--config', dest='config_file_path', type=str, help= f'use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})') parser.add_argument('--source', dest='source', type=str, required=True, help=f'name of the source, e.g. NTTCOM') parser.add_argument('input_file', type=str, help='the name of a file to read') args = parser.parse_args() config_init(args.config_file_path) if get_setting('database_readonly'): print('Unable to run, because database_readonly is set') sys.exit(-1) sys.exit(update(args.source, args.input_file))
def handle_irrd_set_members(self, parameter: str) -> str: """ !i query - find all members of an as-set or route-set, possibly recursively. e.g. !iAS-FOO for non-recursive, !iAS-FOO,1 for recursive """ self._preloaded_query_called() recursive = False if parameter.endswith(',1'): recursive = True parameter = parameter[:-2] self._current_set_root_object_class = None if not recursive: members, leaf_members = self._find_set_members({parameter}) members.update(leaf_members) else: members = self._recursive_set_resolve({parameter}) if parameter in members: members.remove(parameter) if get_setting('compatibility.ipv4_only_route_set_members'): original_members = set(members) for member in original_members: try: IP(member) except ValueError: continue # This is not a prefix, ignore. try: IP(member, ipversion=4) except ValueError: # This was a valid prefix, but not a valid IPv4 prefix, # and should be removed. members.remove(member) return ' '.join(sorted(members))
def validate(self, source: str, prefix: Optional[IP] = None, asn: Optional[int] = None) -> ScopeFilterStatus: """ Validate a prefix and/or ASN, for a particular source. Returns a tuple of a ScopeFilterStatus and an explanation string. """ if not prefix and asn is None: raise ValueError( 'Scope Filter validator must be provided asn or prefix') if get_setting(f'sources.{source}.scopefilter_excluded'): return ScopeFilterStatus.in_scope if prefix: for filtered_prefix in self.filtered_prefixes: if prefix.version() == filtered_prefix.version( ) and filtered_prefix.overlaps(prefix): return ScopeFilterStatus.out_scope_prefix if asn is not None: if asn in self.filtered_asns: return ScopeFilterStatus.out_scope_as for range_start, range_end in self.filtered_asn_ranges: if range_start <= asn <= range_end: return ScopeFilterStatus.out_scope_as return ScopeFilterStatus.in_scope
def handle_irrd_database_serial_range(self, parameter: str) -> str: """!j query - database serial range""" if parameter == '-*': sources = self.all_valid_sources else: sources = [s.upper() for s in parameter.split(',')] invalid_sources = [ s for s in sources if s not in self.all_valid_sources ] query = RPSLDatabaseStatusQuery().sources(sources) query_results = self.database_handler.execute_query(query) result_txt = '' for query_result in query_results: source = query_result['source'].upper() keep_journal = 'Y' if get_setting( f'sources.{source}.keep_journal') else 'N' serial_oldest = query_result['serial_oldest_journal'] serial_newest = query_result['serial_newest_journal'] fields = [ source, keep_journal, f'{serial_oldest}-{serial_newest}' if serial_oldest and serial_newest else '-', ] if query_result['serial_last_dump']: fields.append(str(query_result['serial_last_dump'])) result_txt += ':'.join(fields) + '\n' for invalid_source in invalid_sources: result_txt += f'{invalid_source.upper()}:X:Database unknown\n' return result_txt.strip()
def run(self) -> None: self.database_handler = DatabaseHandler() try: serial_newest_seen, force_reload = self._status() nrtm_enabled = bool( get_setting(f'sources.{self.source}.nrtm_host')) logger.debug( f'Most recent serial seen for {self.source}: {serial_newest_seen},' f'force_reload: {force_reload}, nrtm enabled: {nrtm_enabled}') if force_reload or not serial_newest_seen or not nrtm_enabled: self.full_import_runner.run( database_handler=self.database_handler, serial_newest_seen=serial_newest_seen, force_reload=force_reload) else: self.update_stream_runner.run( serial_newest_seen, database_handler=self.database_handler) self.database_handler.commit() except OSError as ose: # I/O errors can occur and should not log a full traceback (#177) logger.error( f'An error occurred while attempting a mirror update or initial import ' f'for {self.source}: {ose}') except Exception as exc: logger.error( f'An exception occurred while attempting a mirror update or initial import ' f'for {self.source}: {exc}', exc_info=exc) finally: self.database_handler.close()
def __init__(self, preloader: Preloader, database_handler: DatabaseHandler) -> None: self.all_valid_sources = list(get_setting('sources', {}).keys()) self.sources_default = list(get_setting('sources_default', [])) self.sources: List[ str] = self.sources_default if self.sources_default else self.all_valid_sources if get_setting('rpki.roa_source'): self.all_valid_sources.append(RPKI_IRR_PSEUDO_SOURCE) self.object_class_filter: List[str] = [] self.rpki_aware = bool(get_setting('rpki.roa_source')) self.rpki_invalid_filter_enabled = self.rpki_aware self.out_scope_filter_enabled = True self.user_agent: Optional[str] = None self.preloader = preloader self.database_handler = database_handler self.sql_queries: List[str] = [] self.sql_trace = False
def __init__(self, peer: str) -> None: self.all_valid_sources = list(get_setting('sources').keys()) self.sources: List[str] = [] self.object_classes: List[str] = [] self.user_agent: Optional[str] = None self.multiple_command_mode = False self.key_fields_only = False self.peer = peer
def load_filters(self): """ (Re)load the local cache of the configured filters. Also called by __init__ """ prefixes = get_setting('scopefilter.prefixes', []) self.filtered_prefixes = [IP(prefix) for prefix in prefixes] self.filtered_asns = set() self.filtered_asn_ranges = set() asn_filters = get_setting('scopefilter.asns', []) for asn_filter in asn_filters: if '-' in str(asn_filter): start, end = asn_filter.split('-') self.filtered_asn_ranges.add((int(start), int(end))) else: self.filtered_asns.add(int(asn_filter))
def send_email(recipient, subject, body) -> None: if get_setting('email.recipient_override'): recipient = get_setting('email.recipient_override') logger.debug(f'Sending email to {recipient}, subject {subject}') body += get_setting('email.footer') hostname = socket.gethostname() body += f'\n\nGenerated by IRRd version {__version__} on {hostname}' msg = MIMEText(body) msg['Subject'] = subject msg['From'] = get_setting('email.from') msg['To'] = recipient s = SMTP(get_setting('email.smtp')) s.send_message(msg) s.quit()
def __init__(self, source: str, nrtm_data: str, database_handler: DatabaseHandler) -> None: self.source = source self.database_handler = database_handler self.rpki_aware = bool(get_setting('rpki.roa_source')) super().__init__() self.operations: List[NRTMOperation] = [] self._split_stream(nrtm_data)
def validate_pgp_signature( message: str, detached_signature: Optional[str] = None ) -> Tuple[Optional[str], Optional[str]]: """ Verify a PGP signature in a message. If there is a valid signature, returns a tuple of an optional signed message part, and the PGP fingerprint, or None,None if there was no (valid) signature. The signed message part is relevant for inline signing, where only part of the message may be signed. If it is None, the entire message was signed. If detached_signature is set, it is expected to contain a PGP signature block that was used to sign message (for PGP/MIME signatures). For PGP/MIME, note that message should include the entire text/plain part of the signed message, including content-type headers. If there is a single PGP inline signed message in message, this message will be validated, and the signed part of the message is returned. If there are multiple inline PGP signed messages, this function returns None,None. Note that PGP validation is dependent on the PGP key already being in the keychain contained in the auth.gnupg_keyring setting. This is usually done by importing a key-cert, which will add the certificate to the keychain during validation, in RPSLKeyCert.clean(). """ gpg = gnupg.GPG(gnupghome=get_setting('auth.gnupg_keyring')) new_message = None if detached_signature: with NamedTemporaryFile() as data_file: data_file.write(message.encode(gpg.encoding)) data_file.flush() result = gpg.verify(detached_signature, data_filename=data_file.name) elif message.count('BEGIN PGP SIGNED MESSAGE') == 1: result = gpg.verify(message) match = pgp_inline_re.search( message.replace('\r\n', '\n').replace('\r', '\n')) if not match: # pragma: no cover msg = f'message contained an inline PGP signature, but regular expression failed to extract body: {message}' logger.info(msg) return None, None new_message = match.group(2) + '\n' else: return None, None log_message = result.stderr.replace('\n', ' -- ').replace( 'gpg: ', '') logger.info(f'checked PGP signature, response: {log_message}') if result.valid and result.key_status is None: logger.info( f'Found valid PGP signature, fingerprint {result.fingerprint}') return new_message, result.fingerprint return None, None
def is_client_permitted(ip: str, access_list_setting: str, default_deny=True, log=True) -> bool: """ Determine whether a client is permitted to access an interface, based on the value of the setting of access_list_setting. If default_deny is True, an unset or empty access list will lead to denial. IPv6-mapped IPv4 addresses are unmapped to regular IPv4 addresses before processing. """ try: client_ip = IP(ip) except (ValueError, AttributeError) as e: if log: logger.error( f'Rejecting request as client IP could not be read from ' f'{ip}: {e}') return False if client_ip.version() == 6: try: client_ip = client_ip.v46map() except ValueError: pass access_list_name = get_setting(access_list_setting) access_list = get_setting(f'access_lists.{access_list_name}') if not access_list_name or not access_list: if default_deny: if log: logger.info( f'Rejecting request, access list empty or undefined: {client_ip}' ) return False else: return True allowed = any([client_ip in IP(allowed) for allowed in access_list]) if not allowed and log: logger.info( f'Rejecting request, IP not in access list {access_list_name}: {client_ip}' ) return allowed
def test_send_email_with_recipient_override(self, monkeypatch, config_override): config_override({'email': {'recipient_override': '*****@*****.**'}}) mock_smtp = Mock() monkeypatch.setattr('irrd.utils.email.SMTP', lambda server: mock_smtp) send_email('Sasha <*****@*****.**>', 'subject', 'body') assert mock_smtp.mock_calls[0][0] == 'send_message' assert mock_smtp.mock_calls[0][1][0]['From'] == get_setting('email.from') assert mock_smtp.mock_calls[0][1][0]['To'] == '*****@*****.**' assert mock_smtp.mock_calls[0][1][0]['Subject'] == 'subject'
def __init__(self): object_class_filter = get_setting( f'sources.{self.source}.object_class_filter') if object_class_filter: self.object_class_filter = [ c.strip().lower() for c in object_class_filter.split(',') ] else: self.object_class_filter = None
def __init__(self, client_ip: str, client_str: str) -> None: self.all_valid_sources = list(get_setting('sources', {}).keys()) self.sources_default = get_setting('sources_default') self.sources: List[ str] = self.sources_default if self.sources_default else self.all_valid_sources if get_setting('rpki.roa_source'): self.all_valid_sources.append(RPKI_IRR_PSEUDO_SOURCE) self.object_classes: List[str] = [] self.user_agent: Optional[str] = None self.multiple_command_mode = False self.rpki_aware = bool(get_setting('rpki.roa_source')) self.rpki_invalid_filter_enabled = self.rpki_aware self.timeout = 30 self.key_fields_only = False self.client_ip = client_ip self.client_str = client_str self.preloader = Preloader() self._preloaded_query_count = 0
def send_notification_target_reports(self): # First key is e-mail address of recipient, second is UpdateRequestStatus.SAVED # or UpdateRequestStatus.ERROR_AUTH reports_per_recipient: Dict[str, Dict[UpdateRequestStatus, OrderedSet]] = defaultdict(dict) sources: OrderedSet[str] = OrderedSet() for result in self.results: for target in result.notification_targets(): if result.status in [ UpdateRequestStatus.SAVED, UpdateRequestStatus.ERROR_AUTH ]: if result.status not in reports_per_recipient[target]: reports_per_recipient[target][ result.status] = OrderedSet() reports_per_recipient[target][result.status].add( result.notification_target_report()) sources.add(result.rpsl_obj_new.source()) sources_str = '/'.join(sources) subject = f'Notification of {sources_str} database changes' header = get_setting('email.notification_header', '').format(sources_str=sources_str) header += '\nThis message is auto-generated.\n' header += 'The request was made with the following details:\n' header_saved = textwrap.dedent(""" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Some objects in which you are referenced have been created, deleted or changed. """) header_failed = textwrap.dedent(""" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Some objects in which you are referenced were requested to be created, deleted or changed, but *failed* the proper authorisation for any of the referenced maintainers. """) for recipient, reports_per_status in reports_per_recipient.items(): user_report = header + self._request_meta_str() if UpdateRequestStatus.ERROR_AUTH in reports_per_status: user_report += header_failed for report in reports_per_status[ UpdateRequestStatus.ERROR_AUTH]: user_report += f'---\n{report}\n' user_report += '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n' if UpdateRequestStatus.SAVED in reports_per_status: user_report += header_saved for report in reports_per_status[UpdateRequestStatus.SAVED]: user_report += f'---\n{report}\n' user_report += '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n' email.send_email(recipient, subject, user_report)
def __init__(self, server_address, bind_and_activate=True): # noqa: N803 self.address_family = socket.AF_INET6 if IP(server_address[0]).version() == 6 else socket.AF_INET super().__init__(server_address, None, bind_and_activate) self.connection_queue = mp.Queue() self.workers = [] for i in range(int(get_setting('server.whois.max_connections'))): worker = WhoisWorker(self.connection_queue) worker.start() self.workers.append(worker)
def _import_roas(self): roa_source = get_setting('rpki.roa_source') slurm_source = get_setting('rpki.slurm_source') logger.info(f'Running full ROA import from: {roa_source}, SLURM {slurm_source}') self.database_handler.delete_all_roa_objects() self.database_handler.delete_all_rpsl_objects_with_journal(RPKI_IRR_PSEUDO_SOURCE) slurm_data = None if slurm_source: slurm_data, _ = self._retrieve_file(slurm_source, return_contents=True) roa_filename, roa_to_delete = self._retrieve_file(roa_source, return_contents=False) with open(roa_filename) as fh: roa_importer = ROADataImporter(fh.read(), slurm_data, self.database_handler) if roa_to_delete: os.unlink(roa_filename) logger.info(f'ROA import from {roa_source}, SLURM {slurm_source}, imported {len(roa_importer.roa_objs)} ROAs, running validator') return roa_importer.roa_objs
def is_client_permitted(self, peer): try: client_ip = IP(peer.host) except (ValueError, AttributeError) as e: logger.error( f'Rejecting request as whois client IP could not be read from ' f'{peer}: {e}') return False access_list_name = get_setting('server.whois.access_list') access_list = get_setting(f'access_lists.{access_list_name}') if not access_list: return True allowed = any([client_ip in IP(allowed) for allowed in access_list]) if not allowed: logger.info( f'Rejecting whois request, IP not in access list: {client_ip}') return allowed