async def introduce_tables(pool: asyncpg.pool.Pool, table_collection: list[dict]) -> None: async with pool.acquire() as connection: for table in table_collection: name = table.get("name", None) fields = table.get("fields", None) other_params = table.get("other_params", None) if type(name) is not str or type(fields) is not list or type( other_params) is not list or len([ field for field in fields + other_params if type(field) is not str ]) > 0: continue a = f"""CREATE TABLE IF NOT EXISTS {name} ( {("," + chr(10) + " ").join(fields + other_params)} ) """ try: await connection.execute(a) except Exception: print( f"Something went wrong creating the table {name} with \n{a}" )
async def update_cache(url: str, result: str, pool: asyncpg.pool.Pool): async with pool.acquire() as conn: await conn.execute(""" UPDATE past_checks SET data = $1, created_at = NOW() WHERE url = $2 """, result, url)
async def any_banned_systems(pool: asyncpg.pool.Pool, sid: int) -> bool: """Check to see if there are any banned systems in the guild specified.""" async with pool.acquire() as conn: response = await conn.fetchval( "select exists(select 1 from banned_systems where server_id = $1)", sid) return response
async def remove_banned_system(pool: asyncpg.pool.Pool, sid: int, system_id: str): """Removes a system and all associated discord accounts from the banned systems table""" async with pool.acquire() as conn: await conn.execute( "DELETE FROM banned_systems WHERE server_id = $1 AND system_id = $2", sid, system_id)
async def upgrade(self, pool: asyncpg.pool.Pool) -> None: async with pool.acquire() as conn: await conn.execute("""CREATE TABLE IF NOT EXISTS version ( version INTEGER PRIMARY KEY )""") row: asyncpg.Record = await conn.fetchrow("SELECT version FROM version LIMIT 1") version = row["version"] if row else 0 if len(self.upgrades) < version: error = (f"Unsupported database version v{version} " f"(latest known is v{len(self.upgrades) - 1})") if not self.allow_unsupported: raise UnsupportedDatabaseVersion(error) else: self.log.warning(error) return elif len(self.upgrades) == version: self.log.debug(f"Database at v{version}, not upgrading") return for new_version in range(version + 1, len(self.upgrades)): upgrade = self.upgrades[new_version] desc = getattr(upgrade, "__mau_db_upgrade_description__", None) suffix = f": {desc}" if desc else "" self.log.debug(f"Upgrading database from v{version} to v{new_version}{suffix}") await upgrade(conn) version = new_version async with conn.transaction(): self.log.debug(f"Saving current version (v{version}) to database") await conn.execute("DELETE FROM version") await conn.execute("INSERT INTO version (version) VALUES ($1)", version)
async def insert_migrate_column_if_not_exists( pool: asyncpg.pool.Pool, table_name: str, column_name: str, column_params: str = None, rename_from: str = None, create_if_not_exists: bool = True) -> None: try: async with pool.acquire() as connection: available_columns = await connection.fetch( f"SELECT column_name FROM information_schema.columns WHERE table_name='{table_name}'" ) available_columns = [ column["column_name"] for column in available_columns ] if column_name not in available_columns: if type(rename_from) is str: print( f"Renaming column {rename_from} to {column_name} in {table_name}" ) await connection.execute( f"ALTER TABLE {table_name} RENAME COLUMN {rename_from} TO {column_name}" ) elif create_if_not_exists: print(f"Adding {column_name} to {table_name}") await connection.execute( f"ALTER TABLE {table_name} ADD {column_name} {column_params}" ) except Exception as e: print( f"{e}\nThe above occurred when trying to insert {column_name} into {table_name}" )
async def get_guild_settings(pool: asyncpg.pool.Pool, guild_id: int) -> Optional[GuildDBSettings]: async with pool.acquire() as conn: conn: asyncpg.connection.Connection row = await conn.fetchrow( " SELECT * from guild_settings WHERE guild_id = $1", guild_id) return GuildDBSettings(**row) if row is not None else None
async def update_interview_type_msg_id(pool: asyncpg.pool.Pool, cid: int, mid: int, interview_type_msg_id: int): async with pool.acquire() as conn: conn: asyncpg.connection.Connection await conn.execute( "UPDATE interviews SET interview_type_msg_id = $1 WHERE channel_id = $2 AND member_id = $3", interview_type_msg_id, cid, mid)
async def put(pool: asyncpg.pool.Pool, fid: int, content: bytes): sql = """ insert into story (id, content) values ($1, $2) ON CONFLICT (id) DO UPDATE SET content = excluded.content """ async with pool.acquire() as conn: await conn.execute(sql, fid, content)
async def update_interview_question_number(pool: asyncpg.pool.Pool, cid: int, mid: int, question_number: int): async with pool.acquire() as conn: conn: asyncpg.connection.Connection await conn.execute( "UPDATE interviews SET question_number = $1 WHERE channel_id = $2 AND member_id = $3", question_number, cid, mid)
async def delete_webrisk(pool: asyncpg.pool.Pool, url: str): async with pool.acquire() as conn: await conn.execute( """ DELETE FROM web_risk WHERE url = $1 """, url)
async def delete_member(pool: asyncpg.pool.Pool, pk_sid: str, pk_mid: str): async with pool.acquire() as conn: conn: asyncpg.connection.Connection await conn.execute( "DELETE FROM members WHERE pk_sid = $1 AND pk_mid = $2", pk_sid, pk_mid)
async def insert_hsts(pool: asyncpg.pool.Pool, url: str, status: str): async with pool.acquire() as conn: await conn.execute( """ INSERT INTO hsts (url, status) VALUES($1, $2) """, url, status)
async def add_banned_system(pool: asyncpg.pool.Pool, sid: int, system_id: str, user_id: int): """Adds a banned system and an associated Discord User ID to the banned systems table.""" async with pool.acquire() as conn: await conn.execute( """INSERT INTO banned_systems(server_id, user_id, system_id) VALUES($1, $2, $3)""", sid, user_id, system_id)
async def add_linked_discord_account(pool: asyncpg.pool.Pool, pk_sid: str, dis_uid: int): async with pool.acquire() as conn: conn: asyncpg.connection.Connection await conn.execute( "INSERT INTO accounts(dis_uid, pk_sid) VALUES($1, $2)", dis_uid, pk_sid)
async def update_interview_read_rules(pool: asyncpg.pool.Pool, cid: int, mid: int, read_rules: bool): async with pool.acquire() as conn: conn: asyncpg.connection.Connection await conn.execute( "UPDATE interviews SET read_rules = $1 WHERE channel_id = $2 AND member_id = $3", read_rules, cid, mid)
async def init_db(pool: asyncpg.pool.Pool) -> None: async with pool.acquire() as conn: await conn.execute(create_db) try: version = await conn.fetch(""" SELECT version FROM version; """) version = version[0].get("version") except asyncpg.exceptions.UndefinedTableError: version = "v0.0.0" print(f"Current database version: {version}") version_action = queries[version] if version_action["new_version"] is not None: current_query = version_action next_version = current_query["new_version"] for i in range(0, len(queries) * 2): if current_query["new_version"] is None: break else: if current_query["query"] is not None: await conn.execute(current_query["query"]) next_version = current_query["new_version"] current_query = queries[next_version] await conn.execute("truncate table version;") await conn.execute( """ INSERT INTO version Values(0, $1) """, next_version, ) print(f"Database updated to {next_version}") await asyncio.sleep(1)
async def run_worker( worker_id: int, queue: asyncio.Queue[Response], pool: asyncpg.pool.Pool ) -> None: """ Send a response to kafka """ insert_response = "INSERT INTO monitoring VALUES(DEFAULT, $1, $2, $3, $4, $5, $6)" while True: response = await queue.get() # save a message async with pool.acquire() as connection: try: await connection.execute( insert_response, response.url, response.load_time, response.status_code, response.ok, response.error, response.request_time, ) except asyncpg.exceptions.PostgresConnectionError: logger.error(f"[{worker_id}] cannot connect to postgresql") except asyncpg.exceptions.DataError as err: logger.error(f"[{worker_id}] invalid postgres query {err}") except Exception as err: # pylint: disable=broad-except logger.error(f"[{worker_id}] unexpected error {err}") queue.task_done()
async def add_new_interview(pool: asyncpg.pool.Pool, sid: int, member_id: int, username: str, channel_id: int, question_number: int = 0, interview_finished: bool = False, paused: bool = False, interview_type: str = 'unknown', read_rules: bool = False, join_ts: datetime = None, interview_type_msg_id=None): async with pool.acquire() as conn: conn: asyncpg.connection.Connection # Convert ts to str if join_ts is None: join_ts = datetime.utcnow() ts = join_ts.timestamp() await conn.execute( "INSERT INTO interviews(guild_id, member_id, user_name, channel_id, question_number, interview_finished, paused, interview_type, read_rules, join_ts, interview_type_msg_id) VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)", sid, member_id, username, channel_id, question_number, interview_finished, paused, interview_type, read_rules, ts, interview_type_msg_id)
async def move_role(pool: asyncpg.pool.Pool, gid: int, role_id: int, new_cat_id: int): async with pool.acquire() as conn: conn: asyncpg.connection.Connection await conn.execute( "UPDATE allowed_roles SET cat_id = $1 WHERE role_id = $2 AND guild_id = $3", new_cat_id, role_id, gid)
async def move_role_cat(pool: asyncpg.pool.Pool, cat_id: int, cat_pos: int): async with pool.acquire() as conn: conn: asyncpg.connection.Connection await conn.execute( """ UPDATE role_categories SET cat_position = $1 WHERE cat_id = $2 """, cat_pos, cat_id)
async def fetch_cached_result(url: str, pool: asyncpg.pool.Pool): async with pool.acquire() as conn: result = await conn.fetchval(""" SELECT data FROM past_checks WHERE url = $1 """, url) return result
async def get_roles_in_guild(pool: asyncpg.pool.Pool, gid: int) -> List[AllowedRole]: async with pool.acquire() as conn: conn: asyncpg.connection.Connection raw_rows = await conn.fetch( "SELECT * from allowed_roles where guild_id = $1", gid) return [AllowedRole(**row) for row in raw_rows]
async def rename_role_cat(pool: asyncpg.pool.Pool, cat_id: int, cat_name: str): async with pool.acquire() as conn: conn: asyncpg.connection.Connection await conn.execute( """ UPDATE role_categories SET cat_name = $1 WHERE cat_id = $2 """, cat_name, cat_id)
async def query_db(cls, db: asyncpg.pool.Pool, query: str) -> list: records = list() async with db.acquire() as conn: async with conn.transaction(): async for record in conn.cursor(query): record = cls.process_record(record) records.append(record) return records
async def fetch_updated_at(pool: asyncpg.pool.Pool, url: str): async with pool.acquire() as conn: updated_at = await conn.fetchval(""" SELECT created_at FROM past_checks WHERE url = $1 """, url) return updated_at
async def insert_blank_webrisk(pool: asyncpg.pool.Pool, url: str): async with pool.acquire() as conn: await conn.execute( """ INSERT INTO web_risk (url, expire_time) VALUES ($1, $2) """, url, datetime.datetime.now() + datetime.timedelta(minutes=20))
async def test_select(self, proxy, db_pool: asyncpg.pool.Pool): proxy = Proxy.create_from_url(url=proxy) proxy_db = ProxyDb(db_connect=db_pool, table_proxy=proxy_table) await proxy_db.insert_proxy(**proxy.as_dict()) res = await proxy_db.select_proxy_pm(host=proxy.host, port=proxy.port) assert res['port'] == proxy.port and str(res['host']) == proxy.host async with db_pool.acquire() as conn: query = sqlalchemy.text('delete from proxy where (host = $1 and port = $2)') res = await conn.execute(query, proxy.host, proxy.port)
async def create_table_books(pool: asyncpg.pool.Pool, logger: logging.Logger): async with pool.acquire() as connection: connection: asyncpg.connection.Connection if await check_table(get_table_name(), connection): logger.debug(f"SUCCESS table: {get_table_name()}") else: logger.debug(f"CREATE table: {get_table_name()}") await connection.execute(query=CREATE_TABLE.format( table_name=get_table_name()))
async def update_db_hsts(pool: asyncpg.pool.Pool, url: str, status: str): async with pool.acquire() as conn: await conn.execute( """ UPDATE hsts SET status = $1, updated_at = NOW() WHERE url = $2 """, status, url)
async def check_database(pool: asyncpg.pool.Pool): """Checks the current database's version.""" log.info("Checking database version...") try: async with pool.acquire() as con: version = await get_version(con) if version <= 0: log.info("Schema is empty, creating tables.") await create_database(con) await set_version(con, 1) else: log.info("\tVersion 1 found.") except asyncpg.InsufficientPrivilegeError as e: log.error(f"PostgreSQL error: {e}") return False return True
async def drop_tables(pool: asyncpg.pool.Pool): """Drops all tables and functions from the database.""" async with pool.acquire() as con: log.debug("Dropping tables") await con.execute(""" DO $$ DECLARE r RECORD; BEGIN FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = current_schema()) LOOP EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE'; END LOOP; END $$;""") await con.execute(""" DO $$ DECLARE r RECORD; BEGIN FOR r IN (SELECT routine_name FROM information_schema.routines WHERE routine_type='FUNCTION' AND specific_schema=current_schema()) LOOP EXECUTE 'DROP FUNCTION ' || quote_ident(r.routine_name) || ' CASCADE'; END LOOP; END $$;""") log.debug("Tables dropped")
async def import_legacy_db(pool: asyncpg.pool.Pool, path): """Imports a SQLite database into the PostgreSQL database.""" if not os.path.isfile(path): log.error("Database file doesn't exist or path is invalid.") return legacy_conn = sqlite3.connect(path) log.info("Checking old database...") if not check_sql_database(legacy_conn): log.error("Can't import sqlite database.") return log.info("Importing SQLite rows") start = time.time() c = legacy_conn.cursor() clean_up_old_db(c) # Dictionary that maps SQL IDs to their PSQL ID new_ids = {} async with pool.acquire() as conn: await import_characters(conn, c, new_ids) await import_server_properties(conn, c) await import_roles(conn, c) await import_events(conn, c, new_ids) await import_ignored_channels(conn, c) log.info(f"Importing finished in {time.time()-start:,.2f} seconds.")