async def clean(self, ctx: Context): """ Cleans out old colourme roles for users no longer in the server. """ async with threadpool(): with ctx.bot.database.get_session() as sess: assert isinstance(sess, Session) roles = sess.query(UserColour).filter( UserColour.guild_id == ctx.guild.id).all() modchoice_enabled = await ctx.bot.database.get_setting( ctx.guild, "colourme_modchoice_enabled") removed = [] async with ctx.channel.typing(): for usercolour in roles: member = ctx.guild.get_member(usercolour.user_id) if member is None or modchoice_enabled: role = ctx.guild.roles.find | ( lambda r: r.id == usercolour.role_id) if not role: continue await role.delete() removed.append(role) rids = [role.id for role in removed] async with threadpool(): with ctx.bot.database.get_session() as sess: assert isinstance(sess, Session) sess.query(UserColour).filter( UserColour.role_id.in_(rids)).delete() await ctx.send(":heavy_check_mark: Deleted `{}` roles.".format( len(removed)))
async def ban_unauth_user_by_query(self, guild_id, placer_id, username, discriminator): async with threadpool(): with self.get_session() as session: dbuser = None if discriminator: dbuser = session.query(UnauthenticatedUsers) \ .filter(UnauthenticatedUsers.guild_id == guild_id) \ .filter(UnauthenticatedUsers.username.ilike("%" + username + "%")) \ .filter(UnauthenticatedUsers.discriminator == discriminator) \ .order_by(UnauthenticatedUsers.id.desc()).first() else: dbuser = session.query(UnauthenticatedUsers) \ .filter(UnauthenticatedUsers.guild_id == guild_id) \ .filter(UnauthenticatedUsers.username.ilike("%" + username + "%")) \ .order_by(UnauthenticatedUsers.id.desc()).first() if not dbuser: return "Ban error! Guest user cannot be found." dbban = session.query(UnauthenticatedBans) \ .filter(UnauthenticatedBans.guild_id == guild_id) \ .filter(UnauthenticatedBans.last_username == dbuser.username) \ .filter(UnauthenticatedBans.last_discriminator == dbuser.discriminator).first() if dbban is not None: if dbban.lifter_id is None: return "Ban error! Guest user, **{}#{}**, has already been banned.".format( dbban.last_username, dbban.last_discriminator) session.delete(dbban) dbban = UnauthenticatedBans(guild_id, dbuser.ip_address, dbuser.username, dbuser.discriminator, "", placer_id) session.add(dbban) session.commit() return "Guest user, **{}#{}**, has successfully been added to the ban list!".format( dbban.last_username, dbban.last_discriminator)
async def flag_unactive_bans(self, guild_id, guildbans): async with threadpool(): with self.get_session() as session: changed = False for usr in guildbans: dbusr = session.query(GuildMembers) \ .filter(GuildMembers.guild_id == int(guild_id)) \ .filter(GuildMembers.user_id == int(usr.id)) \ .filter(GuildMembers.active == False).first() changed = True if dbusr: dbusr.banned = True else: dbusr = GuildMembers( int(guild_id), int(usr.id), usr.name, usr.discriminator, None, usr.avatar, False, True, "[]" ) session.add(dbusr) if changed: session.commit()
async def revoke_unauth_user_by_query(self, guild_id, username, discriminator): async with threadpool(): with self.get_session() as session: dbuser = None if discriminator: dbuser = session.query(UnauthenticatedUsers) \ .filter(UnauthenticatedUsers.guild_id == guild_id) \ .filter(UnauthenticatedUsers.username.ilike("%" + username + "%")) \ .filter(UnauthenticatedUsers.discriminator == discriminator) \ .order_by(UnauthenticatedUsers.id.desc()).first() else: dbuser = session.query(UnauthenticatedUsers) \ .filter(UnauthenticatedUsers.guild_id == guild_id) \ .filter(UnauthenticatedUsers.username.ilike("%" + username + "%")) \ .order_by(UnauthenticatedUsers.id.desc()).first() if not dbuser: return "Kick error! Guest user cannot be found." elif dbuser.revoked: return "Kick error! Guest user **{}#{}** has already been kicked!".format( dbuser.username, dbuser.discriminator) dbuser.revoked = True session.commit() return "Successfully kicked **{}#{}**!".format( dbuser.username, dbuser.discriminator)
async def sample_handler(request): # id = int(request.match_info.get('id', 0)) async with threadpool(): with db_session: user = User.get(id=0) name = user.name if user else 'Anonymous' return web.Response(text="Hello, %s" % name)
async def commit(self): """ Commits the current transaction. """ async with self._lock: async with threadpool(): self.connection.commit()
async def get_reminder(self, id: int) -> Reminder: """ Gets a reminder by ID. """ async with threadpool(): with self.get_session() as sess: return sess.query(Reminder).filter(Reminder.id == id).first()
async def update_guild(self, guild): if guild.me.server_permissions.manage_webhooks: server_webhooks = await self.bot.get_server_webhooks(guild) else: server_webhooks = [] async with threadpool(): with self.get_session() as session: gui = session.query(Guilds).filter( Guilds.guild_id == guild.id).first() if not gui: gui = Guilds( guild.id, guild.name, json.dumps(get_roles_list(guild.roles)), json.dumps(get_channels_list(guild.channels)), json.dumps(get_webhooks_list(server_webhooks)), json.dumps(get_emojis_list(guild.emojis)), guild.owner_id, guild.icon) session.add(gui) else: gui.name = guild.name gui.roles = json.dumps(get_roles_list(guild.roles)) gui.channels = json.dumps(get_channels_list( guild.channels)) gui.webhooks = json.dumps( get_webhooks_list(server_webhooks)) gui.emojis = json.dumps(get_emojis_list(guild.emojis)) gui.owner_id = guild.owner_id gui.icon = guild.icon session.commit()
async def update_event_setting(self, guild: discord.Guild, event: str, *, enabled: bool = None, message: str = None, channel: discord.TextChannel = None) -> EventSetting: """ Updates an event setting. """ original = await self.get_event_setting(guild, event) guild = await self.get_or_create_guild(guild) async with threadpool(): with self.get_session() as sess: if original is None: original = EventSetting(event=event) # add now, to prevent sqlalchemy being mean sess.add(original) # update backrefs original.guild = guild if enabled is not None: original.enabled = enabled if message is not None: original.message = message if channel is not None: original.channel_id = channel.id return original
async def get_user_stocks(self, user: discord.Member, *, guild: discord.Guild = None) -> typing.Sequence[UserStock]: """ Gets the stocks that a user owns. If guild is provided, this will only fetch stocks from that guild. """ # always create if needed user = await self.get_or_create_user(user) if guild: g_obb = await self.get_or_create_guild(guild) else: g_obb = None async with threadpool(): with self.get_session() as sess: assert isinstance(sess, Session) # if guild was provided, we want to do a joined query on `stock` if guild is not None: query = sess.query(UserStock) \ .join(UserStock.stock) \ .filter((UserStock.user_id == user.id) & (Stock.guild_id == guild.id)) else: query = sess.query(UserStock) \ .filter(UserStock.user_id == user.id) results = list(query.all()) return results
async def save_rolestate(self, member: discord.Member) -> RoleState: """ Saves the rolestate for a member. """ guild = await self.get_or_create_guild(member.guild) user = await self.get_or_create_user(member) async with threadpool(): with self.get_session() as session: assert isinstance(session, Session) current_rolestate = session.query(RoleState) \ .filter((RoleState.user_id == member.id) & (RoleState.guild_id == member.guild.id)) \ .first() if current_rolestate is None: current_rolestate = RoleState(user_id=member.id) current_rolestate.guild = guild # Add role IDs directly as an array. current_rolestate.nick = member.nick current_rolestate.roles = [r.id for r in member.roles if not r == member.guild.default_role] session.merge(user) session.add(current_rolestate) return current_rolestate
async def sql(self, ctx: Context, *, s: str): """ Executes some raw SQL. """ # strip the graves if s.startswith("```"): s = s[3:] if s.endswith("```"): s = s[:-3] # wrpa the query in text t = text(s) # needs more indentation try: async with ctx.channel.typing(): async with threadpool(): with ctx.bot.database.get_session() as sess: results = sess.execute(t) # type: ResultProxy headers = results.keys() all_values = results.fetchall() sess.commit() except DatabaseError as e: await ctx.send("```sql\n{}```".format(e.args[0])) return if not all_values: await ctx.send("Query executed without results.") return tables = paginate_table(all_values, headers) for tbl in tables: await ctx.send(tbl)
async def change_stock(self, channel: discord.TextChannel, *, amount: int = None, price: int = None) -> Stock: """ Changes the stock for the specified channel. If the stock already exists, the properties are updated. :param amount: The amount of stocks to create. :param price: The price of this stock. """ guild = await self.get_or_create_guild(channel.guild) async with threadpool(): with self.get_session() as sess: assert isinstance(sess, Session) stock = Stock() # update the appropriate fields stock.channel_id = channel.id stock.guild = guild if amount is not None: stock.amount = amount if price is not None: stock.price = price sess.merge(stock) return stock
async def update_guild_member(self, member, active=True, banned=False): async with threadpool(): with self.get_session() as session: dbmember = session.query(GuildMembers) \ .filter(GuildMembers.guild_id == member.server.id) \ .filter(GuildMembers.user_id == member.id) \ .order_by(GuildMembers.id).all() if not dbmember: dbmember = GuildMembers( member.server.id, member.id, member.name, member.discriminator, member.nick, member.avatar, active, banned, json.dumps(list_role_ids(member.roles))) session.add(dbmember) else: if len(dbmember) > 1: for mem in dbmember[1:]: session.delete(mem) dbmember = dbmember[0] dbmember.banned = banned dbmember.active = active dbmember.username = member.name dbmember.discriminator = member.discriminator dbmember.nickname = member.nick dbmember.avatar = member.avatar dbmember.roles = json.dumps(list_role_ids(member.roles)) session.commit()
async def save_tag(self, guild: discord.Guild, name: str, content: str, *, owner: discord.Member = None, lua: bool = False) -> Tag: """ Saves a tag to the database. """ guild = await self.get_or_create_guild(guild) tag = await self.get_tag(guild, name) async with threadpool(): with self.get_session() as sess: # add it first otherwise sqlalchemy cries if tag is None: tag = Tag() sess.add(tag) # update tag tag.name = name tag.content = content tag.last_modified = datetime.datetime.now() if owner is not None: tag.user_id = owner.id tag.guild_id = guild.id tag.lua = lua return tag
async def get_tag( self, guild: discord.Guild, name: str, return_alias: bool = False ) -> typing.Union[Tag, typing.Tuple[Tag, TagAlias]]: """ Gets a tag from the database. """ async with threadpool(): with self.get_session() as sess: tag = sess.query(Tag) \ .filter((Tag.name == name) & (Tag.guild_id == guild.id)) \ .first() alias = None if tag is None: alias = sess.query(TagAlias) \ .filter((TagAlias.alias_name == name) & (TagAlias.guild_id == guild.id)) \ .first() if alias is not None: tag = alias.tag if return_alias: return tag, alias else: return tag
async def set_colourme_role(self, member: discord.Member, role: discord.Role) -> UserColour: """ Sets the colourme role for a member. """ guild = await self.get_or_create_guild(member.guild) user = await self.get_or_create_user(member) async with threadpool(): with self.get_session() as sess: uc = sess.query(UserColour) \ .filter((UserColour.user_id == member.id) & (UserColour.guild_id == member.guild.id)) \ .first() # type: UserColour if uc is None: uc = UserColour() # update these, to be sure uc.role_id = role.id uc.guild = guild uc.user = user sess.add(uc) return uc
async def set_setting(self, guild: discord.Guild, setting_name: str, value: dict = None, **kwargs) -> Setting: """ Sets a setting value. """ g = await self.get_or_create_guild(guild) async with threadpool(): with self.get_session() as session: setting = session.query(Setting) \ .filter((Setting.guild_id == guild.id) & (Setting.name == setting_name)) \ .first() if setting is None: setting = Setting(name=setting_name) setting.guild = g if value is None: value = {} value = {**value, **kwargs} setting.value = value session.add(setting) return setting
async def fetch_row(self) -> typing.Mapping[str, typing.Any]: """ Fetches one row. """ async with threadpool(): row = self.cursor.fetchone() return DictRow(row) if row is not None else None
async def get_all_tags_for_guild(self, guild: discord.Guild) -> typing.Sequence[Tag]: """ Gets all tags for this guild. """ await self.get_or_create_guild(guild) async with threadpool(): with self.get_session() as sess: return list(sess.query(Tag).filter(Tag.guild_id == guild.id).all())
async def test_threadpool_contextmanager(self): """Test that threadpool() with an argument works as a context manager.""" event_loop_thread = threading.current_thread() async with threadpool(): func_thread = threading.current_thread() assert threading.current_thread() is event_loop_thread assert func_thread is not event_loop_thread
async def fetch_many( self, n: int) -> typing.List[typing.Mapping[str, typing.Any]]: """ Fetches many rows. """ async with threadpool(): rows = self.cursor.fetchmany(size=n) return [DictRow(r) for r in rows if r is not None]
async def test_threadpool_await_in_thread(self): """Test that attempting to await in a thread results in a RuntimeError.""" future = Future() with pytest.raises(RuntimeError) as exc: async with threadpool(): await future assert str(exc.value) == 'attempted to "await" in a worker thread'
async def fetch_last_data(): """ Fetch the last object Fetch the last object :return: """ async with threadpool(): with db_session: return Bangladesh.select().first()
async def unban_server_user(self, user, server): async with threadpool(): with self.get_session() as session: dbmember = session.query(GuildMembers) \ .filter(GuildMembers.guild_id == server.id) \ .filter(GuildMembers.user_id == user.id).first() if dbmember: dbmember.banned = False session.commit()
async def get_multiple_guilds(self, *guilds: typing.List[discord.Guild]) -> typing.Sequence[Guild]: """ Gets multiple guilds. """ async with threadpool(): with self.get_session() as sess: g = sess.query(Guild).filter(Guild.id.in_([g.id for g in guilds])).all() return list(g)
async def remove_guild(self, guild): async with threadpool(): with self.get_session() as session: gui = session.query(Guilds).filter(Guilds.guild_id == int(guild.id)).first() if gui: dbmsgs = session.query(Messages).filter(Messages.guild_id == int(guild.id)).all() for msg in dbmsgs: session.delete(msg) session.delete(gui) session.commit()
async def test_threadpool_contextmanager_exception(self): """Test that an exception raised from a threadpool block is properly propagated.""" event_loop_thread = threading.current_thread() with pytest.raises(ValueError) as exc: async with threadpool(): raise ValueError('foo') assert threading.current_thread() is event_loop_thread assert str(exc.value) == 'foo'
async def connect(self, *args, **kwargs): """ Connects this pool. """ async with threadpool(): for x in range(0, self.queue.maxsize): conn = self._new_connection() self.queue.put_nowait(conn) return self
async def remove_tag_alias(self, guild: discord.Guild, alias: TagAlias): """ Removes a tag alias. """ await self.get_or_create_guild(guild) async with threadpool(): with self.get_session() as sess: sess.delete(alias) return alias
async def dialplan(request): number = request.headers['agi_extension'] async with threadpool(): peers = get_devices_for_number(number) for peer in peers: pprint(peer) await request.send_command('EXEC Dial {}') status = dial_status(request) pprint(status) if ( not (status in STATUS_CONTINUE)): break
def threadpool(self, executor: Union[Executor, str] = None): """ Return an asynchronous context manager that runs the block in a (thread pool) executor. :param executor: either an :class:`~concurrent.futures.Executor` instance, the resource name of one or ``None`` to use the event loop's default executor :return: an asynchronous context manager """ assert check_argument_types() if isinstance(executor, str): executor = self.require_resource(Executor, executor) return asyncio_extras.threadpool(executor)
def executor(arg: Union[Executor, str, Callable] = None): """ Decorate a function so that it runs in an :class:`~concurrent.futures.Executor`. If a resource name is given, the first argument must be a :class:`~.Context`. Usage:: @executor def should_run_in_executor(): ... With a resource name:: @executor('resourcename') def should_run_in_executor(ctx): ... :param arg: a callable to decorate, an :class:`~concurrent.futures.Executor` instance, the resource name of one or ``None`` to use the event loop's default executor :return: the wrapped function """ def outer_wrapper(func: Callable): @wraps(func) def inner_wrapper(*args, **kwargs): try: ctx = next(arg for arg in args[:2] if isinstance(arg, Context)) except StopIteration: raise RuntimeError('the first positional argument to {}() has to be a Context ' 'instance'.format(callable_name(func))) from None executor = ctx.require_resource(Executor, resource_name) return asyncio_extras.call_in_executor(func, *args, executor=executor, **kwargs) return inner_wrapper if isinstance(arg, str): resource_name = arg return outer_wrapper return asyncio_extras.threadpool(arg)