Exemplo n.º 1
0
def sample_n_rows_with_repeat(query: Query, num_of_rows: int) -> Iterator[VocabularyTerm]:
    out = []
    count = query.count()
    for _ in range(num_of_rows):
        row = query.offset(math.floor(random.random() * count)).first()
        out.append(row)
    return out
Exemplo n.º 2
0
 def response(self, query: orm.Query) -> PaginatedResponse:
     count = query.count()
     self._response.headers.update(self.headers(query, count))
     return {
         "results": list(self.paginate(query, count)),
         "count": count,
     }
Exemplo n.º 3
0
def paginate(model, query: Query,
             params: Optional[PaginationParams]) -> BasePage:
    code = '200'
    message = 'Success'

    try:
        total = query.count()

        if params.order:
            direction = desc if params.order == 'desc' else asc
            data = query.order_by(direction(getattr(model, params.sort_by))) \
                .limit(params.page_size)\
                .offset(params.page_size * params.page) \
                .all()
        else:
            data = query.limit(params.page_size).offset(params.page_size *
                                                        params.page).all()

        metadata = MetadataSchema(current_page=params.page,
                                  page_size=params.page_size,
                                  total_items=total)

    except Exception as e:
        raise CustomException(http_code=500, code='500', message=str(e))

    return PageType.get().create(code, message, data, metadata)
Exemplo n.º 4
0
def paginate(query: Query, params: Optional[AbstractParams] = None) -> AbstractPage:
    params = resolve_params(params)

    total = query.count()
    items = paginate_query(query, params).all()

    return create_page(items, total, params)
Exemplo n.º 5
0
def paginate_mods(mods: Query, page_size: int = 30) -> Tuple[List[Mod], int, int]:
    total_pages = math.ceil(mods.count() / page_size)
    page = get_page()
    if page > total_pages:
        page = total_pages
    if page < 1:
        page = 1
    return mods.offset(page_size * (page - 1)).limit(page_size), page, total_pages
Exemplo n.º 6
0
def paginate(query: Query, page: int, per_page: int, include_total=False):
    start = (page - 1) * per_page

    count = query.count()
    query = query.slice(start, start + per_page)

    if not include_total:
        return query

    return query, ceil(count / per_page)
Exemplo n.º 7
0
    def test_3_queries(self):
        session = self.get_session()
        p = session.query(User).first()
        print(f"{p.name}")
        all_rec = [rec for rec in session.query(Channel).all()]
        self.assertEqual(len(all_rec), 3)
        #
        # Check the Join table is there
        #
        all_user = [rec for rec in session.query(UserChannel).all()]
        self.assertEqual(len(all_user), 3)

        # Query with no Join ... Bad result
        q = Query([User, Channel, UserChannel], session=session)
        self.assertEqual(q.count(), 18)

        q = Query([User, Channel, UserChannel], session=session). \
            filter(User.user_id == UserChannel.user_id). \
            filter(Channel.channel_id == UserChannel.channel_id)
        self.assertEqual(q.count(), 3)
Exemplo n.º 8
0
def sample_n_rows_no_repeat(query: Query, num_of_rows: int) -> Iterable[VocabularyTerm]:  # Assumption: primary key
    out = {}
    count = query.count()
    pct = num_of_rows / count
    while len(out) < num_of_rows:
        sample = query.filter(func.random() < pct)
        for row in sample:
            out[row.id] = row
            if len(out) >= num_of_rows:
                break
    return out.values()
def cancel_expired_bookings(query: Query, batch_size: int = 500):
    expiring_bookings_count = query.count()
    logger.info("[cancel_expired_bookings] %d expiring bookings to cancel",
                expiring_bookings_count)
    if expiring_bookings_count == 0:
        return

    updated_total = 0
    expiring_booking_ids = bookings_repository.find_expiring_booking_ids_from_query(
        query).limit(batch_size).all()
    max_id = expiring_booking_ids[-1][0]

    # we commit here to make sure there is no unexpected objects in SQLA cache before the update,
    # as we use synchronize_session=False
    db.session.commit()

    while expiring_booking_ids:
        updated = (Booking.query.filter(Booking.id <= max_id).filter(
            Booking.id.in_(expiring_booking_ids)).update(
                {
                    "isCancelled": True,
                    "status": BookingStatus.CANCELLED,
                    "cancellationReason": BookingCancellationReasons.EXPIRED,
                    "cancellationDate": datetime.datetime.utcnow(),
                },
                synchronize_session=False,
            ))
        # Recompute denormalized stock quantity
        stocks_to_recompute = [
            row[0] for row in db.session.query(Booking.stockId).filter(
                Booking.id.in_(expiring_booking_ids)).distinct().all()
        ]
        recompute_dnBookedQuantity(stocks_to_recompute)
        db.session.commit()

        updated_total += updated
        expiring_booking_ids = bookings_repository.find_expiring_booking_ids_from_query(
            query).limit(batch_size).all()
        if expiring_booking_ids:
            max_id = expiring_booking_ids[-1][0]
        logger.info(
            "[cancel_expired_bookings] %d Bookings have been cancelled in this batch",
            updated,
        )

    logger.info(
        "[cancel_expired_bookings] %d Bookings have been cancelled",
        updated_total,
    )
Exemplo n.º 10
0
def paginate_query(query: Query, page: int = 1, page_count: Optional[int] = None) -> (list, int, int):
    if page < 0 or (page_count and page_count < 0):
        raise ValueError('page and page count values should be non-negative')
    if page_count is None or page == page_count:
        title_count = query.count()
        page_count = 1 + (title_count - 1) // PAGE_SIZE
        page = min(page, page_count)
    if page_count == 0:
        return [], page, page_count
    last_index = page * PAGE_SIZE
    first_index = last_index - PAGE_SIZE
    titles = query[first_index:last_index]
    if not titles:
        return paginate_query(query, page)
    return titles, page, page_count
Exemplo n.º 11
0
def _fetch_dag_runs(
    query: Query,
    *,
    end_date_gte: Optional[str],
    end_date_lte: Optional[str],
    execution_date_gte: Optional[str],
    execution_date_lte: Optional[str],
    start_date_gte: Optional[str],
    start_date_lte: Optional[str],
    limit: Optional[int],
    offset: Optional[int],
    order_by: str,
) -> Tuple[List[DagRun], int]:
    if start_date_gte:
        query = query.filter(DagRun.start_date >= start_date_gte)
    if start_date_lte:
        query = query.filter(DagRun.start_date <= start_date_lte)
    # filter execution date
    if execution_date_gte:
        query = query.filter(DagRun.execution_date >= execution_date_gte)
    if execution_date_lte:
        query = query.filter(DagRun.execution_date <= execution_date_lte)
    # filter end date
    if end_date_gte:
        query = query.filter(DagRun.end_date >= end_date_gte)
    if end_date_lte:
        query = query.filter(DagRun.end_date <= end_date_lte)

    total_entries = query.count()
    to_replace = {"dag_run_id": "run_id"}
    allowed_filter_attrs = [
        "id",
        "state",
        "dag_id",
        "execution_date",
        "dag_run_id",
        "start_date",
        "end_date",
        "external_trigger",
        "conf",
    ]
    query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
    return query.offset(offset).limit(limit).all(), total_entries
Exemplo n.º 12
0
    def create_paginated_response(
        self,
        items_per_page: int,
        query: Query,
        schema: SchemaMeta,
        page_index: Optional[int] = None,
        start_index: Optional[int] = None,
    ) -> dict:
        # Validate
        assert isinstance(query, Query)
        assert issubclass(schema, BaseSchema)
        assert isinstance(items_per_page, int)

        total_items = query.count()
        start_index, stop_index, page_index, total_pages = self._get_pagination_params(
            page_index, start_index, items_per_page, total_items)

        assert isinstance(start_index, int)

        # Slice
        items = query.slice(start_index - 1, stop_index)
        current_item_count = items.count()

        # Serialize
        items = schema(many=True).dump(items)

        # Prepare result
        data = dict(
            items_per_page=items_per_page,
            current_item_count=current_item_count,
            page_index=page_index,
            start_index=start_index,
            total_items=total_items,
            total_pages=total_pages,
            items=items,
        )

        response = self._create_response(data=data)

        return APIPaginatedResponseSchema().dump(response)
Exemplo n.º 13
0
    def execute_query(self, query: Query, use_list=True, count=False):
        """
        Executes a new database query and return its result.

        :param query: The SQLAlchemy query object
        :param use_list: Flag to specify if want the returned elements as a list or not
        :param count: Flag to specify if we only want to count the rows that match the query
        :return: The query result
        """
        try:
            if self._session_object is not None:
                query = query.with_session(self._session_object)
            if count:
                return query.count()
            elif use_list:
                return query.all()
            else:
                return query.first()
        except exc.SQLAlchemyError as e:
            self.app.logger.error(
                "Can't execute the requested query. Details: %s", str(e))
            raise DBInternalError("Can't execute the requested query")
Exemplo n.º 14
0
def paginated_update(
    query: Query,
    print_page_progress: Optional[Union[Callable[[int, int], None],
                                        bool]] = None,
    batch_size: int = DEFAULT_BATCH_SIZE,
) -> Iterator[Any]:
    """
    Update models in small batches so we don't have to load everything in memory.
    """
    start = 0
    count = query.count()
    session: Session = inspect(query).session
    if print_page_progress is None or print_page_progress is True:
        print_page_progress = lambda current, total: print(
            f"    {current}/{total}", end="\r")
    while start < count:
        end = min(start + batch_size, count)
        for obj in query[start:end]:
            yield obj
            session.merge(obj)
        session.commit()
        if print_page_progress:
            print_page_progress(end, count)
        start += batch_size
Exemplo n.º 15
0
def _query_with_filters(
    response: Response,
    query: Query,
    range: Optional[List[int]] = None,
    sort: Optional[List[str]] = None,
    filters: Optional[List[str]] = None,
) -> List:
    if filters is not None:
        for filter in chunked(filters, 2):
            if filter and len(filter) == 2:
                field = filter[0]
                value = filter[1]
                value_as_bool = value.lower() in ("yes", "y", "ye", "true",
                                                  "1", "ja", "insync")
                if value is not None:
                    if field.endswith("_gt"):
                        query = query.filter(
                            SubscriptionTable.__dict__[field[:-3]] > value)
                    elif field.endswith("_gte"):
                        query = query.filter(
                            SubscriptionTable.__dict__[field[:-4]] >= value)
                    elif field.endswith("_lte"):
                        query = query.filter(
                            SubscriptionTable.__dict__[field[:-4]] <= value)
                    elif field.endswith("_lt"):
                        query = query.filter(
                            SubscriptionTable.__dict__[field[:-3]] < value)
                    elif field.endswith("_ne"):
                        query = query.filter(
                            SubscriptionTable.__dict__[field[:-3]] != value)
                    elif field == "insync":
                        query = query.filter(
                            SubscriptionTable.insync.is_(value_as_bool))
                    elif field == "tags":
                        # For node and port selector form widgets
                        sub_values = value.split("-")
                        query = query.filter(
                            func.lower(ProductTable.tag).in_(
                                [s.lower() for s in sub_values]))
                    elif field == "tag":
                        # For React table 7
                        sub_values = value.split("-")
                        query = query.filter(
                            func.lower(ProductTable.tag).in_(
                                [s.lower() for s in sub_values]))
                    elif field == "product":
                        sub_values = value.split("-")
                        query = query.filter(
                            func.lower(ProductTable.name).in_(
                                [s.lower() for s in sub_values]))
                    elif field == "status":
                        # For React table 7
                        statuses = value.split("-")
                        query = query.filter(
                            SubscriptionTable.status.in_(
                                [s.lower() for s in statuses]))
                    elif field == "statuses":
                        # For port subscriptions
                        sub_values = value.split("-")
                        query = query.filter(
                            SubscriptionTable.status.in_(
                                [s.lower() for s in sub_values]))
                    elif field == "organisation":
                        try:
                            value_as_uuid = UUID(value)
                        except (ValueError, AttributeError):
                            msg = "Not a valid customer_id, must be a UUID: '{value}'"
                            logger.exception(msg)
                            raise_status(HTTPStatus.BAD_REQUEST, msg)
                        query = query.filter(
                            SubscriptionTable.customer_id == value_as_uuid)
                    elif field == "tsv":
                        logger.debug("Running full-text search query.",
                                     value=value)
                        query = query.search(value)
                    elif field in SubscriptionTable.__dict__:
                        query = query.filter(
                            cast(SubscriptionTable.__dict__[field],
                                 String).ilike("%" + value + "%"))

    if sort is not None and len(sort) >= 2:
        for item in chunked(sort, 2):
            if item and len(item) == 2:
                if item[0] in ["product", "tag"]:
                    field = "name" if item[0] == "product" else "tag"
                    if item[1].upper() == "DESC":
                        query = query.order_by(
                            expression.desc(ProductTable.__dict__[field]))
                    else:
                        query = query.order_by(
                            expression.asc(ProductTable.__dict__[field]))
                else:
                    if item[1].upper() == "DESC":
                        query = query.order_by(
                            expression.desc(
                                SubscriptionTable.__dict__[item[0]]))
                    else:
                        query = query.order_by(
                            expression.asc(
                                SubscriptionTable.__dict__[item[0]]))

    if range is not None and len(range) == 2:
        try:
            range_start = int(range[0])
            range_end = int(range[1])
            if range_start >= range_end:
                raise ValueError("range start must be lower than end")
        except (ValueError, AssertionError):
            msg = "Invalid range parameters"
            logger.exception(msg)
            raise_status(HTTPStatus.BAD_REQUEST, msg)
        total = query.count()
        query = query.slice(range_start, range_end)

        response.headers[
            "Content-Range"] = f"subscriptions {range_start}-{range_end}/{total}"

    return query.all()
Exemplo n.º 16
0
 def get_count(self, query:Query):
     """Calculate total item count based on query."""
     return query.count()
Exemplo n.º 17
0
def q_all_tq(query: Query):
    count = query.count()
    all = query.all()
    tq = tqdm(all, total=count, smoothing=0.05)
    return tq
Exemplo n.º 18
0
def query_tq(query: Query):
    count = query.count()
    iterator = query_iterator(query)
    tq = tqdm(iterator, total=count, smoothing=0.05)
    return tq
Exemplo n.º 19
0
    def _paginated_query(
        self,
        query: Query,
        cursor: Optional[HistoryCursor],
        limit: Optional[int],
    ) -> PaginatedHistory[TokenChangeHistoryEntry]:
        """Run a paginated query (one with a limit or a cursor)."""
        limited_query = query

        # Apply the cursor, if there is one.
        if cursor:
            limited_query = self._apply_cursor(limited_query, cursor)

        # When retrieving a previous set of results using a previous
        # cursor, we have to reverse the sort algorithm so that the cursor
        # boundary can be applied correctly.  We'll then later reverse the
        # result set to return it in proper forward-sorted order.
        if cursor and cursor.previous:
            limited_query = limited_query.order_by(
                TokenChangeHistory.event_time, TokenChangeHistory.id
            )
        else:
            limited_query = limited_query.order_by(
                TokenChangeHistory.event_time.desc(),
                TokenChangeHistory.id.desc(),
            )

        # Grab one more element than the query limit so that we know whether
        # to create a cursor (because there are more elements) and what the
        # cursor value should be (for forward cursors).
        if limit:
            limited_query = limited_query.limit(limit + 1)

        # Execute the query twice, once to get the next bach of results and
        # once to get the count of all entries without pagination.
        entries = limited_query.all()
        count = query.count()

        # Calculate the cursors, remove the extra element we asked for, and
        # reverse the results again if we did a reverse sort because we were
        # using a previous cursor.
        prev_cursor = None
        next_cursor = None
        if cursor and cursor.previous:
            if limit:
                next_cursor = HistoryCursor.invert(cursor)
                if len(entries) > limit:
                    prev_cursor = self._build_prev_cursor(entries[limit - 1])
                    entries = entries[:limit]
            entries.reverse()
        elif limit:
            if cursor:
                prev_cursor = HistoryCursor.invert(cursor)
            if len(entries) > limit:
                next_cursor = self._build_next_cursor(entries[limit])
                entries = entries[:limit]

        # Return the results.
        return PaginatedHistory[TokenChangeHistoryEntry](
            entries=[TokenChangeHistoryEntry.from_orm(e) for e in entries],
            count=count,
            prev_cursor=prev_cursor,
            next_cursor=next_cursor,
        )
Exemplo n.º 20
0
 def get_count(self, query:Query):
     """Calculate total item count based on query."""
     return query.count()
Exemplo n.º 21
0
def _query_with_filters(
    response: Response,
    model: BaseModel,
    query: Query,
    range: Optional[List[int]] = None,
    sort: Optional[List[str]] = None,
    filters: Optional[List[str]] = None,
) -> List:
    if filters is not None:
        for filter in chunked(filters, 2):
            if filter and len(filter) == 2:
                field = filter[0]
                value = filter[1]
                value_as_bool = value.lower() in (
                    "yes",
                    "y",
                    "ye",
                    "true",
                    "1",
                    "ja",
                    "insync",
                )
                if value is not None:
                    if field.endswith("_gt"):
                        query = query.filter(
                            model.__dict__[field[:-3]] > value)
                    elif field.endswith("_gte"):
                        query = query.filter(
                            model.__dict__[field[:-4]] >= value)
                    elif field.endswith("_lte"):
                        query = query.filter(
                            model.__dict__[field[:-4]] <= value)
                    elif field.endswith("_lt"):
                        query = query.filter(
                            model.__dict__[field[:-3]] < value)
                    elif field.endswith("_ne"):
                        query = query.filter(
                            model.__dict__[field[:-3]] != value)
                    elif field == "tsv":
                        logger.debug("Running full-text search query.",
                                     value=value)
                        query = query.search(value)
                    elif field in model.__dict__:
                        query = query.filter(
                            cast(model.__dict__[field],
                                 String).ilike("%" + value + "%"))

    if sort is not None and len(sort) >= 2:
        for sort in chunked(sort, 2):
            if sort and len(sort) == 2:
                if sort[1].upper() == "DESC":
                    query = query.order_by(
                        expression.desc(model.__dict__[sort[0]]))
                else:
                    query = query.order_by(
                        expression.asc(model.__dict__[sort[0]]))

    if range is not None and len(range) == 2:
        try:
            range_start = int(range[0])
            range_end = int(range[1])
            assert range_start < range_end
        except (ValueError, AssertionError):
            msg = "Invalid range parameters"
            logger.exception(msg)
            raise_status(HTTPStatus.BAD_REQUEST, msg)
        total = query.count()
        query = query.slice(range_start, range_end)

        response.headers[
            "Content-Range"] = f"items {range_start}-{range_end}/{total}"

    return query.all()