Esempio n. 1
0
def view_proposals(vendor: str = None, profile: str = None):
    entries = db.session.query(Vulnerability, Nvd)
    entries = entries.filter(Vulnerability.creator == g.user)
    entries = entries.outerjoin(Vulnerability,
                                Nvd.cve_id == Vulnerability.cve_id)
    entries = entries.order_by(desc(Nvd.id))

    bookmarked_page = parse_pagination_param("proposal_p")
    per_page = 10
    entries_non_processed = entries.filter(~Vulnerability.state.in_(
        [VulnerabilityState.ARCHIVED, VulnerabilityState.PUBLISHED]))
    entries_full = entries_non_processed.options(default_nvd_view_options)
    proposal_vulns = get_page(entries_full, per_page, page=bookmarked_page)
    proposal_vulns = VulnViewTypesetPaginationObjectWrapper(
        proposal_vulns.paging)

    entries_processed = entries.filter(
        Vulnerability.state.in_(
            [VulnerabilityState.ARCHIVED, VulnerabilityState.PUBLISHED]))
    bookmarked_page_processed = parse_pagination_param("proposal_processed_p")
    entries_processed_full = entries_processed.options(
        default_nvd_view_options)
    proposal_vulns_processed = get_page(entries_processed_full,
                                        per_page,
                                        page=bookmarked_page_processed)
    proposal_vulns_processed = VulnViewTypesetPaginationObjectWrapper(
        proposal_vulns_processed.paging)

    return render_template(
        "profile/proposals_view.html",
        proposal_vulns=proposal_vulns,
        proposal_vulns_processed=proposal_vulns_processed,
    )
Esempio n. 2
0
def test_orm_bad_page(dburl):
    with S(dburl, echo=ECHO) as s:
        q = s.query(Book).order_by(Book.name)

        # check that malformed page tuple fails
        with pytest.raises(InvalidPage):
            get_page(q, per_page=10, page=((1,), False, "Potatoes"))

        # one order col, so check place with 2 elements fails
        with pytest.raises(InvalidPage):
            get_page(q, per_page=10, page=((1, 1), False))
Esempio n. 3
0
    async def get(self, request):
        """Handler for all GET requests"""
        db = get_database()

        log.info(request.query_params)
        args = await parser.parse(self.args_schema,
                                  request,
                                  location="querystring")
        log.info(args)

        ## This is why the response was blank
        # if not len(request.query_params.keys()):
        #     return JSONResponse({"Hello": "World"})

        q = db.session.query(self.meta.table)
        if args["all"]:
            res = q.all()
            return APIResponse(res, total_count=len(res), to_dict=True)

        try:
            res = get_page(q, per_page=args["per_page"], page=args["page"])
        except ValueError:
            raise ValidationError("Invalid page token.")

        return APIResponse(res, total_count=q.count(), to_dict=True)
Esempio n. 4
0
def product_view(vendor: str = None, product: str = None):
    sub_query = db.session.query(Cpe.nvd_json_id).filter(
        and_(Cpe.vendor == vendor, Cpe.product == product)).distinct()
    number_vulns = sub_query.count()

    entries = db.session.query(Vulnerability, Nvd)
    entries = entries.filter(Nvd.id.in_(sub_query)).with_labels()
    entries = entries.outerjoin(Vulnerability,
                                Nvd.cve_id == Vulnerability.cve_id)
    entries = entries.order_by(desc(Nvd.id))

    bookmarked_page = parse_pagination_param("product_p")

    per_page = 10
    entries_full = entries.options(default_nvd_view_options)
    product_vulns = get_page(entries_full, per_page, page=bookmarked_page)
    product_vulns = VulnViewTypesetPaginationObjectWrapper(
        product_vulns.paging)

    entries_commits = get_entries_commits(entries)
    repo_urls = get_unique_repo_urls(entries_commits)

    return render_template("product/view.html",
                           vendor=vendor,
                           product=product,
                           product_vulns=product_vulns,
                           repo_urls=repo_urls,
                           number_vulns=number_vulns)
Esempio n. 5
0
def main():
    with S(DB, echo=False) as s:
        s.execute("""
            drop table if exists single;
        """)

        s.execute("""
            create table if not exists
                single(id serial, title text, year int, peak_position int)
        """)

    with S(DB, echo=False) as s:
        for line in SINGLES.splitlines():
            title, year, peak = line.rsplit(' ', 2)

            single = Single(
                title=title,
                year=year,
                peak_position=peak
            )
            s.add(single)

    with S(DB, echo=False) as s:
        q = s.query(Single).order_by(Single.peak_position, desc(Single.year), Single.title, desc(Single.id))

        bookmark = None

        while True:
            p = get_page(q, per_page=PER_PAGE, page=bookmark)
            print_page(p)
            bookmark = p.paging.bookmark_next
            if not p.paging.has_next:
                break
Esempio n. 6
0
def get_pending_proposals_paged():
    entries = db.session.query(Vulnerability, Nvd)
    entries = entries.filter(
        Vulnerability.state != VulnerabilityState.PUBLISHED)
    entries = entries.outerjoin(Vulnerability,
                                Nvd.cve_id == Vulnerability.cve_id)
    entries = entries.order_by(asc(Vulnerability.state), desc(Nvd.id))
    bookmarked_page = parse_pagination_param("review_p")
    per_page = 10
    entries_full = entries.options(default_nvd_view_options)
    review_vulns = get_page(entries_full, per_page, page=bookmarked_page)
    review_vulns = VulnViewTypesetPaginationObjectWrapper(review_vulns.paging)
    return review_vulns
Esempio n. 7
0
def test_marker_and_bookmark_per_item(dburl):

    with S(dburl, echo=ECHO) as s:
        q = s.query(Book).order_by(Book.id)
        page = get_page(q, per_page=3)

        paging = page.paging
        assert len(page) == 3
        assert paging.get_marker_at(0) == ((1, ), False)
        assert paging.get_marker_at(1) == ((2, ), False)
        assert paging.get_marker_at(2) == ((3, ), False)

        paging_items = list(paging.items())
        assert len(paging_items) == 3
        for i, (key, book) in enumerate(paging_items):
            assert key == ((i + 1, ), False)
            assert book.id == i + 1

        assert paging.get_bookmark_at(0) == ">i:1"
        assert paging.get_bookmark_at(1) == ">i:2"
        assert paging.get_bookmark_at(2) == ">i:3"

        bookmark_items = list(paging.bookmark_items())
        assert len(bookmark_items) == 3
        for i, (key, book) in enumerate(bookmark_items):
            assert key == ">i:%d" % (i + 1)
            assert book.id == i + 1

        place, _ = paging.get_marker_at(2)
        page = get_page(q, per_page=3, before=place)

        paging = page.paging
        assert len(page) == 2
        assert paging.get_marker_at(0) == ((2, ), True)
        assert paging.get_marker_at(1) == ((1, ), True)

        assert paging.get_bookmark_at(0) == "<i:2"
        assert paging.get_bookmark_at(1) == "<i:1"
Esempio n. 8
0
def list(vendor: str = None, profile: str = None):
    entries = db.session.query(Vulnerability, Nvd)
    entries = entries.filter(
        Vulnerability.state != VulnerabilityState.PUBLISHED)
    entries = entries.outerjoin(Vulnerability,
                                Nvd.cve_id == Vulnerability.cve_id)
    entries = entries.order_by(desc(Nvd.id))

    bookmarked_page = parse_pagination_param("review_p")
    per_page = 10
    entries_full = entries.options(default_nvd_view_options)
    review_vulns = get_page(entries_full, per_page, page=bookmarked_page)
    review_vulns = VulnViewTypesetPaginationObjectWrapper(review_vulns.paging)
    return render_template("review/list.html", review_vulns=review_vulns)
Esempio n. 9
0
    async def get(self, request):
        """Handler for all GET requests"""

        args = await parser.parse(self.args_schema,
                                  request,
                                  location="querystring")

        db = get_database()

        if not len(request.query_params.keys()):
            return JSONResponse({})

        # Wrap entire query infrastructure in error-handling block.
        # We should probably make this a "with" statement or something
        # to use throughout our API code.
        with db.session_scope(commit=False):

            try:
                DataFile = self.model

                q = db.session.query(
                    DataFile.file_hash,
                    DataFile.file_mtime,
                    DataFile.basename,
                    DataFile.type_id,
                ).order_by(DataFile.file_mtime)

                if args["all"]:
                    res = q.all()
                    return APIResponse(res,
                                       schema=self.schema(many=True),
                                       total_count=len(res))

                for _filter in self._filters:
                    q = _filter(q, args)

                try:
                    res = get_page(q,
                                   per_page=args["per_page"],
                                   page=args["page"])

                except ValueError:
                    raise ValidationError("Invalid page token.")

                # Note: we don't need to use a schema to serialize here. but it is nice if we have it
                return APIResponse(res, schema=self.schema(many=True))
            except Exception as err:
                raise ApplicationError(str(err))
Esempio n. 10
0
    async def get(self, request):
        """Handler for all GET requests"""

        if request.path_params.get("id") is not None:
            # Pass off to the single-item handler
            return await self.get_single(request)
        args = await self.parse_querystring(request, self.args_schema)
        db = get_database()

        schema = self.meta.schema(many=True, allowed_nests=args["nest"])
        model = schema.opts.model

        if not len(request.query_params.keys()):
            return await self.api_docs(request, schema)

        # Wrap entire query infrastructure in error-handling block.
        # We should probably make this a "with" statement or something
        # to use throughout our API code.
        with db.session_scope(commit=False):

            q = db.session.query(schema.opts.model)

            for _filter in self._filters:
                q = _filter(q, args)

            if not request.user.is_authenticated and hasattr(
                    schema.opts.model, "embargo_date"):
                q = q.filter(schema.opts.model.embargo_date == None)

            q = q.options(*list(schema.query_options(max_depth=None)))

            if args["all"]:
                res = q.all()
                return APIResponse(res, schema=schema, total_count=len(res))

            # By default, we order by the "natural" order of Primary Keys. This
            # is not really what we want in most cases, probably.
            pk = [desc(p) for p in get_primary_keys(model)]
            q = q.order_by(*pk)
            # https://github.com/djrobstep/sqlakeyset
            try:
                res = get_page(q, per_page=args["per_page"], page=args["page"])
            except ValueError:
                raise ValidationError("Invalid page token.")

            return APIResponse(res, schema=schema, total_count=q.count())
Esempio n. 11
0
def check_paging_orm(q):
    item_counts = range(1, 12)

    unpaged = q.all()

    for backwards in [False, True]:
        for per_page in item_counts:
            gathered = []

            page = None, backwards

            while True:
                serialized_page = serialize_bookmark(page)
                page = unserialize_bookmark(serialized_page)

                page_with_paging = get_page(q,
                                            per_page=per_page,
                                            page=serialized_page)
                paging = page_with_paging.paging

                assert paging.current == page

                if backwards:
                    gathered = page_with_paging + gathered
                else:
                    gathered = gathered + page_with_paging

                page = paging.further

                if len(gathered) < len(unpaged):
                    # Ensure each page is the correct size
                    assert paging.has_further
                    assert len(page_with_paging) == per_page
                else:
                    assert not paging.has_further

                if not page_with_paging:
                    assert not paging.has_further
                    assert paging.further == paging.current
                    assert paging.current_opposite == (None,
                                                       not paging.backwards)
                    break

            # Ensure union of pages is original q.all()
            assert gathered == unpaged
Esempio n. 12
0
 def get_paginated_messages_for_user(
     self,
     user_id: int,
     read_status: ReadStatus,
     exclude_author_ids: List[int] = None,
     include_event_types: List[EventTypeDatabaseParameters] = None,
     exclude_event_types: List[EventTypeDatabaseParameters] = None,
     count: Optional[int] = DEFAULT_NB_ITEM_PAGINATION,
     page_token: Optional[int] = None,
 ) -> Page:
     query = self._base_query(
         user_id=user_id,
         read_status=read_status,
         include_event_types=include_event_types,
         exclude_event_types=exclude_event_types,
         exclude_author_ids=exclude_author_ids,
     ).order_by(Message.event_id.desc())
     return get_page(query, per_page=count, page=page_token or False)
Esempio n. 13
0
def view_proposals(vendor: str = None, profile: str = None):
    entries = db.session.query(Vulnerability, Nvd)
    entries = entries.filter(
        Vulnerability.creator == g.user,
        Vulnerability.state != VulnerabilityState.PUBLISHED)
    entries = entries.outerjoin(Vulnerability,
                                Nvd.cve_id == Vulnerability.cve_id)
    entries = entries.order_by(desc(Nvd.id))

    #if existing_user_proposals:
    #    flash_error("No proposals exist so far.")
    #    return

    bookmarked_page = parse_pagination_param("proposal_p")
    per_page = 10
    entries_full = entries.options(default_nvd_view_options)
    proposal_vulns = get_page(entries_full, per_page, page=bookmarked_page)
    proposal_vulns = VulnViewTypesetPaginationObjectWrapper(
        proposal_vulns.paging)
    return render_template("profile/proposals_view.html",
                           proposal_vulns=proposal_vulns)
Esempio n. 14
0
    def get(self, site):
        page = request.args.get('page')
        currentPage = ''
        if page is not None and page != '':
            currentPage = urllib.parse.unquote(page)

        session = Session()

        linksJS = []
        links = get_page(session.query(Link).filter_by(site=site).filter_by(
            public=True).order_by(desc(Link.created), Link.id),
                         per_page=5,
                         page=currentPage)
        next_page = links.paging.bookmark_next
        if links.paging.has_next == False:
            next_page = ""

        for link in links:
            linksJS.append(link.publicJSON())

        session.close()

        return jsonify({'links': linksJS, 'page': next_page})
Esempio n. 15
0
def check_paging_orm(q):
    item_counts = range(1, 12)

    unpaged = q.all()

    for backwards in [False, True]:
        for per_page in item_counts:
            gathered = []

            page = None, backwards

            while True:
                serialized_page = serialize_bookmark(page)
                page = unserialize_bookmark(serialized_page)

                page_with_paging = get_page(q,
                                            per_page=per_page,
                                            page=serialized_page)
                paging = page_with_paging.paging

                assert paging.current == page

                if backwards:
                    gathered = page_with_paging + gathered
                else:
                    gathered = gathered + page_with_paging

                page = paging.further

                if not page_with_paging:
                    assert not paging.has_further
                    assert paging.further == paging.current
                    assert paging.current_opposite == (None,
                                                       not paging.backwards)
                    break

            assert gathered == unpaged
Esempio n. 16
0
    def post_search(self, search_request: STACSearch,
                    **kwargs) -> Dict[str, Any]:
        """POST search catalog."""
        with self.session.reader.context_session() as session:
            token = (self.get_token(search_request.token)
                     if search_request.token else False)
            query = session.query(self.item_table)

            # Filter by collection
            count = None
            if search_request.collections:
                query = query.join(self.collection_table).filter(
                    sa.or_(*[
                        self.collection_table.id == col_id
                        for col_id in search_request.collections
                    ]))

            # Sort
            if search_request.sortby:
                sort_fields = [
                    getattr(self.item_table.get_field(sort.field),
                            sort.direction.value)()
                    for sort in search_request.sortby
                ]
                sort_fields.append(self.item_table.id)
                query = query.order_by(*sort_fields)
            else:
                # Default sort is date
                query = query.order_by(self.item_table.datetime.desc(),
                                       self.item_table.id)

            # Ignore other parameters if ID is present
            if search_request.ids:
                id_filter = sa.or_(
                    *[self.item_table.id == i for i in search_request.ids])
                items = query.filter(id_filter).order_by(self.item_table.id)
                page = get_page(items,
                                per_page=search_request.limit,
                                page=token)
                if self.extension_is_enabled(ContextExtension):
                    count = len(search_request.ids)
                page.next = (self.insert_token(
                    keyset=page.paging.bookmark_next)
                             if page.paging.has_next else None)
                page.previous = (self.insert_token(
                    keyset=page.paging.bookmark_previous)
                                 if page.paging.has_previous else None)

            else:
                # Spatial query
                poly = None
                if search_request.intersects is not None:
                    poly = shape(search_request.intersects)
                elif search_request.bbox:
                    poly = ShapelyPolygon.from_bounds(*search_request.bbox)

                if poly:
                    filter_geom = ga.shape.from_shape(poly, srid=4326)
                    query = query.filter(
                        ga.func.ST_Intersects(self.item_table.geometry,
                                              filter_geom))

                # Temporal query
                if search_request.datetime:
                    # Two tailed query (between)
                    if ".." not in search_request.datetime:
                        query = query.filter(
                            self.item_table.datetime.between(
                                *search_request.datetime))
                    # All items after the start date
                    if search_request.datetime[0] != "..":
                        query = query.filter(self.item_table.datetime >=
                                             search_request.datetime[0])
                    # All items before the end date
                    if search_request.datetime[1] != "..":
                        query = query.filter(self.item_table.datetime <=
                                             search_request.datetime[1])

                # Query fields
                if search_request.query:
                    for (field_name, expr) in search_request.query.items():
                        field = self.item_table.get_field(field_name)
                        for (op, value) in expr.items():
                            query = query.filter(op.operator(field, value))

                if self.extension_is_enabled(ContextExtension):
                    count_query = query.statement.with_only_columns(
                        [func.count()]).order_by(None)
                    count = query.session.execute(count_query).scalar()
                page = get_page(query,
                                per_page=search_request.limit,
                                page=token)
                # Create dynamic attributes for each page
                page.next = (self.insert_token(
                    keyset=page.paging.bookmark_next)
                             if page.paging.has_next else None)
                page.previous = (self.insert_token(
                    keyset=page.paging.bookmark_previous)
                                 if page.paging.has_previous else None)

            links = []
            if page.next:
                links.append(
                    PaginationLink(
                        rel=Relations.next,
                        type="application/geo+json",
                        href=f"{kwargs['request'].base_url}search",
                        method="POST",
                        body={"token": page.next},
                        merge=True,
                    ))
            if page.previous:
                links.append(
                    PaginationLink(
                        rel=Relations.previous,
                        type="application/geo+json",
                        href=f"{kwargs['request'].base_url}search",
                        method="POST",
                        body={"token": page.previous},
                        merge=True,
                    ))

            response_features = []
            filter_kwargs = {}
            if self.extension_is_enabled(FieldsExtension):
                if search_request.query is not None:
                    query_include: Set[str] = set([
                        k if k in Settings.get().indexed_fields else
                        f"properties.{k}" for k in search_request.query.keys()
                    ])
                    if not search_request.field.include:
                        search_request.field.include = query_include
                    else:
                        search_request.field.include.union(query_include)

                filter_kwargs = search_request.field.filter_fields

            xvals = []
            yvals = []
            for item in page:
                item.base_url = str(kwargs["request"].base_url)
                item_model = schemas.Item.from_orm(item)
                xvals += [item_model.bbox[0], item_model.bbox[2]]
                yvals += [item_model.bbox[1], item_model.bbox[3]]
                response_features.append(item_model.to_dict(**filter_kwargs))

        try:
            bbox = (min(xvals), min(yvals), max(xvals), max(yvals))
        except ValueError:
            bbox = None

        context_obj = None
        if self.extension_is_enabled(ContextExtension):
            context_obj = {
                "returned": len(page),
                "limit": search_request.limit,
                "matched": count,
            }

        return {
            "type": "FeatureCollection",
            "context": context_obj,
            "features": response_features,
            "links": links,
            "bbox": bbox,
        }
Esempio n. 17
0
    def __init__(self):

        self.keyword = None
        # TODO: Look into neabling this once public contributions are enabled.
        # self.top_contributors = []
        # self.fetch_top_contributors()

        vcdb_entries = db.session.query(Vulnerability, Nvd)
        vcdb_entries = vcdb_entries.outerjoin(
            Nvd, Vulnerability.cve_id == Nvd.cve_id)
        vcdb_entries = vcdb_entries.options(default_nvd_view_options)
        vcdb_entries = vcdb_entries.order_by(
            asc(Vulnerability.date_created), desc(Vulnerability.id))
        self.vcdb_entries = vcdb_entries

        nvd_entries = db.session.query(Nvd)
        nvd_entries = nvd_entries.outerjoin(Vulnerability,
                                            Nvd.cve_id == Vulnerability.cve_id)
        nvd_entries = nvd_entries.options(default_nvd_view_options)
        nvd_entries = nvd_entries.filter(Vulnerability.cve_id.is_(None))
        nvd_entries = nvd_entries.order_by(
            desc(Nvd.published_date), desc(Nvd.id))
        self.nvd_entries = nvd_entries

        self.keyword = request.args.get("keyword", None, type=str)

        apply_filter = None
        if self.keyword:
            # TODO: Make the filtering work with fulltext search as well.
            if VulnerabilityDetails.is_cve_id(self.keyword):
                apply_filter = or_(False, Nvd.cve_id == self.keyword)
            elif VulnerabilityDetails.is_vcdb_id(self.keyword):
                apply_filter = or_(False, Vulnerability.id == self.keyword)
            else:
                escaped_keyword = self.keyword.replace("%", "")
                # escaped_keyword = re.sub('[\W]+', ' ', self.keyword)
                # Attention: We can't use FullText search here because of some buggy
                # Mysql 5.7 behavior (using FullText on Join results seems is doing bad
                # things. We might need to apply the filter before joining below.
                # apply_filter = or_(
                #     FullTextSearch(escaped_keyword, Nvd, FullTextMode.BOOLEAN),
                #     FullTextSearch(escaped_keyword, Vulnerability, FullTextMode.BOOLEAN))
                apply_filter = or_(
                    Nvd.descriptions.any(
                        Description.value.like("%" + escaped_keyword + "%")),
                    Vulnerability.comment.like("%" + escaped_keyword + "%"),
                )

            # TODO: add product search support.
            # apply_filter = or_(apply_filter, Cpe.product == keyword)

        if apply_filter is not None:
            self.vcdb_entries = self.vcdb_entries.filter(apply_filter)
            self.nvd_entries = self.nvd_entries.filter(apply_filter)

        per_page = 7
        vcdb_page = request.args.get("vcdb_p", 1, type=int)
        self.vcdb_pagination = self.vcdb_entries.paginate(
            vcdb_page, per_page=per_page)
        self.vcdb_pagination = VulnViewSqlalchemyPaginationObjectWrapper(
            self.vcdb_pagination)

        def filter_pagination_param(param):
            filtered = re.sub('[^a-zA-Z\d\- <>:~]', '', param)
            return filtered

        nvd_bookmarked_page = request.args.get('nvd_p', None)
        if nvd_bookmarked_page:
            nvd_bookmarked_page = filter_pagination_param(nvd_bookmarked_page)
            nvd_bookmarked_page = unserialize_bookmark(nvd_bookmarked_page)

        self.nvd_pagination = get_page(
            self.nvd_entries, per_page, page=nvd_bookmarked_page)
        self.nvd_pagination = VulnViewTypesetPaginationObjectWrapper(
            self.nvd_pagination.paging)
        num_nvd_entries = db.session.query(Nvd).count()
        num_vuln_entries = db.session.query(Vulnerability).count()
        num_unique_nvd_estimate = num_nvd_entries - num_vuln_entries
        self.nvd_pagination.set_total(num_unique_nvd_estimate)
Esempio n. 18
0
    def item_collection(self,
                        id: str,
                        limit: int = 10,
                        token: str = None,
                        **kwargs) -> ItemCollection:
        """Read an item collection from the database."""
        with self.session.reader.context_session() as session:
            collection_children = (session.query(self.item_table).join(
                self.collection_table).filter(
                    self.collection_table.id == id).order_by(
                        self.item_table.datetime.desc(), self.item_table.id))
            count = None
            if self.extension_is_enabled(ContextExtension):
                count_query = collection_children.statement.with_only_columns(
                    [func.count()]).order_by(None)
                count = collection_children.session.execute(
                    count_query).scalar()
            token = self.get_token(token) if token else token
            page = get_page(collection_children,
                            per_page=limit,
                            page=(token or False))
            # Create dynamic attributes for each page
            page.next = (self.insert_token(keyset=page.paging.bookmark_next)
                         if page.paging.has_next else None)
            page.previous = (self.insert_token(
                keyset=page.paging.bookmark_previous)
                             if page.paging.has_previous else None)

            links = []
            if page.next:
                links.append(
                    PaginationLink(
                        rel=Relations.next,
                        type="application/geo+json",
                        href=
                        f"{kwargs['request'].base_url}collections/{id}/items?token={page.next}&limit={limit}",
                        method="GET",
                    ))
            if page.previous:
                links.append(
                    PaginationLink(
                        rel=Relations.previous,
                        type="application/geo+json",
                        href=
                        f"{kwargs['request'].base_url}collections/{id}/items?token={page.previous}&limit={limit}",
                        method="GET",
                    ))

            response_features = []
            for item in page:
                item.base_url = str(kwargs["request"].base_url)
                response_features.append(schemas.Item.from_orm(item))

            context_obj = None
            if self.extension_is_enabled(ContextExtension):
                context_obj = {
                    "returned": len(page),
                    "limit": limit,
                    "matched": count
                }

            return ItemCollection(
                type="FeatureCollection",
                context=context_obj,
                features=response_features,
                links=links,
            )
Esempio n. 19
0
    def item_collection(self,
                        id: str,
                        limit: int = 10,
                        token: str = None,
                        **kwargs) -> ItemCollection:
        """Read an item collection from the database."""
        base_url = str(kwargs["request"].base_url)
        with self.session.reader.context_session() as session:
            collection_children = (session.query(self.item_table).join(
                self.collection_table).filter(
                    self.collection_table.id == id).order_by(
                        self.item_table.datetime.desc(), self.item_table.id))
            count = None
            if self.extension_is_enabled(ContextExtension):
                count_query = collection_children.statement.with_only_columns(
                    [func.count()]).order_by(None)
                count = collection_children.session.execute(
                    count_query).scalar()
            token = self.get_token(token) if token else token
            page = get_page(collection_children,
                            per_page=limit,
                            page=(token or False))
            # Create dynamic attributes for each page
            page.next = (self.insert_token(keyset=page.paging.bookmark_next)
                         if page.paging.has_next else None)
            page.previous = (self.insert_token(
                keyset=page.paging.bookmark_previous)
                             if page.paging.has_previous else None)

            links = []
            if page.next:
                links.append({
                    "rel": Relations.next.value,
                    "type": "application/geo+json",
                    "href":
                    f"{kwargs['request'].base_url}collections/{id}/items?token={page.next}&limit={limit}",
                    "method": "GET",
                })
            if page.previous:
                links.append({
                    "rel": Relations.previous.value,
                    "type": "application/geo+json",
                    "href":
                    f"{kwargs['request'].base_url}collections/{id}/items?token={page.previous}&limit={limit}",
                    "method": "GET",
                })

            response_features = []
            for item in page:
                response_features.append(
                    self.item_serializer.db_to_stac(item, base_url=base_url))

            context_obj = None
            if self.extension_is_enabled(ContextExtension):
                context_obj = {
                    "returned": len(page),
                    "limit": limit,
                    "matched": count,
                }

            # TODO: return stac_extensions
            return ItemCollection(
                type="FeatureCollection",
                stac_version=STAC_VERSION,
                features=response_features,
                links=links,
                context=context_obj,
            )
Esempio n. 20
0
    def post_search(self, search_request: SQLAlchemySTACSearch,
                    **kwargs) -> ItemCollection:
        """POST search catalog."""
        base_url = str(kwargs["request"].base_url)
        with self.session.reader.context_session() as session:
            token = (self.get_token(search_request.token)
                     if search_request.token else False)
            query = session.query(self.item_table)

            # Filter by collection
            count = None
            if search_request.collections:
                query = query.join(self.collection_table).filter(
                    sa.or_(*[
                        self.collection_table.id == col_id
                        for col_id in search_request.collections
                    ]))

            # Sort
            if search_request.sortby:
                sort_fields = [
                    getattr(
                        self.item_table.get_field(sort.field),
                        sort.direction.value,
                    )() for sort in search_request.sortby
                ]
                sort_fields.append(self.item_table.id)
                query = query.order_by(*sort_fields)
            else:
                # Default sort is date
                query = query.order_by(self.item_table.datetime.desc(),
                                       self.item_table.id)

            # Ignore other parameters if ID is present
            if search_request.ids:
                id_filter = sa.or_(
                    *[self.item_table.id == i for i in search_request.ids])
                items = query.filter(id_filter).order_by(self.item_table.id)
                page = get_page(items,
                                per_page=search_request.limit,
                                page=token)
                if self.extension_is_enabled(ContextExtension):
                    count = len(search_request.ids)
                page.next = (self.insert_token(
                    keyset=page.paging.bookmark_next)
                             if page.paging.has_next else None)
                page.previous = (self.insert_token(
                    keyset=page.paging.bookmark_previous)
                                 if page.paging.has_previous else None)

            else:
                # Spatial query
                poly = None
                if search_request.intersects is not None:
                    poly = shape(search_request.intersects)
                elif search_request.bbox:
                    poly = ShapelyPolygon.from_bounds(*search_request.bbox)

                if poly:
                    filter_geom = ga.shape.from_shape(poly, srid=4326)
                    query = query.filter(
                        ga.func.ST_Intersects(self.item_table.geometry,
                                              filter_geom))

                # Temporal query
                if search_request.datetime:
                    # Two tailed query (between)
                    dts = search_request.datetime.split("/")
                    if ".." not in search_request.datetime:
                        query = query.filter(
                            self.item_table.datetime.between(*dts))
                    # All items after the start date
                    if dts[0] != "..":
                        query = query.filter(
                            self.item_table.datetime >= dts[0])
                    # All items before the end date
                    if dts[1] != "..":
                        query = query.filter(
                            self.item_table.datetime <= dts[1])

                # Query fields
                if search_request.query:
                    for (field_name, expr) in search_request.query.items():
                        field = self.item_table.get_field(field_name)
                        for (op, value) in expr.items():
                            query = query.filter(op.operator(field, value))

                if self.extension_is_enabled(ContextExtension):
                    count_query = query.statement.with_only_columns(
                        [func.count()]).order_by(None)
                    count = query.session.execute(count_query).scalar()
                page = get_page(query,
                                per_page=search_request.limit,
                                page=token)
                # Create dynamic attributes for each page
                page.next = (self.insert_token(
                    keyset=page.paging.bookmark_next)
                             if page.paging.has_next else None)
                page.previous = (self.insert_token(
                    keyset=page.paging.bookmark_previous)
                                 if page.paging.has_previous else None)

            links = []
            if page.next:
                links.append({
                    "rel": Relations.next.value,
                    "type": "application/geo+json",
                    "href": f"{kwargs['request'].base_url}search",
                    "method": "POST",
                    "body": {
                        "token": page.next
                    },
                    "merge": True,
                })
            if page.previous:
                links.append({
                    "rel": Relations.previous.value,
                    "type": "application/geo+json",
                    "href": f"{kwargs['request'].base_url}search",
                    "method": "POST",
                    "body": {
                        "token": page.previous
                    },
                    "merge": True,
                })

            response_features = []
            for item in page:
                response_features.append(
                    self.item_serializer.db_to_stac(item, base_url=base_url))

            # Use pydantic includes/excludes syntax to implement fields extension
            if self.extension_is_enabled(FieldsExtension):
                if search_request.query is not None:
                    query_include: Set[str] = set([
                        k if k in Settings.get().indexed_fields else
                        f"properties.{k}" for k in search_request.query.keys()
                    ])
                    if not search_request.field.include:
                        search_request.field.include = query_include
                    else:
                        search_request.field.include.union(query_include)

                filter_kwargs = search_request.field.filter_fields
                # Need to pass through `.json()` for proper serialization
                # of datetime
                response_features = [
                    json.loads(
                        stac_pydantic.Item(**feat).json(**filter_kwargs))
                    for feat in response_features
                ]

        context_obj = None
        if self.extension_is_enabled(ContextExtension):
            context_obj = {
                "returned": len(page),
                "limit": search_request.limit,
                "matched": count,
            }

        # TODO: return stac_extensions
        return ItemCollection(
            type="FeatureCollection",
            stac_version=STAC_VERSION,
            features=response_features,
            links=links,
            context=context_obj,
        )
Esempio n. 21
0
    def item_collection(
        self, id: str, limit: int = 10, token: str = None, **kwargs
    ) -> ItemCollection:
        """Read an item collection from the database"""
        try:
            collection_children = (
                self.reader_session.query(self.table)
                .join(self.collection_table)
                .filter(self.collection_table.id == id)
                .order_by(self.table.datetime.desc(), self.table.id)
            )
            count = None
            if config.settings.api_extension_is_enabled(config.ApiExtensions.context):
                count_query = collection_children.statement.with_only_columns(
                    [func.count()]
                ).order_by(None)
                count = collection_children.session.execute(count_query).scalar()
            token = self.pagination_client.get(token) if token else token
            page = get_page(collection_children, per_page=limit, page=(token or False))
            # Create dynamic attributes for each page
            page.next = (
                self.pagination_client.insert(keyset=page.paging.bookmark_next)
                if page.paging.has_next
                else None
            )
            page.previous = (
                self.pagination_client.insert(keyset=page.paging.bookmark_previous)
                if page.paging.has_previous
                else None
            )
        except errors.NotFoundError:
            raise
        except Exception as e:
            logger.error(e, exc_info=True)
            raise errors.DatabaseError(
                "Unhandled database error when getting collection children"
            )

        links = []
        if page.next:
            links.append(
                PaginationLink(
                    rel=Relations.next,
                    type="application/geo+json",
                    href=f"{kwargs['request'].base_url}collections/{id}/items?token={page.next}&limit={limit}",
                    method="GET",
                )
            )
        if page.previous:
            links.append(
                PaginationLink(
                    rel=Relations.previous,
                    type="application/geo+json",
                    href=f"{kwargs['request'].base_url}collections/{id}/items?token={page.previous}&limit={limit}",
                    method="GET",
                )
            )

        response_features = []
        for item in page:
            item.base_url = str(kwargs["request"].base_url)
            response_features.append(schemas.Item.from_orm(item))

        context_obj = None
        if config.settings.api_extension_is_enabled(ApiExtensions.context):
            context_obj = {"returned": len(page), "limit": limit, "matched": count}

        return ItemCollection(
            type="FeatureCollection",
            context=context_obj,
            features=response_features,
            links=links,
        )
Esempio n. 22
0
    def __init__(self):

        self.keyword = None
        self.top_contributors = []
        # TODO: Look into neabling this once public contributions are enabled.
        # self.fetch_top_contributors()

        has_annotations_col = Vulnerability.has_annotations
        vcdb_entries = db.session.query(Vulnerability, Nvd,
                                        has_annotations_col)
        vcdb_entries = vcdb_entries.filter(
            Vulnerability.state == VulnerabilityState.PUBLISHED)
        vcdb_entries = vcdb_entries.outerjoin(
            Nvd, Vulnerability.cve_id == Nvd.cve_id)
        vcdb_entries = vcdb_entries.options(default_nvd_view_options)
        vcdb_entries = vcdb_entries.from_self()
        vcdb_entries = vcdb_entries.order_by(
            desc(has_annotations_col),
            asc(Vulnerability.date_created),
            desc(Vulnerability.id),
        )
        self.vcdb_entries = vcdb_entries

        nvd_entries = db.session.query(Nvd)
        nvd_entries = nvd_entries.outerjoin(Vulnerability,
                                            Nvd.cve_id == Vulnerability.cve_id)
        nvd_entries = nvd_entries.options(default_nvd_view_options)
        nvd_entries = nvd_entries.filter(Vulnerability.cve_id.is_(None))
        nvd_entries = nvd_entries.order_by(desc(Nvd.published_date),
                                           desc(Nvd.id))
        self.nvd_entries = nvd_entries

        self.keyword = request.args.get("keyword", None, type=str)

        apply_filter = None
        if self.keyword:
            # TODO: Make the filtering work with fulltext search as well.
            if VulnerabilityDetails.is_cve_id(self.keyword):
                apply_filter = or_(False, Nvd.cve_id == self.keyword)
            elif VulnerabilityDetails.is_vcdb_id(self.keyword):
                apply_filter = or_(False, Vulnerability.id == self.keyword)
            else:
                escaped_keyword = self.keyword.replace("%", "")
                # escaped_keyword = re.sub('[\W]+', ' ', self.keyword)
                # Attention: We can't use FullText search here because of some
                # buggy Mysql 5.7 behavior (using FullText on Join results seems
                # is doing bad things. We might need to apply the filter before
                # joining below.
                # apply_filter = or_(
                #     FullTextSearch(escaped_keyword, Nvd,
                #                    FullTextMode.BOOLEAN),
                #     FullTextSearch(escaped_keyword, Vulnerability,
                #                    FullTextMode.BOOLEAN))
                apply_filter = or_(
                    Nvd.descriptions.any(
                        Description.value.like("%" + escaped_keyword + "%")),
                    Vulnerability.comment.like("%" + escaped_keyword + "%"),
                )

            # TODO: add product search support.
            # apply_filter = or_(apply_filter, Cpe.product == keyword)

        if apply_filter is not None:
            self.vcdb_entries = self.vcdb_entries.filter(apply_filter)
            self.nvd_entries = self.nvd_entries.filter(apply_filter)

        per_page = 7
        vcdb_bookmarked_page = parse_pagination_param("vcdb_p")
        # Replace a sqlakeyset function to support our use case.
        # TODO: File a PR for this?
        sqlakeyset.paging.value_from_thing = custom_value_from_thing
        self.vcdb_pagination = get_page(self.vcdb_entries,
                                        per_page,
                                        page=vcdb_bookmarked_page)
        self.vcdb_pagination = VulnViewTypesetPaginationObjectWrapper(
            self.vcdb_pagination.paging)
        num_vuln_entries = db.session.query(func.count(
            Vulnerability.id)).scalar()
        self.vcdb_pagination.set_total(num_vuln_entries)

        nvd_bookmarked_page = parse_pagination_param("nvd_p")
        self.nvd_pagination = get_page(self.nvd_entries,
                                       per_page,
                                       page=nvd_bookmarked_page)
        self.nvd_pagination = VulnViewTypesetPaginationObjectWrapper(
            self.nvd_pagination.paging)
        num_nvd_entries = db.session.query(func.count(Nvd.id)).scalar()
        num_unique_nvd_estimate = num_nvd_entries - num_vuln_entries
        self.nvd_pagination.set_total(num_unique_nvd_estimate)
Esempio n. 23
0
def test_warn_when_sorting_by_nullable(dburl):
    with pytest.warns(UserWarning):
        with S(dburl, echo=ECHO) as s:
            q = s.query(Book).order_by(Book.a, Book.id)
            get_page(q, per_page=10, page=(None, False))