class MSSQLClient(ODBCClient): query_class = MSSQLQuery schema_generator = MSSQLSchemaGenerator executor_class = MSSQLExecutor capabilities = Capabilities("mssql", support_update_limit_order_by=False, support_for_update=False) def __init__( self, *, user: str, password: str, host: str, port: SupportsInt, driver: str, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.dsn = f"DRIVER={driver};SERVER={host},{port};UID={user};PWD={password};" def _in_transaction(self) -> "TransactionContext": return TransactionContextPooled(TransactionWrapper(self)) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: await cursor.execute(query, values) await cursor.execute("SELECT @@IDENTITY;") return (await cursor.fetchone())[0]
class OracleClient(ODBCClient): query_class = OracleQuery schema_generator = OracleSchemaGenerator executor_class = OracleExecutor capabilities = Capabilities(dialect="oracle") def __init__( self, *, user: str, password: str, host: str, port: SupportsInt, driver: str, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.password = password self.dsn = f"DRIVER={driver};DBQ={host}:{port};UID={user};PWD={password};" def _in_transaction(self) -> "TransactionContext": return TransactionContextPooled(TransactionWrapper(self)) def acquire_connection( self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]: return OraclePoolConnectionWrapper(self) async def db_create(self) -> None: await self.create_connection(with_db=False) await self.execute_script( f'CREATE USER "{self.database}" IDENTIFIED BY "{self.password}"') await self.execute_script(f'GRANT ALL PRIVILEGES TO "{self.database}"') await self.close() async def db_delete(self) -> None: await self.create_connection(with_db=False) try: await self.execute_script(f'DROP USER "{self.database}" CASCADE') except pyodbc.Error as e: if "does not exist" not in str(e): raise await self.close() @translate_exceptions async def execute_script(self, query: str) -> None: async with self.acquire_connection() as connection: self.log.debug(query) async with connection.cursor() as cursor: for q in query.split(";"): if not q.strip(): continue await cursor.execute(q) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) await connection.execute(query, values) return 0
def __init__(self, file_path: str, **kwargs) -> None: super().__init__(**kwargs) self.filename = file_path self._transaction_class = type('TransactionWrapper', (TransactionWrapper, self.__class__), {}) self._connection = None # type: Optional[aiosqlite.Connection] self.capabilities = Capabilities('sqlite', connection={'file': file_path})
def __init__(self, user: str, password: str, database: str, host: str, port: SupportsInt, **kwargs) -> None: super().__init__(**kwargs) self.user = user self.password = password self.database = database self.host = host self.port = int(port) # make sure port is int type self._connection = None # Type: Optional[aiomysql.Connection] self._transaction_class = type('TransactionWrapper', (TransactionWrapper, self.__class__), {}) self.capabilities = Capabilities('mysql', connection={ 'user': user, 'database': database, 'host': host, 'port': port, })
class MySQLClient(BaseDBAsyncClient): query_class = MySQLQuery executor_class = MySQLExecutor schema_generator = MySQLSchemaGenerator capabilities = Capabilities("mysql", requires_limit=True, inline_comment=True) def __init__( self, *, user: str, password: str, database: str, host: str, port: SupportsInt, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.user = user self.password = password self.database = database self.host = host self.port = int(port) # make sure port is int type self.extra = kwargs.copy() self.storage_engine = self.extra.pop("storage_engine", "") self.extra.pop("connection_name", None) self.extra.pop("fetch_inserted", None) self.extra.pop("db", None) self.extra.pop("autocommit", None) self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES") self.charset = self.extra.pop("charset", "utf8mb4") self.pool_minsize = int(self.extra.pop("minsize", 1)) self.pool_maxsize = int(self.extra.pop("maxsize", 5)) self._template: dict = {} self._pool: Optional[aiomysql.Pool] = None self._connection = None async def create_connection(self, with_db: bool) -> None: if charset_by_name(self.charset) is None: # type: ignore raise DBConnectionError(f"Unknown charset {self.charset}") self._template = { "host": self.host, "port": self.port, "user": self.user, "db": self.database if with_db else None, "autocommit": True, "charset": self.charset, "minsize": self.pool_minsize, "maxsize": self.pool_maxsize, **self.extra, } try: self._pool = await aiomysql.create_pool(password=self.password, **self._template) if isinstance(self._pool, aiomysql.Pool): async with self.acquire_connection() as connection: async with connection.cursor() as cursor: if self.storage_engine: await cursor.execute( f"SET default_storage_engine='{self.storage_engine}';" ) if self.storage_engine.lower() != "innodb": # pragma: nobranch self.capabilities.__dict__["supports_transactions"] = False hours = timezone.now().utcoffset().seconds / 3600 # type: ignore tz = "{:+d}:{:02d}".format(int(hours), int((hours % 1) * 60)) await cursor.execute(f"SET SESSION time_zone='{tz}';") self.log.debug("Created connection %s pool with params: %s", self._pool, self._template) except pymysql.err.OperationalError: raise DBConnectionError(f"Can't connect to MySQL server: {self._template}") async def _expire_connections(self) -> None: if self._pool: # pragma: nobranch for conn in self._pool._free: conn._reader.set_exception(EOFError("EOF")) async def _close(self) -> None: if self._pool: # pragma: nobranch self._pool.close() await self._pool.wait_closed() self.log.debug("Closed connection %s with params: %s", self._connection, self._template) self._pool = None async def close(self) -> None: await self._close() self._template.clear() async def db_create(self) -> None: await self.create_connection(with_db=False) await self.execute_script(f"CREATE DATABASE {self.database}") await self.close() async def db_delete(self) -> None: await self.create_connection(with_db=False) try: await self.execute_script(f"DROP DATABASE {self.database}") except pymysql.err.DatabaseError: # pragma: nocoverage pass await self.close() def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]: return PoolConnectionWrapper(self._pool) def _in_transaction(self) -> "TransactionContext": return TransactionContextPooled(TransactionWrapper(self)) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: await cursor.execute(query, values) return cursor.lastrowid # return auto-generated id @translate_exceptions async def execute_many(self, query: str, values: list) -> None: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: if self.capabilities.supports_transactions: await connection.begin() try: await cursor.executemany(query, values) except Exception: await connection.rollback() raise else: await connection.commit() else: await cursor.executemany(query, values) @translate_exceptions async def execute_query( self, query: str, values: Optional[list] = None ) -> Tuple[int, List[dict]]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: await cursor.execute(query, values) rows = await cursor.fetchall() if rows: fields = [f.name for f in cursor._result.fields] return cursor.rowcount, [dict(zip(fields, row)) for row in rows] return cursor.rowcount, [] async def execute_query_dict(self, query: str, values: Optional[list] = None) -> List[dict]: return (await self.execute_query(query, values))[1] @translate_exceptions async def execute_script(self, query: str) -> None: async with self.acquire_connection() as connection: self.log.debug(query) async with connection.cursor() as cursor: await cursor.execute(query)
class SqliteClient(BaseDBAsyncClient): executor_class = SqliteExecutor schema_generator = SqliteSchemaGenerator capabilities = Capabilities("sqlite", daemon=False, requires_limit=True) def __init__(self, file_path: str, **kwargs) -> None: super().__init__(**kwargs) self.filename = file_path self.pragmas = kwargs.copy() self.pragmas.pop("connection_name", None) self.pragmas.pop("fetch_inserted", None) self._transaction_class = type( "TransactionWrapper", (TransactionWrapper, self.__class__), {} ) self._connection = None # type: Optional[aiosqlite.Connection] self._lock = asyncio.Lock() async def create_connection(self, with_db: bool) -> None: if not self._connection: # pragma: no branch self._connection = aiosqlite.connect(self.filename, isolation_level=None) self._connection.start() await self._connection._connect() self._connection._conn.row_factory = sqlite3.Row for pragma, val in self.pragmas.items(): cursor = await self._connection.execute("PRAGMA {}={}".format(pragma, val)) await cursor.close() self.log.debug( "Created connection %s with params: filename=%s %s", self._connection, self.filename, " ".join(["{}={}".format(k, v) for k, v in self.pragmas.items()]), ) async def close(self) -> None: if self._connection: await self._connection.close() self.log.debug( "Closed connection %s with params: filename=%s %s", self._connection, self.filename, " ".join(["{}={}".format(k, v) for k, v in self.pragmas.items()]), ) self._connection = None async def db_create(self) -> None: pass async def db_delete(self) -> None: await self.close() try: os.remove(self.filename) except FileNotFoundError: # pragma: nocoverage pass def acquire_connection(self) -> ConnectionWrapper: return ConnectionWrapper(self._connection, self._lock) def _in_transaction(self) -> "TransactionWrapper": return self._transaction_class( connection_name=self.connection_name, connection=self._connection, lock=self._lock, fetch_inserted=self.fetch_inserted, ) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) return (await connection.execute_insert(query, values))[0] @translate_exceptions async def execute_many(self, query: str, values: List[list]) -> None: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) # TODO: Ensure that this is wrapped by a transaction, will provide a big speedup await connection.executemany(query, values) @translate_exceptions async def execute_query(self, query: str) -> List[dict]: async with self.acquire_connection() as connection: self.log.debug(query) res = [dict(row) for row in await connection.execute_fetchall(query)] return res @translate_exceptions async def execute_script(self, query: str) -> None: async with self.acquire_connection() as connection: self.log.debug(query) await connection.executescript(query)
class MySQLClient(BaseDBAsyncClient): query_class = MySQLQuery filter_class = MySQLFilter executor_class = MySQLExecutor schema_generator = MySQLSchemaGenerator capabilities = Capabilities("mysql", requires_limit=True, inline_comment=True) def __init__(self, *, user: str, password: str, database: str, host: str, port: SupportsInt, **kwargs) -> None: super().__init__(**kwargs) self.user = user self.password = password self.database = database self.host = host self.port = int(port) # make sure port is int type self.extra = kwargs.copy() self.storage_engine = self.extra.pop("storage_engine", "") self.extra.pop("connection_name", None) self.extra.pop("db", None) self.extra.pop("autocommit", None) self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES") self.charset = self.extra.pop("charset", "utf8mb4") self.pool_minsize = int(self.extra.pop("minsize", 1)) self.pool_maxsize = int(self.extra.pop("maxsize", 5)) self._pool: Optional[aiomysql.Pool] = None def _copy(self, base) -> None: super()._copy(base) self.user = base.user self.password = base.password self.database = base.database self.host = base.host self.port = base.port self.extra = base.extra self.storage_engine = base.storage_engine self.charset = base.charset self.pool_minsize = base.pool_minsize self.pool_maxsize = base.pool_maxsize self._pool = base._pool async def create_connection(self, with_db: bool) -> None: if charset_by_name(self.charset) is None: # type: ignore raise DBConnectionError(f"Unknown charset {self.charset}") pool_template = { "host": self.host, "port": self.port, "user": self.user, "db": self.database if with_db else None, "autocommit": True, "charset": self.charset, "minsize": self.pool_minsize, "maxsize": self.pool_maxsize, **self.extra, } try: self._pool = await aiomysql.create_pool(password=self.password, **pool_template) if isinstance(self._pool, aiomysql.Pool): async with self.acquire_connection() as connection: async with connection.cursor() as cursor: if self.storage_engine: await cursor.execute( f"SET default_storage_engine='{self.storage_engine}';" ) if self.storage_engine.lower( ) != "innodb": # pragma: nobranch self.capabilities.__dict__[ "supports_transactions"] = False self.log.debug("Created connection %s pool with params: %s", self._pool, pool_template) except pymysql.err.OperationalError: raise DBConnectionError( f"Can't connect to MySQL server: {pool_template}") async def _expire_connections(self) -> None: if self._pool: # pragma: nobranch for conn in self._pool._free: conn._reader.set_exception(EOFError("EOF")) async def close(self) -> None: if self._pool: # pragma: nobranch self._pool.close() await self._pool.wait_closed() self.log.debug("Closed connection pool %s", self._pool) self._pool = None async def db_create(self) -> None: await self.create_connection(with_db=False) await self.execute_script(f"CREATE DATABASE {self.database}") await self.close() async def db_delete(self) -> None: await self.create_connection(with_db=False) try: await self.execute_script(f"DROP DATABASE {self.database}") except pymysql.err.DatabaseError: # pragma: nocoverage pass await self.close() def acquire_connection(self) -> ConnectionWrapper: return self._pool.acquire() def in_transaction(self) -> "TransactionContext": return LockTransactionContext(TransactionWrapper(self)) @translate_mysql_exceptions async def execute_insert(self, query: str, values: list) -> int: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: await cursor.execute(query, values) return cursor.lastrowid # return auto-generated id @translate_mysql_exceptions async def execute_many(self, query: str, values: list) -> None: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: if self.capabilities.supports_transactions: await connection.begin() try: await cursor.executemany(query, values) except Exception: await connection.rollback() raise else: await connection.commit() else: await cursor.executemany(query, values) @translate_mysql_exceptions async def execute_query( self, query: str, values: Optional[list] = None ) -> Tuple[int, List[str], Sequence[Sequence[Any]]]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: await cursor.execute(query, values) rows = await cursor.fetchall() return cursor.rowcount, [ f.name for f in cursor._result.fields ] if rows else [], rows @translate_mysql_exceptions async def execute_script(self, query: str) -> None: async with self.acquire_connection() as connection: self.log.debug(query) async with connection.cursor() as cursor: await cursor.execute(query)
class AsyncpgDBClient(BaseDBAsyncClient): DSN_TEMPLATE = "postgres://{user}:{password}@{host}:{port}/{database}" query_class = PostgreSQLQuery executor_class = AsyncpgExecutor schema_generator = AsyncpgSchemaGenerator capabilities = Capabilities("postgres", pooling=True) def __init__(self, user: str, password: str, database: str, host: str, port: SupportsInt, min_size: SupportsInt = 10, max_size: SupportsInt = 10000, max_inactive_connection_lifetime=0, **kwargs) -> None: super().__init__(**kwargs) self.user = user self.password = password self.database = database self.host = host self.port = int(port) # make sure port is int type self.extra = kwargs.copy() self.extra.pop("connection_name", None) self.extra.pop("fetch_inserted", None) self.extra.pop("loop", None) self.extra.pop("connection_class", None) self._pool = None # type: Optional[asyncpg.pool.Pool] self.min_size = int(min_size) self.max_size = int(max_size) self.max_inactive_connection_lifetime = int( max_inactive_connection_lifetime) self._template = {} # type: dict self._transaction_class = type("TransactionWrapper", (TransactionWrapper, self.__class__), {}) async def create_connection(self, with_db: bool) -> None: self._template = { "host": self.host, "port": self.port, "user": self.user, "database": self.database if with_db else None, "min_size": self.min_size, "max_size": self.max_size, "max_inactive_connection_lifetime": self.max_inactive_connection_lifetime, **self.extra, } try: self._pool = await asyncpg.create_pool(None, password=self.password, **self._template) self.log.debug("Created pool %s with params: %s", self._pool, self._template) except asyncpg.InvalidCatalogNameError: raise DBConnectionError( "Can't establish connection to database {}".format( self.database)) async def _close(self) -> None: if self._pool: # pragma: nobranch try: await asyncio.wait_for(self._pool.close(), 10) except asyncio.TimeoutError: await self._pool.terminate() self.log.debug("Closed pool %s with params: %s", self._pool, self._template) self._template.clear() async def close(self) -> None: await self._close() self._pool = None async def db_create(self) -> None: await self.create_connection(with_db=False) await self.execute_script('CREATE DATABASE "{}" OWNER "{}"'.format( self.database, self.user)) await self.close() async def db_delete(self) -> None: await self.create_connection(with_db=False) try: await self.execute_script('DROP DATABASE "{}"'.format( self.database)) except asyncpg.InvalidCatalogNameError: # pragma: nocoverage pass await self.close() def acquire_connection(self) -> "AsyncContextManager": return PoolConnectionDispatcher(self._pool) def _in_transaction(self) -> "TransactionWrapper": transaction_wrapper = self._transaction_class(None, self) current_transaction.set(transaction_wrapper) return transaction_wrapper @translate_exceptions @retry_connection async def execute_insert(self, query: str, values: list) -> Optional[asyncpg.Record]: transaction_wrapper = current_transaction.get() if transaction_wrapper: stmt = await transaction_wrapper._connection.prepare(query) return await stmt.fetchrow(*values) async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) # TODO: Cache prepared statement stmt = await connection.prepare(query) return await stmt.fetchrow(*values) @translate_exceptions @retry_connection async def execute_many(self, query: str, values: list) -> None: transaction_wrapper = current_transaction.get() if transaction_wrapper: self.log.debug("%s: %s", query, values) # TODO: Consider using copy_records_to_table instead await transaction_wrapper._connection.executemany(query, values) async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) # TODO: Consider using copy_records_to_table instead await connection.executemany(query, values) @translate_exceptions @retry_connection async def execute_query(self, query: str) -> List[dict]: transaction_wrapper = current_transaction.get() if transaction_wrapper: self.log.debug(query) return await transaction_wrapper._connection.fetch(query) async with self.acquire_connection() as connection: self.log.debug(query) return await connection.fetch(query) @translate_exceptions @retry_connection async def execute_script(self, query: str) -> None: transaction_wrapper = current_transaction.get() if transaction_wrapper: self.log.debug(query) await transaction_wrapper._connection.execute(query) async with self.acquire_connection() as connection: self.log.debug(query) await connection.execute(query)
class AsyncpgDBClient(BaseDBAsyncClient): DSN_TEMPLATE = 'postgres://{user}:{password}@{host}:{port}/{database}' query_class = PostgreSQLQuery executor_class = AsyncpgExecutor schema_generator = AsyncpgSchemaGenerator capabilities = Capabilities('postgres') def __init__(self, user: str, password: str, database: str, host: str, port: SupportsInt, **kwargs) -> None: super().__init__(**kwargs) self.user = user self.password = password self.database = database self.host = host self.port = int(port) # make sure port is int type self._connection = None # Type: Optional[asyncpg.Connection] self._transaction_class = type('TransactionWrapper', (TransactionWrapper, self.__class__), {}) async def create_connection(self, with_db: bool) -> None: dsn = self.DSN_TEMPLATE.format( user=self.user, password=self.password, host=self.host, port=self.port, database=self.database if with_db else '') try: self._connection = await asyncpg.connect(dsn) self.log.debug( 'Created connection %s with params: user=%s database=%s host=%s port=%s', self._connection, self.user, self.database, self.host, self.port) except asyncpg.InvalidCatalogNameError: raise DBConnectionError( "Can't establish connection to database {}".format( self.database)) async def close(self) -> None: if self._connection: # pragma: nobranch await self._connection.close() self.log.debug( 'Closed connection %s with params: user=%s database=%s host=%s port=%s', self._connection, self.user, self.database, self.host, self.port) self._connection = None async def db_create(self) -> None: await self.create_connection(with_db=False) await self.execute_script('CREATE DATABASE "{}" OWNER "{}"'.format( self.database, self.user)) await self.close() async def db_delete(self) -> None: await self.create_connection(with_db=False) try: await self.execute_script('DROP DATABASE "{}"'.format( self.database)) except asyncpg.InvalidCatalogNameError: # pragma: nocoverage pass await self.close() def acquire_connection(self) -> ConnectionWrapper: return ConnectionWrapper(self._connection) def _in_transaction(self) -> 'TransactionWrapper': return self._transaction_class(self.connection_name, self._connection) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: async with self.acquire_connection() as connection: self.log.debug('%s: %s', query, values) # TODO: Cache prepared statement stmt = await connection.prepare(query) return await stmt.fetchval(*values) @translate_exceptions async def execute_query(self, query: str) -> List[dict]: async with self.acquire_connection() as connection: self.log.debug(query) return await connection.fetch(query) @translate_exceptions async def execute_script(self, query: str) -> None: async with self.acquire_connection() as connection: self.log.debug(query) await connection.execute(query)
class SqliteClient(BaseDBAsyncClient): executor_class = SqliteExecutor schema_generator = SqliteSchemaGenerator capabilities = Capabilities("sqlite", daemon=False, requires_limit=True, inline_comment=True) def __init__(self, file_path: str, **kwargs: Any) -> None: super().__init__(**kwargs) self.filename = file_path self.pragmas = kwargs.copy() self.pragmas.pop("connection_name", None) self.pragmas.pop("fetch_inserted", None) self.pragmas.setdefault("journal_mode", "WAL") self.pragmas.setdefault("journal_size_limit", 16384) self.pragmas.setdefault("foreign_keys", "ON") self._connection: Optional[aiosqlite.Connection] = None self._lock = asyncio.Lock() async def create_connection(self, with_db: bool) -> None: if not self._connection: # pragma: no branch self._connection = aiosqlite.connect(self.filename, isolation_level=None) self._connection.start() await self._connection._connect() self._connection._conn.row_factory = sqlite3.Row for pragma, val in self.pragmas.items(): cursor = await self._connection.execute(f"PRAGMA {pragma}={val}") await cursor.close() self.log.debug( "Created connection %s with params: filename=%s %s", self._connection, self.filename, " ".join([f"{k}={v}" for k, v in self.pragmas.items()]), ) async def close(self) -> None: if self._connection: await self._connection.close() self.log.debug( "Closed connection %s with params: filename=%s %s", self._connection, self.filename, " ".join([f"{k}={v}" for k, v in self.pragmas.items()]), ) self._connection = None async def db_create(self) -> None: pass async def db_delete(self) -> None: await self.close() try: os.remove(self.filename) except FileNotFoundError: # pragma: nocoverage pass def acquire_connection(self) -> ConnectionWrapper: return ConnectionWrapper(self._connection, self._lock) def _in_transaction(self) -> "TransactionContext": return TransactionContext(TransactionWrapper(self)) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) return (await connection.execute_insert(query, values))[0] @translate_exceptions async def execute_many(self, query: str, values: List[list]) -> None: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) # This code is only ever called in AUTOCOMMIT mode await connection.execute("BEGIN") try: await connection.executemany(query, values) except Exception: await connection.rollback() raise else: await connection.commit() @translate_exceptions async def execute_query( self, query: str, values: Optional[list] = None ) -> Tuple[int, Sequence[dict]]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) start = connection.total_changes rows = await connection.execute_fetchall(query, values) return (connection.total_changes - start) or len(rows), rows @translate_exceptions async def execute_query_dict(self, query: str, values: Optional[list] = None) -> List[dict]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) return list(map(dict, await connection.execute_fetchall(query, values))) @translate_exceptions async def execute_script(self, query: str) -> None: async with self.acquire_connection() as connection: self.log.debug(query) await connection.executescript(query)
class BasePostgresClient(BaseDBAsyncClient, abc.ABC): DSN_TEMPLATE = "postgres://{user}:{password}@{host}:{port}/{database}" query_class: Type[PostgreSQLQuery] = PostgreSQLQuery executor_class: Type[BasePostgresExecutor] = BasePostgresExecutor schema_generator: Type[ BasePostgresSchemaGenerator] = BasePostgresSchemaGenerator capabilities = Capabilities("postgres", support_update_limit_order_by=False) connection_class = None loop = None _pool: Optional[Any] = None _connection: Optional[Any] = None def __init__( self, user: Optional[str] = None, password: Optional[str] = None, database: Optional[str] = None, host: Optional[str] = None, port: SupportsInt = 5432, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.user = user self.password = password self.database = database self.host = host self.port = int(port) # make sure port is int type self.extra = kwargs.copy() # we can't deep copy kwargs because of ssl context # since server_settings is a dict, we copy it again self.server_settings = (self.extra.pop("server_settings", None) or {}).copy() self.schema = self.extra.pop("schema", None) self.application_name = self.extra.pop("application_name", None) self.extra.pop("connection_name", None) self.extra.pop("fetch_inserted", None) self.loop = self.extra.pop("loop", None) self.connection_class = self.extra.pop("connection_class", self.connection_class) self.pool_minsize = int(self.extra.pop("minsize", 1)) self.pool_maxsize = int(self.extra.pop("maxsize", 5)) self._template: dict = {} self._pool = None self._connection = None @abc.abstractmethod async def create_connection(self, with_db: bool) -> None: raise NotImplementedError("create_connection is not implemented") @abc.abstractmethod async def create_pool(self, **kwargs): raise NotImplementedError("create_pool is not implemented") @abc.abstractmethod async def _expire_connections(self) -> None: raise NotImplementedError("_expire_connections is not implemented") @abc.abstractmethod async def _close(self) -> None: raise NotImplementedError("_close is not implemented") @abc.abstractmethod async def _translate_exceptions(self, func, *args, **kwargs) -> Exception: raise NotImplementedError("translate_exceptions is not implemented") async def close(self) -> None: await self._close() self._template.clear() async def db_create(self) -> None: await self.create_connection(with_db=False) await self.execute_script( f'CREATE DATABASE "{self.database}" OWNER "{self.user}"') await self.close() async def db_delete(self) -> None: await self.create_connection(with_db=False) try: await self.execute_script(f'DROP DATABASE "{self.database}"') finally: await self.close() def acquire_connection( self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]: return PoolConnectionWrapper(self._pool) @abc.abstractmethod def _in_transaction(self) -> "TransactionContext": raise NotImplementedError("_in_transaction is not implemented") @abc.abstractmethod async def execute_insert(self, query: str, values: list) -> Optional[Any]: raise NotImplementedError("execute_insert is not implemented") @abc.abstractmethod async def execute_many(self, query: str, values: list) -> None: raise NotImplementedError("execute_many is not implemented") @abc.abstractmethod async def execute_query( self, query: str, values: Optional[list] = None) -> Tuple[int, List[dict]]: raise NotImplementedError("execute_query is not implemented") @abc.abstractmethod async def execute_query_dict(self, query: str, values: Optional[list] = None) -> List[dict]: raise NotImplementedError("execute_query_dict is not implemented") @translate_exceptions async def execute_script(self, query: str) -> None: async with self.acquire_connection() as connection: self.log.debug(query) await connection.execute(query)
class MySQLClient(BaseDBAsyncClient): query_class = MySQLQuery executor_class = MySQLExecutor schema_generator = MySQLSchemaGenerator capabilities = Capabilities('mysql', safe_indexes=False, requires_limit=True) def __init__(self, user: str, password: str, database: str, host: str, port: SupportsInt, **kwargs) -> None: super().__init__(**kwargs) self.user = user self.password = password self.database = database self.host = host self.port = int(port) # make sure port is int type self._connection = None # Type: Optional[aiomysql.Connection] self._lock = asyncio.Lock() self._transaction_class = type( 'TransactionWrapper', (TransactionWrapper, self.__class__), {} ) async def create_connection(self, with_db: bool) -> None: template = { 'host': self.host, 'port': self.port, 'user': self.user, 'password': self.password, 'db': self.database if with_db else None, 'autocommit': True, } try: self._connection = await aiomysql.connect(**template) self.log.debug( 'Created connection %s with params: user=%s database=%s host=%s port=%s', self._connection, self.user, self.database, self.host, self.port ) except pymysql.err.OperationalError: raise DBConnectionError( "Can't connect to MySQL server: " 'user={user} database={database} host={host} port={port}'.format( user=self.user, database=self.database, host=self.host, port=self.port ) ) async def close(self) -> None: if self._connection: # pragma: nobranch self._connection.close() self.log.debug( 'Closed connection %s with params: user=%s database=%s host=%s port=%s', self._connection, self.user, self.database, self.host, self.port ) self._connection = None async def db_create(self) -> None: await self.create_connection(with_db=False) await self.execute_script( 'CREATE DATABASE {}'.format(self.database) ) await self.close() async def db_delete(self) -> None: await self.create_connection(with_db=False) try: await self.execute_script('DROP DATABASE {}'.format(self.database)) except pymysql.err.DatabaseError: # pragma: nocoverage pass await self.close() def acquire_connection(self) -> ConnectionWrapper: return ConnectionWrapper(self._connection, self._lock) def _in_transaction(self): return self._transaction_class(self.connection_name, self._connection, self._lock) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: async with self.acquire_connection() as connection: self.log.debug('%s: %s', query, values) async with connection.cursor() as cursor: # TODO: Use prepared statement, and cache it await cursor.execute(query, values) return cursor.lastrowid # return auto-generated id @translate_exceptions async def execute_query(self, query: str) -> List[aiomysql.DictCursor]: async with self.acquire_connection() as connection: self.log.debug(query) async with connection.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(query) return await cursor.fetchall() @translate_exceptions async def execute_script(self, query: str) -> None: async with self.acquire_connection() as connection: self.log.debug(query) async with connection.cursor() as cursor: await cursor.execute(query)
class AsyncpgDBClient(BaseDBAsyncClient): DSN_TEMPLATE = "postgres://{user}:{password}@{host}:{port}/{database}" query_class = PostgreSQLQuery executor_class = AsyncpgExecutor schema_generator = AsyncpgSchemaGenerator capabilities = Capabilities("postgres") def __init__(self, user: str, password: str, database: str, host: str, port: SupportsInt, **kwargs) -> None: super().__init__(**kwargs) self.user = user self.password = password self.database = database self.host = host self.port = int(port) # make sure port is int type self.extra = kwargs.copy() self.schema = self.extra.pop("schema", None) self.extra.pop("connection_name", None) self.extra.pop("fetch_inserted", None) self.extra.pop("loop", None) self.extra.pop("connection_class", None) self._template: dict = {} self._connection: Optional[asyncpg.Connection] = None self._lock = asyncio.Lock() self._transaction_class = type("TransactionWrapper", (TransactionWrapper, self.__class__), {}) async def create_connection(self, with_db: bool) -> None: self._template = { "host": self.host, "port": self.port, "user": self.user, "database": self.database if with_db else None, **self.extra, } try: self._connection = await asyncpg.connect(None, password=self.password, **self._template) self.log.debug("Created connection %s with params: %s", self._connection, self._template) except asyncpg.InvalidCatalogNameError: raise DBConnectionError( f"Can't establish connection to database {self.database}") # Set post-connection variables if self.schema: await self.execute_script(f"SET search_path TO {self.schema}") async def _close(self) -> None: if self._connection: # pragma: nobranch await self._connection.close() self.log.debug("Closed connection %s with params: %s", self._connection, self._template) self._template.clear() async def close(self) -> None: await self._close() self._connection = None async def db_create(self) -> None: await self.create_connection(with_db=False) await self.execute_script( f'CREATE DATABASE "{self.database}" OWNER "{self.user}"') await self.close() async def db_delete(self) -> None: await self.create_connection(with_db=False) try: await self.execute_script(f'DROP DATABASE "{self.database}"') except asyncpg.InvalidCatalogNameError: # pragma: nocoverage pass await self.close() def acquire_connection(self) -> ConnectionWrapper: return ConnectionWrapper(self._connection, self._lock) def _in_transaction(self) -> "TransactionWrapper": return self._transaction_class(self) @translate_exceptions @retry_connection async def execute_insert(self, query: str, values: list) -> Optional[asyncpg.Record]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) # TODO: Cache prepared statement stmt = await connection.prepare(query) return await stmt.fetchrow(*values) @translate_exceptions @retry_connection async def execute_many(self, query: str, values: list) -> None: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) # TODO: Consider using copy_records_to_table instead await connection.executemany(query, values) @translate_exceptions @retry_connection async def execute_query(self, query: str, values: Optional[list] = None) -> List[dict]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) if values: # TODO: Cache prepared statement stmt = await connection.prepare(query) return await stmt.fetch(*values) return await connection.fetch(query) @translate_exceptions @retry_connection async def execute_script(self, query: str) -> None: async with self.acquire_connection() as connection: self.log.debug(query) await connection.execute(query)
class SqliteClient(BaseDBAsyncClient): executor_class = SqliteExecutor schema_generator = SqliteSchemaGenerator capabilities = Capabilities('sqlite', requires_limit=True) def __init__(self, file_path: str, **kwargs) -> None: super().__init__(**kwargs) self.filename = file_path self._transaction_class = type( 'TransactionWrapper', (TransactionWrapper, self.__class__), {} ) self._connection = None # type: Optional[aiosqlite.Connection] self._lock = asyncio.Lock() async def create_connection(self, with_db: bool) -> None: if not self._connection: # pragma: no branch self._connection = aiosqlite.connect(self.filename, isolation_level=None) self._connection.start() await self._connection._connect() self._connection._conn.row_factory = sqlite3.Row self.log.debug( 'Created connection %s with params: filename=%s', self._connection, self.filename ) async def close(self) -> None: if self._connection: await self._connection.close() self.log.debug( 'Closed connection %s with params: filename=%s', self._connection, self.filename ) self._connection = None async def db_create(self) -> None: pass async def db_delete(self) -> None: await self.close() try: os.remove(self.filename) except FileNotFoundError: # pragma: nocoverage pass def acquire_connection(self) -> ConnectionWrapper: return ConnectionWrapper(self._connection, self._lock) def _in_transaction(self) -> 'TransactionWrapper': return self._transaction_class(self.connection_name, self._connection, self._lock) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: async with self.acquire_connection() as connection: self.log.debug('%s: %s', query, values) return (await connection.execute_insert(query, values))[0] @translate_exceptions async def execute_query(self, query: str) -> List[dict]: async with self.acquire_connection() as connection: self.log.debug(query) return [dict(row) for row in await connection.execute_fetchall(query)] @translate_exceptions async def execute_script(self, query: str) -> None: async with self.acquire_connection() as connection: self.log.debug(query) await connection.executescript(query)
class MSSqlClient(BaseDBAsyncClient): query_class = MSSQLQuery executor_class = MSSQLExecutor schema_generator = MSSQLSchemaGenerator capabilities = Capabilities("mssql", requires_limit=True, inline_comment=True) def __init__( self, *, user: str, password: str, database: str, host: str, port: SupportsInt, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.user = user self.password = password self.database = database self.host = host self.port = int(port) # make sure port is int type self.extra = kwargs.copy() self.storage_engine = self.extra.pop("storage_engine", "") self.extra.pop("connection_name", None) self.extra.pop("fetch_inserted", None) self.extra.pop("db", None) self.extra.pop("autocommit", None) self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES") self.charset = self.extra.pop("charset", "utf8mb4") self.pool_minsize = int(self.extra.pop("minsize", 1)) self.pool_maxsize = int(self.extra.pop("maxsize", 5)) self.pool_recycle = int(self.extra.pop("pool_recycle", 5)) self.timeout = int(self.extra.pop("timeout", 5)) dsn = self.extra.pop("dsn", "") if dsn: self.dsn = urllib.parse.unquote(dsn) else: self.dsn = "DSN=MYMSSQL;DATABASE={DB_DATABASE};UID={DB_USER};PWD={DB_PASSWORD}".format( DB_DATABASE=database, DB_USER=user, DB_PASSWORD=password) self._template: dict = {} self._pool: Optional[aioodbc.Pool] = None self._connection = None async def create_connection(self, with_db: bool) -> None: self._template = { "dsn": self.dsn, "minsize": self.pool_minsize, "maxsize": self.pool_maxsize, "executor": _g_thread_executor, "pool_recycle": self.pool_recycle, "timeout": self.timeout, } try: self._pool = await aioodbc.create_pool(**self._template) if isinstance(self._pool, aioodbc.Pool): async with self.acquire_connection() as connection: async with connection.cursor() as cursor: if self.storage_engine: await cursor.execute( f"SET default_storage_engine='{self.storage_engine}';" ) if self.storage_engine.lower( ) != "innodb": # pragma: nobranch self.capabilities.__dict__[ "supports_transactions"] = False self.log.debug("Created connection %s pool with params: %s", self._pool, self._template) except pyodbc.OperationalError as e: raise DBConnectionError( f"Can't connect to MySQL server: {self._template}") async def _expire_connections(self) -> None: if self._pool: # pragma: nobranch for conn in self._pool._free: conn._reader.set_exception(EOFError("EOF")) async def _close(self) -> None: if self._pool: # pragma: nobranch self._pool.close() await self._pool.wait_closed() self.log.debug("Closed connection %s with params: %s", self._connection, self._template) self._pool = None async def close(self) -> None: await self._close() self._template.clear() async def db_create(self) -> None: await self.create_connection(with_db=False) await self.execute_script(f"CREATE DATABASE {self.database}") await self.close() async def db_delete(self) -> None: await self.create_connection(with_db=False) try: await self.execute_script(f"DROP DATABASE {self.database}") except pyodbc.DatabaseError: # pragma: nocoverage pass await self.close() def acquire_connection( self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]: return PoolConnectionWrapper(self._pool) def _in_transaction(self) -> "TransactionContext": return TransactionContextPooled(TransactionWrapper(self)) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: await cursor.execute(query, values) if "output inserted" in query: row = await cursor.fetchone() return row[0] # return cursor.lastrowid # return auto-generated id @translate_exceptions async def execute_many(self, query: str, values: list) -> None: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: if self.capabilities.supports_transactions: await connection.begin() try: await cursor.executemany(query, values) except Exception: await connection.rollback() raise else: await connection.commit() else: await cursor.executemany(query, values) @translate_exceptions async def execute_query( self, query: str, values: Optional[list] = None) -> Tuple[int, List[dict]]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: if not values: values = tuple() await cursor.execute(query, values) if "UPDATE" in query: return cursor.rowcount, [] if "DELETE" in query: return cursor.rowcount, [] rows = await cursor.fetchall() if rows: return cursor.rowcount, [ dict( zip([l[0].lower() for l in row.cursor_description], row)) for row in rows ] return cursor.rowcount, [] async def execute_query_dict(self, query: str, values: Optional[list] = None) -> List[dict]: return (await self.execute_query(query, values))[1] @translate_exceptions async def execute_script(self, query: str) -> None: async with self.acquire_connection() as connection: self.log.debug(query) async with connection.cursor() as cursor: await cursor.execute(query)
class AsyncpgDBClient(BaseDBAsyncClient): DSN_TEMPLATE = "postgres://{user}:{password}@{host}:{port}/{database}" query_class = PostgreSQLQuery executor_class = AsyncpgExecutor schema_generator = AsyncpgSchemaGenerator capabilities = Capabilities("postgres") connection_class = asyncpg.connection.Connection loop = None def __init__(self, user: str, password: str, database: str, host: str, port: SupportsInt, **kwargs: Any) -> None: super().__init__(**kwargs) self.user = user self.password = password self.database = database self.host = host self.port = int(port) # make sure port is int type self.extra = kwargs.copy() self.schema = self.extra.pop("schema", None) self.extra.pop("connection_name", None) self.extra.pop("fetch_inserted", None) self.loop = self.extra.pop("loop", None) self.connection_class = self.extra.pop("connection_class", self.connection_class) self.pool_minsize = int(self.extra.pop("minsize", 1)) self.pool_maxsize = int(self.extra.pop("maxsize", 5)) self._template: dict = {} self._pool: Optional[asyncpg.pool] = None self._connection = None async def create_connection(self, with_db: bool) -> None: self._template = { "host": self.host, "port": self.port, "user": self.user, "database": self.database if with_db else None, "min_size": self.pool_minsize, "max_size": self.pool_maxsize, "connection_class": self.connection_class, "loop": self.loop, **self.extra, } if self.schema: self._template["server_settings"] = {"search_path": self.schema} try: self._pool = await asyncpg.create_pool(None, password=self.password, **self._template) self.log.debug("Created connection pool %s with params: %s", self._pool, self._template) except asyncpg.InvalidCatalogNameError: raise DBConnectionError( f"Can't establish connection to database {self.database}") # Set post-connection variables async def _expire_connections(self) -> None: if self._pool: # pragma: nobranch await self._pool.expire_connections() async def _close(self) -> None: if self._pool: # pragma: nobranch try: await asyncio.wait_for(self._pool.close(), 10) except asyncio.TimeoutError: # pragma: nocoverage self._pool.terminate() self._pool = None self.log.debug("Closed connection pool %s with params: %s", self._pool, self._template) async def close(self) -> None: await self._close() self._template.clear() async def db_create(self) -> None: await self.create_connection(with_db=False) await self.execute_script( f'CREATE DATABASE "{self.database}" OWNER "{self.user}"') await self.close() async def db_delete(self) -> None: await self.create_connection(with_db=False) try: await self.execute_script(f'DROP DATABASE "{self.database}"') except asyncpg.InvalidCatalogNameError: # pragma: nocoverage pass await self.close() def acquire_connection( self) -> Union["PoolConnectionWrapper", "ConnectionWrapper"]: return PoolConnectionWrapper(self._pool) def _in_transaction(self) -> "TransactionContext": return TransactionContextPooled(TransactionWrapper(self)) @translate_exceptions async def execute_insert(self, query: str, values: list) -> Optional[asyncpg.Record]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) # TODO: Cache prepared statement return await connection.fetchrow(query, *values) @translate_exceptions async def execute_many(self, query: str, values: list) -> None: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) # TODO: Consider using copy_records_to_table instead transaction = connection.transaction() await transaction.start() try: await connection.executemany(query, values) except Exception: await transaction.rollback() raise else: await transaction.commit() @translate_exceptions async def execute_query( self, query: str, values: Optional[list] = None) -> Tuple[int, List[dict]]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) if values: params = [query, *values] else: params = [query] if query.startswith("UPDATE") or query.startswith("DELETE"): res = await connection.execute(*params) try: rows_affected = int(res.split(" ")[1]) except Exception: # pragma: nocoverage rows_affected = 0 return rows_affected, [] rows = await connection.fetch(*params) return len(rows), rows @translate_exceptions async def execute_query_dict(self, query: str, values: Optional[list] = None) -> List[dict]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) if values: return list(map(dict, await connection.fetch(query, *values))) return list(map(dict, await connection.fetch(query))) @translate_exceptions async def execute_script(self, query: str) -> None: async with self.acquire_connection() as connection: self.log.debug(query) await connection.execute(query)
class MySQLClient(BaseDBAsyncClient): query_class = MySQLQuery executor_class = MySQLExecutor schema_generator = MySQLSchemaGenerator capabilities = Capabilities("mysql", safe_indexes=False, requires_limit=True, inline_comment=True) def __init__(self, *, user: str, password: str, database: str, host: str, port: SupportsInt, **kwargs) -> None: super().__init__(**kwargs) self.user = user self.password = password self.database = database self.host = host self.port = int(port) # make sure port is int type self.extra = kwargs.copy() self.extra.pop("connection_name", None) self.extra.pop("fetch_inserted", None) self.extra.pop("db", None) self.extra.pop("autocommit", None) self.charset = self.extra.pop("charset", "") self._template = {} # type: dict self._connection = None # Type: Optional[aiomysql.Connection] self._lock = asyncio.Lock() self._transaction_class = type("TransactionWrapper", (TransactionWrapper, self.__class__), {}) async def create_connection(self, with_db: bool) -> None: self._template = { "host": self.host, "port": self.port, "user": self.user, "db": self.database if with_db else None, "autocommit": True, "charset": self.charset, **self.extra, } try: self._connection = await aiomysql.connect(password=self.password, **self._template) self.log.debug("Created connection %s with params: %s", self._connection, self._template) except pymysql.err.OperationalError: raise DBConnectionError( "Can't connect to MySQL server: {template}".format( template=self._template)) async def _close(self) -> None: if self._connection: # pragma: nobranch self._connection.close() self.log.debug("Closed connection %s with params: %s", self._connection, self._template) self._template.clear() async def close(self) -> None: await self._close() self._connection = None async def db_create(self) -> None: await self.create_connection(with_db=False) await self.execute_script("CREATE DATABASE {}".format(self.database)) await self.close() async def db_delete(self) -> None: await self.create_connection(with_db=False) try: await self.execute_script("DROP DATABASE {}".format(self.database)) except pymysql.err.DatabaseError: # pragma: nocoverage pass await self.close() def acquire_connection(self) -> ConnectionWrapper: return ConnectionWrapper(self._connection, self._lock) def _in_transaction(self): return self._transaction_class(self) @translate_exceptions @retry_connection async def execute_insert(self, query: str, values: list) -> int: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: await cursor.execute(query, values) return cursor.lastrowid # return auto-generated id @translate_exceptions @retry_connection async def execute_many(self, query: str, values: list) -> None: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: await cursor.executemany(query, values) @translate_exceptions @retry_connection async def execute_query( self, query: str, values: Optional[list] = None) -> List[aiomysql.DictCursor]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(query, values) return await cursor.fetchall() @translate_exceptions @retry_connection async def execute_script(self, query: str) -> None: async with self.acquire_connection() as connection: self.log.debug(query) async with connection.cursor() as cursor: await cursor.execute(query)