예제 #1
0
 def _test_get_foreign_keys(self, schema=None):
     meta = MetaData(testing.db)
     (users, addresses) = createTables(meta, schema)
     meta.create_all()
     insp = Inspector(meta.bind)
     try:
         expected_schema = schema
         # users
         users_fkeys = insp.get_foreign_keys(users.name,
                                             schema=schema)
         fkey1 = users_fkeys[0]
         self.assert_(fkey1['name'] is not None)
         eq_(fkey1['referred_schema'], expected_schema)
         eq_(fkey1['referred_table'], users.name)
         eq_(fkey1['referred_columns'], ['user_id', ])
         eq_(fkey1['constrained_columns'], ['parent_user_id'])
         #addresses
         addr_fkeys = insp.get_foreign_keys(addresses.name,
                                            schema=schema)
         fkey1 = addr_fkeys[0]
         self.assert_(fkey1['name'] is not None)
         eq_(fkey1['referred_schema'], expected_schema)
         eq_(fkey1['referred_table'], users.name)
         eq_(fkey1['referred_columns'], ['user_id', ])
         eq_(fkey1['constrained_columns'], ['remote_user_id'])
     finally:
         addresses.drop()
         users.drop()
예제 #2
0
 def _test_get_foreign_keys(self, schema=None):
     meta = MetaData(testing.db)
     (users, addresses) = createTables(meta, schema)
     meta.create_all()
     insp = Inspector(meta.bind)
     try:
         expected_schema = schema
         if schema is None:
             try:
                 expected_schema = meta.bind.dialect.get_default_schema_name(
                                 meta.bind)
             except NotImplementedError:
                 expected_schema = None
         # users
         users_fkeys = insp.get_foreign_keys(users.name,
                                             schema=schema)
         fkey1 = users_fkeys[0]
         self.assert_(fkey1['name'] is not None)
         self.assertEqual(fkey1['referred_schema'], expected_schema)
         self.assertEqual(fkey1['referred_table'], users.name)
         self.assertEqual(fkey1['referred_columns'], ['user_id', ])
         self.assertEqual(fkey1['constrained_columns'], ['parent_user_id'])
         #addresses
         addr_fkeys = insp.get_foreign_keys(addresses.name,
                                            schema=schema)
         fkey1 = addr_fkeys[0]
         self.assert_(fkey1['name'] is not None)
         self.assertEqual(fkey1['referred_schema'], expected_schema)
         self.assertEqual(fkey1['referred_table'], users.name)
         self.assertEqual(fkey1['referred_columns'], ['user_id', ])
         self.assertEqual(fkey1['constrained_columns'], ['remote_user_id'])
     finally:
         addresses.drop()
         users.drop()
 def test_compare_get_foreign_keys_for_sql_and_odbc(self, schema, table,
                                                    engine_name):
     with self.engine_map[engine_name].begin() as c:
         if schema is None:
             c.execute("OPEN SCHEMA %s" % self.schema_2)
         dialect = Inspector(c).dialect
         foreign_keys_fallback = dialect.get_foreign_keys(
             connection=c,
             table_name=table,
             schema=schema,
             use_sql_fallback=True)
         foreign_keys_odbc = dialect.get_foreign_keys(connection=c,
                                                      table_name=table,
                                                      schema=schema)
         assert str(foreign_keys_fallback) == str(foreign_keys_odbc)
 def test_get_foreign_keys_table_name_none(self, use_sql_fallback,
                                           engine_name):
     with self.engine_map[engine_name].begin() as c:
         dialect = Inspector(c).dialect
         foreign_keys = dialect.get_foreign_keys(
             connection=c,
             schema=self.schema,
             table_name=None,
             use_sql_fallback=use_sql_fallback)
         assert foreign_keys == []
예제 #5
0
def generic_find_fk_constraint_name(  # pylint: disable=invalid-name
        table: str, columns: Set[str], referenced: str,
        insp: Inspector) -> Optional[str]:
    """Utility to find a foreign-key constraint name in alembic migrations"""
    for fk in insp.get_foreign_keys(table):
        if (fk["referred_table"] == referenced
                and set(fk["referred_columns"]) == columns):
            return fk["name"]

    return None
예제 #6
0
def generic_find_fk_constraint_names(table: str, columns: Set[str],
                                     referenced: str,
                                     insp: Inspector) -> Set[str]:
    """Utility to find foreign-key constraint names in alembic migrations"""
    names = set()

    for fk in insp.get_foreign_keys(table):
        if (fk["referred_table"] == referenced
                and set(fk["referred_columns"]) == columns):
            names.add(fk["name"])

    return names
예제 #7
0
 def _get_foreign_keys(self, dataset_urn: str, inspector: Inspector,
                       schema: str,
                       table: str) -> List[ForeignKeyConstraint]:
     try:
         foreign_keys = [
             self.get_foreign_key_metadata(dataset_urn, schema, fk_rec,
                                           inspector)
             for fk_rec in inspector.get_foreign_keys(table, schema)
         ]
     except KeyError:
         # certain databases like MySQL cause issues due to lower-case/upper-case irregularities
         logger.debug(
             f"{dataset_urn}: failure in foreign key extraction... skipping"
         )
         foreign_keys = []
     return foreign_keys
    def test_get_foreign_keys(self, use_sql_fallback, engine_name):
        with self.engine_map[engine_name].begin() as c:
            dialect = Inspector(c).dialect
            foreign_keys = dialect.get_foreign_keys(
                connection=c,
                schema=self.schema,
                table_name="s",
                use_sql_fallback=use_sql_fallback)
            expected = [{
                'name': 'fk_test',
                'constrained_columns': ['fid1', 'fid2'],
                'referred_schema': 'test_get_metadata_functions_schema',
                'referred_table': 't',
                'referred_columns': ['pid1', 'pid2']
            }]

            assert foreign_keys == expected
예제 #9
0
class Db(object):
    def __init__(self, sqla_conn, args, schemas=[None]):
        self.args = args
        self.sqla_conn = sqla_conn
        self.schemas = schemas
        self.engine = sa.create_engine(sqla_conn)
        self.inspector = Inspector(bind=self.engine)
        self.conn = self.engine.connect()
        self.tables = OrderedDict()

        for schema in self.schemas:
            meta = sa.MetaData(bind=self.engine)  # excised schema=schema to prevent errors
            meta.reflect(schema=schema)
            for tbl in meta.sorted_tables:
                if args.tables and not _table_matches_any_pattern(tbl.schema, tbl.name, self.args.tables):
                    continue
                if _table_matches_any_pattern(tbl.schema, tbl.name, self.args.exclude_tables):
                    continue
                tbl.db = self
                # TODO: Replace all these monkeypatches with an instance assigment
                tbl.find_n_rows = types.MethodType(_find_n_rows, tbl)
                tbl.random_row_func = types.MethodType(_random_row_func, tbl)
                tbl.fks = self.inspector.get_foreign_keys(tbl.name, schema=tbl.schema)
                tbl.pk = self.inspector.get_primary_keys(tbl.name, schema=tbl.schema)
                if not tbl.pk:
                    tbl.pk = [d["name"] for d in self.inspector.get_columns(tbl.name)]
                tbl.filtered_by = types.MethodType(_filtered_by, tbl)
                tbl.by_pk = types.MethodType(_by_pk, tbl)
                tbl.pk_val = types.MethodType(_pk_val, tbl)
                tbl.child_fks = []
                estimate_rows = not _table_matches_any_pattern(tbl.schema, tbl.name, self.args.full_tables)
                tbl.find_n_rows(estimate=estimate_rows)
                self.tables[(tbl.schema, tbl.name)] = tbl
        all_constraints = args.config.get("constraints", {})
        for ((tbl_schema, tbl_name), tbl) in self.tables.items():
            qualified = "{}.{}".format(tbl_schema, tbl_name)
            if qualified in all_constraints:
                constraints = all_constraints[qualified]
            else:
                constraints = all_constraints.get(tbl_name, [])
            tbl.constraints = constraints
            for fk in tbl.fks + constraints:
                fk["constrained_schema"] = tbl_schema
                fk["constrained_table"] = tbl_name  # TODO: check against constrained_table
                self.tables[(fk["referred_schema"], fk["referred_table"])].child_fks.append(fk)

    def __repr__(self):
        return "Db('%s')" % self.sqla_conn

    def assign_target(self, target_db):
        for ((tbl_schema, tbl_name), tbl) in self.tables.items():
            tbl._random_row_gen_fn = types.MethodType(_random_row_gen_fn, tbl)
            tbl.random_rows = tbl._random_row_gen_fn()
            tbl.next_row = types.MethodType(_next_row, tbl)
            target = target_db.tables[(tbl_schema, tbl_name)]
            target.requested = deque()
            target.required = deque()
            target.pending = dict()
            target.done = set()
            target.fetch_all = False
            if _table_matches_any_pattern(tbl.schema, tbl.name, self.args.full_tables):
                target.n_rows_desired = tbl.n_rows
                target.fetch_all = True
            else:
                if tbl.n_rows:
                    if self.args.logarithmic:
                        target.n_rows_desired = int(math.pow(10, math.log10(tbl.n_rows) * self.args.fraction)) or 1
                    else:
                        target.n_rows_desired = int(tbl.n_rows * self.args.fraction) or 1
                else:
                    target.n_rows_desired = 0
            target.source = tbl
            tbl.target = target
            target.completeness_score = types.MethodType(_completeness_score, target)
            logging.debug("assigned methods to %s" % target.name)

    def confirm(self):
        message = []
        for (tbl_schema, tbl_name) in sorted(self.tables, key=lambda t: t[1]):
            tbl = self.tables[(tbl_schema, tbl_name)]
            message.append(
                "Create %d rows from %d in %s.%s" % (tbl.target.n_rows_desired, tbl.n_rows, tbl_schema or "", tbl_name)
            )
        print("\n".join(sorted(message)))
        if self.args.yes:
            return True
        response = input("Proceed? (Y/n) ").strip().lower()
        return (not response) or (response[0] == "y")

    def create_row_in(self, source_row, target_db, target, prioritized=False):
        logging.debug("create_row_in %s:%s " % (target.name, target.pk_val(source_row)))

        pks = tuple((source_row[key] for key in target.pk))
        row_exists = pks in target.pending or pks in target.done
        logging.debug("Row exists? %s" % str(row_exists))
        if row_exists and not prioritized:
            return

        if not row_exists:
            # make sure that all required rows are in parent table(s)
            for fk in target.fks:
                target_parent = target_db.tables[(fk["referred_schema"], fk["referred_table"])]
                slct = sa.sql.select([target_parent])
                any_non_null_key_columns = False
                for (parent_col, child_col) in zip(fk["referred_columns"], fk["constrained_columns"]):
                    slct = slct.where(target_parent.c[parent_col] == source_row[child_col])
                    if source_row[child_col] is not None:
                        any_non_null_key_columns = True
                        break
                if any_non_null_key_columns:
                    target_parent_row = target_db.conn.execute(slct).first()
                    if not target_parent_row:
                        source_parent_row = self.conn.execute(slct).first()
                        self.create_row_in(source_parent_row, target_db, target_parent)

            # make sure that all referenced rows are in referenced table(s)
            for constraint in target.constraints:
                target_referred = target_db.tables[(constraint["referred_schema"], constraint["referred_table"])]
                slct = sa.sql.select([target_referred])
                any_non_null_key_columns = False
                for (referred_col, constrained_col) in zip(
                    constraint["referred_columns"], constraint["constrained_columns"]
                ):
                    slct = slct.where(target_referred.c[referred_col] == source_row[constrained_col])
                    if source_row[constrained_col] is not None:
                        any_non_null_key_columns = True
                        break
                if any_non_null_key_columns:
                    target_referred_row = target_db.conn.execute(slct).first()
                    if not target_referred_row:
                        source_referred_row = self.conn.execute(slct).first()
                        # because constraints aren't enforced like real FKs, the referred row isn't guaranteed to exist
                        if source_referred_row:
                            self.create_row_in(source_referred_row, target_db, target_referred)

            pks = tuple((source_row[key] for key in target.pk))
            target.pending[pks] = source_row
            target.n_rows += 1

        for child_fk in target.child_fks:
            child = self.tables[(child_fk["constrained_schema"], child_fk["constrained_table"])]
            slct = sa.sql.select([child])
            for (child_col, this_col) in zip(child_fk["constrained_columns"], child_fk["referred_columns"]):
                slct = slct.where(child.c[child_col] == source_row[this_col])
            if not prioritized:
                slct = slct.limit(self.args.children)
            for (n, desired_row) in enumerate(self.conn.execute(slct)):
                if prioritized:
                    child.target.required.append((desired_row, prioritized))
                elif n == 0:
                    child.target.requested.appendleft((desired_row, prioritized))
                else:
                    child.target.requested.append((desired_row, prioritized))

    @property
    def pending(self):
        return functools.reduce(lambda count, table: count + len(table.pending), self.tables.values(), 0)

    def flush(self):
        for table in self.tables.values():
            if not table.pending:
                continue
            self.conn.execute(table.insert(), list(table.pending.values()))
            table.done = table.done.union(table.pending.keys())
            table.pending = dict()

    def create_subset_in(self, target_db):

        for (tbl_name, pks) in self.args.force_rows.items():
            if "." in tbl_name:
                (tbl_schema, tbl_name) = tbl_name.split(".", 1)
            else:
                tbl_schema = None
            source = self.tables[(tbl_schema, tbl_name)]
            for pk in pks:
                source_row = source.by_pk(pk)
                if source_row:
                    self.create_row_in(source_row, target_db, source.target, prioritized=True)
                else:
                    logging.warn("requested %s:%s not found in source db," "could not create" % (source.name, pk))

        # import pdb; pdb.set_trace()
        while True:
            targets = sorted(target_db.tables.values(), key=lambda t: t.completeness_score())
            try:
                target = targets.pop(0)
                while not target.source.n_rows:
                    target = targets.pop(0)
            except IndexError:  # pop failure, no more tables
                break
            logging.debug("total n_rows in target: %d" % sum((t.n_rows for t in target_db.tables.values())))
            logging.debug(
                "target tables with 0 n_rows: %s" % ", ".join(t.name for t in target_db.tables.values() if not t.n_rows)
            )
            logging.info("lowest completeness score (in %s) at %f" % (target.name, target.completeness_score()))
            if target.completeness_score() > 0.97:
                break
            (source_row, prioritized) = target.source.next_row()
            self.create_row_in(source_row, target_db, target, prioritized=prioritized)

            if target_db.pending > self.args.buffer:
                target_db.flush()

        target_db.flush()
        def verify_thd(conn):
            metadata = sa.MetaData()
            metadata.bind = conn

            # Verify database contents.

            # 'workers' table contents.
            workers = sautils.Table('workers', metadata, autoload=True)
            c = workers.c
            q = sa.select(
                [c.id, c.name, c.info]
            ).order_by(c.id)
            self.assertEqual(
                q.execute().fetchall(), [
                    (30, 'worker-1', '{}'),
                    (31, 'worker-2', '{"a": 1}'),
                ])

            # 'builds' table contents.
            builds = sautils.Table('builds', metadata, autoload=True)
            c = builds.c
            q = sa.select(
                [c.id, c.number, c.builderid, c.buildrequestid, c.workerid,
                 c.masterid, c.started_at, c.complete_at, c.state_string,
                 c.results]
            ).order_by(c.id)
            self.assertEqual(
                q.execute().fetchall(), [
                    (40, 1, None, 20, 30, 10, 1000, None, 'state', None),
                    (41, 2, 50, 21, None, 11, 2000, 3000, 'state 2', 9),
                ])

            # 'configured_workers' table contents.
            configured_workers = sautils.Table(
                'configured_workers', metadata, autoload=True)
            c = configured_workers.c
            q = sa.select(
                [c.id, c.buildermasterid, c.workerid]
            ).order_by(c.id)
            self.assertEqual(
                q.execute().fetchall(), [
                    (60, 70, 30),
                    (61, 71, 31),
                ])

            # 'connected_workers' table contents.
            connected_workers = sautils.Table(
                'connected_workers', metadata, autoload=True)
            c = connected_workers.c
            q = sa.select(
                [c.id, c.masterid, c.workerid]
            ).order_by(c.id)
            self.assertEqual(
                q.execute().fetchall(), [
                    (80, 10, 30),
                    (81, 11, 31),
                ])

            # Verify that there is no "slave"-named items in schema.
            inspector = Inspector(conn)

            def check_name(name, table_name, item_type):
                if not name:
                    return
                self.assertTrue(
                    "slave" not in name.lower(),
                    msg="'slave'-named {type} in table '{table}': "
                        "'{name}'".format(
                        type=item_type, table=table_name,
                        name=name))

            # Check every table.
            for table_name in inspector.get_table_names():
                # Check table name.
                check_name(table_name, table_name, "table name")

                # Check column names.
                for column_info in inspector.get_columns(table_name):
                    check_name(column_info['name'], table_name, "column")

                # Check foreign key names.
                for fk_info in inspector.get_foreign_keys(table_name):
                    check_name(fk_info['name'], table_name, "foreign key")

                # Check indexes names.
                for index_info in inspector.get_indexes(table_name):
                    check_name(index_info['name'], table_name, "index")

                # Check primary keys constraints names.
                pk_info = inspector.get_pk_constraint(table_name)
                check_name(pk_info.get('name'), table_name, "primary key")

            # Test that no "slave"-named items present in schema
            for name in inspector.get_schema_names():
                self.assertTrue("slave" not in name.lower())
예제 #11
0
class Db(object):
    def __init__(self, sqla_conn, args, schema=None):
        self.args = args
        self.sqla_conn = sqla_conn
        self.schema = schema
        self.engine = sa.create_engine(sqla_conn)
        self.meta = sa.MetaData(
            bind=self.engine)  # excised schema=schema to prevent errors
        self.meta.reflect(schema=self.schema)
        self.inspector = Inspector(bind=self.engine)
        self.conn = self.engine.connect()
        self.tables = OrderedDict()
        for tbl in self.meta.sorted_tables:
            tbl.db = self
            # TODO: Replace all these monkeypatches with an instance assigment
            tbl.find_n_rows = types.MethodType(_find_n_rows, tbl)
            tbl.random_row_func = types.MethodType(_random_row_func, tbl)
            tbl.fks = self.inspector.get_foreign_keys(tbl.name,
                                                      schema=tbl.schema)
            tbl.pk = self.inspector.get_primary_keys(tbl.name,
                                                     schema=tbl.schema)
            tbl.filtered_by = types.MethodType(_filtered_by, tbl)
            tbl.by_pk = types.MethodType(_by_pk, tbl)
            tbl.pk_val = types.MethodType(_pk_val, tbl)
            tbl.exists = types.MethodType(_exists, tbl)
            tbl.child_fks = []
            tbl.find_n_rows(estimate=True)
            self.tables[(tbl.schema, tbl.name)] = tbl
        for ((tbl_schema, tbl_name), tbl) in self.tables.items():
            for fk in tbl.fks:
                fk['constrained_schema'] = tbl_schema
                fk['constrained_table'] = tbl_name  # TODO: check against constrained_table
                self.tables[(fk['referred_schema'],
                             fk['referred_table'])].child_fks.append(fk)

    def __repr__(self):
        return "Db('%s')" % self.sqla_conn

    def assign_target(self, target_db):
        for ((tbl_schema, tbl_name), tbl) in self.tables.items():
            tbl._random_row_gen_fn = types.MethodType(_random_row_gen_fn, tbl)
            tbl.random_rows = tbl._random_row_gen_fn()
            tbl.next_row = types.MethodType(_next_row, tbl)
            target = target_db.tables[(tbl_schema, tbl_name)]
            target.requested = deque()
            target.required = deque()
            if tbl.n_rows:
                if self.args.logarithmic:
                    target.n_rows_desired = int(
                        math.pow(
                            10,
                            math.log10(tbl.n_rows) * self.args.fraction)) or 1
                else:
                    target.n_rows_desired = int(
                        tbl.n_rows * self.args.fraction) or 1
            else:
                target.n_rows_desired = 0
            target.source = tbl
            tbl.target = target
            target.completeness_score = types.MethodType(
                _completeness_score, target)
            logging.debug("assigned methods to %s" % target.name)

    def confirm(self):
        message = []
        for (tbl_schema, tbl_name) in sorted(self.tables, key=lambda t: t[1]):
            tbl = self.tables[(tbl_schema, tbl_name)]
            message.append("Create %d rows from %d in %s.%s" %
                           (tbl.target.n_rows_desired, tbl.n_rows, tbl_schema
                            or '', tbl_name))
        print("\n".join(sorted(message)))
        if self.args.yes:
            return True
        response = input("Proceed? (Y/n) ").strip().lower()
        return (not response) or (response[0] == 'y')

    def create_row_in(self, source_row, target_db, target, prioritized=False):
        logging.debug('create_row_in %s:%s ' %
                      (target.name, target.pk_val(source_row)))

        row_exists = target.exists(**(dict(source_row)))
        logging.debug("Row exists? %s" % str(row_exists))
        if row_exists and not prioritized:
            return

        if not row_exists:
            # make sure that all required rows are in parent table(s)
            for fk in target.fks:
                target_parent = target_db.tables[(fk['referred_schema'],
                                                  fk['referred_table'])]
                slct = sa.sql.select([
                    target_parent,
                ])
                any_non_null_key_columns = False
                for (parent_col, child_col) in zip(fk['referred_columns'],
                                                   fk['constrained_columns']):
                    slct = slct.where(
                        target_parent.c[parent_col] == source_row[child_col])
                    if source_row[child_col] is not None:
                        any_non_null_key_columns = True
                if any_non_null_key_columns:
                    target_parent_row = target_db.conn.execute(slct).first()
                    if not target_parent_row:
                        source_parent_row = self.conn.execute(slct).first()
                        self.create_row_in(source_parent_row, target_db,
                                           target_parent)

            ins = target.insert().values(**source_row)
            target_db.conn.execute(ins)
            target.n_rows += 1

        for child_fk in target.child_fks:
            child = self.tables[(child_fk['constrained_schema'],
                                 child_fk['constrained_table'])]
            slct = sa.sql.select([
                child,
            ])
            for (child_col, this_col) in zip(child_fk['constrained_columns'],
                                             child_fk['referred_columns']):
                slct = slct.where(child.c[child_col] == source_row[this_col])
            if not prioritized:
                slct = slct.limit(self.args.children)
            for (n, desired_row) in enumerate(self.conn.execute(slct)):
                if prioritized:
                    child.target.required.append((desired_row, prioritized))
                elif (n == 0):
                    child.target.requested.appendleft(
                        (desired_row, prioritized))
                else:
                    child.target.requested.append((desired_row, prioritized))

    def create_subset_in(self, target_db):

        for (tbl_name, pks) in self.args.force_rows.items():
            if '.' in tbl_name:
                (tbl_schema, tbl_name) = tbl_name.split('.', 1)
            else:
                tbl_schema = None
            source = self.tables[(tbl_schema, tbl_name)]
            for pk in pks:
                source_row = source.by_pk(pk)
                if source_row:
                    self.create_row_in(source_row,
                                       target_db,
                                       source.target,
                                       prioritized=True)
                else:
                    logging.warn("requested %s:%s not found in source db,"
                                 "could not create" % (source.name, pk))

        while True:
            targets = sorted(target_db.tables.values(),
                             key=lambda t: t.completeness_score())
            try:
                target = targets.pop(0)
                while not target.source.n_rows:
                    target = targets.pop(0)
            except IndexError:  # pop failure, no more tables
                return
            logging.debug("total n_rows in target: %d" % sum(
                (t.n_rows for t in target_db.tables.values())))
            logging.debug(
                "target tables with 0 n_rows: %s" %
                ", ".join(t.name
                          for t in target_db.tables.values() if not t.n_rows))
            logging.info("lowest completeness score (in %s) at %f" %
                         (target.name, target.completeness_score()))
            if target.completeness_score() > 0.97:
                return
            (source_row, prioritized) = target.source.next_row()
            self.create_row_in(source_row,
                               target_db,
                               target,
                               prioritized=prioritized)
예제 #12
0
        def verify_thd(conn):
            metadata = sa.MetaData()
            metadata.bind = conn

            # Verify database contents.

            # 'workers' table contents.
            workers = sautils.Table('workers', metadata, autoload=True)
            c = workers.c
            q = sa.select([c.id, c.name, c.info]).order_by(c.id)
            self.assertEqual(q.execute().fetchall(), [
                (30, u'worker-1', u'{}'),
                (31, u'worker-2', u'{"a": 1}'),
            ])

            # 'builds' table contents.
            builds = sautils.Table('builds', metadata, autoload=True)
            c = builds.c
            q = sa.select([
                c.id, c.number, c.builderid, c.buildrequestid, c.workerid,
                c.masterid, c.started_at, c.complete_at, c.state_string,
                c.results
            ]).order_by(c.id)
            self.assertEqual(q.execute().fetchall(), [
                (40, 1, None, 20, 30, 10, 1000, None, u'state', None),
                (41, 2, 50, 21, None, 11, 2000, 3000, u'state 2', 9),
            ])

            # 'configured_workers' table contents.
            configured_workers = sautils.Table('configured_workers',
                                               metadata,
                                               autoload=True)
            c = configured_workers.c
            q = sa.select([c.id, c.buildermasterid, c.workerid]).order_by(c.id)
            self.assertEqual(q.execute().fetchall(), [
                (60, 70, 30),
                (61, 71, 31),
            ])

            # 'connected_workers' table contents.
            connected_workers = sautils.Table('connected_workers',
                                              metadata,
                                              autoload=True)
            c = connected_workers.c
            q = sa.select([c.id, c.masterid, c.workerid]).order_by(c.id)
            self.assertEqual(q.execute().fetchall(), [
                (80, 10, 30),
                (81, 11, 31),
            ])

            # Verify that there is no "slave"-named items in schema.
            inspector = Inspector(conn)

            def check_name(name, table_name, item_type):
                if not name:
                    return
                self.assertTrue(
                    u"slave" not in name.lower(),
                    msg=u"'slave'-named {type} in table '{table}': "
                    u"'{name}'".format(type=item_type,
                                       table=table_name,
                                       name=name))

            # Check every table.
            for table_name in inspector.get_table_names():
                # Check table name.
                check_name(table_name, table_name, u"table name")

                # Check column names.
                for column_info in inspector.get_columns(table_name):
                    check_name(column_info['name'], table_name, u"column")

                # Check foreign key names.
                for fk_info in inspector.get_foreign_keys(table_name):
                    check_name(fk_info['name'], table_name, u"foreign key")

                # Check indexes names.
                for index_info in inspector.get_indexes(table_name):
                    check_name(index_info['name'], table_name, u"index")

                # Check primary keys constraints names.
                pk_info = inspector.get_pk_constraint(table_name)
                check_name(pk_info.get('name'), table_name, u"primary key")

            # Test that no "slave"-named items present in schema
            for name in inspector.get_schema_names():
                self.assertTrue(u"slave" not in name.lower())
예제 #13
0
    def loop_tables(
        self,
        inspector: Inspector,
        schema: str,
        sql_config: SQLAlchemyConfig,
    ) -> Iterable[SqlWorkUnit]:
        tables_seen: Set[str] = set()
        for table in inspector.get_table_names(schema):
            schema, table = self.standardize_schema_table_names(schema=schema,
                                                                entity=table)
            dataset_name = self.get_identifier(schema=schema,
                                               entity=table,
                                               inspector=inspector)
            if dataset_name not in tables_seen:
                tables_seen.add(dataset_name)
            else:
                logger.debug(
                    f"{dataset_name} has already been seen, skipping...")
                continue

            self.report.report_entity_scanned(dataset_name, ent_type="table")

            if not sql_config.table_pattern.allowed(dataset_name):
                self.report.report_dropped(dataset_name)
                continue

            columns = []
            try:
                columns = inspector.get_columns(table, schema)
                if len(columns) == 0:
                    self.report.report_warning(dataset_name,
                                               "missing column information")
            except Exception as e:
                self.report.report_warning(
                    dataset_name,
                    f"unable to get column information due to an error -> {e}",
                )

            try:
                # SQLALchemy stubs are incomplete and missing this method.
                # PR: https://github.com/dropbox/sqlalchemy-stubs/pull/223.
                table_info: dict = inspector.get_table_comment(
                    table, schema)  # type: ignore
            except NotImplementedError:
                description: Optional[str] = None
                properties: Dict[str, str] = {}
            except ProgrammingError as pe:
                # Snowflake needs schema names quoted when fetching table comments.
                logger.debug(
                    f"Encountered ProgrammingError. Retrying with quoted schema name for schema {schema} and table {table}",
                    pe,
                )
                description = None
                properties = {}
                table_info: dict = inspector.get_table_comment(
                    table, f'"{schema}"')  # type: ignore
            else:
                description = table_info["text"]

                # The "properties" field is a non-standard addition to SQLAlchemy's interface.
                properties = table_info.get("properties", {})

            dataset_urn = make_dataset_urn(self.platform, dataset_name,
                                           self.config.env)
            dataset_snapshot = DatasetSnapshot(
                urn=dataset_urn,
                aspects=[StatusClass(removed=False)],
            )
            if self.is_stateful_ingestion_configured():
                cur_checkpoint = self.get_current_checkpoint(
                    self.get_default_ingestion_job_id())
                if cur_checkpoint is not None:
                    checkpoint_state = cast(BaseSQLAlchemyCheckpointState,
                                            cur_checkpoint.state)
                    checkpoint_state.add_table_urn(dataset_urn)

            if description is not None or properties:
                dataset_properties = DatasetPropertiesClass(
                    description=description,
                    customProperties=properties,
                )
                dataset_snapshot.aspects.append(dataset_properties)

            pk_constraints: dict = inspector.get_pk_constraint(table, schema)
            try:
                foreign_keys = [
                    self.get_foreign_key_metadata(dataset_urn, schema, fk_rec,
                                                  inspector)
                    for fk_rec in inspector.get_foreign_keys(table, schema)
                ]
            except KeyError:
                # certain databases like MySQL cause issues due to lower-case/upper-case irregularities
                logger.debug(
                    f"{dataset_urn}: failure in foreign key extraction... skipping"
                )
                foreign_keys = []

            schema_fields = self.get_schema_fields(dataset_name, columns,
                                                   pk_constraints)

            schema_metadata = get_schema_metadata(
                self.report,
                dataset_name,
                self.platform,
                columns,
                pk_constraints,
                foreign_keys,
                schema_fields,
            )
            dataset_snapshot.aspects.append(schema_metadata)

            mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot)
            wu = SqlWorkUnit(id=dataset_name, mce=mce)
            self.report.report_workunit(wu)
            yield wu
예제 #14
0
class Db(object):
    def __init__(self, sqla_conn, args, schemas=None):
        if schemas is None:
            schemas = [None]
        self.args = args
        self.sqla_conn = sqla_conn
        self.schemas = schemas
        self.engine = sa.create_engine(sqla_conn)
        self.inspector = Inspector(bind=self.engine)
        self.conn = self.engine.connect()
        self.tables = OrderedDict()

        for schema in self.schemas:
            meta = sa.MetaData(
                bind=self.engine)  # excised schema=schema to prevent errors
            meta.reflect(schema=schema)
            for tbl in meta.sorted_tables:
                if args.tables and not _table_matches_any_pattern(
                        tbl.schema, tbl.name, self.args.tables):
                    continue
                if _table_matches_any_pattern(tbl.schema, tbl.name,
                                              self.args.exclude_tables):
                    continue
                tbl.db = self

                if self.engine.name == 'postgresql':
                    fix_postgres_array_of_enum(self.conn, tbl)

                # TODO: Replace all these monkeypatches with an instance assigment
                tbl.find_n_rows = types.MethodType(_find_n_rows, tbl)
                tbl.random_row_func = types.MethodType(_random_row_func, tbl)
                tbl.fks = self.inspector.get_foreign_keys(tbl.name,
                                                          schema=tbl.schema)
                tbl.pk = self.inspector.get_primary_keys(tbl.name,
                                                         schema=tbl.schema)
                if not tbl.pk:
                    tbl.pk = [
                        d['name']
                        for d in self.inspector.get_columns(tbl.name,
                                                            schema=tbl.schema)
                    ]
                tbl.filtered_by = types.MethodType(_filtered_by, tbl)
                tbl.by_pk = types.MethodType(_by_pk, tbl)
                tbl.pk_val = types.MethodType(_pk_val, tbl)
                tbl.child_fks = []
                estimate_rows = not _table_matches_any_pattern(
                    tbl.schema, tbl.name, self.args.full_tables)
                tbl.find_n_rows(estimate=estimate_rows)
                self.tables[(tbl.schema, tbl.name)] = tbl
        all_constraints = args.config.get('constraints', {})
        for ((tbl_schema, tbl_name), tbl) in self.tables.items():
            qualified = "{}.{}".format(tbl_schema, tbl_name)
            if qualified in all_constraints:
                constraints = all_constraints[qualified]
            else:
                constraints = all_constraints.get(tbl_name, [])
            tbl.constraints = constraints
            for fk in (tbl.fks + constraints):
                fk['constrained_schema'] = tbl_schema
                fk['constrained_table'] = tbl_name  # TODO: check against constrained_table

                self.tables[(fk['referred_schema'],
                             fk['referred_table'])].child_fks.append(fk)

    def __repr__(self):
        return "Db('%s')" % self.sqla_conn

    def assign_target(self, target_db):
        for ((tbl_schema, tbl_name), tbl) in self.tables.items():
            tbl._random_row_gen_fn = types.MethodType(_random_row_gen_fn, tbl)
            tbl.random_rows = tbl._random_row_gen_fn()
            tbl.next_row = types.MethodType(_next_row, tbl)
            target = target_db.tables[(tbl_schema, tbl_name)]
            target.requested = deque()
            target.required = deque()
            target.pending = dict()
            target.done = set()
            target.fetch_all = False
            if _table_matches_any_pattern(tbl.schema, tbl.name,
                                          self.args.full_tables):
                target.n_rows_desired = tbl.n_rows
                target.fetch_all = True
            else:
                if tbl.n_rows:
                    if self.args.logarithmic:
                        target.n_rows_desired = int(
                            math.pow(
                                10,
                                math.log10(tbl.n_rows) *
                                self.args.fraction)) or 1
                    else:
                        target.n_rows_desired = int(
                            tbl.n_rows * self.args.fraction) or 1
                else:
                    target.n_rows_desired = 0
            target.source = tbl
            tbl.target = target
            target.completeness_score = types.MethodType(
                _completeness_score, target)
            logging.debug("assigned methods to %s" % target.name)

    def confirm(self):
        message = []
        for (tbl_schema, tbl_name) in sorted(self.tables, key=lambda t: t[1]):
            tbl = self.tables[(tbl_schema, tbl_name)]
            message.append("Create %d rows from %d in %s.%s" %
                           (tbl.target.n_rows_desired, tbl.n_rows, tbl_schema
                            or '', tbl_name))
        print("\n".join(sorted(message)))
        if self.args.yes:
            return True
        response = input("Proceed? (Y/n) ").strip().lower()
        return (not response) or (response[0] == 'y')

    def create_row_in(self, source_row, target_db, target, prioritized=False):
        logging.debug('create_row_in %s:%s ' %
                      (target.name, target.pk_val(source_row)))

        pks = hashable((source_row[key] for key in target.pk))
        row_exists = pks in target.pending or pks in target.done
        logging.debug("Row exists? %s" % str(row_exists))
        if row_exists and not prioritized:
            return

        if not row_exists:
            # make sure that all required rows are in parent table(s)
            for fk in target.fks:
                target_parent = target_db.tables[(fk['referred_schema'],
                                                  fk['referred_table'])]
                slct = sa.sql.select([
                    target_parent,
                ])
                any_non_null_key_columns = False
                for (parent_col, child_col) in zip(fk['referred_columns'],
                                                   fk['constrained_columns']):
                    slct = slct.where(
                        target_parent.c[parent_col] == source_row[child_col])
                    if source_row[child_col] is not None:
                        any_non_null_key_columns = True
                        break
                if any_non_null_key_columns:
                    target_parent_row = target_db.conn.execute(slct).first()
                    if not target_parent_row:
                        source_parent_row = self.conn.execute(slct).first()
                        self.create_row_in(source_parent_row, target_db,
                                           target_parent)

            pks = hashable((source_row[key] for key in target.pk))
            target.n_rows += 1

            # insert source row here to prevent recursion
            if self.args.buffer == 0:
                target_db.insert_one(target, pks, source_row)
            else:
                target.pending[pks] = source_row

            # make sure that all referenced rows are in referenced table(s)
            for constraint in target.constraints:
                target_referred = target_db.tables[(
                    constraint['referred_schema'],
                    constraint['referred_table'])]
                slct = sa.sql.select([
                    target_referred,
                ])
                any_non_null_key_columns = False
                for (referred_col, constrained_col) in zip(
                        constraint['referred_columns'],
                        constraint['constrained_columns']):
                    slct = slct.where(target_referred.c[referred_col] ==
                                      source_row[constrained_col])
                    if source_row[constrained_col] is not None:
                        any_non_null_key_columns = True
                        break
                if any_non_null_key_columns:
                    target_referred_row = target_db.conn.execute(slct).first()
                    if not target_referred_row:
                        source_referred_row = self.conn.execute(slct).first()
                        # because constraints aren't enforced like real FKs, the referred row isn't guaranteed to exist
                        if source_referred_row:
                            self.create_row_in(source_referred_row, target_db,
                                               target_referred)

            # defer signal?
            signal(SIGNAL_ROW_ADDED).send(self,
                                          source_row=source_row,
                                          target_db=target_db,
                                          target_table=target,
                                          prioritized=prioritized)

        for child_fk in target.child_fks:
            child = self.tables[(child_fk['constrained_schema'],
                                 child_fk['constrained_table'])]
            slct = sa.sql.select([child])
            for (child_col, this_col) in zip(child_fk['constrained_columns'],
                                             child_fk['referred_columns']):
                slct = slct.where(child.c[child_col] == source_row[this_col])
            if not prioritized:
                slct = slct.limit(self.args.children)
            for (n, desired_row) in enumerate(self.conn.execute(slct)):
                if prioritized:
                    child.target.required.append((desired_row, prioritized))
                elif n == 0:
                    child.target.requested.appendleft(
                        (desired_row, prioritized))
                else:
                    child.target.requested.append((desired_row, prioritized))

    @property
    def pending(self):
        return functools.reduce(
            lambda count, table: count + len(table.pending),
            self.tables.values(), 0)

    def insert_one(self, table, pk, values):
        self.conn.execute(table.insert(), values)
        table.done.add(pk)

    def flush(self):
        for table in self.tables.values():
            if not table.pending:
                continue
            self.conn.execute(table.insert(), list(table.pending.values()))
            table.done = table.done.union(table.pending.keys())
            table.pending = dict()

    def create_subset_in(self, target_db):

        for (tbl_name, pks) in self.args.force_rows.items():
            if '.' in tbl_name:
                (tbl_schema, tbl_name) = tbl_name.split('.', 1)
            else:
                tbl_schema = None
            source = self.tables[(tbl_schema, tbl_name)]
            for pk in pks:
                source_row = source.by_pk(pk)
                if source_row:
                    self.create_row_in(source_row,
                                       target_db,
                                       source.target,
                                       prioritized=True)
                else:
                    logging.warn("requested %s:%s not found in source db,"
                                 "could not create" % (source.name, pk))

        while True:
            targets = sorted(target_db.tables.values(),
                             key=lambda t: t.completeness_score())
            try:
                target = targets.pop(0)
                while not target.source.n_rows:
                    target = targets.pop(0)
            except IndexError:  # pop failure, no more tables
                break
            logging.debug("total n_rows in target: %d" % sum(
                (t.n_rows for t in target_db.tables.values())))
            logging.debug(
                "target tables with 0 n_rows: %s" %
                ", ".join(t.name
                          for t in target_db.tables.values() if not t.n_rows))
            logging.info("lowest completeness score (in %s) at %f" %
                         (target.name, target.completeness_score()))
            if target.completeness_score() > 0.97:
                break
            (source_row, prioritized) = target.source.next_row()
            self.create_row_in(source_row,
                               target_db,
                               target,
                               prioritized=prioritized)

            if target_db.pending > self.args.buffer > 0:
                target_db.flush()

        if self.args.buffer > 0:
            target_db.flush()
예제 #15
0
    def loop_tables(
        self,
        inspector: Inspector,
        schema: str,
        sql_config: SQLAlchemyConfig,
    ) -> Iterable[SqlWorkUnit]:
        for table in inspector.get_table_names(schema):
            schema, table = self.standardize_schema_table_names(schema=schema,
                                                                entity=table)
            dataset_name = self.get_identifier(schema=schema,
                                               entity=table,
                                               inspector=inspector)
            self.report.report_entity_scanned(dataset_name, ent_type="table")

            if not sql_config.table_pattern.allowed(dataset_name):
                self.report.report_dropped(dataset_name)
                continue

            columns = inspector.get_columns(table, schema)
            if len(columns) == 0:
                self.report.report_warning(dataset_name,
                                           "missing column information")

            try:
                # SQLALchemy stubs are incomplete and missing this method.
                # PR: https://github.com/dropbox/sqlalchemy-stubs/pull/223.
                table_info: dict = inspector.get_table_comment(
                    table, schema)  # type: ignore
            except NotImplementedError:
                description: Optional[str] = None
                properties: Dict[str, str] = {}
            else:
                description = table_info["text"]

                # The "properties" field is a non-standard addition to SQLAlchemy's interface.
                properties = table_info.get("properties", {})

            datasetUrn = f"urn:li:dataset:(urn:li:dataPlatform:{self.platform},{dataset_name},{self.config.env})"
            dataset_snapshot = DatasetSnapshot(
                urn=datasetUrn,
                aspects=[],
            )

            if description is not None or properties:
                dataset_properties = DatasetPropertiesClass(
                    description=description,
                    customProperties=properties,
                )
                dataset_snapshot.aspects.append(dataset_properties)

            pk_constraints: dict = inspector.get_pk_constraint(table, schema)
            try:
                foreign_keys = [
                    self.get_foreign_key_metadata(datasetUrn, fk_rec,
                                                  inspector)
                    for fk_rec in inspector.get_foreign_keys(table, schema)
                ]
            except KeyError:
                # certain databases like MySQL cause issues due to lower-case/upper-case irregularities
                logger.debug(
                    f"{datasetUrn}: failure in foreign key extraction... skipping"
                )
                foreign_keys = []

            schema_metadata = get_schema_metadata(
                self.report,
                dataset_name,
                self.platform,
                columns,
                pk_constraints,
                foreign_keys,
            )
            dataset_snapshot.aspects.append(schema_metadata)

            mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot)
            wu = SqlWorkUnit(id=dataset_name, mce=mce)
            self.report.report_workunit(wu)
            yield wu
예제 #16
0
class Db(object):

    def __init__(self, sqla_conn, args, schema=None):
        self.args = args
        self.sqla_conn = sqla_conn
        self.schema = schema
        self.engine = sa.create_engine(sqla_conn)
        self.meta = sa.MetaData(bind=self.engine) # excised schema=schema to prevent errors
        self.meta.reflect(schema=self.schema)
        self.inspector = Inspector(bind=self.engine)
        self.conn = self.engine.connect()
        self.tables = OrderedDict()
        for tbl in self.meta.sorted_tables:
            tbl.db = self
            # TODO: Replace all these monkeypatches with an instance assigment
            tbl.find_n_rows = types.MethodType(_find_n_rows, tbl)
            tbl.random_row_func = types.MethodType(_random_row_func, tbl)
            tbl.fks = self.inspector.get_foreign_keys(tbl.name, schema=tbl.schema)
            tbl.pk = self.inspector.get_primary_keys(tbl.name, schema=tbl.schema)
            tbl.filtered_by = types.MethodType(_filtered_by, tbl)
            tbl.by_pk = types.MethodType(_by_pk, tbl)
            tbl.pk_val = types.MethodType(_pk_val, tbl)
            tbl.exists = types.MethodType(_exists, tbl)
            tbl.child_fks = []
            tbl.find_n_rows(estimate=True)
            self.tables[(tbl.schema, tbl.name)] = tbl
        for ((tbl_schema, tbl_name), tbl) in self.tables.items():
            for fk in tbl.fks:
                fk['constrained_schema'] = tbl_schema
                fk['constrained_table'] = tbl_name  # TODO: check against constrained_table
                self.tables[(fk['referred_schema'], fk['referred_table'])].child_fks.append(fk)

    def __repr__(self):
        return "Db('%s')" % self.sqla_conn

    def assign_target(self, target_db):
        for ((tbl_schema, tbl_name), tbl) in self.tables.items():
            tbl._random_row_gen_fn = types.MethodType(_random_row_gen_fn, tbl)
            tbl.random_rows = tbl._random_row_gen_fn()
            tbl.next_row = types.MethodType(_next_row, tbl)
            target = target_db.tables[(tbl_schema, tbl_name)]
            target.requested = deque()
            target.required = deque()
            if tbl.n_rows:
                if self.args.logarithmic:
                    target.n_rows_desired = int(math.pow(10, math.log10(tbl.n_rows)
                                                * self.args.fraction)) or 1
                else:
                    target.n_rows_desired = int(tbl.n_rows * self.args.fraction) or 1
            else:
                target.n_rows_desired = 0
            target.source = tbl
            tbl.target = target
            target.completeness_score = types.MethodType(_completeness_score, target)
            logging.debug("assigned methods to %s" % target.name)

    def confirm(self):
        message = []
        for (tbl_schema, tbl_name) in sorted(self.tables, key=lambda t: t[1]):
            tbl = self.tables[(tbl_schema, tbl_name)]
            message.append("Create %d rows from %d in %s.%s" %
                           (tbl.target.n_rows_desired, tbl.n_rows,
                            tbl_schema or '', tbl_name))
        print("\n".join(sorted(message)))
        if self.args.yes:
            return True
        response = input("Proceed? (Y/n) ").strip().lower()
        return (not response) or (response[0] == 'y')


    def create_row_in(self, source_row, target_db, target, prioritized=False):
        logging.debug('create_row_in %s:%s ' %
                      (target.name, target.pk_val(source_row)))

        row_exists = target.exists(**(dict(source_row)))
        logging.debug("Row exists? %s" % str(row_exists))
        if row_exists and not prioritized:
            return

        if not row_exists:
            # make sure that all required rows are in parent table(s)
            for fk in target.fks:
                target_parent = target_db.tables[(fk['referred_schema'], fk['referred_table'])]
                slct = sa.sql.select([target_parent,])
                any_non_null_key_columns = False
                for (parent_col, child_col) in zip(fk['referred_columns'],
                                                   fk['constrained_columns']):
                    slct = slct.where(target_parent.c[parent_col] ==
                                      source_row[child_col])
                    if source_row[child_col] is not None:
                        any_non_null_key_columns = True
                if any_non_null_key_columns:
                    target_parent_row = target_db.conn.execute(slct).first()
                    if not target_parent_row:
                        source_parent_row = self.conn.execute(slct).first()
                        self.create_row_in(source_parent_row, target_db, target_parent)

            ins = target.insert().values(**source_row)
            target_db.conn.execute(ins)
            target.n_rows += 1

        for child_fk in target.child_fks:
            child = self.tables[(child_fk['constrained_schema'], child_fk['constrained_table'])]
            slct = sa.sql.select([child,])
            for (child_col, this_col) in zip(child_fk['constrained_columns'],
                                             child_fk['referred_columns']):
                slct = slct.where(child.c[child_col] == source_row[this_col])
            if not prioritized:
                slct = slct.limit(self.args.children)
            for (n, desired_row )in enumerate(self.conn.execute(slct)):
                if prioritized:
                    child.target.required.append((desired_row, prioritized))
                elif (n == 0):
                    child.target.requested.appendleft((desired_row, prioritized))
                else:
                    child.target.requested.append((desired_row, prioritized))

    def create_subset_in(self, target_db):

        for (tbl_name, pks) in self.args.force_rows.items():
            if '.' in tbl_name:
                (tbl_schema, tbl_name) = tbl_name.split('.', 1)
            else:
                tbl_schema = None
            source = self.tables[(tbl_schema, tbl_name)]
            for pk in pks:
                source_row = source.by_pk(pk)
                if source_row:
                    self.create_row_in(source_row, target_db,
                                       source.target, prioritized=True)
                else:
                    logging.warn("requested %s:%s not found in source db,"
                                 "could not create" % (source.name, pk))


        while True:
            targets = sorted(target_db.tables.values(),
                             key=lambda t: t.completeness_score())
            try:
                target = targets.pop(0)
                while not target.source.n_rows:
                    target = targets.pop(0)
            except IndexError: # pop failure, no more tables
                return
            logging.debug("total n_rows in target: %d" %
                          sum((t.n_rows for t in target_db.tables.values())))
            logging.debug("target tables with 0 n_rows: %s" %
                          ", ".join(t.name for t in target_db.tables.values()
                                    if not t.n_rows))
            logging.info("lowest completeness score (in %s) at %f" %
                         (target.name, target.completeness_score()))
            if target.completeness_score() > 0.97:
                return
            (source_row, prioritized) = target.source.next_row()
            self.create_row_in(source_row, target_db, target,
                               prioritized=prioritized)