예제 #1
0
def update(source, filename) -> int:
    if any([
            get_setting(f'sources.{source}.import_source'),
            get_setting(f'sources.{source}.import_serial_source')
    ]):
        print(
            f'Error: to use this command, import_source and import_serial_source '
            f'for source {source} must not be set.')
        return 2

    dh = DatabaseHandler()
    roa_validator = BulkRouteROAValidator(dh)
    parser = MirrorUpdateFileImportParser(source,
                                          filename,
                                          database_handler=dh,
                                          direct_error_return=True,
                                          roa_validator=roa_validator)
    error = parser.run_import()
    if error:
        dh.rollback()
    else:
        dh.commit()
    dh.close()
    if error:
        print(f'Error occurred while processing object:\n{error}')
        return 1
    return 0
예제 #2
0
def load(source, filename, serial) -> int:
    if any([
        get_setting(f'sources.{source}.import_source'),
        get_setting(f'sources.{source}.import_serial_source')
    ]):
        print(f'Error: to use this command, import_source and import_serial_source '
              f'for source {source} must not be set.')
        return 2

    dh = DatabaseHandler()
    roa_validator = BulkRouteROAValidator(dh)
    dh.delete_all_rpsl_objects_with_journal(source)
    dh.disable_journaling()
    parser = MirrorFileImportParser(
        source=source, filename=filename, serial=serial, database_handler=dh,
        direct_error_return=True, roa_validator=roa_validator)
    error = parser.run_import()
    if error:
        dh.rollback()
    else:
        dh.commit()
    dh.close()
    if error:
        print(f'Error occurred while processing object:\n{error}')
        return 1
    return 0
예제 #3
0
def set_last_modified():
    dh = DatabaseHandler()
    auth_sources = [
        k for k, v in get_setting('sources').items() if v.get('authoritative')
    ]
    q = RPSLDatabaseQuery(column_names=['pk', 'object_text', 'updated'],
                          enable_ordering=False)
    q = q.sources(auth_sources)

    results = list(dh.execute_query(q))
    print(f'Updating {len(results)} objects in sources {auth_sources}')
    for result in results:
        rpsl_obj = rpsl_object_from_text(result['object_text'],
                                         strict_validation=False)
        if rpsl_obj.messages.errors():  # pragma: no cover
            print(
                f'Failed to process {rpsl_obj}: {rpsl_obj.messages.errors()}')
            continue
        new_text = rpsl_obj.render_rpsl_text(result['updated'])
        stmt = RPSLDatabaseObject.__table__.update().where(
            RPSLDatabaseObject.__table__.c.pk == result['pk']).values(
                object_text=new_text, )
        dh.execute_statement(stmt)
    dh.commit()
    dh.close()
예제 #4
0
class SourceExportRunner:
    """
    This SourceExportRunner is the entry point for the expect process
    for a single source.

    A gzipped file will be created in the export_destination directory
    with the contents of the source, along with a CURRENTSERIAL file.

    The contents of the source are first written to a temporary file, and
    then moved in place.
    """
    def __init__(self, source: str) -> None:
        self.source = source

    def run(self) -> None:
        self.database_handler = DatabaseHandler()
        try:
            export_destination = get_setting(f'sources.{self.source}.export_destination')
            logger.info(f'Starting a source export for {self.source} to {export_destination}')
            self._export(export_destination)

            self.database_handler.commit()
        except Exception as exc:
            logger.error(f'An exception occurred while attempting to run an export '
                         f'for {self.source}: {exc}', exc_info=exc)
        finally:
            self.database_handler.close()

    def _export(self, export_destination):
        filename_export = Path(export_destination) / f'{self.source.lower()}.db.gz'
        export_tmpfile = NamedTemporaryFile(delete=False)
        filename_serial = Path(export_destination) / f'{self.source.upper()}.CURRENTSERIAL'

        query = DatabaseStatusQuery().source(self.source)

        try:
            serial = next(self.database_handler.execute_query(query))['serial_newest_seen']
        except StopIteration:
            logger.error(f'Unable to run export for {self.source}, internal database status is empty.')
            return

        with gzip.open(export_tmpfile, 'wb') as fh:
            query = RPSLDatabaseQuery().sources([self.source])
            for obj in self.database_handler.execute_query(query):
                object_bytes = remove_auth_hashes(obj['object_text']).encode('utf-8')
                fh.write(object_bytes + b'\n')

        if filename_export.exists():
            os.unlink(filename_export)
        if filename_serial.exists():
            os.unlink(filename_serial)
        shutil.move(export_tmpfile.name, filename_export)

        if serial is not None:
            with open(filename_serial, 'w') as fh:
                fh.write(str(serial))

        self.database_handler.record_serial_exported(self.source, serial)
        logger.info(f'Export for {self.source} complete, stored in {filename_export} / {filename_serial}')
예제 #5
0
class RPSLParse:
    obj_parsed = 0
    obj_errors = 0
    obj_unknown = 0
    unknown_object_classes: Set[str] = set()
    database_handler = None

    def main(self, filename, strict_validation, database, show_info=True):
        self.show_info = show_info
        if database:
            self.database_handler = DatabaseHandler(journaling_enabled=False)

        if filename == '-':  # pragma: no cover
            f = sys.stdin
        else:
            f = open(filename, encoding='utf-8', errors='backslashreplace')

        for paragraph in split_paragraphs_rpsl(f):
            self.parse_object(paragraph, strict_validation)

        print(
            f'Processed {self.obj_parsed} objects, {self.obj_errors} with errors'
        )
        if self.obj_unknown:
            unknown_formatted = ', '.join(self.unknown_object_classes)
            print(
                f'Ignored {self.obj_unknown} objects due to unknown object classes: {unknown_formatted}'
            )

        if self.database_handler:
            self.database_handler.commit()
            self.database_handler.close()

    def parse_object(self, rpsl_text, strict_validation):
        try:
            self.obj_parsed += 1
            obj = rpsl_object_from_text(rpsl_text.strip(),
                                        strict_validation=strict_validation)
            if (obj.messages.messages()
                    and self.show_info) or obj.messages.errors():
                if obj.messages.errors():
                    self.obj_errors += 1

                print(rpsl_text.strip())
                print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
                print(obj.messages)
                print('\n=======================================\n')

            if self.database_handler and obj and not obj.messages.errors():
                self.database_handler.upsert_rpsl_object(obj)

        except UnknownRPSLObjectClassException as e:
            self.obj_unknown += 1
            self.unknown_object_classes.add(str(e).split(':')[1].strip())
        except Exception as e:  # pragma: no cover
            print('=======================================')
            print(rpsl_text)
            print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
            raise e
예제 #6
0
class RPSLMirrorImportUpdateRunner:
    """
    This RPSLMirrorImportUpdateRunner is the entry point for updating a single
    database mirror, depending on current state.

    If there is no current mirrored data, will call RPSLMirrorFullImportRunner
    to run a new import from full export files. Otherwise, will call
    NRTMImportUpdateStreamRunner to retrieve new updates from NRTM.
    """
    def __init__(self, source: str) -> None:
        self.source = source
        self.full_import_runner = RPSLMirrorFullImportRunner(source)
        self.update_stream_runner = NRTMImportUpdateStreamRunner(source)

    def run(self) -> None:
        self.database_handler = DatabaseHandler()

        try:
            serial_newest_mirror, force_reload = self._status()
            nrtm_enabled = bool(
                get_setting(f'sources.{self.source}.nrtm_host'))
            logger.debug(
                f'Most recent mirrored serial for {self.source}: {serial_newest_mirror}, '
                f'force_reload: {force_reload}, nrtm enabled: {nrtm_enabled}')
            if force_reload or not serial_newest_mirror or not nrtm_enabled:
                self.full_import_runner.run(
                    database_handler=self.database_handler,
                    serial_newest_mirror=serial_newest_mirror,
                    force_reload=force_reload)
            else:
                self.update_stream_runner.run(
                    serial_newest_mirror,
                    database_handler=self.database_handler)

            self.database_handler.commit()
        except OSError as ose:
            # I/O errors can occur and should not log a full traceback (#177)
            logger.error(
                f'An error occurred while attempting a mirror update or initial import '
                f'for {self.source}: {ose}')
        except Exception as exc:
            logger.error(
                f'An exception occurred while attempting a mirror update or initial import '
                f'for {self.source}: {exc}',
                exc_info=exc)
        finally:
            self.database_handler.close()

    def _status(self) -> Tuple[Optional[int], Optional[bool]]:
        query = DatabaseStatusQuery().source(self.source)
        result = self.database_handler.execute_query(query)
        try:
            status = next(result)
            return status['serial_newest_mirror'], status['force_reload']
        except StopIteration:
            return None, None
예제 #7
0
    def generate_status(self) -> str:
        """
        Generate a human-readable overview of database status.
        """
        database_handler = DatabaseHandler()

        statistics_query = RPSLDatabaseObjectStatisticsQuery()
        self.statistics_results = list(database_handler.execute_query(statistics_query))
        status_query = DatabaseStatusQuery()
        self.status_results = list(database_handler.execute_query(status_query))

        results = [self._generate_header(), self._generate_statistics_table(), self._generate_source_detail()]
        database_handler.close()
        return '\n\n'.join(results)
예제 #8
0
def load(source, filename, serial) -> int:
    dh = DatabaseHandler()
    dh.delete_all_rpsl_objects_with_journal(source)
    dh.disable_journaling()
    parser = MirrorFileImportParser(source, filename, serial=serial, database_handler=dh, direct_error_return=True)
    error = parser.run_import()
    if error:
        dh.rollback()
    else:
        dh.commit()
    dh.close()
    if error:
        print(f'Error occurred while processing object:\n{error}')
        return 1
    return 0
예제 #9
0
def load_pgp_keys(source: str) -> None:
    dh = DatabaseHandler()
    query = RPSLDatabaseQuery(column_names=['rpsl_pk', 'object_text'])
    query = query.sources([source]).object_classes(['key-cert'])
    keycerts = dh.execute_query(query)

    for keycert in keycerts:
        rpsl_pk = keycert["rpsl_pk"]
        print(f'Loading key-cert {rpsl_pk}')
        # Parsing the keycert in strict mode will load it into the GPG keychain
        result = rpsl_object_from_text(keycert['object_text'], strict_validation=True)
        if result.messages.errors():
            print(f'Errors in PGP key {rpsl_pk}: {result.messages.errors()}')

    print('All valid key-certs loaded into the GnuPG keychain.')
    dh.close()
예제 #10
0
class MirrorUpdateRunner:
    """
    This MirrorUpdateRunner is the entry point for updating a single
    database mirror, depending on current state.

    If there is no current mirrored data, will call MirrorFullImportRunner
    to run a new import from full export files. Otherwise, will call
    NRTMUpdateStreamRunner to retrieve new updates from NRTM.
    """
    def __init__(self, source: str) -> None:
        self.source = source
        self.full_import_runner = MirrorFullImportRunner(source)
        self.update_stream_runner = NRTMUpdateStreamRunner(source)

    def run(self) -> None:
        self.database_handler = DatabaseHandler()

        try:
            serial_newest_seen, force_reload = self._status()
            logger.debug(
                f'Most recent serial seen for {self.source}: {serial_newest_seen}, force_reload: {force_reload}'
            )
            if not serial_newest_seen or force_reload:
                self.full_import_runner.run(
                    database_handler=self.database_handler)
            else:
                self.update_stream_runner.run(
                    serial_newest_seen, database_handler=self.database_handler)

            self.database_handler.commit()
        except Exception as exc:
            logger.critical(
                f'An exception occurred while attempting a mirror update or initial import '
                f'for {self.source}: {exc}',
                exc_info=exc)
        finally:
            self.database_handler.close()

    def _status(self) -> Tuple[Optional[int], Optional[bool]]:
        query = DatabaseStatusQuery().source(self.source)
        result = self.database_handler.execute_query(query)
        try:
            status = next(result)
            return status['serial_newest_seen'], status['force_reload']
        except StopIteration:
            return None, None
예제 #11
0
class ScopeFilterUpdateRunner:
    """
    Update the scope filter status for all objects.
    This runner does not actually import anything, the scope filter
    is in the configuration.
    """

    # API consistency with other importers, source is actually ignored
    def __init__(self, source=None):
        pass

    def run(self):
        self.database_handler = DatabaseHandler()

        try:
            validator = ScopeFilterValidator()
            status = validator.validate_all_rpsl_objects(self.database_handler)
            rpsl_objs_now_in_scope, rpsl_objs_now_out_scope_as, rpsl_objs_now_out_scope_prefix = status
            self.database_handler.update_scopefilter_status(
                rpsl_objs_now_in_scope=rpsl_objs_now_in_scope,
                rpsl_objs_now_out_scope_as=rpsl_objs_now_out_scope_as,
                rpsl_objs_now_out_scope_prefix=rpsl_objs_now_out_scope_prefix,
            )
            self.database_handler.commit()
            logger.info(
                f'Scopefilter status updated for all routes, '
                f'{len(rpsl_objs_now_in_scope)} newly in scope, '
                f'{len(rpsl_objs_now_out_scope_as)} newly out of scope AS, '
                f'{len(rpsl_objs_now_out_scope_prefix)} newly out of scope prefix'
            )

        except Exception as exc:
            logger.error(
                f'An exception occurred while attempting a scopefilter status update: {exc}',
                exc_info=exc)
        finally:
            self.database_handler.close()
예제 #12
0
class ROAImportRunner(FileImportRunnerBase):
    """
    This runner performs a full import of ROA objects.
    The URL file for the ROA export in JSON format is provided
    in the configuration.
    """
    # API consistency with other importers, source is actually ignored
    def __init__(self, source=None):
        pass

    def run(self):
        self.database_handler = DatabaseHandler()

        try:
            self.database_handler.disable_journaling()
            roa_objs = self._import_roas()
            # Do an early commit to make the new ROAs available to other processes.
            self.database_handler.commit()
            # The ROA import does not use journaling, but updating the RPKI
            # status may create journal entries.
            self.database_handler.enable_journaling()

            validator = BulkRouteROAValidator(self.database_handler, roa_objs)
            objs_now_valid, objs_now_invalid, objs_now_not_found = validator.validate_all_routes()
            self.database_handler.update_rpki_status(
                rpsl_objs_now_valid=objs_now_valid,
                rpsl_objs_now_invalid=objs_now_invalid,
                rpsl_objs_now_not_found=objs_now_not_found,
            )
            self.database_handler.commit()
            notified = notify_rpki_invalid_owners(self.database_handler, objs_now_invalid)
            logger.info(f'RPKI status updated for all routes, {len(objs_now_valid)} newly valid, '
                        f'{len(objs_now_invalid)} newly invalid, '
                        f'{len(objs_now_not_found)} newly not_found routes, '
                        f'{notified} emails sent to contacts of newly invalid authoritative objects')

        except OSError as ose:
            # I/O errors can occur and should not log a full traceback (#177)
            logger.error(f'An error occurred while attempting a ROA import: {ose}')
        except ROAParserException as rpe:
            logger.error(f'An exception occurred while attempting a ROA import: {rpe}')
        except Exception as exc:
            logger.error(f'An exception occurred while attempting a ROA import: {exc}', exc_info=exc)
        finally:
            self.database_handler.close()

    def _import_roas(self):
        roa_source = get_setting('rpki.roa_source')
        slurm_source = get_setting('rpki.slurm_source')
        logger.info(f'Running full ROA import from: {roa_source}, SLURM {slurm_source}')

        self.database_handler.delete_all_roa_objects()
        self.database_handler.delete_all_rpsl_objects_with_journal(RPKI_IRR_PSEUDO_SOURCE)

        slurm_data = None
        if slurm_source:
            slurm_data, _ = self._retrieve_file(slurm_source, return_contents=True)

        roa_filename, roa_to_delete = self._retrieve_file(roa_source, return_contents=False)
        with open(roa_filename) as fh:
            roa_importer = ROADataImporter(fh.read(), slurm_data, self.database_handler)
        if roa_to_delete:
            os.unlink(roa_filename)
        logger.info(f'ROA import from {roa_source}, SLURM {slurm_source}, imported {len(roa_importer.roa_objs)} ROAs, running validator')
        return roa_importer.roa_objs
예제 #13
0
class ChangeSubmissionHandler:
    """
    The ChangeSubmissionHandler handles the text of one or more requested RPSL changes
    (create, modify or delete), parses, validates and eventually saves
    them. This includes validating references between objects, including
    those part of the same message, and checking authentication.
    """
    def load_text_blob(self,
                       object_texts_blob: str,
                       pgp_fingerprint: str = None,
                       request_meta: Dict[str, Optional[str]] = None):
        self.database_handler = DatabaseHandler()
        self.request_meta = request_meta if request_meta else {}
        self._pgp_key_id = self._resolve_pgp_key_id(
            pgp_fingerprint) if pgp_fingerprint else None

        reference_validator = ReferenceValidator(self.database_handler)
        auth_validator = AuthValidator(self.database_handler, self._pgp_key_id)
        change_requests = parse_change_requests(object_texts_blob,
                                                self.database_handler,
                                                auth_validator,
                                                reference_validator)

        self._handle_change_requests(change_requests, reference_validator,
                                     auth_validator)
        self.database_handler.commit()
        self.database_handler.close()
        return self

    def load_change_submission(self,
                               data: RPSLChangeSubmission,
                               delete=False,
                               request_meta: Dict[str, Optional[str]] = None):
        self.database_handler = DatabaseHandler()
        self.request_meta = request_meta if request_meta else {}

        reference_validator = ReferenceValidator(self.database_handler)
        auth_validator = AuthValidator(self.database_handler)
        change_requests: List[Union[ChangeRequest, SuspensionRequest]] = []

        delete_reason = None
        if delete:
            delete_reason = data.delete_reason

        auth_validator.passwords = data.passwords
        auth_validator.overrides = [data.override] if data.override else []

        for rpsl_obj in data.objects:
            object_text = rpsl_obj.object_text
            if rpsl_obj.attributes:
                # We don't have a neat way to process individual attribute pairs,
                # so construct a pseudo-object by appending the text.
                composite_object = []
                for attribute in rpsl_obj.attributes:
                    composite_object.append(attribute.name + ': ' +
                                            attribute.value)  # type: ignore
                object_text = '\n'.join(composite_object) + '\n'

            assert object_text  # enforced by pydantic
            change_requests.append(
                ChangeRequest(object_text, self.database_handler,
                              auth_validator, reference_validator,
                              delete_reason))

        self._handle_change_requests(change_requests, reference_validator,
                                     auth_validator)
        self.database_handler.commit()
        self.database_handler.close()
        return self

    def load_suspension_submission(self,
                                   data: RPSLSuspensionSubmission,
                                   request_meta: Dict[str,
                                                      Optional[str]] = None):
        self.database_handler = DatabaseHandler()
        self.request_meta = request_meta if request_meta else {}

        reference_validator = ReferenceValidator(self.database_handler)
        auth_validator = AuthValidator(self.database_handler)
        change_requests: List[Union[ChangeRequest, SuspensionRequest]] = []

        auth_validator.overrides = [data.override] if data.override else []

        for rpsl_obj in data.objects:
            # We don't have a neat way to process individual attribute pairs,
            # so construct a pseudo-object by appending the text.
            object_text = f"mntner: {rpsl_obj.mntner}\nsource: {rpsl_obj.source}\n"
            change_requests.append(
                SuspensionRequest(
                    object_text,
                    self.database_handler,
                    auth_validator,
                    rpsl_obj.request_type.value,
                ))

        self._handle_change_requests(change_requests, reference_validator,
                                     auth_validator)
        self.database_handler.commit()
        self.database_handler.close()
        return self

    def _handle_change_requests(self, change_requests: List[
        Union[ChangeRequest,
              SuspensionRequest]], reference_validator: ReferenceValidator,
                                auth_validator: AuthValidator) -> None:

        objects = ', '.join([
            f'{request.rpsl_obj_new} (request {id(request)})'
            for request in change_requests
        ])
        logger.info(
            f'Processing change requests for {objects}, metadata is {self.request_meta}'
        )
        # When an object references another object, e.g. tech-c referring a person or mntner,
        # an add/update is only valid if those referred objects exist. To complicate matters,
        # the object referred to may be part of this very same submission. For this reason, the
        # reference validator can be provided with all new objects to be added in this submission.
        # However, a possible scenario is that A, B and C are submitted. Object A refers to B,
        # B refers to C, C refers to D and D does not exist - or C fails authentication.
        # At a first scan, A is valid because B exists, B is valid because C exists. C
        # becomes invalid on the first scan, which is why another scan is performed, which
        # will mark B invalid due to the reference to an invalid C, etc. This continues until
        # all references are resolved and repeated scans lead to the same conclusions.
        valid_changes = [r for r in change_requests if r.is_valid()]
        previous_valid_changes: List[Union[ChangeRequest,
                                           SuspensionRequest]] = []
        loop_count = 0
        loop_max = len(change_requests) + 10

        while valid_changes != previous_valid_changes:
            previous_valid_changes = valid_changes
            reference_validator.preload(valid_changes)
            valid_potential_new_mntners = [
                r.rpsl_obj_new for r in valid_changes
                if r.request_type == UpdateRequestType.CREATE
                and isinstance(r.rpsl_obj_new, RPSLMntner)
            ]
            auth_validator.pre_approve(valid_potential_new_mntners)

            for result in valid_changes:
                result.validate()
            valid_changes = [r for r in change_requests if r.is_valid()]

            loop_count += 1
            if loop_count > loop_max:  # pragma: no cover
                msg = f'Update validity resolver ran an excessive amount of loops, may be stuck, aborting ' \
                      f'processing. Message metadata: {self.request_meta}'
                logger.error(msg)
                raise ValueError(msg)

        for result in change_requests:
            if result.is_valid():
                result.save()

        self.results = change_requests

    def _resolve_pgp_key_id(self, pgp_fingerprint: str) -> Optional[str]:
        """
        Find a PGP key ID for a given fingerprint.
        This method looks for an actual matching object in the database,
        and then returns the object's PK.
        """
        clean_fingerprint = pgp_fingerprint.replace(' ', '')
        key_id = 'PGPKEY-' + clean_fingerprint[-8:]
        query = RPSLDatabaseQuery().object_classes(['key-cert'
                                                    ]).rpsl_pk(key_id)
        results = list(self.database_handler.execute_query(query))

        for result in results:
            if result['parsed_data'].get('fingerpr',
                                         '').replace(' ',
                                                     '') == clean_fingerprint:
                return key_id
        logger.info(
            f'Message was signed with key {key_id}, but key was not found in the database. Treating message '
            f'as unsigned. Message metadata: {self.request_meta}')
        return None

    def status(self) -> str:
        """Provide a simple SUCCESS/FAILED string based - former used if all objects were saved."""
        if all([
                result.status == UpdateRequestStatus.SAVED
                for result in self.results
        ]):
            return 'SUCCESS'
        return 'FAILED'

    def submitter_report_human(self) -> str:
        """Produce a human-readable report for the submitter."""
        # flake8: noqa: W293
        successful = [
            r for r in self.results if r.status == UpdateRequestStatus.SAVED
        ]
        failed = [
            r for r in self.results if r.status != UpdateRequestStatus.SAVED
        ]
        number_successful_create = len([
            r for r in successful if r.request_type == UpdateRequestType.CREATE
        ])
        number_successful_modify = len([
            r for r in successful if r.request_type == UpdateRequestType.MODIFY
        ])
        number_successful_delete = len([
            r for r in successful if r.request_type == UpdateRequestType.DELETE
        ])
        number_failed_create = len(
            [r for r in failed if r.request_type == UpdateRequestType.CREATE])
        number_failed_modify = len(
            [r for r in failed if r.request_type == UpdateRequestType.MODIFY])
        number_failed_delete = len(
            [r for r in failed if r.request_type == UpdateRequestType.DELETE])

        user_report = self._request_meta_str() + textwrap.dedent(f"""
        SUMMARY OF UPDATE:

        Number of objects found:                  {len(self.results):3}
        Number of objects processed successfully: {len(successful):3}
            Create:      {number_successful_create:3}
            Modify:      {number_successful_modify:3}
            Delete:      {number_successful_delete:3}
        Number of objects processed with errors:  {len(failed):3}
            Create:      {number_failed_create:3}
            Modify:      {number_failed_modify:3}
            Delete:      {number_failed_delete:3}
            
        DETAILED EXPLANATION:
        
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        """)
        for result in self.results:
            user_report += '---\n'
            user_report += result.submitter_report_human()
            user_report += '\n'
        user_report += '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n'
        return user_report

    def submitter_report_json(self):
        """Produce a JSON-ready report for the submitter."""
        successful = [
            r for r in self.results if r.status == UpdateRequestStatus.SAVED
        ]
        failed = [
            r for r in self.results if r.status != UpdateRequestStatus.SAVED
        ]
        number_successful_create = len([
            r for r in successful if r.request_type == UpdateRequestType.CREATE
        ])
        number_successful_modify = len([
            r for r in successful if r.request_type == UpdateRequestType.MODIFY
        ])
        number_successful_delete = len([
            r for r in successful if r.request_type == UpdateRequestType.DELETE
        ])
        number_failed_create = len(
            [r for r in failed if r.request_type == UpdateRequestType.CREATE])
        number_failed_modify = len(
            [r for r in failed if r.request_type == UpdateRequestType.MODIFY])
        number_failed_delete = len(
            [r for r in failed if r.request_type == UpdateRequestType.DELETE])

        return {
            'request_meta': self.request_meta,
            'summary': {
                'objects_found': len(self.results),
                'successful': len(successful),
                'successful_create': number_successful_create,
                'successful_modify': number_successful_modify,
                'successful_delete': number_successful_delete,
                'failed': len(failed),
                'failed_create': number_failed_create,
                'failed_modify': number_failed_modify,
                'failed_delete': number_failed_delete,
            },
            'objects':
            [result.submitter_report_json() for result in self.results],
        }

    def send_notification_target_reports(self):
        # First key is e-mail address of recipient, second is UpdateRequestStatus.SAVED
        # or UpdateRequestStatus.ERROR_AUTH
        reports_per_recipient: Dict[str, Dict[UpdateRequestStatus,
                                              OrderedSet]] = defaultdict(dict)
        sources: OrderedSet[str] = OrderedSet()

        for result in self.results:
            for target in result.notification_targets():
                if result.status in [
                        UpdateRequestStatus.SAVED,
                        UpdateRequestStatus.ERROR_AUTH
                ]:
                    if result.status not in reports_per_recipient[target]:
                        reports_per_recipient[target][
                            result.status] = OrderedSet()
                    reports_per_recipient[target][result.status].add(
                        result.notification_target_report())
                    sources.add(result.rpsl_obj_new.source())

        sources_str = '/'.join(sources)
        subject = f'Notification of {sources_str} database changes'
        header = get_setting('email.notification_header',
                             '').format(sources_str=sources_str)
        header += '\nThis message is auto-generated.\n'
        header += 'The request was made with the following details:\n'
        header_saved = textwrap.dedent("""
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            Some objects in which you are referenced have been created,
            deleted or changed.
            
        """)

        header_failed = textwrap.dedent("""
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            Some objects in which you are referenced were requested
            to be created, deleted or changed, but *failed* the 
            proper authorisation for any of the referenced maintainers.
            
        """)

        for recipient, reports_per_status in reports_per_recipient.items():
            user_report = header + self._request_meta_str()
            if UpdateRequestStatus.ERROR_AUTH in reports_per_status:
                user_report += header_failed
                for report in reports_per_status[
                        UpdateRequestStatus.ERROR_AUTH]:
                    user_report += f'---\n{report}\n'
                user_report += '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n'
            if UpdateRequestStatus.SAVED in reports_per_status:
                user_report += header_saved
                for report in reports_per_status[UpdateRequestStatus.SAVED]:
                    user_report += f'---\n{report}\n'
                user_report += '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n'

            email.send_email(recipient, subject, user_report)

    def _request_meta_str(self):
        request_meta_str = '\n'.join(
            [f'> {k}: {v}' for k, v in self.request_meta.items() if v])
        if request_meta_str:
            request_meta_str = '\n' + request_meta_str + '\n\n'
        return request_meta_str
예제 #14
0
class WhoisQueryParser:
    """
    Parser for all whois-style queries.

    This parser distinguishes RIPE-style, e.g. '-K 192.0.2.1' or '-i mnt-by FOO'
    from IRRD-style, e.g. '!oFOO'.

    Some query flags, particularly -k/!! and -s/!s retain state across queries,
    so a single instance of this object should be created per session, with
    handle_query() being called for each individual query.
    """
    lookup_field_names = lookup_field_names()
    database_handler: DatabaseHandler
    _current_set_root_object_class: Optional[str]

    def __init__(self, client_ip: str, client_str: str) -> None:
        self.all_valid_sources = list(get_setting('sources', {}).keys())
        self.sources_default = get_setting('sources_default')
        self.sources: List[
            str] = self.sources_default if self.sources_default else self.all_valid_sources
        if get_setting('rpki.roa_source'):
            self.all_valid_sources.append(RPKI_IRR_PSEUDO_SOURCE)
        self.object_classes: List[str] = []
        self.user_agent: Optional[str] = None
        self.multiple_command_mode = False
        self.rpki_aware = bool(get_setting('rpki.roa_source'))
        self.rpki_invalid_filter_enabled = self.rpki_aware
        self.timeout = 30
        self.key_fields_only = False
        self.client_ip = client_ip
        self.client_str = client_str
        self.preloader = Preloader()
        self._preloaded_query_count = 0

    def handle_query(self, query: str) -> WhoisQueryResponse:
        """
        Process a single query. Always returns a WhoisQueryResponse object.
        Not thread safe - only one call must be made to this method at the same time.
        """
        # These flags are reset with every query.
        self.database_handler = DatabaseHandler()
        self.key_fields_only = False
        self.object_classes = []

        if query.startswith('!'):
            try:
                return self.handle_irrd_command(query[1:])
            except WhoisQueryParserException as exc:
                logger.info(
                    f'{self.client_str}: encountered parsing error while parsing query "{query}": {exc}'
                )
                return WhoisQueryResponse(
                    response_type=WhoisQueryResponseType.ERROR,
                    mode=WhoisQueryResponseMode.IRRD,
                    result=str(exc))
            except Exception as exc:
                logger.error(
                    f'An exception occurred while processing whois query "{query}": {exc}',
                    exc_info=exc)
                return WhoisQueryResponse(
                    response_type=WhoisQueryResponseType.ERROR,
                    mode=WhoisQueryResponseMode.IRRD,
                    result=
                    'An internal error occurred while processing this query.')
            finally:
                self.database_handler.close()

        try:
            return self.handle_ripe_command(query)
        except WhoisQueryParserException as exc:
            logger.info(
                f'{self.client_str}: encountered parsing error while parsing query "{query}": {exc}'
            )
            return WhoisQueryResponse(
                response_type=WhoisQueryResponseType.ERROR,
                mode=WhoisQueryResponseMode.RIPE,
                result=str(exc))
        except Exception as exc:
            logger.error(
                f'An exception occurred while processing whois query "{query}": {exc}',
                exc_info=exc)
            return WhoisQueryResponse(
                response_type=WhoisQueryResponseType.ERROR,
                mode=WhoisQueryResponseMode.RIPE,
                result='An internal error occurred while processing this query.'
            )
        finally:
            self.database_handler.close()

    def handle_irrd_command(self, full_command: str) -> WhoisQueryResponse:
        """Handle an IRRD-style query. full_command should not include the first exclamation mark. """
        if not full_command:
            raise WhoisQueryParserException(f'Missing IRRD command')
        command = full_command[0].upper()
        parameter = full_command[1:]
        response_type = WhoisQueryResponseType.SUCCESS
        result = None

        # A is not tested here because it is already handled in handle_irrd_routes_for_as_set
        queries_with_parameter = list('TG6IJMNORS')
        if command in queries_with_parameter and not parameter:
            raise WhoisQueryParserException(
                f'Missing parameter for {command} query')

        if command == '!':
            self.multiple_command_mode = True
            result = None
            response_type = WhoisQueryResponseType.NO_RESPONSE
        elif full_command.upper() == 'FNO-RPKI-FILTER':
            self.rpki_invalid_filter_enabled = False
            result = 'Filtering out RPKI invalids is disabled for !r and RIPE style ' \
                     'queries for the rest of this connection.'
        elif command == 'V':
            result = self.handle_irrd_version()
        elif command == 'T':
            self.handle_irrd_timeout_update(parameter)
        elif command == 'G':
            result = self.handle_irrd_routes_for_origin_v4(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == '6':
            result = self.handle_irrd_routes_for_origin_v6(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'A':
            result = self.handle_irrd_routes_for_as_set(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'I':
            result = self.handle_irrd_set_members(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'J':
            result = self.handle_irrd_database_serial_range(parameter)
        elif command == 'M':
            result = self.handle_irrd_exact_key(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'N':
            self.handle_user_agent(parameter)
        elif command == 'O':
            result = self.handle_inverse_attr_search('mnt-by', parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'R':
            result = self.handle_irrd_route_search(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'S':
            result = self.handle_irrd_sources_list(parameter)
        else:
            raise WhoisQueryParserException(f'Unrecognised command: {command}')

        return WhoisQueryResponse(
            response_type=response_type,
            mode=WhoisQueryResponseMode.IRRD,
            result=result,
        )

    def handle_irrd_timeout_update(self, timeout: str) -> None:
        """!timeout query - update timeout in connection"""
        try:
            timeout_value = int(timeout)
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid value for timeout: {timeout}')

        if timeout_value > 0 and timeout_value <= 1000:
            self.timeout = timeout_value
        else:
            raise WhoisQueryParserException(
                f'Invalid value for timeout: {timeout}')

    def handle_irrd_routes_for_origin_v4(self, origin: str) -> str:
        """!g query - find all originating IPv4 prefixes from an origin, e.g. !gAS65537"""
        return self._routes_for_origin(origin, 4)

    def handle_irrd_routes_for_origin_v6(self, origin: str) -> str:
        """!6 query - find all originating IPv6 prefixes from an origin, e.g. !6as65537"""
        return self._routes_for_origin(origin, 6)

    def _routes_for_origin(self,
                           origin: str,
                           ip_version: Optional[int] = None) -> str:
        """
        Resolve all route(6)s prefixes for an origin, returning a space-separated list
        of all originating prefixes, not including duplicates.
        """
        try:
            origin_formatted, _ = parse_as_number(origin)
        except ValidationError as ve:
            raise WhoisQueryParserException(str(ve))

        self._preloaded_query_called()
        prefixes = self.preloader.routes_for_origins([origin_formatted],
                                                     self.sources,
                                                     ip_version=ip_version)
        return ' '.join(prefixes)

    def handle_irrd_routes_for_as_set(self, set_name: str) -> str:
        """
        !a query - find all originating prefixes for all members of an AS-set, e.g. !a4AS-FOO or !a6AS-FOO
        """
        ip_version: Optional[int] = None
        if set_name.startswith('4'):
            set_name = set_name[1:]
            ip_version = 4
        elif set_name.startswith('6'):
            set_name = set_name[1:]
            ip_version = 6

        if not set_name:
            raise WhoisQueryParserException(
                f'Missing required set name for A query')

        self._preloaded_query_called()
        self._current_set_root_object_class = 'as-set'
        members = self._recursive_set_resolve({set_name})
        prefixes = self.preloader.routes_for_origins(members,
                                                     self.sources,
                                                     ip_version=ip_version)
        return ' '.join(prefixes)

    def handle_irrd_set_members(self, parameter: str) -> str:
        """
        !i query - find all members of an as-set or route-set, possibly recursively.
        e.g. !iAS-FOO for non-recursive, !iAS-FOO,1 for recursive
        """
        self._preloaded_query_called()
        recursive = False
        if parameter.endswith(',1'):
            recursive = True
            parameter = parameter[:-2]

        self._current_set_root_object_class = None
        if not recursive:
            members, leaf_members = self._find_set_members({parameter})
            members.update(leaf_members)
        else:
            members = self._recursive_set_resolve({parameter})
        if parameter in members:
            members.remove(parameter)

        if get_setting('compatibility.ipv4_only_route_set_members'):
            original_members = set(members)
            for member in original_members:
                try:
                    IP(member)
                except ValueError:
                    continue  # This is not a prefix, ignore.
                try:
                    IP(member, ipversion=4)
                except ValueError:
                    # This was a valid prefix, but not a valid IPv4 prefix,
                    # and should be removed.
                    members.remove(member)

        return ' '.join(sorted(members))

    def _recursive_set_resolve(self,
                               members: Set[str],
                               sets_seen=None) -> Set[str]:
        """
        Resolve all members of a number of sets, recursively.

        For each set in members, determines whether it has been seen already (to prevent
        infinite recursion), ignores it if already seen, and then either adds
        it directly or adds it to a set that requires further resolving.
        """
        if not sets_seen:
            sets_seen = set()

        if all([member in sets_seen for member in members]):
            return set()
        sets_seen.update(members)

        set_members = set()
        resolved_as_members = set()
        sub_members, leaf_members = self._find_set_members(members)

        for sub_member in sub_members:
            if self._current_set_root_object_class is None or self._current_set_root_object_class == 'route-set':
                try:
                    IP(sub_member)
                    set_members.add(sub_member)
                    continue
                except ValueError:
                    pass
            # AS numbers are permitted in route-sets and as-sets, per RFC 2622 5.3.
            # When an AS number is encountered as part of route-set resolving,
            # the prefixes originating from that AS should be added to the response.
            try:
                as_number_formatted, _ = parse_as_number(sub_member)
                if self._current_set_root_object_class == 'route-set':
                    set_members.update(
                        self.preloader.routes_for_origins(
                            [as_number_formatted], self.sources))
                    resolved_as_members.add(sub_member)
                else:
                    set_members.add(sub_member)
                continue
            except ValueError:
                pass

        further_resolving_required = sub_members - set_members - sets_seen - resolved_as_members
        new_members = self._recursive_set_resolve(further_resolving_required,
                                                  sets_seen)
        set_members.update(new_members)

        return set_members

    def _find_set_members(self,
                          set_names: Set[str]) -> Tuple[Set[str], Set[str]]:
        """
        Find all members of a number of route-sets or as-sets. Includes both
        direct members listed in members attribute, but also
        members included by mbrs-by-ref/member-of.

        Returns a tuple of two sets:
        - members found of the sets included in set_names, both
          references to other sets and direct AS numbers, etc.
        - leaf members that were included in set_names, i.e.
          names for which no further data could be found - for
          example references to non-existent other sets
        """
        members: Set[str] = set()
        sets_already_resolved: Set[str] = set()

        columns = ['parsed_data', 'rpsl_pk', 'source', 'object_class']
        query = self._prepare_query(column_names=columns)

        object_classes = ['as-set', 'route-set']
        # Per RFC 2622 5.3, route-sets can refer to as-sets,
        # but as-sets can only refer to other as-sets.
        if self._current_set_root_object_class == 'as-set':
            object_classes = [self._current_set_root_object_class]

        query = query.object_classes(object_classes).rpsl_pks(set_names)
        query_result = list(self.database_handler.execute_query(query))

        if not query_result:
            # No sub-members are found, and apparantly all inputs were leaf members.
            return set(), set_names

        # Track the object class of the root object set.
        # In one case self._current_set_root_object_class may already be set
        # on the first run: when the set resolving should be fixed to one
        # type of set object.
        if not self._current_set_root_object_class:
            self._current_set_root_object_class = query_result[0][
                'object_class']

        for result in query_result:
            rpsl_pk = result['rpsl_pk']

            # The same PK may occur in multiple sources, but we are
            # only interested in the first matching object, prioritised
            # according to the source order. This priority is part of the
            # query ORDER BY, so basically we only process an RPSL pk once.
            if rpsl_pk in sets_already_resolved:
                continue
            sets_already_resolved.add(rpsl_pk)

            object_class = result['object_class']
            object_data = result['parsed_data']
            mbrs_by_ref = object_data.get('mbrs-by-ref', None)
            for members_attr in ['members', 'mp-members']:
                if members_attr in object_data:
                    members.update(set(object_data[members_attr]))

            if not rpsl_pk or not object_class or not mbrs_by_ref:
                continue

            # If mbrs-by-ref is set, find any objects with member-of pointing to the route/as-set
            # under query, and include a maintainer listed in mbrs-by-ref, unless mbrs-by-ref
            # is set to ANY.
            query_object_class = [
                'route', 'route6'
            ] if object_class == 'route-set' else ['aut-num']
            query = self._prepare_query(
                column_names=columns).object_classes(query_object_class)
            query = query.lookup_attrs_in(['member-of'], [rpsl_pk])

            if 'ANY' not in [m.strip().upper() for m in mbrs_by_ref]:
                query = query.lookup_attrs_in(['mnt-by'], mbrs_by_ref)

            referring_objects = self.database_handler.execute_query(query)

            for result in referring_objects:
                member_object_class = result['object_class']
                members.add(result['parsed_data'][member_object_class])

        leaf_members = set_names - sets_already_resolved
        return members, leaf_members

    def handle_irrd_database_serial_range(self, parameter: str) -> str:
        """!j query - database serial range"""
        if parameter == '-*':
            sources = self.sources_default if self.sources_default else self.all_valid_sources
        else:
            sources = [s.upper() for s in parameter.split(',')]
        invalid_sources = [
            s for s in sources if s not in self.all_valid_sources
        ]
        query = DatabaseStatusQuery().sources(sources)
        query_results = self.database_handler.execute_query(query)

        result_txt = ''
        for query_result in query_results:
            source = query_result['source'].upper()
            keep_journal = 'Y' if get_setting(
                f'sources.{source}.keep_journal') else 'N'
            serial_oldest = query_result['serial_oldest_seen']
            serial_newest = query_result['serial_newest_seen']
            fields = [
                source,
                keep_journal,
                f'{serial_oldest}-{serial_newest}'
                if serial_oldest and serial_newest else '-',
            ]
            if query_result['serial_last_export']:
                fields.append(str(query_result['serial_last_export']))
            result_txt += ':'.join(fields) + '\n'

        for invalid_source in invalid_sources:
            result_txt += f'{invalid_source.upper()}:X:Database unknown\n'
        return result_txt.strip()

    def handle_irrd_exact_key(self, parameter: str):
        """!m query - exact object key lookup, e.g. !maut-num,AS65537"""
        try:
            object_class, rpsl_pk = parameter.split(',', maxsplit=1)
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid argument for object lookup: {parameter}')
        query = self._prepare_query().object_classes(
            [object_class]).rpsl_pk(rpsl_pk).first_only()
        return self._execute_query_flatten_output(query)

    def handle_irrd_route_search(self, parameter: str):
        """
        !r query - route search with various options:
           !r192.0.2.0/24 returns all exact matching objects
           !r192.0.2.0/24,o returns space-separated origins of all exact matching objects
           !r192.0.2.0/24,l returns all one-level less specific objects, not including exact
           !r192.0.2.0/24,L returns all less specific objects, including exact
           !r192.0.2.0/24,M returns all more specific objects, not including exact
        """
        option: Optional[str] = None
        if ',' in parameter:
            address, option = parameter.split(',')
        else:
            address = parameter
        try:
            address = IP(address)
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid input for route search: {parameter}')

        query = self._prepare_query(ordered_by_sources=False).object_classes(
            ['route', 'route6'])
        if option is None or option == 'o':
            query = query.ip_exact(address)
        elif option == 'l':
            query = query.ip_less_specific_one_level(address)
        elif option == 'L':
            query = query.ip_less_specific(address)
        elif option == 'M':
            query = query.ip_more_specific(address)
        else:
            raise WhoisQueryParserException(
                f'Invalid route search option: {option}')

        if option == 'o':
            query_result = self.database_handler.execute_query(query)
            prefixes = [r['parsed_data']['origin'] for r in query_result]
            return ' '.join(prefixes)
        return self._execute_query_flatten_output(query)

    def handle_irrd_sources_list(self, parameter: str) -> Optional[str]:
        """
        !s query - set used sources
           !s-lc returns all enabled sources, space separated
           !sripe,nttcom limits sources to ripe and nttcom
        """
        if parameter == '-lc':
            return ','.join(self.sources)

        sources = parameter.upper().split(',')
        if not all([source in self.all_valid_sources for source in sources]):
            raise WhoisQueryParserException(
                'One or more selected sources are unavailable.')
        self.sources = sources

        return None

    def handle_irrd_version(self):
        """!v query - return version"""
        return f'IRRd -- version {__version__}'

    def handle_ripe_command(self, full_query: str) -> WhoisQueryResponse:
        """
        Process RIPE-style queries. Any query that is not explicitly an IRRD-style
        query (i.e. starts with exclamation mark) is presumed to be a RIPE query.
        """
        full_query = re.sub(' +', ' ', full_query)
        components = full_query.strip().split(' ')
        result = None
        response_type = WhoisQueryResponseType.SUCCESS

        while len(components):
            component = components.pop(0)
            if component.startswith('-'):
                command = component[1:]
                try:
                    if command == 'k':
                        self.multiple_command_mode = True
                    elif command in ['l', 'L', 'M', 'x']:
                        result = self.handle_ripe_route_search(
                            command, components.pop(0))
                        if not result:
                            response_type = WhoisQueryResponseType.KEY_NOT_FOUND
                        break
                    elif command == 'i':
                        result = self.handle_inverse_attr_search(
                            components.pop(0), components.pop(0))
                        if not result:
                            response_type = WhoisQueryResponseType.KEY_NOT_FOUND
                        break
                    elif command == 's':
                        self.handle_ripe_sources_list(components.pop(0))
                    elif command == 'a':
                        self.handle_ripe_sources_list(None)
                    elif command == 'T':
                        self.handle_ripe_restrict_object_class(
                            components.pop(0))
                    elif command == 't':
                        result = self.handle_ripe_request_object_template(
                            components.pop(0))
                        break
                    elif command == 'K':
                        self.handle_ripe_key_fields_only()
                    elif command == 'V':
                        self.handle_user_agent(components.pop(0))
                    elif command == 'g':
                        result = self.handle_nrtm_request(components.pop(0))
                    elif command in ['F', 'r']:
                        continue  # These flags disable recursion, but IRRd never performs recursion anyways
                    else:
                        raise WhoisQueryParserException(
                            f'Unrecognised flag/search: {command}')
                except IndexError:
                    raise WhoisQueryParserException(
                        f'Missing argument for flag/search: {command}')
            else:  # assume query to be a free text search
                result = self.handle_ripe_text_search(component)

        return WhoisQueryResponse(
            response_type=response_type,
            mode=WhoisQueryResponseMode.RIPE,
            result=result,
        )

    def handle_ripe_route_search(self, command: str, parameter: str) -> str:
        """
        -l/L/M/x query - route search for:
           -x 192.0.2.0/2 returns all exact matching objects
           -l 192.0.2.0/2 returns all one-level less specific objects, not including exact
           -L 192.0.2.0/2 returns all less specific objects, including exact
           -M 192.0.2.0/2 returns all more specific objects, not including exact
        """
        try:
            address = IP(parameter)
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid input for route search: {parameter}')

        query = self._prepare_query(ordered_by_sources=False).object_classes(
            ['route', 'route6'])
        if command == 'x':
            query = query.ip_exact(address)
        elif command == 'l':
            query = query.ip_less_specific_one_level(address)
        elif command == 'L':
            query = query.ip_less_specific(address)
        elif command == 'M':
            query = query.ip_more_specific(address)

        return self._execute_query_flatten_output(query)

    def handle_ripe_sources_list(self, sources_list: Optional[str]) -> None:
        """-s/-a parameter - set sources list. Empty list enables all sources. """
        if sources_list:
            sources = sources_list.upper().split(',')
            if not all(
                [source in self.all_valid_sources for source in sources]):
                raise WhoisQueryParserException(
                    'One or more selected sources are unavailable.')
            self.sources = sources
        else:
            self.sources = self.sources_default if self.sources_default else self.all_valid_sources

    def handle_ripe_restrict_object_class(self, object_classes) -> None:
        """-T parameter - restrict object classes for this query, comma-seperated"""
        self.object_classes = object_classes.split(',')

    def handle_ripe_request_object_template(self, object_class) -> str:
        """-t query - return the RPSL template for an object class"""
        try:
            return OBJECT_CLASS_MAPPING[object_class]().generate_template()
        except KeyError:
            raise WhoisQueryParserException(
                f'Unknown object class: {object_class}')

    def handle_ripe_key_fields_only(self) -> None:
        """-K paramater - only return primary key and members fields"""
        self.key_fields_only = True

    def handle_ripe_text_search(self, value: str) -> str:
        query = self._prepare_query(
            ordered_by_sources=False).text_search(value)
        return self._execute_query_flatten_output(query)

    def handle_user_agent(self, user_agent: str):
        """-V/!n parameter/query - set a user agent for the client"""
        self.user_agent = user_agent
        logger.info(f'{self.client_str}: user agent set to: {user_agent}')

    def handle_nrtm_request(self, param):
        try:
            source, version, serial_range = param.split(':')
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid parameter: must contain three elements')

        try:
            serial_start, serial_end = serial_range.split('-')
            serial_start = int(serial_start)
            if serial_end == 'LAST':
                serial_end = None
            else:
                serial_end = int(serial_end)
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid serial range: {serial_range}')

        if version not in ['1', '3']:
            raise WhoisQueryParserException(f'Invalid NRTM version: {version}')

        source = source.upper()
        if source not in self.all_valid_sources:
            raise WhoisQueryParserException(f'Unknown source: {source}')

        if not is_client_permitted(self.client_ip,
                                   f'sources.{source}.nrtm_access_list'):
            raise WhoisQueryParserException(f'Access denied')

        try:
            return NRTMGenerator().generate(source, version, serial_start,
                                            serial_end, self.database_handler)
        except NRTMGeneratorException as nge:
            raise WhoisQueryParserException(str(nge))

    def handle_inverse_attr_search(self, attribute: str, value: str) -> str:
        """
        -i/!o query - inverse search for attribute values
        e.g. `-i mnt-by FOO` finds all objects where (one of the) maintainer(s) is FOO,
        as does `!oFOO`. Restricted to designated lookup fields.
        """
        if attribute not in self.lookup_field_names:
            readable_lookup_field_names = ', '.join(self.lookup_field_names)
            msg = (
                f'Inverse attribute search not supported for {attribute},' +
                f'only supported for attributes: {readable_lookup_field_names}'
            )
            raise WhoisQueryParserException(msg)
        query = self._prepare_query(ordered_by_sources=False).lookup_attr(
            attribute, value)
        return self._execute_query_flatten_output(query)

    def _prepare_query(self,
                       column_names=None,
                       ordered_by_sources=True) -> RPSLDatabaseQuery:
        """Prepare an RPSLDatabaseQuery by applying relevant sources/class filters."""
        query = RPSLDatabaseQuery(column_names, ordered_by_sources)
        if self.sources and self.sources != self.all_valid_sources:
            query.sources(self.sources)
        else:
            default = list(get_setting('sources_default', []))
            if default:
                query.sources(list(default))
        if self.object_classes:
            query.object_classes(self.object_classes)
        if self.rpki_invalid_filter_enabled:
            query.rpki_status([RPKIStatus.not_found, RPKIStatus.valid])
        return query

    def _execute_query_flatten_output(self, query: RPSLDatabaseQuery) -> str:
        """
        Execute an RPSLDatabaseQuery, and flatten the output into a string with object text
        for easy passing to a WhoisQueryResponse.
        """
        query_response = self.database_handler.execute_query(query)
        if self.key_fields_only:
            result = self._filter_key_fields(query_response)
        else:
            result = ''
            for obj in query_response:
                result += obj['object_text']
                if (self.rpki_aware and obj['source'] != RPKI_IRR_PSEUDO_SOURCE
                        and obj['object_class']
                        in RPKI_RELEVANT_OBJECT_CLASSES):
                    comment = ''
                    if obj['rpki_status'] == RPKIStatus.not_found:
                        comment = ' # No ROAs found, or RPKI validation not enabled for source'
                    result += f'rpki-ov-state:  {obj["rpki_status"].name}{comment}\n'
                result += '\n'
        return result.strip('\n\r')

    def _filter_key_fields(self, query_response) -> str:
        results: OrderedSet[str] = OrderedSet()
        for obj in query_response:
            result = ''
            rpsl_object_class = OBJECT_CLASS_MAPPING[obj['object_class']]
            fields_included = rpsl_object_class.pk_fields + [
                'members', 'mp-members'
            ]

            for field_name in fields_included:
                field_data = obj['parsed_data'].get(field_name)
                if field_data:
                    if isinstance(field_data, list):
                        for item in field_data:
                            result += f'{field_name}: {item}\n'
                    else:
                        result += f'{field_name}: {field_data}\n'
            results.add(result)
        return '\n'.join(results)

    def _preloaded_query_called(self):
        """
        Called each time the user runs a query that can be preloaded.
        After the 5th, load the preload store into memory to speed
        up expected further queries.
        """
        self._preloaded_query_count += 1
        if self._preloaded_query_count > 5:
            self.preloader.load_routes_into_memory()
예제 #15
0
def set_force_reload(source) -> None:
    dh = DatabaseHandler(enable_preload_update=False)
    dh.set_force_reload(source)
    dh.commit()
    dh.close()
예제 #16
0
파일: handler.py 프로젝트: bitcynth/irrd4
class ChangeSubmissionHandler:
    """
    The ChangeSubmissionHandler handles the text of one or more requested RPSL changes
    (create, modify or delete), parses, validates and eventually saves
    them. This includes validating references between objects, including
    those part of the same message, and checking authentication.
    """

    def __init__(self, object_texts: str, pgp_fingerprint: str=None, request_meta: Dict[str, Optional[str]]=None) -> None:
        self.database_handler = DatabaseHandler()
        self.request_meta = request_meta if request_meta else {}
        self._pgp_key_id = self._resolve_pgp_key_id(pgp_fingerprint) if pgp_fingerprint else None
        self._handle_object_texts(object_texts)
        self.database_handler.commit()
        self.database_handler.close()

    def _handle_object_texts(self, object_texts: str) -> None:
        reference_validator = ReferenceValidator(self.database_handler)
        auth_validator = AuthValidator(self.database_handler, self._pgp_key_id)
        results = parse_change_requests(object_texts, self.database_handler, auth_validator, reference_validator)

        # When an object references another object, e.g. tech-c referring a person or mntner,
        # an add/update is only valid if those referred objects exist. To complicate matters,
        # the object referred to may be part of this very same submission. For this reason, the
        # reference validator can be provided with all new objects to be added in this submission.
        # However, a possible scenario is that A, B and C are submitted. Object A refers to B,
        # B refers to C, C refers to D and D does not exist - or C fails authentication.
        # At a first scan, A is valid because B exists, B is valid because C exists. C
        # becomes invalid on the first scan, which is why another scan is performed, which
        # will mark B invalid due to the reference to an invalid C, etc. This continues until
        # all references are resolved and repeated scans lead to the same conclusions.
        valid_changes = [r for r in results if r.is_valid()]
        previous_valid_changes: List[ChangeRequest] = []
        loop_count = 0
        loop_max = len(results) + 10

        while valid_changes != previous_valid_changes:
            previous_valid_changes = valid_changes
            reference_validator.preload(valid_changes)
            auth_validator.pre_approve(valid_changes)

            for result in valid_changes:
                result.validate()
            valid_changes = [r for r in results if r.is_valid()]

            loop_count += 1
            if loop_count > loop_max:  # pragma: no cover
                msg = f'Update validity resolver ran an excessive amount of loops, may be stuck, aborting ' \
                      f'processing. Message metadata: {self.request_meta}'
                logger.error(msg)
                raise ValueError(msg)

        for result in results:
            if result.is_valid():
                result.save(self.database_handler)

        self.results = results

    def _resolve_pgp_key_id(self, pgp_fingerprint: str) -> Optional[str]:
        """
        Find a PGP key ID for a given fingerprint.
        This method looks for an actual matching object in the database,
        and then returns the object's PK.
        """
        clean_fingerprint = pgp_fingerprint.replace(' ', '')
        key_id = "PGPKEY-" + clean_fingerprint[-8:]
        query = RPSLDatabaseQuery().object_classes(['key-cert']).rpsl_pk(key_id)
        results = list(self.database_handler.execute_query(query))

        for result in results:
            if result['parsed_data'].get('fingerpr', '').replace(' ', '') == clean_fingerprint:
                return key_id
        logger.info(f'Message was signed with key {key_id}, but key was not found in the database. Treating message '
                    f'as unsigned. Message metadata: {self.request_meta}')
        return None

    def status(self) -> str:
        """Provide a simple SUCCESS/FAILED string based - former used if all objects were saved."""
        if all([result.status == UpdateRequestStatus.SAVED for result in self.results]):
            return "SUCCESS"
        return "FAILED"

    def submitter_report(self) -> str:
        """Produce a human-readable report for the submitter."""
        # flake8: noqa: W293
        successful = [r for r in self.results if r.status == UpdateRequestStatus.SAVED]
        failed = [r for r in self.results if r.status != UpdateRequestStatus.SAVED]
        number_successful_create = len([r for r in successful if r.request_type == UpdateRequestType.CREATE])
        number_successful_modify = len([r for r in successful if r.request_type == UpdateRequestType.MODIFY])
        number_successful_delete = len([r for r in successful if r.request_type == UpdateRequestType.DELETE])
        number_failed_create = len([r for r in failed if r.request_type == UpdateRequestType.CREATE])
        number_failed_modify = len([r for r in failed if r.request_type == UpdateRequestType.MODIFY])
        number_failed_delete = len([r for r in failed if r.request_type == UpdateRequestType.DELETE])

        user_report = self._request_meta_str() + textwrap.dedent(f"""
        SUMMARY OF UPDATE:

        Number of objects found:                  {len(self.results):3}
        Number of objects processed successfully: {len(successful):3}
            Create:      {number_successful_create:3}
            Modify:      {number_successful_modify:3}
            Delete:      {number_successful_delete:3}
        Number of objects processed with errors:  {len(failed):3}
            Create:      {number_failed_create:3}
            Modify:      {number_failed_modify:3}
            Delete:      {number_failed_delete:3}
            
        DETAILED EXPLANATION:
        
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        """)
        for result in self.results:
            user_report += "---\n"
            user_report += result.submitter_report()
            user_report += "\n"
        user_report += '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n'
        return user_report

    def send_notification_target_reports(self):
        # First key is e-mail address of recipient, second is UpdateRequestStatus.SAVED
        # or UpdateRequestStatus.ERROR_AUTH
        reports_per_recipient: Dict[str, Dict[UpdateRequestStatus, OrderedSet]] = defaultdict(dict)
        sources: OrderedSet[str] = OrderedSet()

        for result in self.results:
            for target in result.notification_targets():
                if result.status in [UpdateRequestStatus.SAVED, UpdateRequestStatus.ERROR_AUTH]:
                    if result.status not in reports_per_recipient[target]:
                        reports_per_recipient[target][result.status] = OrderedSet()
                    reports_per_recipient[target][result.status].add(result.notification_target_report())
                    sources.add(result.rpsl_obj_new.source())

        sources_str = '/'.join(sources)
        subject = f'Notification of {sources_str} database changes'
        header = textwrap.dedent(f"""
            This is to notify you of changes in the {sources_str} database
            or object authorisation failures.
            
            You may receive this message because you are listed in
            the notify attribute on the changed object(s), or because
            you are listed in the mnt-nfy or upd-to attribute on a maintainer
            of the object(s).
            
            This message is auto-generated.
            The request was made by email, with the following details:
        """)
        header_saved = textwrap.dedent("""
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            Some objects in which you are referenced have been created,
            deleted or changed.
            
        """)

        header_failed = textwrap.dedent("""
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            Some objects in which you are referenced were requested
            to be created, deleted or changed, but *failed* the 
            proper authorisation for any of the referenced maintainers.
            
        """)

        for recipient, reports_per_status in reports_per_recipient.items():
            user_report = header + self._request_meta_str()
            if UpdateRequestStatus.ERROR_AUTH in reports_per_status:
                user_report += header_failed
                for report in reports_per_status[UpdateRequestStatus.ERROR_AUTH]:
                    user_report += f"---\n{report}\n"
                user_report += '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n'
            if UpdateRequestStatus.SAVED in reports_per_status:
                user_report += header_saved
                for report in reports_per_status[UpdateRequestStatus.SAVED]:
                    user_report += f"---\n{report}\n"
                user_report += '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n'

            email.send_email(recipient, subject, user_report)

    def _request_meta_str(self):
        request_meta_str = '\n'.join([f"> {k}: {v}" for k, v in self.request_meta.items() if v])
        if request_meta_str:
            request_meta_str = "\n" + request_meta_str + "\n\n"
        return request_meta_str
예제 #17
0
class WhoisQueryParser:
    """
    Parser for all whois-style queries.

    This parser distinguishes RIPE-style, e.g. "-K 192.0.2.1" or "-i mnt-by FOO"
    from IRRD-style, e.g. "!oFOO".

    Some query flags, particularly -k/!! and -s/!s retain state across queries,
    so a single instance of this object should be created per session, with
    handle_query() being called for each individual query.
    """
    lookup_field_names = lookup_field_names()
    database_handler: DatabaseHandler

    def __init__(self, peer: str) -> None:
        self.all_valid_sources = list(get_setting('sources').keys())
        self.sources: List[str] = []
        self.object_classes: List[str] = []
        self.user_agent: Optional[str] = None
        self.multiple_command_mode = False
        self.key_fields_only = False
        self.peer = peer

        # The WhoisQueryParser itself should not run concurrently,
        # as answers could arrive out of order.
        self.safe_to_run_query = threading.Event()
        self.safe_to_run_query.set()

    def handle_query(self, query: str) -> WhoisQueryResponse:
        """Process a single query. Always returns a WhoisQueryResponse object."""
        # These flags are reset with every query.
        self.safe_to_run_query.wait()
        self.safe_to_run_query.clear()
        self.database_handler = DatabaseHandler()
        self.key_fields_only = False
        self.object_classes = []

        if query.startswith('!'):
            try:
                return self.handle_irrd_command(query[1:])
            except WhoisQueryParserException as exc:
                logger.info(f'{self.peer}: encountered parsing error while parsing query "{query}": {exc}')
                return WhoisQueryResponse(
                    response_type=WhoisQueryResponseType.ERROR,
                    mode=WhoisQueryResponseMode.IRRD,
                    result=str(exc)
                )
            except Exception as exc:
                logger.critical(f'An exception occurred while processing whois query "{query}": {exc}', exc_info=exc)
                return WhoisQueryResponse(
                    response_type=WhoisQueryResponseType.ERROR,
                    mode=WhoisQueryResponseMode.IRRD,
                    result='An internal error occurred while processing this query.'
                )
            finally:
                self.database_handler.close()
                self.safe_to_run_query.set()

        try:
            return self.handle_ripe_command(query)
        except WhoisQueryParserException as exc:
            logger.info(f'{self.peer}: encountered parsing error while parsing query "{query}": {exc}')
            return WhoisQueryResponse(
                response_type=WhoisQueryResponseType.ERROR,
                mode=WhoisQueryResponseMode.RIPE,
                result=str(exc)
            )
        except Exception as exc:
            logger.critical(f'An exception occurred while processing whois query "{query}": {exc}', exc_info=exc)
            return WhoisQueryResponse(
                response_type=WhoisQueryResponseType.ERROR,
                mode=WhoisQueryResponseMode.RIPE,
                result='An internal error occurred while processing this query.'
            )
        finally:
            self.database_handler.close()
            self.safe_to_run_query.set()

    def handle_irrd_command(self, full_command: str) -> WhoisQueryResponse:
        """Handle an IRRD-style query. full_command should not include the first exclamation mark. """
        if not full_command:
            raise WhoisQueryParserException(f'Missing IRRD command')
        command = full_command[0].upper()
        parameter = full_command[1:]
        response_type = WhoisQueryResponseType.SUCCESS
        result = None

        if command == '!':
            self.multiple_command_mode = True
            result = None
        elif command == 'G':
            result = self.handle_irrd_routes_for_origin_v4(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == '6':
            result = self.handle_irrd_routes_for_origin_v6(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'I':
            result = self.handle_irrd_set_members(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'J':
            result = self.handle_irrd_database_serial_range(parameter)
        elif command == 'M':
            result = self.handle_irrd_exact_key(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'N':
            self.handle_user_agent(parameter)
        elif command == 'O':
            result = self.handle_inverse_attr_search('mnt-by', parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'R':
            result = self.handle_irrd_route_search(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'S':
            result = self.handle_irrd_sources_list(parameter)
        elif command == 'V':
            result = self.handle_irrd_version()
        else:
            raise WhoisQueryParserException(f'Unrecognised command: {command}')

        return WhoisQueryResponse(
            response_type=response_type,
            mode=WhoisQueryResponseMode.IRRD,
            result=result,
        )

    def handle_irrd_routes_for_origin_v4(self, origin: str) -> str:
        """!g query - find all originating IPv4 prefixes from an origin, e.g. !gAS65537"""
        return self._routes_for_origin('route', origin)

    def handle_irrd_routes_for_origin_v6(self, origin: str) -> str:
        """!6 query - find all originating IPv6 prefixes from an origin, e.g. !6as65537"""
        return self._routes_for_origin('route6', origin)

    def _routes_for_origin(self, object_class: str, origin: str) -> str:
        """
        Resolve all route(6)s for an origin, returning a space-separated list
        of all originating prefixes, not including duplicates.
        """
        try:
            _, asn = parse_as_number(origin)
        except ValidationError as ve:
            raise WhoisQueryParserException(str(ve))

        query = self._prepare_query().object_classes([object_class]).asn(asn)
        query_result = self.database_handler.execute_query(query)

        prefixes = [r['parsed_data'][object_class] for r in query_result]
        return ' '.join(OrderedSet(prefixes))

    def handle_irrd_set_members(self, parameter: str) -> str:
        """
        !i query - find all members of an as-set or route-set, possibly recursively.
        e.g. !iAS-FOO for non-recursive, !iAS-FOO,1 for recursive
        """
        recursive = False
        if parameter.endswith(',1'):
            recursive = True
            parameter = parameter[:-2]

        self._current_set_priority_source = None
        if not recursive:
            members, leaf_members = self._find_set_members({parameter})
            members.update(leaf_members)
        else:
            members = self._recursive_set_resolve({parameter})
        if parameter in members:
            members.remove(parameter)
        return ' '.join(sorted(members))

    def _recursive_set_resolve(self, members: Set[str], sets_seen=None) -> Set[str]:
        """
        Resolve all members of a number of sets, recursively.

        For each set in members, determines whether it has been seen already (to prevent
        infinite recursion), ignores it if already seen, and then either adds
        it directly or adds it to a set that requires further resolving.
        """
        if not sets_seen:
            sets_seen = set()

        if all([member in sets_seen for member in members]):
            return set()
        sets_seen.update(members)

        set_members = set()
        sub_members, leaf_members = self._find_set_members(members)

        for sub_member in sub_members:
            try:
                IP(sub_member)
                set_members.add(sub_member)
                continue
            except ValueError:
                pass
            try:
                parse_as_number(sub_member)
                set_members.add(sub_member)
                continue
            except ValueError:
                pass

        further_resolving_required = sub_members - set_members - sets_seen
        new_members = self._recursive_set_resolve(further_resolving_required, sets_seen)
        set_members.update(new_members)

        return set_members

    def _find_set_members(self, set_names: Set[str]) -> Tuple[Set[str], Set[str]]:
        """
        Find all members of a number of route-sets or as-sets. Includes both
        direct members listed in members attribute, but also
        members included by mbrs-by-ref/member-of.

        Returns a tuple of two sets:
        - members found of the sets included in set_names, both
          references to other sets and direct AS numbers, etc.
        - leaf members that were included in set_names, i.e.
          names for which no further data could be found - for
          example references to non-existent other sets
        """
        members: Set[str] = set()
        sets_already_resolved: Set[str] = set()

        query = self._prepare_query().object_classes(['as-set', 'route-set']).rpsl_pks(set_names)
        if self._current_set_priority_source:
            query.prioritise_source(self._current_set_priority_source)
        query_result = list(self.database_handler.execute_query(query))

        if not query_result:
            # No sub-members are found, and apparantly all inputs were leaf members.
            return set(), set_names

        # Track the source of the root object set
        if not self._current_set_priority_source:
            self._current_set_priority_source = query_result[0]['source']

        for result in query_result:
            rpsl_pk = result['rpsl_pk']

            # The same PK may occur in multiple sources, but we are
            # only interested in the first matching object, prioritised
            # to look for the same source as the root object. This priority
            # is part of the query ORDER BY, so basically we only process
            # an RPSL pk once.
            if rpsl_pk in sets_already_resolved:
                continue
            sets_already_resolved.add(rpsl_pk)

            object_class = result['object_class']
            object_data = result['parsed_data']
            mbrs_by_ref = object_data.get('mbrs-by-ref', None)
            for members_attr in ['members', 'mp-members']:
                if members_attr in object_data:
                    members.update(set(object_data[members_attr]))

            if not rpsl_pk or not object_class or not mbrs_by_ref:
                continue

            # If mbrs-by-ref is set, find any objects with member-of pointing to the route/as-set
            # under query, and include a maintainer listed in mbrs-by-ref, unless mbrs-by-ref
            # is set to ANY.
            query_object_class = ['route', 'route6'] if object_class == 'route-set' else ['aut-num']
            query = self._prepare_query().object_classes(query_object_class)
            query = query.lookup_attrs_in(['member-of'], [rpsl_pk])

            if 'ANY' not in [m.strip().upper() for m in mbrs_by_ref]:
                query = query.lookup_attrs_in(['mnt-by'], mbrs_by_ref)

            referring_objects = self.database_handler.execute_query(query)

            for result in referring_objects:
                member_object_class = result['object_class']
                members.add(result['parsed_data'][member_object_class])

        leaf_members = set_names - sets_already_resolved
        return members, leaf_members

    def handle_irrd_database_serial_range(self, parameter: str) -> str:
        """!j query - database serial range"""
        if parameter == '-*':
            default = get_setting('sources_default')
            sources = default if default else self.all_valid_sources
        else:
            sources = [s.upper() for s in parameter.split(',')]
        invalid_sources = [s for s in sources if s not in self.all_valid_sources]
        query = DatabaseStatusQuery().sources(sources)
        query_results = self.database_handler.execute_query(query)

        result_txt = ''
        for query_result in query_results:
            source = query_result['source'].upper()
            keep_journal = 'Y' if get_setting(f'sources.{source}.keep_journal') else 'N'
            serial_oldest = query_result['serial_oldest_journal']
            serial_newest = query_result['serial_newest_journal']
            fields = [
                source,
                keep_journal,
                f'{serial_oldest}-{serial_newest}' if serial_oldest and serial_newest else '-',
            ]
            if query_result['serial_last_export']:
                fields.append(str(query_result['serial_last_export']))
            result_txt += ':'.join(fields) + '\n'

        for invalid_source in invalid_sources:
            result_txt += f'{invalid_source.upper()}:X:Database unknown\n'
        return result_txt.strip()

    def handle_irrd_exact_key(self, parameter: str):
        """!m query - exact object key lookup, e.g. !maut-num,AS65537"""
        try:
            object_class, rpsl_pk = parameter.split(',', maxsplit=1)
        except ValueError:
            raise WhoisQueryParserException(f'Invalid argument for object lookup: {parameter}')
        query = self._prepare_query().object_classes([object_class]).rpsl_pk(rpsl_pk).first_only()
        return self._execute_query_flatten_output(query)

    def handle_irrd_route_search(self, parameter: str):
        """
        !r query - route search with various options:
           !r192.0.2.0/24 returns all exact matching objects
           !r192.0.2.0/24,o returns space-separated origins of all exact matching objects
           !r192.0.2.0/24,l returns all one-level less specific objects, not including exact
           !r192.0.2.0/24,L returns all less specific objects, including exact
           !r192.0.2.0/24,M returns all more specific objects, not including exact
        """
        option: Optional[str] = None
        if ',' in parameter:
            address, option = parameter.split(',')
        else:
            address = parameter
        try:
            address = IP(address)
        except ValueError:
            raise WhoisQueryParserException(f'Invalid input for route search: {parameter}')

        query = self._prepare_query().object_classes(['route', 'route6'])
        if option is None or option == 'o':
            query = query.ip_exact(address)
        elif option == 'l':
            query = query.ip_less_specific_one_level(address)
        elif option == 'L':
            query = query.ip_less_specific(address)
        elif option == 'M':
            query = query.ip_more_specific(address)
        else:
            raise WhoisQueryParserException(f'Invalid route search option: {option}')

        if option == 'o':
            query_result = self.database_handler.execute_query(query)
            prefixes = [r['parsed_data']['origin'] for r in query_result]
            return ' '.join(prefixes)
        return self._execute_query_flatten_output(query)

    def handle_irrd_sources_list(self, parameter: str) -> Optional[str]:
        """
        !s query - set used sources
           !s-lc returns all enabled sources, space separated
           !sripe,nttcom limits sources to ripe and nttcom
        """
        if parameter == '-lc':
            default = get_setting('sources_default')
            sources_selected = default if default else self.all_valid_sources
            return ','.join(sources_selected)
        if parameter:
            sources = parameter.upper().split(',')
            if not all([source in self.all_valid_sources for source in sources]):
                raise WhoisQueryParserException("One or more selected sources are unavailable.")
            self.sources = sources
        else:
            raise WhoisQueryParserException("One or more selected sources are unavailable.")
        return None

    def handle_irrd_version(self):
        """!v query - return version"""
        return f'IRRD -- version {__version__}'

    def handle_ripe_command(self, full_query: str) -> WhoisQueryResponse:
        """
        Process RIPE-style queries. Any query that is not explicitly an IRRD-style
        query (i.e. starts with exclamation mark) is presumed to be a RIPE query.
        """
        full_query = re.sub(' +', ' ', full_query)
        components = full_query.strip().split(' ')
        result = None
        response_type = WhoisQueryResponseType.SUCCESS

        while len(components):
            component = components.pop(0)
            if component.startswith('-'):
                command = component[1:]
                try:
                    if command == 'k':
                        self.multiple_command_mode = True
                    elif command in ['l', 'L', 'M', 'x']:
                        result = self.handle_ripe_route_search(command, components.pop(0))
                        if not result:
                            response_type = WhoisQueryResponseType.KEY_NOT_FOUND
                        break
                    elif command == 'i':
                        result = self.handle_inverse_attr_search(components.pop(0), components.pop(0))
                        if not result:
                            response_type = WhoisQueryResponseType.KEY_NOT_FOUND
                        break
                    elif command == 's':
                        self.handle_ripe_sources_list(components.pop(0))
                    elif command == 'a':
                        self.handle_ripe_sources_list(None)
                    elif command == 'T':
                        self.handle_ripe_restrict_object_class(components.pop(0))
                    elif command == 't':
                        result = self.handle_ripe_request_object_template(components.pop(0))
                        break
                    elif command == 'K':
                        self.handle_ripe_key_fields_only()
                    elif command == 'V':
                        self.handle_user_agent(components.pop(0))
                    elif command in ['F', 'r']:
                        continue  # These flags disable recursion, but IRRd never performs recursion anyways
                    else:
                        raise WhoisQueryParserException(f'Unrecognised flag/search: {command}')
                except IndexError:
                    raise WhoisQueryParserException(f'Missing argument for flag/search: {command}')
            else:  # assume query to be a free text search
                result = self.handle_ripe_text_search(component)

        return WhoisQueryResponse(
            response_type=response_type,
            mode=WhoisQueryResponseMode.RIPE,
            result=result,
        )

    def handle_ripe_route_search(self, command: str, parameter: str) -> str:
        """
        -l/L/M/x query - route search for:
           -x 192.0.2.0/2 returns all exact matching objects
           -l 192.0.2.0/2 returns all one-level less specific objects, not including exact
           -L 192.0.2.0/2 returns all less specific objects, including exact
           -M 192.0.2.0/2 returns all more specific objects, not including exact
        """
        try:
            address = IP(parameter)
        except ValueError:
            raise WhoisQueryParserException(f'Invalid input for route search: {parameter}')

        query = self._prepare_query().object_classes(['route', 'route6'])
        if command == 'x':
            query = query.ip_exact(address)
        elif command == 'l':
            query = query.ip_less_specific_one_level(address)
        elif command == 'L':
            query = query.ip_less_specific(address)
        elif command == 'M':
            query = query.ip_more_specific(address)

        return self._execute_query_flatten_output(query)

    def handle_ripe_sources_list(self, sources_list: Optional[str]) -> None:
        """-s/-a parameter - set sources list. Empty list enables all sources. """
        if sources_list:
            sources = sources_list.upper().split(',')
            if not all([source in self.all_valid_sources for source in sources]):
                raise WhoisQueryParserException("One or more selected sources are unavailable.")
            self.sources = sources
        else:
            self.sources = []

    def handle_ripe_restrict_object_class(self, object_classes) -> None:
        """-T parameter - restrict object classes for this query, comma-seperated"""
        self.object_classes = object_classes.split(',')

    def handle_ripe_request_object_template(self, object_class) -> str:
        """-t query - return the RPSL template for an object class"""
        try:
            return OBJECT_CLASS_MAPPING[object_class]().generate_template()
        except KeyError:
            raise WhoisQueryParserException(f'Unknown object class: {object_class}')

    def handle_ripe_key_fields_only(self) -> None:
        """-K paramater - only return primary key and members fields"""
        self.key_fields_only = True

    def handle_ripe_text_search(self, value: str) -> str:
        query = self._prepare_query().text_search(value)
        return self._execute_query_flatten_output(query)

    def handle_user_agent(self, user_agent: str):
        """-V/!n parameter/query - set a user agent for the client"""
        self.user_agent = user_agent
        logger.info(f'{self.peer}: user agent set to: {user_agent}')

    def handle_inverse_attr_search(self, attribute: str, value: str) -> str:
        """
        -i/!o query - inverse search for attribute values
        e.g. `-i mnt-by FOO` finds all objects where (one of the) maintainer(s) is FOO,
        as does `!oFOO`. Restricted to designated lookup fields.
        """
        if attribute not in self.lookup_field_names:
            readable_lookup_field_names = ", ".join(self.lookup_field_names)
            msg = (f'Inverse attribute search not supported for {attribute},' +
                   f'only supported for attributes: {readable_lookup_field_names}')
            raise WhoisQueryParserException(msg)
        query = self._prepare_query().lookup_attr(attribute, value)
        return self._execute_query_flatten_output(query)

    def _prepare_query(self) -> RPSLDatabaseQuery:
        """Prepare an RPSLDatabaseQuery by applying relevant sources/class filters."""
        query = RPSLDatabaseQuery()
        if self.sources:
            query.sources(self.sources)
        else:
            default = get_setting('sources_default')
            if default:
                query.sources(list(default))
        if self.object_classes:
            query.object_classes(self.object_classes)
        return query

    def _execute_query_flatten_output(self, query: RPSLDatabaseQuery) -> str:
        """
        Execute an RPSLDatabaseQuery, and flatten the output into a string with object text
        for easy passing to a WhoisQueryResponse.
        """
        query_response = self.database_handler.execute_query(query)
        if self.key_fields_only:
            result = self._filter_key_fields(query_response)
        else:
            result = ''
            for obj in query_response:
                result += obj['object_text'] + '\n'
        return result.strip('\n\r')

    def _filter_key_fields(self, query_response) -> str:
        results = OrderedSet()
        for obj in query_response:
            result = ''
            rpsl_object_class = OBJECT_CLASS_MAPPING[obj['object_class']]
            fields_included = rpsl_object_class.pk_fields + ['members', 'mp-members']

            for field_name in fields_included:
                field_data = obj['parsed_data'].get(field_name)
                if field_data:
                    if isinstance(field_data, list):
                        for item in field_data:
                            result += f'{field_name}: {item}\n'
                    else:
                        result += f'{field_name}: {field_data}\n'
            results.add(result)
        return '\n'.join(results)
예제 #18
0
def set_force_reload(source) -> None:
    dh = DatabaseHandler()
    dh.set_force_reload(source)
    dh.commit()
    dh.close()
예제 #19
0
class SourceExportRunner:
    """
    This SourceExportRunner is the entry point for the export process
    for a single source.

    A gzipped file will be created in the export_destination directory
    with the contents of the source, along with a CURRENTSERIAL file.

    The contents of the source are first written to a temporary file, and
    then moved in place.
    """
    def __init__(self, source: str) -> None:
        self.source = source

    def run(self) -> None:
        self.database_handler = DatabaseHandler()
        try:
            export_destination = get_setting(
                f'sources.{self.source}.export_destination')
            if export_destination:
                logger.info(
                    f'Starting a source export for {self.source} to {export_destination}'
                )
                self._export(export_destination)

            export_destination_unfiltered = get_setting(
                f'sources.{self.source}.export_destination_unfiltered')
            if export_destination_unfiltered:
                logger.info(
                    f'Starting an unfiltered source export for {self.source} '
                    f'to {export_destination_unfiltered}')
                self._export(export_destination_unfiltered,
                             remove_auth_hashes=False)

            self.database_handler.commit()
        except Exception as exc:
            logger.error(
                f'An exception occurred while attempting to run an export '
                f'for {self.source}: {exc}',
                exc_info=exc)
        finally:
            self.database_handler.close()

    def _export(self, export_destination, remove_auth_hashes=True):
        filename_export = Path(
            export_destination) / f'{self.source.lower()}.db.gz'
        export_tmpfile = NamedTemporaryFile(delete=False)
        filename_serial = Path(
            export_destination) / f'{self.source.upper()}.CURRENTSERIAL'

        query = DatabaseStatusQuery().source(self.source)

        try:
            serial = next(self.database_handler.execute_query(
                query))['serial_newest_seen']
        except StopIteration:
            serial = None

        with gzip.open(export_tmpfile.name, 'wb') as fh:
            query = RPSLDatabaseQuery().sources([self.source])
            query = query.rpki_status([RPKIStatus.not_found, RPKIStatus.valid])
            query = query.scopefilter_status([ScopeFilterStatus.in_scope])
            for obj in self.database_handler.execute_query(query):
                object_text = obj['object_text']
                if remove_auth_hashes:
                    object_text = remove_auth_hashes_func(object_text)
                object_bytes = object_text.encode('utf-8')
                fh.write(object_bytes + b'\n')
            fh.write(b'# EOF\n')

        os.chmod(export_tmpfile.name, EXPORT_PERMISSIONS)
        if filename_export.exists():
            os.unlink(filename_export)
        if filename_serial.exists():
            os.unlink(filename_serial)
        shutil.move(export_tmpfile.name, filename_export)

        if serial is not None:
            with open(filename_serial, 'w') as fh:
                fh.write(str(serial))
            os.chmod(filename_serial, EXPORT_PERMISSIONS)

        self.database_handler.record_serial_exported(self.source, serial)
        logger.info(
            f'Export for {self.source} complete at serial {serial}, stored in {filename_export} / {filename_serial}'
        )