Exemplo n.º 1
0
async def test_session_save__does_not_commit(async_db: AsyncDatabase):
    async with async_db.session() as session:
        await session.save(User())

    async with async_db.session() as session:
        select_count = sa.select(sa.func.count(User.id))
        assert await session.one(select_count) == 0
Exemplo n.º 2
0
async def test_session_save_all__inserts_and_updates_multiple_types(
        async_db: AsyncDatabase):
    existing_users = [User(name="uA"), User(name="uB"), User(name="uC")]
    new_users = [User(name="u1"), User(name="u2"), User(name="u3")]

    existing_groups = [Group(name="gA"), Group(name="gB"), Group(name="gC")]
    new_groups = [Group(name="g1"), Group(name="g2"), Group(name="g3")]

    async with async_db.begin() as session:
        await session.save_all(existing_users)
        await session.save_all(existing_groups)

    all_models = existing_users + existing_groups + new_users + new_groups  # type: ignore
    all_users = existing_users + new_users
    all_groups = existing_groups + new_groups

    async with async_db.begin() as session:
        await session.save_all(all_models)

    async with async_db.session() as session:
        count_users = sa.select(sa.func.count(User.id))
        assert await session.one(count_users) == len(all_users)

        count_groups = sa.select(sa.func.count(Group.id))
        assert await session.one(count_groups) == len(all_groups)
Exemplo n.º 3
0
async def test_async_database_begin__starts_new_transaction(
        async_db: AsyncDatabase):
    async with async_db.begin() as session:
        assert isinstance(session, AsyncSession)
        assert session.in_transaction()

        async with async_db.begin() as other_session:
            assert other_session is not session
Exemplo n.º 4
0
def test_async_database_session__returns_new_session_object(
        async_db: AsyncDatabase):
    session = async_db.session()
    assert isinstance(session, AsyncSession)
    assert not session.in_transaction()

    other_session = async_db.session()
    assert other_session is not session
Exemplo n.º 5
0
async def test_session_save_all__returns_list(async_db: AsyncDatabase):
    users = [User(), User(), User()]
    async with async_db.begin() as session:
        result = await session.save_all(users)
        assert isinstance(result, list)
        assert result == users

    async with async_db.session() as session:
        select_count = sa.select(sa.func.count(User.id))
        assert await session.one(select_count) == len(users)
Exemplo n.º 6
0
async def test_session_save__inserts_new_without_pk(async_db: AsyncDatabase):
    user = User(name="n")
    async with async_db.begin() as session:
        await session.save(user)

    assert user.id == 1
    assert user.name == "n"

    async with async_db.session() as session:
        select_count = sa.select(sa.func.count(User.id))
        assert await session.one(select_count) == 1
Exemplo n.º 7
0
async def test_session_save_all__accepts_single_model_and_returns_list(
        async_db: AsyncDatabase):
    user = User()
    async with async_db.begin() as session:
        result = await session.save_all(user)
        assert isinstance(result, list)
        assert result == [user]

    async with async_db.session() as session:
        select_count = sa.select(sa.func.count(User.id))
        assert await session.one(select_count) == 1
Exemplo n.º 8
0
def test_async_database_session__can_override_default_options(
        async_db: AsyncDatabase, option, default_value, override_value):
    session = async_db.session()
    option_value = getattr(session.sync_session, option)
    assert (
        option_value == default_value
    ), f"Expected session.{option} to be {default_value!r}, not {option_value!r}"

    session = async_db.session(**{option: override_value})
    option_value = getattr(session.sync_session, option)
    assert (
        option_value == override_value
    ), f"Expected session.{option} to be {override_value!r}, not {option_value!r}"
Exemplo n.º 9
0
async def test_session_save__updates_existing_in_same_session(
        async_db: AsyncDatabase):
    async with async_db.session() as session:
        user = User(id=1)
        session.add(user)
        await session.commit()

        user.name = "n"
        await session.save(user)
        await session.commit()

    async with async_db.session() as session:
        select_count = sa.select(sa.func.count(User.id))
        assert await session.one(select_count) == 1
Exemplo n.º 10
0
async def test_async_database_create_all():
    model_collection = create_model_collection()
    model_class = model_collection["model_class"]
    model_tables = model_collection["tables"]

    async_db = AsyncDatabase("sqlite+aiosqlite://", model_class=model_class)
    await async_db.create_all()

    async with async_db.connect() as conn:
        result = await conn.execute(sa.text("SELECT name FROM sqlite_master"))
        table_names = result.scalars().all()

    assert len(table_names) == len(model_tables)
    for table_name in table_names:
        assert table_name in async_db.tables
Exemplo n.º 11
0
async def test_session_save__updates_existing(async_db: AsyncDatabase):
    existing_user = (await async_create_users(async_db, count=1))[0]
    assert existing_user.id == 1
    assert existing_user.name is None

    new_user = User(id=1, name="n")
    async with async_db.begin() as session:
        await session.save(existing_user)

    assert new_user.id == 1
    assert new_user.name == "n"

    async with async_db.session() as session:
        select_count = sa.select(sa.func.count(User.id))
        assert await session.one(select_count) == 1
Exemplo n.º 12
0
async def test_session_save_all__raises_on_duplicate_primary_keys_in_list(
        async_db: AsyncDatabase):
    users = [User(id=1), User(id=1)]

    with pytest.raises(TypeError):
        async with async_db.begin() as session:
            await session.save_all(users)
Exemplo n.º 13
0
async def async_db() -> t.AsyncGenerator[AsyncDatabase, None]:
    uri = "sqlite+aiosqlite://"
    echo = False
    _db = AsyncDatabase(uri, model_class=Model, echo=echo)
    await _db.create_all()
    yield _db
    await _db.close()
Exemplo n.º 14
0
async def test_session_one_or_none__returns_one_model_or_none_or_raises(
        async_db: AsyncDatabase):
    users = await async_create_users(async_db)
    user = users[0]

    async with async_db.session() as session:
        result = await session.one_or_none(
            sa.select(User).where(User.id == user.id))
        assert isinstance(result, User)
        assert result.id == user.id

        result = await session.one_or_none(sa.select(User).where(User.id == 0))
        assert result is None

    with pytest.raises(MultipleResultsFound):
        async with async_db.session() as session:
            await session.one_or_none(sa.select(User))
Exemplo n.º 15
0
async def async_filedb(
        tmp_path: Path) -> t.AsyncGenerator[AsyncDatabase, None]:
    dbpath = tmp_path / "test_async.db"
    uri = f"sqlite+aiosqlite:///{dbpath}"
    _db = AsyncDatabase(uri, model_class=Model)
    await _db.create_all()
    yield _db
    await _db.close()
Exemplo n.º 16
0
async def test_async_session_all__returns_all_models(async_db: AsyncDatabase):
    users = await async_create_users(async_db)

    async with async_db.session() as session:
        results = await session.all(sa.select(User))
        assert len(results) == len(users)
        for model in results:
            assert isinstance(model, User)
Exemplo n.º 17
0
def test_async_database_settings__accepts_session_options_dict():
    session_options = {"autoflush": False}
    other_options = {"expire_on_commit": True, "autoflush": True}
    expected_options = {**other_options, **session_options}

    async_db = AsyncDatabase("sqlite+aiosqlite://",
                             session_options=session_options,
                             **other_options)
    assert async_db.settings.get_session_options() == expected_options
Exemplo n.º 18
0
def test_async_database_settings__accepts_engine_options_dict():
    engine_options = {"echo": True}
    other_options = {"encoding": "utf8", "echo": False}
    expected_options = {**other_options, **engine_options}

    async_db = AsyncDatabase("sqlite+aiosqlite://",
                             engine_options=engine_options,
                             **other_options)
    assert async_db.settings.get_engine_options() == expected_options
Exemplo n.º 19
0
async def test_session_first__returns_first_result_or_none(
        async_db: AsyncDatabase):
    await async_create_users(async_db)

    async with async_db.session() as session:
        result = await session.first(sa.select(User))
        assert isinstance(result, User)

        result = await session.first(sa.select(User).where(User.id == 0))
        assert result is None
Exemplo n.º 20
0
async def async_create_users(
        async_db: AsyncDatabase,
        count: int = 3,
        overrides: t.Optional[dict] = None) -> t.List[User]:
    if overrides is None:
        overrides = {}
    users = [User(id=i, **overrides) for i in range(1, count + 1)]
    async with async_db.begin(expire_on_commit=False) as session:
        session.add_all(users)
    return users
Exemplo n.º 21
0
def test_async_database_models():
    model_collection = create_model_collection()
    model_class = model_collection["model_class"]
    models = model_collection["models"]

    async_db = AsyncDatabase("sqlite+aiosqlite://", model_class=model_class)
    assert len(async_db.models) == len(models)

    for _model_name, model in async_db.models.items():
        orig_model = models[models.index(model)]
        assert model is orig_model
Exemplo n.º 22
0
async def test_session_all__uniquifies_joinedload_results(
        async_db: AsyncDatabase):
    users = [
        User(id=1, addresses=[Address(id=11),
                              Address(id=12),
                              Address(id=13)]),
        User(id=2, addresses=[Address(id=21),
                              Address(id=22),
                              Address(id=23)]),
        User(id=3, addresses=[Address(id=31),
                              Address(id=32),
                              Address(id=33)]),
    ]
    async with async_db.begin() as session:
        session.add_all(users)

    stmt = sa.select(User).options(joinedload(User.addresses))
    async with async_db.session() as session:
        results = await session.all(stmt)
        assert len(results) == len(users)
Exemplo n.º 23
0
async def test_session_one_or_none__returns_one_row_or_none_or_raises(
        async_db: AsyncDatabase):
    users = await async_create_users(async_db)
    user = users[0]

    async with async_db.session() as session:
        result = await session.one_or_none(
            sa.text("SELECT * FROM users WHERE id = :id"),
            params={"id": user.id})
        assert not isinstance(result, User)
        assert result.id == user.id
        assert result.name == user.name

        result = await session.one_or_none(
            sa.text("SELECT * FROM users WHERE id = :id"), params={"id": 0})
        assert result is None

    with pytest.raises(MultipleResultsFound):
        async with async_db.session() as session:
            await session.one_or_none(sa.text("SELECT * FROM users"))
Exemplo n.º 24
0
def test_async_database_tables():
    model_collection = create_model_collection()
    model_class = model_collection["model_class"]
    model_tables = model_collection["tables"]

    async_db = AsyncDatabase("sqlite+aiosqlite://", model_class=model_class)
    assert len(async_db.tables) == len(model_tables)

    for table_name, table in async_db.tables.items():
        assert table in model_tables
        model_table = model_tables[model_tables.index(table)]
        assert model_table.name == table_name
Exemplo n.º 25
0
async def test_async_database_drop_all():
    model_collection = create_model_collection()
    model_class = model_collection["model_class"]
    model_tables = model_collection["tables"]

    async_db = AsyncDatabase("sqlite+aiosqlite://", model_class=model_class)
    conn = await async_db.connect()
    await async_db.create_all()

    count_tables = sa.text("SELECT COUNT(name) FROM sqlite_master")
    assert (await conn.execute(count_tables)).scalar_one() == len(model_tables)
    await async_db.drop_all()
    assert (await conn.execute(count_tables)).scalar_one() == 0
Exemplo n.º 26
0
def test_async_database_settings(key: str, value: t.Any, kind: str):
    settings = {key: value}
    # NOTE: Using poolclass so that pool_* options can be used with sqlite.
    settings.setdefault("poolclass", QueuePool)
    async_db = AsyncDatabase("sqlite+aiosqlite://", **settings)

    assert getattr(async_db.settings, key) == value
    assert async_db.settings[key] == value
    assert dict(async_db.settings)[key] == value

    if kind == "session":
        assert async_db.settings.get_session_options()[key] == value
    elif kind == "engine":
        assert async_db.settings.get_engine_options()[key] == value
    else:
        raise RuntimeError(
            f"kind must be one of 'session' or 'engine', not {kind!r}")
Exemplo n.º 27
0
async def test_session_first__returns_first_result_or_none_from_non_orm_query(
    async_db: AsyncDatabase, ):
    users = await async_create_users(async_db)
    first_user = users[0]

    async with async_db.session() as session:
        result = await session.first(
            sa.text("SELECT * FROM users WHERE active = :active ORDER BY id"),
            params={"active": True},
        )
        assert not isinstance(result, User)
        assert result.id == first_user.id
        assert result.name == first_user.name

        result = await session.first(
            sa.text("SELECT * FROM users WHERE id = :id ORDER BY id"),
            params={"id": 0})
        assert result is None
Exemplo n.º 28
0
async def test_async_database_reflect(tmp_path: Path):
    db_file = tmp_path / "reflect_async.db"
    uri = f"sqlite+aiosqlite:///{db_file}"

    model_collection = create_model_collection()
    model_class = model_collection["model_class"]
    model_tables = model_collection["tables"]

    await AsyncDatabase(uri, model_class=model_class).create_all()
    async_db = AsyncDatabase(uri)
    assert len(async_db.tables) == 0

    await async_db.reflect()

    assert len(async_db.tables) == len(model_tables)
    model_tables_by_name = {table.name: table for table in model_tables}

    for table_name, _table in async_db.tables.items():
        assert table_name in model_tables_by_name
Exemplo n.º 29
0
async def test_session_all__returns_all_rows_from_non_orm_query(
        async_db: AsyncDatabase):
    inactive_users_by_id = {
        user.id: user
        for user in (
            await async_create_users(async_db, overrides={"active": False}))
    }

    async with async_db.session() as session:
        results = await session.all(
            sa.text("SELECT * FROM users WHERE active = :active"),
            params={"active": False})
        for user in results:
            assert not isinstance(user, User)

            inactive_user = inactive_users_by_id.get(user.id)
            assert inactive_user

            for key, value in dict(inactive_user).items():
                assert user[key] == value
Exemplo n.º 30
0
async def test_session_save_all__raises_on_invalid_type(
        async_db: AsyncDatabase, value: t.Any, exception: t.Type[Exception]):
    with pytest.raises(exception):
        async with async_db.session() as session:
            await session.save_all(value)