コード例 #1
0
    def test_attributes(self, async_engine):
        from asyncio import current_task

        expected = [
            name for cls in _AsyncSession.mro() for name in vars(cls)
            if not name.startswith("_")
        ]

        ignore_list = {
            "dispatch",
            "sync_session_class",
            "run_sync",
            "get_transaction",
            "get_nested_transaction",
            "in_transaction",
            "in_nested_transaction",
        }

        SM = async_scoped_session(
            sessionmaker(async_engine, class_=_AsyncSession), current_task)

        missing = [
            name for name in expected
            if not hasattr(SM, name) and name not in ignore_list
        ]
        eq_(missing, [])
コード例 #2
0
    async def test_basic(self, async_engine):
        from asyncio import current_task

        AsyncSession = async_scoped_session(
            sa.orm.sessionmaker(async_engine, class_=_AsyncSession),
            scopefunc=current_task,
        )

        some_async_session = AsyncSession()
        some_other_async_session = AsyncSession()

        is_(some_async_session, some_other_async_session)
        is_(some_async_session.bind, async_engine)

        User = self.classes.User

        async with AsyncSession.begin():
            user_name = "scoped_async_session_u1"
            u1 = User(name=user_name)

            AsyncSession.add(u1)

            await AsyncSession.flush()

            conn = await AsyncSession.connection()
            stmt = select(func.count(User.id)).where(User.name == user_name)
            eq_(await conn.scalar(stmt), 1)

            await AsyncSession.delete(u1)
            await AsyncSession.flush()
            eq_(await conn.scalar(stmt), 0)
コード例 #3
0
ファイル: sessionmakers.py プロジェクト: kkirsche/sqlalchemy
def async_scoped_session_factory(
    engine: AsyncEngine,
) -> async_scoped_session[MyAsyncSession]:
    return async_scoped_session(
        async_sessionmaker(engine, class_=MyAsyncSession),
        scopefunc=lambda: None,
    )
コード例 #4
0
ファイル: database.py プロジェクト: lsst-sqre/safir
async def create_async_session(
    engine: AsyncEngine,
    logger: Optional[BoundLogger] = None,
    *,
    statement: Optional[Select] = None,
) -> async_scoped_session:
    """Create a new async database session.

    Optionally checks that the database is available and retries in a loop for
    10s if it is not.  This should be used during application startup to wait
    for any network setup or database proxy sidecar.

    Parameters
    ----------
    engine : `sqlalchemy.ext.asyncio.AsyncEngine`
        Database engine to use for the session.
    logger : ``structlog.stdlib.BoundLogger``, optional
        Logger for reporting errors.  Used only if a statement is provided.
    statement : `sqlalchemy.sql.expression.Select`, optional
        If provided, statement to run to check database connectivity.  This
        will be modified with ``limit(1)`` before execution.  If not provided,
        database connectivity will not be checked.

    Returns
    -------
    session : `sqlalchemy.ext.asyncio.async_scoped_session`
        The database session proxy.  This is an asyncio scoped session that is
        scoped to the current task, which means that it will materialize new
        AsyncSession objects for each asyncio task (and thus each web
        request).  ``await session.remove()`` should be called when the caller
        is done with the session.
    """
    factory = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
    session = async_scoped_session(factory, scopefunc=asyncio.current_task)

    # If no statement was provided, just return the async_scoped_session.
    if statement is None:
        return session

    # A statement was provided, so we want to check connectivity and retry for
    # up to ten seconds before returning the session.
    for _ in range(5):
        try:
            async with session.begin():
                await session.execute(statement.limit(1))
                return session
        except (ConnectionRefusedError, OperationalError, OSError):
            if logger:
                logger.info("database not ready, waiting two seconds")
            await session.remove()
            await asyncio.sleep(2)
            continue

    # If we got here, we failed five times.  Try one last time without
    # catching exceptions so that we raise the appropriate exception to our
    # caller.
    async with session.begin():
        await session.execute(statement.limit(1))
        return session
コード例 #5
0
 def __init__(self, Model: T):
     self.Model = Model
     engine = create_engine(
         setting.POSTGRES_DATABASE_URI,
         echo=True,
     )
     self.engine = engine
     session_factory = sessionmaker(autocommit=False,
                                    autoflush=False,
                                    bind=engine)
     self.session = scoped_session(session_factory)  # type: ignore
     async_engine = create_async_engine(
         setting.POSTGRES_DATABASE_URI_ASYNC,
         future=True,
         echo=True,
         json_serializer=json_serializer,
     )
     self.AsyncSession = async_session_factory = sessionmaker(  # type: ignore
         async_engine,
         expire_on_commit=False,
         class_=AsyncSession)
     self.async_session = async_scoped_session(  # type: ignore
         async_session_factory, scopefunc=current_task)
コード例 #6
0
def get_db_session(engine, scope_func):
    async_session_factory = sessionmaker(engine,
                                         expire_on_commit=False,
                                         class_=AsyncSession)
    return async_scoped_session(async_session_factory, scopefunc=scope_func)