def is_async_sqla_obj(obj): """ Returns True if `obj` is an asynchronous sqlalchemy connectable (engine or connection) otherwise False. """ # sqla < 1.4 does not support asynchronous connectables if not _sqla_gt14(): return False from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine return isinstance(obj, (AsyncConnection, AsyncEngine))
def table_exists(self, connection=None) -> bool: """ Returns True if the table defined in given instance of PandasSpecialEngine exists else returns False. """ con = self.connection if connection is None else connection insp = sa.inspect(con) if _sqla_gt14(): return insp.has_table(schema=self.schema, table_name=self.table.name) else: # this is not particularly efficient but AFAIK it's the best we can do at connection level return self.table.name in insp.get_table_names(schema=self.schema)
def schema_exists(self, connection=None) -> bool: """ Returns True if the PostgreSQL defined in given instance of PandasSpecialEngine exists else returns False. """ self._raise_no_schema_feature() con = self.connection if connection is None else connection if _sqla_gt14(): insp = sa.inspect(con) return self.schema in insp.get_schema_names() else: return con.dialect.has_schema(con, self.schema)
def _verify_connection_like_object(connection): # handle easy cases first if isinstance(connection, Connection): return True # maybe we are in presence of an asynchronous connection is_connection = False # until proven otherwise if _sqla_gt14(): from sqlalchemy.ext.asyncio.engine import AsyncConnection is_connection = isinstance(connection, AsyncConnection) # raise if not connection like if not is_connection: raise TypeError( f'Expected a Connection or AsyncConnection object. Got {type(connection)} instead' )
def get_db_columns_names(self, connection=None) -> List[str]: """ Gets the column names of the SQL table defined in given instance of PandasSpecialEngine. """ con = self.connection if connection is None else connection if _sqla_gt14(): insp = sa.inspect(con) columns_info = insp.get_columns(schema=self.schema, table_name=self.table.name) else: columns_info = con.dialect.get_columns(connection=con, schema=self.schema, table_name=self.table.name) db_columns_names = [col_info["name"] for col_info in columns_info] # handle case of SQlite where no errors are raised in case of a missing table # but instead 0 columns are returned by sqlalchemy assert len(db_columns_names) > 0 return db_columns_names
async def test_change_column_type_if_column_empty_async( engine, schema, caplog, new_empty_column_value=None): print(new_empty_column_value) # store arguments we will use for multiple PandasSpecialEngine instances table_name = TableNames.CHANGE_EMPTY_COL_TYPE common_kwargs = dict(schema=schema, table_name=table_name) common_kwargs['dtype'] = { 'profileid': VARCHAR(5) } if 'mysql' in engine.dialect.dialect_description else None # json like will not work for sqlalchemy < 1.4 # also skip sqlite as it does not support such alteration json_like = isinstance(new_empty_column_value, (dict, list)) if json_like and not _sqla_gt14(): pytest.skip('JSON like values will not work for sqlalchemy < 1.4') elif 'sqlite' in engine.dialect.dialect_description: pytest.skip('such column alteration is not possible with SQlite') # create our example table df = pd.DataFrame({ 'profileid': ['foo'], 'empty_col': [None] }).set_index('profileid') async with engine.connect() as connection: pse = PandasSpecialEngine(connection=connection, df=df, **common_kwargs) await pse.acreate_table_if_not_exists() await connection.commit() assert await pse.atable_exists() # recreate an instance of PandasSpecialEngine with a new df (so the model gets refreshed) # the line below is a "hack" to set any type of element as a column value # without pandas trying to broadcast it. This is useful when passing a list or such df['empty_col'] = df.index.map(lambda x: new_empty_column_value) async with engine.connect() as connection: pse = PandasSpecialEngine(connection=connection, df=df, **common_kwargs) with caplog.at_level(logging.INFO, logger='pangres'): await pse.aadapt_dtype_of_empty_db_columns() assert len(caplog.records) == 1 assert 'Changed type of column empty_col' in caplog.text caplog.clear()
async def __aenter__(self): # make sure the sqlalchemy version allows for async usage # we only need to do this on entry of the context manager if not _sqla_gt14(): raise NotImplementedError('Async usage of sqlalchemy requires version >= 1.4') from sqlalchemy.ext.asyncio.engine import AsyncEngine, AsyncConnection, AsyncConnectable # similar procedure to __enter__ with different object types if isinstance(self.connectable, AsyncEngine): self.connection = await self.connectable.connect() elif isinstance(self.connectable, AsyncConnection): self.connection = self.connectable else: raise TypeError(f'Expected an async sqlalchemy connectable object ({AsyncConnectable}). ' f'Got {type(self.connectable)}') if isinstance(self.connectable, AsyncEngine): try: self.transaction = await self.connection.begin() except Exception as e: # pragma: no cover self._close_resources() raise e return self
def create_sync_or_async_engine(conn_string, **kwargs): """ Automatically creates an appropriate engine for given connection string (synchronous or asynchronous). Examples -------- >>> # sync >>> engine = create_sync_or_async_engine("sqlite://") >>> # async >>> engine = create_sync_or_async_engine("postgresql+asyncpg://username:password@localhost:5432/postgres") # doctest: +SKIP """ # if we see any known async drivers we will create an async engine if any(s in conn_string.split('/')[0] for s in async_to_sync_drivers_dict): if not _sqla_gt14(): raise NotImplementedError( 'Asynchronous engines require sqlalchemy >= 1.4') from sqlalchemy.ext.asyncio import create_async_engine return create_async_engine(conn_string) # otherwise we will just assume we have to create a sync engine else: return create_engine(conn_string)
def pytest_generate_tests(metafunc): # this is called for every test # if we see the parameters "engine" and "schema" in a function # then we will repeat the test for each engine func_params = signature(metafunc.function).parameters if not ('engine' in func_params and 'schema' in func_params): # I could not find any other way than to add a dummy # for executing a test only once (parameterize needs arguments) metafunc.parametrize('_', [''], scope='module') return # tests that we need to repeat for each engine + options (e.g. future) conn_strings = { 'sqlite': metafunc.config.option.sqlite_conn, 'async_sqlite_conn': metafunc.config.option.async_sqlite_conn, 'pg': metafunc.config.option.pg_conn, 'asyncpg': metafunc.config.option.async_pg_conn, 'mysql': metafunc.config.option.mysql_conn, 'async_mysql_conn': metafunc.config.option.async_mysql_conn } if not any(v is not None for v in conn_strings.values()): raise ValueError( 'You must provide at least one connection string (e.g. argument --sqlite_conn)!' ) engines, schemas, ids = [], [], [] for db_type, conn_string in conn_strings.items(): # cases where we don't skip tests generation if conn_string is None: continue # get engine and schema schema = metafunc.config.option.pg_schema if db_type in ( 'pg', 'asyncpg') else None engine = create_sync_or_async_engine(conn_string) # skip async tests for sync engines test_func_info = TestFunctionInfo( module_namespace=metafunc.module.__name__, function_name=metafunc.function.__name__) is_async_engine = is_async_sqla_obj(engine) if test_func_info.is_async: if not is_async_engine: continue # skip sync tests for async engines when a tests module has an async variant elif test_func_info.has_async_variant and is_async_engine: continue # generate tests schemas.append(schema) engines.append(engine) schema_id = '' if schema is None else f'_schema:{schema}' ids.append(f'{engine.url.drivername}{schema_id}') # for sqlalchemy 1.4+ use future=True to try the future sqlalchemy 2.0 # do not do this for async engines which already implement 2.0 functionalities if _sqla_gt14() and not is_async_engine: future_engine = create_engine(conn_string, future=True) schemas.append(schema) engines.append(future_engine) ids.append(f'{engine.url.drivername}{schema_id}_future') assert len(engines) == len(schemas) == len(ids) metafunc.parametrize("engine, schema", list(zip(engines, schemas)), ids=ids, scope='module')
def table_exists(connection, schema, table_name) -> bool: insp = sa.inspect(connection) if _sqla_gt14(): return insp.has_table(schema=schema, table_name=table_name) else: return table_name in insp.get_table_names(schema=schema)
def __init__(self, connection:Connection, df:pd.DataFrame, table_name:str, schema:Optional[str]=None, dtype:Optional[dict]=None): """ Interacts with SQL tables via pandas and SQLalchemy table models. Attributes ---------- connection : sqlalchemy.engine.base.Connection Connection provided during class instantiation df : pd.DataFrame DataFrame provided during class instantiation table_name : str Table name provided during class instantiation schema : str or None SQL schema provided during class instantiation table : sqlalchemy.sql.schema.Table Sqlalchemy table model for df Parameters ---------- connection A connection that was for example directly created from a sqlalchemy engine (see https://docs.sqlalchemy.org/en/13/core/engines.html and examples below) or from pangres's transaction handler class (pangres.transaction.TransactionHandler) df A pandas DataFrame table_name Name of the SQL table schema Name of the schema that contains/will contain the table For postgres defaults to "public" if not provided. dtype : None or dict {str:SQL_TYPE}, default None Similar to pd.to_sql dtype argument. This is especially useful for MySQL where the length of primary keys with text has to be provided (see Examples) Examples -------- >>> from sqlalchemy import create_engine >>> >>> engine = create_engine("sqlite://") >>> df = pd.DataFrame({'name':['Albert', 'Toto'], ... 'profileid':[10, 11]}).set_index('profileid') >>> >>> with engine.connect() as connection: # doctest: +SKIP ... pse = PandasSpecialEngine(connection=connection, df=df, table_name='example') ... print(pse) PandasSpecialEngine (id 123456, hexid 0x123456) * connection: <sqlalchemy.engine.base.Connection...> * schema: None * table: example * SQLalchemy table model: Table('example', MetaData(bind=<sqlalchemy.engine.base.Connection...>), Column('profileid', BigInteger(), table=<example>, primary_key=True, nullable=False), Column('name', Text(), table=<example>), schema=None) * df.head(): | profileid | name | |------------:|:-------| | 10 | Albert | | 11 | Toto | """ self._db_type = self._detect_db_type(connection) if self._db_type == "postgres": schema = 'public' if schema is None else schema # raise if we find columns with "(", ")" or "%" bad_col_names = [col for col in df.columns if RE_BAD_COL_NAME.search(col)] if len(bad_col_names) > 0: err = ("psycopg2 (Python postgres driver) does not seem to support " "column names with '%', '(' or ')' " "(see https://github.com/psycopg/psycopg2/issues/167). You need to fix " f"these names: {bad_col_names}") raise BadColumnNamesException(err) # VERIFY ARGUMENTS # all index levels have names index_names = list(df.index.names) if any(ix_name is None for ix_name in index_names): raise UnnamedIndexLevelsException("All index levels must be named!") # index is unique if not df.index.is_unique: err = ("The index must be unique since it is used " "as primary key.\n" "Check duplicates using this code (assuming df " " is the DataFrame you want to upsert):\n" ">>> df.index[df.index.duplicated(keep=False)]") raise DuplicateValuesInIndexException(err) # there are no duplicated names fields = list(df.index.names) + df.columns.tolist() if len(set(fields)) != len(fields): duplicated_labels = [c for c in fields if fields.count(c) > 1] raise DuplicateLabelsException("Found duplicates across index " f"and columns: {duplicated_labels}") # detect json columns def is_json(col:str): s = df[col].dropna() return (not s.empty and s.map(lambda x: isinstance(x, (list, dict))).all()) json_cols = [col for col in df.columns if is_json(col)] # merge with dtype from user new_dtype = {c:JSON for c in json_cols} if dtype is not None: new_dtype.update(dtype) new_dtype = None if new_dtype == {} else new_dtype # create sqlalchemy table model via pandas pandas_sql_engine = pd.io.sql.SQLDatabase(engine=connection, schema=schema) pandas_table = pd.io.sql.SQLTable(name=table_name, pandas_sql_engine=pandas_sql_engine, frame=df, dtype=new_dtype) # turn pandas table into a pure sqlalchemy table # inspired from https://github.com/pandas-dev/pandas/blob/main/pandas/io/sql.py#L815-L821 metadata = MetaData(bind=connection) if _sqla_gt14(): table = pandas_table.table.to_metadata(metadata) else: table = pandas_table.table.tometadata(metadata) # add PK constraint = PrimaryKeyConstraint(*[table.columns[name] for name in df.index.names]) table.append_constraint(constraint) # FIX: make sure there is no autoincrement # see https://docs.sqlalchemy.org/en/14/dialects/mysql.html#auto-increment-behavior for name in df.index.names: table.columns[name].autoincrement = False # add remaining attributes self.connection = connection self.df = df self.schema = schema self.table = table