class Database(object): def __init__(self, redis_url, mysql_url, loop): self.loop = loop self.redis_url = redis_url self.mysql_url = mysql_url self.loop.create_task(self.create()) self.redis_address = parse_redis_url(redis_url) async def create(self): log.debug('Creating Redis instance') self.redis = await aioredis.create_redis( self.redis_address, encoding='utf8', db=0, #lol ) log.debug('Creating MySQL instance') self.mysql = DatabaseInterface(self.mysql_url) await self.mysql.connect() from silph.models.base import Table log.debug('Binding MySQL to ORM') self.mysql.bind_tables(Table.metadata) async def close(self): await self.mysql.close() await self.redis.quit()
async def test_merge(db: DatabaseInterface, table: Table): id_ = 100 async with db.get_session() as sess: await sess.execute("insert into {} values ({}, 'test', '')".format( table.__tablename__, id_)) async with db.get_session() as sess: await sess.merge(table(id=id_))
async def test_update(db: DatabaseInterface, table: Table): name = "test2" async with db.get_session() as sess: await sess.update(table).set(table.name, name).where(table.id > 10) async with db.get_session() as sess: results = await sess.select(table).where(table.id > 10).all() async for result in results: assert result.name == name
async def test_create_table(db: DatabaseInterface): async with db.get_ddl_session() as sess: await sess.create_table( table_name, Column.with_name("id", Integer(), primary_key=True), Column.with_name("name", String(128)), Column.with_name("balance", Real())) async with db.get_session() as sess: assert await sess.fetch("select * from {}".format(table_name)) is None
async def test_alter_column_type(db: DatabaseInterface): async with db.get_ddl_session() as sess: await sess.alter_column_type(table_name, "age", Real()) async with db.get_session() as sess: await sess.execute( "insert into {} values (1, 'Drizzt Do''Urden', 1.5)".format( table_name)) result = await sess.fetch("select age from {}".format(table_name)) assert result['age'] == 1.5
async def test_upsert(db: DatabaseInterface, table: Table): async with db.get_session() as sess: query = sess.insert.rows(table(id=1, name="upsert", email="notupdated")) query = query.on_conflict(table.id).update(table.name) await query.run() async with db.get_session() as sess: res = await sess.select(table).where(table.id == 1).first() assert table.name == "upsert" assert table.email != "notupdated"
async def test_rollback(db: DatabaseInterface, table: Table): sess = db.get_session() try: await sess.start() await sess.execute('delete from {} where 1=1'.format( table.__tablename__)) await sess.rollback() finally: await sess.close() async with db.get_session() as sess: res = await sess.cursor('select count(*) from test') assert (await res.fetch_row())[0] > 0
async def test_create_unique_index(db: DatabaseInterface): if isinstance(db.dialect, sqlite3.Sqlite3Dialect): num_indexes = 2 # sqlite3 does't index primary keys else: num_indexes = 3 async with db.get_ddl_session() as sess: await sess.create_index(table_name, "index_age", "age", unique=True) assert await get_num_indexes(db) == num_indexes fmt = "insert into {} values ({{}}, 'test', 10);".format(table_name) async with db.get_session() as sess: await sess.execute(fmt.format(100)) with pytest.raises(DatabaseException): await sess.execute(fmt.format(101))
async def create(self): log.debug('Creating Redis instance') self.redis = await aioredis.create_redis( self.redis_address, encoding='utf8', db=0, #lol ) log.debug('Creating MySQL instance') self.mysql = DatabaseInterface(self.mysql_url) await self.mysql.connect() from silph.models.base import Table log.debug('Binding MySQL to ORM') self.mysql.bind_tables(Table.metadata)
async def test_create_index(db: DatabaseInterface): if isinstance(db.dialect, sqlite3.Sqlite3Dialect): num_indexes = 1 # sqlite3 does't index primary keys else: num_indexes = 2 async with db.get_ddl_session() as sess: await sess.create_index(table_name, "index_name", "name") assert await get_num_indexes(db) == num_indexes
async def test_transaction_use(db: DatabaseInterface): tr = db.get_transaction() await tr.begin() # this just ensures the connection doesn't error await tr.execute("SELECT 1 + 1;") await tr.rollback() await tr.close()
async def test_insert(db: DatabaseInterface, table: Table): async with db.get_session() as sess: name = kwargs["name"] email = kwargs["email"] rows = [] for i in range(50): rows.append(table(id=i, name=name.format(i), email=email.format(i))) await sess.insert.rows(*rows)
async def test_upsert_multiple_constriants(db: DatabaseInterface, table: Table): idx_name = "{}_email_idx".format(table.__tablename__) async with db.get_ddl_session() as sess: await sess.create_index(table.__tablename__, idx_name, "email", "name", unique=True) for i in range(51, 53): async with db.get_session() as sess: query = sess.insert.rows( table(id=51, name="test1", email="*****@*****.**")) query = query.on_conflict(table.name, table.email).nothing() await query.run() async with db.get_session() as sess: res = await sess.select(table).where(table.id == 52).first() assert res is None
async def test_transaction_fetch_one(db: DatabaseInterface): tr = db.get_transaction() await tr.begin() cursor = await tr.cursor("SELECT 1 + 1;") async with cursor: row = await cursor.fetch_row() # rowdict assert row[0] == 2 await tr.rollback() await tr.close()
async def test_transaction_fetch_many(db: DatabaseInterface): tr = db.get_transaction() await tr.begin() cursor = await tr.cursor('SELECT 1 AS result UNION ALL SELECT 2;') async with cursor: rows = await cursor.fetch_many(n=2) assert rows[0]["result"] == 1 assert rows[1]["result"] == 2 await tr.rollback() await tr.close()
async def test_transaction_fetch_multiple(db: DatabaseInterface): tr = db.get_transaction() await tr.begin() cursor = await tr.cursor('SELECT 1 AS result UNION ALL SELECT 2;') previous = 0 async with cursor: async for row in cursor: assert row["result"] > previous previous = row["result"] await tr.rollback() await tr.close()
async def test_transaction_with_error(db: DatabaseInterface): tr = db.get_transaction() await tr.begin() with pytest.raises(DatabaseException): try: # deliberately bad query await tr.execute("SELECT nonexistant FROM nosuchtable;") finally: await tr.rollback() await tr.close(has_error=True) # dont raise this time await tr.begin() await tr.execute("SELECT 1+1;") await tr.rollback() await tr.close()
async def get_num_columns(db: DatabaseInterface) -> int: count = 0 async with db.get_ddl_session() as sess: for _ in await sess.get_columns(table_name): count += 1 return count
async def test_add_column(db: DatabaseInterface): async with db.get_ddl_session() as sess: await sess.add_column(table_name, Column.with_name("age", Integer())) assert await get_num_columns(db) == 4
async def test_select(db: DatabaseInterface, table: Table): async with db.get_session() as sess: res = await sess.select(table).where(table.id == 1).first() for attr, value in kwargs.items(): assert getattr(res, attr, object()) == value.format(res.id)
async def test_drop_column(db: DatabaseInterface): async with db.get_ddl_session() as sess: await sess.drop_column(table_name, "balance") assert await get_num_columns(db) == 3
async def test_fetch(db: DatabaseInterface, table: Table): async with db.get_session() as sess: res = await sess.fetch('select * from {}'.format(table.__tablename__)) for attr, value in kwargs.items(): assert res[attr] == value.format(res["id"])
async def test_acquire_transaction(db: DatabaseInterface): tr = db.get_transaction() assert isinstance(tr, BaseTransaction)
async def test_delete(db: DatabaseInterface, table: Table): async with db.get_session() as sess: await sess.delete(table).where(table.id == 1) async with db.get_session() as sess: res = await sess.select(table).where(table.id == 1).first() assert res is None
""" py.test configuration """ import asyncio import os import pytest from asyncqlio import DatabaseInterface from asyncqlio.orm.schema.table import table_base, Table from asyncqlio.orm.schema.column import Column from asyncqlio.orm.schema.types import Integer, String # global so it can be accessed in other fixtures iface = DatabaseInterface(dsn=os.environ["ASQL_DSN"]) @pytest.fixture(scope="module") async def db() -> DatabaseInterface: await iface.connect() yield iface await iface.close() @pytest.fixture(scope="module") async def table() -> Table: class Test(table_base()): id = Column(Integer(), primary_key=True) name = Column(String(64)) email = Column(String(64))
async def test_truncate(db: DatabaseInterface, table: Table): async with db.get_session() as sess: await sess.truncate(table) async with db.get_session() as sess: res = await sess.select(table).first() assert res is None
async def test_drop_table(db: DatabaseInterface): async with db.get_ddl_session() as sess: await sess.drop_table(table_name) async with db.get_session() as sess: with pytest.raises(DatabaseException): await sess.execute("select * from {}".format(table_name))