Exemple #1
0
    def test_with_standalone_aliased_join(self):
        people = self.tables.people
        values = Values(
            column("bookcase_id", Integer),
            column("bookcase_owner_id", Integer),
        ).data([(1, 1), (2, 1), (3, 2), (3, 3)])
        values = alias(values, "bookcases")

        stmt = select(people, values).select_from(
            people.join(values,
                        values.c.bookcase_owner_id == people.c.people_id))
        self.assert_compile(
            stmt,
            "SELECT people.people_id, people.age, people.name, "
            "bookcases.bookcase_id, bookcases.bookcase_owner_id FROM people "
            "JOIN (VALUES (:param_1, :param_2), (:param_3, :param_4), "
            "(:param_5, :param_6), (:param_7, :param_8)) AS bookcases "
            "(bookcase_id, bookcase_owner_id) "
            "ON people.people_id = bookcases.bookcase_owner_id",
            checkparams={
                "param_1": 1,
                "param_2": 1,
                "param_3": 2,
                "param_4": 1,
                "param_5": 3,
                "param_6": 2,
                "param_7": 3,
                "param_8": 3,
            },
        )
Exemple #2
0
 def test_with_join_unnamed(self):
     people = self.tables.people
     values = Values(
         column("column1", Integer),
         column("column2", Integer),
     ).data([(1, 1), (2, 1), (3, 2), (3, 3)])
     stmt = select(people, values).select_from(
         people.join(values, values.c.column2 == people.c.people_id))
     self.assert_compile(
         stmt,
         "SELECT people.people_id, people.age, people.name, column1, "
         "column2 FROM people JOIN (VALUES (:param_1, :param_2), "
         "(:param_3, :param_4), (:param_5, :param_6), "
         "(:param_7, :param_8)) "
         "ON people.people_id = column2",
         checkparams={
             "param_1": 1,
             "param_2": 1,
             "param_3": 2,
             "param_4": 1,
             "param_5": 3,
             "param_6": 2,
             "param_7": 3,
             "param_8": 3,
         },
     )
Exemple #3
0
 def test_lateral(self):
     people = self.tables.people
     values = (Values(
         column("bookcase_id", Integer),
         column("bookcase_owner_id", Integer),
         name="bookcases",
     ).data([(1, 1), (2, 1), (3, 2), (3, 3)]).lateral())
     stmt = select(people, values).select_from(people.join(values, true()))
     self.assert_compile(
         stmt,
         "SELECT people.people_id, people.age, people.name, "
         "bookcases.bookcase_id, bookcases.bookcase_owner_id FROM people "
         "JOIN LATERAL (VALUES (:param_1, :param_2), (:param_3, :param_4), "
         "(:param_5, :param_6), (:param_7, :param_8)) AS bookcases "
         "(bookcase_id, bookcase_owner_id) "
         "ON true",
         checkparams={
             "param_1": 1,
             "param_2": 1,
             "param_3": 2,
             "param_4": 1,
             "param_5": 3,
             "param_6": 2,
             "param_7": 3,
             "param_8": 3,
         },
     )
Exemple #4
0
def get_obj_id_values(obj_ids):
    """Return a Postgres VALUES representation of ordered list of Obj IDs
    to be returned by the Candidates/Sources query.

    Parameters
    ----------
    obj_ids: `list`
        List of Obj IDs

    Returns
    -------
    values_table: `sqlalchemy.sql.expression.FromClause`
        The VALUES representation of the Obj IDs list.
    """
    values_table = (
        Values(
            column("id", String),
            column("ordering", Integer),
        )
        .data(
            [
                (
                    obj_id,
                    idx,
                )
                for idx, obj_id in enumerate(obj_ids)
            ]
        )
        .alias("values_table")
    )
    return values_table
Exemple #5
0
 def go(literal_binds):
     return Values(
         column("mykey", Integer),
         column("mytext", String),
         column("myint", Integer),
         name="myvalues",
         literal_binds=literal_binds,
     ).data([(1, "textA", 99), (2, "textB", 88)])
Exemple #6
0
    def test_wrong_number_of_elements(self):
        v1 = Values(
            column("CaseSensitive", Integer),
            column("has spaces", String),
            name="Spaces and Cases",
        ).data([(1, "textA", 99), (2, "textB", 88)])

        with expect_raises_message(
                exc.ArgumentError,
                r"Wrong number of elements for 2-tuple: \(1, 'textA', 99\)",
        ):
            str(v1)
Exemple #7
0
        def go(literal_binds, omit=None):
            cols = [
                column("mykey", Integer),
                column("mytext", String),
                column("myint", Integer),
            ]
            if omit:
                for idx in omit:
                    cols[idx] = column(cols[idx].name)

            return Values(*cols, name="myvalues",
                          literal_binds=literal_binds).data([(1, "textA", 99),
                                                             (2, "textB", 88)])
Exemple #8
0
    def test_from_linting_unnamed(self):
        people = self.tables.people
        values = Values(
            column("bookcase_id", Integer),
            column("bookcase_owner_id", Integer),
        ).data([(1, 1), (2, 1), (3, 2), (3, 3)])
        stmt = select(people, values)

        with testing.expect_warnings(
                r"SELECT statement has a cartesian product between FROM "
                r'element\(s\) "(?:\(unnamed VALUES element\)|people)" and '
                r'FROM element "(?:people|\(unnamed VALUES element\))"'):
            stmt.compile(linting=FROM_LINTING)
Exemple #9
0
def get_values_table_and_condition(df):
    """Return a postgres VALUES representation of the indexed columns of
    a photometry dataframe returned by `standardize_photometry_data`.
    Also returns the join condition for cross-matching the VALUES
    representation of `df` against the Photometry table using the
    deduplication index.

    Parameters
    ----------
    df: `pandas.DataFrame`
        Dataframe with the columns 'obj_id', 'instrument_id', 'origin',
        'mjd', 'standardized_fluxerr', 'standardized_flux'.

    Returns
    -------
    values_table: `sqlalchemy.sql.expression.FromClause`
        The VALUES representation of the photometry DataFrame.

    condition: `sqlalchemy.sql.elements.AsBoolean`
       The join condition for cross matching the VALUES representation of
       `df` against the Photometry table using the deduplication index.
    """
    values_table = (Values(
        column("pdidx", sa.Integer),
        column("obj_id", sa.String),
        column("instrument_id", sa.Integer),
        column("origin", sa.String),
        column("mjd", sa.Float),
        column("fluxerr", sa.Float),
        column("flux", sa.Float),
    ).data([(
        row.Index,
        row.obj_id,
        row.instrument_id,
        row.origin,
        float(row.mjd),
        float(row.standardized_fluxerr),
        float(row.standardized_flux),
    ) for row in df.itertuples()]).alias("values_table"))

    # make sure no duplicate data are posted using the index
    condition = and_(
        Photometry.obj_id == values_table.c.obj_id,
        Photometry.instrument_id == values_table.c.instrument_id,
        Photometry.origin == values_table.c.origin,
        Photometry.mjd == values_table.c.mjd,
        Photometry.fluxerr == values_table.c.fluxerr,
        Photometry.flux == values_table.c.flux,
    )

    return values_table, condition
Exemple #10
0
    def _foreign_key_view_statement(self, tables, values=None):
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', category=sa.exc.SAWarning)
            table_constraints = self.model(
                'table_constraints',
                'information_schema',
            )
            key_column_usage = self.model(
                'key_column_usage',
                'information_schema',
            )
            constraint_column_usage = self.model(
                'constraint_column_usage',
                'information_schema',
            )
            query = sa.select([
                table_constraints.c.table_name,
                sa.func.ARRAY_AGG(
                    sa.cast(
                        key_column_usage.c.column_name,
                        sa.TEXT,
                    )).label('foreign_keys'),
            ]).join(
                key_column_usage,
                sa.and_(
                    key_column_usage.c.constraint_name ==
                    table_constraints.c.constraint_name,
                    key_column_usage.c.table_schema ==
                    table_constraints.c.table_schema,
                )).join(
                    constraint_column_usage,
                    sa.and_(
                        constraint_column_usage.c.constraint_name ==
                        table_constraints.c.constraint_name,
                        constraint_column_usage.c.table_schema ==
                        table_constraints.c.table_schema,
                    )).where(*[
                        table_constraints.c.table_name.in_(tables),
                        table_constraints.c.constraint_type == 'FOREIGN KEY',
                    ]).group_by(table_constraints.c.table_name)

            if values:
                query = query.union(
                    sa.select(
                        Values(
                            sa.column('table_name'),
                            sa.column('foreign_keys'),
                        ).data([(value[0], array(value[1]))
                                for value in values]).alias('t')))
        return query
Exemple #11
0
 def test_column_quoting(self):
     v1 = Values(
         column("CaseSensitive", Integer),
         column("has spaces", String),
         name="Spaces and Cases",
     ).data([(1, "textA", 99), (2, "textB", 88)])
     self.assert_compile(
         select(v1),
         'SELECT "Spaces and Cases"."CaseSensitive", '
         '"Spaces and Cases"."has spaces" FROM '
         "(VALUES (:param_1, :param_2, :param_3), "
         "(:param_4, :param_5, :param_6)) "
         'AS "Spaces and Cases" ("CaseSensitive", "has spaces")',
     )
Exemple #12
0
        def go(literal_binds, omit=None):
            cols = [
                column("mykey", Integer),
                column("mytext", String),
                column("myenum", Enum(SomeEnum)),
            ]
            if omit:
                for idx in omit:
                    cols[idx] = column(cols[idx].name)

            return Values(*cols, name="myvalues",
                          literal_binds=literal_binds).data([
                              (MumPyNumber(1), MumPyString("textA"), one),
                              (MumPyNumber(2), MumPyString("textB"), two),
                          ])
Exemple #13
0
 def test_anon_alias(self):
     people = self.tables.people
     values = (Values(
         column("bookcase_id", Integer),
         column("bookcase_owner_id", Integer),
     ).data([(1, 1), (2, 1), (3, 2), (3, 3)]).alias())
     stmt = select(people, values).select_from(
         people.join(values,
                     values.c.bookcase_owner_id == people.c.people_id))
     self.assert_compile(
         stmt,
         "SELECT people.people_id, people.age, people.name, "
         "anon_1.bookcase_id, anon_1.bookcase_owner_id FROM people "
         "JOIN (VALUES (:param_1, :param_2), (:param_3, :param_4), "
         "(:param_5, :param_6), (:param_7, :param_8)) AS anon_1 "
         "(bookcase_id, bookcase_owner_id) "
         "ON people.people_id = anon_1.bookcase_owner_id",
     )
Exemple #14
0
    def test_values_in_cte_literal_binds(self):
        cte1 = select(
            Values(
                column("col1", String),
                column("col2", Integer),
                name="temp_table",
                literal_binds=True,
            ).data([("a", 2), ("b", 3)])).cte("cte1")

        cte2 = select(cte1.c.col1).where(cte1.c.col1 == "q").cte("cte2")
        stmt = select(cte2.c.col1)

        self.assert_compile(
            stmt,
            "WITH cte1 AS (SELECT temp_table.col1 AS col1, "
            "temp_table.col2 AS col2 FROM (VALUES ('a', 2), ('b', 3)) "
            "AS temp_table (col1, col2)), "
            "cte2 AS "
            "(SELECT cte1.col1 AS col1 FROM cte1 WHERE cte1.col1 = :col1_1) "
            "SELECT cte2.col1 FROM cte2",
            checkparams={"col1_1": "q"},
        )
Exemple #15
0
    def test_values_in_cte_params(self):
        cte1 = select(
            Values(
                column("col1", String),
                column("col2", Integer),
                name="temp_table",
            ).data([("a", 2), ("b", 3)])).cte("cte1")

        cte2 = select(cte1.c.col1).where(cte1.c.col1 == "q").cte("cte2")
        stmt = select(cte2.c.col1)

        dialect = default.DefaultDialect()
        dialect.positional = True
        dialect.paramstyle = "numeric"
        self.assert_compile(
            stmt,
            "WITH cte1 AS (SELECT temp_table.col1 AS col1, "
            "temp_table.col2 AS col2 FROM (VALUES (:1, :2), (:3, :4)) AS "
            "temp_table (col1, col2)), "
            "cte2 AS "
            "(SELECT cte1.col1 AS col1 FROM cte1 WHERE cte1.col1 = :5) "
            "SELECT cte2.col1 FROM cte2",
            checkpositional=("a", 2, "b", 3, "q"),
            dialect=dialect,
        )

        self.assert_compile(
            stmt,
            "WITH cte1 AS (SELECT temp_table.col1 AS col1, "
            "temp_table.col2 AS col2 FROM (VALUES ('a', 2), ('b', 3)) "
            "AS temp_table (col1, col2)), "
            "cte2 AS "
            "(SELECT cte1.col1 AS col1 FROM cte1 WHERE cte1.col1 = 'q') "
            "SELECT cte2.col1 FROM cte2",
            literal_binds=True,
            dialect=dialect,
        )
Exemple #16
0
    def _primary_key_view_statement(self, schema, tables, views):

        with warnings.catch_warnings():
            warnings.simplefilter('ignore', category=sa.exc.SAWarning)
            pg_class = self.model('pg_class', 'pg_catalog')
            pg_index = self.model('pg_index', 'pg_catalog')
            pg_attribute = self.model('pg_attribute', 'pg_catalog')
            pg_namespace = self.model('pg_namespace', 'pg_catalog')

        alias = pg_class.alias('x')
        statement = sa.select([
            sa.cast(
                sa.cast(
                    pg_index.c.indrelid,
                    sa.dialects.postgresql.REGCLASS,
                ), sa.Text,
            ).label(
                'table_name'
            ),
            sa.func.ARRAY_AGG(pg_attribute.c.attname).label(
                'primary_keys'
            ),
        ]).join(
            pg_attribute,
            pg_attribute.c.attrelid == pg_index.c.indrelid,
        ).join(
            pg_class,
            pg_class.c.oid == pg_index.c.indexrelid,
        ).join(
            alias,
            alias.c.oid == pg_index.c.indrelid,
        ).join(
            pg_namespace,
            pg_namespace.c.oid == pg_class.c.relnamespace,
        ).where(*[
            pg_namespace.c.nspname.notin_(
                ['pg_catalog', 'pg_toast']
            ),
            pg_index.c.indisprimary,
            sa.cast(
                sa.cast(
                    pg_index.c.indrelid,
                    sa.dialects.postgresql.REGCLASS,
                ), sa.Text,
            ).in_(tables),
            pg_attribute.c.attnum == sa.any_(pg_index.c.indkey),
        ]).group_by(
            pg_index.c.indrelid
        )

        if PRIMARY_KEY_VIEW in views:
            values = self.fetchall(sa.select([
                sa.column('table_name'),
                sa.column('primary_keys'),
            ]).select_from(
                sa.text(PRIMARY_KEY_VIEW)
            ))
            self.__engine.execute(DropView(schema, PRIMARY_KEY_VIEW))
            if values:
                statement = statement.union(
                    sa.select(
                        Values(
                            sa.column('table_name'),
                            sa.column('primary_keys'),
                        ).data(
                            [(value[0], array(value[1])) for value in values]
                        ).alias(
                            't'
                        )
                    )
                )

        return statement
Exemple #17
0
    def sync(
        self,
        filters: Optional[dict] = None,
        txmin: Optional[int] = None,
        txmax: Optional[int] = None,
        extra: Optional[dict] = None,
        ctid: Optional[int] = None,
    ) -> Generator:
        if filters is None:
            filters: dict = {}

        root: Node = self.tree.build(self.nodes)

        self.query_builder.isouter: bool = True

        for node in root.traverse_post_order():

            self._build_filters(filters, node)

            if node.is_root:

                if ctid is not None:
                    subquery = []
                    for page, rows in ctid.items():
                        subquery.append(
                            sa.select([
                                sa.cast(
                                    sa.literal_column(f"'({page},'").concat(
                                        sa.column("s")).concat(")"),
                                    TupleIdentifierType,
                                )
                            ]).select_from(
                                Values(sa.column("s"), ).data([
                                    (row, ) for row in rows
                                ]).alias("s")))
                    if subquery:
                        node._filters.append(
                            sa.or_(*[
                                node.model.c.ctid == sa.any_(
                                    sa.func.ARRAY(q.scalar_subquery()))
                                for q in subquery
                            ]))

                if txmin:
                    node._filters.append(
                        sa.cast(
                            sa.cast(
                                node.model.c.xmin,
                                sa.Text,
                            ),
                            sa.BigInteger,
                        ) >= txmin)
                if txmax:
                    node._filters.append(
                        sa.cast(
                            sa.cast(
                                node.model.c.xmin,
                                sa.Text,
                            ),
                            sa.BigInteger,
                        ) < txmax)

            try:
                self.query_builder.build_queries(node)
            except Exception as e:
                logger.exception(f"Exception {e}")
                raise

        if self.verbose:
            compiled_query(node._subquery, "Query")

        count: int = self.fetchcount(node._subquery)

        with click.progressbar(
                length=count,
                show_pos=True,
                show_percent=True,
                show_eta=True,
                fill_char="=",
                empty_char="-",
                width=50,
        ) as bar:

            for i, (keys, row,
                    primary_keys) in enumerate(self.fetchmany(node._subquery)):
                bar.update(1)

                row: dict = transform(row, self.nodes)

                row[META] = get_private_keys(keys)
                if extra:
                    if extra["table"] not in row[META]:
                        row[META][extra["table"]] = {}
                    if extra["column"] not in row[META][extra["table"]]:
                        row[META][extra["table"]][extra["column"]] = []
                    row[META][extra["table"]][extra["column"]].append(0)

                if self.verbose:
                    print(f"{(i+1)})")
                    print(f"Pkeys: {primary_keys}")
                    pprint.pprint(row)
                    print("-" * 10)

                doc: dict = {
                    "_id": self.get_doc_id(primary_keys),
                    "_index": self.index,
                    "_source": row,
                }

                if self.routing:
                    doc["_routing"] = row[self.routing]

                if self.es.major_version < 7 and not self.es.is_opensearch:
                    doc["_type"] = "_doc"

                if self._plugins:
                    doc = next(self._plugins.transform([doc]))
                    if not doc:
                        continue

                if self.pipeline:
                    doc["pipeline"] = self.pipeline

                yield doc
Exemple #18
0
    def _foreign_key_view_statement(
        self,
        schema,
        tables,
        views,
        user_defined_fkey_tables=None,
    ):
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', category=sa.exc.SAWarning)
            table_constraints = self.model(
                'table_constraints',
                'information_schema',
            )
            key_column_usage = self.model(
                'key_column_usage',
                'information_schema',
            )
            constraint_column_usage = self.model(
                'constraint_column_usage',
                'information_schema',
            )

        statement = sa.select([
            table_constraints.c.table_name,
            sa.func.ARRAY_AGG(
                sa.cast(
                    key_column_usage.c.column_name,
                    sa.TEXT,
                )
            ).label('foreign_keys'),
        ]).join(
            key_column_usage,
            sa.and_(
                key_column_usage.c.constraint_name == table_constraints.c.constraint_name,
                key_column_usage.c.table_schema == table_constraints.c.table_schema,
             )
        ).join(
            constraint_column_usage,
            sa.and_(
                constraint_column_usage.c.constraint_name == table_constraints.c.constraint_name,
                constraint_column_usage.c.table_schema == table_constraints.c.table_schema,
             )
        ).where(*[
             table_constraints.c.table_name.in_(tables),
             table_constraints.c.constraint_type == 'FOREIGN KEY',
        ]).group_by(
            table_constraints.c.table_name
        )

        unions = []
        if FOREIGN_KEY_VIEW in views:
            values = self.fetchall(sa.select([
                sa.column('table_name'),
                sa.column('foreign_keys'),
            ]).select_from(
                sa.text(FOREIGN_KEY_VIEW)
            ))
            self.__engine.execute(DropView(schema, FOREIGN_KEY_VIEW))
            if values:
                unions.append(
                    sa.select(
                        Values(
                            sa.column('table_name'),
                            sa.column('foreign_keys'),
                        ).data(
                            [(value[0], array(value[1])) for value in values]
                        ).alias(
                            'u'
                        )
                    )
                )

        if user_defined_fkey_tables:
            unions.append(
                sa.select(
                    Values(
                        sa.column('table_name'),
                        sa.column('foreign_keys'),
                    ).data(
                        [
                            (
                                value[0], array(value[1])
                            ) for value in user_defined_fkey_tables
                        ]
                    ).alias(
                        'v'
                    )
                )
            )

        if unions:
            statement = sa.union(*unions)

        return statement
Exemple #19
0
    def _primary_key_view_statement(self, schema, tables, views):
        """
        Table name and primary keys association where the table name
        is the index.
        This is only called once on bootstrap.
        It is used within the trigger function to determine what payload
        values to send to pg_notify.

        Since views cannot be modified, we query the existing view for exiting
        rows and union this to the next query.

        So if 'specie' was the only row before, and the next query returns
        'unit' and 'structure', we want to end up with the result below.

        table_name | primary_keys
        ------------+--------------
        specie     | {id}
        unit       | {id}
        structure  | {id}
        """

        with warnings.catch_warnings():
            warnings.simplefilter('ignore', category=sa.exc.SAWarning)
            pg_class = self.model('pg_class', 'pg_catalog')
            pg_index = self.model('pg_index', 'pg_catalog')
            pg_attribute = self.model('pg_attribute', 'pg_catalog')
            pg_namespace = self.model('pg_namespace', 'pg_catalog')

        alias = pg_class.alias('x')
        statement = sa.select([
            sa.cast(
                sa.cast(
                    pg_index.c.indrelid,
                    sa.dialects.postgresql.REGCLASS,
                ),
                sa.Text,
            ).label('table_name'),
            sa.func.ARRAY_AGG(pg_attribute.c.attname).label('primary_keys'),
        ]).join(
            pg_attribute,
            pg_attribute.c.attrelid == pg_index.c.indrelid,
        ).join(
            pg_class,
            pg_class.c.oid == pg_index.c.indexrelid,
        ).join(
            alias,
            alias.c.oid == pg_index.c.indrelid,
        ).join(
            pg_namespace,
            pg_namespace.c.oid == pg_class.c.relnamespace,
        ).where(*[
            pg_namespace.c.nspname.notin_(['pg_catalog', 'pg_toast']),
            pg_index.c.indisprimary,
            sa.cast(
                sa.cast(
                    pg_index.c.indrelid,
                    sa.dialects.postgresql.REGCLASS,
                ),
                sa.Text,
            ).in_(tables),
            pg_attribute.c.attnum == sa.any_(pg_index.c.indkey),
        ]).group_by(pg_index.c.indrelid)

        if PRIMARY_KEY_VIEW in views:
            values = self.fetchall(
                sa.select([
                    sa.column('table_name'),
                    sa.column('primary_keys'),
                ]).select_from(sa.text(PRIMARY_KEY_VIEW)))
            self.__engine.execute(DropView(schema, PRIMARY_KEY_VIEW))
            if values:
                statement = statement.union(
                    sa.select(
                        Values(
                            sa.column('table_name'),
                            sa.column('primary_keys'),
                        ).data([(value[0], array(value[1]))
                                for value in values]).alias('t')))

        return statement
Exemple #20
0
    def create_views(self, schema, tables, user_defined_fkey_tables):

        views = sa.inspect(self.engine).get_view_names(schema)

        logger.debug(f'Creating view: {schema}.{PRIMARY_KEY_VIEW}')
        self.__engine.execute(
            CreateView(
                schema,
                PRIMARY_KEY_VIEW,
                self._primary_key_view_statement(schema, tables, views),
            ))
        self.__engine.execute(DropIndex(PRIMARY_KEY_INDEX))
        self.__engine.execute(
            CreateIndex(
                PRIMARY_KEY_INDEX,
                schema,
                PRIMARY_KEY_VIEW,
                ['table_name'],
            ))
        logger.debug(f'Created view: {schema}.{PRIMARY_KEY_VIEW}')

        logger.debug(f'Creating view: {schema}.{FOREIGN_KEY_VIEW}')

        # if view exists, query it first
        rows = {}
        if FOREIGN_KEY_VIEW in views:
            _rows = self.fetchall(
                sa.select([
                    sa.column('table_name').label('table_name'),
                    sa.func.ARRAY_AGG(
                        sa.column('fkeys')).label('foreign_keys'),
                ]).select_from(
                    sa.text(FOREIGN_KEY_VIEW),
                    sa.func.unnest(
                        sa.column('foreign_keys')).alias('fkeys')).group_by(
                            sa.column('table_name')))
            rows = _rows_to_dict(_rows)

        _rows = self.fetchall(self._foreign_key_view_statement(tables))
        if _rows:
            _rows = _rows_to_dict(_rows)
            rows = _merge_dict(rows, _rows)

        if user_defined_fkey_tables:
            rows = _merge_dict(rows, user_defined_fkey_tables)

        # if view exists, drop it first, we have all the existing data
        if FOREIGN_KEY_VIEW in views:
            self.__engine.execute(DropView(schema, FOREIGN_KEY_VIEW))

        statement = sa.select(
            Values(
                sa.column('table_name'),
                sa.column('foreign_keys'),
            ).data([(key, array(value))
                    for key, value in rows.items()]).alias('v'))
        self.__engine.execute(CreateView(
            schema,
            FOREIGN_KEY_VIEW,
            statement,
        ))
        self.__engine.execute(DropIndex(FOREIGN_KEY_INDEX))
        self.__engine.execute(
            CreateIndex(
                FOREIGN_KEY_INDEX,
                schema,
                FOREIGN_KEY_VIEW,
                ['table_name'],
            ))
        logger.debug(f'Created view: {schema}.{FOREIGN_KEY_VIEW}')
Exemple #21
0
def create_view(
        engine,
        schema: str,
        tables: list,
        user_defined_fkey_tables: dict,
        base: "Base",  # noqa F821
) -> None:
    """
    View describing primary_keys and foreign_keys for each table
    with an index on table_name

    This is only called once on bootstrap.
    It is used within the trigger function to determine what payload
    values to send to pg_notify.

    Since views cannot be modified, we query the existing view for exiting
    rows and union this to the next query.

    So if 'specie' was the only row before, and the next query returns
    'unit' and 'structure', we want to end up with the result below.

    table_name | primary_keys | foreign_keys
    -----------+--------------+--------------
    specie     | {id}         | {id, user_id}
    unit       | {id}         | {id, profile_id}
    structure  | {id}         | {id}
    """

    views: list = sa.inspect(engine).get_view_names(schema)

    rows: dict = {}
    if MATERIALIZED_VIEW in views:
        for table_name, primary_keys, foreign_keys in base.fetchall(
                sa.select(["*"]).select_from(
                    sa.text(f"{schema}.{MATERIALIZED_VIEW}"))):
            rows.setdefault(
                table_name,
                {
                    "primary_keys": set([]),
                    "foreign_keys": set([])
                },
            )
            if primary_keys:
                rows[table_name]["primary_keys"] = set(primary_keys)
            if foreign_keys:
                rows[table_name]["foreign_keys"] = set(foreign_keys)

        engine.execute(DropView(schema, MATERIALIZED_VIEW))

    if schema != DEFAULT_SCHEMA:
        for table in set(tables):
            tables.add(f"{schema}.{table}")

    for table_name, columns in base.fetchall(base._primary_keys(
            schema, tables)):
        rows.setdefault(
            table_name,
            {
                "primary_keys": set([]),
                "foreign_keys": set([])
            },
        )
        if columns:
            rows[table_name]["primary_keys"] |= set(columns)

    for table_name, columns in base.fetchall(base._foreign_keys(
            schema, tables)):
        rows.setdefault(
            table_name,
            {
                "primary_keys": set([]),
                "foreign_keys": set([])
            },
        )
        if columns:
            rows[table_name]["foreign_keys"] |= set(columns)

    if user_defined_fkey_tables:
        for table_name, columns in user_defined_fkey_tables.items():
            rows.setdefault(
                table_name,
                {
                    "primary_keys": set([]),
                    "foreign_keys": set([])
                },
            )
            if columns:
                rows[table_name]["foreign_keys"] |= set(columns)

    if not rows:
        rows.setdefault(
            None,
            {
                "primary_keys": set([]),
                "foreign_keys": set([])
            },
        )

    statement = sa.select(
        Values(
            sa.column("table_name"),
            sa.column("primary_keys"),
            sa.column("foreign_keys"),
        ).data([(
            table_name,
            array(fields["primary_keys"])
            if fields.get("primary_keys") else None,
            array(fields.get("foreign_keys"))
            if fields.get("foreign_keys") else None,
        ) for table_name, fields in rows.items()]).alias("t"))
    logger.debug(f"Creating view: {schema}.{MATERIALIZED_VIEW}")
    engine.execute(CreateView(schema, MATERIALIZED_VIEW, statement))
    engine.execute(DropIndex("_idx"))
    engine.execute(
        CreateIndex("_idx", schema, MATERIALIZED_VIEW, ["table_name"]))
    logger.debug(f"Created view: {schema}.{MATERIALIZED_VIEW}")