Ejemplo n.º 1
0
    def test_pos_spec(self, conn):
        s = sql.SQL("select {0} from {1}").format(sql.Identifier("field"),
                                                  sql.Identifier("table"))
        s1 = s.as_string(conn)
        assert isinstance(s1, str)
        assert s1 == 'select "field" from "table"'

        s = sql.SQL("select {1} from {0}").format(sql.Identifier("table"),
                                                  sql.Identifier("field"))
        s1 = s.as_string(conn)
        assert isinstance(s1, str)
        assert s1 == 'select "field" from "table"'
Ejemplo n.º 2
0
    def _get_type_name(self, tx, schema, value):
        # Special case it as it is passed as unknown so is returned as text
        if schema == (list, str):
            return sql.SQL("text[]")

        registry = self.conn.adapters.types
        dumper = tx.get_dumper(value, self.format)
        dumper.dump(value)  # load the oid if it's dynamic (e.g. array)
        info = registry.get(dumper.oid) or registry.get("text")
        if dumper.oid == info.array_oid:
            return sql.SQL("{}[]").format(sql.Identifier(info.name))
        else:
            return sql.Identifier(info.name)
Ejemplo n.º 3
0
 def _make_columns_clause(columns: List[str] = None) -> psql.Composable:
     """Create the list of columns for a select statement in psql syntax"""
     if isinstance(columns, str):
         raise TypeError(
             f"'columns' must be a list of column names. Got '{columns}'")
     return (psql.SQL(",").join([psql.Identifier(c) for c in columns])
             if columns else psql.SQL("*"))
Ejemplo n.º 4
0
def main() -> None:
    cnn = psycopg.connect()

    cnn.execute(
        sql.SQL("create table testdec ({})").format(
            sql.SQL(", ").join([
                sql.SQL("{} numeric(10,2)").format(sql.Identifier(f"t{i}"))
                for i in range(ncols)
            ])))
    cur = cnn.cursor()

    if test == "copy":
        with cur.copy(
                f"copy testdec from stdin (format {format.name})") as copy:
            for j in range(nrows):
                copy.write_row([
                    Decimal(randrange(10000000000)) / 100 for i in range(ncols)
                ])

    elif test == "insert":
        ph = ["%t", "%b"][format]
        cur.executemany(
            "insert into testdec values (%s)" % ", ".join([ph] * ncols),
            ([Decimal(randrange(10000000000)) / 100 for i in range(ncols)]
             for j in range(nrows)),
        )
    else:
        raise Exception(f"bad test: {test}")
Ejemplo n.º 5
0
    def test_join(self, conn):
        obj = sql.SQL(", ").join(
            [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)]
        )
        assert isinstance(obj, sql.Composed)
        assert obj.as_string(conn) == '"foo", bar, 42'

        obj = sql.SQL(", ").join(
            sql.Composed(
                [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)]
            )
        )
        assert isinstance(obj, sql.Composed)
        assert obj.as_string(conn) == '"foo", bar, 42'

        obj = sql.SQL(", ").join([])
        assert obj == sql.Composed([])
Ejemplo n.º 6
0
def test_json_load_copy(conn, val, jtype, fmt_out):
    cur = conn.cursor()
    stmt = sql.SQL("copy (select {}::{}) to stdout (format {})").format(
        val, sql.Identifier(jtype), sql.SQL(fmt_out.name))
    with cur.copy(stmt) as copy:
        copy.set_types([jtype])
        (got, ) = copy.read_row()

    assert got == json.loads(val)
Ejemplo n.º 7
0
    def __init__(self, connection):
        self.conn = connection
        self.format = PyFormat.BINARY
        self.records = []

        self._schema = None
        self._types = None
        self._types_names = None
        self._makers = {}
        self.table_name = sql.Identifier("fake_table")
Ejemplo n.º 8
0
 def func(self, function_name: str, *args: Any) -> Generator:
     """Call a database function."""
     placeholders = sql.SQL(", ").join(sql.Placeholder() * len(args))
     func = sql.Identifier(function_name)
     cleaned_args = []
     for arg in args:
         if isinstance(arg, dict):
             cleaned_args.append(psycopg.types.json.Jsonb(arg))
         else:
             cleaned_args.append(arg)
     base_query = sql.SQL("SELECT * FROM {}({});").format(func, placeholders)
     return self.query(base_query, cleaned_args)
Ejemplo n.º 9
0
    def test_copy(self, conn):
        cur = conn.cursor()
        cur.execute("""
            create table test_compose (
                id serial primary key,
                foo text, bar text, "ba'z" text)
            """)

        with cur.copy(
                sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format(
                    t=sql.Identifier("test_compose"),
                    f=sql.Identifier("ba'z")), ) as copy:
            copy.write_row((10, "a", "b", "c"))
            copy.write_row((20, "d", "e", "f"))

        with cur.copy(
                sql.SQL(
                    "copy (select {f} from {t} order by id) to stdout").format(
                        t=sql.Identifier("test_compose"),
                        f=sql.Identifier("ba'z"))) as copy:
            assert list(copy) == [b"c\n", b"f\n"]
Ejemplo n.º 10
0
    def write_append(self, target: str, dataframe: Union[pd.DataFrame,
                                                         Dict[str, List]]):
        """Write data in append-mode to a Postgres table

        Args:
            target: The database table to write to.
            dataframe: The data to write as pandas.DataFrame or as a Python dictionary
                in the format `column_name: [column_data]`

        Raises:
            TypeError: When the dataframe is neither a pandas.DataFrame nor a dictionary
        """
        table = psql.Identifier(target)
        if isinstance(dataframe, pd.DataFrame):
            columns = psql.SQL(",").join(
                [psql.Identifier(c) for c in dataframe.columns])
            values = ", ".join(len(dataframe.columns) * ["%s"])
            query = psql.SQL("INSERT INTO {table}({columns}) VALUES (" +
                             values + ")").format(table=table,
                                                  columns=columns,
                                                  values=values)
            with self._connection.cursor() as cursor:
                cursor.executemany(
                    query,
                    map(tuple,
                        dataframe.where(dataframe.notnull()).values))
        elif isinstance(dataframe, dict):
            columns = psql.SQL(",").join(
                [psql.Identifier(c) for c in dataframe])
            values = ", ".join(len(dataframe) * ["%s"])
            query = psql.SQL("INSERT INTO {table}({columns}) VALUES (" +
                             values + ")").format(table=table,
                                                  columns=columns,
                                                  values=values)
            with self._connection.cursor() as cursor:
                cursor.executemany(query, zip(*dataframe.values()))
        else:
            raise TypeError(
                "dataframe must either be a pandas DataFrame or a dict of lists"
            )
Ejemplo n.º 11
0
 def test_init(self):
     assert isinstance(sql.Identifier("foo"), sql.Identifier)
     assert isinstance(sql.Identifier("foo"), sql.Identifier)
     assert isinstance(sql.Identifier("foo", "bar", "baz"), sql.Identifier)
     with pytest.raises(TypeError):
         sql.Identifier()
     with pytest.raises(TypeError):
         sql.Identifier(10)
     with pytest.raises(TypeError):
         sql.Identifier(dt.date(2016, 12, 31))
Ejemplo n.º 12
0
    def test_executemany(self, conn):
        cur = conn.cursor()
        cur.execute("""
            create table test_compose (
                id serial primary key,
                foo text, bar text, "ba'z" text)
            """)
        cur.executemany(
            sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
                sql.Identifier("test_compose"),
                sql.SQL(", ").join(map(sql.Identifier,
                                       ["foo", "bar", "ba'z"])),
                (sql.Placeholder() * 3).join(", "),
            ),
            [(10, "a", "b", "c"), (20, "d", "e", "f")],
        )

        cur.execute("select * from test_compose")
        assert cur.fetchall() == [(10, "a", "b", "c"), (20, "d", "e", "f")]
Ejemplo n.º 13
0
    def _make_psql_query(
        self,
        source: str,
        columns: List[str] = None,
        row_filter: str = None,
        limit: int = -1,
        sample: int = -1,
        drop_duplicates: bool = False,
    ) -> psql.Composed:
        """Compose a full SQL query from the information given in the arguments.

        Args:
            source: The table name (may include a database name)
            columns: List of column names to limit the reading to
            row_filter: Filter expression for selecting rows
            limit: Maximum number of rows to return (limit to first n rows)
            sample: Size of a random sample to return

        Returns:
            Prepared query for psycopg
        """
        if limit != -1 and sample != -1:
            sample = min(limit, sample)
            limit = -1
        table = psql.Identifier(source)
        select = psql.SQL("SELECT DISTINCT") if drop_duplicates else psql.SQL(
            "SELECT")
        columns_clause = self._make_columns_clause(columns)
        where_clause = self._make_where_clause(row_filter)
        limit_clause = self._make_limit_clause(limit)
        sample_clause = self._make_sample_clause(sample)
        query = psql.SQL(" ").join([
            x for x in [
                select,
                columns_clause,
                psql.SQL("FROM"),
                table,
                where_clause,
                sample_clause,
                limit_clause,
            ] if x
        ])
        return query
Ejemplo n.º 14
0
    def write_replace(self, target: str, dataframe: Union[pd.DataFrame,
                                                          Dict[str, List]]):
        """Write data to a Postgres table after deleting all the existing content

        Args:
            target: The database table to write to.
            dataframe: The data to write as pandas.DataFrame or as a Python dictionary
                in the format `column_name: [column_data]`

        Raises:
            TypeError: When the dataframe is neither a pandas.DataFrame nor a dictionary
        """
        if not isinstance(dataframe, (pd.DataFrame, dict)):
            raise TypeError("dataframe must either be a pandas DataFrame "
                            f"or a dict of lists but was {dataframe}")
        table = psql.Identifier(target)
        query = psql.SQL("DELETE FROM {table}").format(table=table)
        with self._connection.cursor() as cursor:
            cursor.execute(query)
        self.write_append(target, dataframe)
Ejemplo n.º 15
0
 def test_repr(self):
     obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
     assert (
         repr(obj) == """Composed([Literal('foo'), Identifier("b'ar")])""")
     assert str(obj) == repr(obj)
Ejemplo n.º 16
0
import pytest

import psycopg
from psycopg import sql
from psycopg.pq import TransactionStatus
from psycopg.types import TypeInfo


@pytest.mark.parametrize("name", ["text", sql.Identifier("text")])
@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
def test_fetch(conn, name, status):
    status = getattr(TransactionStatus, status)
    if status == TransactionStatus.INTRANS:
        conn.execute("select 1")

    assert conn.info.transaction_status == status
    info = TypeInfo.fetch(conn, name)
    assert conn.info.transaction_status == status

    assert info.name == "text"
    # TODO: add the schema?
    # assert info.schema == "pg_catalog"

    assert info.oid == psycopg.adapters.types["text"].oid
    assert info.array_oid == psycopg.adapters.types["text"].array_oid
    assert info.alt_name == "text"


@pytest.mark.asyncio
@pytest.mark.parametrize("name", ["text", sql.Identifier("text")])
@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
Ejemplo n.º 17
0
 def test_dict(self, conn):
     s = sql.SQL("select {f} from {t}").format(f=sql.Identifier("field"),
                                               t=sql.Identifier("table"))
     s1 = s.as_string(conn)
     assert isinstance(s1, str)
     assert s1 == 'select "field" from "table"'
Ejemplo n.º 18
0
    def load_partition(
        self,
        partition: Partition,
        items: Iterable[Dict[str, Any]],
        insert_mode: Optional[Methods] = Methods.insert,
    ) -> None:
        """Load items data for a single partition."""
        conn = self.db.connect()
        t = time.perf_counter()

        logger.debug(f"Loading data for partition: {partition}.")
        with conn.cursor() as cur:
            if partition.requires_update:
                with conn.transaction():
                    cur.execute(
                        "SELECT * FROM partitions WHERE name = %s FOR UPDATE;",
                        (partition.name,),
                    )
                    cur.execute(
                        """
                        INSERT INTO partitions
                        (collection, datetime_range, end_datetime_range)
                        VALUES
                            (%s, tstzrange(%s, %s, '[]'), tstzrange(%s,%s, '[]'))
                        ON CONFLICT (name) DO UPDATE SET
                            datetime_range = EXCLUDED.datetime_range,
                            end_datetime_range = EXCLUDED.end_datetime_range
                        ;
                    """,
                        (
                            partition.collection,
                            partition.datetime_range_min,
                            partition.datetime_range_max,
                            partition.end_datetime_range_min,
                            partition.end_datetime_range_max,
                        ),
                    )
                    logger.debug(
                        f"Adding or updating partition {partition.name} "
                        f"took {time.perf_counter() - t}s"
                    )
                partition.requires_update = False
            else:
                logger.debug(f"Partition {partition.name} does not require an update.")

            with conn.transaction():
                t = time.perf_counter()
                if insert_mode in (
                    None,
                    Methods.insert,
                ):
                    with cur.copy(
                        sql.SQL(
                            """
                            COPY {}
                            (id, collection, datetime, end_datetime, geometry, content)
                            FROM stdin;
                            """
                        ).format(sql.Identifier(partition.name))
                    ) as copy:
                        for item in items:
                            item.pop("partition")
                            copy.write_row(
                                (
                                    item["id"],
                                    item["collection"],
                                    item["datetime"],
                                    item["end_datetime"],
                                    item["geometry"],
                                    item["content"],
                                )
                            )
                    logger.debug(cur.statusmessage)
                    logger.debug(f"Rows affected: {cur.rowcount}")
                elif insert_mode in (
                    Methods.insert_ignore,
                    Methods.upsert,
                    Methods.delsert,
                    Methods.ignore,
                ):
                    cur.execute(
                        """
                        DROP TABLE IF EXISTS items_ingest_temp;
                        CREATE TEMP TABLE items_ingest_temp
                        ON COMMIT DROP AS SELECT * FROM items LIMIT 0;
                        """
                    )
                    with cur.copy(
                        """
                        COPY items_ingest_temp
                        (id, collection, datetime, end_datetime, geometry, content)
                        FROM stdin;
                        """
                    ) as copy:
                        for item in items:
                            item.pop("partition")
                            copy.write_row(
                                (
                                    item["id"],
                                    item["collection"],
                                    item["datetime"],
                                    item["end_datetime"],
                                    item["geometry"],
                                    item["content"],
                                )
                            )
                    logger.debug(cur.statusmessage)
                    logger.debug(f"Copied rows: {cur.rowcount}")

                    cur.execute(
                        sql.SQL(
                            """
                                LOCK TABLE ONLY {} IN EXCLUSIVE MODE;
                            """
                        ).format(sql.Identifier(partition.name))
                    )
                    if insert_mode in (
                        Methods.ignore,
                        Methods.insert_ignore,
                    ):
                        cur.execute(
                            sql.SQL(
                                """
                                INSERT INTO {}
                                SELECT *
                                FROM items_ingest_temp ON CONFLICT DO NOTHING;
                                """
                            ).format(sql.Identifier(partition.name))
                        )
                        logger.debug(cur.statusmessage)
                        logger.debug(f"Rows affected: {cur.rowcount}")
                    elif insert_mode == Methods.upsert:
                        cur.execute(
                            sql.SQL(
                                """
                                INSERT INTO {} AS t SELECT * FROM items_ingest_temp
                                ON CONFLICT (id) DO UPDATE
                                SET
                                    datetime = EXCLUDED.datetime,
                                    end_datetime = EXCLUDED.end_datetime,
                                    geometry = EXCLUDED.geometry,
                                    collection = EXCLUDED.collection,
                                    content = EXCLUDED.content
                                WHERE t IS DISTINCT FROM EXCLUDED
                                ;
                            """
                            ).format(sql.Identifier(partition.name))
                        )
                        logger.debug(cur.statusmessage)
                        logger.debug(f"Rows affected: {cur.rowcount}")
                    elif insert_mode == Methods.delsert:
                        cur.execute(
                            sql.SQL(
                                """
                                WITH deletes AS (
                                    DELETE FROM items i USING items_ingest_temp s
                                        WHERE
                                            i.id = s.id
                                            AND i.collection = s.collection
                                            AND i IS DISTINCT FROM s
                                    RETURNING i.id, i.collection
                                )
                                INSERT INTO {}
                                SELECT s.* FROM
                                    items_ingest_temp s
                                    JOIN deletes d
                                    USING (id, collection);
                                ;
                            """
                            ).format(sql.Identifier(partition.name))
                        )
                        logger.debug(cur.statusmessage)
                        logger.debug(f"Rows affected: {cur.rowcount}")
                else:
                    raise Exception(
                        "Available modes are insert, ignore, upsert, and delsert."
                        f"You entered {insert_mode}."
                    )
        logger.debug(
            f"Copying data for {partition} took {time.perf_counter() - t} seconds"
        )
Ejemplo n.º 19
0
 def fields_names(self):
     return [sql.Identifier(f"fld_{i}") for i in range(len(self.schema))]
Ejemplo n.º 20
0
 def test_join(self):
     assert not hasattr(sql.Identifier("foo"), "join")
Ejemplo n.º 21
0
 def test_as_bytes(self, conn, args, want, enc):
     want = want.encode(enc)
     conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}")
     assert sql.Identifier(*args).as_bytes(conn) == want
Ejemplo n.º 22
0
 def test_as_string(self, conn, args, want):
     assert sql.Identifier(*args).as_string(conn) == want
Ejemplo n.º 23
0
 def test_eq(self):
     assert sql.Identifier("foo") == sql.Identifier("foo")
     assert sql.Identifier("foo", "bar") == sql.Identifier("foo", "bar")
     assert sql.Identifier("foo") != sql.Identifier("bar")
     assert sql.Identifier("foo") != "foo"
     assert sql.Identifier("foo") != sql.SQL("foo")
Ejemplo n.º 24
0
 def test_eq(self):
     L = [sql.Literal("foo"), sql.Identifier("b'ar")]
     l2 = [sql.Literal("foo"), sql.Literal("b'ar")]
     assert sql.Composed(L) == sql.Composed(list(L))
     assert sql.Composed(L) != L
     assert sql.Composed(L) != sql.Composed(l2)
Ejemplo n.º 25
0
 def test_join(self, conn):
     obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
     obj = obj.join(", ")
     assert isinstance(obj, sql.Composed)
     assert noe(obj.as_string(conn)) == "'foo', \"b'ar\""
Ejemplo n.º 26
0
 def _quote_ident(value, conn):
     return sql.Identifier(value).as_string(conn)
Ejemplo n.º 27
0
 def test_as_str(self, conn):
     assert sql.Identifier("foo").as_string(conn) == '"foo"'
     assert sql.Identifier("foo", "bar").as_string(conn) == '"foo"."bar"'
     assert (sql.Identifier("fo'o",
                            'ba"r').as_string(conn) == '"fo\'o"."ba""r"')