コード例 #1
ファイル: filter.py プロジェクト: hy144328/steins-feed
def filter_dates(
    q: sqla.sql.Select,
    start: datetime.datetime = None,
    finish: datetime.datetime = None,
) -> sqla.sql.Select:
    if start:
        q = q.where(orm_items.Item.Published >= start, )

    if finish:
        q = q.where(orm_items.Item.Published < finish, )

    return q
コード例 #2
    def security(q: Query, stmt: sa.sql.Select) -> sa.sql.Select:
        """ Security: make sure that the user can only access their own data """
        ALLOWED_USER_ID = 1

        path = q.load_path
        if path == (Article, ):
            return stmt.where(q.Model.user_id == ALLOWED_USER_ID)
        elif path == (Article, 'author', User):
            return stmt.where(q.Model.id == ALLOWED_USER_ID)
        elif path == (Article, 'comments', Comment):
            return stmt.where(q.Model.user_id == ALLOWED_USER_ID)
            raise NotImplementedError
コード例 #3
    def apply_to_statement(self, query: QueryObject, target_Model: SAModelOrAlias, stmt: sa.sql.Select) -> sa.sql.Select:
        # Prepare the filter expression
        cursor = self.cursor_value

        if cursor is None:
            filter_expression = True
            limit = self.limit
            # Make sure the columns are still the same
            if set(cursor.cols) != query.sort.names:
                raise exc.QueryObjectError('You cannot adjust "sort" fields while using cursor-based pagination.')

            # Filter
            op = {'>': operator.gt, '<': operator.lt}[cursor.op]
            filter_expression = op(
                    resolve_column_by_name(field.name, target_Model, where='skip')
                    for field in query.sort.fields
            limit = cursor.limit

        if limit is None:
            return stmt

        # Paginate
        # We will always load one more row to check if there's a next page
        if SA_14:
            return stmt.filter(filter_expression).limit(limit + 1)
            return stmt.where(filter_expression).limit(limit + 1)
コード例 #4
def filter_languages(
    q: sqla.sql.Select,
    langs: typing.List[schema_feeds.Language],
) -> sqla.sql.Select:
    langs_name = [e.name for e in langs]
    q = q.where(orm_feeds.Feed.Language.in_(langs_name))

    return q
コード例 #5
def filter_like(
    q: sqla.sql.Select,
    score: schema_items.Like,
    user: orm_users.User,
) -> sqla.sql.Select:
    item_likes = orm_items.Item.likes.and_(
        orm_items.Like.UserID == user.UserID, )

    q = q.join(item_likes)

    q = q.where(orm_items.Like.Score == score.name)

    return q
コード例 #6
    def apply_to_statement(self, stmt: sa.sql.Select) -> sa.sql.Select:
        """ Modify the Select statement: add the WHERE clause """
        # Compile the conditions
        conditions = (self._compile_condition(condition)
                      for condition in self.query.filter.conditions)

        # Add the WHERE clause
        if SA_13:
            stmt = stmt.where(sa.and_(*conditions))
            stmt = stmt.filter(*conditions)

        # Done
        return stmt
コード例 #7
def filter_magic(
    q: sqla.sql.Select,
    user: orm_users.User,
    unscored: bool = True,
) -> sqla.sql.Select:
    item_magic = orm_items.Item.magic.and_(
        orm_items.Magic.UserID == user.UserID, )

    q = q.join(item_magic, isouter=True)

    q_where = orm_items.Item.magic.any(orm_items.Magic.UserID == user.UserID)
    if unscored:
        q_where = ~q_where

    q = q.where(q_where)

    return q
コード例 #8
def filter_tags(
    q: sqla.sql.Select,
    tags: typing.List[orm_feeds.Tag],
    user: orm_users.User,
) -> sqla.sql.Select:
    feed_tags = orm_feeds.Feed.tags.and_(
        orm_feeds.Tag.UserID == user.UserID,
    q = q.join(feed_tags)

    q = q.where(
            [tag_it.TagID for tag_it in tags]

    return q
コード例 #9
    def prepare_query(self, q: sa.sql.Select) -> sa.sql.Select:
        """ Prepare the statement for loading: add columns to select, add filter condition

            q: SELECT statement prepared by QueryExecutor.statement().
               It has no columns yet, but has a select_from(self.target_model), unaliased.
               NOTE: we never alias the target model: the one we're loading. It would've made things too complicated.
        # Use SelectInLoader
        # self.query_info: primary key columns, the IN expression, etc
        # self._parent_alias: used with JOINed relationships where our table has to be joined to an alias of the parent table
        query_info = self.loader._query_info
        parent_alias = self.loader._parent_alias if query_info.load_with_join else NotImplemented
        effective_entity = self.target_model

        # [ADDED] Adapt pk_cols
        # [o] pk_cols = query_info.pk_cols
        # [o] in_expr = query_info.in_expr
        pk_cols = query_info.pk_cols
        in_expr = query_info.in_expr

        # [o] if not query_info.load_with_join:
        if not query_info.load_with_join:
            # [o] if effective_entity.is_aliased_class:
            # [o]     pk_cols = [ effective_entity._adapt_element(col) for col in pk_cols ]
            # [o]     in_expr = effective_entity._adapt_element(in_expr)
            adapter = SimpleColumnsAdapter(self.target_model)
            pk_cols = adapter.replace_many(pk_cols)
            in_expr = adapter.replace(in_expr)

        # [o] bundle_ent = orm_util.Bundle("pk", *pk_cols)
        # [o] entity_sql = effective_entity.__clause_element__()
        # [o] q = Select._create_raw_select(
        # [o]     _raw_columns=[bundle_sql, entity_sql],
        # [o]     _label_style=LABEL_STYLE_TABLENAME_PLUS_COL,
        # [CUSTOMIZED]
        if not query_info.load_with_join:
            q = add_columns(q, pk_cols)  # [CUSTOMIZED]
            # NOTE: we cannot always add our FK columns: when `load_with_join` is used, these columns
            # may actually refer to columns from a M2M table with conflicting names!
            # Example:
            #   SELECT articles.id, tags.id
            #   FROM articles JOIN ... JOIN tags
            # So we have to rename them. We use "table.column" aliases because this horrible "." makes it clear
            # it's not just another column
            # label_prefix = self.source_model.__table__.name + '.'
            self.fk_label_prefix = self.source_model.__tablename__ + '.'  # type: ignore[union-attr]
            q = add_columns(q, [  # [CUSTOMIZED]
                col.label(self.fk_label_prefix + col.key)
                for col in pk_cols

        # Effective entity
        # This is the class that we select from
        # [o] if not query_info.load_with_join:
        # [o]     q = q.select_from(effective_entity)
        # [o] else:
        # [o]     q = q.select_from(self._parent_alias).join(...)
        # [CUSTOMIZED]
        if not query_info.load_with_join:
            q = q.select_from(self.target_model)
            if SA_13:
                q = q.select_from(
                    sa.orm.join(parent_alias, self.target_model, onclause=getattr(parent_alias, self.key).of_type(self.target_model))
                q = q.select_from(parent_alias).join(
                    getattr(parent_alias, self.key).of_type(self.target_model)

        # [o] q = q.filter(in_expr.in_(sql.bindparam("primary_keys")))
        if SA_13:
            q = q.where(in_expr.in_(sa.sql.bindparam("primary_keys", expanding=True)))
            q = q.filter(in_expr.in_(sa.sql.bindparam("primary_keys")))

        return q
コード例 #10
    def _apply_window_over_foreign_key_pagination(
            self, stmt: sa.sql.Select, *,
            fk_columns: list[SAAttribute]) -> sa.sql.Select:
        """ Instead of the usual limit, use a window function over the given columns.

        This method is used with the selectin-load loading strategy to load a limited number of related
        items per every primary entity. Instead of using LIMIT, we will group rows over `fk_columns`,
        and impose a limit per group.

        This is achieved using a Window Function:

            SELECT *, row_number() OVER(PARTITION BY author_id) AS group_row_n
            FROM articles
            WHERE group_row_n < 10

            This will result in the following table:

            id  |   author_id   |   group_row_n
            1       1               1
            2       1               2
            3       2               1
            4       2               2
            5       2               3
            6       3               1
            7       3               2
        skip, limit = self.skip, self.limit

        # Apply it only when there's a limit
        if not skip and not limit:
            return stmt

        # First, add a row counter
        adapter = SimpleColumnsAdapter(self.target_Model)
        row_counter_col = (
                # Groups are partitioned by self._window_over_columns,
                    fk_columns),  # type: ignore[arg-type]
                # We have to apply the same ordering from the outside query;
                # otherwise, the numbering will be undetermined
                        self.target_Model))  # type: ignore[arg-type]
            # give it a name that we can use later
        stmt = add_columns(stmt, [row_counter_col])

        # Wrap ourselves into a subquery.
        # This is necessary because Postgres does not let you reference SELECT aliases in the WHERE clause.
        # Reason: WHERE clause is executed before SELECT
        if SA_14:
            subquery = (
                # Taken from: Query.from_self()
                )  # type: ignore[attr-defined]
            subquery = (stmt.correlate(None).alias())

        stmt = sa.select([
            column for column in subquery.c if column.key !=
            '__group_row_n'  # skip this column. We don't need it.

        # Apply the LIMIT condition using row numbers
        # These two statements simulate skip/limit using window functions
        if skip:
            if SA_14:
                stmt = stmt.filter(
                    sa.sql.literal_column('__group_row_n') > skip)
                stmt = stmt.where(
                    sa.sql.literal_column('__group_row_n') > skip)
        if limit:
            if SA_14:
                stmt = stmt.filter(
                    sa.sql.literal_column('__group_row_n') <= (
                        (skip or 0) + limit))
                stmt = stmt.where(
                    sa.sql.literal_column('__group_row_n') <= (
                        (skip or 0) + limit))

        # Done
        return stmt