def setUp(self): """ Init a temporary table with some data. """ self.left_table_name = "raw_booking" self.right_table_name = "booking" self.ts_nodash = (FakedDatetime.now().isoformat().replace("-", "").replace( ":", "")) self.now = FakedDatetime.now() sql = [ "DROP SCHEMA if exists tmp CASCADE;", "CREATE SCHEMA IF NOT EXISTS tmp;", "CREATE SCHEMA IF NOT EXISTS hello;", f""" CREATE TABLE IF NOT EXISTS tmp.{self.left_table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" CREATE TABLE IF NOT EXISTS hello.{self.right_table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" INSERT INTO tmp.{self.left_table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'), (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00'), ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00') """, f""" INSERT INTO hello.{self.right_table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'), (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00') """, ] self.conn = Connector(TEST_DB_URI) for s in sql: self.conn.execute(s) self.consistency_checker = ConsistencyChecker(TEST_DB_URI)
def setUp(self): """ Init a temporary table with some data. """ sql = [ f"DROP SCHEMA IF EXISTS {DATA_QUALITY_SCHEMA} CASCADE;" f"CREATE SCHEMA IF NOT EXISTS {DATA_QUALITY_SCHEMA};", get_quality_table_creation_script(DATA_QUALITY_SCHEMA, DATA_QUALITY_TABLE_1), get_quality_table_creation_script(DATA_QUALITY_SCHEMA, DATA_QUALITY_TABLE_2), f""" create table {DATA_QUALITY_SCHEMA}.{MIGRATION_TABLE} ( version_num varchar(32) not null constraint alembic_version_pkc primary key ); INSERT INTO {DATA_QUALITY_SCHEMA}.{MIGRATION_TABLE} (version_num) VALUES ('0.1.4-hash'); """, ] self.conn = Connector(TEST_DB_URI) for s in sql: self.conn.execute(s)
def test_execute_consistency_diff(self): """ Test scenario where table diff is done on tables with different column order.""" self.inconsistent_colums_left = "user" self.inconsistent_colums_right = "user_inconsistent" self.conn = Connector(TEST_DB_URI) self.consistency_checker = ConsistencyChecker(TEST_DB_URI) sql = [ f""" CREATE TABLE IF NOT EXISTS tmp.{self.inconsistent_colums_left}( id SERIAL PRIMARY KEY, name text ) """, f""" CREATE TABLE IF NOT EXISTS tmp.{self.inconsistent_colums_right}( name text, id SERIAL PRIMARY KEY ) """, f""" INSERT INTO tmp.{self.inconsistent_colums_left} VALUES (1, 'John Doe') """, f""" INSERT INTO tmp.{self.inconsistent_colums_right} VALUES ('John Doe', 1) """, ] for s in sql: self.conn.execute(s) self.consistency_checker.run( self.consistency_checker.DIFF, left_check_table={ "schema_name": "tmp", "table_name": self.inconsistent_colums_left, }, right_check_table={ "schema_name": "tmp", "table_name": self.inconsistent_colums_right, }, result_table={ "schema_name": "data_quality", "table_name": self.result_table_name, }, context={"task_ts": self.now}, ) rows = self.conn.get_records( f""" SELECT * from data_quality.consistency_check_{self.result_table_name} order by created_at """ ) self.assertEqual(rows.fetchone()["status"], "valid")
def __init__(self, left_conn_uri_or_engine, right_conn_uri_or_engine=None): self.left_conn_uri_or_engine = left_conn_uri_or_engine self.left_conn = Connector(left_conn_uri_or_engine) if right_conn_uri_or_engine is None: self.right_conn_uri_or_engine = self.left_conn_uri_or_engine self.right_conn = self.left_conn else: self.right_conn_uri_or_engine = right_conn_uri_or_engine self.right_conn = Connector(right_conn_uri_or_engine)
def __init__(self, migrations_map, package_version, url, schema): """ :param migrations_map: map of package versions and their migrations. In form of dictionary {'0.1.4':'A', '0.1.5':'B'} :param package_version: the version of the package planned to be migrated :param url: the database url where the Alembic migration table is present or planned to be created :param schema: the database schema where the Alembic migration table is present or planned to be created """ self.versions_migrations = migrations_map self.package_version = package_version self.url = url self.schema = schema self.conn = Connector(self.url)
def setUp(self): """ Init a temporary table with some data. """ sql = [ f"DROP SCHEMA IF EXISTS {DATA_QUALITY_SCHEMA} CASCADE;" f"CREATE SCHEMA IF NOT EXISTS {DATA_QUALITY_SCHEMA};", get_quality_table_creation_script(DATA_QUALITY_SCHEMA, DATA_QUALITY_TABLE_1), get_quality_table_creation_script(DATA_QUALITY_SCHEMA, DATA_QUALITY_TABLE_2), ] self.conn = Connector(TEST_DB_URI) for s in sql: self.conn.execute(s)
def test_set_medians(conn: Connector, monkeypatch): DQBase.metadata.clear() qc = create_default_quality_check_class( ResultTable(schema_name="data_quality", table_name="t")) qc.__table__.create(conn.engine) instance = qc() conn.execute(""" insert into data_quality.quality_check_t(failed, passed, task_ts) values (10, 200, '2018-09-11T13:00:00'), (3, 22, '2018-09-10T13:00:00'), (11, 110, '2018-09-09T13:00:00'), (55, 476, '2018-09-08T13:00:00'), (77, 309, '2018-07-12T13:00:00') -- should not be taken """) monkeypatch.setattr("contessa.models.datetime", FakedDatetime) instance.set_medians(conn) assert instance.median_30_day_failed == 10.5 assert instance.median_30_day_passed == 155
def test_set_medians(conn: Connector, monkeypatch): DQBase.metadata.clear() qc = create_default_check_class( ResultTable(schema_name="data_quality", table_name="t", model_cls=QualityCheck)) qc.__table__.create(conn.engine) instance = qc() conn.execute(""" insert into data_quality.quality_check_t(attribute, rule_name, rule_type, failed, passed, task_ts, time_filter) values ('a', 'b', 'not_null', 10, 200, '2018-09-11T13:00:00', 'not_set'), ('a', 'b', 'not_null', 3, 22, '2018-09-10T13:00:00', 'not_set'), ('a', 'b', 'not_null', 11, 110, '2018-09-09T13:00:00', 'not_set'), ('a', 'b', 'not_null', 55, 476, '2018-09-08T13:00:00', 'not_set'), ('a', 'b', 'not_null', 77, 309, '2018-07-12T13:00:00', 'not_set') -- should not be taken """) monkeypatch.setattr("contessa.models.datetime", FakedDatetime) instance.set_medians(conn) assert instance.median_30_day_failed == 10.5 assert instance.median_30_day_passed == 155
def apply(self, conn: Connector): """ Execute a formatted sql. Check if it returns list of booleans that is needed to do a quality check. If yes, return pd.Series. :return: pd.Series """ sql = self.sql_with_where results = [r for r in conn.get_records(sql) ] # returns generator, so get it to memory is_list_of_bool = all((len(r) == 1 and isinstance(r[0], (bool, type(None))) for r in results)) if not is_list_of_bool: raise ValueError( f"Your query for rule `{self.name}` does not return list of booleans or Nones." ) return pd.Series([bool(r[0]) for r in results])
def test_upsert_qualitycheck(conn: Connector): from sqlalchemy import Column, DateTime, text, UniqueConstraint from sqlalchemy.dialects.postgresql import TEXT, INTEGER, BIGINT class A(DQBase): id = Column(BIGINT, primary_key=True) name = Column(TEXT, nullable=False) price = Column(INTEGER) created_at = Column( DateTime(timezone=True), server_default=text("NOW()"), nullable=False, index=True, ) __tablename__ = "my_table" __table_args__ = (UniqueConstraint("name", name=f"unique_constraint_test",),) conn.ensure_table(A.__table__) instance = A(name="hello", price=13) conn.upsert( objs=[instance,] ) # check if inserted s = conn.make_session() row = s.query(A.__table__).all() s.expunge_all() s.commit() assert len(row) == 1 assert row[0].price == 13 # change data and insert again - should upsert instance.price = 42 conn.upsert( objs=[instance,] ) row = s.query(A.__table__).all() s.expunge_all() s.commit() assert len(row) == 1 assert row[0].price == 42 s.close()
def conn(): h = Connector(TEST_DB_URI) schemas = ["tmp", "temporary", "data_quality", "raw"] create_queries = [f"create schema if not exists {s}" for s in schemas] drop_queries = [f"drop schema if exists {s} cascade" for s in schemas] for c in create_queries: h.execute(c) yield h for d in drop_queries: h.execute(d)
def set_medians(self, conn: Connector, days=30): """ Calculate median of passed/failed quality checks from last 30 days. """ now = datetime.today().date() past = now - timedelta(days=days) cls = self.__class__ session = conn.make_session() checks = (session.query(cls.failed, cls.passed).filter( and_(cls.task_ts <= str(now), cls.task_ts >= str(past))).all()) session.expunge_all() session.commit() session.close() failed = [ch.failed for ch in checks] self.median_30_day_failed = median(failed) if failed else None passed = [ch.passed for ch in checks] self.median_30_day_passed = median(passed) if passed else None
class TestReturnResults(unittest.TestCase): def setUp(self): """ Init a temporary table with some data. """ self.left_table_name = "raw_booking" self.right_table_name = "booking" self.ts_nodash = (FakedDatetime.now().isoformat().replace("-", "").replace( ":", "")) self.now = FakedDatetime.now() sql = [ "DROP SCHEMA if exists tmp CASCADE;", "CREATE SCHEMA IF NOT EXISTS tmp;", "CREATE SCHEMA IF NOT EXISTS hello;", f""" CREATE TABLE IF NOT EXISTS tmp.{self.left_table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" CREATE TABLE IF NOT EXISTS hello.{self.right_table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" INSERT INTO tmp.{self.left_table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'), (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00'), ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00') """, f""" INSERT INTO hello.{self.right_table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'), (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00') """, ] self.conn = Connector(TEST_DB_URI) for s in sql: self.conn.execute(s) self.consistency_checker = ConsistencyChecker(TEST_DB_URI) def tearDown(self): """ Drop all created tables. """ self.conn.execute(f"DROP schema tmp CASCADE;") self.conn.execute(f"DROP schema hello CASCADE;") DQBase.metadata.clear() @mock.patch("contessa.executor.datetime", FakedDatetime) def test_execute_consistency(self): result = self.consistency_checker.run( self.consistency_checker.COUNT, left_check_table={ "schema_name": "tmp", "table_name": self.left_table_name }, right_check_table={ "schema_name": "hello", "table_name": self.right_table_name, }, context={"task_ts": self.now}, ) self.assertEqual(result.status, "invalid") self.assertEqual(result.context["left_table_name"], "tmp.raw_booking") self.assertEqual(result.context["right_table_name"], "hello.booking")
class TestMigrationsResolverInit(unittest.TestCase): def setUp(self): """ Init a temporary table with some data. """ sql = [ f"DROP SCHEMA IF EXISTS {DATA_QUALITY_SCHEMA} CASCADE;" f"CREATE SCHEMA IF NOT EXISTS {DATA_QUALITY_SCHEMA};", get_quality_table_creation_script(DATA_QUALITY_SCHEMA, DATA_QUALITY_TABLE_1), get_quality_table_creation_script(DATA_QUALITY_SCHEMA, DATA_QUALITY_TABLE_2), ] self.conn = Connector(TEST_DB_URI) for s in sql: self.conn.execute(s) def tearDown(self): """ Drop all created tables. """ self.conn.execute(f"DROP schema {DATA_QUALITY_SCHEMA} CASCADE;") DQBase.metadata.clear() def test_migration_table_exists_init(self): versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"} m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) migration_table_exists = m.migrations_table_exists() assert migration_table_exists is False def test_get_current_migration_init(self): versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"} m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) current = m.get_applied_migration() assert current is None def test_is_on_head_init(self): versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"} m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) is_on_head = m.is_on_head() is_on_head is False def test_get_migrations_to_head__is_before_init(self): versions_migrations = { "0.1.2": "0.1.2-hash", "0.1.3": "0.1.3-hash", "0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash", "0.1.6": "0.1.6-hash", "0.1.7": "0.1.7-hash", } m = MigrationsResolver(versions_migrations, "0.1.7", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) migrations = m.get_migration_to_head() assert migrations[0] is "upgrade" assert migrations[1] is "0.1.7-hash"
class ContessaRunner: model_cls = QualityCheck def __init__(self, conn_uri_or_engine, special_qc_map=None): self.conn_uri_or_engine = conn_uri_or_engine self.conn = Connector(conn_uri_or_engine) # todo - allow cfg self.special_qc_map = special_qc_map or {} def run( self, raw_rules: List[Dict[str, str]], check_table: Dict, result_table: Optional[ Dict] = None, # todo - docs for quality name, maybe defaults.. context: Optional[Dict] = None, ) -> Union[CheckResult, QualityCheck]: check_table = Table(**check_table) context = self.get_context(check_table, context) normalized_rules = self.normalize_rules(raw_rules) refresh_executors(check_table, self.conn, context) if result_table: result_table = ResultTable(**result_table, model_cls=self.model_cls) quality_check_class = self.get_quality_check_class(result_table) self.conn.ensure_table(quality_check_class.__table__) else: quality_check_class = CheckResult rules = self.build_rules(normalized_rules) objs = self.do_quality_checks(quality_check_class, rules, context) if result_table: self.conn.upsert(objs) return objs @staticmethod def get_context(check_table: Table, context: Optional[Dict] = None) -> Dict: """ Construct context to pass to executors. User context overrides defaults. """ ctx_defaults = { "table_fullname": check_table.fullname, "task_ts": datetime.now(), # todo - is now() ok ? } if context: ctx_defaults.update(context) return ctx_defaults def normalize_rules(self, raw_rules): return RuleNormalizer.normalize(raw_rules) def do_quality_checks(self, dq_cls, rules: List[Rule], context: Dict = None): """ Run quality check for all rules. Use `qc_cls` to construct objects that will be inserted afterwards. """ ret = [] for rule in rules: obj = self.apply_rule(context, dq_cls, rule) ret.append(obj) return ret def apply_rule(self, context, dq_cls, rule): e = get_executor(rule) logging.info(f"Executing rule `{rule}`.") results = e.execute(rule) obj = dq_cls() obj.init_row(rule, results, self.conn, context) return obj @staticmethod def build_rules(normalized_rules): """ Construct rules classes from user definition that are dicts. Raises if there are bad arguments for a certain rule. :return: list of Rule objects """ ret = [] for rule_def in normalized_rules: rule_cls = ContessaRunner.pick_rule_cls(rule_def) try: r = rule_cls(**rule_def) except Exception as e: logging.error(f"For rule `{rule_cls.__name__}`. {e.args[0]}") raise else: ret.append(r) return ret @staticmethod def pick_rule_cls(rule_def): """ Get rule class based on its type that was input by user. :param rule_def: dict :return: Rule class """ return get_rule_cls(rule_def["type"]) def get_quality_check_class(self, result_table: ResultTable): """ QualityCheck can be different, e.g. `special_table` has specific quality_check. Or kind of generic one that computes number of passed/failed objects etc. So determine if is special or not and return the class. :return: QualityCheck cls """ special_checks = self.special_qc_map.keys() if result_table.fullname in special_checks: quality_check_class = self.special_qc_map[result_table.fullname] logging.info( f"Using {quality_check_class.__name__} as quality check class." ) else: quality_check_class = create_default_check_class(result_table) logging.info("Using default QualityCheck class.") return quality_check_class
class TestMigrationsResolver(unittest.TestCase): def setUp(self): """ Init a temporary table with some data. """ sql = [ f"DROP SCHEMA IF EXISTS {DATA_QUALITY_SCHEMA} CASCADE;" f"CREATE SCHEMA IF NOT EXISTS {DATA_QUALITY_SCHEMA};", get_quality_table_creation_script(DATA_QUALITY_SCHEMA, DATA_QUALITY_TABLE_1), get_quality_table_creation_script(DATA_QUALITY_SCHEMA, DATA_QUALITY_TABLE_2), f""" create table {DATA_QUALITY_SCHEMA}.{MIGRATION_TABLE} ( version_num varchar(32) not null constraint alembic_version_pkc primary key ); INSERT INTO {DATA_QUALITY_SCHEMA}.{MIGRATION_TABLE} (version_num) VALUES ('0.1.4-hash'); """, ] self.conn = Connector(TEST_DB_URI) for s in sql: self.conn.execute(s) def tearDown(self): """ Drop all created tables. """ # self.conn.execute(f"DROP schema data_quality_test CASCADE;") DQBase.metadata.clear() def test_schema_exists(self): versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"} m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) schema_exists = m.schema_exists() assert schema_exists m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL, "not_existing_schema") schema_exists = m.schema_exists() assert schema_exists is False def test_migration_table_exists(self): versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"} m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) migration_table_exists = m.migrations_table_exists() assert migration_table_exists def test_get_current_migration(self): versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"} m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) current = m.get_applied_migration() assert current == "0.1.4-hash" def test_is_on_head(self): versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"} m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) is_on_head = m.is_on_head() assert is_on_head def test_is_on_head_no_on_head(self): versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"} m = MigrationsResolver(versions_migrations, "0.1.5", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) is_on_head = m.is_on_head() assert is_on_head is False def test_is_on_head_with_fallback(self): versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.6": "0.1.6-hash"} m = MigrationsResolver(versions_migrations, "0.1.5", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) is_on_head = m.is_on_head() assert is_on_head def test_get_migrations_to_head__already_on_head(self): versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"} m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) migrations = m.get_migration_to_head() assert migrations is None def test_get_migrations_to_head__package_greather_than_map_max(self): versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"} m = MigrationsResolver(versions_migrations, "0.1.6", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) migrations = m.get_migration_to_head() assert migrations[0] is "upgrade" assert migrations[1] is "0.1.5-hash" def test_get_migrations_to_head__is_down_from_head(self): versions_migrations = { "0.1.2": "0.1.2-hash", "0.1.3": "0.1.3-hash", "0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash", "0.1.6": "0.1.6-hash", "0.1.7": "0.1.7-hash", } m = MigrationsResolver(versions_migrations, "0.1.7", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) migrations = m.get_migration_to_head() assert migrations[0] is "upgrade" assert migrations[1] is "0.1.7-hash" def test_get_migrations_to_head__is_down_from_head_with_fallback(self): versions_migrations = { "0.1.2": "0.1.2-hash", "0.1.3": "0.1.3-hash", "0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash", "0.1.8": "0.1.8-hash", "0.1.9": "0.1.9-hash", } m = MigrationsResolver(versions_migrations, "0.1.7", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) migrations = m.get_migration_to_head() assert migrations[0] is "upgrade" assert migrations[1] is "0.1.5-hash" def test_get_migrations_to_head__is_up_from_head(self): versions_migrations = { "0.1.2": "0.1.2-hash", "0.1.3": "0.1.3-hash", "0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash", "0.1.6": "0.1.6-hash", "0.1.7": "0.1.7-hash", } m = MigrationsResolver(versions_migrations, "0.1.2", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) migrations = m.get_migration_to_head() assert migrations[0] is "downgrade" assert migrations[1] is "0.1.2-hash" def test_get_migrations_to_head__is_up_from_head_with_fallback(self): versions_migrations = { "0.1.1": "0.1.1-hash", "0.1.3": "0.1.3-hash", "0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash", "0.1.6": "0.1.6-hash", "0.1.7": "0.1.7-hash", } m = MigrationsResolver(versions_migrations, "0.1.2", SQLALCHEMY_URL, DATA_QUALITY_SCHEMA) migrations = m.get_migration_to_head() assert migrations[0] is "downgrade" assert migrations[1] is "0.1.1-hash"
def setUp(self): """ Init a temporary table with some data. """ self.table_name = "booking_all_v2" self.ts_nodash = (FakedDatetime.now().isoformat().replace("-", "").replace( ":", "")) self.tmp_table_name = f"{self.table_name}_{self.ts_nodash}" self.now = FakedDatetime.now() sql = [ "DROP SCHEMA if exists tmp CASCADE;", "DROP SCHEMA if exists data_quality CASCADE;", "CREATE SCHEMA IF NOT EXISTS tmp;", "CREATE SCHEMA IF NOT EXISTS data_quality;", "CREATE SCHEMA IF NOT EXISTS hello;", f""" CREATE TABLE IF NOT EXISTS tmp.{self.table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" CREATE TABLE IF NOT EXISTS tmp.{self.tmp_table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" INSERT INTO tmp.{self.tmp_table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'), -- this is older than 30 days. -- not in stats when time_filter = `created_at` (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00'), ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00') """, f""" INSERT INTO tmp.{self.table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T13:00:00'), -- this is older than 30 days. -- not in stats when time_filter = `created_at` (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T13:00:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T13:00:00'), ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T13:00:00') """, ] self.conn = Connector(TEST_DB_URI) for s in sql: self.conn.execute(s) self.contessa_runner = ContessaRunner(TEST_DB_URI)
class TestDataQualityOperator(unittest.TestCase): def setUp(self): """ Init a temporary table with some data. """ self.table_name = "booking_all_v2" self.ts_nodash = (FakedDatetime.now().isoformat().replace("-", "").replace( ":", "")) self.tmp_table_name = f"{self.table_name}_{self.ts_nodash}" self.now = FakedDatetime.now() sql = [ "DROP SCHEMA if exists tmp CASCADE;", "DROP SCHEMA if exists data_quality CASCADE;", "CREATE SCHEMA IF NOT EXISTS tmp;", "CREATE SCHEMA IF NOT EXISTS data_quality;", "CREATE SCHEMA IF NOT EXISTS hello;", f""" CREATE TABLE IF NOT EXISTS tmp.{self.table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" CREATE TABLE IF NOT EXISTS tmp.{self.tmp_table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" INSERT INTO tmp.{self.tmp_table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'), -- this is older than 30 days. -- not in stats when time_filter = `created_at` (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00'), ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00') """, f""" INSERT INTO tmp.{self.table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T13:00:00'), -- this is older than 30 days. -- not in stats when time_filter = `created_at` (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T13:00:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T13:00:00'), ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T13:00:00') """, ] self.conn = Connector(TEST_DB_URI) for s in sql: self.conn.execute(s) self.contessa_runner = ContessaRunner(TEST_DB_URI) def tearDown(self): """ Drop all created tables. """ self.conn.execute(f"DROP schema tmp CASCADE;") self.conn.execute(f"DROP schema data_quality CASCADE;") self.conn.execute(f"DROP schema hello CASCADE;") DQBase.metadata.clear() @mock.patch("contessa.executor.datetime", FakedDatetime) def test_execute_tmp(self): sql = """ SELECT CASE WHEN src = 'BTS' and dst is null THEN false ELSE true END as res from {{ table_fullname }} """ rules = [ { "name": "not_null_name", "type": "not_null", "column": "dst", "time_filter": "created_at", }, { "name": "gt_name", "type": "gt", "column": "price", "value": 10, "time_filter": "created_at", }, { "name": "sql_name", "type": "sql", "sql": sql, "column": "src_dst", "description": "test sql rule", }, { "name": "not_name", "type": "not", "column": "src", "value": "dst" }, ] self.contessa_runner.run( check_table={ "schema_name": "tmp", "table_name": self.tmp_table_name }, result_table={ "schema_name": "data_quality", "table_name": self.table_name }, raw_rules=rules, context={"task_ts": self.now}, ) rows = self.conn.get_records(f""" SELECT * from data_quality.quality_check_{self.table_name} order by created_at """).fetchall() self.assertEqual(len(rows), 4) notnull_rule = rows[0] self.assertEqual(notnull_rule["failed"], 1) self.assertEqual(notnull_rule["passed"], 2) self.assertEqual(notnull_rule["attribute"], "dst") self.assertEqual(notnull_rule["task_ts"].timestamp(), self.now.timestamp()) gt_rule = rows[1] self.assertEqual(gt_rule["failed"], 3) self.assertEqual(gt_rule["passed"], 0) self.assertEqual(gt_rule["attribute"], "price") sql_rule = rows[2] self.assertEqual(sql_rule["failed"], 1) self.assertEqual(sql_rule["passed"], 3) self.assertEqual(sql_rule["attribute"], "src_dst") not_column_rule = rows[3] self.assertEqual(not_column_rule["failed"], 1) self.assertEqual(not_column_rule["passed"], 3) self.assertEqual(not_column_rule["attribute"], "src") @mock.patch("contessa.executor.datetime", FakedDatetime) def test_execute_dst(self): sql = """ SELECT CASE WHEN src = 'BTS' AND dst is null THEN false ELSE true END as res FROM {{ table_fullname }} WHERE created_at BETWEEN timestamptz '{{task_ts}}' - INTERVAL '1 hour' AND timestamptz '{{task_ts}}' + INTERVAL '1 hour' """ rules = [ { "name": "not_null_name", "type": "not_null", "column": "dst", "time_filter": "created_at", }, { "name": "sql_name", "type": "sql", "sql": sql, "column": "src_dst", "description": "test sql rule", }, ] self.contessa_runner.run( check_table={ "schema_name": "tmp", "table_name": self.tmp_table_name }, result_table={ "schema_name": "data_quality", "table_name": self.table_name }, raw_rules=rules, context={"task_ts": self.now}, ) rows = self.conn.get_records(f""" SELECT * from data_quality.quality_check_{self.table_name} order by created_at """).fetchall() self.assertEqual(len(rows), 2) notnull_rule = rows[0] self.assertEqual(notnull_rule["failed"], 1) self.assertEqual(notnull_rule["passed"], 2) self.assertEqual(notnull_rule["attribute"], "dst") self.assertEqual(notnull_rule["task_ts"].timestamp(), self.now.timestamp()) sql_rule = rows[1] self.assertEqual(sql_rule["failed"], 1) self.assertEqual(sql_rule["passed"], 0) self.assertEqual(sql_rule["attribute"], "src_dst") def test_different_schema(self): rules = [{ "name": "not_nul_name", "type": "not_null", "column": "dst", "time_filter": "created_at", }] self.contessa_runner.run( check_table={ "schema_name": "tmp", "table_name": self.tmp_table_name }, result_table={ "schema_name": "hello", "table_name": "abcde" }, raw_rules=rules, context={"task_ts": self.now}, ) rows = self.conn.get_records(f""" SELECT 1 from hello.quality_check_abcde """).fetchall() self.assertEqual(len(rows), 1)
class ConsistencyChecker: """ Checks consistency of the sync between two tables. """ model_cls = ConsistencyCheck COUNT = "count" DIFF = "difference" def __init__(self, left_conn_uri_or_engine, right_conn_uri_or_engine=None): self.left_conn_uri_or_engine = left_conn_uri_or_engine self.left_conn = Connector(left_conn_uri_or_engine) if right_conn_uri_or_engine is None: self.right_conn_uri_or_engine = self.left_conn_uri_or_engine self.right_conn = self.left_conn else: self.right_conn_uri_or_engine = right_conn_uri_or_engine self.right_conn = Connector(right_conn_uri_or_engine) def run( self, method: str, left_check_table: Dict, right_check_table: Dict, result_table: Optional[Dict] = None, columns: Optional[List[str]] = None, time_filter: str = TIME_FILTER_DEFAULT, left_custom_sql: str = None, right_custom_sql: str = None, context: Optional[Dict] = None, ) -> Union[CheckResult, ConsistencyCheck]: if left_custom_sql and right_custom_sql: if columns or time_filter: raise ValueError( "When using custom sqls you cannot change 'columns' or 'time_filter' attribute" ) left_check_table = Table(**left_check_table) right_check_table = Table(**right_check_table) context = self.get_context(left_check_table, right_check_table, context) result = self.do_consistency_check( method, columns, time_filter, left_check_table, right_check_table, left_custom_sql, right_custom_sql, context, ) if result_table: result_table = ResultTable(**result_table, model_cls=self.model_cls) quality_check_class = create_default_check_class(result_table) self.right_conn.ensure_table(quality_check_class.__table__) self.upsert(quality_check_class, result) return result obj = CheckResult() obj.init_row_consistency(**result) return obj @staticmethod def get_context( left_check_table: Table, right_check_table: Table, context: Optional[Dict] = None, ) -> Dict: """ Construct context to pass to executors. User context overrides defaults. """ ctx_defaults = { "left_table_fullname": left_check_table.fullname, "right_table_fullname": right_check_table.fullname, "task_ts": datetime.now(), } if context: ctx_defaults.update(context) return ctx_defaults def do_consistency_check( self, method: str, columns: Optional[List[str]], time_filter: str, left_check_table: Table, right_check_table: Table, left_sql: str = None, right_sql: str = None, context: Dict = None, ): """ Run quality check for all rules. Use `qc_cls` to construct objects that will be inserted afterwards. """ if not left_sql or not right_sql: if method == self.COUNT: if columns: column = f"count({', '.join(columns)})" else: column = "count(*)" elif method == self.DIFF: if columns: column = ", ".join(columns) else: # List the columns explicitly in case column order of compared tables is not the same. column = ", ".join( sorted( self.right_conn.get_column_names( right_check_table.fullname))) else: raise NotImplementedError(f"Method {method} not implemented") if not left_sql: left_sql = self.construct_default_query(left_check_table.fullname, column, time_filter, context) left_result = self.run_query(self.left_conn, left_sql, context) if not right_sql: right_sql = self.construct_default_query( right_check_table.fullname, column, time_filter, context) right_result = self.run_query(self.right_conn, right_sql, context) valid, passed, failed = self.compare_results(left_result, right_result, method) return { "check": { "type": method, "description": "", "name": "consistency", "passed": passed, "failed": failed, }, "status": "valid" if valid else "invalid", "left_table_name": left_check_table.fullname, "right_table_name": right_check_table.fullname, "time_filter": time_filter, "context": context, } def compare_results(self, left_result, right_result, method): if method == self.COUNT: left_count = left_result[0][0] right_count = right_result[0][0] passed = min(left_count, right_count) failed = (left_count - passed) - (right_count - passed) return failed == 0, passed, failed elif method == self.DIFF: left_set = set(left_result) right_set = set(right_result) common = left_set.intersection(right_set) passed = len(common) failed = (len(left_set) - len(common)) + (len(right_set) - len(common)) return failed == 0, passed, failed else: raise NotImplementedError(f"Method {method} not implemented") def construct_default_query(self, table_name, column, time_filter, context): time_filter = compose_where_time_filter(time_filter, context["task_ts"]) query = f""" SELECT {column} FROM {table_name} {f'WHERE {time_filter}' if time_filter else ''} """ return query def render_sql(self, sql, context): """ Replace some parameters in query. :return str, formatted sql """ t = jinja2.Template(sql) rendered = t.render(**context) return rendered def run_query(self, conn: Connector, query: str, context): query = self.render_sql(query, context) logging.info(query) result = [r._row for r in conn.get_records(query)] return result def upsert(self, dc_cls, result): obj = dc_cls() obj.init_row(**result) self.right_conn.upsert([ obj, ])
class MigrationsResolver: """ Migrations helper class for the Contessa migrations. """ def __init__(self, migrations_map, package_version, url, schema): """ :param migrations_map: map of package versions and their migrations. In form of dictionary {'0.1.4':'A', '0.1.5':'B'} :param package_version: the version of the package planned to be migrated :param url: the database url where the Alembic migration table is present or planned to be created :param schema: the database schema where the Alembic migration table is present or planned to be created """ self.versions_migrations = migrations_map self.package_version = package_version self.url = url self.schema = schema self.conn = Connector(self.url) def schema_exists(self): """ Check if schema with the Alembic migration table exists. :return: Return true if schema with the Alembic migration exists. """ result = self.conn.get_records(f""" SELECT EXISTS ( SELECT 1 FROM information_schema.schemata WHERE schema_name = '{self.schema}' ); """) return result.first()[0] def migrations_table_exists(self): """ Check if the Alembic versions table exists. """ result = self.conn.get_records(f""" SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = '{self.schema}' AND table_name = '{MIGRATION_TABLE}' ); """) return result.first()[0] def get_applied_migration(self): """ Get the current applied migration in the target schema. """ if self.migrations_table_exists() is False: return None version = self.conn.get_records( f"select * from {self.schema}.{MIGRATION_TABLE}") return version.first()[0] def is_on_head(self): """ Check if the current applied migration is valid for the Contessa version. """ if self.migrations_table_exists() is False: return False current = self.get_applied_migration() fallback_package_version = self.get_fallback_version() return self.versions_migrations[fallback_package_version] == current def get_fallback_version(self): """ Get fallback version in the case of non existing migration for the Contessa package version. Returns the last package version containing the migration. In the case we have this migrations versions map versions_migrations = { "0.1.4": "54f8985b0ee5", "0.1.5": "480e6618700d", "0.1.8": "3w4er8y50yyd", "0.1.9": "034hfa8943hr", } and we ask for the fallback version of our current version e.g. 0.1.7., we get 0.1.5. as a result because it is the last version with specified migration before ours. """ keys = list(self.versions_migrations.keys()) if self.package_version in self.versions_migrations.keys(): return self.package_version if pv(self.package_version) < pv(keys[0]): return list(self.versions_migrations.keys())[0] if pv(self.package_version) > pv(keys[-1]): return list(self.versions_migrations.keys())[-1] result = keys[0] for k in keys[1:]: if pv(k) <= pv(self.package_version): result = k else: return result def get_migration_to_head(self): """ Get the migration command for alembic. Migration command is a tupple of type of migration and migration hash. E.g. ('upgrade', 'dfgdfg5b0ee5') or ('downgrade', 'dfgdfg5b0ee5') """ if self.is_on_head(): return None fallback_version = self.get_fallback_version() if self.migrations_table_exists() is False: return "upgrade", self.versions_migrations[fallback_version] migrations_versions = dict( map(reversed, self.versions_migrations.items())) applied_migration = self.get_applied_migration() applied_package = migrations_versions[applied_migration] if pv(applied_package) < pv(fallback_version): return "upgrade", self.versions_migrations[fallback_version] if pv(applied_package) > pv(fallback_version): return "downgrade", self.versions_migrations[fallback_version]
def run_query(self, conn: Connector, query: str, context): query = self.render_sql(query, context) logging.info(query) result = [r._row for r in conn.get_records(query)] return result
class ContessaRunner: """ todo - rewrite comments """ def __init__(self, conn_uri_or_engine, special_qc_map=None): self.conn_uri_or_engine = conn_uri_or_engine self.conn = Connector(conn_uri_or_engine) # todo - allow cfg self.special_qc_map = special_qc_map or {} def run( self, raw_rules: List[Dict[str, str]], check_table: Dict, result_table: Dict, # todo - docs for quality name, maybe defaults.. context: Optional[Dict] = None, ): check_table = Table(**check_table) result_table = ResultTable(**result_table) context = self.get_context(check_table, context) normalized_rules = self.normalize_rules(raw_rules) refresh_executors(check_table, self.conn, context) quality_check_class = self.get_quality_check_class(result_table) self.ensure_table(quality_check_class) rules = self.build_rules(normalized_rules) objs = self.do_quality_checks(quality_check_class, rules, context) self.insert(objs) @staticmethod def get_context(check_table: Table, context: Optional[Dict] = None) -> Dict: """ Construct context to pass to executors. User context overrides defaults. """ ctx_defaults = { "table_fullname": check_table.fullname, "task_ts": datetime.now(), # todo - is now() ok ? } ctx_defaults.update(context) return ctx_defaults def normalize_rules(self, raw_rules): return RuleNormalizer.normalize(raw_rules) def do_quality_checks(self, dq_cls, rules: List[Rule], context: Dict = None): """ Run quality check for all rules. Use `qc_cls` to construct objects that will be inserted afterwards. """ ret = [] for rule in rules: obj = self.apply_rule(context, dq_cls, rule) ret.append(obj) return ret def apply_rule(self, context, dq_cls, rule): e = get_executor(rule) logging.info(f"Executing rule `{rule}`.") results = e.execute(rule) obj = dq_cls() obj.init_row(rule, results, self.conn, context) return obj def insert(self, objs): """ Insert QualityCheck objects using sqlalchemy. If there is integrity error, skip it. """ logging.info(f"Inserting {len(objs)} results.") session = self.conn.make_session() try: session.add_all(objs) session.commit() except sqlalchemy.exc.IntegrityError: ts = objs[0].task_ts logging.info( f"This quality check ({ts}) was already done. Skipping it this time." ) session.rollback() finally: session.close() def ensure_table(self, qc_cls): """ Create table for QualityCheck class if it doesn't exists. E.g. quality_check_ """ try: qc_cls.__table__.create(bind=self.conn.engine) logging.info(f"Created table {qc_cls.__tablename__}.") except sqlalchemy.exc.ProgrammingError: logging.info( f"Table {qc_cls.__tablename__} already exists. Skipping creation." ) def build_rules(self, normalized_rules): """ Construct rules classes from user definition that are dicts. Raises if there are bad arguments for a certain rule. :return: list of Rule objects """ ret = [] for rule_def in normalized_rules: rule_cls = self.pick_rule_cls(rule_def) try: r = rule_cls(**rule_def) except Exception as e: logging.error(f"For rule `{rule_cls.__name__}`. {e.args[0]}") raise else: ret.append(r) return ret def pick_rule_cls(self, rule_def): """ Get rule class based on its name that was input by user. :param rule_def: dict :return: Rule class """ return get_rule_cls(rule_def["name"]) def get_quality_check_class(self, result_table: ResultTable): """ QualityCheck can be different, e.g. `special_table` has specific quality_check. Or kind of generic one that computes number of passed/failed objects etc. So determine if is special or not and return the class. :return: QualityCheck cls """ special_checks = self.special_qc_map.keys() if result_table.fullname in special_checks: quality_check_class = self.special_qc_map[result_table.fullname] logging.info( f"Using {quality_check_class.__name__} as quality check class." ) else: quality_check_class = create_default_quality_check_class( result_table) logging.info("Using default QualityCheck class.") return quality_check_class
class TestConsistencyChecker(unittest.TestCase): def setUp(self): """ Init a temporary table with some data. """ self.left_table_name = "raw_booking" self.right_table_name = "booking" self.result_table_name = "booking" self.ts_nodash = ( FakedDatetime.now().isoformat().replace("-", "").replace(":", "") ) self.now = FakedDatetime.now() sql = [ "DROP SCHEMA if exists tmp CASCADE;", "DROP SCHEMA if exists data_quality CASCADE;", "CREATE SCHEMA IF NOT EXISTS tmp;", "CREATE SCHEMA IF NOT EXISTS data_quality;", "CREATE SCHEMA IF NOT EXISTS hello;", f""" CREATE TABLE IF NOT EXISTS tmp.{self.left_table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" CREATE TABLE IF NOT EXISTS hello.{self.right_table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" INSERT INTO tmp.{self.left_table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'), (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00'), ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00') """, f""" INSERT INTO hello.{self.right_table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'), (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00') """, ] self.conn = Connector(TEST_DB_URI) for s in sql: self.conn.execute(s) self.consistency_checker = ConsistencyChecker(TEST_DB_URI) def tearDown(self): """ Drop all created tables. """ self.conn.execute(f"DROP schema tmp CASCADE;") self.conn.execute(f"DROP schema data_quality CASCADE;") self.conn.execute(f"DROP schema hello CASCADE;") DQBase.metadata.clear() @mock.patch("contessa.executor.datetime", FakedDatetime) def test_execute_consistency_false(self): self.consistency_checker.run( self.consistency_checker.COUNT, left_check_table={"schema_name": "tmp", "table_name": self.left_table_name}, right_check_table={ "schema_name": "hello", "table_name": self.right_table_name, }, result_table={ "schema_name": "data_quality", "table_name": self.result_table_name, }, context={"task_ts": self.now}, ) rows = self.conn.get_records( f""" SELECT * from data_quality.consistency_check_{self.result_table_name} order by created_at """ ) self.assertEqual(rows.fetchone()["status"], "invalid") @mock.patch("contessa.executor.datetime", FakedDatetime) def test_execute_consistency_true(self): # add missing record to the right table self.conn.execute( f""" INSERT INTO hello.{self.right_table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00') """ ) self.consistency_checker.run( self.consistency_checker.COUNT, left_check_table={"schema_name": "tmp", "table_name": self.left_table_name}, right_check_table={ "schema_name": "hello", "table_name": self.right_table_name, }, result_table={ "schema_name": "data_quality", "table_name": self.result_table_name, }, context={"task_ts": self.now}, ) rows = self.conn.get_records( f""" SELECT * from data_quality.consistency_check_{self.result_table_name} order by created_at """ ) self.assertEqual(rows.fetchone()["status"], "valid") @mock.patch("contessa.executor.datetime", FakedDatetime) def test_execute_consistency_diff(self): """ Test scenario where table diff is done on tables with different column order.""" self.inconsistent_colums_left = "user" self.inconsistent_colums_right = "user_inconsistent" self.conn = Connector(TEST_DB_URI) self.consistency_checker = ConsistencyChecker(TEST_DB_URI) sql = [ f""" CREATE TABLE IF NOT EXISTS tmp.{self.inconsistent_colums_left}( id SERIAL PRIMARY KEY, name text ) """, f""" CREATE TABLE IF NOT EXISTS tmp.{self.inconsistent_colums_right}( name text, id SERIAL PRIMARY KEY ) """, f""" INSERT INTO tmp.{self.inconsistent_colums_left} VALUES (1, 'John Doe') """, f""" INSERT INTO tmp.{self.inconsistent_colums_right} VALUES ('John Doe', 1) """, ] for s in sql: self.conn.execute(s) self.consistency_checker.run( self.consistency_checker.DIFF, left_check_table={ "schema_name": "tmp", "table_name": self.inconsistent_colums_left, }, right_check_table={ "schema_name": "tmp", "table_name": self.inconsistent_colums_right, }, result_table={ "schema_name": "data_quality", "table_name": self.result_table_name, }, context={"task_ts": self.now}, ) rows = self.conn.get_records( f""" SELECT * from data_quality.consistency_check_{self.result_table_name} order by created_at """ ) self.assertEqual(rows.fetchone()["status"], "valid") @mock.patch("contessa.executor.datetime", FakedDatetime) def test_execute_consistency_sqls(self): result = self.consistency_checker.run( self.consistency_checker.DIFF, left_check_table={"schema_name": "tmp", "table_name": self.left_table_name}, right_check_table={ "schema_name": "hello", "table_name": self.right_table_name, }, left_custom_sql="SELECT 16349;", right_custom_sql="SELECT 16349;", context={"task_ts": self.now}, ) self.assertEqual("valid", result.status) result = self.consistency_checker.run( self.consistency_checker.DIFF, left_check_table={"schema_name": "tmp", "table_name": self.left_table_name}, right_check_table={ "schema_name": "hello", "table_name": self.right_table_name, }, left_custom_sql="SELECT 42;", right_custom_sql="SELECT 16349;", context={"task_ts": self.now}, ) self.assertEqual("invalid", result.status)
def setUpClass(cls) -> None: cls.conn = Connector(TEST_DB_URI) cls.alembic_cfg = Config(ALEMBIC_INI_PATH) migration_table_name = cls.alembic_cfg.get_main_option("version_table") cls.migration_table = Table(DATA_QUALITY_SCHEMA, migration_table_name)
def run_query(self, conn: Connector, query: str, context): query = self.render_sql(query, context) logging.debug(query) result = [tuple(r.values()) for r in conn.get_records(query)] return result
def __init__(self, conn_uri_or_engine, special_qc_map=None): self.conn_uri_or_engine = conn_uri_or_engine self.conn = Connector(conn_uri_or_engine) # todo - allow cfg self.special_qc_map = special_qc_map or {}
class ConsistencyChecker: """ Checks consistency of the sync between two tables. """ model_cls = ConsistencyCheck COUNT = "count" DIFF = "difference" def __init__(self, left_conn_uri_or_engine, right_conn_uri_or_engine=None): self.left_conn_uri_or_engine = left_conn_uri_or_engine self.left_conn = Connector(left_conn_uri_or_engine) if right_conn_uri_or_engine is None: self.right_conn_uri_or_engine = self.left_conn_uri_or_engine self.right_conn = self.left_conn else: self.right_conn_uri_or_engine = right_conn_uri_or_engine self.right_conn = Connector(right_conn_uri_or_engine) def run( self, method: str, left_check_table: Dict, right_check_table: Dict, result_table: Optional[Dict] = None, columns: Optional[List[str]] = None, time_filter: Optional[Union[str, List[Dict], TimeFilter]] = None, left_custom_sql: str = None, right_custom_sql: str = None, context: Optional[Dict] = None, example_selector: ExampleSelector = default_example_selector, ) -> Union[CheckResult, ConsistencyCheck]: if left_custom_sql and right_custom_sql: if columns or time_filter: raise ValueError( "When using custom sqls you cannot change 'columns' or 'time_filter' attribute" ) time_filter = parse_time_filter(time_filter) left_check_table = Table(**left_check_table) right_check_table = Table(**right_check_table) context = self.get_context(left_check_table, right_check_table, context) result = self.do_consistency_check( method, columns, time_filter, left_check_table, right_check_table, left_custom_sql, right_custom_sql, context, example_selector, ) if result_table: result_table = ResultTable(**result_table, model_cls=self.model_cls) quality_check_class = create_default_check_class(result_table) self.right_conn.ensure_table(quality_check_class.__table__) self.upsert(quality_check_class, result) return result obj = CheckResult() obj.init_row_consistency(**result) return obj @staticmethod def get_context( left_check_table: Table, right_check_table: Table, context: Optional[Dict] = None, ) -> Dict: """ Construct context to pass to executors. User context overrides defaults. """ ctx_defaults = { "left_table_fullname": left_check_table.fullname, "right_table_fullname": right_check_table.fullname, "task_ts": datetime.now(), } if context: ctx_defaults.update(context) return ctx_defaults def do_consistency_check( self, method: str, columns: Optional[List[str]], time_filter: Optional[TimeFilter], left_check_table: Table, right_check_table: Table, left_sql: str = None, right_sql: str = None, context: Dict = None, example_selector: ExampleSelector = default_example_selector, ): """ Run quality check for all rules. Use `qc_cls` to construct objects that will be inserted afterwards. """ if not left_sql or not right_sql: if method == self.COUNT: if columns: column = f"count({', '.join(columns)})" else: column = "count(*)" elif method == self.DIFF: if columns: column = ", ".join(columns) else: # List the columns explicitly in case column order of compared tables is not the same. column = ", ".join( sorted( self.right_conn.get_column_names( right_check_table.fullname))) else: raise NotImplementedError(f"Method {method} not implemented") if not left_sql: left_sql = self.construct_default_query(left_check_table.fullname, column, time_filter, context) left_result = self.run_query(self.left_conn, left_sql, context) if not right_sql: right_sql = self.construct_default_query( right_check_table.fullname, column, time_filter, context) right_result = self.run_query(self.right_conn, right_sql, context) results = self.compare_results(left_result, right_result, method, example_selector) return { "check": { "type": method, "description": "", "name": "consistency", }, "results": results, "left_table_name": left_check_table.fullname, "right_table_name": right_check_table.fullname, "time_filter": time_filter, "context": context, } def compare_results(self, left_result, right_result, method, example_selector): if method == self.COUNT: left_count = left_result[0][0] right_count = right_result[0][0] passed = min(left_count, right_count) failed = (left_count - passed) - (right_count - passed) return AggregatedResult( total_records=max(left_count, right_count), failed=failed, passed=passed, ) elif method == self.DIFF: left_set = set(left_result) right_set = set(right_result) common = left_set.intersection(right_set) passed = len(common) failed = (len(left_set) - len(common)) + (len(right_set) - len(common)) failed_examples = example_selector.select_examples( left_set.symmetric_difference(right_set)) return AggregatedResult( total_records=failed + passed, failed=failed, passed=passed, failed_example=list(failed_examples), ) else: raise NotImplementedError(f"Method {method} not implemented") def construct_default_query( self, table_name: str, column: str, time_filter: Optional[TimeFilter], context: Dict, ): if time_filter: if context.get("task_ts"): time_filter.now = context["task_ts"] time_filter = time_filter.sql query = f""" SELECT {column} FROM {table_name} {f'WHERE {time_filter}' if time_filter else ''} """ return query def render_sql(self, sql, context): """ Replace some parameters in query. :return str, formatted sql """ rendered = render_jinja_sql(sql, context) return rendered def run_query(self, conn: Connector, query: str, context): query = self.render_sql(query, context) logging.debug(query) result = [tuple(r.values()) for r in conn.get_records(query)] return result def upsert(self, dc_cls, result): obj = dc_cls() obj.init_row(**result) self.right_conn.upsert([ obj, ]) def construct_automatic_time_filter( self, left_check_table: Dict, created_at_column=None, updated_at_column=None, ) -> TimeFilter: left_check_table = Table(**left_check_table) if created_at_column is None and updated_at_column is None: raise ValueError( "Automatic time filter need at least one time column") since_column = updated_at_column or created_at_column since_sql = f"SELECT min({since_column}) FROM {left_check_table.fullname}" logging.info(since_sql) since = self.left_conn.get_records(since_sql).scalar() return TimeFilter( columns=[ TimeFilterColumn(since_column, since=since), ], conjunction=TimeFilterConjunction.AND, )