示例#1
0
def is_async_sqla_obj(obj):
    """
    Returns True if `obj` is an asynchronous sqlalchemy connectable (engine or connection)
    otherwise False.
    """
    # sqla < 1.4 does not support asynchronous connectables
    if not _sqla_gt14():
        return False
    from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
    return isinstance(obj, (AsyncConnection, AsyncEngine))
示例#2
0
 def table_exists(self, connection=None) -> bool:
     """
     Returns True if the table defined in given instance
     of PandasSpecialEngine exists else returns False.
     """
     con = self.connection if connection is None else connection
     insp = sa.inspect(con)
     if _sqla_gt14():
         return insp.has_table(schema=self.schema, table_name=self.table.name)
     else:
         # this is not particularly efficient but AFAIK it's the best we can do at connection level
         return self.table.name in insp.get_table_names(schema=self.schema)
示例#3
0
 def schema_exists(self, connection=None) -> bool:
     """
     Returns True if the PostgreSQL defined in given instance
     of PandasSpecialEngine exists else returns False.
     """
     self._raise_no_schema_feature()
     con = self.connection if connection is None else connection
     if _sqla_gt14():
         insp = sa.inspect(con)
         return self.schema in insp.get_schema_names()
     else:
         return con.dialect.has_schema(con, self.schema)
示例#4
0
    def _verify_connection_like_object(connection):
        # handle easy cases first
        if isinstance(connection, Connection):
            return True

        # maybe we are in presence of an asynchronous connection
        is_connection = False  # until proven otherwise
        if _sqla_gt14():
            from sqlalchemy.ext.asyncio.engine import AsyncConnection
            is_connection = isinstance(connection, AsyncConnection)

        # raise if not connection like
        if not is_connection:
            raise TypeError(
                f'Expected a Connection or AsyncConnection object. Got {type(connection)} instead'
            )
示例#5
0
 def get_db_columns_names(self, connection=None) -> List[str]:
     """
     Gets the column names of the SQL table defined
     in given instance of PandasSpecialEngine.
     """
     con = self.connection if connection is None else connection
     if _sqla_gt14():
         insp = sa.inspect(con)
         columns_info = insp.get_columns(schema=self.schema, table_name=self.table.name)
     else:
         columns_info = con.dialect.get_columns(connection=con, schema=self.schema,
                                                table_name=self.table.name)
     db_columns_names = [col_info["name"] for col_info in columns_info]
     # handle case of SQlite where no errors are raised in case of a missing table
     # but instead 0 columns are returned by sqlalchemy
     assert len(db_columns_names) > 0
     return db_columns_names
async def test_change_column_type_if_column_empty_async(
        engine, schema, caplog, new_empty_column_value=None):
    print(new_empty_column_value)
    # store arguments we will use for multiple PandasSpecialEngine instances
    table_name = TableNames.CHANGE_EMPTY_COL_TYPE
    common_kwargs = dict(schema=schema, table_name=table_name)
    common_kwargs['dtype'] = {
        'profileid': VARCHAR(5)
    } if 'mysql' in engine.dialect.dialect_description else None

    # json like will not work for sqlalchemy < 1.4
    # also skip sqlite as it does not support such alteration
    json_like = isinstance(new_empty_column_value, (dict, list))
    if json_like and not _sqla_gt14():
        pytest.skip('JSON like values will not work for sqlalchemy < 1.4')
    elif 'sqlite' in engine.dialect.dialect_description:
        pytest.skip('such column alteration is not possible with SQlite')

    # create our example table
    df = pd.DataFrame({
        'profileid': ['foo'],
        'empty_col': [None]
    }).set_index('profileid')
    async with engine.connect() as connection:
        pse = PandasSpecialEngine(connection=connection,
                                  df=df,
                                  **common_kwargs)
        await pse.acreate_table_if_not_exists()
        await connection.commit()
        assert await pse.atable_exists()

    # recreate an instance of PandasSpecialEngine with a new df (so the model gets refreshed)
    # the line below is a "hack" to set any type of element as a column value
    # without pandas trying to broadcast it. This is useful when passing a list or such
    df['empty_col'] = df.index.map(lambda x: new_empty_column_value)
    async with engine.connect() as connection:
        pse = PandasSpecialEngine(connection=connection,
                                  df=df,
                                  **common_kwargs)
        with caplog.at_level(logging.INFO, logger='pangres'):
            await pse.aadapt_dtype_of_empty_db_columns()
        assert len(caplog.records) == 1
        assert 'Changed type of column empty_col' in caplog.text
    caplog.clear()
示例#7
0
    async def __aenter__(self):
        # make sure the sqlalchemy version allows for async usage
        # we only need to do this on entry of the context manager
        if not _sqla_gt14():
            raise NotImplementedError('Async usage of sqlalchemy requires version >= 1.4')
        from sqlalchemy.ext.asyncio.engine import AsyncEngine, AsyncConnection, AsyncConnectable

        # similar procedure to __enter__ with different object types
        if isinstance(self.connectable, AsyncEngine):
            self.connection = await self.connectable.connect()
        elif isinstance(self.connectable, AsyncConnection):
            self.connection = self.connectable
        else:
            raise TypeError(f'Expected an async sqlalchemy connectable object ({AsyncConnectable}). '
                            f'Got {type(self.connectable)}')

        if isinstance(self.connectable, AsyncEngine):
            try:
                self.transaction = await self.connection.begin()
            except Exception as e:  # pragma: no cover
                self._close_resources()
                raise e
        return self
示例#8
0
def create_sync_or_async_engine(conn_string, **kwargs):
    """
    Automatically creates an appropriate engine for given connection string
    (synchronous or asynchronous).

    Examples
    --------
    >>> # sync
    >>> engine = create_sync_or_async_engine("sqlite://")
    >>> # async
    >>> engine = create_sync_or_async_engine("postgresql+asyncpg://username:password@localhost:5432/postgres")  # doctest: +SKIP
    """
    # if we see any known async drivers we will create an async engine
    if any(s in conn_string.split('/')[0] for s in async_to_sync_drivers_dict):
        if not _sqla_gt14():
            raise NotImplementedError(
                'Asynchronous engines require sqlalchemy >= 1.4')

        from sqlalchemy.ext.asyncio import create_async_engine
        return create_async_engine(conn_string)
    # otherwise we will just assume we have to create a sync engine
    else:
        return create_engine(conn_string)
示例#9
0
def pytest_generate_tests(metafunc):
    # this is called for every test
    # if we see the parameters "engine" and "schema" in a function
    # then we will repeat the test for each engine
    func_params = signature(metafunc.function).parameters
    if not ('engine' in func_params and 'schema' in func_params):
        # I could not find any other way than to add a dummy
        # for executing a test only once (parameterize needs arguments)
        metafunc.parametrize('_', [''], scope='module')
        return

    # tests that we need to repeat for each engine + options (e.g. future)
    conn_strings = {
        'sqlite': metafunc.config.option.sqlite_conn,
        'async_sqlite_conn': metafunc.config.option.async_sqlite_conn,
        'pg': metafunc.config.option.pg_conn,
        'asyncpg': metafunc.config.option.async_pg_conn,
        'mysql': metafunc.config.option.mysql_conn,
        'async_mysql_conn': metafunc.config.option.async_mysql_conn
    }
    if not any(v is not None for v in conn_strings.values()):
        raise ValueError(
            'You must provide at least one connection string (e.g. argument --sqlite_conn)!'
        )

    engines, schemas, ids = [], [], []
    for db_type, conn_string in conn_strings.items():

        # cases where we don't skip tests generation
        if conn_string is None:
            continue

        # get engine and schema
        schema = metafunc.config.option.pg_schema if db_type in (
            'pg', 'asyncpg') else None
        engine = create_sync_or_async_engine(conn_string)

        # skip async tests for sync engines
        test_func_info = TestFunctionInfo(
            module_namespace=metafunc.module.__name__,
            function_name=metafunc.function.__name__)
        is_async_engine = is_async_sqla_obj(engine)
        if test_func_info.is_async:
            if not is_async_engine:
                continue

        # skip sync tests for async engines when a tests module has an async variant
        elif test_func_info.has_async_variant and is_async_engine:
            continue

        # generate tests
        schemas.append(schema)
        engines.append(engine)
        schema_id = '' if schema is None else f'_schema:{schema}'
        ids.append(f'{engine.url.drivername}{schema_id}')
        # for sqlalchemy 1.4+ use future=True to try the future sqlalchemy 2.0
        # do not do this for async engines which already implement 2.0 functionalities
        if _sqla_gt14() and not is_async_engine:
            future_engine = create_engine(conn_string, future=True)
            schemas.append(schema)
            engines.append(future_engine)
            ids.append(f'{engine.url.drivername}{schema_id}_future')
    assert len(engines) == len(schemas) == len(ids)
    metafunc.parametrize("engine, schema",
                         list(zip(engines, schemas)),
                         ids=ids,
                         scope='module')
示例#10
0
def table_exists(connection, schema, table_name) -> bool:
    insp = sa.inspect(connection)
    if _sqla_gt14():
        return insp.has_table(schema=schema, table_name=table_name)
    else:
        return table_name in insp.get_table_names(schema=schema)
示例#11
0
    def __init__(self,
                 connection:Connection,
                 df:pd.DataFrame,
                 table_name:str,
                 schema:Optional[str]=None,
                 dtype:Optional[dict]=None):
        """
        Interacts with SQL tables via pandas and SQLalchemy table models.

        Attributes
        ----------
        connection : sqlalchemy.engine.base.Connection
            Connection provided during class instantiation
        df : pd.DataFrame
            DataFrame provided during class instantiation
        table_name : str
            Table name provided during class instantiation
        schema : str or None
            SQL schema provided during class instantiation
        table : sqlalchemy.sql.schema.Table
            Sqlalchemy table model for df

        Parameters
        ----------
        connection
            A connection that was for example directly created from a sqlalchemy
            engine (see https://docs.sqlalchemy.org/en/13/core/engines.html
            and examples below) or from pangres's transaction handler class
            (pangres.transaction.TransactionHandler)
        df
            A pandas DataFrame
        table_name
            Name of the SQL table
        schema
            Name of the schema that contains/will contain the table
            For postgres defaults to "public" if not provided.
        dtype : None or dict {str:SQL_TYPE}, default None
            Similar to pd.to_sql dtype argument.
            This is especially useful for MySQL where the length of
            primary keys with text has to be provided (see Examples)

        Examples
        --------
        >>> from sqlalchemy import create_engine
        >>>
        >>> engine = create_engine("sqlite://")
        >>> df = pd.DataFrame({'name':['Albert', 'Toto'],
        ...                    'profileid':[10, 11]}).set_index('profileid')
        >>>
        >>> with engine.connect() as connection: # doctest: +SKIP
        ...     pse = PandasSpecialEngine(connection=connection, df=df, table_name='example')
        ...     print(pse)
        PandasSpecialEngine (id 123456, hexid 0x123456)
        * connection: <sqlalchemy.engine.base.Connection...>
        * schema: None
        * table: example
        * SQLalchemy table model:
        Table('example', MetaData(bind=<sqlalchemy.engine.base.Connection...>),
              Column('profileid', BigInteger(), table=<example>, primary_key=True, nullable=False),
              Column('name', Text(), table=<example>), schema=None)
        * df.head():
        |   profileid | name   |
        |------------:|:-------|
        |          10 | Albert |
        |          11 | Toto   |
        """
        self._db_type = self._detect_db_type(connection)
        if self._db_type == "postgres":
            schema = 'public' if schema is None else schema
            # raise if we find columns with "(", ")" or "%"
            bad_col_names = [col for col in df.columns if RE_BAD_COL_NAME.search(col)]
            if len(bad_col_names) > 0:
                err = ("psycopg2 (Python postgres driver) does not seem to support "
                       "column names with '%', '(' or ')' "
                       "(see https://github.com/psycopg/psycopg2/issues/167). You need to fix "
                       f"these names: {bad_col_names}")
                raise BadColumnNamesException(err)

        # VERIFY ARGUMENTS
        # all index levels have names
        index_names = list(df.index.names)
        if any(ix_name is None for ix_name in index_names):
            raise UnnamedIndexLevelsException("All index levels must be named!")

        # index is unique
        if not df.index.is_unique:
            err = ("The index must be unique since it is used "
                   "as primary key.\n"
                   "Check duplicates using this code (assuming df "
                   " is the DataFrame you want to upsert):\n"
                   ">>> df.index[df.index.duplicated(keep=False)]")
            raise DuplicateValuesInIndexException(err)

        # there are no duplicated names
        fields = list(df.index.names) + df.columns.tolist()
        if len(set(fields)) != len(fields):
            duplicated_labels = [c for c in fields if fields.count(c) > 1]
            raise DuplicateLabelsException("Found duplicates across index "
                                           f"and columns: {duplicated_labels}")

        # detect json columns
        def is_json(col:str):
            s = df[col].dropna()
            return (not s.empty and
                    s.map(lambda x: isinstance(x, (list, dict))).all())
        json_cols = [col for col in df.columns if is_json(col)]

        # merge with dtype from user
        new_dtype = {c:JSON for c in json_cols}
        if dtype is not None:
            new_dtype.update(dtype)
        new_dtype = None if new_dtype == {} else new_dtype

        # create sqlalchemy table model via pandas
        pandas_sql_engine = pd.io.sql.SQLDatabase(engine=connection, schema=schema)
        pandas_table = pd.io.sql.SQLTable(name=table_name,
                                          pandas_sql_engine=pandas_sql_engine,
                                          frame=df,
                                          dtype=new_dtype)

        # turn pandas table into a pure sqlalchemy table
        # inspired from https://github.com/pandas-dev/pandas/blob/main/pandas/io/sql.py#L815-L821
        metadata = MetaData(bind=connection)
        if _sqla_gt14():
            table = pandas_table.table.to_metadata(metadata)
        else:
            table = pandas_table.table.tometadata(metadata)

        # add PK
        constraint = PrimaryKeyConstraint(*[table.columns[name] for name in df.index.names])
        table.append_constraint(constraint)

        # FIX: make sure there is no autoincrement
        # see https://docs.sqlalchemy.org/en/14/dialects/mysql.html#auto-increment-behavior
        for name in df.index.names:
            table.columns[name].autoincrement = False

        # add remaining attributes
        self.connection = connection
        self.df = df
        self.schema = schema
        self.table = table