async def test_session_save__does_not_commit(async_db: AsyncDatabase): async with async_db.session() as session: await session.save(User()) async with async_db.session() as session: select_count = sa.select(sa.func.count(User.id)) assert await session.one(select_count) == 0
async def test_session_save_all__inserts_and_updates_multiple_types( async_db: AsyncDatabase): existing_users = [User(name="uA"), User(name="uB"), User(name="uC")] new_users = [User(name="u1"), User(name="u2"), User(name="u3")] existing_groups = [Group(name="gA"), Group(name="gB"), Group(name="gC")] new_groups = [Group(name="g1"), Group(name="g2"), Group(name="g3")] async with async_db.begin() as session: await session.save_all(existing_users) await session.save_all(existing_groups) all_models = existing_users + existing_groups + new_users + new_groups # type: ignore all_users = existing_users + new_users all_groups = existing_groups + new_groups async with async_db.begin() as session: await session.save_all(all_models) async with async_db.session() as session: count_users = sa.select(sa.func.count(User.id)) assert await session.one(count_users) == len(all_users) count_groups = sa.select(sa.func.count(Group.id)) assert await session.one(count_groups) == len(all_groups)
async def test_async_database_begin__starts_new_transaction( async_db: AsyncDatabase): async with async_db.begin() as session: assert isinstance(session, AsyncSession) assert session.in_transaction() async with async_db.begin() as other_session: assert other_session is not session
def test_async_database_session__returns_new_session_object( async_db: AsyncDatabase): session = async_db.session() assert isinstance(session, AsyncSession) assert not session.in_transaction() other_session = async_db.session() assert other_session is not session
async def test_session_save_all__returns_list(async_db: AsyncDatabase): users = [User(), User(), User()] async with async_db.begin() as session: result = await session.save_all(users) assert isinstance(result, list) assert result == users async with async_db.session() as session: select_count = sa.select(sa.func.count(User.id)) assert await session.one(select_count) == len(users)
async def test_session_save__inserts_new_without_pk(async_db: AsyncDatabase): user = User(name="n") async with async_db.begin() as session: await session.save(user) assert user.id == 1 assert user.name == "n" async with async_db.session() as session: select_count = sa.select(sa.func.count(User.id)) assert await session.one(select_count) == 1
async def test_session_save_all__accepts_single_model_and_returns_list( async_db: AsyncDatabase): user = User() async with async_db.begin() as session: result = await session.save_all(user) assert isinstance(result, list) assert result == [user] async with async_db.session() as session: select_count = sa.select(sa.func.count(User.id)) assert await session.one(select_count) == 1
def test_async_database_session__can_override_default_options( async_db: AsyncDatabase, option, default_value, override_value): session = async_db.session() option_value = getattr(session.sync_session, option) assert ( option_value == default_value ), f"Expected session.{option} to be {default_value!r}, not {option_value!r}" session = async_db.session(**{option: override_value}) option_value = getattr(session.sync_session, option) assert ( option_value == override_value ), f"Expected session.{option} to be {override_value!r}, not {option_value!r}"
async def test_session_save__updates_existing_in_same_session( async_db: AsyncDatabase): async with async_db.session() as session: user = User(id=1) session.add(user) await session.commit() user.name = "n" await session.save(user) await session.commit() async with async_db.session() as session: select_count = sa.select(sa.func.count(User.id)) assert await session.one(select_count) == 1
async def test_async_database_create_all(): model_collection = create_model_collection() model_class = model_collection["model_class"] model_tables = model_collection["tables"] async_db = AsyncDatabase("sqlite+aiosqlite://", model_class=model_class) await async_db.create_all() async with async_db.connect() as conn: result = await conn.execute(sa.text("SELECT name FROM sqlite_master")) table_names = result.scalars().all() assert len(table_names) == len(model_tables) for table_name in table_names: assert table_name in async_db.tables
async def test_session_save__updates_existing(async_db: AsyncDatabase): existing_user = (await async_create_users(async_db, count=1))[0] assert existing_user.id == 1 assert existing_user.name is None new_user = User(id=1, name="n") async with async_db.begin() as session: await session.save(existing_user) assert new_user.id == 1 assert new_user.name == "n" async with async_db.session() as session: select_count = sa.select(sa.func.count(User.id)) assert await session.one(select_count) == 1
async def test_session_save_all__raises_on_duplicate_primary_keys_in_list( async_db: AsyncDatabase): users = [User(id=1), User(id=1)] with pytest.raises(TypeError): async with async_db.begin() as session: await session.save_all(users)
async def async_db() -> t.AsyncGenerator[AsyncDatabase, None]: uri = "sqlite+aiosqlite://" echo = False _db = AsyncDatabase(uri, model_class=Model, echo=echo) await _db.create_all() yield _db await _db.close()
async def test_session_one_or_none__returns_one_model_or_none_or_raises( async_db: AsyncDatabase): users = await async_create_users(async_db) user = users[0] async with async_db.session() as session: result = await session.one_or_none( sa.select(User).where(User.id == user.id)) assert isinstance(result, User) assert result.id == user.id result = await session.one_or_none(sa.select(User).where(User.id == 0)) assert result is None with pytest.raises(MultipleResultsFound): async with async_db.session() as session: await session.one_or_none(sa.select(User))
async def async_filedb( tmp_path: Path) -> t.AsyncGenerator[AsyncDatabase, None]: dbpath = tmp_path / "test_async.db" uri = f"sqlite+aiosqlite:///{dbpath}" _db = AsyncDatabase(uri, model_class=Model) await _db.create_all() yield _db await _db.close()
async def test_async_session_all__returns_all_models(async_db: AsyncDatabase): users = await async_create_users(async_db) async with async_db.session() as session: results = await session.all(sa.select(User)) assert len(results) == len(users) for model in results: assert isinstance(model, User)
def test_async_database_settings__accepts_session_options_dict(): session_options = {"autoflush": False} other_options = {"expire_on_commit": True, "autoflush": True} expected_options = {**other_options, **session_options} async_db = AsyncDatabase("sqlite+aiosqlite://", session_options=session_options, **other_options) assert async_db.settings.get_session_options() == expected_options
def test_async_database_settings__accepts_engine_options_dict(): engine_options = {"echo": True} other_options = {"encoding": "utf8", "echo": False} expected_options = {**other_options, **engine_options} async_db = AsyncDatabase("sqlite+aiosqlite://", engine_options=engine_options, **other_options) assert async_db.settings.get_engine_options() == expected_options
async def test_session_first__returns_first_result_or_none( async_db: AsyncDatabase): await async_create_users(async_db) async with async_db.session() as session: result = await session.first(sa.select(User)) assert isinstance(result, User) result = await session.first(sa.select(User).where(User.id == 0)) assert result is None
async def async_create_users( async_db: AsyncDatabase, count: int = 3, overrides: t.Optional[dict] = None) -> t.List[User]: if overrides is None: overrides = {} users = [User(id=i, **overrides) for i in range(1, count + 1)] async with async_db.begin(expire_on_commit=False) as session: session.add_all(users) return users
def test_async_database_models(): model_collection = create_model_collection() model_class = model_collection["model_class"] models = model_collection["models"] async_db = AsyncDatabase("sqlite+aiosqlite://", model_class=model_class) assert len(async_db.models) == len(models) for _model_name, model in async_db.models.items(): orig_model = models[models.index(model)] assert model is orig_model
async def test_session_all__uniquifies_joinedload_results( async_db: AsyncDatabase): users = [ User(id=1, addresses=[Address(id=11), Address(id=12), Address(id=13)]), User(id=2, addresses=[Address(id=21), Address(id=22), Address(id=23)]), User(id=3, addresses=[Address(id=31), Address(id=32), Address(id=33)]), ] async with async_db.begin() as session: session.add_all(users) stmt = sa.select(User).options(joinedload(User.addresses)) async with async_db.session() as session: results = await session.all(stmt) assert len(results) == len(users)
async def test_session_one_or_none__returns_one_row_or_none_or_raises( async_db: AsyncDatabase): users = await async_create_users(async_db) user = users[0] async with async_db.session() as session: result = await session.one_or_none( sa.text("SELECT * FROM users WHERE id = :id"), params={"id": user.id}) assert not isinstance(result, User) assert result.id == user.id assert result.name == user.name result = await session.one_or_none( sa.text("SELECT * FROM users WHERE id = :id"), params={"id": 0}) assert result is None with pytest.raises(MultipleResultsFound): async with async_db.session() as session: await session.one_or_none(sa.text("SELECT * FROM users"))
def test_async_database_tables(): model_collection = create_model_collection() model_class = model_collection["model_class"] model_tables = model_collection["tables"] async_db = AsyncDatabase("sqlite+aiosqlite://", model_class=model_class) assert len(async_db.tables) == len(model_tables) for table_name, table in async_db.tables.items(): assert table in model_tables model_table = model_tables[model_tables.index(table)] assert model_table.name == table_name
async def test_async_database_drop_all(): model_collection = create_model_collection() model_class = model_collection["model_class"] model_tables = model_collection["tables"] async_db = AsyncDatabase("sqlite+aiosqlite://", model_class=model_class) conn = await async_db.connect() await async_db.create_all() count_tables = sa.text("SELECT COUNT(name) FROM sqlite_master") assert (await conn.execute(count_tables)).scalar_one() == len(model_tables) await async_db.drop_all() assert (await conn.execute(count_tables)).scalar_one() == 0
def test_async_database_settings(key: str, value: t.Any, kind: str): settings = {key: value} # NOTE: Using poolclass so that pool_* options can be used with sqlite. settings.setdefault("poolclass", QueuePool) async_db = AsyncDatabase("sqlite+aiosqlite://", **settings) assert getattr(async_db.settings, key) == value assert async_db.settings[key] == value assert dict(async_db.settings)[key] == value if kind == "session": assert async_db.settings.get_session_options()[key] == value elif kind == "engine": assert async_db.settings.get_engine_options()[key] == value else: raise RuntimeError( f"kind must be one of 'session' or 'engine', not {kind!r}")
async def test_session_first__returns_first_result_or_none_from_non_orm_query( async_db: AsyncDatabase, ): users = await async_create_users(async_db) first_user = users[0] async with async_db.session() as session: result = await session.first( sa.text("SELECT * FROM users WHERE active = :active ORDER BY id"), params={"active": True}, ) assert not isinstance(result, User) assert result.id == first_user.id assert result.name == first_user.name result = await session.first( sa.text("SELECT * FROM users WHERE id = :id ORDER BY id"), params={"id": 0}) assert result is None
async def test_async_database_reflect(tmp_path: Path): db_file = tmp_path / "reflect_async.db" uri = f"sqlite+aiosqlite:///{db_file}" model_collection = create_model_collection() model_class = model_collection["model_class"] model_tables = model_collection["tables"] await AsyncDatabase(uri, model_class=model_class).create_all() async_db = AsyncDatabase(uri) assert len(async_db.tables) == 0 await async_db.reflect() assert len(async_db.tables) == len(model_tables) model_tables_by_name = {table.name: table for table in model_tables} for table_name, _table in async_db.tables.items(): assert table_name in model_tables_by_name
async def test_session_all__returns_all_rows_from_non_orm_query( async_db: AsyncDatabase): inactive_users_by_id = { user.id: user for user in ( await async_create_users(async_db, overrides={"active": False})) } async with async_db.session() as session: results = await session.all( sa.text("SELECT * FROM users WHERE active = :active"), params={"active": False}) for user in results: assert not isinstance(user, User) inactive_user = inactive_users_by_id.get(user.id) assert inactive_user for key, value in dict(inactive_user).items(): assert user[key] == value
async def test_session_save_all__raises_on_invalid_type( async_db: AsyncDatabase, value: t.Any, exception: t.Type[Exception]): with pytest.raises(exception): async with async_db.session() as session: await session.save_all(value)