Beispiel #1
0
    def sample(self, min_score, max_score, size):

        weight = (self.connection.execute(
            select(func.count(self.table.c.id)).select_from(self.table).where(
                self.table.c.poids.is_(None),
                max_score >= self.table.c.score_confiance,
                self.table.c.score_confiance >= min_score,
            )).scalar() / size)

        selection = (select(self.table.c.id).select_from(self.table).where(
            max_score >= self.table.c.score_confiance,
            self.table.c.score_confiance >= min_score,
            self.table.c.poids.is_(None),
        ).order_by(func.random()).limit(size))

        results = list(
            self.connection.execute(
                update(self.table).where(
                    self.table.c.id.in_(selection)).values({
                        self.table.c.poids:
                        weight
                    }).returning(*[column(c) for c in self.inline_tables()])))

        actual_size = len(results)
        if actual_size < size:
            raise InsufficientPopulationSize(
                f"You requested {size} elements, but only {actual_size} are available",
                requested_size=size,
                actual_size=actual_size,
            )
        return results
Beispiel #2
0
async def autoscroll(client, message):
    command = message.command
    command.pop(0)
    chat = message.chat.id
    if command:
        chat = command[0]

    try:
        chat, entity_client = await get_entity(client, chat)
    except:
        await self_destruct("<code>Invalid chat or group</code>")
        return

    if chat.id in f:
        f.remove(chat.id)
        lel = (await session.execute(
            select(AutoScroll).where(AutoScroll.id == chat.id)))
        # lel = session.query(AutoScroll).get(chat.id)
        if lel.scalar().one_or_none():
            await session.delete(lel.scalar().one())
        await message.edit(f"<code>Autoscroll disabled in {chat.title}</code>")
    else:
        f.add(chat.id)
        lel = (await session.execute(
            select(AutoScroll).where(AutoScroll.id == chat.id)))
        # lel = session.query(AutoScroll).get(chat.id)
        if not lel:
            await session.add(AutoScroll(chat.id))
        await message.edit(f"<code>Autoscroll enabled in {chat.title}</code>")
    await session.commit()
    await asyncio.sleep(3)
    await message.delete()
    def test_correlate_entity(self):
        User = self.classes.User
        Address = self.classes.Address

        expected = (
            "SELECT users.name, addresses.id, "
            "(SELECT count(addresses.id) AS count_1 "
            "FROM addresses WHERE users.id = addresses.user_id) AS anon_1 "
            "FROM users, addresses")

        stmt1 = select(
            User.name,
            Address.id,
            select(func.count(Address.id)).where(
                User.id == Address.user_id).correlate(User).scalar_subquery(),
        )
        stmt2 = (Session().query(
            User.name,
            Address.id,
            select(func.count(Address.id)).where(
                User.id == Address.user_id).correlate(User).scalar_subquery(),
        ).statement)

        self.assert_compile(stmt1, expected)
        self.assert_compile(stmt2, expected)
Beispiel #4
0
async def fetch_all_rubrics(session: AsyncSession,
                            user_id: int,
                            *,
                            with_links: bool = False) -> list[Rubric]:
    """
    Fetch all rubrics. Optionally, might be loaded rubric links.

    :param session: db connection
    :type session: AsyncSession
    :param user_id: user id
    :type user_id: int
    :keyword with_links: to load rubric links
    :type with_links: bool

    :return: rubrics
    :rtype: list[Rubric]
    """

    if with_links:
        stmt = select(Rubric).options(selectinload(Rubric.links)).where(
            Rubric.user_id == user_id).order_by(Rubric.name)
    else:
        stmt = select(Rubric).where(Rubric.user_id == user_id).order_by(
            Rubric.name)

    result = await session.execute(stmt)
    rubrics = list(result.scalars())

    return rubrics
Beispiel #5
0
    async def update(self, db: AsyncSession, *, db_obj: models.User,
                     obj_in: Union[UserUpdate, Dict[str, Any]]) -> models.User:

        if isinstance(obj_in, dict):
            update_data = obj_in
        else:
            update_data = obj_in.dict(exclude_unset=True)
        if update_data.get("password"):
            hashed_password = get_password_hash(update_data["password"])
            del update_data["password"]
            update_data["hashed_password"] = hashed_password
        if update_data.get("roles") or len(update_data["roles"]) == 0:
            roles = await db.execute(
                select(models.Role).filter(models.Role.name.in_(obj_in.roles)))
            db_obj.roles = roles.scalars().all()
            del update_data["roles"]
        if update_data.get("study_areas") or len(
                update_data["study_areas"]) == 0:
            study_areas = await db.execute(
                select(models.StudyArea).filter(
                    models.StudyArea.id.in_(obj_in.study_areas)))
            db_obj.study_areas = study_areas.scalars().all()
            del update_data["study_areas"]

        return await super().update(db, db_obj=db_obj, obj_in=update_data)
Beispiel #6
0
async def fetch_one_link(session: AsyncSession,
                         link_id: int,
                         *,
                         with_rubric: bool = True) -> Link:
    """
    Fetch one link. Optionally, might be rubric loaded.

    :param session: db connection
    :type session: AsyncSession
    :param link_id: link id
    :type link_id: int
    :param with_rubric: to load link`s rubric
    :type with_rubric: bool

    :return: link
    :rtype: Link
    """

    if with_rubric:
        stmt = select(Link).options(joinedload(
            Link.rubric)).where(Link.id == link_id)
    else:
        stmt = select(Link).where(Link.id == link_id)

    result = await session.execute(stmt)
    link = result.scalar()

    return link
    def test_correlate_aliased_entity(self):
        User = self.classes.User
        Address = self.classes.Address
        uu = aliased(User, name="uu")

        stmt1 = select(
            uu.name,
            Address.id,
            select(func.count(Address.id)).where(
                uu.id == Address.user_id).correlate(uu).scalar_subquery(),
        )

        stmt2 = (Session().query(
            uu.name,
            Address.id,
            select(func.count(Address.id)).where(
                uu.id == Address.user_id).correlate(uu).scalar_subquery(),
        ).statement)

        expected = (
            "SELECT uu.name, addresses.id, "
            "(SELECT count(addresses.id) AS count_1 "
            "FROM addresses WHERE uu.id = addresses.user_id) AS anon_1 "
            "FROM users AS uu, addresses")

        self.assert_compile(stmt1, expected)
        self.assert_compile(stmt2, expected)
Beispiel #8
0
async def test_create_sync_session() -> None:
    logger = structlog.get_logger(__name__)
    engine = create_database_engine(TEST_DATABASE_URL, TEST_DATABASE_PASSWORD)
    await initialize_database(engine, logger, schema=Base.metadata, reset=True)
    await engine.dispose()

    session = create_sync_session(
        TEST_DATABASE_URL,
        TEST_DATABASE_PASSWORD,
        logger,
        statement=select(User),
    )
    with session.begin():
        session.add(User(username="******"))
    session.remove()

    # Use a query against a non-existent table as the liveness check and
    # ensure that fails.
    metadata = MetaData()
    bad_table = Table("bad", metadata, Column("name", String(64)))
    with pytest.raises(ProgrammingError):
        session = create_sync_session(
            TEST_DATABASE_URL,
            TEST_DATABASE_PASSWORD,
            logger,
            statement=select(bad_table),
        )
 def topic_data_find_with_aggregate(self, where, topic_name, aggregate):
     table_name = 'topic_' + topic_name
     table = self.get_topic_table_by_name(table_name)
     return_column_name = None
     for key, value in aggregate.items():
         if value == "sum":
             stmt = select(text(f'sum({key.lower()}) as sum_{key.lower()}'))
             return_column_name = f'SUM_{key.upper()}'
         elif value == "count":
             stmt = select(f'count(*) as count')
             return_column_name = 'COUNT'
         elif value == "avg":
             stmt = select(text(f'avg({key.lower()}) as avg_{key.lower()}'))
             return_column_name = f'AVG_{key.upper()}'
     stmt = stmt.select_from(table)
     stmt = stmt.where(self.build_oracle_where_expression(table, where))
     with self.engine.connect() as conn:
         cursor = conn.execute(stmt).cursor
         columns = [col[0] for col in cursor.description]
         cursor.rowfactory = lambda *args: dict(zip(columns, args))
         res = cursor.fetchone()
     if res is None:
         return None
     else:
         return res[return_column_name]
    def test_update(self):
        User, Address = self.classes("User", "Address")

        s = Session(testing.db, future=True)

        def go(ids, values):
            stmt = lambda_stmt(lambda: update(User).where(User.id.in_(ids)))
            s.execute(
                stmt,
                values,
                # note this currently just unrolls the lambda on the statement.
                # so lambda caching for updates is not actually that useful
                # unless synchronize_session is turned off.
                # evaluate is similar just doesn't work for IN yet.
                execution_options={"synchronize_session": "fetch"},
            )

        go([1, 2], {"name": "jack2"})
        eq_(
            s.execute(select(User.id, User.name).order_by(User.id)).all(),
            [(1, "jack2"), (2, "jack2"), (3, "jill"), (4, "jane")],
        )

        go([3], {"name": "jane2"})
        eq_(
            s.execute(select(User.id, User.name).order_by(User.id)).all(),
            [(1, "jack2"), (2, "jack2"), (3, "jane2"), (4, "jane")],
        )
async def delete_catalog(request: sanic.Request):
    if isinstance(request.json, dict):
        catalog_id = request.json.get('catalog_id', None)

        if isinstance(catalog_id, int):
            query = select(HttpRuleCatalog).where(
                HttpRuleCatalog.catalog_id == catalog_id)
            catalog = (await request.ctx.db_session.execute(query)).scalar()
            if catalog is not None:
                # 删除分类和其下的所有规则
                query = select(HttpRule).where(
                    HttpRule.catalog_id == catalog_id)

                rules = (await request.ctx.db_session.execute(query)).scalars()
                for rule in rules:
                    await request.ctx.db_session.delete(rule)
                await request.ctx.db_session.delete(catalog)
                await request.ctx.db_session.commit()
                return json(Response.success('删除成功'))
            else:
                return json(Response.failed('分类不存在'))
        else:
            return json(Response.invalid('参数无效'))
    else:
        return json(Response.invalid('参数无效'))
Beispiel #12
0
async def get_events(
    response: Response,
    authorization: Optional[str] = Header(None),
    status_code=status.HTTP_200_OK,
    user_info=None,
    limit: Optional[int] = 20,
    offset: Optional[int] = 0,
) -> List[dict]:
    async with get_session() as s:
        filters = (await s.execute(
            select(UserFilter).filter(UserFilter.user_id == user_info.id)
        )).scalars().all()
        res = []
        for f in filters:
            subjects = (await s.execute(
                select(Subject, FilterSubject).join(
                    FilterSubject, Subject.id == FilterSubject.subject).filter(
                        FilterSubject.filter == f.id))).scalars().all()
            res.append({
                "id":
                f.id,
                "start_time":
                f.start_time,
                "end_time":
                f.end_time,
                "city":
                f.city,
                "subjects": [{
                    "id": i.id,
                    "name": i.name
                } for i in subjects],
            })
        return [UserFilterResponse.parse_obj(r) for r in res]
Beispiel #13
0
async def system_log_list(request: sanic.Request):
    if isinstance(request.json, dict):
        page = request.json.get('page', 0)
        page_size = request.json.get('page_size', 35)
        filter = request.json.get('filter', {})

        if isinstance(page, int) and isinstance(page_size, int) and isinstance(
                filter, dict):
            system_log_scalars = (await request.ctx.db_session.execute(
                select(SystemLog).order_by(SystemLog.log_id.desc()).offset(
                    page * page_size).limit(page_size))).scalars()

            count = (await request.ctx.db_session.execute(
                select(func.count('1')).select_from(SystemLog))).scalar()

            system_logs = []
            for log in system_log_scalars:
                log.log_time = log.log_time.strftime('%Y-%m-%d %H:%M:%S')
                system_logs.append(log)

            paged = PagedResponse(payload=system_logs,
                                  total_page=ceil(count / page_size),
                                  curr_page=page)
            return json(Response.success('', paged))

        else:
            return json(Response.invalid('无效参数'))
    else:
        return json(Response.invalid('无效参数'))
Beispiel #14
0
async def test_reconcile_chains(db, chains):
    # Create multiple test chains in db
    chain_1 = Chain(name="chain_1", type="evm", active=True)
    chain_2 = Chain(name="chain_2", type="substrate", active=True)
    db.add_all([chain_1, chain_2])
    await db.commit()

    # execute reconciliation method and fetch result
    chains = await Chain.reconcile_chains(db, chains)
    result = await db.execute(select(Chain))
    result = result.scalars().all()

    # Create list of Chain models for Chain's in  Chains object
    chains_models = [
        Chain(name=chain.name, type=chain.type, active=chain.active)
        for _, chain in chains.get_chains().items()
    ]

    # Assert that expected chain entries in db
    assert sorted([x.serialise() for x in result], key=lambda x: x["name"]) == sorted(
        [x.serialise() for x in [chain_1, chain_2] + chains_models],
        key=lambda x: x["name"],
    )

    # Assert that chain not in chains object has been disabled
    assert (
        await db.execute(select(Chain).where(Chain.name == "chain_1"))
    ).scalars().first().active is False
Beispiel #15
0
async def test_database_init() -> None:
    logger = structlog.get_logger(__name__)
    engine = create_database_engine(TEST_DATABASE_URL, TEST_DATABASE_PASSWORD)
    await initialize_database(engine, logger, schema=Base.metadata, reset=True)
    session = await create_async_session(engine, logger)
    async with session.begin():
        session.add(User(username="******"))
    await session.remove()

    # Reinitializing the database without reset should preserve the row.
    await initialize_database(engine, logger, schema=Base.metadata)
    session = await create_async_session(engine, logger)
    async with session.begin():
        result = await session.scalars(select(User.username))
        assert result.all() == ["someuser"]
    await session.remove()

    # Reinitializing the database with reset should delete the data.
    await initialize_database(engine, logger, schema=Base.metadata, reset=True)
    session = await create_async_session(engine, logger)
    async with session.begin():
        result = await session.scalars(select(User.username))
        assert result.all() == []
    await session.remove()
    await engine.dispose()
Beispiel #16
0
async def set_stickers(client, message):
	if message.from_user.id not in app_user_ids:
		return

	if not DB_AVAILABLE:
		await message.edit("Your database is not avaiable!")
		return

	global TEMP_KEYBOARD, USER_SET
	if message.text in TEMP_KEYBOARD:
		await client.delete_messages(message.chat.id, USER_SET[message.from_user.id])
		if USER_SET["type"] == 1:
			sticker = (await session.execute(select(StickerSet).where(StickerSet.id == message.from_user.id))).scalars().one_or_none()
			if sticker:
				sticker.sticker = message.text
			else:
				sticker = StickerSet(message.from_user.id, message.text)
				session.add(sticker)
		elif USER_SET["type"] == 2:
			sticker = (await session.execute(select(AnimatedStickerSet).where(AnimatedStickerSet.id == message.from_user.id))).scalars().one_or_none()
			sticker = session.query(AnimatedStickerSet).get(message.from_user.id)
			if sticker:
				sticker.sticker = message.text
			else:
				sticker = AnimatedStickerSet(message.from_user.id, message.text)
				session.add(sticker)
		await session.commit()
		status = "Ok, sticker pack was set to <code>{}</code>.".format(message.text)
	else:
		status = "Invalid pack selected."

	TEMP_KEYBOARD = []
	USER_SET = {}
	button = InlineKeyboardMarkup([[InlineKeyboardButton("Set Sticker Pack", callback_data="setsticker")]])
	await slave.send_message(message.chat.id, f"{status}\nWhat else would you like to do?", reply_markup=button)
Beispiel #17
0
async def fetch_one_rubric(session: AsyncSession,
                           rubric_id: int,
                           *,
                           with_links: bool = False) -> Rubric:
    """
    Fetch one rubric. Optionally, might be loaded rubric links.

    :param session: db connection
    :type session: AsyncSession
    :param rubric_id: rubric id
    :type rubric_id: int
    :keyword with_links: to load rubric links
    :type with_links: bool

    :return: rubric
    :rtype: Rubric
    """

    if with_links:
        stmt = select(Rubric).options(selectinload(
            Rubric.links)).where(Rubric.id == rubric_id)
    else:
        stmt = select(Rubric).where(Rubric.id == rubric_id)

    result = await session.execute(stmt)
    rubric = result.scalar()

    return rubric
Beispiel #18
0
async def async_main():
    """Main program function."""

    engine = create_async_engine(
        "postgresql+asyncpg://scott:tiger@localhost/test",
        echo=True,
    )

    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.drop_all)
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)

    # expire_on_commit=False will prevent attributes from being expired
    # after commit.
    async_session = sessionmaker(engine,
                                 expire_on_commit=False,
                                 class_=AsyncSession)

    async with async_session() as session:
        async with session.begin():
            session.add_all([
                A(bs=[B(), B()], data="a1"),
                A(bs=[B()], data="a2"),
                A(bs=[B(), B()], data="a3"),
            ])

        # for relationship loading, eager loading should be applied.
        stmt = select(A).options(selectinload(A.bs))

        # AsyncSession.execute() is used for 2.0 style ORM execution
        # (same as the synchronous API).
        result = await session.execute(stmt)

        # result is a buffered Result object.
        for a1 in result.scalars():
            print(a1)
            print(f"created at: {a1.create_date}")
            for b1 in a1.bs:
                print(b1)

        # for streaming ORM results, AsyncSession.stream() may be used.
        result = await session.stream(stmt)

        # result is a streaming AsyncResult object.
        async for a1 in result.scalars():
            print(a1)
            for b1 in a1.bs:
                print(b1)

        result = await session.execute(select(A).order_by(A.id))

        a1 = result.scalars().first()

        a1.data = "new data"

        await session.commit()
Beispiel #19
0
async def modify(request: sanic.Request):
    if isinstance(request.json, dict):
        rule_id = request.json.get('rule_id', None)
        rule_type = request.json.get('rule_type', None)
        path = request.json.get('path', None)
        filename = request.json.get('filename', None)
        write_log = request.json.get('write_log', None)
        send_mail = request.json.get('send_mail', None)
        comment = request.json.get('comment', None)
        catalog_id = request.json.get('catalog_id', None)

        if isinstance(rule_id, int):
            query = select(HttpRule).where(HttpRule.rule_id == rule_id)
            rule = (await request.ctx.db_session.execute(query)).scalar()
            if rule is not None:
                if isinstance(path, str):
                    query = select(HttpRule).where(HttpRule.path == path)
                    result = (await
                              request.ctx.db_session.execute(query)).scalar()
                    if result is None or result.rule_id == rule_id:
                        rule.path = path
                    else:
                        return json(Response.failed('路径已经存在'))

                if isinstance(catalog_id, int):
                    query = select(HttpRuleCatalog).where(
                        HttpRuleCatalog.catalog_id == catalog_id)
                    result = (await
                              request.ctx.db_session.execute(query)).scalar()
                    if result is not None:
                        rule.catalog_id = catalog_id
                    else:
                        return json(Response.failed('分类不存在'))

                if isinstance(filename, str):
                    rule.filename = filename
                if isinstance(write_log, bool):
                    rule.write_log = write_log
                if isinstance(send_mail, bool):
                    rule.send_mail = send_mail
                if isinstance(comment, str):
                    rule.comment = comment
                if rule_type in constants.RULE_TYPES:
                    rule.rule_type = rule_type

                request.ctx.db_session.add(rule)
                await request.ctx.db_session.commit()
                return json(Response.success('修改成功'))
            else:
                return json(Response.failed('规则不存在'))
        else:
            return json(Response.invalid('参数无效'))
    else:
        return json(Response.invalid('参数无效'))
Beispiel #20
0
async def async_main(engine):
    """Main program function."""

    engine = create_async_engine(config.POSTGRES_URI, echo=False)

    async with engine.begin() as conn:
        # await conn.run_sync(BaseModel.metadata.drop_all)
        await conn.run_sync(BaseModel.metadata.create_all)

    async with AsyncSession(engine) as session:
        async with session.begin():
            session.add_all(
                [
                    A(bs=[B(), B()], data="a1"),
                    A(bs=[B()], data="a2"),
                    A(bs=[B(), B()], data="a3"),
                ]
            )

        # for relationship loading, eager loading should be applied.
        stmt = select(A).options(selectinload(A.bs))

        # AsyncSession.execute() is used for 2.0 style ORM execution
        # (same as the synchronous API).
        result = await session.execute(stmt)

        # result is a buffered Result object.
        for a1 in result.scalars():
            print(a1)
            for b1 in a1.bs:
                print(b1)

        # for streaming ORM results, AsyncSession.stream() may be used.
        result = await session.stream(stmt)

        # result is a streaming AsyncResult object.
        async for a1 in result.scalars():
            print(a1)
            for b1 in a1.bs:
                print(b1)

        result = await session.execute(select(A).order_by(A.id))

        a1 = result.scalars().first()

        a1.data = "new data"

        await session.commit()

        await engine.dispose()


# asyncio.run(async_main())
Beispiel #21
0
async def get_events(
    response: Response,
    authorization: Optional[str] = Header(None),
    status_code=status.HTTP_200_OK,
    user_info=None,
    city: Optional[int] = None,
    start_time: Optional[str] = None,
    end_time: Optional[str] = None,
    subjects: Optional[str] = None,
    limit: Optional[int] = 20,
    offset: Optional[int] = 0,
) -> List[dict]:
    async with get_session() as s:
        subjects_qs = select(Subject, EventSubject).join(
            EventSubject, EventSubject.subject == Subject.id)
        if subjects:
            subjects_list = []
            for i in subjects.split(','):
                try:
                    subjects_list.append(int(i))
                except:
                    pass
            subjects_qs = subjects_qs.filter(Subject.id.in_(subjects_list))
        subjects_qs = (await s.execute(subjects_qs)).fetchall()
        if subjects:
            events = select(Event, City, EventSubject).join(
                City, City.id == Event.city).join(
                    EventSubject, (EventSubject.event == Event.id) &
                    (EventSubject.subject.in_([i.id for i, j in subjects_qs])))
        else:
            events = select(Event, City, EventSubject).join(
                City,
                City.id == Event.city).join(EventSubject,
                                            EventSubject.event == Event.id)
        if city:
            events.filter(City.id == city)
        if start_time:
            events.filter(
                Event.start_time == dateutil.parser.parse(start_time))
        if end_time:
            events.filter(Event.end_time == dateutil.parser.parse(end_time))
        events = (await
                  s.execute(events.limit(limit).offset(offset))).fetchall()
        res = []
        for e, c, esub in events:
            e = e.as_dict()
            e["city"] = c.name if c else None
            e["subjects"] = [{
                "id": subj.id,
                "name": subj.name
            } for subj, es in subjects_qs if e["id"] == es.event]
            res.append(e)
        return [EventResponse.parse_obj(e) for e in res]
async def get_items(query: str, offset: int):
    async with async_session() as session:
        if query:
            results = await session.execute(
                select(Items).filter(
                    func.lower(
                        Items.name).like(f'%{query.lower()}%')).order_by(
                            Items.name).offset(offset).limit(20))
        else:
            results = await session.execute(
                select(Items).order_by(Items.name).offset(offset).limit(20))
        return results.scalars()
def topic_find_one_and_update(where, updates, name):
    '''
    table = Table('topic_' + name, metadata, extend_existing=True,
                  autoload=True, autoload_with=engine)
    '''
    table_name = 'topic_' + name
    table = get_topic_table_by_name(table_name)
    data_dict: dict = convert_to_dict(updates)

    select_for_update_stmt = select(table). \
        with_for_update(nowait=False). \
        where(build_oracle_where_expression(table, where))

    # if "id_" not in updates:
    #     updates["id_"] = get_surrogate_key()
    insert_stmt = insert(table).values(
        build_oracle_updates_expression_for_insert(table, data_dict))

    update_stmt = update(table).where(
        build_oracle_where_expression(table, where)).values(
            build_oracle_updates_expression_for_update(table, data_dict))

    select_new_stmt = select(table). \
        where(build_oracle_where_expression(table, where))

    with engine.connect() as conn:
        with conn.begin():
            row = conn.execute(select_for_update_stmt).fetchone()
            if row is not None:
                conn.execute(update_stmt)
            else:
                conn.execute(insert_stmt)
    '''
    with engine.connect() as conn:
        with conn.begin():
            cursor = conn.execute(select_stmt).cursor
            columns = [col[0] for col in cursor.description]
            cursor.rowfactory = lambda *args: dict(zip(columns, args))
            result = cursor.fetchone()
            if result is not None:
                conn.execute(update_stmt)
            else:
                conn.execute(insert_stmt)
    '''
    with engine.connect() as conn:
        with conn.begin():
            cursor = conn.execute(select_new_stmt).cursor
            columns = [col[0] for col in cursor.description]
            cursor.rowfactory = lambda *args: dict(zip(columns, args))
            result = cursor.fetchone()

    return convert_dict_key(result, name)
Beispiel #24
0
    def countResults():
        with engine.connect() as conn:
            race = [
                row[0] for row in conn.execute(
                    select(func.count('*')).select_from(Race))
            ][0]
            horse_results = [
                row[0] for row in conn.execute(
                    select(func.count('*')).select_from(HorseResult))
            ][0]

        counts = {'count': {'Race': race, 'HorseResult': horse_results}}
        return counts
Beispiel #25
0
    async def _boss_id_locator(self, boss_identifier: str) -> int:
        try:
            boss_id = int(boss_identifier)
            stmt = select(Boss).filter(Boss.boss_id == boss_id)
        except ValueError:
            stmt = select(Boss).filter(Boss.alias == boss_identifier)

        async with self.__session.begin() as async_session:
            query = await async_session.stream(stmt)
            record = await query.scalars().first()
        if record is None:
            return -1
        else:
            return record.boss_id
Beispiel #26
0
 async def create(self, db: AsyncSession, *,
                  obj_in: UserCreate) -> models.User:
     db_obj = models.User.from_orm(obj_in)
     db_obj.hashed_password = get_password_hash(obj_in.password)
     roles = await db.execute(
         select(models.Role).filter(models.Role.name.in_(obj_in.roles)))
     db_obj.roles = roles.scalars().all()
     study_areas = await db.execute(
         select(models.StudyArea).filter(
             models.StudyArea.id.in_(obj_in.study_areas)))
     db_obj.study_areas = study_areas.scalars().all()
     db.add(db_obj)
     await db.commit()
     await db.refresh(db_obj)
     return db_obj
        def query(names):
            stmt = lambda_stmt(
                lambda: select(User.name, Address.email_address).where(
                    User.name.in_(names)).join(User.addresses)) + (
                        lambda s: s.order_by(User.id, Address.id))

            return s.execute(stmt)
Beispiel #28
0
                def query(names):
                    u1 = aliased(User)
                    stmt = lambda_stmt(lambda: select(u1).where(
                        u1.name.in_(names)).options(selectinload(u1.addresses))
                                       ) + (lambda s: s.order_by(u1.id))

                    return s.execute(stmt)
Beispiel #29
0
async def process_note_text(message: types.Message, chat: Chat,
                            state: FSMContext):
    dt = message.date
    async with state.proxy() as data:
        async with OrmSession() as session:
            select_stmt = select(Project)\
                .where(
                    Project.chat_id == chat.id
                )\
                .order_by(Project.id)
            projects_result = await session.execute(select_stmt)
            project = projects_result.scalars().first()

            item_in_list_pos = data['item_in_list_pos']
            item_id = data['items_ids'][item_in_list_pos]
            log_message = ItemNote(project_id=project.id,
                                   item_id=item_id,
                                   text=message.text,
                                   created_dt=dt)
            logging.info(log_message)
            session.add(log_message)
            await session.commit()
    await message.reply(emojize(
        text(text('Все, так и запишу:'),
             text('    :pencil2:', message.text),
             sep='\n')),
                        disable_web_page_preview=True)
    await state.finish()
Beispiel #30
0
async def async_main():

    engine = create_async_engine(
        "postgresql+asyncpg://scott:tiger@localhost/test",
        echo=True,
    )

    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.drop_all)
        await conn.run_sync(Base.metadata.create_all)

    async_session = async_sessionmaker(engine, expire_on_commit=False)

    async with async_session() as session, session.begin():
        session.add_all([A(data="a_%d" % i) for i in range(100)])

    statements = [
        select(A).where(A.data == "a_%d" % random.choice(range(100)))
        for i in range(30)
    ]

    results = await asyncio.gather(
        *(run_out_of_band(async_session, session, statement)
          for statement in statements))
    print(f"results: {[r.all() for r in results]}")