Esempio n. 1
0
 def _set_rpsl_query_fields(self):
     """
     Create a sub-schema for the fields that can be queried for RPSL objects.
     This includes all fields from all objects, along with a few
     special fields.
     """
     string_list_fields = {'rpsl_pk', 'sources',
                           'object_class'}.union(lookup_field_names())
     params = [
         snake_to_camel_case(p) + ': [String!]'
         for p in sorted(string_list_fields)
     ]
     params += [
         'ipExact: IP',
         'ipLessSpecific: IP',
         'ipLessSpecificOneLevel: IP',
         'ipMoreSpecific: IP',
         'ipAny: IP',
         'asn: [ASN!]',
         'rpkiStatus: [RPKIStatus!]',
         'scopeFilterStatus: [ScopeFilterStatus!]',
         'textSearch: String',
         'recordLimit: Int',
         'sqlTrace: Boolean',
     ]
     self.rpsl_query_fields = ', '.join(params)
Esempio n. 2
0
 def __table_args__(cls):  # noqa
     args = [
         sa.UniqueConstraint(
             'rpsl_pk',
             'source',
             'object_class',
             name='rpsl_objects_rpsl_pk_source_class_unique'),
         sa.Index(
             'ix_rpsl_objects_ip_first_ip_last',
             'ip_first',
             'ip_last',
         ),
         sa.Index('ix_rpsl_objects_ip_last_ip_first', 'ip_last',
                  'ip_first'),
         sa.Index('ix_rpsl_objects_asn_first_asn_last', 'asn_first',
                  'asn_last'),
         sa.Index('ix_rpsl_objects_prefix_gist',
                  sa.text('prefix inet_ops'),
                  postgresql_using='gist')
     ]
     for name in lookup_field_names():
         index_name = 'ix_rpsl_objects_parsed_data_' + name.replace(
             '-', '_')
         index_on = sa.text(f"(parsed_data->'{name}')")
         args.append(sa.Index(index_name, index_on, postgresql_using='gin'))
     return tuple(args)
Esempio n. 3
0
 def __table_args__(cls):  # noqa
     args = [
         sa.UniqueConstraint('rpsl_pk',
                             'source',
                             name='rpsl_objects_rpsl_pk_source_unique'),
     ]
     for name in lookup_field_names():
         index_name = 'ix_rpsl_objects_parsed_data_' + name.replace(
             '-', '_')
         index_on = sa.text(f"(parsed_data->'{name}')")
         args.append(sa.Index(index_name, index_on, postgresql_using="gin"))
     return tuple(args)
Esempio n. 4
0
class RPSLDatabaseQuery(BaseRPSLObjectDatabaseQuery):
    """
    RPSL data query builder for retrieving RPSL objects.

    Offers various ways to filter, which are always constructed in an AND query.
    For example:
        q = RPSLDatabaseQuery().sources(['NTTCOM']).asn_less_specific(65537)
    would match all objects that refer or include AS65537 (i.e. aut-num, route,
    as-block, route6) from the NTTCOM source.

    For methods taking a prefix or IP address, this should be an IPy.IP object.
    """
    table = RPSLDatabaseObject.__table__
    columns = RPSLDatabaseObject.__table__.c
    lookup_field_names = lookup_field_names()

    def __init__(self):
        super().__init__()
        self.statement = sa.select([
            self.columns.pk,
            self.columns.object_class,
            self.columns.rpsl_pk,
            self.columns.parsed_data,
            self.columns.object_text,
            self.columns.source,
        ])
        self._lookup_attr_counter = 0

    def lookup_attr(self, attr_name: str, attr_value: str):
        """
        Filter on a lookup attribute, e.g. mnt-by.
        At least one of the values for the lookup attribute must match attr_value.
        Matching is case-insensitive.
        """
        return self.lookup_attrs_in([attr_name], [attr_value])

    def lookup_attrs_in(self, attr_names: List[str], attr_values: List[str]):
        """
        Filter on one or more lookup attributes, e.g. mnt-by, or ['admin-c', 'tech-c']
        At least one of the values for at least one of the lookup attributes must
        match one of the items in attr_values. Matching is case-insensitive.
        """
        attr_names = [attr_name.lower() for attr_name in attr_names]
        for attr_name in attr_names:
            if attr_name not in self.lookup_field_names:
                raise ValueError(f"Invalid lookup attribute: {attr_name}")
        self._check_query_frozen()

        value_filters = []
        statement_params = {}
        for attr_name in attr_names:
            for attr_value in attr_values:
                counter = self._lookup_attr_counter
                self._lookup_attr_counter += 1
                value_filters.append(sa.text(f"parsed_data->:lookup_attr_name{counter} ? :lookup_attr_value{counter}"))
                statement_params[f"lookup_attr_name{counter}"] = attr_name
                statement_params[f"lookup_attr_value{counter}"] = attr_value.upper()
        fltr = sa.or_(*value_filters)
        self.statement = self.statement.where(fltr).params(**statement_params)

        return self

    def ip_exact(self, ip: IP):
        """
        Filter on an exact prefix or address.

        The provided ip should be an IPy.IP class, and can be a prefix or
        an address.
        """
        fltr = sa.and_(
            self.columns.ip_first == str(ip.net()),
            self.columns.ip_last == str(ip.broadcast()),
            self.columns.ip_version == ip.version()
        )
        return self._filter(fltr)

    def ip_less_specific(self, ip: IP):
        """Filter any less specifics or exact matches of a prefix."""
        fltr = sa.and_(
            self.columns.ip_first <= str(ip.net()),
            self.columns.ip_last >= str(ip.broadcast()),
            self.columns.ip_version == ip.version()
        )
        return self._filter(fltr)

    def ip_less_specific_one_level(self, ip: IP):
        """
        Filter one level less specific of a prefix.

        Due to implementation details around filtering, this must
        always be the last call on a query object, or unpredictable
        results may occur.
        """
        self._check_query_frozen()
        # One level less specific could still have multiple objects.
        # A subquery determines the smallest possible size less specific object,
        # and this is then used to filter for any objects with that size.
        fltr = sa.and_(
            self.columns.ip_first <= str(ip.net()),
            self.columns.ip_last >= str(ip.broadcast()),
            self.columns.ip_version == ip.version(),
            sa.not_(sa.and_(self.columns.ip_first == str(ip.net()), self.columns.ip_last == str(ip.broadcast()))),
        )
        self.statement = self.statement.where(fltr)

        size_subquery = self.statement.with_only_columns([self.columns.ip_size])
        size_subquery = size_subquery.order_by(self.columns.ip_size.asc())
        size_subquery = size_subquery.limit(1)

        self.statement = self.statement.where(self.columns.ip_size.in_(size_subquery))
        self._query_frozen = True
        return self

    def ip_more_specific(self, ip: IP):
        """Filter any more specifics of a prefix, not including exact matches.

        Note that this only finds full more specifics: objects for which their
        IP range is fully encompassed by the ip parameter.
        """
        fltr = sa.and_(
            self.columns.ip_first >= str(ip.net()),
            self.columns.ip_first <= str(ip.broadcast()),
            self.columns.ip_last <= str(ip.broadcast()),
            self.columns.ip_last >= str(ip.net()),
            self.columns.ip_version == ip.version(),
            sa.not_(sa.and_(self.columns.ip_first == str(ip.net()), self.columns.ip_last == str(ip.broadcast()))),
        )
        return self._filter(fltr)

    def asn(self, asn: int):
        """
        Filter for exact matches on an ASN.
        """
        fltr = sa.and_(self.columns.asn_first == asn, self.columns.asn_last == asn)
        return self._filter(fltr)

    def asn_less_specific(self, asn: int):
        """
        Filter for a specific ASN, or any less specific matches.

        This will match all objects that refer to this ASN, or a block
        encompassing it - including route, route6, aut-num and as-block.
        """
        fltr = sa.and_(self.columns.asn_first <= asn, self.columns.asn_last >= asn)
        return self._filter(fltr)

    def text_search(self, value: str):
        """
        Search the database for a specific free text.

        In order, this attempts:
        - If the value is a valid AS number, return all as-block, as-set, aut-num objects
          relating or including that AS number.
        - If the value is a valid IP address or network, return all objects that relate to
          that resource and any less specifics.
        - Otherwise, return all objects where the RPSL primary key is exactly this value,
          or it matches part of a person/role name (not nic-hdl, their
          actual person/role attribute value).
        """
        self._check_query_frozen()
        try:
            _, asn = parse_as_number(value)
            return self.object_classes(['as-block', 'as-set', 'aut-num']).asn_less_specific(asn)
        except ValidationError:
            pass

        try:
            ip = IP(value)
            return self.ip_less_specific(ip)
        except ValueError:
            pass

        counter = self._lookup_attr_counter
        self._lookup_attr_counter += 1
        fltr = sa.or_(
            self.columns.rpsl_pk == value.upper(),
            sa.and_(
                self.columns.object_class == 'person',
                sa.text(f"parsed_data->>'person' ILIKE :lookup_attr_text_search{counter}")
            ),
            sa.and_(
                self.columns.object_class == 'role',
                sa.text(f"parsed_data->>'role' ILIKE :lookup_attr_text_search{counter}")
            ),
        )
        self.statement = self.statement.where(fltr).params(
            **{f'lookup_attr_text_search{counter}': '%' + value + '%'}
        )
        return self

    def __repr__(self):
        return f"RPSLDatabaseQuery: {self.statement}\nPARAMS: {self.statement.compile().params}"
Esempio n. 5
0
class WhoisQueryParser:
    """
    Parser for all whois-style queries.

    This parser distinguishes RIPE-style, e.g. "-K 192.0.2.1" or "-i mnt-by FOO"
    from IRRD-style, e.g. "!oFOO".

    Some query flags, particularly -k/!! and -s/!s retain state across queries,
    so a single instance of this object should be created per session, with
    handle_query() being called for each individual query.
    """
    lookup_field_names = lookup_field_names()
    database_handler: DatabaseHandler

    def __init__(self, peer: str) -> None:
        self.all_valid_sources = list(get_setting('sources').keys())
        self.sources: List[str] = []
        self.object_classes: List[str] = []
        self.user_agent: Optional[str] = None
        self.multiple_command_mode = False
        self.key_fields_only = False
        self.peer = peer

    def handle_query(self, query: str) -> WhoisQueryResponse:
        """Process a single query. Always returns a WhoisQueryResponse object."""
        # These flags are reset with every query.
        self.database_handler = DatabaseHandler()
        self.key_fields_only = False
        self.object_classes = []

        if query.startswith('!'):
            try:
                return self.handle_irrd_command(query[1:])
            except WhoisQueryParserException as exc:
                logger.info(
                    f'{self.peer}: encountered parsing error while parsing query {query}: {exc}'
                )
                return WhoisQueryResponse(
                    response_type=WhoisQueryResponseType.ERROR,
                    mode=WhoisQueryResponseMode.IRRD,
                    result=str(exc))
            finally:
                self.database_handler.close()

        try:
            return self.handle_ripe_command(query)
        except WhoisQueryParserException as exc:
            logger.info(
                f'{self.peer}: encountered parsing error while parsing query {query}: {exc}'
            )
            return WhoisQueryResponse(
                response_type=WhoisQueryResponseType.ERROR,
                mode=WhoisQueryResponseMode.RIPE,
                result=str(exc))
        finally:
            self.database_handler.close()

    def handle_irrd_command(self, full_command: str) -> WhoisQueryResponse:
        """Handle an IRRD-style query. full_command should not include the first exclamation mark. """
        if not full_command:
            raise WhoisQueryParserException(f'Missing IRRD command')
        command = full_command[0].upper()
        parameter = full_command[1:]
        response_type = WhoisQueryResponseType.SUCCESS
        result = None

        if command == '!':
            self.multiple_command_mode = True
            result = None
        elif command == 'G':
            result = self.handle_irrd_routes_for_origin_v4(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == '6':
            result = self.handle_irrd_routes_for_origin_v6(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'I':
            result = self.handle_irrd_set_members(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'J':
            result = self.handle_irrd_database_serial_range(parameter)
        elif command == 'M':
            result = self.handle_irrd_exact_key(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'N':
            self.handle_user_agent(parameter)
        elif command == 'O':
            result = self.handle_inverse_attr_search('mnt-by', parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'R':
            result = self.handle_irrd_route_search(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'S':
            result = self.handle_irrd_sources_list(parameter)
        elif command == 'V':
            result = self.handle_irrd_version()
        else:
            raise WhoisQueryParserException(f'Unrecognised command: {command}')

        return WhoisQueryResponse(
            response_type=response_type,
            mode=WhoisQueryResponseMode.IRRD,
            result=result,
        )

    def handle_irrd_routes_for_origin_v4(self, origin: str) -> str:
        """!g query - find all originating IPv4 prefixes from an origin, e.g. !gAS65537"""
        return self._routes_for_origin('route', origin)

    def handle_irrd_routes_for_origin_v6(self, origin: str) -> str:
        """!6 query - find all originating IPv6 prefixes from an origin, e.g. !6as65537"""
        return self._routes_for_origin('route6', origin)

    def _routes_for_origin(self, object_class: str, origin: str) -> str:
        """
        Resolve all route(6)s for an origin, returning a space-separated list
        of all originating prefixes, not including duplicates.
        """
        try:
            _, asn = parse_as_number(origin)
        except ValidationError as ve:
            raise WhoisQueryParserException(str(ve))

        query = self._prepare_query().object_classes([object_class]).asn(asn)
        query_result = self.database_handler.execute_query(query)

        prefixes = [r['parsed_data'][object_class] for r in query_result]
        unique_prefixes: List[str] = []
        for prefix in prefixes:
            if prefix not in unique_prefixes:
                unique_prefixes.append(prefix)

        return ' '.join(unique_prefixes)

    def handle_irrd_set_members(self, parameter: str) -> str:
        """
        !i query - find all members of an as-set or route-set, possibly recursively.
        e.g. !iAS-FOO for non-recursive, !iAS-FOO,1 for recursive
        """
        recursive = False
        if parameter.endswith(',1'):
            recursive = True
            parameter = parameter[:-2]

        self._current_set_priority_source = None
        if not recursive:
            members, leaf_members = self._find_set_members({parameter})
            members.update(leaf_members)
        else:
            members = self._recursive_set_resolve({parameter})
            if parameter in members:
                members.remove(parameter)
        return ' '.join(sorted(members))

    def _recursive_set_resolve(self,
                               members: Set[str],
                               sets_seen=None) -> Set[str]:
        """
        Resolve all members of a number of sets, recursively.

        For each set in members, determines whether it has been seen already (to prevent
        infinite recursion), ignores it if already seen, and then either adds
        it directly or adds it to a set that requires further resolving.
        """
        if not sets_seen:
            sets_seen = set()

        if all([member in sets_seen for member in members]):
            return set()
        sets_seen.update(members)

        set_members = set()
        sub_members, leaf_members = self._find_set_members(members)
        set_members.update(leaf_members)

        for sub_member in sub_members:
            try:
                IP(sub_member)
                set_members.add(sub_member)
                continue
            except ValueError:
                pass
            try:
                parse_as_number(sub_member)
                set_members.add(sub_member)
                continue
            except ValueError:
                pass

        further_resolving_required = sub_members - set_members - sets_seen
        new_members = self._recursive_set_resolve(further_resolving_required,
                                                  sets_seen)
        set_members.update(new_members)

        return set_members

    def _find_set_members(self,
                          set_names: Set[str]) -> Tuple[Set[str], Set[str]]:
        """
        Find all members of a number of route-sets or as-sets. Includes both
        direct members listed in members attribute, but also
        members included by mbrs-by-ref/member-of.

        Returns a tuple of two sets:
        - members found of the sets included in set_names, both
          references to other sets and direct AS numbers, etc.
        - leaf members that were included in set_names, i.e.
          names for which no further data could be found - for
          example references to non-existent other sets
        """
        members: Set[str] = set()
        sets_already_resolved: Set[str] = set()

        query = self._prepare_query().object_classes(['as-set', 'route-set'
                                                      ]).rpsl_pks(set_names)
        if self._current_set_priority_source:
            query.prioritise_source(self._current_set_priority_source)
        query_result = list(self.database_handler.execute_query(query))

        if not query_result:
            # No sub-members are found, and apparantly all inputs were leaf members.
            return set(), set_names

        # Track the source of the root object set
        if not self._current_set_priority_source:
            self._current_set_priority_source = query_result[0]['source']

        for result in query_result:
            rpsl_pk = result['rpsl_pk']

            # The same PK may occur in multiple sources, but we are
            # only interested in the first matching object, prioritised
            # to look for the same source as the root object. This priority
            # is part of the query ORDER BY, so basically we only process
            # an RPSL pk once.
            if rpsl_pk in sets_already_resolved:
                continue
            sets_already_resolved.add(rpsl_pk)

            object_class = result['object_class']
            object_data = result['parsed_data']
            mbrs_by_ref = object_data.get('mbrs-by-ref', None)
            for members_attr in ['members', 'mp-members']:
                if members_attr in object_data:
                    members.update(set(object_data[members_attr]))

            if not rpsl_pk or not object_class or not mbrs_by_ref:
                continue

            # If mbrs-by-ref is set, find any objects with member-of pointing to the route/as-set
            # under query, and include a maintainer listed in mbrs-by-ref, unless mbrs-by-ref
            # is set to ANY.
            query_object_class = [
                'route', 'route6'
            ] if object_class == 'route-set' else ['aut-num']
            query = self._prepare_query().object_classes(query_object_class)
            query = query.lookup_attrs_in(['member-of'], [rpsl_pk])

            if 'ANY' not in [m.strip().upper() for m in mbrs_by_ref]:
                query = query.lookup_attrs_in(['mnt-by'], mbrs_by_ref)

            referring_objects = self.database_handler.execute_query(query)

            for result in referring_objects:
                member_object_class = result['object_class']
                members.add(result['parsed_data'][member_object_class])

        leaf_members = set_names - sets_already_resolved
        return members, leaf_members

    def handle_irrd_database_serial_range(self, parameter: str) -> str:
        """!j query - database serial range"""
        if parameter == '-*':
            sources = self.all_valid_sources
        else:
            sources = [s.upper() for s in parameter.split(',')]
        invalid_sources = [
            s for s in sources if s not in self.all_valid_sources
        ]
        query = RPSLDatabaseStatusQuery().sources(sources)
        query_results = self.database_handler.execute_query(query)

        result_txt = ''
        for query_result in query_results:
            source = query_result['source'].upper()
            keep_journal = 'Y' if get_setting(
                f'sources.{source}.keep_journal') else 'N'
            serial_oldest = query_result['serial_oldest_journal']
            serial_newest = query_result['serial_newest_journal']
            fields = [
                source,
                keep_journal,
                f'{serial_oldest}-{serial_newest}'
                if serial_oldest and serial_newest else '-',
            ]
            if query_result['serial_last_dump']:
                fields.append(str(query_result['serial_last_dump']))
            result_txt += ':'.join(fields) + '\n'

        for invalid_source in invalid_sources:
            result_txt += f'{invalid_source.upper()}:X:Database unknown\n'
        return result_txt.strip()

    def handle_irrd_exact_key(self, parameter: str):
        """!m query - exact object key lookup, e.g. !maut-num,AS65537"""
        try:
            object_class, rpsl_pk = parameter.split(',', maxsplit=1)
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid argument for object lookup: {parameter}')
        query = self._prepare_query().object_classes(
            [object_class]).rpsl_pk(rpsl_pk).first_only()
        return self._execute_query_flatten_output(query)

    def handle_irrd_route_search(self, parameter: str):
        """
        !r query - route search with various options:
           !r192.0.2.0/24 returns all exact matching objects
           !r192.0.2.0/24,o returns space-separated origins of all exact matching objects
           !r192.0.2.0/24,l returns all one-level less specific objects, not including exact
           !r192.0.2.0/24,L returns all less specific objects, including exact
           !r192.0.2.0/24,M returns all more specific objects, not including exact
        """
        option: Optional[str] = None
        if ',' in parameter:
            address, option = parameter.split(',')
        else:
            address = parameter
        try:
            address = IP(address)
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid input for route search: {parameter}')

        query = self._prepare_query().object_classes(['route', 'route6'])
        if option is None or option == 'o':
            query = query.ip_exact(address)
        elif option == 'l':
            query = query.ip_less_specific_one_level(address)
        elif option == 'L':
            query = query.ip_less_specific(address)
        elif option == 'M':
            query = query.ip_more_specific(address)
        else:
            raise WhoisQueryParserException(
                f'Invalid route search option: {option}')

        if option == 'o':
            query_result = self.database_handler.execute_query(query)
            prefixes = [r['parsed_data']['origin'] for r in query_result]
            return ' '.join(prefixes)
        return self._execute_query_flatten_output(query)

    def handle_irrd_sources_list(self, parameter: str) -> Optional[str]:
        """
        !s query - set used sources
           !s-lc returns all enabled sources, space separated
           !sripe,nttcom limits sources to ripe and nttcom
        """
        if parameter == '-lc':
            return ','.join(self.all_valid_sources)
        if parameter:
            sources = parameter.upper().split(',')
            if not all(
                [source in self.all_valid_sources for source in sources]):
                raise WhoisQueryParserException(
                    "One or more selected sources are unavailable.")
            self.sources = sources
        else:
            raise WhoisQueryParserException(
                "One or more selected sources are unavailable.")
        return None

    def handle_irrd_version(self):
        """!v query - return version"""
        return f'IRRD -- version {__version__}'

    def handle_ripe_command(self, full_query: str) -> WhoisQueryResponse:
        """
        Process RIPE-style queries. Any query that is not explicitly an IRRD-style
        query (i.e. starts with exclamation mark) is presumed to be a RIPE query.
        """
        full_query = re.sub(' +', ' ', full_query)
        components = full_query.strip().split(' ')
        result = None
        response_type = WhoisQueryResponseType.SUCCESS

        while len(components):
            component = components.pop(0)
            if component.startswith('-'):
                command = component[1:]
                try:
                    if command == 'k':
                        self.multiple_command_mode = True
                    elif command in 'lLMx':
                        result = self.handle_ripe_route_search(
                            command, components.pop(0))
                        if not result:
                            response_type = WhoisQueryResponseType.KEY_NOT_FOUND
                        break
                    elif command == 'i':
                        result = self.handle_inverse_attr_search(
                            components.pop(0), components.pop(0))
                        if not result:
                            response_type = WhoisQueryResponseType.KEY_NOT_FOUND
                        break
                    elif command == 's':
                        self.handle_ripe_sources_list(components.pop(0))
                    elif command == 'a':
                        self.handle_ripe_sources_list(None)
                    elif command == 'T':
                        self.handle_ripe_restrict_object_class(
                            components.pop(0))
                    elif command == 't':
                        result = self.handle_ripe_request_object_template(
                            components.pop(0))
                        break
                    elif command == 'K':
                        self.handle_ripe_key_fields_only()
                    elif command in 'V':
                        self.handle_user_agent(components.pop(0))
                    elif command in 'Fr':
                        continue  # These flags disable recursion, but IRRd never performs recursion anyways
                    else:
                        raise WhoisQueryParserException(
                            f'Unrecognised flag/search: {command}')
                except IndexError:
                    raise WhoisQueryParserException(
                        f'Missing argument for flag/search: {command}')
            else:  # assume query to be a free text search
                result = self.handle_ripe_text_search(component)

        return WhoisQueryResponse(
            response_type=response_type,
            mode=WhoisQueryResponseMode.RIPE,
            result=result,
        )

    def handle_ripe_route_search(self, command: str, parameter: str) -> str:
        """
        -l/L/M/x query - route search for:
           -x 192.0.2.0/2 returns all exact matching objects
           -l 192.0.2.0/2 returns all one-level less specific objects, not including exact
           -L 192.0.2.0/2 returns all less specific objects, including exact
           -M 192.0.2.0/2 returns all more specific objects, not including exact
        """
        try:
            address = IP(parameter)
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid input for route search: {parameter}')

        query = self._prepare_query().object_classes(['route', 'route6'])
        if command == 'x':
            query = query.ip_exact(address)
        elif command == 'l':
            query = query.ip_less_specific_one_level(address)
        elif command == 'L':
            query = query.ip_less_specific(address)
        elif command == 'M':
            query = query.ip_more_specific(address)

        return self._execute_query_flatten_output(query)

    def handle_ripe_sources_list(self, sources_list: Optional[str]) -> None:
        """-s/-a parameter - set sources list. Empty list enables all sources. """
        if sources_list:
            sources = sources_list.upper().split(',')
            if not all(
                [source in self.all_valid_sources for source in sources]):
                raise WhoisQueryParserException(
                    "One or more selected sources are unavailable.")
            self.sources = sources
        else:
            self.sources = []

    def handle_ripe_restrict_object_class(self, object_classes) -> None:
        """-T parameter - restrict object classes for this query, comma-seperated"""
        self.object_classes = object_classes.split(',')

    def handle_ripe_request_object_template(self, object_class) -> str:
        """-t query - return the RPSL template for an object class"""
        try:
            return OBJECT_CLASS_MAPPING[object_class]().generate_template()
        except KeyError:
            raise WhoisQueryParserException(
                f'Unknown object class: {object_class}')

    def handle_ripe_key_fields_only(self) -> None:
        """-K paramater - only return primary key and members fields"""
        self.key_fields_only = True

    def handle_ripe_text_search(self, value: str) -> str:
        query = self._prepare_query().text_search(value)
        return self._execute_query_flatten_output(query)

    def handle_user_agent(self, user_agent: str):
        """-V/!n parameter/query - set a user agent for the client"""
        self.user_agent = user_agent
        logger.info(f'{self.peer}: user agent set to: {user_agent}')

    def handle_inverse_attr_search(self, attribute: str, value: str) -> str:
        """
        -i/!o query - inverse search for attribute values
        e.g. `-i mnt-by FOO` finds all objects where (one of the) maintainer(s) is FOO,
        as does `!oFOO`. Restricted to designated lookup fields.
        """
        if attribute not in self.lookup_field_names:
            readable_lookup_field_names = ", ".join(self.lookup_field_names)
            msg = (
                f'Inverse attribute search not supported for {attribute},' +
                f'only supported for attributes: {readable_lookup_field_names}'
            )
            raise WhoisQueryParserException(msg)
        query = self._prepare_query().lookup_attr(attribute, value)
        return self._execute_query_flatten_output(query)

    def _prepare_query(self) -> RPSLDatabaseQuery:
        """Prepare an RPSLDatabaseQuery by applying relevant sources/class filters."""
        query = RPSLDatabaseQuery()
        if self.sources:
            query.sources(self.sources)
        if self.object_classes:
            query.object_classes(self.object_classes)
        return query

    def _execute_query_flatten_output(self, query: RPSLDatabaseQuery) -> str:
        """
        Execute an RPSLDatabaseQuery, and flatten the output into a string with object text
        for easy passing to a WhoisQueryResponse.
        """
        query_response = self.database_handler.execute_query(query)
        if self.key_fields_only:
            result = self._filter_key_fields(query_response)
        else:
            result = ''
            for obj in query_response:
                result += obj['object_text'] + '\n'
        return result.strip('\n\r')

    def _filter_key_fields(self, query_response) -> str:
        result = ''
        for obj in query_response:
            rpsl_object_class = OBJECT_CLASS_MAPPING[obj['object_class']]
            fields_included = rpsl_object_class.pk_fields + [
                'members', 'mp-members', 'member-of'
            ]

            for field_name in fields_included:
                field_data = obj['parsed_data'].get(field_name)
                if field_data:
                    if isinstance(field_data, list):
                        for item in field_data:
                            result += f'{field_name}: {item}\n'
                    else:
                        result += f'{field_name}: {field_data}\n'
            result += '\n'
        return result
Esempio n. 6
0
            sa.Index('ix_roa_objects_prefix_gist',
                     sa.text('prefix inet_ops'),
                     postgresql_using='gist')
        ]
        return tuple(args)

    def __repr__(self):
        return f'<{self.prefix}/{self.asn}>'


# Before you update this, please check the storage documentation for changing lookup fields.
expected_lookup_field_names = {
    'admin-c',
    'tech-c',
    'zone-c',
    'member-of',
    'mnt-by',
    'role',
    'members',
    'person',
    'mp-members',
    'origin',
    'mbrs-by-ref',
}
if sorted(lookup_field_names()) != sorted(
        expected_lookup_field_names):  # pragma: no cover
    raise RuntimeError(
        f'Field names of lookup fields do not match expected set. Indexes may be missing. '
        f'Expected: {expected_lookup_field_names}, actual: {lookup_field_names()}'
    )
Esempio n. 7
0
from irrd.rpki.status import RPKIStatus
from irrd.rpsl.rpsl_objects import OBJECT_CLASS_MAPPING, lookup_field_names
from irrd.scopefilter.status import ScopeFilterStatus
from irrd.server.access_check import is_client_permitted
from irrd.storage.queries import RPSLDatabaseQuery, RPSLDatabaseJournalQuery
from irrd.utils.text import snake_to_camel_case, remove_auth_hashes
from .schema_generator import SchemaGenerator
from ..query_resolver import QueryResolver
"""
Resolvers resolve GraphQL queries, usually by translating them
to a database query and then translating the results to an
appropriate format for GraphQL.
"""

schema = SchemaGenerator()
lookup_fields = lookup_field_names()


def resolve_rpsl_object_type(obj: Dict[str, str], *_) -> str:
    """
    Find the GraphQL name for an object given its object class.
    (GraphQL names match RPSL class names.)
    """
    return OBJECT_CLASS_MAPPING[obj.get('objectClass',
                                        obj.get('object_class', ''))].__name__


@ariadne.convert_kwargs_to_snake_case
def resolve_rpsl_objects(_, info: GraphQLResolveInfo, **kwargs):
    """
    Resolve a `rpslObjects` query. This query has a considerable
Esempio n. 8
0
class WhoisQueryParser:
    """
    Parser for all whois-style queries.

    This parser distinguishes RIPE-style, e.g. '-K 192.0.2.1' or '-i mnt-by FOO'
    from IRRD-style, e.g. '!oFOO'.

    Some query flags, particularly -k/!! and -s/!s retain state across queries,
    so a single instance of this object should be created per session, with
    handle_query() being called for each individual query.
    """
    lookup_field_names = lookup_field_names()
    database_handler: DatabaseHandler
    _current_set_root_object_class: Optional[str]

    def __init__(self, client_ip: str, client_str: str) -> None:
        self.all_valid_sources = list(get_setting('sources', {}).keys())
        self.sources_default = get_setting('sources_default')
        self.sources: List[
            str] = self.sources_default if self.sources_default else self.all_valid_sources
        if get_setting('rpki.roa_source'):
            self.all_valid_sources.append(RPKI_IRR_PSEUDO_SOURCE)
        self.object_classes: List[str] = []
        self.user_agent: Optional[str] = None
        self.multiple_command_mode = False
        self.rpki_aware = bool(get_setting('rpki.roa_source'))
        self.rpki_invalid_filter_enabled = self.rpki_aware
        self.timeout = 30
        self.key_fields_only = False
        self.client_ip = client_ip
        self.client_str = client_str
        self.preloader = Preloader()
        self._preloaded_query_count = 0

    def handle_query(self, query: str) -> WhoisQueryResponse:
        """
        Process a single query. Always returns a WhoisQueryResponse object.
        Not thread safe - only one call must be made to this method at the same time.
        """
        # These flags are reset with every query.
        self.database_handler = DatabaseHandler()
        self.key_fields_only = False
        self.object_classes = []

        if query.startswith('!'):
            try:
                return self.handle_irrd_command(query[1:])
            except WhoisQueryParserException as exc:
                logger.info(
                    f'{self.client_str}: encountered parsing error while parsing query "{query}": {exc}'
                )
                return WhoisQueryResponse(
                    response_type=WhoisQueryResponseType.ERROR,
                    mode=WhoisQueryResponseMode.IRRD,
                    result=str(exc))
            except Exception as exc:
                logger.error(
                    f'An exception occurred while processing whois query "{query}": {exc}',
                    exc_info=exc)
                return WhoisQueryResponse(
                    response_type=WhoisQueryResponseType.ERROR,
                    mode=WhoisQueryResponseMode.IRRD,
                    result=
                    'An internal error occurred while processing this query.')
            finally:
                self.database_handler.close()

        try:
            return self.handle_ripe_command(query)
        except WhoisQueryParserException as exc:
            logger.info(
                f'{self.client_str}: encountered parsing error while parsing query "{query}": {exc}'
            )
            return WhoisQueryResponse(
                response_type=WhoisQueryResponseType.ERROR,
                mode=WhoisQueryResponseMode.RIPE,
                result=str(exc))
        except Exception as exc:
            logger.error(
                f'An exception occurred while processing whois query "{query}": {exc}',
                exc_info=exc)
            return WhoisQueryResponse(
                response_type=WhoisQueryResponseType.ERROR,
                mode=WhoisQueryResponseMode.RIPE,
                result='An internal error occurred while processing this query.'
            )
        finally:
            self.database_handler.close()

    def handle_irrd_command(self, full_command: str) -> WhoisQueryResponse:
        """Handle an IRRD-style query. full_command should not include the first exclamation mark. """
        if not full_command:
            raise WhoisQueryParserException(f'Missing IRRD command')
        command = full_command[0].upper()
        parameter = full_command[1:]
        response_type = WhoisQueryResponseType.SUCCESS
        result = None

        # A is not tested here because it is already handled in handle_irrd_routes_for_as_set
        queries_with_parameter = list('TG6IJMNORS')
        if command in queries_with_parameter and not parameter:
            raise WhoisQueryParserException(
                f'Missing parameter for {command} query')

        if command == '!':
            self.multiple_command_mode = True
            result = None
            response_type = WhoisQueryResponseType.NO_RESPONSE
        elif full_command.upper() == 'FNO-RPKI-FILTER':
            self.rpki_invalid_filter_enabled = False
            result = 'Filtering out RPKI invalids is disabled for !r and RIPE style ' \
                     'queries for the rest of this connection.'
        elif command == 'V':
            result = self.handle_irrd_version()
        elif command == 'T':
            self.handle_irrd_timeout_update(parameter)
        elif command == 'G':
            result = self.handle_irrd_routes_for_origin_v4(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == '6':
            result = self.handle_irrd_routes_for_origin_v6(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'A':
            result = self.handle_irrd_routes_for_as_set(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'I':
            result = self.handle_irrd_set_members(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'J':
            result = self.handle_irrd_database_serial_range(parameter)
        elif command == 'M':
            result = self.handle_irrd_exact_key(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'N':
            self.handle_user_agent(parameter)
        elif command == 'O':
            result = self.handle_inverse_attr_search('mnt-by', parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'R':
            result = self.handle_irrd_route_search(parameter)
            if not result:
                response_type = WhoisQueryResponseType.KEY_NOT_FOUND
        elif command == 'S':
            result = self.handle_irrd_sources_list(parameter)
        else:
            raise WhoisQueryParserException(f'Unrecognised command: {command}')

        return WhoisQueryResponse(
            response_type=response_type,
            mode=WhoisQueryResponseMode.IRRD,
            result=result,
        )

    def handle_irrd_timeout_update(self, timeout: str) -> None:
        """!timeout query - update timeout in connection"""
        try:
            timeout_value = int(timeout)
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid value for timeout: {timeout}')

        if timeout_value > 0 and timeout_value <= 1000:
            self.timeout = timeout_value
        else:
            raise WhoisQueryParserException(
                f'Invalid value for timeout: {timeout}')

    def handle_irrd_routes_for_origin_v4(self, origin: str) -> str:
        """!g query - find all originating IPv4 prefixes from an origin, e.g. !gAS65537"""
        return self._routes_for_origin(origin, 4)

    def handle_irrd_routes_for_origin_v6(self, origin: str) -> str:
        """!6 query - find all originating IPv6 prefixes from an origin, e.g. !6as65537"""
        return self._routes_for_origin(origin, 6)

    def _routes_for_origin(self,
                           origin: str,
                           ip_version: Optional[int] = None) -> str:
        """
        Resolve all route(6)s prefixes for an origin, returning a space-separated list
        of all originating prefixes, not including duplicates.
        """
        try:
            origin_formatted, _ = parse_as_number(origin)
        except ValidationError as ve:
            raise WhoisQueryParserException(str(ve))

        self._preloaded_query_called()
        prefixes = self.preloader.routes_for_origins([origin_formatted],
                                                     self.sources,
                                                     ip_version=ip_version)
        return ' '.join(prefixes)

    def handle_irrd_routes_for_as_set(self, set_name: str) -> str:
        """
        !a query - find all originating prefixes for all members of an AS-set, e.g. !a4AS-FOO or !a6AS-FOO
        """
        ip_version: Optional[int] = None
        if set_name.startswith('4'):
            set_name = set_name[1:]
            ip_version = 4
        elif set_name.startswith('6'):
            set_name = set_name[1:]
            ip_version = 6

        if not set_name:
            raise WhoisQueryParserException(
                f'Missing required set name for A query')

        self._preloaded_query_called()
        self._current_set_root_object_class = 'as-set'
        members = self._recursive_set_resolve({set_name})
        prefixes = self.preloader.routes_for_origins(members,
                                                     self.sources,
                                                     ip_version=ip_version)
        return ' '.join(prefixes)

    def handle_irrd_set_members(self, parameter: str) -> str:
        """
        !i query - find all members of an as-set or route-set, possibly recursively.
        e.g. !iAS-FOO for non-recursive, !iAS-FOO,1 for recursive
        """
        self._preloaded_query_called()
        recursive = False
        if parameter.endswith(',1'):
            recursive = True
            parameter = parameter[:-2]

        self._current_set_root_object_class = None
        if not recursive:
            members, leaf_members = self._find_set_members({parameter})
            members.update(leaf_members)
        else:
            members = self._recursive_set_resolve({parameter})
        if parameter in members:
            members.remove(parameter)

        if get_setting('compatibility.ipv4_only_route_set_members'):
            original_members = set(members)
            for member in original_members:
                try:
                    IP(member)
                except ValueError:
                    continue  # This is not a prefix, ignore.
                try:
                    IP(member, ipversion=4)
                except ValueError:
                    # This was a valid prefix, but not a valid IPv4 prefix,
                    # and should be removed.
                    members.remove(member)

        return ' '.join(sorted(members))

    def _recursive_set_resolve(self,
                               members: Set[str],
                               sets_seen=None) -> Set[str]:
        """
        Resolve all members of a number of sets, recursively.

        For each set in members, determines whether it has been seen already (to prevent
        infinite recursion), ignores it if already seen, and then either adds
        it directly or adds it to a set that requires further resolving.
        """
        if not sets_seen:
            sets_seen = set()

        if all([member in sets_seen for member in members]):
            return set()
        sets_seen.update(members)

        set_members = set()
        resolved_as_members = set()
        sub_members, leaf_members = self._find_set_members(members)

        for sub_member in sub_members:
            if self._current_set_root_object_class is None or self._current_set_root_object_class == 'route-set':
                try:
                    IP(sub_member)
                    set_members.add(sub_member)
                    continue
                except ValueError:
                    pass
            # AS numbers are permitted in route-sets and as-sets, per RFC 2622 5.3.
            # When an AS number is encountered as part of route-set resolving,
            # the prefixes originating from that AS should be added to the response.
            try:
                as_number_formatted, _ = parse_as_number(sub_member)
                if self._current_set_root_object_class == 'route-set':
                    set_members.update(
                        self.preloader.routes_for_origins(
                            [as_number_formatted], self.sources))
                    resolved_as_members.add(sub_member)
                else:
                    set_members.add(sub_member)
                continue
            except ValueError:
                pass

        further_resolving_required = sub_members - set_members - sets_seen - resolved_as_members
        new_members = self._recursive_set_resolve(further_resolving_required,
                                                  sets_seen)
        set_members.update(new_members)

        return set_members

    def _find_set_members(self,
                          set_names: Set[str]) -> Tuple[Set[str], Set[str]]:
        """
        Find all members of a number of route-sets or as-sets. Includes both
        direct members listed in members attribute, but also
        members included by mbrs-by-ref/member-of.

        Returns a tuple of two sets:
        - members found of the sets included in set_names, both
          references to other sets and direct AS numbers, etc.
        - leaf members that were included in set_names, i.e.
          names for which no further data could be found - for
          example references to non-existent other sets
        """
        members: Set[str] = set()
        sets_already_resolved: Set[str] = set()

        columns = ['parsed_data', 'rpsl_pk', 'source', 'object_class']
        query = self._prepare_query(column_names=columns)

        object_classes = ['as-set', 'route-set']
        # Per RFC 2622 5.3, route-sets can refer to as-sets,
        # but as-sets can only refer to other as-sets.
        if self._current_set_root_object_class == 'as-set':
            object_classes = [self._current_set_root_object_class]

        query = query.object_classes(object_classes).rpsl_pks(set_names)
        query_result = list(self.database_handler.execute_query(query))

        if not query_result:
            # No sub-members are found, and apparantly all inputs were leaf members.
            return set(), set_names

        # Track the object class of the root object set.
        # In one case self._current_set_root_object_class may already be set
        # on the first run: when the set resolving should be fixed to one
        # type of set object.
        if not self._current_set_root_object_class:
            self._current_set_root_object_class = query_result[0][
                'object_class']

        for result in query_result:
            rpsl_pk = result['rpsl_pk']

            # The same PK may occur in multiple sources, but we are
            # only interested in the first matching object, prioritised
            # according to the source order. This priority is part of the
            # query ORDER BY, so basically we only process an RPSL pk once.
            if rpsl_pk in sets_already_resolved:
                continue
            sets_already_resolved.add(rpsl_pk)

            object_class = result['object_class']
            object_data = result['parsed_data']
            mbrs_by_ref = object_data.get('mbrs-by-ref', None)
            for members_attr in ['members', 'mp-members']:
                if members_attr in object_data:
                    members.update(set(object_data[members_attr]))

            if not rpsl_pk or not object_class or not mbrs_by_ref:
                continue

            # If mbrs-by-ref is set, find any objects with member-of pointing to the route/as-set
            # under query, and include a maintainer listed in mbrs-by-ref, unless mbrs-by-ref
            # is set to ANY.
            query_object_class = [
                'route', 'route6'
            ] if object_class == 'route-set' else ['aut-num']
            query = self._prepare_query(
                column_names=columns).object_classes(query_object_class)
            query = query.lookup_attrs_in(['member-of'], [rpsl_pk])

            if 'ANY' not in [m.strip().upper() for m in mbrs_by_ref]:
                query = query.lookup_attrs_in(['mnt-by'], mbrs_by_ref)

            referring_objects = self.database_handler.execute_query(query)

            for result in referring_objects:
                member_object_class = result['object_class']
                members.add(result['parsed_data'][member_object_class])

        leaf_members = set_names - sets_already_resolved
        return members, leaf_members

    def handle_irrd_database_serial_range(self, parameter: str) -> str:
        """!j query - database serial range"""
        if parameter == '-*':
            sources = self.sources_default if self.sources_default else self.all_valid_sources
        else:
            sources = [s.upper() for s in parameter.split(',')]
        invalid_sources = [
            s for s in sources if s not in self.all_valid_sources
        ]
        query = DatabaseStatusQuery().sources(sources)
        query_results = self.database_handler.execute_query(query)

        result_txt = ''
        for query_result in query_results:
            source = query_result['source'].upper()
            keep_journal = 'Y' if get_setting(
                f'sources.{source}.keep_journal') else 'N'
            serial_oldest = query_result['serial_oldest_seen']
            serial_newest = query_result['serial_newest_seen']
            fields = [
                source,
                keep_journal,
                f'{serial_oldest}-{serial_newest}'
                if serial_oldest and serial_newest else '-',
            ]
            if query_result['serial_last_export']:
                fields.append(str(query_result['serial_last_export']))
            result_txt += ':'.join(fields) + '\n'

        for invalid_source in invalid_sources:
            result_txt += f'{invalid_source.upper()}:X:Database unknown\n'
        return result_txt.strip()

    def handle_irrd_exact_key(self, parameter: str):
        """!m query - exact object key lookup, e.g. !maut-num,AS65537"""
        try:
            object_class, rpsl_pk = parameter.split(',', maxsplit=1)
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid argument for object lookup: {parameter}')
        query = self._prepare_query().object_classes(
            [object_class]).rpsl_pk(rpsl_pk).first_only()
        return self._execute_query_flatten_output(query)

    def handle_irrd_route_search(self, parameter: str):
        """
        !r query - route search with various options:
           !r192.0.2.0/24 returns all exact matching objects
           !r192.0.2.0/24,o returns space-separated origins of all exact matching objects
           !r192.0.2.0/24,l returns all one-level less specific objects, not including exact
           !r192.0.2.0/24,L returns all less specific objects, including exact
           !r192.0.2.0/24,M returns all more specific objects, not including exact
        """
        option: Optional[str] = None
        if ',' in parameter:
            address, option = parameter.split(',')
        else:
            address = parameter
        try:
            address = IP(address)
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid input for route search: {parameter}')

        query = self._prepare_query(ordered_by_sources=False).object_classes(
            ['route', 'route6'])
        if option is None or option == 'o':
            query = query.ip_exact(address)
        elif option == 'l':
            query = query.ip_less_specific_one_level(address)
        elif option == 'L':
            query = query.ip_less_specific(address)
        elif option == 'M':
            query = query.ip_more_specific(address)
        else:
            raise WhoisQueryParserException(
                f'Invalid route search option: {option}')

        if option == 'o':
            query_result = self.database_handler.execute_query(query)
            prefixes = [r['parsed_data']['origin'] for r in query_result]
            return ' '.join(prefixes)
        return self._execute_query_flatten_output(query)

    def handle_irrd_sources_list(self, parameter: str) -> Optional[str]:
        """
        !s query - set used sources
           !s-lc returns all enabled sources, space separated
           !sripe,nttcom limits sources to ripe and nttcom
        """
        if parameter == '-lc':
            return ','.join(self.sources)

        sources = parameter.upper().split(',')
        if not all([source in self.all_valid_sources for source in sources]):
            raise WhoisQueryParserException(
                'One or more selected sources are unavailable.')
        self.sources = sources

        return None

    def handle_irrd_version(self):
        """!v query - return version"""
        return f'IRRd -- version {__version__}'

    def handle_ripe_command(self, full_query: str) -> WhoisQueryResponse:
        """
        Process RIPE-style queries. Any query that is not explicitly an IRRD-style
        query (i.e. starts with exclamation mark) is presumed to be a RIPE query.
        """
        full_query = re.sub(' +', ' ', full_query)
        components = full_query.strip().split(' ')
        result = None
        response_type = WhoisQueryResponseType.SUCCESS

        while len(components):
            component = components.pop(0)
            if component.startswith('-'):
                command = component[1:]
                try:
                    if command == 'k':
                        self.multiple_command_mode = True
                    elif command in ['l', 'L', 'M', 'x']:
                        result = self.handle_ripe_route_search(
                            command, components.pop(0))
                        if not result:
                            response_type = WhoisQueryResponseType.KEY_NOT_FOUND
                        break
                    elif command == 'i':
                        result = self.handle_inverse_attr_search(
                            components.pop(0), components.pop(0))
                        if not result:
                            response_type = WhoisQueryResponseType.KEY_NOT_FOUND
                        break
                    elif command == 's':
                        self.handle_ripe_sources_list(components.pop(0))
                    elif command == 'a':
                        self.handle_ripe_sources_list(None)
                    elif command == 'T':
                        self.handle_ripe_restrict_object_class(
                            components.pop(0))
                    elif command == 't':
                        result = self.handle_ripe_request_object_template(
                            components.pop(0))
                        break
                    elif command == 'K':
                        self.handle_ripe_key_fields_only()
                    elif command == 'V':
                        self.handle_user_agent(components.pop(0))
                    elif command == 'g':
                        result = self.handle_nrtm_request(components.pop(0))
                    elif command in ['F', 'r']:
                        continue  # These flags disable recursion, but IRRd never performs recursion anyways
                    else:
                        raise WhoisQueryParserException(
                            f'Unrecognised flag/search: {command}')
                except IndexError:
                    raise WhoisQueryParserException(
                        f'Missing argument for flag/search: {command}')
            else:  # assume query to be a free text search
                result = self.handle_ripe_text_search(component)

        return WhoisQueryResponse(
            response_type=response_type,
            mode=WhoisQueryResponseMode.RIPE,
            result=result,
        )

    def handle_ripe_route_search(self, command: str, parameter: str) -> str:
        """
        -l/L/M/x query - route search for:
           -x 192.0.2.0/2 returns all exact matching objects
           -l 192.0.2.0/2 returns all one-level less specific objects, not including exact
           -L 192.0.2.0/2 returns all less specific objects, including exact
           -M 192.0.2.0/2 returns all more specific objects, not including exact
        """
        try:
            address = IP(parameter)
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid input for route search: {parameter}')

        query = self._prepare_query(ordered_by_sources=False).object_classes(
            ['route', 'route6'])
        if command == 'x':
            query = query.ip_exact(address)
        elif command == 'l':
            query = query.ip_less_specific_one_level(address)
        elif command == 'L':
            query = query.ip_less_specific(address)
        elif command == 'M':
            query = query.ip_more_specific(address)

        return self._execute_query_flatten_output(query)

    def handle_ripe_sources_list(self, sources_list: Optional[str]) -> None:
        """-s/-a parameter - set sources list. Empty list enables all sources. """
        if sources_list:
            sources = sources_list.upper().split(',')
            if not all(
                [source in self.all_valid_sources for source in sources]):
                raise WhoisQueryParserException(
                    'One or more selected sources are unavailable.')
            self.sources = sources
        else:
            self.sources = self.sources_default if self.sources_default else self.all_valid_sources

    def handle_ripe_restrict_object_class(self, object_classes) -> None:
        """-T parameter - restrict object classes for this query, comma-seperated"""
        self.object_classes = object_classes.split(',')

    def handle_ripe_request_object_template(self, object_class) -> str:
        """-t query - return the RPSL template for an object class"""
        try:
            return OBJECT_CLASS_MAPPING[object_class]().generate_template()
        except KeyError:
            raise WhoisQueryParserException(
                f'Unknown object class: {object_class}')

    def handle_ripe_key_fields_only(self) -> None:
        """-K paramater - only return primary key and members fields"""
        self.key_fields_only = True

    def handle_ripe_text_search(self, value: str) -> str:
        query = self._prepare_query(
            ordered_by_sources=False).text_search(value)
        return self._execute_query_flatten_output(query)

    def handle_user_agent(self, user_agent: str):
        """-V/!n parameter/query - set a user agent for the client"""
        self.user_agent = user_agent
        logger.info(f'{self.client_str}: user agent set to: {user_agent}')

    def handle_nrtm_request(self, param):
        try:
            source, version, serial_range = param.split(':')
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid parameter: must contain three elements')

        try:
            serial_start, serial_end = serial_range.split('-')
            serial_start = int(serial_start)
            if serial_end == 'LAST':
                serial_end = None
            else:
                serial_end = int(serial_end)
        except ValueError:
            raise WhoisQueryParserException(
                f'Invalid serial range: {serial_range}')

        if version not in ['1', '3']:
            raise WhoisQueryParserException(f'Invalid NRTM version: {version}')

        source = source.upper()
        if source not in self.all_valid_sources:
            raise WhoisQueryParserException(f'Unknown source: {source}')

        if not is_client_permitted(self.client_ip,
                                   f'sources.{source}.nrtm_access_list'):
            raise WhoisQueryParserException(f'Access denied')

        try:
            return NRTMGenerator().generate(source, version, serial_start,
                                            serial_end, self.database_handler)
        except NRTMGeneratorException as nge:
            raise WhoisQueryParserException(str(nge))

    def handle_inverse_attr_search(self, attribute: str, value: str) -> str:
        """
        -i/!o query - inverse search for attribute values
        e.g. `-i mnt-by FOO` finds all objects where (one of the) maintainer(s) is FOO,
        as does `!oFOO`. Restricted to designated lookup fields.
        """
        if attribute not in self.lookup_field_names:
            readable_lookup_field_names = ', '.join(self.lookup_field_names)
            msg = (
                f'Inverse attribute search not supported for {attribute},' +
                f'only supported for attributes: {readable_lookup_field_names}'
            )
            raise WhoisQueryParserException(msg)
        query = self._prepare_query(ordered_by_sources=False).lookup_attr(
            attribute, value)
        return self._execute_query_flatten_output(query)

    def _prepare_query(self,
                       column_names=None,
                       ordered_by_sources=True) -> RPSLDatabaseQuery:
        """Prepare an RPSLDatabaseQuery by applying relevant sources/class filters."""
        query = RPSLDatabaseQuery(column_names, ordered_by_sources)
        if self.sources and self.sources != self.all_valid_sources:
            query.sources(self.sources)
        else:
            default = list(get_setting('sources_default', []))
            if default:
                query.sources(list(default))
        if self.object_classes:
            query.object_classes(self.object_classes)
        if self.rpki_invalid_filter_enabled:
            query.rpki_status([RPKIStatus.not_found, RPKIStatus.valid])
        return query

    def _execute_query_flatten_output(self, query: RPSLDatabaseQuery) -> str:
        """
        Execute an RPSLDatabaseQuery, and flatten the output into a string with object text
        for easy passing to a WhoisQueryResponse.
        """
        query_response = self.database_handler.execute_query(query)
        if self.key_fields_only:
            result = self._filter_key_fields(query_response)
        else:
            result = ''
            for obj in query_response:
                result += obj['object_text']
                if (self.rpki_aware and obj['source'] != RPKI_IRR_PSEUDO_SOURCE
                        and obj['object_class']
                        in RPKI_RELEVANT_OBJECT_CLASSES):
                    comment = ''
                    if obj['rpki_status'] == RPKIStatus.not_found:
                        comment = ' # No ROAs found, or RPKI validation not enabled for source'
                    result += f'rpki-ov-state:  {obj["rpki_status"].name}{comment}\n'
                result += '\n'
        return result.strip('\n\r')

    def _filter_key_fields(self, query_response) -> str:
        results: OrderedSet[str] = OrderedSet()
        for obj in query_response:
            result = ''
            rpsl_object_class = OBJECT_CLASS_MAPPING[obj['object_class']]
            fields_included = rpsl_object_class.pk_fields + [
                'members', 'mp-members'
            ]

            for field_name in fields_included:
                field_data = obj['parsed_data'].get(field_name)
                if field_data:
                    if isinstance(field_data, list):
                        for item in field_data:
                            result += f'{field_name}: {item}\n'
                    else:
                        result += f'{field_name}: {field_data}\n'
            results.add(result)
        return '\n'.join(results)

    def _preloaded_query_called(self):
        """
        Called each time the user runs a query that can be preloaded.
        After the 5th, load the preload store into memory to speed
        up expected further queries.
        """
        self._preloaded_query_count += 1
        if self._preloaded_query_count > 5:
            self.preloader.load_routes_into_memory()
Esempio n. 9
0
class QueryResolver:
    """
    Resolver for all RPSL queries.

    Some aspects like setting sources retain state, so a single instance
    should not be shared across unrelated query sessions.
    """
    lookup_field_names = lookup_field_names()
    database_handler: DatabaseHandler
    _current_set_root_object_class: Optional[str]

    def __init__(self, preloader: Preloader,
                 database_handler: DatabaseHandler) -> None:
        self.all_valid_sources = list(get_setting('sources', {}).keys())
        self.sources_default = list(get_setting('sources_default', []))
        self.sources: List[
            str] = self.sources_default if self.sources_default else self.all_valid_sources
        if get_setting('rpki.roa_source'):
            self.all_valid_sources.append(RPKI_IRR_PSEUDO_SOURCE)
        self.object_class_filter: List[str] = []
        self.rpki_aware = bool(get_setting('rpki.roa_source'))
        self.rpki_invalid_filter_enabled = self.rpki_aware
        self.out_scope_filter_enabled = True
        self.user_agent: Optional[str] = None
        self.preloader = preloader
        self.database_handler = database_handler
        self.sql_queries: List[str] = []
        self.sql_trace = False

    def set_query_sources(self, sources: Optional[List[str]]) -> None:
        """Set the sources for future queries. If sources is None, default source list is set."""
        if sources is None:
            sources = self.sources_default if self.sources_default else self.all_valid_sources
        elif not all([source in self.all_valid_sources for source in sources]):
            raise InvalidQueryException(
                'One or more selected sources are unavailable.')
        self.sources = sources

    def disable_rpki_filter(self) -> None:
        self.rpki_invalid_filter_enabled = False

    def disable_out_of_scope_filter(self) -> None:
        self.out_scope_filter_enabled = False

    def set_object_class_filter_next_query(self,
                                           object_classes: List[str]) -> None:
        """Restrict object classes for the next query, comma-seperated"""
        self.object_class_filter = object_classes

    def key_lookup(self, object_class: str,
                   rpsl_pk: str) -> RPSLDatabaseResponse:
        """RPSL exact key lookup."""
        query = self._prepare_query().object_classes(
            [object_class]).rpsl_pk(rpsl_pk).first_only()
        return self._execute_query(query)

    def rpsl_text_search(self, value: str) -> RPSLDatabaseResponse:
        query = self._prepare_query(
            ordered_by_sources=False).text_search(value)
        return self._execute_query(query)

    def route_search(self, address: IP, lookup_type: RouteLookupType):
        """Route(6) object search for an address, supporting exact/less/more specific."""
        query = self._prepare_query(ordered_by_sources=False).object_classes(
            ['route', 'route6'])
        lookup_queries = {
            RouteLookupType.EXACT: query.ip_exact,
            RouteLookupType.LESS_SPECIFIC_ONE_LEVEL:
            query.ip_less_specific_one_level,
            RouteLookupType.LESS_SPECIFIC_WITH_EXACT: query.ip_less_specific,
            RouteLookupType.MORE_SPECIFIC_WITHOUT_EXACT:
            query.ip_more_specific,
        }
        query = lookup_queries[lookup_type](address)
        return self._execute_query(query)

    def rpsl_attribute_search(self, attribute: str,
                              value: str) -> RPSLDatabaseResponse:
        """
        -i/!o query - inverse search for attribute values
        e.g. `-i mnt-by FOO` finds all objects where (one of the) maintainer(s) is FOO,
        as does `!oFOO`. Restricted to designated lookup fields.
        """
        if attribute not in self.lookup_field_names:
            readable_lookup_field_names = ', '.join(self.lookup_field_names)
            msg = (
                f'Inverse attribute search not supported for {attribute},' +
                f'only supported for attributes: {readable_lookup_field_names}'
            )
            raise InvalidQueryException(msg)
        query = self._prepare_query(ordered_by_sources=False).lookup_attr(
            attribute, value)
        return self._execute_query(query)

    def routes_for_origin(self,
                          origin: str,
                          ip_version: Optional[int] = None) -> Set[str]:
        """
        Resolve all route(6)s prefixes for an origin, returning a set
        of all prefixes. Origin must be in 'ASxxx' format.
        """
        prefixes = self.preloader.routes_for_origins([origin],
                                                     self.sources,
                                                     ip_version=ip_version)
        return prefixes

    def routes_for_as_set(self,
                          set_name: str,
                          ip_version: Optional[int] = None,
                          exclude_sets: Set[str] = None) -> Set[str]:
        """
        Find all originating prefixes for all members of an AS-set. May be restricted
        to IPv4 or IPv6. Returns a set of all prefixes.
        """
        self._current_set_root_object_class = 'as-set'
        self._current_excluded_sets = exclude_sets if exclude_sets else set()
        self._current_set_maximum_depth = 0
        members = self._recursive_set_resolve({set_name})
        return self.preloader.routes_for_origins(members,
                                                 self.sources,
                                                 ip_version=ip_version)

    def members_for_set_per_source(self,
                                   parameter: str,
                                   exclude_sets: Set[str] = None,
                                   depth=0,
                                   recursive=False) -> Dict[str, List[str]]:
        """
        Find all members of an as-set or route-set, possibly recursively, distinguishing
        between multiple root objects in different sources with the same name.
        Returns a dict with sources as keys, list of all members, including leaf members,
        as values.
        """
        query = self._prepare_query(column_names=['source'])
        object_classes = ['as-set', 'route-set']
        query = query.object_classes(object_classes).rpsl_pk(parameter)
        set_sources = [row['source'] for row in self._execute_query(query)]

        return {
            source: self.members_for_set(
                parameter=parameter,
                exclude_sets=exclude_sets,
                depth=depth,
                recursive=recursive,
                root_source=source,
            )
            for source in set_sources
        }

    def members_for_set(self,
                        parameter: str,
                        exclude_sets: Set[str] = None,
                        depth=0,
                        recursive=False,
                        root_source: Optional[str] = None) -> List[str]:
        """
        Find all members of an as-set or route-set, possibly recursively.
        Returns a list of all members, including leaf members.
        If root_source is set, the root object is only looked for in that source -
        resolving is then continued using the currently set sources.
        """
        self._current_set_root_object_class = None
        self._current_excluded_sets = exclude_sets if exclude_sets else set()
        self._current_set_maximum_depth = depth
        if not recursive:
            members, leaf_members = self._find_set_members(
                {parameter}, limit_source=root_source)
            members.update(leaf_members)
        else:
            members = self._recursive_set_resolve({parameter},
                                                  root_source=root_source)
        if parameter in members:
            members.remove(parameter)

        if get_setting('compatibility.ipv4_only_route_set_members'):
            original_members = set(members)
            for member in original_members:
                try:
                    IP(member)
                except ValueError:
                    continue  # This is not a prefix, ignore.
                try:
                    IP(member, ipversion=4)
                except ValueError:
                    # This was a valid prefix, but not a valid IPv4 prefix,
                    # and should be removed.
                    members.remove(member)

        return sorted(members)

    def _recursive_set_resolve(self,
                               members: Set[str],
                               sets_seen=None,
                               root_source: Optional[str] = None) -> Set[str]:
        """
        Resolve all members of a number of sets, recursively.

        For each set in members, determines whether it has been seen already (to prevent
        infinite recursion), ignores it if already seen, and then either adds
        it directly or adds it to a set that requires further resolving.
        If root_source is set, the root object is only looked for in that source -
        resolving is then continued using the currently set sources.
        """
        if not sets_seen:
            sets_seen = set()

        if all([member in sets_seen for member in members]):
            return set()
        sets_seen.update(members)

        set_members = set()

        resolved_as_members = set()
        sub_members, leaf_members = self._find_set_members(
            members, limit_source=root_source)

        for sub_member in sub_members:
            if self._current_set_root_object_class is None or self._current_set_root_object_class == 'route-set':
                try:
                    IP(sub_member.split('^')[0])
                    set_members.add(sub_member)
                    continue
                except ValueError:
                    pass
            # AS numbers are permitted in route-sets and as-sets, per RFC 2622 5.3.
            # When an AS number is encountered as part of route-set resolving,
            # the prefixes originating from that AS should be added to the response.
            try:
                as_number_formatted, _ = parse_as_number(sub_member)
                if self._current_set_root_object_class == 'route-set':
                    set_members.update(
                        self.preloader.routes_for_origins(
                            [as_number_formatted], self.sources))
                    resolved_as_members.add(sub_member)
                else:
                    set_members.add(sub_member)
                continue
            except ValueError:
                pass

        self._current_set_maximum_depth -= 1
        if self._current_set_maximum_depth == 0:
            return set_members | sub_members | leaf_members

        further_resolving_required = sub_members - set_members - sets_seen - resolved_as_members - self._current_excluded_sets
        new_members = self._recursive_set_resolve(further_resolving_required,
                                                  sets_seen)
        set_members.update(new_members)

        return set_members

    def _find_set_members(
            self,
            set_names: Set[str],
            limit_source: Optional[str] = None) -> Tuple[Set[str], Set[str]]:
        """
        Find all members of a number of route-sets or as-sets. Includes both
        direct members listed in members attribute, but also
        members included by mbrs-by-ref/member-of.
        If limit_source is set, the set_names are only looked for in that source.

        Returns a tuple of two sets:
        - members found of the sets included in set_names, both
          references to other sets and direct AS numbers, etc.
        - leaf members that were included in set_names, i.e.
          names for which no further data could be found - for
          example references to non-existent other sets
        """
        members: Set[str] = set()
        sets_already_resolved: Set[str] = set()

        columns = ['parsed_data', 'rpsl_pk', 'source', 'object_class']
        query = self._prepare_query(column_names=columns)

        object_classes = ['as-set', 'route-set']
        # Per RFC 2622 5.3, route-sets can refer to as-sets,
        # but as-sets can only refer to other as-sets.
        if self._current_set_root_object_class == 'as-set':
            object_classes = [self._current_set_root_object_class]

        query = query.object_classes(object_classes).rpsl_pks(set_names)
        if limit_source:
            query = query.sources([limit_source])
        query_result = list(self._execute_query(query))

        if not query_result:
            # No sub-members are found, and apparantly all inputs were leaf members.
            return set(), set_names

        # Track the object class of the root object set.
        # In one case self._current_set_root_object_class may already be set
        # on the first run: when the set resolving should be fixed to one
        # type of set object.
        if not self._current_set_root_object_class:
            self._current_set_root_object_class = query_result[0][
                'object_class']

        for result in query_result:
            rpsl_pk = result['rpsl_pk']

            # The same PK may occur in multiple sources, but we are
            # only interested in the first matching object, prioritised
            # according to the source order. This priority is part of the
            # query ORDER BY, so basically we only process an RPSL pk once.
            if rpsl_pk in sets_already_resolved:
                continue
            sets_already_resolved.add(rpsl_pk)

            object_class = result['object_class']
            object_data = result['parsed_data']
            mbrs_by_ref = object_data.get('mbrs-by-ref', None)
            for members_attr in ['members', 'mp-members']:
                if members_attr in object_data:
                    members.update(set(object_data[members_attr]))

            if not rpsl_pk or not object_class or not mbrs_by_ref:
                continue

            # If mbrs-by-ref is set, find any objects with member-of pointing to the route/as-set
            # under query, and include a maintainer listed in mbrs-by-ref, unless mbrs-by-ref
            # is set to ANY.
            query_object_class = [
                'route', 'route6'
            ] if object_class == 'route-set' else ['aut-num']
            query = self._prepare_query(
                column_names=columns).object_classes(query_object_class)
            query = query.lookup_attrs_in(['member-of'], [rpsl_pk])

            if 'ANY' not in [m.strip().upper() for m in mbrs_by_ref]:
                query = query.lookup_attrs_in(['mnt-by'], mbrs_by_ref)

            referring_objects = self._execute_query(query)

            for result in referring_objects:
                member_object_class = result['object_class']
                members.add(result['parsed_data'][member_object_class])

        leaf_members = set_names - sets_already_resolved
        return members, leaf_members

    def database_status(
        self,
        sources: Optional[List[str]] = None
    ) -> 'OrderedDict[str, OrderedDict[str, Any]]':
        """Database status. If sources is None, return all valid sources."""
        if sources is None:
            sources = self.sources_default if self.sources_default else self.all_valid_sources
        invalid_sources = [
            s for s in sources if s not in self.all_valid_sources
        ]
        query = DatabaseStatusQuery().sources(sources)
        query_results = self._execute_query(query)

        results: OrderedDict[str, OrderedDict[str, Any]] = OrderedDict()
        for query_result in query_results:
            source = query_result['source'].upper()
            results[source] = OrderedDict()
            results[source]['authoritative'] = get_setting(
                f'sources.{source}.authoritative', False)
            object_class_filter = get_setting(
                f'sources.{source}.object_class_filter')
            results[source]['object_class_filter'] = list(
                object_class_filter) if object_class_filter else None
            results[source]['rpki_rov_filter'] = bool(
                get_setting('rpki.roa_source')
                and not get_setting(f'sources.{source}.rpki_excluded'))
            results[source]['scopefilter_enabled'] = bool(
                get_setting('scopefilter')
            ) and not get_setting(f'sources.{source}.scopefilter_excluded')
            results[source]['local_journal_kept'] = get_setting(
                f'sources.{source}.keep_journal', False)
            results[source]['serial_oldest_journal'] = query_result[
                'serial_oldest_journal']
            results[source]['serial_newest_journal'] = query_result[
                'serial_newest_journal']
            results[source]['serial_last_export'] = query_result[
                'serial_last_export']
            results[source]['serial_newest_mirror'] = query_result[
                'serial_newest_mirror']
            results[source]['last_update'] = query_result[
                'updated'].astimezone(timezone('UTC')).isoformat()
            results[source]['synchronised_serials'] = is_serial_synchronised(
                self.database_handler, source)

        for invalid_source in invalid_sources:
            results[invalid_source.upper()] = OrderedDict(
                {'error': 'Unknown source'})
        return results

    def rpsl_object_template(self, object_class) -> str:
        """Return the RPSL template for an object class"""
        try:
            return OBJECT_CLASS_MAPPING[object_class]().generate_template()
        except KeyError:
            raise InvalidQueryException(
                f'Unknown object class: {object_class}')

    def enable_sql_trace(self):
        self.sql_trace = True

    def retrieve_sql_trace(self) -> List[str]:
        trace = self.sql_queries
        self.sql_trace = False
        self.sql_queries = []
        return trace

    def _prepare_query(self,
                       column_names=None,
                       ordered_by_sources=True) -> RPSLDatabaseQuery:
        """Prepare an RPSLDatabaseQuery by applying relevant sources/class filters."""
        query = RPSLDatabaseQuery(column_names, ordered_by_sources)
        if self.sources:
            query.sources(self.sources)
        if self.object_class_filter:
            query.object_classes(self.object_class_filter)
        if self.rpki_invalid_filter_enabled:
            query.rpki_status([RPKIStatus.not_found, RPKIStatus.valid])
        if self.out_scope_filter_enabled:
            query.scopefilter_status([ScopeFilterStatus.in_scope])
        self.object_class_filter = []
        return query

    def _execute_query(self, query) -> RPSLDatabaseResponse:
        if self.sql_trace:
            self.sql_queries.append(repr(query))
        return self.database_handler.execute_query(query,
                                                   refresh_on_error=True)
Esempio n. 10
0
    pk = sa.Column(pg.UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), primary_key=True)

    prefix = sa.Column(pg.CIDR, nullable=False)
    asn = sa.Column(sa.BigInteger, nullable=False)
    max_length = sa.Column(sa.Integer, nullable=False)
    trust_anchor = sa.Column(sa.String)
    ip_version = sa.Column(sa.Integer, nullable=False, index=True)

    created = sa.Column(sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False)

    @declared_attr
    def __table_args__(cls):  # noqa
        args = [
            sa.UniqueConstraint('prefix', 'asn', 'max_length', 'trust_anchor', name='roa_object_prefix_asn_maxlength_unique'),
            sa.Index('ix_roa_objects_prefix_gist', sa.text('prefix inet_ops'), postgresql_using='gist')
        ]
        return tuple(args)

    def __repr__(self):
        return f'<{self.prefix}/{self.asn}>'


# Before you update this, please check the storage documentation for changing lookup fields.
expected_lookup_field_names = {
    'admin-c', 'tech-c', 'zone-c', 'member-of', 'mnt-by', 'role', 'members', 'person',
    'mp-members', 'origin', 'mbrs-by-ref',
}
if sorted(lookup_field_names()) != sorted(expected_lookup_field_names):  # pragma: no cover
    raise RuntimeError(f'Field names of lookup fields do not match expected set. Indexes may be missing. '
                       f'Expected: {expected_lookup_field_names}, actual: {lookup_field_names()}')