Example #1
0
class MSSQLClient(ODBCClient):
    query_class = MSSQLQuery
    schema_generator = MSSQLSchemaGenerator
    executor_class = MSSQLExecutor
    capabilities = Capabilities("mssql",
                                support_update_limit_order_by=False,
                                support_for_update=False)

    def __init__(
        self,
        *,
        user: str,
        password: str,
        host: str,
        port: SupportsInt,
        driver: str,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.dsn = f"DRIVER={driver};SERVER={host},{port};UID={user};PWD={password};"

    def _in_transaction(self) -> "TransactionContext":
        return TransactionContextPooled(TransactionWrapper(self))

    @translate_exceptions
    async def execute_insert(self, query: str, values: list) -> int:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            async with connection.cursor() as cursor:
                await cursor.execute(query, values)
                await cursor.execute("SELECT @@IDENTITY;")
                return (await cursor.fetchone())[0]
Example #2
0
class OracleClient(ODBCClient):
    query_class = OracleQuery
    schema_generator = OracleSchemaGenerator
    executor_class = OracleExecutor
    capabilities = Capabilities(dialect="oracle")

    def __init__(
        self,
        *,
        user: str,
        password: str,
        host: str,
        port: SupportsInt,
        driver: str,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.password = password
        self.dsn = f"DRIVER={driver};DBQ={host}:{port};UID={user};PWD={password};"

    def _in_transaction(self) -> "TransactionContext":
        return TransactionContextPooled(TransactionWrapper(self))

    def acquire_connection(
            self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]:
        return OraclePoolConnectionWrapper(self)

    async def db_create(self) -> None:
        await self.create_connection(with_db=False)
        await self.execute_script(
            f'CREATE USER "{self.database}" IDENTIFIED BY "{self.password}"')
        await self.execute_script(f'GRANT ALL PRIVILEGES TO "{self.database}"')
        await self.close()

    async def db_delete(self) -> None:
        await self.create_connection(with_db=False)
        try:
            await self.execute_script(f'DROP USER "{self.database}" CASCADE')
        except pyodbc.Error as e:
            if "does not exist" not in str(e):
                raise
        await self.close()

    @translate_exceptions
    async def execute_script(self, query: str) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            async with connection.cursor() as cursor:
                for q in query.split(";"):
                    if not q.strip():
                        continue
                    await cursor.execute(q)

    @translate_exceptions
    async def execute_insert(self, query: str, values: list) -> int:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            await connection.execute(query, values)
            return 0
Example #3
0
 def __init__(self, file_path: str, **kwargs) -> None:
     super().__init__(**kwargs)
     self.filename = file_path
     self._transaction_class = type('TransactionWrapper',
                                    (TransactionWrapper, self.__class__),
                                    {})
     self._connection = None  # type: Optional[aiosqlite.Connection]
     self.capabilities = Capabilities('sqlite',
                                      connection={'file': file_path})
Example #4
0
    def __init__(self, user: str, password: str, database: str, host: str,
                 port: SupportsInt, **kwargs) -> None:
        super().__init__(**kwargs)

        self.user = user
        self.password = password
        self.database = database
        self.host = host
        self.port = int(port)  # make sure port is int type

        self._connection = None  # Type: Optional[aiomysql.Connection]

        self._transaction_class = type('TransactionWrapper',
                                       (TransactionWrapper, self.__class__),
                                       {})
        self.capabilities = Capabilities('mysql',
                                         connection={
                                             'user': user,
                                             'database': database,
                                             'host': host,
                                             'port': port,
                                         })
Example #5
0
class MySQLClient(BaseDBAsyncClient):
    query_class = MySQLQuery
    executor_class = MySQLExecutor
    schema_generator = MySQLSchemaGenerator
    capabilities = Capabilities("mysql", requires_limit=True, inline_comment=True)

    def __init__(
        self,
        *,
        user: str,
        password: str,
        database: str,
        host: str,
        port: SupportsInt,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)

        self.user = user
        self.password = password
        self.database = database
        self.host = host
        self.port = int(port)  # make sure port is int type
        self.extra = kwargs.copy()
        self.storage_engine = self.extra.pop("storage_engine", "")
        self.extra.pop("connection_name", None)
        self.extra.pop("fetch_inserted", None)
        self.extra.pop("db", None)
        self.extra.pop("autocommit", None)
        self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES")
        self.charset = self.extra.pop("charset", "utf8mb4")
        self.pool_minsize = int(self.extra.pop("minsize", 1))
        self.pool_maxsize = int(self.extra.pop("maxsize", 5))

        self._template: dict = {}
        self._pool: Optional[aiomysql.Pool] = None
        self._connection = None

    async def create_connection(self, with_db: bool) -> None:
        if charset_by_name(self.charset) is None:  # type: ignore
            raise DBConnectionError(f"Unknown charset {self.charset}")
        self._template = {
            "host": self.host,
            "port": self.port,
            "user": self.user,
            "db": self.database if with_db else None,
            "autocommit": True,
            "charset": self.charset,
            "minsize": self.pool_minsize,
            "maxsize": self.pool_maxsize,
            **self.extra,
        }
        try:
            self._pool = await aiomysql.create_pool(password=self.password, **self._template)

            if isinstance(self._pool, aiomysql.Pool):
                async with self.acquire_connection() as connection:
                    async with connection.cursor() as cursor:
                        if self.storage_engine:
                            await cursor.execute(
                                f"SET default_storage_engine='{self.storage_engine}';"
                            )
                            if self.storage_engine.lower() != "innodb":  # pragma: nobranch
                                self.capabilities.__dict__["supports_transactions"] = False
                        hours = timezone.now().utcoffset().seconds / 3600  # type: ignore
                        tz = "{:+d}:{:02d}".format(int(hours), int((hours % 1) * 60))
                        await cursor.execute(f"SET SESSION time_zone='{tz}';")
            self.log.debug("Created connection %s pool with params: %s", self._pool, self._template)
        except pymysql.err.OperationalError:
            raise DBConnectionError(f"Can't connect to MySQL server: {self._template}")

    async def _expire_connections(self) -> None:
        if self._pool:  # pragma: nobranch
            for conn in self._pool._free:
                conn._reader.set_exception(EOFError("EOF"))

    async def _close(self) -> None:
        if self._pool:  # pragma: nobranch
            self._pool.close()
            await self._pool.wait_closed()
            self.log.debug("Closed connection %s with params: %s", self._connection, self._template)
            self._pool = None

    async def close(self) -> None:
        await self._close()
        self._template.clear()

    async def db_create(self) -> None:
        await self.create_connection(with_db=False)
        await self.execute_script(f"CREATE DATABASE {self.database}")
        await self.close()

    async def db_delete(self) -> None:
        await self.create_connection(with_db=False)
        try:
            await self.execute_script(f"DROP DATABASE {self.database}")
        except pymysql.err.DatabaseError:  # pragma: nocoverage
            pass
        await self.close()

    def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]:
        return PoolConnectionWrapper(self._pool)

    def _in_transaction(self) -> "TransactionContext":
        return TransactionContextPooled(TransactionWrapper(self))

    @translate_exceptions
    async def execute_insert(self, query: str, values: list) -> int:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            async with connection.cursor() as cursor:
                await cursor.execute(query, values)
                return cursor.lastrowid  # return auto-generated id

    @translate_exceptions
    async def execute_many(self, query: str, values: list) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            async with connection.cursor() as cursor:
                if self.capabilities.supports_transactions:
                    await connection.begin()
                    try:
                        await cursor.executemany(query, values)
                    except Exception:
                        await connection.rollback()
                        raise
                    else:
                        await connection.commit()
                else:
                    await cursor.executemany(query, values)

    @translate_exceptions
    async def execute_query(
        self, query: str, values: Optional[list] = None
    ) -> Tuple[int, List[dict]]:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            async with connection.cursor() as cursor:
                await cursor.execute(query, values)
                rows = await cursor.fetchall()
                if rows:
                    fields = [f.name for f in cursor._result.fields]
                    return cursor.rowcount, [dict(zip(fields, row)) for row in rows]
                return cursor.rowcount, []

    async def execute_query_dict(self, query: str, values: Optional[list] = None) -> List[dict]:
        return (await self.execute_query(query, values))[1]

    @translate_exceptions
    async def execute_script(self, query: str) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            async with connection.cursor() as cursor:
                await cursor.execute(query)
Example #6
0
class SqliteClient(BaseDBAsyncClient):
    executor_class = SqliteExecutor
    schema_generator = SqliteSchemaGenerator
    capabilities = Capabilities("sqlite", daemon=False, requires_limit=True)

    def __init__(self, file_path: str, **kwargs) -> None:
        super().__init__(**kwargs)
        self.filename = file_path

        self.pragmas = kwargs.copy()
        self.pragmas.pop("connection_name", None)
        self.pragmas.pop("fetch_inserted", None)

        self._transaction_class = type(
            "TransactionWrapper", (TransactionWrapper, self.__class__), {}
        )
        self._connection = None  # type: Optional[aiosqlite.Connection]
        self._lock = asyncio.Lock()

    async def create_connection(self, with_db: bool) -> None:
        if not self._connection:  # pragma: no branch
            self._connection = aiosqlite.connect(self.filename, isolation_level=None)
            self._connection.start()
            await self._connection._connect()
            self._connection._conn.row_factory = sqlite3.Row
            for pragma, val in self.pragmas.items():
                cursor = await self._connection.execute("PRAGMA {}={}".format(pragma, val))
                await cursor.close()
            self.log.debug(
                "Created connection %s with params: filename=%s %s",
                self._connection,
                self.filename,
                " ".join(["{}={}".format(k, v) for k, v in self.pragmas.items()]),
            )

    async def close(self) -> None:
        if self._connection:
            await self._connection.close()
            self.log.debug(
                "Closed connection %s with params: filename=%s %s",
                self._connection,
                self.filename,
                " ".join(["{}={}".format(k, v) for k, v in self.pragmas.items()]),
            )
            self._connection = None

    async def db_create(self) -> None:
        pass

    async def db_delete(self) -> None:
        await self.close()
        try:
            os.remove(self.filename)
        except FileNotFoundError:  # pragma: nocoverage
            pass

    def acquire_connection(self) -> ConnectionWrapper:
        return ConnectionWrapper(self._connection, self._lock)

    def _in_transaction(self) -> "TransactionWrapper":
        return self._transaction_class(
            connection_name=self.connection_name,
            connection=self._connection,
            lock=self._lock,
            fetch_inserted=self.fetch_inserted,
        )

    @translate_exceptions
    async def execute_insert(self, query: str, values: list) -> int:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            return (await connection.execute_insert(query, values))[0]

    @translate_exceptions
    async def execute_many(self, query: str, values: List[list]) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            # TODO: Ensure that this is wrapped by a transaction, will provide a big speedup
            await connection.executemany(query, values)

    @translate_exceptions
    async def execute_query(self, query: str) -> List[dict]:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            res = [dict(row) for row in await connection.execute_fetchall(query)]
            return res

    @translate_exceptions
    async def execute_script(self, query: str) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            await connection.executescript(query)
Example #7
0
class MySQLClient(BaseDBAsyncClient):

    query_class = MySQLQuery
    filter_class = MySQLFilter
    executor_class = MySQLExecutor
    schema_generator = MySQLSchemaGenerator
    capabilities = Capabilities("mysql",
                                requires_limit=True,
                                inline_comment=True)

    def __init__(self, *, user: str, password: str, database: str, host: str,
                 port: SupportsInt, **kwargs) -> None:
        super().__init__(**kwargs)

        self.user = user
        self.password = password
        self.database = database
        self.host = host
        self.port = int(port)  # make sure port is int type
        self.extra = kwargs.copy()
        self.storage_engine = self.extra.pop("storage_engine", "")
        self.extra.pop("connection_name", None)
        self.extra.pop("db", None)
        self.extra.pop("autocommit", None)
        self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES")
        self.charset = self.extra.pop("charset", "utf8mb4")
        self.pool_minsize = int(self.extra.pop("minsize", 1))
        self.pool_maxsize = int(self.extra.pop("maxsize", 5))

        self._pool: Optional[aiomysql.Pool] = None

    def _copy(self, base) -> None:
        super()._copy(base)

        self.user = base.user
        self.password = base.password
        self.database = base.database
        self.host = base.host
        self.port = base.port
        self.extra = base.extra
        self.storage_engine = base.storage_engine
        self.charset = base.charset
        self.pool_minsize = base.pool_minsize
        self.pool_maxsize = base.pool_maxsize

        self._pool = base._pool

    async def create_connection(self, with_db: bool) -> None:
        if charset_by_name(self.charset) is None:  # type: ignore
            raise DBConnectionError(f"Unknown charset {self.charset}")

        pool_template = {
            "host": self.host,
            "port": self.port,
            "user": self.user,
            "db": self.database if with_db else None,
            "autocommit": True,
            "charset": self.charset,
            "minsize": self.pool_minsize,
            "maxsize": self.pool_maxsize,
            **self.extra,
        }

        try:
            self._pool = await aiomysql.create_pool(password=self.password,
                                                    **pool_template)

            if isinstance(self._pool, aiomysql.Pool):
                async with self.acquire_connection() as connection:
                    async with connection.cursor() as cursor:
                        if self.storage_engine:
                            await cursor.execute(
                                f"SET default_storage_engine='{self.storage_engine}';"
                            )
                            if self.storage_engine.lower(
                            ) != "innodb":  # pragma: nobranch
                                self.capabilities.__dict__[
                                    "supports_transactions"] = False

            self.log.debug("Created connection %s pool with params: %s",
                           self._pool, pool_template)

        except pymysql.err.OperationalError:
            raise DBConnectionError(
                f"Can't connect to MySQL server: {pool_template}")

    async def _expire_connections(self) -> None:
        if self._pool:  # pragma: nobranch
            for conn in self._pool._free:
                conn._reader.set_exception(EOFError("EOF"))

    async def close(self) -> None:
        if self._pool:  # pragma: nobranch
            self._pool.close()
            await self._pool.wait_closed()

            self.log.debug("Closed connection pool %s", self._pool)
            self._pool = None

    async def db_create(self) -> None:
        await self.create_connection(with_db=False)
        await self.execute_script(f"CREATE DATABASE {self.database}")
        await self.close()

    async def db_delete(self) -> None:
        await self.create_connection(with_db=False)
        try:
            await self.execute_script(f"DROP DATABASE {self.database}")
        except pymysql.err.DatabaseError:  # pragma: nocoverage
            pass
        await self.close()

    def acquire_connection(self) -> ConnectionWrapper:
        return self._pool.acquire()

    def in_transaction(self) -> "TransactionContext":
        return LockTransactionContext(TransactionWrapper(self))

    @translate_mysql_exceptions
    async def execute_insert(self, query: str, values: list) -> int:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            async with connection.cursor() as cursor:
                await cursor.execute(query, values)
                return cursor.lastrowid  # return auto-generated id

    @translate_mysql_exceptions
    async def execute_many(self, query: str, values: list) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            async with connection.cursor() as cursor:
                if self.capabilities.supports_transactions:
                    await connection.begin()
                    try:
                        await cursor.executemany(query, values)
                    except Exception:
                        await connection.rollback()
                        raise
                    else:
                        await connection.commit()
                else:
                    await cursor.executemany(query, values)

    @translate_mysql_exceptions
    async def execute_query(
        self,
        query: str,
        values: Optional[list] = None
    ) -> Tuple[int, List[str], Sequence[Sequence[Any]]]:

        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            async with connection.cursor() as cursor:
                await cursor.execute(query, values)
                rows = await cursor.fetchall()
                return cursor.rowcount, [
                    f.name for f in cursor._result.fields
                ] if rows else [], rows

    @translate_mysql_exceptions
    async def execute_script(self, query: str) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            async with connection.cursor() as cursor:
                await cursor.execute(query)
Example #8
0
class AsyncpgDBClient(BaseDBAsyncClient):
    DSN_TEMPLATE = "postgres://{user}:{password}@{host}:{port}/{database}"
    query_class = PostgreSQLQuery
    executor_class = AsyncpgExecutor
    schema_generator = AsyncpgSchemaGenerator
    capabilities = Capabilities("postgres", pooling=True)

    def __init__(self,
                 user: str,
                 password: str,
                 database: str,
                 host: str,
                 port: SupportsInt,
                 min_size: SupportsInt = 10,
                 max_size: SupportsInt = 10000,
                 max_inactive_connection_lifetime=0,
                 **kwargs) -> None:
        super().__init__(**kwargs)

        self.user = user
        self.password = password
        self.database = database
        self.host = host
        self.port = int(port)  # make sure port is int type
        self.extra = kwargs.copy()
        self.extra.pop("connection_name", None)
        self.extra.pop("fetch_inserted", None)
        self.extra.pop("loop", None)
        self.extra.pop("connection_class", None)
        self._pool = None  # type: Optional[asyncpg.pool.Pool]
        self.min_size = int(min_size)
        self.max_size = int(max_size)
        self.max_inactive_connection_lifetime = int(
            max_inactive_connection_lifetime)

        self._template = {}  # type: dict

        self._transaction_class = type("TransactionWrapper",
                                       (TransactionWrapper, self.__class__),
                                       {})

    async def create_connection(self, with_db: bool) -> None:
        self._template = {
            "host": self.host,
            "port": self.port,
            "user": self.user,
            "database": self.database if with_db else None,
            "min_size": self.min_size,
            "max_size": self.max_size,
            "max_inactive_connection_lifetime":
            self.max_inactive_connection_lifetime,
            **self.extra,
        }
        try:
            self._pool = await asyncpg.create_pool(None,
                                                   password=self.password,
                                                   **self._template)
            self.log.debug("Created pool %s with params: %s", self._pool,
                           self._template)
        except asyncpg.InvalidCatalogNameError:
            raise DBConnectionError(
                "Can't establish connection to database {}".format(
                    self.database))

    async def _close(self) -> None:
        if self._pool:  # pragma: nobranch
            try:
                await asyncio.wait_for(self._pool.close(), 10)
            except asyncio.TimeoutError:
                await self._pool.terminate()
            self.log.debug("Closed pool %s with params: %s", self._pool,
                           self._template)
            self._template.clear()

    async def close(self) -> None:
        await self._close()
        self._pool = None

    async def db_create(self) -> None:
        await self.create_connection(with_db=False)
        await self.execute_script('CREATE DATABASE "{}" OWNER "{}"'.format(
            self.database, self.user))
        await self.close()

    async def db_delete(self) -> None:
        await self.create_connection(with_db=False)
        try:
            await self.execute_script('DROP DATABASE "{}"'.format(
                self.database))
        except asyncpg.InvalidCatalogNameError:  # pragma: nocoverage
            pass
        await self.close()

    def acquire_connection(self) -> "AsyncContextManager":
        return PoolConnectionDispatcher(self._pool)

    def _in_transaction(self) -> "TransactionWrapper":
        transaction_wrapper = self._transaction_class(None, self)
        current_transaction.set(transaction_wrapper)
        return transaction_wrapper

    @translate_exceptions
    @retry_connection
    async def execute_insert(self, query: str,
                             values: list) -> Optional[asyncpg.Record]:
        transaction_wrapper = current_transaction.get()
        if transaction_wrapper:
            stmt = await transaction_wrapper._connection.prepare(query)
            return await stmt.fetchrow(*values)

        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            # TODO: Cache prepared statement
            stmt = await connection.prepare(query)
            return await stmt.fetchrow(*values)

    @translate_exceptions
    @retry_connection
    async def execute_many(self, query: str, values: list) -> None:
        transaction_wrapper = current_transaction.get()
        if transaction_wrapper:
            self.log.debug("%s: %s", query, values)
            # TODO: Consider using copy_records_to_table instead
            await transaction_wrapper._connection.executemany(query, values)

        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            # TODO: Consider using copy_records_to_table instead
            await connection.executemany(query, values)

    @translate_exceptions
    @retry_connection
    async def execute_query(self, query: str) -> List[dict]:
        transaction_wrapper = current_transaction.get()
        if transaction_wrapper:
            self.log.debug(query)
            return await transaction_wrapper._connection.fetch(query)

        async with self.acquire_connection() as connection:
            self.log.debug(query)
            return await connection.fetch(query)

    @translate_exceptions
    @retry_connection
    async def execute_script(self, query: str) -> None:
        transaction_wrapper = current_transaction.get()
        if transaction_wrapper:
            self.log.debug(query)
            await transaction_wrapper._connection.execute(query)

        async with self.acquire_connection() as connection:
            self.log.debug(query)
            await connection.execute(query)
Example #9
0
class AsyncpgDBClient(BaseDBAsyncClient):
    DSN_TEMPLATE = 'postgres://{user}:{password}@{host}:{port}/{database}'
    query_class = PostgreSQLQuery
    executor_class = AsyncpgExecutor
    schema_generator = AsyncpgSchemaGenerator
    capabilities = Capabilities('postgres')

    def __init__(self, user: str, password: str, database: str, host: str,
                 port: SupportsInt, **kwargs) -> None:
        super().__init__(**kwargs)

        self.user = user
        self.password = password
        self.database = database
        self.host = host
        self.port = int(port)  # make sure port is int type

        self._connection = None  # Type: Optional[asyncpg.Connection]

        self._transaction_class = type('TransactionWrapper',
                                       (TransactionWrapper, self.__class__),
                                       {})

    async def create_connection(self, with_db: bool) -> None:
        dsn = self.DSN_TEMPLATE.format(
            user=self.user,
            password=self.password,
            host=self.host,
            port=self.port,
            database=self.database if with_db else '')
        try:
            self._connection = await asyncpg.connect(dsn)
            self.log.debug(
                'Created connection %s with params: user=%s database=%s host=%s port=%s',
                self._connection, self.user, self.database, self.host,
                self.port)
        except asyncpg.InvalidCatalogNameError:
            raise DBConnectionError(
                "Can't establish connection to database {}".format(
                    self.database))

    async def close(self) -> None:
        if self._connection:  # pragma: nobranch
            await self._connection.close()
            self.log.debug(
                'Closed connection %s with params: user=%s database=%s host=%s port=%s',
                self._connection, self.user, self.database, self.host,
                self.port)
            self._connection = None

    async def db_create(self) -> None:
        await self.create_connection(with_db=False)
        await self.execute_script('CREATE DATABASE "{}" OWNER "{}"'.format(
            self.database, self.user))
        await self.close()

    async def db_delete(self) -> None:
        await self.create_connection(with_db=False)
        try:
            await self.execute_script('DROP DATABASE "{}"'.format(
                self.database))
        except asyncpg.InvalidCatalogNameError:  # pragma: nocoverage
            pass
        await self.close()

    def acquire_connection(self) -> ConnectionWrapper:
        return ConnectionWrapper(self._connection)

    def _in_transaction(self) -> 'TransactionWrapper':
        return self._transaction_class(self.connection_name, self._connection)

    @translate_exceptions
    async def execute_insert(self, query: str, values: list) -> int:
        async with self.acquire_connection() as connection:
            self.log.debug('%s: %s', query, values)
            # TODO: Cache prepared statement
            stmt = await connection.prepare(query)
            return await stmt.fetchval(*values)

    @translate_exceptions
    async def execute_query(self, query: str) -> List[dict]:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            return await connection.fetch(query)

    @translate_exceptions
    async def execute_script(self, query: str) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            await connection.execute(query)
Example #10
0
class SqliteClient(BaseDBAsyncClient):
    executor_class = SqliteExecutor
    schema_generator = SqliteSchemaGenerator
    capabilities = Capabilities("sqlite", daemon=False, requires_limit=True, inline_comment=True)

    def __init__(self, file_path: str, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self.filename = file_path

        self.pragmas = kwargs.copy()
        self.pragmas.pop("connection_name", None)
        self.pragmas.pop("fetch_inserted", None)
        self.pragmas.setdefault("journal_mode", "WAL")
        self.pragmas.setdefault("journal_size_limit", 16384)
        self.pragmas.setdefault("foreign_keys", "ON")

        self._connection: Optional[aiosqlite.Connection] = None
        self._lock = asyncio.Lock()

    async def create_connection(self, with_db: bool) -> None:
        if not self._connection:  # pragma: no branch
            self._connection = aiosqlite.connect(self.filename, isolation_level=None)
            self._connection.start()
            await self._connection._connect()
            self._connection._conn.row_factory = sqlite3.Row
            for pragma, val in self.pragmas.items():
                cursor = await self._connection.execute(f"PRAGMA {pragma}={val}")
                await cursor.close()
            self.log.debug(
                "Created connection %s with params: filename=%s %s",
                self._connection,
                self.filename,
                " ".join([f"{k}={v}" for k, v in self.pragmas.items()]),
            )

    async def close(self) -> None:
        if self._connection:
            await self._connection.close()
            self.log.debug(
                "Closed connection %s with params: filename=%s %s",
                self._connection,
                self.filename,
                " ".join([f"{k}={v}" for k, v in self.pragmas.items()]),
            )
            self._connection = None

    async def db_create(self) -> None:
        pass

    async def db_delete(self) -> None:
        await self.close()
        try:
            os.remove(self.filename)
        except FileNotFoundError:  # pragma: nocoverage
            pass

    def acquire_connection(self) -> ConnectionWrapper:
        return ConnectionWrapper(self._connection, self._lock)

    def _in_transaction(self) -> "TransactionContext":
        return TransactionContext(TransactionWrapper(self))

    @translate_exceptions
    async def execute_insert(self, query: str, values: list) -> int:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            return (await connection.execute_insert(query, values))[0]

    @translate_exceptions
    async def execute_many(self, query: str, values: List[list]) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            # This code is only ever called in AUTOCOMMIT mode
            await connection.execute("BEGIN")
            try:
                await connection.executemany(query, values)
            except Exception:
                await connection.rollback()
                raise
            else:
                await connection.commit()

    @translate_exceptions
    async def execute_query(
        self, query: str, values: Optional[list] = None
    ) -> Tuple[int, Sequence[dict]]:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            start = connection.total_changes
            rows = await connection.execute_fetchall(query, values)
            return (connection.total_changes - start) or len(rows), rows

    @translate_exceptions
    async def execute_query_dict(self, query: str, values: Optional[list] = None) -> List[dict]:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            return list(map(dict, await connection.execute_fetchall(query, values)))

    @translate_exceptions
    async def execute_script(self, query: str) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            await connection.executescript(query)
Example #11
0
class BasePostgresClient(BaseDBAsyncClient, abc.ABC):
    DSN_TEMPLATE = "postgres://{user}:{password}@{host}:{port}/{database}"
    query_class: Type[PostgreSQLQuery] = PostgreSQLQuery
    executor_class: Type[BasePostgresExecutor] = BasePostgresExecutor
    schema_generator: Type[
        BasePostgresSchemaGenerator] = BasePostgresSchemaGenerator
    capabilities = Capabilities("postgres",
                                support_update_limit_order_by=False)
    connection_class = None
    loop = None
    _pool: Optional[Any] = None
    _connection: Optional[Any] = None

    def __init__(
        self,
        user: Optional[str] = None,
        password: Optional[str] = None,
        database: Optional[str] = None,
        host: Optional[str] = None,
        port: SupportsInt = 5432,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)

        self.user = user
        self.password = password
        self.database = database
        self.host = host
        self.port = int(port)  # make sure port is int type
        self.extra = kwargs.copy()
        # we can't deep copy kwargs because of ssl context
        # since server_settings is a dict, we copy it again
        self.server_settings = (self.extra.pop("server_settings", None)
                                or {}).copy()
        self.schema = self.extra.pop("schema", None)
        self.application_name = self.extra.pop("application_name", None)
        self.extra.pop("connection_name", None)
        self.extra.pop("fetch_inserted", None)
        self.loop = self.extra.pop("loop", None)
        self.connection_class = self.extra.pop("connection_class",
                                               self.connection_class)
        self.pool_minsize = int(self.extra.pop("minsize", 1))
        self.pool_maxsize = int(self.extra.pop("maxsize", 5))

        self._template: dict = {}
        self._pool = None
        self._connection = None

    @abc.abstractmethod
    async def create_connection(self, with_db: bool) -> None:
        raise NotImplementedError("create_connection is not implemented")

    @abc.abstractmethod
    async def create_pool(self, **kwargs):
        raise NotImplementedError("create_pool is not implemented")

    @abc.abstractmethod
    async def _expire_connections(self) -> None:
        raise NotImplementedError("_expire_connections is not implemented")

    @abc.abstractmethod
    async def _close(self) -> None:
        raise NotImplementedError("_close is not implemented")

    @abc.abstractmethod
    async def _translate_exceptions(self, func, *args, **kwargs) -> Exception:
        raise NotImplementedError("translate_exceptions is not implemented")

    async def close(self) -> None:
        await self._close()
        self._template.clear()

    async def db_create(self) -> None:
        await self.create_connection(with_db=False)
        await self.execute_script(
            f'CREATE DATABASE "{self.database}" OWNER "{self.user}"')
        await self.close()

    async def db_delete(self) -> None:
        await self.create_connection(with_db=False)
        try:
            await self.execute_script(f'DROP DATABASE "{self.database}"')
        finally:
            await self.close()

    def acquire_connection(
            self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]:
        return PoolConnectionWrapper(self._pool)

    @abc.abstractmethod
    def _in_transaction(self) -> "TransactionContext":
        raise NotImplementedError("_in_transaction is not implemented")

    @abc.abstractmethod
    async def execute_insert(self, query: str, values: list) -> Optional[Any]:
        raise NotImplementedError("execute_insert is not implemented")

    @abc.abstractmethod
    async def execute_many(self, query: str, values: list) -> None:
        raise NotImplementedError("execute_many is not implemented")

    @abc.abstractmethod
    async def execute_query(
            self,
            query: str,
            values: Optional[list] = None) -> Tuple[int, List[dict]]:
        raise NotImplementedError("execute_query is not implemented")

    @abc.abstractmethod
    async def execute_query_dict(self,
                                 query: str,
                                 values: Optional[list] = None) -> List[dict]:
        raise NotImplementedError("execute_query_dict is not implemented")

    @translate_exceptions
    async def execute_script(self, query: str) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            await connection.execute(query)
Example #12
0
class MySQLClient(BaseDBAsyncClient):
    query_class = MySQLQuery
    executor_class = MySQLExecutor
    schema_generator = MySQLSchemaGenerator
    capabilities = Capabilities('mysql', safe_indexes=False, requires_limit=True)

    def __init__(self, user: str, password: str, database: str, host: str, port: SupportsInt,
                 **kwargs) -> None:
        super().__init__(**kwargs)

        self.user = user
        self.password = password
        self.database = database
        self.host = host
        self.port = int(port)  # make sure port is int type

        self._connection = None  # Type: Optional[aiomysql.Connection]
        self._lock = asyncio.Lock()

        self._transaction_class = type(
            'TransactionWrapper', (TransactionWrapper, self.__class__), {}
        )

    async def create_connection(self, with_db: bool) -> None:
        template = {
            'host': self.host,
            'port': self.port,
            'user': self.user,
            'password': self.password,
            'db': self.database if with_db else None,
            'autocommit': True,
        }

        try:
            self._connection = await aiomysql.connect(**template)
            self.log.debug(
                'Created connection %s with params: user=%s database=%s host=%s port=%s',
                self._connection, self.user, self.database, self.host, self.port
            )
        except pymysql.err.OperationalError:
            raise DBConnectionError(
                "Can't connect to MySQL server: "
                'user={user} database={database} host={host} port={port}'.format(
                    user=self.user, database=self.database, host=self.host, port=self.port
                )
            )

    async def close(self) -> None:
        if self._connection:  # pragma: nobranch
            self._connection.close()
            self.log.debug(
                'Closed connection %s with params: user=%s database=%s host=%s port=%s',
                self._connection, self.user, self.database, self.host, self.port
            )
            self._connection = None

    async def db_create(self) -> None:
        await self.create_connection(with_db=False)
        await self.execute_script(
            'CREATE DATABASE {}'.format(self.database)
        )
        await self.close()

    async def db_delete(self) -> None:
        await self.create_connection(with_db=False)
        try:
            await self.execute_script('DROP DATABASE {}'.format(self.database))
        except pymysql.err.DatabaseError:  # pragma: nocoverage
            pass
        await self.close()

    def acquire_connection(self) -> ConnectionWrapper:
        return ConnectionWrapper(self._connection, self._lock)

    def _in_transaction(self):
        return self._transaction_class(self.connection_name, self._connection, self._lock)

    @translate_exceptions
    async def execute_insert(self, query: str, values: list) -> int:
        async with self.acquire_connection() as connection:
            self.log.debug('%s: %s', query, values)
            async with connection.cursor() as cursor:
                # TODO: Use prepared statement, and cache it
                await cursor.execute(query, values)
                return cursor.lastrowid  # return auto-generated id

    @translate_exceptions
    async def execute_query(self, query: str) -> List[aiomysql.DictCursor]:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            async with connection.cursor(aiomysql.DictCursor) as cursor:
                await cursor.execute(query)
                return await cursor.fetchall()

    @translate_exceptions
    async def execute_script(self, query: str) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            async with connection.cursor() as cursor:
                await cursor.execute(query)
Example #13
0
class AsyncpgDBClient(BaseDBAsyncClient):
    DSN_TEMPLATE = "postgres://{user}:{password}@{host}:{port}/{database}"
    query_class = PostgreSQLQuery
    executor_class = AsyncpgExecutor
    schema_generator = AsyncpgSchemaGenerator
    capabilities = Capabilities("postgres")

    def __init__(self, user: str, password: str, database: str, host: str,
                 port: SupportsInt, **kwargs) -> None:
        super().__init__(**kwargs)

        self.user = user
        self.password = password
        self.database = database
        self.host = host
        self.port = int(port)  # make sure port is int type
        self.extra = kwargs.copy()
        self.schema = self.extra.pop("schema", None)
        self.extra.pop("connection_name", None)
        self.extra.pop("fetch_inserted", None)
        self.extra.pop("loop", None)
        self.extra.pop("connection_class", None)

        self._template: dict = {}
        self._connection: Optional[asyncpg.Connection] = None
        self._lock = asyncio.Lock()

        self._transaction_class = type("TransactionWrapper",
                                       (TransactionWrapper, self.__class__),
                                       {})

    async def create_connection(self, with_db: bool) -> None:
        self._template = {
            "host": self.host,
            "port": self.port,
            "user": self.user,
            "database": self.database if with_db else None,
            **self.extra,
        }
        try:
            self._connection = await asyncpg.connect(None,
                                                     password=self.password,
                                                     **self._template)
            self.log.debug("Created connection %s with params: %s",
                           self._connection, self._template)
        except asyncpg.InvalidCatalogNameError:
            raise DBConnectionError(
                f"Can't establish connection to database {self.database}")
        # Set post-connection variables
        if self.schema:
            await self.execute_script(f"SET search_path TO {self.schema}")

    async def _close(self) -> None:
        if self._connection:  # pragma: nobranch
            await self._connection.close()
            self.log.debug("Closed connection %s with params: %s",
                           self._connection, self._template)
            self._template.clear()

    async def close(self) -> None:
        await self._close()
        self._connection = None

    async def db_create(self) -> None:
        await self.create_connection(with_db=False)
        await self.execute_script(
            f'CREATE DATABASE "{self.database}" OWNER "{self.user}"')
        await self.close()

    async def db_delete(self) -> None:
        await self.create_connection(with_db=False)
        try:
            await self.execute_script(f'DROP DATABASE "{self.database}"')
        except asyncpg.InvalidCatalogNameError:  # pragma: nocoverage
            pass
        await self.close()

    def acquire_connection(self) -> ConnectionWrapper:
        return ConnectionWrapper(self._connection, self._lock)

    def _in_transaction(self) -> "TransactionWrapper":
        return self._transaction_class(self)

    @translate_exceptions
    @retry_connection
    async def execute_insert(self, query: str,
                             values: list) -> Optional[asyncpg.Record]:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            # TODO: Cache prepared statement
            stmt = await connection.prepare(query)
            return await stmt.fetchrow(*values)

    @translate_exceptions
    @retry_connection
    async def execute_many(self, query: str, values: list) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            # TODO: Consider using copy_records_to_table instead
            await connection.executemany(query, values)

    @translate_exceptions
    @retry_connection
    async def execute_query(self,
                            query: str,
                            values: Optional[list] = None) -> List[dict]:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            if values:
                # TODO: Cache prepared statement
                stmt = await connection.prepare(query)
                return await stmt.fetch(*values)
            return await connection.fetch(query)

    @translate_exceptions
    @retry_connection
    async def execute_script(self, query: str) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            await connection.execute(query)
Example #14
0
class SqliteClient(BaseDBAsyncClient):
    executor_class = SqliteExecutor
    schema_generator = SqliteSchemaGenerator
    capabilities = Capabilities('sqlite', requires_limit=True)

    def __init__(self, file_path: str, **kwargs) -> None:
        super().__init__(**kwargs)
        self.filename = file_path
        self._transaction_class = type(
            'TransactionWrapper', (TransactionWrapper, self.__class__), {}
        )
        self._connection = None  # type: Optional[aiosqlite.Connection]
        self._lock = asyncio.Lock()

    async def create_connection(self, with_db: bool) -> None:
        if not self._connection:  # pragma: no branch
            self._connection = aiosqlite.connect(self.filename, isolation_level=None)
            self._connection.start()
            await self._connection._connect()
            self._connection._conn.row_factory = sqlite3.Row
            self.log.debug(
                'Created connection %s with params: filename=%s',
                self._connection, self.filename
            )

    async def close(self) -> None:
        if self._connection:
            await self._connection.close()
            self.log.debug(
                'Closed connection %s with params: filename=%s',
                self._connection, self.filename
            )
            self._connection = None

    async def db_create(self) -> None:
        pass

    async def db_delete(self) -> None:
        await self.close()
        try:
            os.remove(self.filename)
        except FileNotFoundError:  # pragma: nocoverage
            pass

    def acquire_connection(self) -> ConnectionWrapper:
        return ConnectionWrapper(self._connection, self._lock)

    def _in_transaction(self) -> 'TransactionWrapper':
        return self._transaction_class(self.connection_name, self._connection, self._lock)

    @translate_exceptions
    async def execute_insert(self, query: str, values: list) -> int:
        async with self.acquire_connection() as connection:
            self.log.debug('%s: %s', query, values)
            return (await connection.execute_insert(query, values))[0]

    @translate_exceptions
    async def execute_query(self, query: str) -> List[dict]:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            return [dict(row) for row in await connection.execute_fetchall(query)]

    @translate_exceptions
    async def execute_script(self, query: str) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            await connection.executescript(query)
Example #15
0
class MSSqlClient(BaseDBAsyncClient):
    query_class = MSSQLQuery
    executor_class = MSSQLExecutor
    schema_generator = MSSQLSchemaGenerator
    capabilities = Capabilities("mssql",
                                requires_limit=True,
                                inline_comment=True)

    def __init__(
        self,
        *,
        user: str,
        password: str,
        database: str,
        host: str,
        port: SupportsInt,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)

        self.user = user
        self.password = password
        self.database = database
        self.host = host
        self.port = int(port)  # make sure port is int type
        self.extra = kwargs.copy()
        self.storage_engine = self.extra.pop("storage_engine", "")
        self.extra.pop("connection_name", None)
        self.extra.pop("fetch_inserted", None)
        self.extra.pop("db", None)
        self.extra.pop("autocommit", None)
        self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES")
        self.charset = self.extra.pop("charset", "utf8mb4")
        self.pool_minsize = int(self.extra.pop("minsize", 1))
        self.pool_maxsize = int(self.extra.pop("maxsize", 5))
        self.pool_recycle = int(self.extra.pop("pool_recycle", 5))
        self.timeout = int(self.extra.pop("timeout", 5))

        dsn = self.extra.pop("dsn", "")
        if dsn:
            self.dsn = urllib.parse.unquote(dsn)
        else:
            self.dsn = "DSN=MYMSSQL;DATABASE={DB_DATABASE};UID={DB_USER};PWD={DB_PASSWORD}".format(
                DB_DATABASE=database, DB_USER=user, DB_PASSWORD=password)
        self._template: dict = {}
        self._pool: Optional[aioodbc.Pool] = None
        self._connection = None

    async def create_connection(self, with_db: bool) -> None:
        self._template = {
            "dsn": self.dsn,
            "minsize": self.pool_minsize,
            "maxsize": self.pool_maxsize,
            "executor": _g_thread_executor,
            "pool_recycle": self.pool_recycle,
            "timeout": self.timeout,
        }
        try:

            self._pool = await aioodbc.create_pool(**self._template)

            if isinstance(self._pool, aioodbc.Pool):
                async with self.acquire_connection() as connection:
                    async with connection.cursor() as cursor:
                        if self.storage_engine:
                            await cursor.execute(
                                f"SET default_storage_engine='{self.storage_engine}';"
                            )
                            if self.storage_engine.lower(
                            ) != "innodb":  # pragma: nobranch
                                self.capabilities.__dict__[
                                    "supports_transactions"] = False

            self.log.debug("Created connection %s pool with params: %s",
                           self._pool, self._template)
        except pyodbc.OperationalError as e:
            raise DBConnectionError(
                f"Can't connect to MySQL server: {self._template}")

    async def _expire_connections(self) -> None:
        if self._pool:  # pragma: nobranch
            for conn in self._pool._free:
                conn._reader.set_exception(EOFError("EOF"))

    async def _close(self) -> None:
        if self._pool:  # pragma: nobranch
            self._pool.close()
            await self._pool.wait_closed()
            self.log.debug("Closed connection %s with params: %s",
                           self._connection, self._template)
            self._pool = None

    async def close(self) -> None:
        await self._close()
        self._template.clear()

    async def db_create(self) -> None:
        await self.create_connection(with_db=False)
        await self.execute_script(f"CREATE DATABASE {self.database}")
        await self.close()

    async def db_delete(self) -> None:
        await self.create_connection(with_db=False)

        try:
            await self.execute_script(f"DROP DATABASE {self.database}")
        except pyodbc.DatabaseError:  # pragma: nocoverage
            pass
        await self.close()

    def acquire_connection(
            self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]:
        return PoolConnectionWrapper(self._pool)

    def _in_transaction(self) -> "TransactionContext":
        return TransactionContextPooled(TransactionWrapper(self))

    @translate_exceptions
    async def execute_insert(self, query: str, values: list) -> int:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            async with connection.cursor() as cursor:
                await cursor.execute(query, values)
                if "output inserted" in query:
                    row = await cursor.fetchone()
                    return row[0]
                # return cursor.lastrowid  # return auto-generated id

    @translate_exceptions
    async def execute_many(self, query: str, values: list) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            async with connection.cursor() as cursor:
                if self.capabilities.supports_transactions:
                    await connection.begin()
                    try:
                        await cursor.executemany(query, values)
                    except Exception:
                        await connection.rollback()
                        raise
                    else:
                        await connection.commit()
                else:
                    await cursor.executemany(query, values)

    @translate_exceptions
    async def execute_query(
            self,
            query: str,
            values: Optional[list] = None) -> Tuple[int, List[dict]]:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            async with connection.cursor() as cursor:
                if not values:
                    values = tuple()
                await cursor.execute(query, values)
                if "UPDATE" in query:
                    return cursor.rowcount, []
                if "DELETE" in query:
                    return cursor.rowcount, []
                rows = await cursor.fetchall()
                if rows:
                    return cursor.rowcount, [
                        dict(
                            zip([l[0].lower()
                                 for l in row.cursor_description], row))
                        for row in rows
                    ]
                return cursor.rowcount, []

    async def execute_query_dict(self,
                                 query: str,
                                 values: Optional[list] = None) -> List[dict]:
        return (await self.execute_query(query, values))[1]

    @translate_exceptions
    async def execute_script(self, query: str) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            async with connection.cursor() as cursor:
                await cursor.execute(query)
Example #16
0
class AsyncpgDBClient(BaseDBAsyncClient):
    DSN_TEMPLATE = "postgres://{user}:{password}@{host}:{port}/{database}"
    query_class = PostgreSQLQuery
    executor_class = AsyncpgExecutor
    schema_generator = AsyncpgSchemaGenerator
    capabilities = Capabilities("postgres")
    connection_class = asyncpg.connection.Connection
    loop = None

    def __init__(self, user: str, password: str, database: str, host: str,
                 port: SupportsInt, **kwargs: Any) -> None:
        super().__init__(**kwargs)

        self.user = user
        self.password = password
        self.database = database
        self.host = host
        self.port = int(port)  # make sure port is int type
        self.extra = kwargs.copy()
        self.schema = self.extra.pop("schema", None)
        self.extra.pop("connection_name", None)
        self.extra.pop("fetch_inserted", None)
        self.loop = self.extra.pop("loop", None)
        self.connection_class = self.extra.pop("connection_class",
                                               self.connection_class)
        self.pool_minsize = int(self.extra.pop("minsize", 1))
        self.pool_maxsize = int(self.extra.pop("maxsize", 5))

        self._template: dict = {}
        self._pool: Optional[asyncpg.pool] = None
        self._connection = None

    async def create_connection(self, with_db: bool) -> None:
        self._template = {
            "host": self.host,
            "port": self.port,
            "user": self.user,
            "database": self.database if with_db else None,
            "min_size": self.pool_minsize,
            "max_size": self.pool_maxsize,
            "connection_class": self.connection_class,
            "loop": self.loop,
            **self.extra,
        }
        if self.schema:
            self._template["server_settings"] = {"search_path": self.schema}
        try:
            self._pool = await asyncpg.create_pool(None,
                                                   password=self.password,
                                                   **self._template)
            self.log.debug("Created connection pool %s with params: %s",
                           self._pool, self._template)
        except asyncpg.InvalidCatalogNameError:
            raise DBConnectionError(
                f"Can't establish connection to database {self.database}")
        # Set post-connection variables

    async def _expire_connections(self) -> None:
        if self._pool:  # pragma: nobranch
            await self._pool.expire_connections()

    async def _close(self) -> None:
        if self._pool:  # pragma: nobranch
            try:
                await asyncio.wait_for(self._pool.close(), 10)
            except asyncio.TimeoutError:  # pragma: nocoverage
                self._pool.terminate()
            self._pool = None
            self.log.debug("Closed connection pool %s with params: %s",
                           self._pool, self._template)

    async def close(self) -> None:
        await self._close()
        self._template.clear()

    async def db_create(self) -> None:
        await self.create_connection(with_db=False)
        await self.execute_script(
            f'CREATE DATABASE "{self.database}" OWNER "{self.user}"')
        await self.close()

    async def db_delete(self) -> None:
        await self.create_connection(with_db=False)
        try:
            await self.execute_script(f'DROP DATABASE "{self.database}"')
        except asyncpg.InvalidCatalogNameError:  # pragma: nocoverage
            pass
        await self.close()

    def acquire_connection(
            self) -> Union["PoolConnectionWrapper", "ConnectionWrapper"]:
        return PoolConnectionWrapper(self._pool)

    def _in_transaction(self) -> "TransactionContext":
        return TransactionContextPooled(TransactionWrapper(self))

    @translate_exceptions
    async def execute_insert(self, query: str,
                             values: list) -> Optional[asyncpg.Record]:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            # TODO: Cache prepared statement
            return await connection.fetchrow(query, *values)

    @translate_exceptions
    async def execute_many(self, query: str, values: list) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            # TODO: Consider using copy_records_to_table instead
            transaction = connection.transaction()
            await transaction.start()
            try:
                await connection.executemany(query, values)
            except Exception:
                await transaction.rollback()
                raise
            else:
                await transaction.commit()

    @translate_exceptions
    async def execute_query(
            self,
            query: str,
            values: Optional[list] = None) -> Tuple[int, List[dict]]:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            if values:
                params = [query, *values]
            else:
                params = [query]
            if query.startswith("UPDATE") or query.startswith("DELETE"):
                res = await connection.execute(*params)
                try:
                    rows_affected = int(res.split(" ")[1])
                except Exception:  # pragma: nocoverage
                    rows_affected = 0
                return rows_affected, []

            rows = await connection.fetch(*params)
            return len(rows), rows

    @translate_exceptions
    async def execute_query_dict(self,
                                 query: str,
                                 values: Optional[list] = None) -> List[dict]:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            if values:
                return list(map(dict, await connection.fetch(query, *values)))
            return list(map(dict, await connection.fetch(query)))

    @translate_exceptions
    async def execute_script(self, query: str) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            await connection.execute(query)
Example #17
0
class MySQLClient(BaseDBAsyncClient):
    query_class = MySQLQuery
    executor_class = MySQLExecutor
    schema_generator = MySQLSchemaGenerator
    capabilities = Capabilities("mysql",
                                safe_indexes=False,
                                requires_limit=True,
                                inline_comment=True)

    def __init__(self, *, user: str, password: str, database: str, host: str,
                 port: SupportsInt, **kwargs) -> None:
        super().__init__(**kwargs)

        self.user = user
        self.password = password
        self.database = database
        self.host = host
        self.port = int(port)  # make sure port is int type
        self.extra = kwargs.copy()
        self.extra.pop("connection_name", None)
        self.extra.pop("fetch_inserted", None)
        self.extra.pop("db", None)
        self.extra.pop("autocommit", None)
        self.charset = self.extra.pop("charset", "")

        self._template = {}  # type: dict
        self._connection = None  # Type: Optional[aiomysql.Connection]
        self._lock = asyncio.Lock()

        self._transaction_class = type("TransactionWrapper",
                                       (TransactionWrapper, self.__class__),
                                       {})

    async def create_connection(self, with_db: bool) -> None:
        self._template = {
            "host": self.host,
            "port": self.port,
            "user": self.user,
            "db": self.database if with_db else None,
            "autocommit": True,
            "charset": self.charset,
            **self.extra,
        }
        try:
            self._connection = await aiomysql.connect(password=self.password,
                                                      **self._template)
            self.log.debug("Created connection %s with params: %s",
                           self._connection, self._template)
        except pymysql.err.OperationalError:
            raise DBConnectionError(
                "Can't connect to MySQL server: {template}".format(
                    template=self._template))

    async def _close(self) -> None:
        if self._connection:  # pragma: nobranch
            self._connection.close()
            self.log.debug("Closed connection %s with params: %s",
                           self._connection, self._template)
            self._template.clear()

    async def close(self) -> None:
        await self._close()
        self._connection = None

    async def db_create(self) -> None:
        await self.create_connection(with_db=False)
        await self.execute_script("CREATE DATABASE {}".format(self.database))
        await self.close()

    async def db_delete(self) -> None:
        await self.create_connection(with_db=False)
        try:
            await self.execute_script("DROP DATABASE {}".format(self.database))
        except pymysql.err.DatabaseError:  # pragma: nocoverage
            pass
        await self.close()

    def acquire_connection(self) -> ConnectionWrapper:
        return ConnectionWrapper(self._connection, self._lock)

    def _in_transaction(self):
        return self._transaction_class(self)

    @translate_exceptions
    @retry_connection
    async def execute_insert(self, query: str, values: list) -> int:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            async with connection.cursor() as cursor:
                await cursor.execute(query, values)
                return cursor.lastrowid  # return auto-generated id

    @translate_exceptions
    @retry_connection
    async def execute_many(self, query: str, values: list) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            async with connection.cursor() as cursor:
                await cursor.executemany(query, values)

    @translate_exceptions
    @retry_connection
    async def execute_query(
            self,
            query: str,
            values: Optional[list] = None) -> List[aiomysql.DictCursor]:
        async with self.acquire_connection() as connection:
            self.log.debug("%s: %s", query, values)
            async with connection.cursor(aiomysql.DictCursor) as cursor:
                await cursor.execute(query, values)
                return await cursor.fetchall()

    @translate_exceptions
    @retry_connection
    async def execute_script(self, query: str) -> None:
        async with self.acquire_connection() as connection:
            self.log.debug(query)
            async with connection.cursor() as cursor:
                await cursor.execute(query)