async def test_session_save_all__inserts_and_updates_multiple_types( async_db: AsyncDatabase): existing_users = [User(name="uA"), User(name="uB"), User(name="uC")] new_users = [User(name="u1"), User(name="u2"), User(name="u3")] existing_groups = [Group(name="gA"), Group(name="gB"), Group(name="gC")] new_groups = [Group(name="g1"), Group(name="g2"), Group(name="g3")] async with async_db.begin() as session: await session.save_all(existing_users) await session.save_all(existing_groups) all_models = existing_users + existing_groups + new_users + new_groups # type: ignore all_users = existing_users + new_users all_groups = existing_groups + new_groups async with async_db.begin() as session: await session.save_all(all_models) async with async_db.session() as session: count_users = sa.select(sa.func.count(User.id)) assert await session.one(count_users) == len(all_users) count_groups = sa.select(sa.func.count(Group.id)) assert await session.one(count_groups) == len(all_groups)
async def test_async_database_begin__starts_new_transaction( async_db: AsyncDatabase): async with async_db.begin() as session: assert isinstance(session, AsyncSession) assert session.in_transaction() async with async_db.begin() as other_session: assert other_session is not session
async def test_session_save_all__raises_on_duplicate_primary_keys_in_list( async_db: AsyncDatabase): users = [User(id=1), User(id=1)] with pytest.raises(TypeError): async with async_db.begin() as session: await session.save_all(users)
async def test_session_save_all__returns_list(async_db: AsyncDatabase): users = [User(), User(), User()] async with async_db.begin() as session: result = await session.save_all(users) assert isinstance(result, list) assert result == users async with async_db.session() as session: select_count = sa.select(sa.func.count(User.id)) assert await session.one(select_count) == len(users)
async def async_create_users( async_db: AsyncDatabase, count: int = 3, overrides: t.Optional[dict] = None) -> t.List[User]: if overrides is None: overrides = {} users = [User(id=i, **overrides) for i in range(1, count + 1)] async with async_db.begin(expire_on_commit=False) as session: session.add_all(users) return users
async def test_session_save_all__accepts_single_model_and_returns_list( async_db: AsyncDatabase): user = User() async with async_db.begin() as session: result = await session.save_all(user) assert isinstance(result, list) assert result == [user] async with async_db.session() as session: select_count = sa.select(sa.func.count(User.id)) assert await session.one(select_count) == 1
async def test_session_save__inserts_new_without_pk(async_db: AsyncDatabase): user = User(name="n") async with async_db.begin() as session: await session.save(user) assert user.id == 1 assert user.name == "n" async with async_db.session() as session: select_count = sa.select(sa.func.count(User.id)) assert await session.one(select_count) == 1
async def test_session_save__updates_existing(async_db: AsyncDatabase): existing_user = (await async_create_users(async_db, count=1))[0] assert existing_user.id == 1 assert existing_user.name is None new_user = User(id=1, name="n") async with async_db.begin() as session: await session.save(existing_user) assert new_user.id == 1 assert new_user.name == "n" async with async_db.session() as session: select_count = sa.select(sa.func.count(User.id)) assert await session.one(select_count) == 1
async def test_session_all__uniquifies_joinedload_results( async_db: AsyncDatabase): users = [ User(id=1, addresses=[Address(id=11), Address(id=12), Address(id=13)]), User(id=2, addresses=[Address(id=21), Address(id=22), Address(id=23)]), User(id=3, addresses=[Address(id=31), Address(id=32), Address(id=33)]), ] async with async_db.begin() as session: session.add_all(users) stmt = sa.select(User).options(joinedload(User.addresses)) async with async_db.session() as session: results = await session.all(stmt) assert len(results) == len(users)