Пример #1
0
def test_select_results(connection: sa.engine.Connection, query_object: QueryObjectDict, expected_results: list[dict]):
    """ Typical test: real data, real query, real results """
    # Models
    Base = sacompat.declarative_base()

    class Model(IdManyFieldsMixin, Base):
        __tablename__ = 'a'

        @property
        @loads_attributes_readcode()
        def abc(self):
            return ' '.join((self.a, self.b, self.c))

        @sa.ext.hybrid.hybrid_property
        def abch(self):
            # dict-evaluation is not used: instead, it's selected as an expression
            raise NotImplementedError

        @abch.expression
        def abch(cls):
            return cls.a + ' ' + cls.b + ' ' + cls.c

    # Data
    with created_tables(connection, Base):
        # Insert some rows
        insert(connection, Model,
            id_manyfields('m', 1),
        )

        # Test
        typical_test_query_results(connection, query_object, Model, expected_results)
Пример #2
0
def test_filter_sql(connection: sa.engine.Connection,
                    query_object: QueryObjectDict,
                    expected_query_lines: list[str]):
    """ Typical test: what SQL is generated """
    # Models
    Base = sacompat.declarative_base()

    class Model(IdManyFieldsMixin, Base):
        __tablename__ = 'a'

        # This Postgres-specific implementation has .contains() and .overlaps() implementations
        tags = sa.Column(pg.ARRAY(sa.String))

        # A hybrid property can be used in expressions as well
        @sa.ext.hybrid.hybrid_property
        def awow(self):
            pass

        @awow.expression
        def awow(cls):
            return cls.a + '!'

        related = sa.orm.relationship('Related', back_populates='parent')

    class Related(IdManyFieldsMixin, Base):
        __tablename__ = 'r'

        parent_id = sa.Column(sa.ForeignKey(Model.id))
        parent = sa.orm.relationship(Model, back_populates='related')

    # Test
    typical_test_sql_query_text(query_object, Model, expected_query_lines)
Пример #3
0
def test_joined_sort(connection: sa.engine.Connection, query_object: QueryObjectDict, expected_query_lines: list[str], expected_results: list[dict]):
    """ Typical test: JOINs, SQL and results """
    # Models
    Base = sacompat.declarative_base()

    class User(IdManyFieldsMixin, Base):
        __tablename__ = 'u'

        articles = sa.orm.relationship('Article', back_populates='author')

    class Article(IdManyFieldsMixin, Base):
        __tablename__ = 'a'

        user_id = sa.Column(sa.ForeignKey(User.id))
        author = sa.orm.relationship(User, back_populates='articles')

    # Data
    with created_tables(connection, Base):
        # Insert some rows
        insert(connection, User,
            id_manyfields('u', 1),
        )
        insert(connection, Article,
            id_manyfields('a', 1, user_id=1),
            id_manyfields('a', 2, user_id=1),
            id_manyfields('a', 3, user_id=1),
        )

        # Test
        typical_test_query_text_and_results(connection, query_object, User, expected_query_lines, expected_results)
Пример #4
0
def test_skiplimit_sql(connection: sa.engine.Connection,
                       query_object: QueryObjectDict,
                       expected_query_lines: list[str]):
    """ Typical test: what SQL is generated """
    # Models
    Base = sacompat.declarative_base()

    class Model(IdManyFieldsMixin, Base):
        __tablename__ = 'a'

        # This Postgres-specific implementation has .contains() and .overlaps() implementations
        tags = sa.Column(pg.ARRAY(sa.String))

    # Test
    typical_test_sql_query_text(query_object, Model, expected_query_lines)
Пример #5
0
def test_query_customize_statements(connection: sa.engine.Connection,
                                    query_object: QueryObjectDict,
                                    expected_columns: list[str]):
    """ Test Query.customize_statements: adding security to queries """
    # Models
    Base = sacompat.declarative_base()

    class User(IdManyFieldsMixin, Base):
        __tablename__ = 'u'

    class Article(IdManyFieldsMixin, Base):
        __tablename__ = 'a'

        user_id = sa.Column(sa.ForeignKey(User.id))

        author = sa.orm.relationship(User)
        comments = sa.orm.relationship('Comment')

    class Comment(IdManyFieldsMixin, Base):
        __tablename__ = 'c'

        article_id = sa.Column(sa.ForeignKey(Article.id))
        user_id = sa.Column(sa.ForeignKey(User.id))

    # Query
    q = Query(query_object, Article)

    # Security
    @q.customize_statements.append
    def security(q: Query, stmt: sa.sql.Select) -> sa.sql.Select:
        """ Security: make sure that the user can only access their own data """
        ALLOWED_USER_ID = 1

        path = q.load_path
        if path == (Article, ):
            return stmt.where(q.Model.user_id == ALLOWED_USER_ID)
        elif path == (Article, 'author', User):
            return stmt.where(q.Model.id == ALLOWED_USER_ID)
        elif path == (Article, 'comments', Comment):
            return stmt.where(q.Model.user_id == ALLOWED_USER_ID)
        else:
            raise NotImplementedError

    # SQL
    assert_query_statements_lines(q, *expected_columns)
Пример #6
0
def test_sort_results(connection: sa.engine.Connection, query_object: QueryObjectDict, expected_results: list[dict]):
    """ Typical test: real data, real query, real results """
    # Models
    Base = sacompat.declarative_base()

    class Model(IdManyFieldsMixin, Base):
        __tablename__ = 'a'

    # Data
    with created_tables(connection, Base):
        # Insert some rows
        insert(connection, Model,
            id_manyfields('m', 1),
            id_manyfields('m', 2),
            id_manyfields('m', 3),
        )

        # Test
        typical_test_query_results(connection, query_object, Model, expected_results)
Пример #7
0
def test_select_sql(connection: sa.engine.Connection, query_object: QueryObjectDict, expected_columns: list[str]):
    """ Typical test: what SQL is generated """
    # Models
    Base = sacompat.declarative_base()

    class Model(IdManyFieldsMixin, Base):
        __tablename__ = 'a'

        @property
        @loads_attributes_readcode()
        def abc(self):
            self.a, self.b, self.c  # readcode will get this
            raise NotImplementedError  # we don't care in this test

    # Make sure it's detected properly
    assert is_annotated_with_loads(Model.abc)  # decorated ok
    assert get_property_loads_attribute_names(Model.abc) == ('a', 'b', 'c')  # code is read ok

    # Test
    typical_test_sql_selected_columns(query_object, Model, expected_columns)
Пример #8
0
def test_query_object_with_sa_model(query_str: str,
                                    expected_query_object: dict):
    """ Test how Query Object works with a real SqlAlchemy model """
    # Models
    Base = sacompat.declarative_base()

    class Model(IdManyFieldsMixin, Base):
        __tablename__ = 'models'

        # Define some relationships
        object_id = sa.Column(sa.ForeignKey('models.id'))
        object_ids = sa.Column(sa.ForeignKey('models.id'))

        object = sa.orm.relationship('Model', foreign_keys=object_id)
        objects = sa.orm.relationship('Model', foreign_keys=object_ids)

    # Prepare the schema and the query document
    qctx = prepare_graphql_query_for(schema_prepare(), query_str)

    # Prepare QuerySettings
    def to_snake_case(name: str) -> str:
        # We don't have ariadne here, so let's fake it
        if name == 'objectId':
            return 'object_id'
        else:
            return name

    qsets = jessiql.QuerySettings(rewriter=rewrite.RewriteSAModel(
        rewrite.Transform(to_snake_case),
        Model=Model,
    ))

    # Get the Query Object
    api_query_object = query_object_for(qctx.info, runtime_type='Model')
    query_object = qsets.rewriter.query_object(api_query_object)
    assert query_object.dict() == expected_query_object
Пример #9
0
def test_joined_select(connection: sa.engine.Connection, model: str, query_object: QueryObjectDict, expected_query_lines: list[str], expected_results: list[dict]):
    """ Typical test: JOINs, SQL and results """
    # Models
    Base = sacompat.declarative_base()

    # One-to-Many, FK on remote side:
    #   User.articles: User -> Article (Article.user_id)
    #   User.comments: User -> Comment (Comment.user_id)
    #   Article.comments: Article -> Comment (Comment.article_id)
    # Many-to-One, FK on local side:
    #   Comment.article: Comment -> Article (Comment.article_id)
    #   Comment.author: Comment -> User (Comment.user_id)
    #   Article.author: Article -> User (Article.user_id)
    # Many-to-Many:
    #   Article.tags: Article -> (m2m) -> Tag

    class User(IdManyFieldsMixin, Base):
        __tablename__ = 'u'

        articles = sa.orm.relationship('Article', back_populates='author')
        comments = sa.orm.relationship('Comment', back_populates='author')

    class Tag(IdManyFieldsMixin, Base):
        __tablename__ = 't'

    class Article(IdManyFieldsMixin, Base):
        __tablename__ = 'a'

        user_id = sa.Column(sa.ForeignKey(User.id))
        author = sa.orm.relationship(User, back_populates='articles')

        comments = sa.orm.relationship('Comment', back_populates='article')
        tags = sa.orm.relationship(Tag, secondary=lambda: ArticleTagLink.__table__, order_by=Tag.a, lazy='selectin')

        @property
        @loads_attributes_readcode()
        def abc(self):
            return ' '.join((self.a, self.b, self.c))

    class Comment(IdManyFieldsMixin, Base):
        __tablename__ = 'c'

        article_id = sa.Column(sa.ForeignKey(Article.id))
        article = sa.orm.relationship(Article, back_populates='comments')

        user_id = sa.Column(sa.ForeignKey(User.id))
        author = sa.orm.relationship(User, back_populates='comments')

    class ArticleTagLink(Base):  # Many to Many
        __tablename__ = 'at'

        id = sa.Column(sa.Integer, primary_key=True, nullable=False)
        article_id = sa.Column(Article.id.type, sa.ForeignKey(Article.id, ondelete='CASCADE'), nullable=False)
        tag_id = sa.Column(Tag.id.type, sa.ForeignKey(Tag.id, ondelete='CASCADE'), nullable=False)



    # Data
    with created_tables(connection, Base):
        # Insert some rows
        insert(connection, User,
            id_manyfields('u', 1),
            id_manyfields('u', 2),
            id_manyfields('u', 3),
        )
        insert(connection, Article,
            # 2 articles from User(id=1)
            id_manyfields('a', 1, user_id=1),
            id_manyfields('a', 2, user_id=1),
            # 1 article from User(id=2)
            id_manyfields('a', 3, user_id=2),
            # 1 article from User(id=3)
            id_manyfields('a', 4, user_id=3),
            # article with no user
            # this is a potential stumbling block for conditions that fail to filter it out
            id_manyfields('a', 5, user_id=None),
        )
        insert(connection, Comment,
            # User(id=1), User(id=2), User(id=3) commented on Article(id=1)
            id_manyfields('c', 1, user_id=1, article_id=1),
            id_manyfields('c', 2, user_id=2, article_id=1),
            id_manyfields('c', 3, user_id=3, article_id=1),
            # User(id=1) commented on Article(id=2)
            id_manyfields('c', 4, user_id=1, article_id=2),
            # User(id=1) commented on Article(id=3)
            id_manyfields('c', 5, user_id=1, article_id=3),
            # comment with no user/article
            # this is a potential stumbling block for conditions that fail to filter it out
            id_manyfields('c', 6, user_id=None, article_id=None),
        )
        insert(connection, Tag,
            id_manyfields('t', 1),
            id_manyfields('t', 2),
            id_manyfields('t', 3),
        )
        insert(connection, ArticleTagLink,
            # Article id=1: 3 tags
            dict(id=1, article_id=1, tag_id=1),
            dict(id=2, article_id=1, tag_id=2),
            dict(id=3, article_id=1, tag_id=3),
            # Article id=2: one tag
            dict(id=4, article_id=2, tag_id=1),
            # Article id=3: one tag
            dict(id=5, article_id=3, tag_id=1),
        )

        # Test
        Model = locals()[model]
        typical_test_query_text_and_results(connection, query_object, Model, expected_query_lines, expected_results)
Пример #10
0
def test_skiplimit_cursor_pagination(connection: sa.engine.Connection):
    """ Test pagination with cursors """
    def main():
        # ### Test: wrong usage
        # Test: cannot get a link before results are fetched
        q = jessiql.Query(dict(limit=2), User)

        # Not possible to generate links before results are fetched
        with pytest.raises(RuntimeError):
            q.page_links()

        # Fetch results. Now possible.
        q.fetchall(connection)
        q.page_links()  # no error

        # ### Test: skip/limit pages

        # Test: Page 0
        # No prev page, have next page
        q, res = load(select=['id'], sort=['a'], limit=2)

        assert ids(res) == [1, 2]
        assert decode_links(q.page_links()) == (None, dict(skip=2, limit=2))

        # Test: next page
        # Have both prev & next pages
        q, res = load(select=['id'], sort=['a'], after=q.page_links().next)

        assert ids(res) == [3, 4]
        assert decode_links(q.page_links()) == (dict(skip=0, limit=2),
                                                dict(skip=4, limit=2))

        # Test: next page (last page)
        q, res = load(select=['id'], sort=['a'], after=q.page_links().next)

        assert ids(res) == [5]
        assert decode_links(q.page_links()) == (dict(skip=2, limit=2), None)

        # Test: prev page
        q, res = load(select=['id'], sort=['a'], before=q.page_links().prev)

        assert ids(res) == [3, 4]
        assert decode_links(q.page_links()) == (dict(skip=0, limit=2),
                                                dict(skip=4, limit=2))

        # Test: prev page (first)
        q, res = load(select=['id'], sort=['a'], before=q.page_links().prev)

        assert ids(res) == [1, 2]
        assert decode_links(q.page_links()) == (None, dict(skip=2, limit=2))

        # ### Test: approaching the end
        # Because this is the end, there should be no next page

        # Case 1. Got no rows => No next page.
        cursor = SkipCursorData(5, 2).encode()
        q, res = load(select=['id'], sort=['a'], after=cursor)
        assert ids(res) == []
        assert q.page_links().next is None

        # Case 2: Got one row, result set incomplete => No next page.
        cursor = SkipCursorData(4, 2).encode()
        q, res = load(select=['id'], sort=['a'], after=cursor)
        assert ids(res) == [5]
        assert q.page_links().next is None

        # Case 3: Got two rows, but there's nothing beyond that => No next page.
        cursor = SkipCursorData(3, 2).encode()
        q, res = load(select=['id'], sort=['a'], after=cursor)
        assert ids(res) == [4, 5]
        assert q.page_links().next is None

        # ### Test: keyset

        # Test: first page
        q, res = load(select=['id', 'a'], sort=['a', 'id'], limit=2)

        assert ids(res) == [1, 2]
        assert decode_links(q.page_links()) == (None,
                                                dict(cols=['a', 'id'],
                                                     limit=2,
                                                     op='>',
                                                     val=['u-2-a', 2]))

        # Test: next page
        q, res = load(select=['id', 'a'],
                      sort=['a', 'id'],
                      after=q.page_links().next)

        assert ids(res) == [3, 4]
        assert decode_links(q.page_links()) == (dict(cols=['a', 'id'],
                                                     limit=2,
                                                     op='<',
                                                     val=['u-3-a', 3]),
                                                dict(cols=['a', 'id'],
                                                     limit=2,
                                                     op='>',
                                                     val=['u-4-a', 4]))

        # Test: next page
        q, res = load(select=['id', 'a'],
                      sort=['a', 'id'],
                      after=q.page_links().next)

        assert ids(res) == [5]
        assert decode_links(q.page_links()) == (dict(cols=['a', 'id'],
                                                     limit=2,
                                                     op='<',
                                                     val=['u-5-a', 5]), None)

        # Test: prev page
        q, res = load(select=['id', 'a'],
                      sort=['a', 'id'],
                      before=q.page_links().prev)

        # assert ids(res) == [3, 4]  # TODO: backward navigation does not work yet
        # assert decode_links(q.page_links()) == (dict(cols=['a', 'id'], limit=2, op='<', val=['u-3-a', 3]),
        #                                         dict(cols=['a', 'id'], limit=2, op='>', val=['u-4-a', 4]))

        # ### Test: related pagination
        q, res = load(select=['id'],
                      join={
                          'articles': {
                              'select': ['id'],
                              'skip': 1,
                              'limit': 1,
                              'sort': ['id+'],
                          }
                      },
                      filter={'id': 1})
        assert res == [{
            'id':
            1,
            'articles': [
                # {'user_id': 1, 'id': 1},
                {
                    'user_id': 1,
                    'id': 2
                },
                # {'user_id': 1, 'id': 3},
            ],
        }]

    # Models
    Base = sacompat.declarative_base()

    class User(IdManyFieldsMixin, Base):
        __tablename__ = 'u'

        articles = sa.orm.relationship('Article', back_populates='author')

    class Article(IdManyFieldsMixin, Base):
        __tablename__ = 'a'

        user_id = sa.Column(sa.ForeignKey(User.id))
        author = sa.orm.relationship(User, back_populates='articles')

    # Helpers
    def load(**query_object) -> tuple[jessiql.Query, list[dict]]:
        """ Given a Query Object, load results, return (Query, result) """
        q = jessiql.Query(query_object, User)
        res = q.fetchall(connection)
        return q, res

    def ids(row_dicts: list[dict]) -> list[id]:
        """ Convert a list of dicts to ids """
        return [row['id'] for row in row_dicts]

    # Data
    with created_tables(connection, Base):
        # Insert some rows
        insert(connection, User,
               *(id_manyfields('u', id) for id in range(1, 6)))
        insert(
            connection,
            Article,
            id_manyfields('a', 1, user_id=1),
            id_manyfields('a', 2, user_id=1),
            id_manyfields('a', 3, user_id=1),
        )

        # Test
        main()